From 1ffe51e58d20d72a994bd0779ced783205988e55 Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Thu, 13 Jun 2024 17:49:08 +0800 Subject: [PATCH] add/fix pretrain_loss --- .gitignore | 1 + examples/rlhf/four_model_8gpu.py | 163 ++++--- examples/rlhf/four_model_vllm_8gpu.py | 160 ++++--- examples/rlhf/quick_start.md | 2 +- xtuner/rlhf/dataset/base.py | 8 +- xtuner/rlhf/dataset/txt_loader.py | 408 ++++++++---------- xtuner/rlhf/envs/prompt_utils.py | 46 ++ xtuner/rlhf/envs/txt_env.py | 100 ++--- xtuner/rlhf/loss/actor_loss.py | 8 +- xtuner/rlhf/loss/critic_loss.py | 8 +- xtuner/rlhf/loss/pretrain_loss.py | 39 +- xtuner/rlhf/main.py | 25 +- xtuner/rlhf/model_backend/generate_utils.py | 21 +- xtuner/rlhf/model_backend/hf_model_runner.py | 230 +++++----- xtuner/rlhf/model_server/base_model_server.py | 6 +- xtuner/rlhf/repeaters/base.py | 159 ++----- xtuner/rlhf/repeaters/running_mean_std.py | 38 ++ xtuner/rlhf/trainer/ppo.py | 124 ++++-- 18 files changed, 747 insertions(+), 799 deletions(-) create mode 100644 xtuner/rlhf/envs/prompt_utils.py create mode 100644 xtuner/rlhf/repeaters/running_mean_std.py diff --git a/.gitignore b/.gitignore index ffe3444b8..c13320a73 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ data *.pkl.json *.log.json work_dirs/ +rlhf_trainlog*/ # Pytorch *.pth diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py index 2d96a832d..9ccfc3c4c 100644 --- a/examples/rlhf/four_model_8gpu.py +++ b/examples/rlhf/four_model_8gpu.py @@ -1,5 +1,26 @@ import torch +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 256 +PRETRAIN_BATCH_SIZE = 32 + +GENERATE_MICRO_BATCH_SIZE = 16 +AC_INFER_MICRO_BATCH_SIZE = 8 +REF_INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +ACTOR_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE + +MODEL_DTYPE = 'auto' + tokenizer_config = dict( pad_token_id=0, eos_token_id=92542, @@ -7,56 +28,54 @@ ) rollout_config = dict( - actor_micro_bs=32, - reward_micro_bs=32, - clip_reward_min=-5, - clip_reward_max=5, - max_new_tokens=10, - async_reward=True, + actor_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=True, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, 'top_k': 0, 'top_p': 0.9, - 'pad_token_id': 0, - 'eos_token_id': 92542, - 'early_stopping': True, - 'num_beams': 1, 'min_new_tokens': 1, - }) + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) repeater_config = dict( - actor_micro_bs=8, - ref_micro_bs=8, - critic_micro_bs=32, - reward_scale=False, - fine_grained_rm=False, - value_ema=False, + actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, answer_end_id=92542, norm_rewards=True, ) + train_config = dict( - ppo_minibatch=64, - value_minibatch=64, - actor_micro_bs=2, - critic_micro_bs=2, - pretrain_step=0, - save_interval=800, + actor_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_step=20, + save_interval=40, ) -critic_model_path = 'internlm/internlm2-chat-1_8b-sft' - model_configs = dict( actor=dict( model_path='internlm/internlm2-chat-1_8b-sft', model_type='actor', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype=torch.float32, + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=1e-6, @@ -65,14 +84,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -91,34 +110,21 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), - generator_config=dict(shared_with_trainer=True, ), - ), - reference=dict( - model_path='internlm/internlm2-chat-1_8b-sft', - model_type='reference', - use_flash_attn=False, - trainer_config=dict( - torch_dtype=torch.float32, - trainer_type='huggingface', - parallel=dict( - data=dict(size=2, mode='ddp'), - tensor=dict(size=1, mode='1d'), - pipeline=dict(size=1, interleaved_overlap=False), - sequence=False, - ), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, ), + generator_config=dict(shared_with_trainer=True, ), ), critic=dict( - model_path=critic_model_path, + model_path=None, model_type='critic', - use_flash_attn=False, trainer_config=dict( - torch_dtype='auto', + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=5e-6, @@ -127,14 +133,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -152,20 +158,36 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), ), reward=dict( - model_path=critic_model_path, + model_path=None, model_type='reward', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype='auto', + use_flash_attn=True, parallel=dict( - data=dict(size=2, mode='ddp'), + data=dict(size=1, mode='ddp'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, @@ -175,14 +197,23 @@ ) dataset_config = { - 'num_samples_each_epoch': - 64, - 'max_seq_len': - 1024, + 'prompt_samples_each_epoch': + PROMPT_BATCH_SIZE, + 'max_prompt_len': + MAX_PROMPT_LEN, + 'pretrain_samples_each_epoch': + PRETRAIN_BATCH_SIZE, + 'max_pretrain_len': + MAX_PRETRAIN_LEN, 'random_seed': 1024, - 'ppo_datas': [ + "sample_strategy": "in_data", + "ratio_within_datasets": False, + 'prompt_datasets': [ 'Anthropic/hh-rlhf/helpful-base::1.0', 'Anthropic/hh-rlhf/harmless-base::0.5', ], + 'pretrain_datasets': [ + 'Anthropic/hh-rlhf/helpful-base::1.0', + ], } diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py index 654f57691..9d8ea67fe 100644 --- a/examples/rlhf/four_model_vllm_8gpu.py +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -1,5 +1,26 @@ import torch +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 256 +PRETRAIN_BATCH_SIZE = 32 + +GENERATE_MICRO_BATCH_SIZE = 16 +AC_INFER_MICRO_BATCH_SIZE = 8 +REF_INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +ACTOR_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE + +MODEL_DTYPE = 'auto' + tokenizer_config = dict( pad_token_id=0, eos_token_id=92542, @@ -7,55 +28,54 @@ ) rollout_config = dict( - actor_micro_bs=32, - reward_micro_bs=32, - clip_reward_min=-5, - clip_reward_max=5, - max_new_tokens=10, - async_reward=True, + actor_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=True, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, 'top_k': 0, 'top_p': 0.9, - 'pad_token_id': 0, - 'eos_token_id': 92542, - 'early_stopping': True, - 'num_beams': 1, 'min_new_tokens': 1, - }) + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) repeater_config = dict( - actor_micro_bs=8, - ref_micro_bs=8, - critic_micro_bs=32, - reward_scale=False, - fine_grained_rm=False, - value_ema=False, + actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, answer_end_id=92542, norm_rewards=True, ) + train_config = dict( - ppo_minibatch=64, - value_minibatch=64, - actor_micro_bs=2, - critic_micro_bs=2, - pretrain_step=0, - save_interval=800, + actor_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_step=20, + save_interval=40, ) -critic_model_path = 'internlm/internlm2-chat-1_8b-sft' model_configs = dict( actor=dict( model_path='internlm/internlm2-chat-1_8b-sft', model_type='actor', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype=torch.float32, + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=1e-6, @@ -64,14 +84,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -90,10 +110,11 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), generator_config=dict( shared_with_trainer=False, generator_type='vllm', @@ -105,28 +126,14 @@ ), ), ), - reference=dict( - model_path='internlm/internlm2-chat-1_8b-sft', - model_type='reference', - use_flash_attn=False, - trainer_config=dict( - torch_dtype=torch.float32, - trainer_type='huggingface', - parallel=dict( - data=dict(size=1, mode='ddp'), - tensor=dict(size=1, mode='1d'), - pipeline=dict(size=1, interleaved_overlap=False), - sequence=False, - ), - ), - ), critic=dict( - model_path=critic_model_path, + model_path=None, model_type='critic', - use_flash_attn=False, trainer_config=dict( - torch_dtype='auto', + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=5e-6, @@ -135,14 +142,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -160,18 +167,34 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), ), reward=dict( - model_path=critic_model_path, + model_path=None, model_type='reward', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype='auto', + use_flash_attn=True, parallel=dict( data=dict(size=1, mode='ddp'), tensor=dict(size=1, mode='1d'), @@ -183,14 +206,23 @@ ) dataset_config = { - 'num_samples_each_epoch': - 64, - 'max_seq_len': - 1024, + 'prompt_samples_each_epoch': + PROMPT_BATCH_SIZE, + 'max_prompt_len': + MAX_PROMPT_LEN, + 'pretrain_samples_each_epoch': + PRETRAIN_BATCH_SIZE, + 'max_pretrain_len': + MAX_PRETRAIN_LEN, 'random_seed': 1024, - 'ppo_datas': [ + # "sample_strategy": "in_data", + # "ratio_within_datasets": False, + 'prompt_datasets': [ 'Anthropic/hh-rlhf/helpful-base::1.0', 'Anthropic/hh-rlhf/harmless-base::0.5', ], + 'pretrain_datasets': [ + 'Anthropic/hh-rlhf/helpful-base::1.0', + ], } diff --git a/examples/rlhf/quick_start.md b/examples/rlhf/quick_start.md index 823b08ad2..8cc5cb494 100644 --- a/examples/rlhf/quick_start.md +++ b/examples/rlhf/quick_start.md @@ -10,7 +10,7 @@ pip install torch==2.1.2+cu118 torchvision --index-url https://download.pytorch. git clone https://github.com/2581543189/xtuner.git cd xtuner git checkout rlhf -pip install .[rlhf] +pip install '.[rlhf]' ``` ### step2: 使用单引擎(huggingface)启动 rlhf 任务 diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py index b64107d8f..9f9a5cb69 100644 --- a/xtuner/rlhf/dataset/base.py +++ b/xtuner/rlhf/dataset/base.py @@ -153,7 +153,7 @@ def __init__(self, sub_dataset_type='file', tokenizer=None, random_seed=1024, - ratio_within_datas=True): + ratio_within_datasets=True): self._task_group = [] for _task in task_groups: file_path, extra_info = _task.split('::')[0], _task.split('::')[1] @@ -194,9 +194,9 @@ def __init__(self, else: raise NotImplementedError('Cannot support filelist now.') self.random_seed = random_seed - self.ratio_within_datas = ratio_within_datas + self.ratio_within_datasets = ratio_within_datasets - if self.ratio_within_datas: + if self.ratio_within_datasets: sum_prob = sum([task['prob'] for task in self._task_group]) for task in self._task_group: task['prob'] = task['prob'] / sum_prob @@ -220,7 +220,7 @@ def _get_subset_by_ratio(self, dataset: Dataset, ratio: float, seed: int): def __iter__(self): """sample data one task by probs.""" - if self.ratio_within_datas: + if self.ratio_within_datasets: rng = random.Random(self.random_seed) probs = [task['prob'] for task in self._task_group] # Initialize task iterator diff --git a/xtuner/rlhf/dataset/txt_loader.py b/xtuner/rlhf/dataset/txt_loader.py index 5f524206f..cc20a54da 100644 --- a/xtuner/rlhf/dataset/txt_loader.py +++ b/xtuner/rlhf/dataset/txt_loader.py @@ -1,325 +1,267 @@ -"""Finetuning dataset.""" +""" Finetuning dataset. """ import random -from dataclasses import dataclass from typing import List - import numpy as np -from torch.utils.data import DataLoader, IterableDataset, RandomSampler - -from .base import InfiniteDataset, MultiSourceDatset +from dataclasses import dataclass +from torch.utils.data import IterableDataset, DataLoader, RandomSampler +from .base import MultiSourceDatset, InfiniteDataset @dataclass class Message: message: List[dict] - sys_meta: str = 'default' - rm_meta: str = 'default' + sys_meta: str = "default" + rm_meta: str = "default" token_ids: List[int] = None - mes_type: str = 'ppo' + mes_type: str = "prompt" class TxtMessageDataset(IterableDataset): - """Create sequences from dataset. - + """ Create sequences from dataset. Args: - sample_strategy (str) ["in_batch", "in_data"]: - "in_batch": - sample data by ratio for every single training batch - "in_data": - merge all data by ratio first and then sample training batch + sample_strategy (str) ["in_batch", "in_data"]: "in_batch": sample data by ratio for every single training batch + "in_data": merge all data by ratio first and then sample training batch """ - def __init__(self, - ppo_datas: list[str] = None, - pt_datas: list[str] = None, + prompt_datasets: list[str] = None, + pretrain_datasets: list[str] = None, tokenizer=None, - max_seq_len: int = 4096, - num_samples_each_epoch: int = 64, - pt_data_samples: int = 0, + max_prompt_len: int = 4096, + max_pretrain_len: int = 4096, + prompt_samples_each_epoch: int = 64, + pretrain_samples_each_epoch: int = 0, random_seed: int = 110, - sample_strategy: str = 'in_batch', - ratio_within_datas: bool = True, - **kwargs): - - assert sample_strategy in [ - 'in_batch', 'in_data' - ], f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" # noqa: E501 + sample_strategy: str = "in_batch", + ratio_within_datasets: bool = True, + **kwargs + ): + assert sample_strategy in ["in_batch", "in_data"], f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" self.sample_strategy = sample_strategy - assert ppo_datas is not None, '[Data error] Specify your data task config' # noqa: E501 + assert prompt_datasets is not None, "[Data error] Specify your data task config" self.tokenizer = tokenizer - assert self.tokenizer.chat_template is not None, 'Make sure tokenizer has chat_template.' # noqa: E501 + assert self.tokenizer.chat_template is not None, "Make sure tokenizer has chat_template." - self.ppo_message_dataset = MultiSourceDatset( - task_groups=ppo_datas, - sub_dataset_type='file', - tokenizer=self.tokenizer, - ratio_within_datas=ratio_within_datas) - if pt_data_samples is not None and pt_data_samples != 0: - assert pt_datas is not None, f'[PT DATA error] samples num {pt_data_samples}, while pt_datas is None' # noqa: E501 - self.pt_message_dataset = MultiSourceDatset( - task_groups=pt_datas, - sub_dataset_type='file', - tokenizer=self.tokenizer, - ratio_within_datas=ratio_within_datas) - self.pt_data_per_epoch = pt_data_samples - self.ppo_data_per_epoch = num_samples_each_epoch - self.pt_data_per_epoch # noqa: E501 + self.prompt_message_dataset = MultiSourceDatset(task_groups=prompt_datasets, + sub_dataset_type="file", + tokenizer=self.tokenizer, + ratio_within_datasets=ratio_within_datasets + ) + if pretrain_samples_each_epoch is not None and pretrain_samples_each_epoch > 0: + assert pretrain_datasets is not None, f"[PT DATA error] samples num {pretrain_samples_each_epoch}, while pretrain_datasets is None" + self.pt_message_dataset = MultiSourceDatset(task_groups=pretrain_datasets, + sub_dataset_type="file", + tokenizer=self.tokenizer, + ratio_within_datasets=ratio_within_datasets + ) + self.pretrain_samples_each_epoch = pretrain_samples_each_epoch else: self.pt_message_dataset = None - self.pt_data_per_epoch = 0 - self.ppo_data_per_epoch = num_samples_each_epoch - - self.max_seq_len = max_seq_len - self.num_samples_each_epoch = num_samples_each_epoch + self.pretrain_samples_each_epoch = 0 + self.prompt_samples_each_epoch = prompt_samples_each_epoch + self.max_prompt_len = max_prompt_len + self.max_pretrain_len = max_pretrain_len + self.num_samples_each_epoch = self.pretrain_samples_each_epoch + self.prompt_samples_each_epoch + self.random_seed = random_seed self.rng = random.Random(self.random_seed) np.random.seed(self.random_seed) random.seed(self.random_seed) - if self.sample_strategy == 'in_batch': + if self.sample_strategy == "in_batch": self._init_in_batch() - elif self.sample_strategy == 'in_data': + elif self.sample_strategy == "in_data": self._init_in_data() else: - raise NotImplementedError( - f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" # noqa: E501 - ) + raise NotImplementedError(f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}") self.epoch_index = 0 def _init_in_data(self): - print( - '========================= Init in data sampler =========================' # noqa: E501 - ) - if self.pt_data_per_epoch != 0: - assert hasattr(self.pt_message_dataset, 'all_dataset') + print(f"========================= Init in data sampler =========================") + if self.pretrain_samples_each_epoch != 0: + assert hasattr(self.pt_message_dataset, "all_dataset") pt_sampler = RandomSampler(self.pt_message_dataset.all_dataset) - self.pt_dataloader = iter( - DataLoader( - self.pt_message_dataset.all_dataset, - collate_fn=lambda x: x, - sampler=pt_sampler, - batch_size=self.pt_data_per_epoch)) - print( - f'[PT data] pretrain data per epoch: {self.pt_data_per_epoch}') - - assert hasattr(self.ppo_message_dataset, 'all_dataset') - prompt_sampler = RandomSampler(self.ppo_message_dataset.all_dataset) - self.prompt_dataloader = iter( - DataLoader( - self.ppo_message_dataset.all_dataset, - collate_fn=lambda x: x, - sampler=prompt_sampler, - batch_size=self.ppo_data_per_epoch)) + self.pt_dataloader = iter(DataLoader( + self.pt_message_dataset.all_dataset, collate_fn=lambda x: x, sampler=pt_sampler, batch_size=self.pretrain_samples_each_epoch + )) + print(f"[PT data] pretrain data per epoch: {self.pretrain_samples_each_epoch}") - print(f'[PPO data] ppo data per epoch: {self.ppo_data_per_epoch}') - print( - f'[Txt] Training dataset initialized, random seed {self.random_seed}.\n' # noqa: E501 - ) + assert hasattr(self.prompt_message_dataset, "all_dataset") + prompt_sampler = RandomSampler(self.prompt_message_dataset.all_dataset) + self.prompt_dataloader = iter(DataLoader( + self.prompt_message_dataset.all_dataset, collate_fn=lambda x: x, sampler=prompt_sampler, batch_size=self.prompt_samples_each_epoch + )) + print(f"[Prompt data] prompt data per epoch: {self.prompt_samples_each_epoch}") + print(f"[Txt] Training dataset initialized, random seed {self.random_seed}.\n") + def yield_in_data(self): - print( - '========================= yield data from data sampler =========================' # noqa: E501 - ) + print(f"========================= yield data from data sampler =========================") batch_sequence = [] - ppo_sequence, pt_sequence = [], [] - if self.pt_data_per_epoch != 0: - pt_batch_messages = next(self.pt_dataloader) - for index, message in enumerate(pt_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='pt') + prompt_sequence, pretrain_sequence = [], [] + if self.pretrain_samples_each_epoch != 0: + pretrain_batch_messages = next(self.pt_dataloader) + for index, message in enumerate(pretrain_batch_messages): + sequence = self._postprocess_sequence(message, mes_type="pretrain") if sequence is not None: - assert sequence.mes_type == 'pt', f'Data type should be pt, but get {sequence.mes_type}' # noqa: E501 - pt_sequence.append(sequence) - if len(pt_sequence) == self.pt_data_per_epoch: + assert sequence.mes_type == 'pretrain', f"Data type should be pretrain, but get {sequence.mes_type}" + pretrain_sequence.append(sequence) + if len(pretrain_sequence) == self.pretrain_samples_each_epoch: break - assert len( - pt_sequence - ) == self.pt_data_per_epoch, f'{len(pt_sequence)} != {self.pt_data_per_epoch}' # noqa: E501 + assert len(pretrain_sequence) == self.pretrain_samples_each_epoch, f"{len(pretrain_sequence)} != {self.pretrain_samples_each_epoch}" - ppo_batch_messages = next(self.prompt_dataloader) - for index, message in enumerate(ppo_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='ppo') + prompt_batch_messages = next(self.prompt_dataloader) + for index, message in enumerate(prompt_batch_messages): + if message is None: + continue + sequence = self._postprocess_sequence(message, mes_type="prompt") if sequence is not None: - assert sequence.mes_type == 'ppo', f'Data type should be ppo. but get {sequence.mes_type}' # noqa: E501 - ppo_sequence.append(sequence) - if len(ppo_sequence) == self.ppo_data_per_epoch: + assert sequence.mes_type == 'prompt', f"Data type should be prompt. but get {sequence.mes_type}" + prompt_sequence.append(sequence) + if len(prompt_sequence) == self.prompt_samples_each_epoch: break - if len(ppo_sequence) < self.ppo_data_per_epoch: - missed = self.ppo_data_per_epoch - len(ppo_sequence) - print( - f'[Warning] {missed} dirty data, use {missed} data from sampled data...' # noqa: E501 - ) + # TODO, len(prompt_sequence) < self.prompt_samples_each_epoch, random sample from chosen data + if len(prompt_sequence) < self.prompt_samples_each_epoch: + missed = self.prompt_samples_each_epoch - len(prompt_sequence) + print(f"[Warning] {missed} dirty data, use {missed} data from sampled data...") for i in range(missed): - ppo_sequence.append(ppo_sequence[i]) + prompt_sequence.append(prompt_sequence[i]) - assert len( - ppo_sequence - ) == self.ppo_data_per_epoch, f'{len(ppo_sequence)} == {self.ppo_data_per_epoch}' # noqa: E501 + assert len(prompt_sequence) == self.prompt_samples_each_epoch, f"{len(prompt_sequence)} == {self.prompt_samples_each_epoch}" - print( - f'prepare TxtMessageDataset done: {len(ppo_sequence)} ppo & {len(pt_sequence)} pretrain, for epoch {self.epoch_index}.' # noqa: E501 - ) - batch_sequence = ppo_sequence + pt_sequence - assert len( - batch_sequence - ) == self.num_samples_each_epoch, '[Epoch {self.epoch_index}] Wrong data len' # noqa: E501 + print(f"prepare TxtMessageDataset done: {len(prompt_sequence)} prompt & {len(pretrain_sequence)} pretrain, for epoch {self.epoch_index}.") + batch_sequence = prompt_sequence + pretrain_sequence + assert len(batch_sequence) == self.num_samples_each_epoch, "[Epoch {self.epoch_index}] Wrong data len" return batch_sequence def _init_in_batch(self): - print( - '========================= Init in batch sampler =========================' # noqa: E501 - ) + print(f"========================= Init in batch sampler =========================") samples_cnts = [] pt_data_len = 0 - if self.pt_data_per_epoch != 0: + if self.pretrain_samples_each_epoch != 0: for task in self.pt_message_dataset._task_group: - task['target_num_each_epoch'] = int( - task['prob'] * self.pt_data_per_epoch + 0.5) + 1 - inner_dataset = InfiniteDataset(task['dataset'], self.rng) - task['iterator'] = iter(inner_dataset) - samples_cnts.append(task['target_num_each_epoch']) - print( - f"[PT data] {task['filepath']}: task prob: {task['prob']}, " # noqa: E501 - f'ori number of messages: {len(inner_dataset.data)}, ' - f"target_num_each_epoch: {task['target_num_each_epoch']}" - ) # noqa: E501 + task["target_num_each_epoch"] = int(task["prob"] * self.pretrain_samples_each_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task["dataset"], self.rng) + task["iterator"] = iter(inner_dataset) + samples_cnts.append(task["target_num_each_epoch"]) + print(f"[Pretrain data] {task['filepath']}: task prob: {task['prob']}, " + f"ori number of messages: {len(inner_dataset.data)}, " + f"target_num_each_epoch: {task['target_num_each_epoch']}") pt_data_len = sum(samples_cnts) - assert pt_data_len >= self.pt_data_per_epoch, f'Make sure there are enough pretrain data, {pt_data_len} >= {self.pt_data_per_epoch}' # noqa: E501 - print( - f'[PT data] pretrain data per epoch: {self.pt_data_per_epoch}, sampled {pt_data_len}' # noqa: E501 - ) - for task in self.ppo_message_dataset._task_group: - task['target_num_each_epoch'] = int( - task['prob'] * self.ppo_data_per_epoch + 0.5) + 1 - inner_dataset = InfiniteDataset(task['dataset'], self.rng) - task['iterator'] = iter(inner_dataset) - samples_cnts.append(task['target_num_each_epoch']) - print(f"{task['filepath']}: task prob: {task['prob']}, " - f'ori number of messages: {len(inner_dataset.data)}, ' - f"target_num_each_epoch: {task['target_num_each_epoch']}") - assert ( - sum(samples_cnts) - pt_data_len - ) >= self.ppo_data_per_epoch, 'Make sure there are enough ppo datas' - print( - f'[PPO data] ppo data per epoch: {self.ppo_data_per_epoch}, sampled: {sum(samples_cnts) - pt_data_len}' # noqa: E501 - ) + # TODO + assert pt_data_len >= self.pretrain_samples_each_epoch, f"Make sure there are enough pretrain datas, {pt_data_len} >= {self.pretrain_samples_each_epoch}" + print(f"[PT data] pretrain data per epoch: {self.pretrain_samples_each_epoch}, sampled {pt_data_len}") - if sum(samples_cnts) <= self.num_samples_each_epoch: - print( - f'[Txt loader] Warning!!! sample nums {sum(samples_cnts)} <= samples {self.num_samples_each_epoch}' # noqa: E501 - ) - print( - f'[Txt] Training dataset initialized, random seed {self.random_seed}.\n' # noqa: E501 - ) + for task in self.prompt_message_dataset._task_group: + task["target_num_each_epoch"] = int(task["prob"] * self.prompt_samples_each_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task["dataset"], self.rng) + task["iterator"] = iter(inner_dataset) + samples_cnts.append(task["target_num_each_epoch"]) + print(f"{task['filepath']}: task prob: {task['prob']}, " + f"ori number of messages: {len(inner_dataset.data)}, " + f"target_num_each_epoch: {task['target_num_each_epoch']}") + assert (sum(samples_cnts) - pt_data_len) >= self.prompt_samples_each_epoch, "Make sure there are enough prompt datas" + print(f"[Prompt data] prompt data per epoch: {self.prompt_samples_each_epoch}, sampled: {sum(samples_cnts) - pt_data_len}") + assert sum(samples_cnts) >= self.num_samples_each_epoch, "[Dataset init] sample num error" + # if sum(samples_cnts) <= self.num_samples_each_epoch: + # print(f"[Txt loader] Warning!!! sample nums {sum(samples_cnts)} <= samples {self.num_samples_each_epoch}") + print(f"[Txt] Training dataset initialized, random seed {self.random_seed}.\n") + def yield_in_batch(self): - print( - '========================= yield data from batch sampler =========================' # noqa: E501 - ) + print(f"========================= yield data from batch sampler =========================") batch_sequence = [] - ppo_sequence, pt_sequence = [], [] + prompt_sequence, pretrain_sequence = [], [] # epoch_rng only use in this epoch. epoch_rng = np.random.RandomState(self.epoch_index) # prepare epoch data - if self.pt_data_per_epoch != 0: - pt_batch_messages = [] + # print(f"prepare TxtMessageDataset for epoch {self.epoch_index}...") + if self.pretrain_samples_each_epoch != 0 : + pretrain_batch_messages = [] for task in self.pt_message_dataset._task_group: messages = [] - for _ in range(task['target_num_each_epoch']): - messages.append(next(task['iterator'])) - print( - f"[PT] prepare {len(messages)} data from {task['filepath']}" # noqa: E501 - ) + for _ in range(task["target_num_each_epoch"]): + messages.append(next(task["iterator"])) + print(f"[Pretrain] prepare {len(messages)} data from {task['filepath']}") epoch_rng.shuffle(messages) - pt_batch_messages.extend(messages) - epoch_rng.shuffle(pt_batch_messages) - for index, message in enumerate(pt_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='pt') + pretrain_batch_messages.extend(messages) + # if len(pretrain_batch_messages) == self.pretrain_samples_each_epoch: + # break + epoch_rng.shuffle(pretrain_batch_messages) + for index, message in enumerate(pretrain_batch_messages): + sequence = self._postprocess_sequence(message, mes_type="pretrain") if sequence is not None: - assert sequence.mes_type == 'pt', f'Data type should be pt, but get {sequence.mes_type}' # noqa: E501 - pt_sequence.append(sequence) - if len(pt_sequence) == self.pt_data_per_epoch: + assert sequence.mes_type == 'pretrain', f"Data type should be pretrain, but get {sequence.mes_type}" + pretrain_sequence.append(sequence) + if len(pretrain_sequence) == self.pretrain_samples_each_epoch: break - assert len( - pt_sequence - ) == self.pt_data_per_epoch, f'{len(pt_sequence)} != {self.pt_data_per_epoch}' # noqa: E501 + assert len(pretrain_sequence) == self.pretrain_samples_each_epoch, f"{len(pretrain_sequence)} != {self.pretrain_samples_each_epoch}" - ppo_batch_messages = [] - for task in self.ppo_message_dataset._task_group: + prompt_batch_messages = [] + for task in self.prompt_message_dataset._task_group: messages = [] - for _ in range(task['target_num_each_epoch']): - messages.append(next(task['iterator'])) - print( - f"[PPO] prepare {len(messages)} data from {task['filepath']}") + for _ in range(task["target_num_each_epoch"]): + messages.append(next(task["iterator"])) + print(f"[Prompt] prepare {len(messages)} data from {task['filepath']}") epoch_rng.shuffle(messages) - ppo_batch_messages.extend(messages) - epoch_rng.shuffle(ppo_batch_messages) - for index, message in enumerate(ppo_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='ppo') + prompt_batch_messages.extend(messages) + epoch_rng.shuffle(prompt_batch_messages) + for index, message in enumerate(prompt_batch_messages): + sequence = self._postprocess_sequence(message, mes_type="prompt") if sequence is not None: - assert sequence.mes_type == 'ppo', f'Data type should be ppo. but get {sequence.mes_type}' # noqa: E501 - ppo_sequence.append(sequence) - if len(ppo_sequence) == self.ppo_data_per_epoch: + assert sequence.mes_type == 'prompt', f"Data type should be prompt. but get {sequence.mes_type}" + prompt_sequence.append(sequence) + if len(prompt_sequence) == self.prompt_samples_each_epoch: break - assert len( - ppo_sequence - ) == self.ppo_data_per_epoch, f'{len(ppo_sequence)} == {self.ppo_data_per_epoch}' # noqa: E501 + assert len(prompt_sequence) == self.prompt_samples_each_epoch, f"{len(prompt_sequence)} == {self.prompt_samples_each_epoch}" - print( - f'prepare TxtMessageDataset done: {len(ppo_sequence)} ppo & {len(pt_sequence)} pretrain, for epoch {self.epoch_index}.' # noqa: E501 - ) - batch_sequence = ppo_sequence + pt_sequence - assert len( - batch_sequence - ) == self.num_samples_each_epoch, '[Epoch {self.epoch_index}] Wrong data len' # noqa: E501 + print(f"prepare TxtMessageDataset done: {len(prompt_sequence)} prompt & {len(pretrain_sequence)} pretrain, for epoch {self.epoch_index}.") + batch_sequence = prompt_sequence + pretrain_sequence + assert len(batch_sequence) == self.num_samples_each_epoch, "[Epoch {self.epoch_index}] Wrong data len" return batch_sequence def __iter__(self): while True: - if self.sample_strategy == 'in_batch': + if self.sample_strategy == "in_batch": yield self.yield_in_batch() - elif self.sample_strategy == 'in_data': + elif self.sample_strategy == "in_data": yield self.yield_in_data() self.epoch_index += 1 - def _postprocess_sequence(self, message, mes_type='ppo'): + def _postprocess_sequence(self, message, mes_type=None): """Post process sequence: tokenization & truncation.""" message_data = message['data'] new_meaasage_data = [] - if mes_type == 'ppo': + if mes_type == "prompt": for _ in reversed(range(len(message_data))): - if message_data[_]['role'] == 'user': + if message_data[_]["role"] == "user": new_meaasage_data = message_data[:_ + 1] break - assert new_meaasage_data[-1][ - 'role'] == 'user', f'ppo data last role must user, {new_meaasage_data}' # noqa: E501 - token_ids = self.tokenizer.apply_chat_template( - new_meaasage_data, - tokenize=True, - add_generation_prompt=True, - return_tensors='pt') - elif mes_type == 'pt': + assert new_meaasage_data[-1]["role"] == "user", f"prompt data last role must user, {new_meaasage_data}" + token_ids = self.tokenizer.apply_chat_template(new_meaasage_data, tokenize=True, add_generation_prompt=True, return_tensors="pt") + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_prompt_len: + # TODO truncation?? + # raise RuntimeError(f"token_ids is too long: {token_ids.shape[-1]}") + print(f"[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...") + return None + elif mes_type == "pretrain": for _ in reversed(range(len(message_data))): - if message_data[_]['role'] == 'assistant': + if message_data[_]["role"] == "assistant": new_meaasage_data = message_data[:_ + 1] break - assert new_meaasage_data[-1][ - 'role'] == 'assistant', f'pretrain data last role must assistant, {new_meaasage_data}' # noqa: E501 - token_ids = self.tokenizer.apply_chat_template( - new_meaasage_data, - tokenize=True, - add_generation_prompt=False, - return_tensors='pt') - if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_seq_len: - print( - f'[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...' # noqa: E501 - ) - return None - return Message( - message=new_meaasage_data, - token_ids=token_ids, - sys_meta=message['sys_meta'], - rm_meta=message['rm_meta'], - mes_type=mes_type) + assert new_meaasage_data[-1]["role"] == "assistant", f"pretrain data last role must assistant, {new_meaasage_data}" + token_ids = self.tokenizer.apply_chat_template(new_meaasage_data, tokenize=True, add_generation_prompt=False, return_tensors="pt") + + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_pretrain_len: + # TODO truncation?? + # raise RuntimeError(f"token_ids is too long: {token_ids.shape[-1]}") + print(f"[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...") + return None + return Message(message=new_meaasage_data, + token_ids=token_ids, + sys_meta=message['sys_meta'], + rm_meta=message['rm_meta'], + mes_type=mes_type) diff --git a/xtuner/rlhf/envs/prompt_utils.py b/xtuner/rlhf/envs/prompt_utils.py new file mode 100644 index 000000000..54dfab3cf --- /dev/null +++ b/xtuner/rlhf/envs/prompt_utils.py @@ -0,0 +1,46 @@ + + +META_PROMPT = { + 'default': + '', + 'helpful': + 'Consider the helpfulness of the given sentence. ' + 'A helpful sentence is one that provides value, offers insight, or assists the user in some way. ' # noqa: E501 + 'It may solve a problem, answer a question, provide relevant information, or offer guidance. ' # noqa: E501 + "It should directly address the user's needs or inquiries and is ideally actionable, " # noqa: E501 + 'meaning the user can use the information to make decisions or take steps towards their goal. ' # noqa: E501 + 'The score should be higher for sentences that are more beneficial and informative to the user. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'harmless': + 'Evaluate the harmlessness of the provided sentence. ' + 'A harmless sentence is one that does not offend, insult, or cause distress to the user, ' # noqa: E501 + "and does not mislead or give false information. It respects the user's feelings, beliefs, " # noqa: E501 + 'and values, and does not provoke conflict or negative emotions. ' + 'Additionally, a harmless sentence should not incite harmful actions or behaviors. ' # noqa: E501 + 'It should uphold the principles of respect, honesty, and fairness. ' + 'Sentences that are more considerate, respectful, and accurate should receive higher scores on this scale. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'hhh': + 'Evaluate the given sentence based on its helpfulness, harmlessness, and honesty. ' # noqa: E501 + 'A helpful sentence provides value, solves problems, or aids the user in an actionable manner. ' # noqa: E501 + "A harmless sentence respects the user's feelings and beliefs, avoids causing distress, " # noqa: E501 + 'and does not mislead or incite harmful behaviors. An honest sentence delivers reliable and true information, ' # noqa: E501 + 'presents facts objectively, and demonstrates integrity and authenticity. Higher scores should be assigned ' # noqa: E501 + 'to sentences that embody these characteristics more strongly. ' + 'Here is the conversation:\n\n', + 'summarization': + 'As a language model performing a summarization task, your goal is to generate a summary that ' # noqa: E501 + 'accurately, succinctly, and coherently encapsulates the key details of the source text. Ensure relevance to ' # noqa: E501 + 'the original material, completeness of main points, and logical structure. Maintain conciseness and high ' # noqa: E501 + 'linguistic standards. Ensure only the summary is outputted, refraining from adding extraneous comments or ' # noqa: E501 + 'remarks. Here is the original material:\n\n', + 'reddit': + 'Imagine you are a knowledgeable and friendly Reddit user. ' + 'A fellow Redditor has just shared a post seeking feedback, advice, or input. ' # noqa: E501 + 'Please read the post and provide a thoughtful, informative, and respectful response, ' # noqa: E501 + 'just as if you were replying on the platform. Here is the post:\n\n', + 'latex': + 'When mathematical content appears in the conversation, please use latex format to express the mathematical content. Here is the conversation:\n\n', # noqa: E501 + 'math_ci': + "Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:\n- Just write jupyter code to solve the problem without giving your thought;\n- Present the final result in LaTeX using a '\\boxed\\{{}}' without any units. \n", # noqa: E501 +} diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 1a01d0741..7490151b8 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -6,51 +6,7 @@ from torch.utils.data import IterableDataset from ..model_server.base_model_server import BaseModelServer - -META_PROMPT = { - 'default': - '', - 'helpful': - 'Consider the helpfulness of the given sentence. ' - 'A helpful sentence is one that provides value, offers insight, or assists the user in some way. ' # noqa: E501 - 'It may solve a problem, answer a question, provide relevant information, or offer guidance. ' # noqa: E501 - "It should directly address the user's needs or inquiries and is ideally actionable, " # noqa: E501 - 'meaning the user can use the information to make decisions or take steps towards their goal. ' # noqa: E501 - 'The score should be higher for sentences that are more beneficial and informative to the user. ' # noqa: E501 - 'Here is the conversation:\n\n', - 'harmless': - 'Evaluate the harmlessness of the provided sentence. ' - 'A harmless sentence is one that does not offend, insult, or cause distress to the user, ' # noqa: E501 - "and does not mislead or give false information. It respects the user's feelings, beliefs, " # noqa: E501 - 'and values, and does not provoke conflict or negative emotions. ' - 'Additionally, a harmless sentence should not incite harmful actions or behaviors. ' # noqa: E501 - 'It should uphold the principles of respect, honesty, and fairness. ' - 'Sentences that are more considerate, respectful, and accurate should receive higher scores on this scale. ' # noqa: E501 - 'Here is the conversation:\n\n', - 'hhh': - 'Evaluate the given sentence based on its helpfulness, harmlessness, and honesty. ' # noqa: E501 - 'A helpful sentence provides value, solves problems, or aids the user in an actionable manner. ' # noqa: E501 - "A harmless sentence respects the user's feelings and beliefs, avoids causing distress, " # noqa: E501 - 'and does not mislead or incite harmful behaviors. An honest sentence delivers reliable and true information, ' # noqa: E501 - 'presents facts objectively, and demonstrates integrity and authenticity. Higher scores should be assigned ' # noqa: E501 - 'to sentences that embody these characteristics more strongly. ' - 'Here is the conversation:\n\n', - 'summarization': - 'As a language model performing a summarization task, your goal is to generate a summary that ' # noqa: E501 - 'accurately, succinctly, and coherently encapsulates the key details of the source text. Ensure relevance to ' # noqa: E501 - 'the original material, completeness of main points, and logical structure. Maintain conciseness and high ' # noqa: E501 - 'linguistic standards. Ensure only the summary is outputted, refraining from adding extraneous comments or ' # noqa: E501 - 'remarks. Here is the original material:\n\n', - 'reddit': - 'Imagine you are a knowledgeable and friendly Reddit user. ' - 'A fellow Redditor has just shared a post seeking feedback, advice, or input. ' # noqa: E501 - 'Please read the post and provide a thoughtful, informative, and respectful response, ' # noqa: E501 - 'just as if you were replying on the platform. Here is the post:\n\n', - 'latex': - 'When mathematical content appears in the conversation, please use latex format to express the mathematical content. Here is the conversation:\n\n', # noqa: E501 - 'math_ci': - "Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:\n- Just write jupyter code to solve the problem without giving your thought;\n- Present the final result in LaTeX using a '\\boxed\\{{}}' without any units. \n", # noqa: E501 -} +from .prompt_utils import META_PROMPT class TxtEnv: @@ -62,8 +18,6 @@ def __init__( max_new_tokens: int = 1024, actor_micro_bs: int = 32, reward_micro_bs: int = 32, - clip_reward_min: int = -5, - clip_reward_max: int = 5, reward_function: BaseModelServer = None, async_reward: bool = True, generate_kwargs: dict = None, @@ -80,15 +34,13 @@ def __init__( self.max_new_tokens = max_new_tokens self.actor_micro_bs = actor_micro_bs self.reward_micro_bs = reward_micro_bs - self.clip_reward_min = clip_reward_min - self.clip_reward_max = clip_reward_max self.async_reward = async_reward self.generate_kwargs: dict = generate_kwargs def rollout(self, policy_model: BaseModelServer, display=False): sample_data = deepcopy(next(self.dataloader)) - ppo_input_messages = [] - pt_input_messages = [] + prompt_input_messages = [] + pretrain_input_messages = [] for data in sample_data: if data.sys_meta != 'default': message = deepcopy([{ @@ -97,23 +49,23 @@ def rollout(self, policy_model: BaseModelServer, display=False): }] + data.message) else: message = deepcopy(data.message) - if data.mes_type == 'ppo': - ppo_input_messages.append(message) - elif data.mes_type == 'pt': - pt_input_messages.append(message) + if data.mes_type == 'prompt': + prompt_input_messages.append(message) + elif data.mes_type == 'pretrain': + pretrain_input_messages.append(message) else: raise TypeError(f'Wrong message type {data.mes_type}') - # ppo data + # prompt data s_t = time.time() - print(f'[For Generate]: {ppo_input_messages[0]}') + print(f'[For Generate]: {prompt_input_messages[0]}') trajectories = policy_model.generate( - inputs=ppo_input_messages, + inputs=prompt_input_messages, micro_batch_size=self.actor_micro_bs, step=self.max_new_tokens, output_str=True, generate_kwargs=self.generate_kwargs) logger.info( - f'[actor generate] duration: {round(time.time() - s_t, 2)} s, len(inputs): {len(ppo_input_messages)} ' # noqa: E501 + f'[actor generate] duration: {round(time.time() - s_t, 2)} s, len(inputs): {len(prompt_input_messages)} ' # noqa: E501 ) if self.async_reward: @@ -122,25 +74,23 @@ def rollout(self, policy_model: BaseModelServer, display=False): trajectories['reward_output_ref'] = reward_output_ref else: rewards = self.get_reward(sample_data, trajectories) - clipped_rewards = torch.clamp( - rewards, min=self.clip_reward_min, max=self.clip_reward_max) trajectories['rewards'] = rewards - trajectories['clipped_rewards'] = clipped_rewards # pretrain data - if len(pt_input_messages) > 0: - pt_inputs = [ - policy_model.tokenizer.apply_chat_template( - mes, - tokenize=False, - add_generation_prompt=False, - return_tensors='pt') for mes in pt_input_messages - ] - trajectories.pt_data = policy_model.tokenizer( - pt_inputs, return_tensors='pt', padding=True) + if len(pretrain_input_messages) > 0: + from ..tokenizer import tokenizer_utils + pretrain_input_ids, pretrain_attention_mask = tokenizer_utils.encode( + pretrain_input_messages, policy_model.tokenizer) + pretrain_labels = torch.nn.functional.pad(pretrain_input_ids[:, 1:], (0, 1), mode="constant", value=-100) + + trajectories.pretrain_data = {"input_ids": pretrain_input_ids, + "labels": pretrain_labels, + "attention_mask": pretrain_attention_mask} print( - f'[TxtEnv & {policy_model.__class__.__name__}] gets {len(pt_input_messages)} pretrain episodes.' # noqa: E501 + f'[TxtEnv & {policy_model.__class__.__name__}] gets {len(pretrain_input_messages)} pretrain episodes.' # noqa: E501 ) + else: + trajectories.pretrain_data = None return trajectories @@ -149,6 +99,8 @@ def get_reward_async(self, sample_data, policyout): s_t = time.time() rm_input_messages = [] for i in range(len(sample_data)): + if sample_data[i].mes_type != "prompt": + continue if sample_data[i].rm_meta != 'default': cur_rm_data = [{ 'role': 'system', @@ -190,6 +142,8 @@ def get_reward(self, sample_data, policyout): s_t = time.time() rm_input_messages = [] for i in range(len(sample_data)): + if sample_data[i].mes_type != "prompt": + continue if sample_data[i].rm_meta != 'default': cur_rm_data = [{ 'role': 'system', diff --git a/xtuner/rlhf/loss/actor_loss.py b/xtuner/rlhf/loss/actor_loss.py index cd81c97db..e5c05cc01 100644 --- a/xtuner/rlhf/loss/actor_loss.py +++ b/xtuner/rlhf/loss/actor_loss.py @@ -53,12 +53,12 @@ def forward(self, logits: torch.Tensor, labels: dict[str, Any]): Tensor: Return the final loss """ assert logits.ndim == 3 - mask = labels['mask'] # (micro_bsz, seqlen) + mask = labels['mask'] assert logits.shape[0] == labels['input_ids'].shape[0] - input_ids = labels['input_ids'] # (micro_bsz, seqlen) - old_logprobs = labels['old_logprobs'] # (micro_bsz, seqlen) - advantages = labels['advantages'] # (micro_bsz, seqlen) + input_ids = labels['input_ids'] + old_logprobs = labels['old_logprobs'] + advantages = labels['advantages'] loss_factor = labels['loss_factor'] logpy = logprobs_from_logits( diff --git a/xtuner/rlhf/loss/critic_loss.py b/xtuner/rlhf/loss/critic_loss.py index 877c21c28..3ad4e2db6 100644 --- a/xtuner/rlhf/loss/critic_loss.py +++ b/xtuner/rlhf/loss/critic_loss.py @@ -7,7 +7,7 @@ class CriticLoss(torch.nn.Module): """Loss function for critic model.""" def __init__(self, - cliprange_value: float = 100, + cliprange_value: float = 0.5, loss_type: str = 'per_seq'): super().__init__() self.cliprange_value = cliprange_value @@ -53,12 +53,12 @@ def forward(self, values: torch.Tensor, labels: dict[str, Any]): Tensor: Return the final loss """ assert values.ndim == 2 - mask = labels['mask'] # (micro_bsz, seqlen) + mask = labels['mask'] num_actions = mask.size(1) values = values[:, -num_actions:] - old_values = labels['old_values'] # (micro_bsz, seqlen) - returns = labels['returns'] # (micro_bsz, seqlen) + old_values = labels['old_values'] + returns = labels['returns'] loss_factor = labels['loss_factor'] loss = self.critic_loss_fn( values=values, diff --git a/xtuner/rlhf/loss/pretrain_loss.py b/xtuner/rlhf/loss/pretrain_loss.py index fe08d2a0b..6356291d0 100644 --- a/xtuner/rlhf/loss/pretrain_loss.py +++ b/xtuner/rlhf/loss/pretrain_loss.py @@ -1,36 +1,20 @@ import torch from loguru import logger -try: - from flash_attn.losses.cross_entropy import \ - CrossEntropyLoss as FlashCrossEntropyLoss -except ImportError: - pass - -# Adapted from: https://gitlab.pjlab.org.cn/openmmlab/bigmodel/rl3m/-/blob/main/rl3m/layers/loss.py#L37 # noqa: E501 -class FlashGPTLMLoss(torch.nn.Module): +class PretrainLoss(torch.nn.Module): """Loss function for flash GPT Language Model.""" - def __init__(self, parallel_output=True, label_smoothing=0): + def __init__(self, label_smoothing=0): super().__init__() if label_smoothing is not None and label_smoothing != 0: logger.warning(f'Use label_smoothing: {label_smoothing}') self.label_smoothing = label_smoothing - if parallel_output: - # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D # noqa: E501 - self.loss_fn = FlashCrossEntropyLoss( - reduction='mean', - inplace_backward=True, - process_group=None, - label_smoothing=label_smoothing, - ) - else: - # Here, the output will gather output is set in the model, so use ordinary loss # noqa: E501 - self.loss_fn = torch.nn.CrossEntropyLoss( - reduction='mean', label_smoothing=label_smoothing) + # Here, the output will gather output is set in the model, so use ordinary loss # noqa: E501 + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='mean', label_smoothing=label_smoothing) def forward(self, *args): if len(args) == 3: @@ -50,16 +34,3 @@ def forward(self, *args): return loss - -# Adapted from: https://gitlab.pjlab.org.cn/openmmlab/bigmodel/rl3m/-/blob/main/rl3m/layers/loss.py#L37 # noqa: E501 -class PretrainLoss(FlashGPTLMLoss): - """Modified from pretrain/sft loss, but with a loss factor term to balance - with ppo policy loss.""" - - def __init__(self, *args, loss_factor=1.0, **kwargs): - super().__init__(*args, **kwargs) - self.loss_factor = loss_factor - - def forward(self, *args, **kwargs): - loss = super().forward(*args, **kwargs) - return loss * self.loss_factor diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index cf3812fcc..ea64cc36a 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -22,7 +22,7 @@ def parse_args(): '--config', help='config file name or path.', type=str, - default='examples/rlhf/four_model_8gpu.py') + default='examples/rlhf/four_model_vllm_8gpu.py') parser.add_argument( '-w', '--work_dir', @@ -50,17 +50,15 @@ def validate_config(config: Config): assert args.config is not None, 'config should not be None' work_dir = args.work_dir if work_dir is None: - work_dir = os.getcwd() + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') work_dir = os.path.abspath(work_dir) logger.info(f'using work_dir: {work_dir}') os.makedirs(work_dir, exist_ok=True) logger.add( - f'{work_dir}/train.log', + f'{work_dir}/train_rlhf.log', filter=lambda record: record['extra'].get('name') == 'train') - logger.add( - f'{work_dir}/rollout.log', - filter=lambda record: record['extra'].get('name') == 'rollout') logger_train = logger.bind(name='train') configs_path = args.config @@ -131,7 +129,7 @@ def validate_config(config: Config): # # for value & policy learn value_loss_ref = ppo.value_learn_async(trajectories, critic_model) - ppo_loss = 0.0 + ppo_loss, pt_loss = None, None if pretrain_step <= 0: ppo_loss, pt_loss = ppo.policy_learn(trajectories, actor_model) logger_train.info( @@ -145,8 +143,14 @@ def validate_config(config: Config): pretrain_step -= 1 if config['rollout_config'].get('write_to_file', True): - with open(f'{work_dir}/rollout.log', 'a') as file: - file.write(f'generates: {trajectories.output_str}') + if not os.path.exists(f'{work_dir}/rollouts'): + os.makedirs(f'{work_dir}/rollouts') + with open(f'{work_dir}/rollouts/step{step}_rollout.log', + 'a') as file: + for output_s, r in zip(trajectories.output_str, + trajectories.rewards): + file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + '\n' + '=' * 30 + '\n') summaries = dict( reward_mean=trajectories.rewards.mean().item(), reward_std=trajectories.rewards.std().item(), @@ -158,9 +162,10 @@ def validate_config(config: Config): entropy=trajectories.entropy.mean().item(), step=step, policy_loss=ppo_loss, + pretrain_loss=pt_loss, critic_loss=value_loss, ) - with open(f'{work_dir}/train.log.jsonl', 'a') as f: + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: f.write(json.dumps(summaries) + '\n') step += 1 diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py index e88995d28..15adb5640 100644 --- a/xtuner/rlhf/model_backend/generate_utils.py +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -36,6 +36,7 @@ def partition_by_micro_batch_size( input_ids: Union[list[str], torch.Tensor, list[int]], micro_batch_size: int, attention_mask: torch.Tensor = None, + position_ids: torch.Tensor = None, labels: Optional[Union[list[torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]] = None, ) -> list[dict[str, torch.Tensor]]: @@ -46,6 +47,7 @@ def partition_by_micro_batch_size( micro_batch = {} micro_batch['input_ids'] = input_ids micro_batch['attention_mask'] = attention_mask + micro_batch['position_ids'] = position_ids micro_batch['labels'] = labels micro_batches.append(micro_batch) return micro_batches @@ -64,6 +66,9 @@ def partition_by_micro_batch_size( attention_mask_split = ( torch.split(attention_mask, micro_batch_size, dim=0) if attention_mask is not None else [None for _ in range(num_splits)]) + position_ids_split = ( + torch.split(position_ids, micro_batch_size, dim=0) + if position_ids is not None else [None for _ in range(num_splits)]) labels_split = ( partition_label_by_micro_batch_size(labels, micro_batch_size, num_splits) @@ -72,6 +77,7 @@ def partition_by_micro_batch_size( micro_batch = {} micro_batch['input_ids'] = input_ids_split[i] micro_batch['attention_mask'] = attention_mask_split[i] + micro_batch['position_ids'] = position_ids_split[i] micro_batch['labels'] = labels_split[i] micro_batches.append(micro_batch) return micro_batches @@ -108,33 +114,34 @@ def partition_list_by_micro_batch_size( micro_batch_size: list[int], labels: list[torch.Tensor], attention_mask: Optional[list[torch.Tensor]] = None, - loss_weights: Optional[list[float]] = None, + position_ids: Optional[list[torch.Tensor]] = None, ) -> list[dict]: length = len(input_ids) batch_size = input_ids[0].shape[0] num_splits = int(batch_size // micro_batch_size[0]) + ( batch_size % micro_batch_size[0] > 0) micro_batches = [[{} for i in range(length)] for _ in range(num_splits)] - if loss_weights is None: - loss_weights = [None for _ in range(length)] if attention_mask is None: attention_mask = [None for _ in range(length)] + if position_ids == None: + position_ids = [None for _ in range(length)] for i in range(length): sub_input_ids = input_ids[i] sub_attention_mask = attention_mask[i] + sub_position_ids = position_ids[i] sub_labels = labels[i] - sub_loss_weights = loss_weights[i] sub_micro_batches = partition_by_micro_batch_size( - sub_input_ids, micro_batch_size[i], sub_attention_mask, sub_labels) + sub_input_ids, micro_batch_size[i], sub_attention_mask, + sub_position_ids, sub_labels) for micro_batch_index, sub_micro_batch in enumerate(sub_micro_batches): micro_batches[micro_batch_index][i]['input_ids'] = sub_micro_batch[ 'input_ids'] micro_batches[micro_batch_index][i][ 'attention_mask'] = sub_micro_batch['attention_mask'] + micro_batches[micro_batch_index][i][ + 'position_ids'] = sub_micro_batch['position_ids'] micro_batches[micro_batch_index][i]['labels'] = sub_micro_batch[ 'labels'] - micro_batches[micro_batch_index][i][ - 'loss_weights'] = sub_loss_weights return micro_batches diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index ca6a826f5..b873a3497 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -146,71 +146,30 @@ def initialize(self): f'[{self.model_type}] __init__() done with optimizer {self.optimizer.optimizer}.' # noqa: E501 ) - # Training - def compute_loss_and_backward( - self, - input_ids: Union[list[torch.Tensor], torch.Tensor], - labels: Optional[Union[list[torch.Tensor], torch.Tensor, - dict[str, torch.Tensor]]] = None, - attention_mask: Optional[Union[list[torch.Tensor], - torch.Tensor]] = None, - criterion: Optional[Union[list[_Loss], _Loss]] = None, - loss_weights: Optional[list[float]] = None, - gradient_accumulation_steps=1, - **_ignored, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - """ - criterion: _Loss class, e.g., torch.nn.CrossEntropyLoss() - """ - if isinstance(input_ids, torch.Tensor): # returns torch.Tensor - # rarely, since self.train() changes all input_ids to [input_ids] - loss = self.compute_loss(input_ids, labels, attention_mask, - criterion) - self.accelerator.backward(loss) - return loss - - elif type(input_ids) == list: # returns list[torch.Tensor] - # multiple inputs grouped to compute loss, see: - # https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch - assert ( - len(input_ids) == len(labels) == len(criterion) == - len(attention_mask) == len(loss_weights) - ), f'{len(input_ids)} {len(labels)} {len(criterion)} {len(attention_mask)} {len(loss_weights)} must equal' # noqa: E501 - loss_list = [0 for _ in range(len(input_ids))] - loss_weights = [ - x / float(len(loss_weights)) for x in loss_weights - ] # to 1 - - loss_sum = 0 - for i in range(len(input_ids)): - with self.accelerator.autocast(): - loss = self.compute_loss(input_ids[i], labels[i], - attention_mask[i], criterion[i]) - loss_sum += loss * loss_weights[i] - loss_list[i] = loss - self.accelerator.backward(loss_sum) - return loss_list - - else: - raise NotImplementedError(f'unknown input {input_ids}') - def compute_loss( self, input_ids: torch.Tensor, labels: Optional[Union[torch.Tensor, dict[str, torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, criterion: Optional[_Loss] = None, loss_weight: Optional[float] = None, **_ignored, ) -> torch.Tensor: input_ids = input_ids.to(self.device) labels = input_ids.clone() if labels is None else labels - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + if attention_mask is not None: + if position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) batch = { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'position_ids': position_ids.to(self.device) + 'input_ids': + input_ids, + 'attention_mask': + attention_mask.to(self.device) + if attention_mask is not None else None, + 'position_ids': + position_ids.to(self.device) if position_ids is not None else None } self.model.train() @@ -226,19 +185,12 @@ def compute_loss( # OPT. B) Use preset loss functions, e.g., torch.nn.CrossEntropyLoss() # noqa: E501 # Adopted from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1199 # noqa: E501 logits: torch.Tensor = self.model(**batch, use_cache=False).logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_logits = shift_logits.view(-1, self.vocab_size) - shift_labels = shift_labels.view(-1) - shift_labels = shift_labels.to( - shift_logits.device) # enable model para - # loss_fct = criterion() - loss = criterion(shift_logits, shift_labels) + labels = labels.to(self.device) + loss = criterion(logits, labels) elif isinstance(labels, dict): # OPT. C) Use customized loss function, see loss/actor_loss.py logits: torch.Tensor = self.model( **batch, use_cache=False, return_dict=True).logits - # loss_fct = criterion() for k, v in labels.items(): labels[k] = v.to(self.device) loss = criterion(logits, labels) @@ -266,6 +218,7 @@ def train( dict[str, torch.Tensor]]] = None, attention_mask: Optional[Union[list[torch.Tensor], torch.Tensor]] = None, + position_ids: Optional[Union[list[torch.Tensor], torch.Tensor]] = None, criterion: Optional[Union[list[_Loss], _Loss]] = None, loss_weights: Optional[Union[list[float], float]] = None, step_interval: int = 1, @@ -280,58 +233,66 @@ def train( input_ids = [input_ids] labels = [labels] attention_mask = [attention_mask] + position_ids = [position_ids] criterion = [criterion] - loss_weights = [1] if loss_weights is None else [loss_weights] - micro_batch_size = None if micro_batch_size is None else [ - micro_batch_size - ] - return_list = False - - if micro_batch_size is None: - for i in range(len(input_ids)): - self.info_rank0( - f'[{self.model_type}] train input_ids[{i}] shape[{input_ids[i].shape}]' # noqa: E501 - ) - origin_loss = self.compute_loss_and_backward( - input_ids, labels, attention_mask, criterion, loss_weights) + loss_weights = [loss_weights] + micro_batch_size = [micro_batch_size] else: - assert isinstance(input_ids, list) - micro_batches = partition_list_by_micro_batch_size( - input_ids, micro_batch_size, labels, attention_mask, - loss_weights) - origin_loss_list_mb = [] - for index, micro_batch in enumerate(micro_batches): - input_ids_mb = [] - attention_mask_mb = [] - labels_mb = [] - loss_weights_mb = [] - for i in range(len(micro_batch)): - input_ids_mb.append(micro_batch[i]['input_ids'].to( - self.device)) - attention_mask_mb.append( - micro_batch[i]['attention_mask'].to(self.device)) - labels_mb.append(micro_batch[i]['labels']) - loss_weights_mb.append(micro_batch[i]['loss_weights']) - if index == 0: - for i in range(len(input_ids_mb)): - self.info_rank0( - f'[{self.model_type}] will train input_ids_mb[{i}] shape[{input_ids_mb[i].shape}] * {len(micro_batches)} times' # noqa: E501 - ) - origin_loss_mb = self.compute_loss_and_backward( - input_ids_mb, - labels_mb, - attention_mask_mb, - criterion, - loss_weights_mb, - gradient_accumulation_steps=len(micro_batches), + if attention_mask is None: + attention_mask = [None for _ in range(len(input_ids))] + if position_ids is None: + position_ids = [None for _ in range(len(input_ids))] + if criterion is None: + criterion = [None for _ in range(len(input_ids))] + if loss_weights is None: + loss_weights = [None for _ in range(len(input_ids))] + if micro_batch_size is None: + micro_batch_size = [None for _ in range(len(input_ids))] + + assert isinstance(input_ids, list) + + loss_list = [[] for _ in range(len(input_ids))] + for index in range(len(input_ids)): + mb_size_entry = micro_batch_size[index] + if mb_size_entry is None: + micro_batches: list[dict[str, torch.Tensor]] = [] + micro_batches.append({ + 'input_ids': input_ids[index], + 'attention_mask': attention_mask[index], + 'position_ids': position_ids[index], + 'labels': labels[index] + }) + else: + micro_batches = partition_by_micro_batch_size( + input_ids=input_ids[index], + micro_batch_size=micro_batch_size[index], + attention_mask=attention_mask[index], + position_ids=position_ids[index], + labels=labels[index], + ) + loss_entry = [] + for mb_index, micro_batch in enumerate(micro_batches): + if mb_index == 0: + self.info_rank0( + f"[{self.model_type}] will train input_ids[{mb_index}] shape[{micro_batch['input_ids'].shape}] * {len(micro_batches)} times" # noqa: E501 + ) + # compute loss and backward + loss = self.compute_loss( + input_ids=micro_batch['input_ids'], + labels=micro_batch['labels'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids'], + criterion=criterion[index], + loss_weight=loss_weights[index], ) - origin_loss_list_mb.append(origin_loss_mb) + self.accelerator.backward(loss) + loss_entry.append(loss) if debug: set_seed(1234) - origin_loss = merge_loss_list(origin_loss_list_mb) + loss_list[index] = sum(loss_entry) / len(loss_entry) self.parameter_update(step_interval) - return origin_loss if return_list else origin_loss[0] + return loss_list if len(loss_list) > 1 else loss_list[0] # Inference @torch.no_grad() @@ -740,18 +701,21 @@ def initialize_get(self): self.initialize_ref = None # Training - def train_async(self, input_ids, labels, attention_mask, *args, **kwargs): + def train_async(self, input_ids, labels, attention_mask, position_ids, + *args, **kwargs): if isinstance(input_ids, torch.Tensor): micro_batch_size = input_ids.shape[0] // self.dp_size + ( input_ids.shape[0] % self.dp_size > 0 ) # round up division, i.e., math.ceil(a / b) micro_batches = partition_by_micro_batch_size( - input_ids, micro_batch_size, attention_mask, labels) + input_ids, micro_batch_size, attention_mask, position_ids, + labels) assert len(micro_batches) == self.dp_size return [ self.ray_actors[index].train.remote( input_ids=micro_batch['input_ids'], attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids'], labels=micro_batch['labels'], *args, **kwargs, @@ -762,39 +726,47 @@ def train_async(self, input_ids, labels, attention_mask, *args, **kwargs): assert isinstance(input_ids[0], torch.Tensor) micro_batch_size = [i for i in range(len(input_ids))] for index, input_id in enumerate(input_ids): - micro_batch_size[ - index] = input_id[index].shape[0] // self.dp_size + ( - input_id[index].shape[0] % self.dp_size > 0 - ) # round up division, i.e., math.ceil(a / b) + micro_batch_size[index] = input_id.shape[0] // self.dp_size + ( + input_id.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) micro_batches = partition_list_by_micro_batch_size( - input_ids, self.dp_size, attention_mask, labels) + input_ids=input_ids, + micro_batch_size=micro_batch_size, + labels=labels, + attention_mask=attention_mask, + position_ids=position_ids, + ) + assert len(micro_batches) == self.dp_size object_refs = [] for index, micro_batch in enumerate(micro_batches): input_ids_mb = [] attention_mask_mb = [] + position_ids_mb = [] labels_mb = [] - loss_weights_mb = [] - assert len(micro_batch) == self.dp_size for i in range(len(micro_batch)): input_ids_mb.append(micro_batch[i]['input_ids']) attention_mask_mb.append(micro_batch[i]['attention_mask']) + position_ids_mb.append(micro_batch[i]['position_ids']) labels_mb.append(micro_batch[i]['labels']) - loss_weights_mb.append(micro_batch[i]['loss_weights']) - - object_ref = self.ray_actors[index].train.remote( - inputs=input_ids_mb, - attention_mask=attention_mask_mb, - labels=labels_mb, - loss_weights=loss_weights_mb, - *args, - **kwargs, - ) - object_refs.append(object_ref) - return object_ref + object_ref = self.ray_actors[index].train.remote( + input_ids=input_ids_mb, + attention_mask=attention_mask_mb, + position_ids=position_ids_mb, + labels=labels_mb, + *args, + **kwargs, + ) + object_refs.append(object_ref) + return object_refs def train_get(self, object_refs, timeout=None): losses = ray.get(object_refs, timeout=timeout) - return sum(losses) / len(losses) + if isinstance(losses[0], list): + p_loss = [sub_loss[0] for sub_loss in losses] + pt_loss = [sub_loss[1] for sub_loss in losses] + return [sum(p_loss) / len(p_loss), sum(pt_loss) / len(pt_loss)] + else: + return sum(losses) / len(losses) def train(self, *args, **kwargs): object_refs = self.train_async(*args, **kwargs) diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index ffb2426bd..884482f4e 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -107,9 +107,10 @@ def train_async(self, input_ids, labels=None, attention_mask=None, + position_ids=None, *args, **train_kwargs): - return self.trainer.train_async(input_ids, labels, attention_mask, + return self.trainer.train_async(input_ids, labels, attention_mask, position_ids, *args, **train_kwargs) def train_get(self, object_refs, timeout: Optional[float] = None): @@ -119,9 +120,10 @@ def train(self, input_ids, labels=None, attention_mask=None, + position_ids=None, *args, **train_kwargs): - object_refs = self.train_async(input_ids, labels, attention_mask, + object_refs = self.train_async(input_ids, labels, attention_mask, position_ids, *args, **train_kwargs) loss = self.train_get(object_refs) self.log_cuda_mem_stats(remark='[train] ') diff --git a/xtuner/rlhf/repeaters/base.py b/xtuner/rlhf/repeaters/base.py index 0e68600b4..156fe3e02 100644 --- a/xtuner/rlhf/repeaters/base.py +++ b/xtuner/rlhf/repeaters/base.py @@ -1,66 +1,11 @@ import time -import numpy as np import torch from loguru import logger from ..model_server.base_model_server import BaseModelServer from ..policy_output import PolicyOutput - - -def find_mask_begin(padded_datas, mask_id=0): - """finding the mask id begin index and it's length.""" - begin_indexs = [] - lengths = [] - - for padded_data in padded_datas: - is_flag = 0 - for index, data in enumerate(padded_data): - if data != mask_id: - is_flag = 1 - begin_indexs.append(index) - length = (np.array(padded_data) != mask_id).sum() - lengths.append(length) - break - assert is_flag - return begin_indexs, lengths - - -class RunningStates: - # adopt from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py # noqa: E501 - def __init__(self, epsilon: float = 1e-4): - self.mean = torch.tensor(0, dtype=torch.float32) - self.var = torch.tensor(0, dtype=torch.float32) - self.count = epsilon - - def update(self, x: torch.Tensor): - x_var, x_mean = torch.var_mean(x.cpu(), unbiased=False) - x_count = x.shape[0] - self.update_from_moments(x_mean, x_var, x_count) - - def update_from_other(self, other: 'RunningStates'): - self.update_from_moments(other.mean, other.var, other.count) - - def update_from_moments(self, mean: torch.Tensor, var: torch.Tensor, - count: int): - delta = mean - self.mean - tot_count = self.count + count - m_a = self.var * self.count - m_b = var * count - m_2 = m_a + m_b + delta**2 * self.count * count / (self.count + count) - new_var = m_2 / (self.count + count) - - self.mean += delta * count / tot_count - self.var = new_var - self.count = tot_count - - def state_dict(self): - return dict(mean=self.mean, var=self.var, count=self.count) - - def load_state_dict(self, states): - self.mean = states['mean'] - self.var = states['var'] - self.count = states['count'] +from .running_mean_std import RunningStates class BaseRepeater: @@ -68,31 +13,30 @@ class BaseRepeater: def __init__( self, sft_model, - reward_scale: bool = False, - fine_grained_rm: bool = False, - value_ema: bool = False, actor_micro_bs: int = 8, ref_micro_bs: int = 8, critic_micro_bs: int = 32, kl_coeff=0.02, gamma=1.0, gae_lambda=0.95, - answer_end_id=92542, norm_adv=False, + clip_reward_min: int = -5, + clip_reward_max: int = 5, norm_rewards=True, + reward_scale: bool = False, + fine_grained_rm: bool = False, **_ignored, ): self.sft_model = sft_model self.actor_micro_bs = actor_micro_bs self.ref_micro_bs = ref_micro_bs self.critic_micro_bs = critic_micro_bs - self.reward_scale = reward_scale - self.fine_grained_rm = fine_grained_rm - self.value_ema = value_ema self.kl_coeff = kl_coeff self.gamma = gamma self.gae_lambda = gae_lambda - self.answer_end_id = answer_end_id + # rewards + self.clip_reward_min = clip_reward_min + self.clip_reward_max = clip_reward_max self.norm_rewards = norm_rewards if self.norm_rewards: self.running_states = RunningStates(epsilon=0) @@ -158,15 +102,16 @@ def _get_kl_rewards(self, if env.async_reward: rewards = env.get_reward_collect(trajectories['reward_output_ref']) trajectories['reward_output_ref'] = None - clipped_rewards = torch.clamp( - rewards, min=env.clip_reward_min, max=env.clip_reward_max) trajectories['rewards'] = rewards - trajectories['clipped_rewards'] = clipped_rewards # Experimental - rewards = trajectories.clipped_rewards + + clipped_rewards = torch.clamp( + rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['clipped_rewards'] = clipped_rewards + if self.norm_rewards: - self.running_states.update(rewards) - norm_reward_score = (rewards - self.running_states.mean) / ( + self.running_states.update(clipped_rewards) + norm_reward_score = (clipped_rewards - self.running_states.mean) / ( self.running_states.var.sqrt() + 1e-8) action_mask = trajectories.action_mask num_actions = action_mask.size(1) @@ -232,60 +177,6 @@ def _get_values_collect(self, value_output_ref, ) return raw_values - def _get_advantages_and_returns(self, trajectories): - output_ids = trajectories.output_ids - answer_mask = trajectories.answer_mask - values_with_last_value = trajectories.values_with_last_value - kl_rewards = trajectories.kl_rewards - - begins_index, answers_length = find_mask_begin(answer_mask, 0) - count = 0 - advantages_padded, returns_padded = torch.zeros_like( - kl_rewards, dtype=values_with_last_value.dtype), torch.zeros_like( - kl_rewards, dtype=values_with_last_value.dtype) - for begin_index, ans_len, value_with_last_value, reward, output_id in zip( # noqa: E501 - begins_index, answers_length, values_with_last_value, - kl_rewards, output_ids): - # shape :ans_len + 1 - value_with_last_value = value_with_last_value[begin_index - - 1:begin_index + - ans_len] - # shape :ans_len - reward = reward[begin_index:begin_index + ans_len] - last_gae_lam = torch.zeros((1), dtype=values_with_last_value.dtype) - # shape :ans_len - advantages = torch.zeros_like( - reward, dtype=values_with_last_value.dtype) - step_nums = advantages.shape[-1] - # shape:ans_len + 1 - dones = self._build_dones(output_id[begin_index:begin_index + - ans_len]) - for step in reversed(range(step_nums)): - next_non_terminal = 1 - dones[step + 1] - next_values = value_with_last_value[step + 1] - # delta and last_gae_lam using value and reward - delta = reward[ - step] + self.gamma * next_values * next_non_terminal - value_with_last_value[ # noqa: E501 - step] - last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam # noqa: E501 - advantages[step] = last_gae_lam[0] - returns = advantages + value_with_last_value[:-1] - advantages_padded[count, - begin_index:begin_index + ans_len] = advantages - returns_padded[count, begin_index:begin_index + ans_len] = returns - count += 1 - return advantages_padded, returns_padded - - # ans_len + 1: dones - def _build_dones(self, answer_ids): - dones = torch.tensor( - (answer_ids == self.answer_end_id).numpy().astype(np.float32)) - # (1, )the first one is not done, so obs_0_dones=0 - obs_0_dones = torch.zeros((1), dtype=torch.float32) - # (ans_len + 1), - dones = torch.concat((obs_0_dones, dones), axis=0) - return dones - def get_advantages_and_returns( self, values: torch.Tensor, @@ -293,6 +184,24 @@ def get_advantages_and_returns( action_mask: torch.Tensor, ): # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 + """Function that computes advantages and returns from rewards and values. + Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 + Note that rewards may include a KL divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Args: + values: Tensor of shape (batch_size, response_size) + rewards: Tensor of shape (batch_size, response_size) + response_length: Length of the response sequence + use_whitening: Whether to use whitening (ie. normalize advantages) or not + """ lastgaelam = 0 advantages_reversed = [] response_length = rewards.size(1) @@ -303,6 +212,8 @@ def get_advantages_and_returns( for t in reversed(range(response_length)): nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + # Since old_rewards and old_values are masked with action_mask, i.e. they have + # 0's at pad tokens, delta will be 0 if current t is at a pad token, so will lastgaelam delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam advantages_reversed.append(lastgaelam) diff --git a/xtuner/rlhf/repeaters/running_mean_std.py b/xtuner/rlhf/repeaters/running_mean_std.py new file mode 100644 index 000000000..e8b3e2763 --- /dev/null +++ b/xtuner/rlhf/repeaters/running_mean_std.py @@ -0,0 +1,38 @@ +import torch + + +class RunningStates: + # adopt from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py # noqa: E501 + def __init__(self, epsilon: float = 1e-4): + self.mean = torch.tensor(0, dtype=torch.float32) + self.var = torch.tensor(0, dtype=torch.float32) + self.count = epsilon + + def update(self, x: torch.Tensor): + x_var, x_mean = torch.var_mean(x.cpu(), unbiased=False) + x_count = x.shape[0] + self.update_from_moments(x_mean, x_var, x_count) + + def update_from_other(self, other: 'RunningStates'): + self.update_from_moments(other.mean, other.var, other.count) + + def update_from_moments(self, mean: torch.Tensor, var: torch.Tensor, + count: int): + delta = mean - self.mean + tot_count = self.count + count + m_a = self.var * self.count + m_b = var * count + m_2 = m_a + m_b + delta**2 * self.count * count / (self.count + count) + new_var = m_2 / (self.count + count) + + self.mean += delta * count / tot_count + self.var = new_var + self.count = tot_count + + def state_dict(self): + return dict(mean=self.mean, var=self.var, count=self.count) + + def load_state_dict(self, states): + self.mean = states['mean'] + self.var = states['var'] + self.count = states['count'] diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index a4a81aad6..ccfbeb61f 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -5,6 +5,7 @@ from ..loss.actor_loss import ActorLoss from ..loss.critic_loss import CriticLoss +from ..loss.pretrain_loss import PretrainLoss from ..model_server.base_model_server import BaseModelServer from ..timer import Timer @@ -13,92 +14,125 @@ class PPOTrainer: def __init__( self, - policy_model, - value_model, actor_micro_bs=2, critic_micro_bs=2, policy_learn_time=1, value_learn_time=1, - ppo_minibatch=512, - value_minibatch=512, - pt_minibatch=None, - train_minibatch=None, - pt_criterion=None, + policy_minibatch=None, + value_minibatch=None, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_criterion=PretrainLoss(label_smoothing=0), policy_criterion=ActorLoss(cliprange=0.2, loss_type='per_seq'), value_criterion=CriticLoss(cliprange_value=0.5, loss_type='per_seq'), **kwargs, ): - self.ppo_minibatch = ppo_minibatch - self.value_minibatch = value_minibatch self.actor_micro_bs = actor_micro_bs self.critic_micro_bs = critic_micro_bs # policy - self.policy_model = policy_model self.policy_learn_time = policy_learn_time - self.pt_minibatch = pt_minibatch - self.train_minibatch = train_minibatch - self.policy_minibatch = ppo_minibatch + self.policy_minibatch = policy_minibatch # value - self.value_model = value_model self.value_learn_time = value_learn_time self.value_minibatch = value_minibatch - self.pt_criterion = pt_criterion + self.ppo_loss_weight = ppo_loss_weight + self.pretrain_loss_weight = pretrain_loss_weight + self.pretrain_criterion = pretrain_criterion self.policy_criterion = policy_criterion self.value_criterion = value_criterion def policy_learn(self, trajectories, policy_model: BaseModelServer): + if self.policy_minibatch is None: + self.policy_minibatch = len(trajectories.output_ids) policy_updates = len(trajectories.output_ids) // self.policy_minibatch - policy_loss = [] - pt_loss = [] + ppo_loss = [] + pretrain_loss = [] for _ in range(self.policy_learn_time): for i in range(policy_updates): logger.info( '[Policy Train] start policy trains {}/{} | {}'.format( i + 1, policy_updates, _ + 1)) + # prompt train data begin = i * self.policy_minibatch end = begin + self.policy_minibatch - policy_batch_inputs = { - 'input_ids': trajectories.output_ids[begin:end, :], - 'policy_logprobs': - trajectories.policy_logprobs[begin:end, :], - 'advs': trajectories.advantages[begin:end, :], - 'action_mask': trajectories.action_mask[begin:end, :], - 'attention_mask': trajectories.attention_mask[begin:end, :] - } + + train_input_ids = [ + trajectories.output_ids[begin:end, :], + ] + train_attention_mask = [ + trajectories.attention_mask[begin:end, :], + ] + train_criterion = [ + self.policy_criterion, + ] + loss_weights = [ + self.ppo_loss_weight, + ] + micro_batch_size = [ + self.actor_micro_bs, + ] assert len( - policy_batch_inputs['input_ids'] + trajectories.output_ids[begin:end, :] ) == self.policy_minibatch, '[Policy learn] make sure len(policy_batch_inputs) == self.policy_minibatch' # noqa: E501 loss_factor = 1.0 - labels = dict( - input_ids=policy_batch_inputs['input_ids'], - old_logprobs=policy_batch_inputs['policy_logprobs'], - advantages=policy_batch_inputs['advs'], - mask=policy_batch_inputs['action_mask'], - loss_factor=torch.tensor(loss_factor), - ) + train_lables = [ + dict( + input_ids=trajectories.output_ids[begin:end, :], + old_logprobs=trajectories.policy_logprobs[ + begin:end, :], + advantages=trajectories.advantages[begin:end, :], + mask=trajectories.action_mask[begin:end, :], + loss_factor=torch.tensor(loss_factor), + ), + ] + # pretrain data + if trajectories.pretrain_data is not None: + logger.info( + f'[Policy Train] policy train with pretrain data {trajectories.pretrain_data["input_ids"].shape}' + ) + train_input_ids.append( + trajectories.pretrain_data['input_ids']) + train_lables.append(trajectories.pretrain_data['labels']) + # train_position_ids.append(trajectories.pretrain_data["position_ids"]) + train_attention_mask.append( + trajectories.pretrain_data['attention_mask']) + train_criterion.append(self.pretrain_criterion) + loss_weights.append(self.pretrain_loss_weight) + micro_batch_size.append(self.actor_micro_bs) + s_t = time.time() p_loss = policy_model.train( - input_ids=policy_batch_inputs['input_ids'], - labels=labels, - attention_mask=policy_batch_inputs['attention_mask'], - criterion=self.policy_criterion, - micro_batch_size=self.actor_micro_bs) - - logger.info( - f'[actor train] duration: {round(time.time() - s_t, 2)} s, {self.policy_minibatch} batch, Policy loss: {p_loss.item()}' # noqa: E501 - ) - policy_loss.append(p_loss.item()) + input_ids=train_input_ids, + labels=train_lables, + attention_mask=train_attention_mask, + # position_ids=train_position_ids, + criterion=train_criterion, + loss_weights=loss_weights, + micro_batch_size=micro_batch_size) + if isinstance(p_loss, list): + ppo_loss.append(p_loss[0].item()) + pretrain_loss.append(p_loss[1].item()) + logger.info( + f'[Policy Train] duration: {round(time.time() - s_t, 2)} s, prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss[0].item()}; pretrain data: {train_input_ids[1].shape}, pretrain loss: {p_loss[1].item()}' + ) + else: + ppo_loss.append(p_loss.item()) + logger.info( + f'[Policy Train] duration: {round(time.time() - s_t, 2)} s, prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' + ) with Timer('policy_model.sync_model'): policy_model.sync_model() - return policy_loss, pt_loss + return ppo_loss, pretrain_loss def value_learn_async(self, trajectories, value_model: BaseModelServer): + if self.value_minibatch is None: + self.value_minibatch = len(trajectories.output_ids) value_updates = len(trajectories.output_ids) // self.value_minibatch value_loss = [] assert value_updates == 1 and self.policy_learn_time == 1, f'value_updates={value_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501 @@ -125,6 +159,8 @@ def value_learn_get(self, value_loss_ref, value_model: BaseModelServer): ] def value_learn(self, trajectories, value_model: BaseModelServer): + if self.value_minibatch is None: + self.value_minibatch = len(trajectories.output_ids) value_updates = len(trajectories.output_ids) // self.value_minibatch value_loss = []