-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add RLHF code #736
Conversation
Hi, XTuner Team Could you please add a citation for the source of the Ray+vLLM-based RLHF architecture - OpenRLHF, such as in the README.md file: https://github.com/InternLM/xtuner?tab=readme-ov-file#%EF%B8%8F-acknowledgement. An example:
Thank you |
45702b6
to
1ffe51e
Compare
examples/rlhf/four_model_8gpu.py
Outdated
), | ||
) | ||
|
||
dataset_config = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议把经常会调整的配置放在配置文件靠前的地方,并且配置文件中多加一些注释方便用户理解。可以参考xtuner里的配置。
xtuner/rlhf/envs/txt_env.py
Outdated
if sample_data[i].rm_meta != 'default': | ||
cur_rm_data = [{ | ||
'role': 'system', | ||
'content': META_PROMPT[sample_data[i].rm_meta] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以在注释里说明下这个是conditional system prompt,可以加一下paper的链接。
以及META_PROMPT这个变量名太抽象了,最好还是按paper里的名字来吧。
# Adapted from | ||
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py | ||
"""Logging configuration.""" | ||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥有地方用了loguru的logger,这里又用了自定义的logger?最好统一一下吧
xtuner/rlhf/policy_output.py
Outdated
if not gather or labels is None: | ||
return logp | ||
logpy = torch.gather(logp, -1, labels.unsqueeze(2)).squeeze(-1) | ||
return logpy.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥这儿要cuda一下?会变到cpu上吗?
xtuner/rlhf/policy_output.py
Outdated
if policy_output[key] is None: | ||
continue | ||
batch_size, seq_len = policy_output[ | ||
key].shape # assert: only support 2d tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key].shape # assert: only support 2d tensor | |
key].shap[:2] |
这样是不是兼容性会好点
xtuner/rlhf/policy_output.py
Outdated
def padding_policy_outputs(policy_outputs: list[PolicyOutput], | ||
padding_token_map={}): | ||
DEFAULT_PADDING_ID = 0 | ||
RIGHT_PADDING = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个是准备做成可配置的吗?现在是写死了的
xtuner/rlhf/repeaters/base.py
Outdated
import time | ||
|
||
import torch | ||
from loguru import logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一一下logger用哪个吧
xtuner/rlhf/repeaters/base.py
Outdated
num_actions = action_mask.size(1) | ||
if sft_model is not None: | ||
self.sft_model: BaseModelServer = sft_model | ||
kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs = self._get_kl_rewards( # noqa: E501 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs = self._get_kl_rewards( # noqa: E501 | |
(kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs) = self._get_kl_rewards( |
这儿加个括号之后yapf就能帮你换行了
xtuner/rlhf/repeaters/base.py
Outdated
s_t = time.time() | ||
value_output = value_model.infer_get(value_output_ref) | ||
raw_values = value_output.logits.squeeze(-1) | ||
logger.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看代码里好多这种计时的需求,往utils里面加个计时上下文是不是用起来方便点
class TimeLogger:
def __init__(self, message: str):
self.message = message
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
duration = round(time.time() - self.start_time, 2)
logger.info(f'{self.message} duration: {duration} s')
from loguru import logger | ||
|
||
|
||
class Timer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
诶,这里有个计时上下文呀,咋没用上
setup.py
Outdated
'rlhf': | ||
parse_requirements('requirements/rlhf.txt'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'rlhf': | |
parse_requirements('requirements/rlhf.txt'), | |
'rlhf': | |
parse_requirements('requirments/deepspeed.txt') + | |
parse_requirements('requirements/rlhf.txt'), |
requirements/rlhf.txt
Outdated
-r requirements/deepspeed.txt | ||
loguru | ||
ray[default,train]==2.9.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-r requirements/deepspeed.txt | |
loguru | |
ray[default,train]==2.9.1 | |
loguru | |
ray[default,train]==2.9.1 |
xtuner/rlhf/config/config_utils.py
Outdated
num_gpus = 1 | ||
if 'parallel' in trainer_config: | ||
parallel = trainer_config['parallel'] | ||
data = parallel.get('data', {'size': 1}) | ||
tensor = parallel.get('tensor', {'size': 1}) | ||
pipeline = parallel.get('pipeline', {'size': 1}) | ||
num_gpus = data['size'] * tensor['size'] * pipeline['size'] | ||
return num_gpus |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_gpus = 1 | |
if 'parallel' in trainer_config: | |
parallel = trainer_config['parallel'] | |
data = parallel.get('data', {'size': 1}) | |
tensor = parallel.get('tensor', {'size': 1}) | |
pipeline = parallel.get('pipeline', {'size': 1}) | |
num_gpus = data['size'] * tensor['size'] * pipeline['size'] | |
return num_gpus | |
return get_dp_size(trainer_config) * get_tp_size(trainer_config) * get_pp_size(trainer_config) |
xtuner/rlhf/coordinator.py
Outdated
logger.info( | ||
f'{model_name} {model.__class__.__name__}.is_initialized: {model.is_initialized}' # noqa: E501 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这个打印不是非常必要,还容易引起误解
try: | ||
client_context = ray.init( | ||
address=self.cluster_address, | ||
runtime_env=runtime_env, | ||
ignore_reinit_error=True, | ||
) | ||
logger.info( | ||
f'Connected to a running ray cluster at {self.cluster_address}' | ||
) | ||
self.context_type = 'client' | ||
self.context = client_context | ||
|
||
except ConnectionError: | ||
logger.info( | ||
f'Error connecting to {self.cluster_address}, try initializing a new ray cluster.' # noqa: E501 | ||
) | ||
ray_context = ray.init( | ||
address=None, | ||
resources=resources, | ||
runtime_env=runtime_env, | ||
ignore_reinit_error=True, | ||
) | ||
node_ip_address = ray_context.address_info['node_ip_address'] | ||
logger.info(f'Initialize a ray cluster at {node_ip_address}') | ||
self.context_type = 'server' | ||
self.context = ray_context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所以这是一个自动 fallback 到本地 ray server 的方案吗,感觉不要 fallback 直接抛出会更好
xtuner/rlhf/dataset/base.py
Outdated
class FileDataset(IterableDataset): | ||
"""Single json file dataset.""" | ||
|
||
def __init__(self, | ||
filename, | ||
tokenizer, | ||
sys_meta='default', | ||
rm_meta='default'): | ||
self._filename = filename | ||
self.tokenizer = tokenizer | ||
self.data_list = [] | ||
self.sys_meta = sys_meta | ||
self.rm_meta = rm_meta | ||
with open_file(self._filename) as fin: | ||
for lineno, line in enumerate(fin): | ||
data = json.loads(line) | ||
self.data_list.append(data) | ||
|
||
def __len__(self): | ||
return len(self.data_list) | ||
|
||
def __getitem__(self, index: int): | ||
data = self.data_list[index] | ||
try: | ||
self.tokenizer.apply_chat_template(data, tokenize=True) | ||
return { | ||
'data': data, | ||
'sys_meta': self.sys_meta, | ||
'rm_meta': self.rm_meta | ||
} | ||
except Exception: | ||
print(f'[data tokenize check] skip dirty data: {data}') | ||
return None | ||
|
||
def __iter__(self): | ||
with open_file(self._filename) as fin: | ||
for lineno, line in enumerate(fin): | ||
data = json.loads(line) | ||
try: | ||
self.tokenizer.apply_chat_template(data, tokenize=True) | ||
except Exception: | ||
print(f'[data tokenize check] skip dirty data: {data}') | ||
continue | ||
if data is None: | ||
continue | ||
yield { | ||
'data': data, | ||
'sys_meta': self.sys_meta, | ||
'rm_meta': self.rm_meta | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 作为一个
IterableDataset
实现__len__
,__getitem__
这些 MapDataset 的接口,感觉不是很合理 - 作为
IterableDataset
,在__init__
阶段将整个数据集从磁盘加载到内存,而在__iter__
中仍然从磁盘读取,是一种低效的行为 - 改个名字&docstring 可能更加合理。读取的文件是 Json line 而非 json,建议改名
JsonlDataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面的代码也存在同样问题
xtuner/rlhf/dataset/base.py
Outdated
def __iter__(self): | ||
while True: | ||
self.rng.shuffle(self.indices) | ||
for i in self.indices: | ||
yield self.data[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以考虑在 epoch 结束的时候,输出一下日志,方便 debug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要重构。建议
__init__
改成和 Anthropic/hh-rlhf usage 一致的用法,避免从path
中解析data_dir
- 删除
save_to_disk
和load_from_disk
- 改一下 docstring
xtuner/rlhf/dataset/txt_loader.py
Outdated
self.epoch_index = 0 | ||
|
||
def _init_in_data(self): | ||
print(f"========================= Init in data sampler =========================") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件的打印信息要清理一下,去除不必要的
xtuner/rlhf/dataset/txt_loader.py
Outdated
prompt_datasets: list[str] = None, | ||
pretrain_datasets: list[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
讨论一下:prompt_datasets
和 pretrain_datasets
分开可能更好?
xtuner/rlhf/envs/prompt_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
讨论:是否保留,或者换种形式。
初衷是 RM 训练和 PPO 训练之间共享。但目前的仓库结构,RM 和 PPO 不在一起。这样没有意义
xtuner/rlhf/envs/txt_env.py
Outdated
|
||
def __init__( | ||
self, | ||
dataloader: IterableDataset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataloader
是一个 Dataset
感觉不是非常合理
xtuner/rlhf/envs/txt_env.py
Outdated
dataloader (IterableDataset): generate rl data iteratively | ||
reward_function: reward function that computes scalar reward for each episode # noqa: E501 | ||
""" | ||
self.dataloader: IterableDataset = iter(dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果入参 dataloader
实际上是一个 Dataset
,这里用 torch.utils.data.DataLoader
更合理,可以享受更多特性
xtuner/rlhf/envs/txt_env.py
Outdated
'role': | ||
'assistant', | ||
'content': | ||
policyout.output_ans_str[i] | ||
}] | ||
else: | ||
cur_rm_data = sample_data[i].message + [{ | ||
'role': | ||
'assistant', | ||
'content': | ||
policyout.output_ans_str[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式有点问题?
xtuner/rlhf/loss/actor_loss.py
Outdated
from ..policy_output import logprobs_from_logits | ||
|
||
|
||
class ActorLoss(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ActorLoss(torch.nn.Module): | |
class PPOPolicyLoss(torch.nn.Module): |
xtuner/rlhf/loss/actor_loss.py
Outdated
def forward(self, logits: torch.Tensor, labels: dict[str, Any]): | ||
"""Forward function of ActorLoss. | ||
|
||
Args: | ||
logits (Tensor): Forward result of the model. Its shape may be varied. # noqa: E501 | ||
For packed forward: (micro_bsz * seqlen, 1), where micro_bsz = 1 # noqa: E501 | ||
For non packed forward: (micro_bsz, seqlen, 1) | ||
|
||
labels (tuple[dict]): Label values which are split by pipeline | ||
schedule into pieces. The length of the list is micro_bsz. Each | ||
element is a dict, representing labels to a batch. | ||
|
||
Note: | ||
The parameter `labels` seems strange because of pj-colossalai's | ||
pipeline schedule mechanism. Labels are delivered to colosslai.Engine # noqa: E501 | ||
in List format, so pipeline schedule split it into micro_bsz pieces, # noqa: E501 | ||
and deliver them to loss_fn by `*args`. | ||
|
||
Returns: | ||
Tensor: Return the final loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring mismatch
xtuner/rlhf/loss/actor_loss.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
讨论:这段代码看来经过了重构,但可能会和 InterEvo
的 megatron-style 训练方式不兼容。需要额外确认
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以及 per_seq
和 per_token
模式是否保留,是否需要额外说明
xtuner/rlhf/policy_output.py
Outdated
if isinstance(v, torch.Tensor): | ||
if not torch.equal(v, vother): | ||
return False | ||
elif isinstance(v, tuple): # tuple(torch.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
存在 v
是 Tuple[torch.Tensor]
的情况吗,似乎 v
只能是 torch.Tensor
或 None
如果存在,下面的 to(self, device)
的实现需要更改
xtuner/rlhf/policy_output.py
Outdated
if len(self.keys()) != len(other.keys()): | ||
return False | ||
for k, v in self.items(): | ||
if k not in other: | ||
return False | ||
vother = other[k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果 other
的 keys 是 self
的超集,也可以通过检查,是否是预期的行为?
如果不是,可以考虑
if self.keys() != other.keys():
return False
if model_path == 'internlm/internlm2-chat-1_8b-sft': | ||
return InternLM2ForCausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 hardcode 有什么特别的原因吗?
trainer_type = self.trainer_config.get('trainer_type', | ||
'huggingface').lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trainer_type = self.trainer_config.get('trainer_type', | |
'huggingface').lower() | |
trainer_type = self.trainer_config.get('trainer_type', | |
ENGINE_HUGGINGFACE).lower() |
def init_tokenizer_and_config(self, model_config): | ||
super().init_tokenizer_and_config(self.model_config) | ||
|
||
self.reward_token_id = self.tokenizer.pad_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认 self.reward_token_id = self.tokenizer.pad_token_id
很巧妙,但感觉可以加一行注释说明一下原因
9a4609f
to
69fef1f
Compare
|
||
# 启动任务,首次启动建议添加 HF_ENDPOINT=https://hf-mirror.com 方便数据集加载 | ||
HF_ENDPOINT=https://hf-mirror.com xtuner rlhf -c examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记得添加read the docs里详细文档的链接
"""Post process sequence: tokenization & truncation.""" | ||
message_data = message['data'] | ||
new_meaasage_data = [] | ||
if self.message_type == 'prompt': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move post-process to dataset class
trajectories['rewards'] = rewards | ||
|
||
# pretrain data | ||
if self.pretrain_mes_iter is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move pretrain loss tokenize to pretrain dataset
No description provided.