diff --git a/.gitignore b/.gitignore index ffe3444b8..c13320a73 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ data *.pkl.json *.log.json work_dirs/ +rlhf_trainlog*/ # Pytorch *.pth diff --git a/examples/rlhf/demo_datas/pretrain_data.json b/examples/rlhf/demo_datas/pretrain_data.json new file mode 100644 index 000000000..ccc5e0628 --- /dev/null +++ b/examples/rlhf/demo_datas/pretrain_data.json @@ -0,0 +1,2 @@ +[{"role": "user", "content": ""}, {"role": "assistant", "content": "I am an artificial intelligence (AI) assistant named InternLM. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology."}] +[{"role": "user", "content": "Give three tips for staying healthy."}, {"role": "assistant", "content": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."}, {"role": "user", "content": "How to study English?"}, {"role": "assistant","content": "1. Set clear goals. 2. Create a study plan. 3. Build vocabulary. 4. Practice speaking."}] diff --git a/examples/rlhf/demo_datas/prompt_data.json b/examples/rlhf/demo_datas/prompt_data.json new file mode 100644 index 000000000..6bee0447b --- /dev/null +++ b/examples/rlhf/demo_datas/prompt_data.json @@ -0,0 +1,3 @@ +[{"role": "user", "content": "How to study English?"}] +[{"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": "Give three tips for staying healthy."}] +[{"role": "user", "content": "Give three tips for staying healthy."}, {"role": "assistant", "content": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."}, {"role": "user", "content": "How to study English?"}] diff --git a/examples/rlhf/internlm2_chat_1_8b_ppo_ds_8gpu.py b/examples/rlhf/internlm2_chat_1_8b_ppo_ds_8gpu.py new file mode 100644 index 000000000..a058b1ab0 --- /dev/null +++ b/examples/rlhf/internlm2_chat_1_8b_ppo_ds_8gpu.py @@ -0,0 +1,234 @@ +####################################################################### +# Settings # +####################################################################### +RESUME_STEP = -1 +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 256 +PRETRAIN_BATCH_SIZE = 32 + +GENERATE_MICRO_BATCH_SIZE = 16 +INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +POLICY_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 + +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + +MODEL_DTYPE = 'auto' + +tokenizer_config = dict( + pad_token_id=0, + eos_token_id=92542, + padding_side='left', +) + +rollout_config = dict( + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=True, + resume_step=RESUME_STEP, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'min_new_tokens': 1, + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) + +repeater_config = dict( + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=INFER_MICRO_BATCH_SIZE, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, + norm_rewards=True, +) + +train_config = dict( + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + critic_warmup_step=20, + save_interval=40, + max_train_step=400, + resume_step=RESUME_STEP, +) + +model_configs = dict( + policy=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='policy', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), + generator_config=dict(shared_with_trainer=True, ), + ), + critic=dict( + model_path=None, + model_type='critic', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + reward=dict( + model_path=None, + model_type='reward', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/prompt_data.json::0.01[SYS_PROMPT]:summarization', # noqa: E501 + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', + ]) + +pretrain_dataset_config = dict( + samples_each_epoch=PRETRAIN_BATCH_SIZE, + max_len=MAX_PRETRAIN_LEN, + message_type='pretrain', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/pretrain_data.json::0.01', + '[HF]Anthropic/hh-rlhf/helpful-base::0.5', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', + ], +) diff --git a/examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py b/examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py new file mode 100644 index 000000000..d4ba5d4c3 --- /dev/null +++ b/examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py @@ -0,0 +1,243 @@ +####################################################################### +# Settings # +####################################################################### +RESUME_STEP = -1 +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 256 +PRETRAIN_BATCH_SIZE = 32 # 0 + +GENERATE_MICRO_BATCH_SIZE = 16 +INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +POLICY_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 + +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + +MODEL_DTYPE = 'auto' + +tokenizer_config = dict( + pad_token_id=0, + eos_token_id=92542, + padding_side='left', +) + +rollout_config = dict( + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=True, + resume_step=RESUME_STEP, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'min_new_tokens': 1, + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) + +repeater_config = dict( + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=INFER_MICRO_BATCH_SIZE, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, + norm_rewards=True, +) + +train_config = dict( + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + critic_warmup_step=20, + save_interval=40, + max_train_step=400, + resume_step=RESUME_STEP, +) + +model_configs = dict( + policy=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='policy', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), + generator_config=dict( + shared_with_trainer=False, + generator_type='vllm', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=2, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + model_path=None, + model_type='critic', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + reward=dict( + model_path=None, + model_type='reward', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/prompt_data.json::0.01[SYS_PROMPT]:summarization', # noqa: E501 + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', + ]) + +pretrain_dataset_config = dict( + samples_each_epoch=PRETRAIN_BATCH_SIZE, + max_len=MAX_PRETRAIN_LEN, + message_type='pretrain', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/pretrain_data.json::0.01', + '[HF]Anthropic/hh-rlhf/helpful-base::0.5', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', + ], +) diff --git a/examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py b/examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py new file mode 100644 index 000000000..65b445630 --- /dev/null +++ b/examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py @@ -0,0 +1,244 @@ +####################################################################### +# Settings # +####################################################################### +RESUME_STEP = -1 +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 512 +PRETRAIN_BATCH_SIZE = 0 + +GENERATE_MICRO_BATCH_SIZE = 16 +INFER_MICRO_BATCH_SIZE = 16 +TRAIN_MICRO_BATCH_SIZE = 4 +REF_INFER_MICRO_BATCH_SIZE = 26 + +ZERO_STAGE = 3 +POLICY_DP_SIZE = 8 +CRITIC_DP_SIZE = 4 +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 + +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + +import torch # noqa: E402 + +MODEL_DTYPE = torch.float16 + +POLICY_MODEL_PATH = 'meta-llama/Llama-2-7b-chat-hf' +REWARD_MODEL_PATH = 'meta-llama/Llama-2-7b-chat-hf' # better using a well-trained reward model # noqa: E501 + +tokenizer_config = dict( + pad_token_id=2, + eos_token_id=2, + padding_side='left', + chat_template= # noqa: E251 + "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human:\n' + message['content'] + '\n'}}{% elif message['role'] == 'assistant' %}{{'Assistant:\n' + message['content'] + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:\n' }}{% endif %}", # noqa: E501 +) + +rollout_config = dict( + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=False, + resume_step=RESUME_STEP, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'min_new_tokens': 1, + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) + +repeater_config = dict( + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, + norm_rewards=True, +) + +train_config = dict( + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + critic_warmup_step=0, + save_interval=40, + max_train_step=400, + resume_step=RESUME_STEP, + async_learn=True, +) + +model_configs = dict( + policy=dict( + model_path=POLICY_MODEL_PATH, + model_type='policy', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True if MODEL_DTYPE == torch.bfloat16 else False + }, + 'fp16': { + 'enabled': True if MODEL_DTYPE == torch.float16 else False + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), + generator_config=dict( + shared_with_trainer=False, + generator_type='vllm', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + model_path=REWARD_MODEL_PATH, + model_type='critic', + head_name='value_head', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True if MODEL_DTYPE == torch.bfloat16 else False + }, + 'fp16': { + 'enabled': True if MODEL_DTYPE == torch.float16 else False + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path=POLICY_MODEL_PATH, + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=2, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + reward=dict( + model_path=REWARD_MODEL_PATH, + model_type='reward', + head_name='value_head', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + ]) diff --git a/examples/rlhf/quick_start.md b/examples/rlhf/quick_start.md new file mode 100644 index 000000000..ab1fce0f7 --- /dev/null +++ b/examples/rlhf/quick_start.md @@ -0,0 +1,38 @@ +## Quick Start + +### step1: 环境准备 + +``` +# 安装 pytorch +pip install torch==2.1.2+cu118 torchvision --index-url https://download.pytorch.org/whl/cu118 + +# 安装 xtuner rlhf 模块 +git clone https://github.com/2581543189/xtuner.git +cd xtuner +git checkout rlhf +pip install '.[rlhf]' +``` + +### step2: 使用单引擎(huggingface)启动 rlhf 任务 + +``` +# 启动任务 +xtuner rlhf -c examples/rlhf/four_model_8gpu.py +``` + +### step3: 使用双引擎 (vllm + huggingface) 启动 rlhf 任务 + +``` +# 安装 vllm +export VLLM_VERSION=0.3.3 +export PYTHON_VERSION=310 +pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 +pip uninstall xformers -y +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 +pip uninstall cupy-cuda12x -y +pip install cupy-cuda11x==12.1 +python -m cupyx.tools.install_library --library nccl --cuda 11.x + +# 启动任务,首次启动建议添加 HF_ENDPOINT=https://hf-mirror.com 方便数据集加载 +HF_ENDPOINT=https://hf-mirror.com xtuner rlhf -c examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py +``` diff --git a/requirements/rlhf.txt b/requirements/rlhf.txt new file mode 100644 index 000000000..22cfbcaa3 --- /dev/null +++ b/requirements/rlhf.txt @@ -0,0 +1,2 @@ +loguru +ray[default,train]==2.9.1 diff --git a/setup.py b/setup.py index 7a95dfab4..3a95da067 100644 --- a/setup.py +++ b/setup.py @@ -132,6 +132,9 @@ def gen_packages_items(): 'modelscope': parse_requirements('requirements/runtime.txt') + parse_requirements('requirements/modelscope.txt'), + 'rlhf': + parse_requirements('requirements/deepspeed.txt') + + parse_requirements('requirements/rlhf.txt'), }, zip_safe=False, entry_points={'console_scripts': ['xtuner = xtuner:cli']}) diff --git a/xtuner/entry_point.py b/xtuner/entry_point.py index 2af774fd3..404263546 100644 --- a/xtuner/entry_point.py +++ b/xtuner/entry_point.py @@ -12,7 +12,7 @@ # Define valid modes MODES = ('list-cfg', 'copy-cfg', 'log-dataset', 'check-custom-dataset', 'train', 'test', 'chat', 'convert', 'preprocess', 'mmbench', - 'eval_refcoco') + 'eval_refcoco', 'rlhf') CLI_HELP_MSG = \ f""" @@ -207,6 +207,11 @@ def eval_refcoco(): return eval_refcoco.__file__ +def rlhf(): + from xtuner.rlhf import main as rlhf_main + return rlhf_main.__file__ + + modes = { 'list-cfg': list_cfg, 'copy-cfg': copy_cfg, @@ -230,14 +235,15 @@ def eval_refcoco(): '-h': preprocess_help_msg }, 'eval_refcoco': eval_refcoco, - 'list-dataset-format': list_dataset_format + 'list-dataset-format': list_dataset_format, + 'rlhf': rlhf, } HELP_FUNCS = [preprocess_help_msg, convert_help_msg] MAP_FILE_FUNCS = [ list_cfg, copy_cfg, log_dataset, check_custom_dataset, train, test, chat, mmbench, pth_to_hf, merge, split, arxiv_preprocess, eval_refcoco, - convert_refcoco, list_dataset_format + convert_refcoco, list_dataset_format, rlhf ] diff --git a/xtuner/rlhf/__init__.py b/xtuner/rlhf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/config/__init__.py b/xtuner/rlhf/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/config/config.py b/xtuner/rlhf/config/config.py new file mode 100644 index 000000000..038aa0aa1 --- /dev/null +++ b/xtuner/rlhf/config/config.py @@ -0,0 +1,110 @@ +# flake8: noqa: E501 +#!/usr/bin/env python + +# Adapted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/config.py + +import inspect +import sys +from importlib.machinery import SourceFileLoader +from pathlib import Path + + +class Config(dict): + """This is a wrapper class for dict objects so that values of which can be + accessed as attributes. + + Args: + config (dict): The dict object to be wrapped. + """ + + def __init__(self, config: dict = None): + if config is not None: + for k, v in config.items(): + self._add_item(k, v) + + def __missing__(self, key): + raise KeyError(key) + + def __getattr__(self, key): + try: + value = super().__getitem__(key) + return value + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + super().__setitem__(key, value) + + def _add_item(self, key, value): + if isinstance(value, dict): + self.__setattr__(key, Config(value)) + else: + self.__setattr__(key, value) + + def update(self, config): + assert isinstance( + config, + (Config, dict)), 'can only update dictionary or Config objects.' + for k, v in config.items(): + self._add_item(k, v) + return self + + @staticmethod + def from_file(filename: str): + """Reads a python file and constructs a corresponding :class:`Config` + object. + + Args: + filename (str): Name of the file to construct the return object. + + Returns: + :class:`Config`: A :class:`Config` object constructed with information in the file. + + Raises: + AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file + """ + + # check config path + if isinstance(filename, str): + filepath = Path(filename).absolute() + elif isinstance(filename, Path): + filepath = filename.absolute() + + assert filepath.exists( + ), f'{filename} is not found, please check your configuration path' + + # check extension + extension = filepath.suffix + assert extension == '.py', 'only .py files are supported' + + # import the config as module + remove_path = False + if filepath.parent not in sys.path: + sys.path.insert(0, (filepath)) + remove_path = True + + module_name = filepath.stem + source_file = SourceFileLoader( + fullname=str(module_name), path=str(filepath)) + module = source_file.load_module() + + # load into config + config = Config() + + for k, v in module.__dict__.items(): + if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + continue + else: + config._add_item(k, v) + + # NOTE: variables which starts with __, is a module or class declaration are omitted in config file + # remove module + del sys.modules[module_name] + if remove_path: + sys.path.pop(0) + + return config + + +class ConfigException(Exception): + pass diff --git a/xtuner/rlhf/config/config_consts.py b/xtuner/rlhf/config/config_consts.py new file mode 100644 index 000000000..03c64aa43 --- /dev/null +++ b/xtuner/rlhf/config/config_consts.py @@ -0,0 +1,18 @@ +# keywords for config files + +# model type (policy, critic, reward, reference, ...) for `model_type` +MODEL_TYPE_POLICY = 'policy' +MODEL_TYPE_REFERENCE = 'reference' +MODEL_TYPE_REWARD = 'reward' +MODEL_TYPE_CRITIC = 'critic' + +# training or generation engines for `trainer_type` and `generator_type` +ENGINE_HUGGINGFACE = 'huggingface' +ENGINE_INTERNEVO = 'internevo' +ENGINE_VLLM = 'vllm' +ENGINE_LMDEPLOY = 'lmdeploy' + +# plugins for trainer engine (e.g., huggingface accelerate) +ENGINE_PLUGIN_DDP = 'ddp' +ENGINE_PLUGIN_FSDP = 'fsdp' +ENGINE_PLUGIN_DEEPSPEED = 'deepspeed' diff --git a/xtuner/rlhf/config/config_utils.py b/xtuner/rlhf/config/config_utils.py new file mode 100644 index 000000000..ae1ebb3f0 --- /dev/null +++ b/xtuner/rlhf/config/config_utils.py @@ -0,0 +1,65 @@ +from loguru import logger + + +def get_gpu_requirement(trainer_config: dict) -> int: + # Calculates the number of GPUs required for a given trainer configuration. + return get_dp_size(trainer_config) * get_tp_size( + trainer_config) * get_pp_size(trainer_config) + + +def get_resource_requirement(model_configs: dict) -> dict: + """Analyzes resource requirements for a list of model configs and returns a + dictionary with the total number of GPUs and CPUs required. + + Args: + model_configs (dict): A dictionary containing model configurations. + + Returns: + dict: A dictionary with the total number of GPUs and CPUs required. + """ + + resources = {'num_gpus': 0} + for name, model_config in model_configs.items(): + if 'trainer_config' not in model_config: + logger.warning(f'{name} has no trainer_config. SKIP.') + continue + trainer_config = model_config['trainer_config'] + num_gpus = get_gpu_requirement(trainer_config) + + if 'generator_config' in model_config: + generator_config = model_config['generator_config'] + if not generator_config.get( + 'shared_with_trainer'): # None or False + num_gpus += get_gpu_requirement(generator_config) + + resources['num_gpus'] += num_gpus + + resources['num_cpus'] = resources['num_gpus'] * 10 + return resources + + +def get_dp_size(trainer_config: dict) -> int: + dp_size = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('data', {'size': 1}) + dp_size = data['size'] + return dp_size + + +def get_tp_size(trainer_config: dict) -> int: + tp_size = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('tensor', {'size': 1}) + tp_size = data['size'] + return tp_size + + +def get_pp_size(trainer_config: dict) -> int: + pp_size = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('pipeline', {'size': 1}) + pp_size = data['size'] + return pp_size diff --git a/xtuner/rlhf/coordinator.py b/xtuner/rlhf/coordinator.py new file mode 100644 index 000000000..baa1345c2 --- /dev/null +++ b/xtuner/rlhf/coordinator.py @@ -0,0 +1,96 @@ +from pathlib import Path + +import ray +from loguru import logger + +from .config.config_consts import (MODEL_TYPE_CRITIC, MODEL_TYPE_POLICY, + MODEL_TYPE_REFERENCE, MODEL_TYPE_REWARD) +from .config.config_utils import get_resource_requirement +from .model_server import (BaseModelServer, CriticModelServer, + PolicyModelServer, RefModelServer, + RewardModelServer) + +ROOT_PATH = Path(__file__).parents[1].resolve() + + +class Coordinator: + + def __init__(self, cluster_address: str, configs: dict): + self.cluster_address = cluster_address + self.model_configs = configs['model_configs'] + self.tokenizer_config = configs.get('tokenizer_config', {}) + self.model_dict = dict() + self.context_type: str = None # "client" or "server" + self.context: ray._private.workers.BaseContext = None + + resources = get_resource_requirement(self.model_configs) + logger.info(f'Required resources: {resources}') + runtime_env = {'working_dir': ROOT_PATH} + logger.info(f'working_dir (root_path): {ROOT_PATH}') + + try: + client_context = ray.init( + address=self.cluster_address, + runtime_env=runtime_env, + ignore_reinit_error=True, + ) + logger.info( + f'Connected to a running ray cluster at {self.cluster_address}' + ) + self.context_type = 'client' + self.context = client_context + + except ConnectionError: + logger.info( + f'Error connecting to {self.cluster_address}, try initializing a new ray cluster.' # noqa: E501 + ) + ray_context = ray.init( + address=None, + resources=resources, + runtime_env=runtime_env, + ignore_reinit_error=True, + ) + node_ip_address = ray_context.address_info['node_ip_address'] + logger.info(f'Initialize a ray cluster at {node_ip_address}') + self.context_type = 'server' + self.context = ray_context + + def create_models(self) -> dict[str, BaseModelServer]: + self.model_dict = {} + for model_name, model_config in self.model_configs.items(): + model_type = model_config['model_type'] + model_config['tokenizer_config'] = self.tokenizer_config + if model_type == MODEL_TYPE_POLICY: + self.model_dict[model_name] = PolicyModelServer( + model_name, model_config) + elif model_type == MODEL_TYPE_CRITIC: + self.model_dict[model_name] = CriticModelServer( + model_name, model_config) + elif model_type == MODEL_TYPE_REWARD: + self.model_dict[model_name] = RewardModelServer( + model_name, model_config) + elif model_type == MODEL_TYPE_REFERENCE: + self.model_dict[model_name] = RefModelServer( + model_name, model_config) + else: + raise NotImplementedError(f'Unknown model_type: {model_type}') + self._schedule() + return self.model_dict + + def _schedule(self): + for model_name, model in self.model_dict.items( + ): # naive serial initialize + model.initialize_async() + for model_name, model in self.model_dict.items( + ): # naive serial initialize + model.initialize_get() + logger.info( + f'{model_name} {model.__class__.__name__}.is_initialized: {model.is_initialized}' # noqa: E501 + ) + + def clean_up(self): + for _, model_server in self.model_dict.items(): + if model_server.trainer is not None: + model_server.trainer.release_resources() + if model_server.generator is not None: + model_server.generator.release_resources() diff --git a/xtuner/rlhf/dataset/__init__.py b/xtuner/rlhf/dataset/__init__.py new file mode 100644 index 000000000..dea11525b --- /dev/null +++ b/xtuner/rlhf/dataset/__init__.py @@ -0,0 +1,3 @@ +from .message_iter import MessageIter + +__all__ = ['MessageIter'] diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py new file mode 100644 index 000000000..68bddb695 --- /dev/null +++ b/xtuner/rlhf/dataset/base.py @@ -0,0 +1,305 @@ +"""Basic datasets implement.""" + +import gzip +import json +import random +from contextlib import contextmanager + +import numpy as np +from loguru import logger +from torch.utils.data import ConcatDataset, Dataset, IterableDataset, Subset + + +@contextmanager +def open_file(filename): + """Construct a file handler. + + The handler can read a normal file or a file compressed by `gzip`. + """ + if filename.endswith('.gz'): + fp = gzip.open(filename, 'rt') + else: + fp = open(filename, encoding='utf-8') + yield fp + fp.close() + + +class InfiniteDataset(IterableDataset): + """Load infinite data from original dataset with shuffle.""" + + def __init__(self, dataset, rng=None): + logger.info(f'init [InfiniteDataset] for {dataset} ...') + self.data = list( + iter(dataset)) if dataset.data_list is None else dataset.data_list + self.tokenizer = dataset.tokenizer + self.sys_prompt = dataset.sys_prompt + self.rm_prompt = dataset.rm_prompt + + self.indices = list(range(len(self.data))) + if rng is None: + rng = random.Random() + self.rng = rng + + def __iter__(self): + while True: + self.rng.shuffle(self.indices) + for i in self.indices: + if isinstance(self.data[i], dict): + yield self.data[i] + elif isinstance(self.data[i], list): + try: + self.tokenizer.apply_chat_template( + self.data[i], tokenize=True) + except Exception: + logger.info('[data tokenize check] ' + f'skip dirty data: {self.data[i]}') + continue + yield dict( + data=self.data[i], + sys_prompt=self.sys_prompt, + rm_prompt=self.rm_prompt) + + +class IterDataset(IterableDataset): + """Single json file dataset.""" + + def __init__(self, + filename=None, + data_list=None, + tokenizer=None, + sys_prompt='default', + rm_prompt='default'): + assert filename is not None or data_list is not None + self._filename = filename + self.data_list = data_list + self.tokenizer = tokenizer + self.sys_prompt = sys_prompt + self.rm_prompt = rm_prompt + + def __iter__(self): + if self.data_list is not None: + for lineno, data in enumerate(self.data_list): + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + except Exception: + logger.info( + f'[data tokenize check] skip dirty data: {data}') + continue + yield dict( + data=data, + sys_prompt=self.sys_prompt, + rm_prompt=self.rm_prompt) + else: + with open_file(self._filename) as fin: + for lineno, line in enumerate(fin): + data = json.loads(line) + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + except Exception: + logger.info( + f'[data tokenize check] skip dirty data: {data}') + continue + yield dict( + data=data, + sys_prompt=self.sys_prompt, + rm_prompt=self.rm_prompt) + + +class MultiSourceInBatchDatset(IterableDataset): + """Multiple source dataset.""" + + def __init__(self, task_groups, tokenizer=None, random_seed=1024): + self._task_group = [] + for _task in task_groups: + file_path, extra_info = _task.split('::')[0], _task.split('::')[1] + prob = float(extra_info.split('[')[0]) + sys_prompt = 'default' + rm_prompt = 'default' + if '[SYS_PROMPT]:' in extra_info: + sys_prompt = extra_info.split('[SYS_PROMPT]:')[-1].split( + '[')[0] + if '[RM_PROMPT]:' in extra_info: + rm_prompt = extra_info.split('[RM_PROMPT]:')[-1].split('[')[0] + if prob > 0: + self._task_group.append( + dict( + prob=prob, + filepath=file_path, + sys_prompt=sys_prompt, + rm_prompt=rm_prompt)) + logger.info(f'[DataLoader] Load {_task} with prob:{prob}, ' + f'sys_prompt type: {sys_prompt}, ' + f'reward prompt type: {rm_prompt}') + else: + logger.warning('[DataLoader] skip file, ' + f'prob of {file_path} is {prob} ...') + assert len(self._task_group) > 0, 'No data to be trained' + + for task in self._task_group: + filepath = task['filepath'] + if '[HF]' in filepath: + from xtuner.rlhf.dataset.utils.from_hf import load_from_hf + + # loading & convert & save opensource datasets + hf_dir = filepath.split('[HF]')[-1] + logger.info(f'Loading {hf_dir} from huggingface ...') + dataset = load_from_hf(hf_dir, tokenizer=tokenizer) + task['dataset'] = IterDataset( + filename=hf_dir, + data_list=dataset['conversation'], + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + + else: + task['dataset'] = IterDataset( + filename=filepath, + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + + sum_prob = sum([task['prob'] for task in self._task_group]) + for task in self._task_group: + task['prob'] = task['prob'] / sum_prob + + self.random_seed = random_seed + + def __iter__(self): + rng = random.Random(self.random_seed) + probs = [task['prob'] for task in self._task_group] + # Initialize task iterator + for task in self._task_group: + task['iterator'] = iter(task['dataset']) + while True: + task = rng.choices(self._task_group, weights=probs)[0] + try: + yield from task['iterator'] + except StopIteration: + task['iterator'] = iter(task['dataset']) + yield from task['iterator'] + + +class JsonDataset(Dataset): + """Single json file dataset.""" + + def __init__(self, + filename=None, + data_list=None, + tokenizer=None, + sys_prompt='default', + rm_prompt='default'): + assert filename is not None or data_list is not None + self.tokenizer = tokenizer + self.sys_prompt = sys_prompt + self.rm_prompt = rm_prompt + + if filename is not None: + self.data_list = [] + with open_file(filename) as fin: + for lineno, line in enumerate(fin): + data = json.loads(line) + self.data_list.append(data) + else: + self.data_list = data_list + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index: int): + data = self.data_list[index] + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + return { + 'data': data, + 'sys_prompt': self.sys_prompt, + 'rm_prompt': self.rm_prompt + } + except Exception: + logger.info(f'[data tokenize check] skip dirty data: {data}') + return None + + +class MultiSourceInDataDatset(Dataset): + """Multi source dataset. + + Args: + task_groups: list of data path. + e.g. ['PATH_TO_XTUNER/examples/rlhf/demo_datas/prompt_data.json::0.9[SYS_PROMPT]:summarization', # noqa: E501 + 'PATH_TO_XTUNER/examples/rlhf/demo_datas/pretrain_data.json::0.1', + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5' + ] + tokenizer: The tokenizer processes some raw text as input and outputs + an Encoding. This argument should not be None. Default to None. + random_seed: + """ + + def __init__(self, task_groups, tokenizer=None, random_seed=1024): + self._task_group = [] + for _task in task_groups: + file_path, extra_info = _task.split('::')[0], _task.split('::')[1] + prob = float(extra_info.split('[')[0]) + sys_prompt = 'default' + rm_prompt = 'default' + if '[SYS_PROMPT]:' in extra_info: + sys_prompt = extra_info.split('[SYS_PROMPT]:')[-1].split( + '[')[0] + if '[RM_PROMPT]:' in extra_info: + rm_prompt = extra_info.split('[RM_PROMPT]:')[-1].split('[')[0] + if prob > 0: + self._task_group.append( + dict( + prob=prob, + filepath=file_path, + sys_prompt=sys_prompt, + rm_prompt=rm_prompt)) + logger.info( + f'[DataLoader] Load {_task} with prob:{prob}, ' + f'sys_prompt type: {sys_prompt}, reward meta: {rm_prompt}') + else: + logger.warning('[DataLoader] skip file, ' + f'prob of {file_path} is {prob} ...') + assert len(self._task_group) > 0, 'No data to be trained' + + datasets = [] + for task in self._task_group: + filepath = task['filepath'] + + if '[HF]' in filepath: + from xtuner.rlhf.dataset.utils.from_hf import load_from_hf + + # loading & convert & save opensource datasets + hf_dir = filepath.split('[HF]')[-1] + logger.info(f'Loading {hf_dir} with huggingface format ...') + dataset = load_from_hf(hf_dir, tokenizer=tokenizer) + task['dataset'] = JsonDataset( + filename=hf_dir, + data_list=dataset['conversation'], + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + else: + task['dataset'] = JsonDataset( + filename=filepath, + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + task['dataset'] = self._get_subset_by_ratio( + task['dataset'], task['prob'], random_seed) + datasets.append(task['dataset']) + + self.all_dataset = ConcatDataset(datasets) + self.iter_all_dataset = iter(self.all_dataset) + + self.random_seed = random_seed + + def _get_subset_by_ratio(self, dataset: Dataset, ratio: float, seed: int): + np_random = np.random.RandomState(seed) + indices = np.arange(len(dataset)) + np_random.shuffle(indices) + subset_indices = indices[:int(len(dataset) * ratio)] + subset_indices = list(subset_indices) + return Subset(dataset, subset_indices) + + def __iter__(self): + yield next(self.iter_all_dataset) diff --git a/xtuner/rlhf/dataset/message_iter.py b/xtuner/rlhf/dataset/message_iter.py new file mode 100644 index 000000000..94c78b83a --- /dev/null +++ b/xtuner/rlhf/dataset/message_iter.py @@ -0,0 +1,245 @@ +"""Finetuning dataset.""" +import random +from dataclasses import dataclass +from typing import List + +import numpy as np +from loguru import logger +from torch.utils.data import DataLoader, RandomSampler + +from xtuner.rlhf.dataset.base import (InfiniteDataset, + MultiSourceInBatchDatset, + MultiSourceInDataDatset) + + +@dataclass +class Message: + message: List[dict] + sys_prompt: str = 'default' + rm_prompt: str = 'default' + token_ids: List[int] = None + mes_type: str = 'prompt' + + +class MessageIter(): + """Create sequences from dataset. + + Args: + sample_strategy (str) ["in_batch", "in_data"]: + "in_batch": sample data by ratio for every single training batch + "in_data": merge all data by ratio and then sample training batch + """ + + def __init__(self, + message_datasets: list[str] = None, + message_type: str = 'prompt', + tokenizer=None, + max_len: int = 4096, + samples_each_epoch: int = 0, + random_seed: int = 110, + sample_strategy: str = 'in_batch', + **kwargs): + assert message_type in ['prompt', 'pretrain'] + assert sample_strategy in [ + 'in_batch', 'in_data' + ], ("`sample_strategy` should in ['in_batch', 'in_data']," + f' but got {sample_strategy}') + if (message_datasets is None) or (samples_each_epoch == 0): + logger.warning(f'message_datasets: {message_datasets}' + f' samples_each_epoch: {samples_each_epoch}.') + self.message_datasets = None + self.samples_each_epoch = 0 + return None + assert message_datasets is not None + self.message_type = message_type + self.sample_strategy = sample_strategy + self.tokenizer = tokenizer + assert self.tokenizer.chat_template is not None, ( + 'Make sure tokenizer has chat_template.') + # message data + self.message_datasets = message_datasets + self.samples_each_epoch = samples_each_epoch + self.max_len = max_len + + self.random_seed = random_seed + self.rng = random.Random(self.random_seed) + np.random.seed(self.random_seed) + random.seed(self.random_seed) + + if self.sample_strategy == 'in_batch': + self._init_in_batch() + elif self.sample_strategy == 'in_data': + self._init_in_data() + else: + raise NotImplementedError( + "sample_strategy should in ['in_batch', 'in_data']," + f' but got {sample_strategy}') + logger.info(f'[MES_ITER] {self.message_type} dataset initialized, ' + f'random seed {self.random_seed}, ' + f'{self.samples_each_epoch} per epoch.\n') + + self.epoch_index = 0 + + def _init_in_data(self): + logger.info(f'Init {self.message_type} in data dataset ...') + self.message_dataset = MultiSourceInDataDatset( + task_groups=self.message_datasets, tokenizer=self.tokenizer) + + logger.info(f'Init {self.message_type} in data sampler ...') + assert hasattr(self.message_dataset, 'all_dataset') + mes_sampler = RandomSampler(self.message_dataset.all_dataset) + self.mes_dataloader = iter( + DataLoader( + self.message_dataset.all_dataset, + collate_fn=lambda x: x, + sampler=mes_sampler, + batch_size=self.samples_each_epoch)) + + def yield_in_data(self): + logger.info('yielding data from ' + f'{self.message_type} in_data sampler ...') + mes_sequence = [] + + mes_batch_messages = next(self.mes_dataloader) + for index, message in enumerate(mes_batch_messages): + if message is None: + continue + sequence = self._postprocess_sequence(message) + if sequence is not None: + mes_sequence.append(sequence) + if len(mes_sequence) == self.samples_each_epoch: + break + # TODO, len(mes_sequence) < self.samples_each_epoch, + # tmp: random sample from chosen data + if len(mes_sequence) < self.samples_each_epoch: + missed = self.samples_each_epoch - len(mes_sequence) + logger.warning( + f'[MES_ITER] {self.message_type} {missed} dirty data ...') + for i in range(missed): + mes_sequence.append(mes_sequence[i]) + + assert len( + mes_sequence + ) == self.samples_each_epoch, \ + f'{len(mes_sequence)} == {self.samples_each_epoch}' + + assert len(mes_sequence) == self.samples_each_epoch + logger.info(f'[Epoch {self.epoch_index}] ' + f'sample {len(mes_sequence)} {self.message_type}') + return mes_sequence + + def _init_in_batch(self): + logger.info(f'Init {self.message_type} in batch dataset ...') + self.message_dataset = MultiSourceInBatchDatset( + task_groups=self.message_datasets, tokenizer=self.tokenizer) + + logger.info(f'Init {self.message_type} in batch sampler ...') + samples_cnts = [] + for task in self.message_dataset._task_group: + task['target_num_each_epoch'] = int( + task['prob'] * self.samples_each_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task['dataset'], self.rng) + task['iterator'] = iter(inner_dataset) + samples_cnts.append(task['target_num_each_epoch']) + logger.info( + f"[MES_ITER] {task['filepath']}: task prob: {task['prob']}" + f' original number of messages: {len(inner_dataset.data)}' + f" target_num_each_epoch: {task['target_num_each_epoch']}") + assert sum(samples_cnts) >= self.samples_each_epoch + + def yield_in_batch(self): + logger.info('yield data from ' + f'{self.message_type} in_batch sampler ...') + mes_sequence = [] + + # epoch_rng only use in this epoch. + epoch_rng = np.random.RandomState(self.epoch_index) + # prepare epoch data + mes_batch_messages = [] + for task in self.message_dataset._task_group: + messages = [] + for _ in range(task['target_num_each_epoch']): + messages.append(next(task['iterator'])) + logger.info(f'[MES_ITER] sample {len(messages)} ' + f"{self.message_type} from {task['filepath']}") + epoch_rng.shuffle(messages) + mes_batch_messages.extend(messages) + epoch_rng.shuffle(mes_batch_messages) + for index, message in enumerate(mes_batch_messages): + sequence = self._postprocess_sequence(message) + if sequence is not None: + mes_sequence.append(sequence) + if len(mes_sequence) == self.samples_each_epoch: + break + # TODO, len(mes_sequence) < self.samples_each_epoch, + # tmp: random sample from chosen data + if len(mes_sequence) < self.samples_each_epoch: + missed = self.samples_each_epoch - len(mes_sequence) + logger.warning( + f'[MES_ITER] {self.message_type} {missed} dirty data ...') + for i in range(missed): + mes_sequence.append(mes_sequence[i]) + + assert len(mes_sequence) == self.samples_each_epoch + logger.info(f'[Epoch {self.epoch_index}] sample ' + f'{len(mes_sequence)} {self.message_type}') + + return mes_sequence + + def __iter__(self): + while True: + if self.sample_strategy == 'in_batch': + yield self.yield_in_batch() + elif self.sample_strategy == 'in_data': + yield self.yield_in_data() + + self.epoch_index += 1 + + def _postprocess_sequence(self, message): + """Post process sequence: tokenization & truncation.""" + message_data = message['data'] + new_meaasage_data = [] + if self.message_type == 'prompt': + for _ in reversed(range(len(message_data))): + if message_data[_]['role'] == 'user': + new_meaasage_data = message_data[:_ + 1] + break + assert new_meaasage_data[-1]['role'] == 'user', \ + f'prompt data last role must user, {new_meaasage_data}' + token_ids = self.tokenizer.apply_chat_template( + new_meaasage_data, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt') + if (token_ids.shape[-1] <= 4) or (token_ids.shape[-1] > + self.max_len): + # TODO truncation?? + logger.warning( + f'[MES_ITER] {self.message_type} message {message} ' + 'is too short or long, skipped.') + return None + elif self.message_type == 'pretrain': + for _ in reversed(range(len(message_data))): + if message_data[_]['role'] == 'assistant': + new_meaasage_data = message_data[:_ + 1] + break + assert new_meaasage_data[-1]['role'] == 'assistant', \ + f'pretrain data last role must assistant, {new_meaasage_data}' + token_ids = self.tokenizer.apply_chat_template( + new_meaasage_data, + tokenize=True, + add_generation_prompt=False, + return_tensors='pt') + + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_len: + # TODO truncation?? + logger.warning( + f'[MES_ITER] {self.message_type} message {message} ' + 'is too short or long, skipped.') + return None + return Message( + message=new_meaasage_data, + token_ids=token_ids, + sys_prompt=message['sys_prompt'], + rm_prompt=message['rm_prompt'], + mes_type=self.message_type) diff --git a/xtuner/rlhf/dataset/utils/__init__.py b/xtuner/rlhf/dataset/utils/__init__.py new file mode 100644 index 000000000..b16c4da71 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/__init__.py @@ -0,0 +1,15 @@ +from .collate_fns import message_data_collator, messages_collate_fn +from .map_fns import (FW_fineweb_edu_map_fn, H4_hhh_alignment_map_fn, + H4_summarize_map_fn, argilla_prompt_map_fn, + default_map_fn, hhrlhf_map_fn, nvidia_HelpSteer_map_fn, + nvidia_OpenMathInstruct_map_fn, + nvidia_sft_datablend_v1_map_fn, + stingning_ultrachat_map_fn) + +__all__ = [ + 'message_data_collator', 'messages_collate_fn', 'default_map_fn', + 'hhrlhf_map_fn', 'H4_summarize_map_fn', 'H4_hhh_alignment_map_fn', + 'stingning_ultrachat_map_fn', 'nvidia_HelpSteer_map_fn', + 'nvidia_OpenMathInstruct_map_fn', 'nvidia_sft_datablend_v1_map_fn', + 'argilla_prompt_map_fn', 'FW_fineweb_edu_map_fn' +] diff --git a/xtuner/rlhf/dataset/utils/collate_fns.py b/xtuner/rlhf/dataset/utils/collate_fns.py new file mode 100644 index 000000000..c3551f2a0 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/collate_fns.py @@ -0,0 +1,28 @@ +from collections import defaultdict +from functools import partial +from typing import Dict, Sequence + + +def messages_collate_fn( + instances: Sequence[Dict], + return_only_messages: bool = True, +): + + return_dict = defaultdict(list) + messages = [] + + for example in instances: + assert 'conversation' in example.keys() + messages.append(example['conversation']) + for k, v in example.items(): + return_dict[k].append(v) + + if return_only_messages: + return messages + else: + return return_dict + + +def message_data_collator(return_only_messages=True): + return partial( + messages_collate_fn, return_only_messages=return_only_messages) diff --git a/xtuner/rlhf/dataset/utils/from_hf.py b/xtuner/rlhf/dataset/utils/from_hf.py new file mode 100644 index 000000000..6242d4252 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/from_hf.py @@ -0,0 +1,177 @@ +from datasets import load_dataset +from loguru import logger + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.map_fns import template_map_fn_factory +# yapf: disable +from xtuner.rlhf.dataset.utils import (FW_fineweb_edu_map_fn, + H4_hhh_alignment_map_fn, + H4_summarize_map_fn, + argilla_prompt_map_fn, default_map_fn, + hhrlhf_map_fn, nvidia_HelpSteer_map_fn, + nvidia_OpenMathInstruct_map_fn, + nvidia_sft_datablend_v1_map_fn, + stingning_ultrachat_map_fn) +# yapf: enable +from xtuner.utils import PROMPT_TEMPLATE + + +def read_hf_dataset(tokenizer, + path: str = None, + data_dir: str = None, + name: str = None, + data_files: dict = None, + dataset_map_fn=None, + max_length=8192, + split='train', + prompt_template=PROMPT_TEMPLATE.internlm_chat, + remove_unused_columns=False, + shuffle_before_pack=False, + pack_to_max_length=False): + template_map_fn = template_map_fn_factory(template=prompt_template) + dataset_org = load_dataset( + path, + name=name, + data_dir=data_dir, + data_files=data_files, + trust_remote_code=True) + logger.info(f'load_dataset {path}, {dataset_org}') + dataset = process_hf_dataset( + dataset=dataset_org, + tokenizer=tokenizer, + max_length=max_length, + split=split, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + remove_unused_columns=remove_unused_columns, + shuffle_before_pack=shuffle_before_pack, + pack_to_max_length=pack_to_max_length) + return dataset + + +def load_from_hf(hf_dir, tokenizer, data_dir=None): + if 'Anthropic/hh-rlhf' in hf_dir: + if data_dir is not None: + data_dir = data_dir + elif 'helpful-base' in hf_dir: + data_dir = 'helpful-base' + elif 'harmless-base' in hf_dir: + data_dir = 'harmless-base' + logger.info(f'loading from `Anthropic/hh-rlhf`, data_dir={data_dir},' + ' split=`train`, map_fn=hhrlhf_map_fn...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path='Anthropic/hh-rlhf', + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=hhrlhf_map_fn) + elif 'HuggingFaceH4' in hf_dir: + if 'summarize_from_feedback' in hf_dir: + H4_path = 'HuggingFaceH4/summarize_from_feedback' + H4_map_fn = H4_summarize_map_fn + elif 'hhh_alignment': + H4_path = 'HuggingFaceH4/hhh_alignment' + H4_map_fn = H4_hhh_alignment_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + H4_path = hf_dir + H4_map_fn = default_map_fn + logger.info(f'loading {H4_path}, data_dir={data_dir}, ' + f'split=`train_prefs`, map_fn={H4_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=H4_path, + data_dir=data_dir, + max_length=8192, + split='train_prefs', + dataset_map_fn=H4_map_fn) + elif 'ultrachat' in hf_dir: + logger.info( + f'loading from `stingning/ultrachat`, data_dir={data_dir}, ' + 'split=`train`, map_fn=stingning_ultrachat_map_fn...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path='stingning/ultrachat', + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=stingning_ultrachat_map_fn) + elif 'nvidia' in hf_dir: + if 'HelpSteer' in hf_dir: + nvidia_map_fn = nvidia_HelpSteer_map_fn + elif 'OpenMathInstruct' in hf_dir: + nvidia_map_fn = nvidia_OpenMathInstruct_map_fn + elif 'sft_datablend_v1' in hf_dir: + nvidia_map_fn = nvidia_sft_datablend_v1_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + nvidia_map_fn = default_map_fn + logger.info(f'loading from {hf_dir}, data_dir={data_dir}, ' + f'split=`train`, map_fn={nvidia_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=hf_dir, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=nvidia_map_fn) + elif 'argilla' in hf_dir: + if 'prompt-collective' in hf_dir: + argilla_path = 'argilla/prompt-collective' + argilla_map_fn = argilla_prompt_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + argilla_path = hf_dir + argilla_map_fn = default_map_fn + logger.info(f'loading from {argilla_path}, data_dir={data_dir}, ' + f'split=`train`, map_fn={argilla_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=argilla_path, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=argilla_map_fn) + elif 'HuggingFaceFW' in hf_dir: + if 'fineweb-edu' in hf_dir: + FW_path = 'HuggingFaceFW/fineweb-edu' + FW_name = 'CC-MAIN-2024-10' + FW_data_files = { + 'train': [ + 'data/CC-MAIN-2024-10/train-00000-of-00020.parquet', + ] + } + FW_map_fn = FW_fineweb_edu_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + FW_path = hf_dir + FW_map_fn = default_map_fn + logger.info(f'loading from {FW_path}, name={FW_name}, ' + f'data_files={FW_data_files}, data_dir={data_dir}, ' + f'split=`train`, map_fn={FW_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=FW_path, + name=FW_name, + data_files=FW_data_files, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=FW_map_fn) + else: + try: + logger.warning(f'Please specify your dataset_map_fn with {hf_dir}') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=hf_dir, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=default_map_fn) + except Exception as e: + logger.error(f'{e}') + logger.error(f'Cannot load {hf_dir}, ' + 'checkout your datapath or dataset_map_fn...') + logger.info(f'Loaded {hf_dir}, {dataset}') + return dataset diff --git a/xtuner/rlhf/dataset/utils/map_fns.py b/xtuner/rlhf/dataset/utils/map_fns.py new file mode 100644 index 000000000..4bb68c4a3 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/map_fns.py @@ -0,0 +1,117 @@ +import re + + +def default_map_fn(example): + return example + + +def hhrlhf_map_fn(example): + string = example['chosen'] + pattern = r'(\n\nHuman|\n\nAssistant)(.+?)(?=(\n\nHuman|\n\nAssistant|$))' + matches = re.findall(pattern, string, re.DOTALL) + messages = [] + for match in matches: + role, content = match[0].strip(), match[1].strip() + if role == 'Human': + messages.append(dict(role='user', content=content[2:])) + elif role == 'Assistant': + messages.append(dict(role='assistant', content=content[2:])) + else: + raise NotImplementedError('role must in Human or Assistant') + return {'conversation': messages} + + +def H4_hhh_alignment_map_fn(example): + input = example['input'] + choices = example['targets']['choices'] + labels = example['targets']['labels'] + for label, choice in zip(labels, choices): + if label == 1: + chosen = choice + messages = [ + dict(role='user', content=input), + dict(role='assistant', content=chosen) + ] + return {'conversation': messages} + + +def H4_summarize_map_fn(example): + # prompt = example['prompt'] + chosen = example['chosen'] + # rejected = example['rejected'] + return {'conversation': chosen} + + +def stingning_ultrachat_map_fn(example): + # id = example['id'] + data = example['data'] + messages = [] + for i, d in enumerate(data): + if i % 2 == 0: + role = 'user' + else: + role = 'assistant' + messages.append(dict(role=role, content=d)) + + return {'conversation': messages} + + +def nvidia_HelpSteer_map_fn(example): + prompt = example['prompt'] + response = example['response'] + messages = [ + dict(role='user', content=prompt), + dict(role='assistant', content=response) + ] + + return {'conversation': messages} + + +def nvidia_OpenMathInstruct_map_fn(example): + question = example['question'] + # expected_answer = example['expected_answer'] + generated_solution = example['generated_solution'] + messages = [ + dict(role='user', content=question), + dict(role='assistant', content=generated_solution) + ] + + return {'conversation': messages} + + +def nvidia_sft_datablend_v1_map_fn(example): + conversations = example['conversations'] + # system = example['system'] + messages = [] + for conv in conversations: + if conv['from'] == 'User': + role = 'user' + elif conv['from'] == 'Assistant': + role = 'assistant' + messages.append(dict(role=role, content=conv['value'])) + + return {'conversation': messages} + + +def argilla_prompt_map_fn(example): + prompt = example['prompt'] + messages = [dict(role='user', content=prompt)] + return {'conversation': messages} + + +def dibt_prompt_map_fn(example): + prompt = example['prompt'] + messages = [dict(role='user', content=prompt)] + return {'conversation': messages} + + +def FW_fineweb_edu_map_fn(example): + question = '' + answer = example['text'] + token_count = example['token_count'] + messages = [ + dict(role='user', content=question), + dict(role='assistant', content=answer) + ] + + return {'conversation': messages, 'token_count': token_count} diff --git a/xtuner/rlhf/envs/__init__.py b/xtuner/rlhf/envs/__init__.py new file mode 100644 index 000000000..5175867a2 --- /dev/null +++ b/xtuner/rlhf/envs/__init__.py @@ -0,0 +1,3 @@ +from .txt_env import TxtEnv + +__all__ = ['TxtEnv'] diff --git a/xtuner/rlhf/envs/base.py b/xtuner/rlhf/envs/base.py new file mode 100644 index 000000000..6f3f1e84e --- /dev/null +++ b/xtuner/rlhf/envs/base.py @@ -0,0 +1,12 @@ +class EnvBase: + """`EnvBase` is the base class of different environments. + + `env` is responsible to generate the trajectory data. + """ + + def __init__(self): + pass + + def rollout(self, *args, **kwargs): + """define rollout.""" + raise NotImplementedError diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py new file mode 100644 index 000000000..2fcce30c5 --- /dev/null +++ b/xtuner/rlhf/envs/txt_env.py @@ -0,0 +1,181 @@ +from collections.abc import Iterable +from copy import deepcopy + +import torch +from loguru import logger + +from ..model_server.base_model_server import BaseModelServer +from ..timer import Timer +from .base import EnvBase +from .utils import SYSTEM_PROMPT + + +class TxtEnv(EnvBase): + """A generic RL environment to generate textual sequences.""" + + def __init__( + self, + policy_model: BaseModelServer, + reward_model: BaseModelServer, + prompt_mes_iter: Iterable, + pretrain_mes_iter: Iterable = None, + max_new_tokens: int = 1024, + policy_micro_bs: int = 32, + reward_micro_bs: int = 32, + async_reward: bool = True, + generate_kwargs: dict = None, + resume_step=-1, + **_ignored, + ): + self.policy_model = policy_model + self.reward_model = reward_model + + self.prompt_mes_iter = iter(prompt_mes_iter) + self.pretrain_mes_iter = iter( + pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None + + self.max_new_tokens = max_new_tokens + self.policy_micro_bs = policy_micro_bs + self.reward_micro_bs = reward_micro_bs + self.async_reward = async_reward + self.generate_kwargs: dict = generate_kwargs + self.resume_step = resume_step + + def rollout(self, display=True): + while self.resume_step > 0: + logger.info(f'[Resume] {self.resume_step} consuming data...') + next(self.prompt_mes_iter) + if self.pretrain_mes_iter is not None: + next(self.pretrain_mes_iter) + self.resume_step -= 1 + prompt_datas = deepcopy(next(self.prompt_mes_iter)) + prompt_input_messages = [] + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) + else: + message = deepcopy(data.message) + prompt_input_messages.append(message) + # prompt data + if display: + logger.info( + f'[TXT_ENV For Generate]: \n{prompt_input_messages[0]}') + with Timer('policy_model.generate'): + trajectories = self.policy_model.generate( + inputs=prompt_input_messages, + micro_batch_size=self.policy_micro_bs, + step=self.max_new_tokens, + output_str=True, + generate_kwargs=self.generate_kwargs) + logger.info(f'[Generate] len: {len(prompt_input_messages)}') + + if self.async_reward: + reward_output_ref = self.get_reward_async(prompt_datas, + trajectories) + trajectories['reward_output_ref'] = reward_output_ref + else: + rewards = self.get_reward(prompt_datas, trajectories) + trajectories['rewards'] = rewards + + # pretrain data + if self.pretrain_mes_iter is not None: + pretrain_datas = deepcopy(next(self.pretrain_mes_iter)) + pretrain_input_messages = [] + for data in pretrain_datas: + assert data.mes_type == 'pretrain' + pretrain_input_messages.append(message) + + from xtuner.rlhf.tokenizer import encode_inputs + pt_input_ids, pt_attention_mask = encode_inputs( + pretrain_input_messages, self.policy_model.tokenizer) + pretrain_labels = torch.nn.functional.pad( + pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) + + trajectories.pretrain_data = { + 'input_ids': pt_input_ids, + 'labels': pretrain_labels, + 'attention_mask': pt_attention_mask + } + logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') + else: + trajectories.pretrain_data = None + + return trajectories + + # default get_reward() is blocking. + # get_reward_async() needs to call get_reward_collect() + def get_reward_async(self, prompt_datas, policyout): + rm_input_messages = [] + for i in range(len(prompt_datas)): + if prompt_datas[i].mes_type != 'prompt': + continue + if (prompt_datas[i].rm_prompt != + 'default') or (prompt_datas[i].sys_prompt != 'default'): + # Conditional Reward Model + # for queries from different domains, use appropriate conditional system prompts # noqa: E501 + # From Alignment section of the InternLM2 Technical Report: + # https://arxiv.org/pdf/2403.17297 + if prompt_datas[i].rm_prompt != 'default': + prompt = prompt_datas[i].rm_prompt + else: + prompt = prompt_datas[i].sys_prompt + cur_rm_data = [ + dict(role='system', content=SYSTEM_PROMPT[prompt]) + ] + prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] + else: + cur_rm_data = prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] + rm_input_messages.append(cur_rm_data) + + logger.info(f'[For Reward]: {rm_input_messages[0]}') + with Timer('reward_model.infer_async'): + reward_output_ref = self.reward_model.infer_async( + rm_input_messages, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) + return reward_output_ref + + def get_reward_collect(self, reward_output_ref): + with Timer('reward_model.infer_get'): + rm_out = self.reward_model.infer_get(reward_output_ref) + rewards = rm_out.logits.squeeze(-1) + return rewards + + def get_reward(self, prompt_datas, policyout): + rm_input_messages = [] + for i in range(len(prompt_datas)): + if prompt_datas[i].mes_type != 'prompt': + continue + if prompt_datas[i].rm_prompt != 'default': + cur_rm_data = [ + dict( + role='system', + content=SYSTEM_PROMPT[prompt_datas[i].rm_prompt]) + ] + prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] + else: + cur_rm_data = prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] + rm_input_messages.append(cur_rm_data) + + logger.info(f'[For Reward]: {rm_input_messages[0]}') + with Timer('reward_model.infer'): + rm_out = self.reward_model.infer( + rm_input_messages, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) + rewards = rm_out.logits.squeeze(-1) + return rewards diff --git a/xtuner/rlhf/envs/utils.py b/xtuner/rlhf/envs/utils.py new file mode 100644 index 000000000..9695ec92b --- /dev/null +++ b/xtuner/rlhf/envs/utils.py @@ -0,0 +1,44 @@ +SYSTEM_PROMPT = { + 'default': + '', + 'helpful': + 'Consider the helpfulness of the given sentence. ' + 'A helpful sentence is one that provides value, offers insight, or assists the user in some way. ' # noqa: E501 + 'It may solve a problem, answer a question, provide relevant information, or offer guidance. ' # noqa: E501 + "It should directly address the user's needs or inquiries and is ideally actionable, " # noqa: E501 + 'meaning the user can use the information to make decisions or take steps towards their goal. ' # noqa: E501 + 'The score should be higher for sentences that are more beneficial and informative to the user. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'harmless': + 'Evaluate the harmlessness of the provided sentence. ' + 'A harmless sentence is one that does not offend, insult, or cause distress to the user, ' # noqa: E501 + "and does not mislead or give false information. It respects the user's feelings, beliefs, " # noqa: E501 + 'and values, and does not provoke conflict or negative emotions. ' + 'Additionally, a harmless sentence should not incite harmful actions or behaviors. ' # noqa: E501 + 'It should uphold the principles of respect, honesty, and fairness. ' + 'Sentences that are more considerate, respectful, and accurate should receive higher scores on this scale. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'hhh': + 'Evaluate the given sentence based on its helpfulness, harmlessness, and honesty. ' # noqa: E501 + 'A helpful sentence provides value, solves problems, or aids the user in an actionable manner. ' # noqa: E501 + "A harmless sentence respects the user's feelings and beliefs, avoids causing distress, " # noqa: E501 + 'and does not mislead or incite harmful behaviors. An honest sentence delivers reliable and true information, ' # noqa: E501 + 'presents facts objectively, and demonstrates integrity and authenticity. Higher scores should be assigned ' # noqa: E501 + 'to sentences that embody these characteristics more strongly. ' + 'Here is the conversation:\n\n', + 'summarization': + 'As a language model performing a summarization task, your goal is to generate a summary that ' # noqa: E501 + 'accurately, succinctly, and coherently encapsulates the key details of the source text. Ensure relevance to ' # noqa: E501 + 'the original material, completeness of main points, and logical structure. Maintain conciseness and high ' # noqa: E501 + 'linguistic standards. Ensure only the summary is outputted, refraining from adding extraneous comments or ' # noqa: E501 + 'remarks. Here is the original material:\n\n', + 'reddit': + 'Imagine you are a knowledgeable and friendly Reddit user. ' + 'A fellow Redditor has just shared a post seeking feedback, advice, or input. ' # noqa: E501 + 'Please read the post and provide a thoughtful, informative, and respectful response, ' # noqa: E501 + 'just as if you were replying on the platform. Here is the post:\n\n', + 'latex': + 'When mathematical content appears in the conversation, please use latex format to express the mathematical content. Here is the conversation:\n\n', # noqa: E501 + 'math_ci': + "Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:\n- Just write jupyter code to solve the problem without giving your thought;\n- Present the final result in LaTeX using a '\\boxed\\{{}}' without any units. \n", # noqa: E501 +} diff --git a/xtuner/rlhf/logger.py b/xtuner/rlhf/logger.py new file mode 100644 index 000000000..d774f2923 --- /dev/null +++ b/xtuner/rlhf/logger.py @@ -0,0 +1,91 @@ +# Adapted from +# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py +"""Logging configuration.""" +import logging +import sys +from functools import wraps +from time import perf_counter + +_FORMAT = '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' +_DATE_FORMAT = '%m-%d %H:%M:%S' + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None): + logging.Formatter.__init__(self, fmt, datefmt) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != '': + parts = msg.split(record.message) + msg = msg.replace('\n', '\r\n' + parts[0]) + return msg + + +_root_logger = logging.getLogger('marl') +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + # Use the same settings as above for root logger + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(_default_handler) + logger.propagate = False + return logger + + +def log_decorator(logger): + """ + Usage: + @log_decorator(logger) + def func(a, b, ...): + return 1 / 0 + + """ + + def decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + logger.info('----------- LOG DECORATOR -----------') + logger.info( + f'CALLED {func.__name__} ARGS: {args}; KWARGS:{kwargs}') + bgn = perf_counter() + try: + result = func(*args, **kwargs) + end = perf_counter() + dur = end - bgn + logger.info( + f'{func.__name__} RESULT: {result}; DURATION: {dur:4f}s') + return result + except Exception as e: + logger.exception(f'{func.__name__}: {e}') + logger.info('----------- LOG DECORATOR -----------') + + return wrapper + + return decorator diff --git a/xtuner/rlhf/loss/__init__.py b/xtuner/rlhf/loss/__init__.py new file mode 100644 index 000000000..ed50f738d --- /dev/null +++ b/xtuner/rlhf/loss/__init__.py @@ -0,0 +1,4 @@ +from .critic_loss import CriticLoss +from .policy_loss import PPOPolicyLoss, PretrainLoss + +__all__ = ['PPOPolicyLoss', 'PretrainLoss', 'CriticLoss'] diff --git a/xtuner/rlhf/loss/critic_loss.py b/xtuner/rlhf/loss/critic_loss.py new file mode 100644 index 000000000..f043cfe27 --- /dev/null +++ b/xtuner/rlhf/loss/critic_loss.py @@ -0,0 +1,31 @@ +from typing import Any + +import torch + + +class CriticLoss(torch.nn.Module): + """Loss function for critic model.""" + + def __init__(self, cliprange_value: float = 0.5): + super().__init__() + self.cliprange_value = cliprange_value + + def critic_loss_fn(self, values, old_values, returns, mask): + values_clipped = old_values + (values - old_values).clamp( + -self.cliprange_value, self.cliprange_value) + vf_loss1 = (values_clipped - returns)**2 + vf_loss2 = (values - returns)**2 + vf_loss = (torch.max(vf_loss1, vf_loss2) * mask).sum() / mask.sum() + return 0.5 * vf_loss.mean() + + def forward(self, values: torch.Tensor, labels: dict[str, Any]): + assert values.ndim == 2 + mask = labels['mask'] + num_actions = mask.size(1) + values = values[:, -num_actions:] + + old_values = labels['old_values'] + returns = labels['returns'] + loss = self.critic_loss_fn( + values=values, old_values=old_values, returns=returns, mask=mask) + return loss diff --git a/xtuner/rlhf/loss/policy_loss.py b/xtuner/rlhf/loss/policy_loss.py new file mode 100644 index 000000000..f09bfb76c --- /dev/null +++ b/xtuner/rlhf/loss/policy_loss.py @@ -0,0 +1,79 @@ +from typing import Any + +import torch +from loguru import logger + +from ..policy_output import logprobs_from_logits + + +class PretrainLoss(torch.nn.Module): + """Loss function for flash GPT Language Model.""" + + def __init__(self, label_smoothing=0): + super().__init__() + + if label_smoothing is not None and label_smoothing != 0: + logger.warning(f'Use label_smoothing: {label_smoothing}') + self.label_smoothing = label_smoothing + + # the output will gather output is set in the model, + # so use ordinary loss + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='mean', label_smoothing=label_smoothing) + + def forward(self, *args): + if len(args) == 3: + # residual is to match prenorm + logits, _, labels = args + elif len(args) == 2: + # When using postnorm + logits, labels = args + else: + raise RuntimeError( + f'The number of criterion inputs are:{len(args)}') + shift_logits = logits.contiguous().view(-1, logits.size(-1)) + shift_labels = labels.contiguous().view(-1) + loss = self.loss_fn(shift_logits, shift_labels) + # There is no need to consider the ignore_index problem here, + # because the loss calculation will be calculated through the calculation range, # noqa: E501 + # and -100 must be outside this range, + # so there is no problem + + return loss + + +class PPOPolicyLoss(torch.nn.Module): + """Loss function for policy model.""" + + def __init__(self, cliprange: float = 0.2): + super().__init__() + self.cliprange = cliprange + + def policy_loss_fn(self, logprobs, old_logprobs, advantages, mask): + ratio = (logprobs - old_logprobs).exp() + pg_loss1 = -ratio * advantages + pg_loss2 = -ratio.clamp(1 - self.cliprange, + 1 + self.cliprange) * advantages + pg_loss = (torch.max(pg_loss1, pg_loss2) * mask).sum() / mask.sum() + return pg_loss.mean() + + def forward(self, logits: torch.Tensor, labels: dict[str, Any]): + assert logits.ndim == 3 + mask = labels['mask'] + + assert logits.shape[0] == labels['input_ids'].shape[0] + input_ids = labels['input_ids'] + old_logprobs = labels['old_logprobs'] + advantages = labels['advantages'] + + logpy = logprobs_from_logits( + logits=logits[:, :-1, :], labels=input_ids[:, 1:], gather=True) + num_actions = mask.size(1) + logprobs = logpy[:, -num_actions:] + + loss = self.policy_loss_fn( + logprobs=logprobs, + old_logprobs=old_logprobs, + advantages=advantages, + mask=mask) + return loss diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py new file mode 100644 index 000000000..9d692f6bf --- /dev/null +++ b/xtuner/rlhf/main.py @@ -0,0 +1,183 @@ +import argparse +import json +import os +import shutil +import time + +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import KLGAERepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_vllm_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') + + logger.add( + f'{work_dir}/train_rlhf.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger_train = logger.bind(name='train') + + config = Config.from_file(args.config) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config) + model_dict = coordinator.create_models() + ref_model = model_dict['reference'] + policy_model = model_dict['policy'] + reward_model = model_dict['reward'] + critic_model = model_dict['critic'] + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + + # init txt env + rollout_config = config.get('rollout_config', {}) + txt_env = TxtEnv( + policy_model=policy_model, + reward_model=reward_model, + prompt_mes_iter=prompt_mes_iter, + pretrain_mes_iter=pretrain_mes_iter, # None + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + ppo_repeater = KLGAERepeater( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + env=txt_env, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = PPOTrainer( + policy_model=policy_model, critic_model=critic_model, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] + save_interval = train_config['save_interval'] + max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) + + step = max(0, resume_step) + while step <= max_train_step: + s_t = time.time() + with Timer(f'step {step}: end_to_end'): + # generate trajectories + trajectories = txt_env.rollout(display=True) + + # deal with trajectories + trajectories = ppo_repeater.process(trajectories) + + # critic & policy learn + if async_learn: + critic_loss_ref = ppo.critic_learn_async(trajectories) + else: + critic_loss = ppo.critic_learn(trajectories) + + ppo_loss, pt_loss = None, None + if critic_warmup_step <= 0: + ppo_loss, pt_loss = ppo.policy_learn(trajectories) + logger_train.info( + f'[Policy Train] Step: {step}, ' + f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + + if async_learn: + critic_loss = ppo.critic_learn_get(critic_loss_ref) + + logger_train.info( + f'[Critic Train] step: {step}, critic loss: {critic_loss}') + logger_train.info(f'rewards: {trajectories.rewards.mean()}') + critic_warmup_step -= 1 + + if config['rollout_config'].get('write_to_file', True): + if not os.path.exists(f'{work_dir}/rollouts'): + os.makedirs(f'{work_dir}/rollouts') + with open(f'{work_dir}/rollouts/step{step}_rollout.log', + 'a') as file: + for output_s, r in zip(trajectories.output_str, + trajectories.rewards): + file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + '\n' + '=' * 30 + '\n') + summaries = dict( + reward_mean=trajectories.rewards.mean().item(), + reward_std=trajectories.rewards.std().item(), + new_tokens_mean=trajectories.action_mask.sum( + -1).float().mean().item(), + new_tokens_std=trajectories.action_mask.sum( + -1).float().std().item(), + kl=trajectories.kl.mean().item(), + entropy=trajectories.entropy.mean().item(), + step=step, + policy_loss=ppo_loss, + pretrain_loss=pt_loss, + critic_loss=critic_loss, + ) + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + f.write(json.dumps(summaries) + '\n') + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') diff --git a/xtuner/rlhf/model_backend/__init__.py b/xtuner/rlhf/model_backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/model_backend/cuda_memory_stats.py b/xtuner/rlhf/model_backend/cuda_memory_stats.py new file mode 100644 index 000000000..cdf195fad --- /dev/null +++ b/xtuner/rlhf/model_backend/cuda_memory_stats.py @@ -0,0 +1,52 @@ +from loguru import logger + +GB_SHIFT = 30 +MB_SHIFT = 20 + + +class CudaMemoryStats(dict): + # see: https://pytorch.org/docs/stable/generated/torch.cuda.memory_stats.html # noqa: E501 + # def add_memory_stats(self, key, device): + # import torch + # status = torch.cuda.memory_stats(device=device) + # self.__setattr__(key, status) + + @property + def num_gpus(self): + return len(self.keys()) + + @property + def total_current_bytes(self): + CURRENT_BYTE_KEY = 'allocated_bytes.all.current' + total = 0 + for _, v in self.items(): + total += v.get(CURRENT_BYTE_KEY, 0) + return total + + @property + def total_current_gb(self): + return self.total_current_bytes >> GB_SHIFT + + @property + def total_current_mb(self): + return self.total_current_bytes >> MB_SHIFT + + @property + def avg_current_bytes(self): + return self.total_current_bytes / self.num_gpus if self.num_gpus != 0 else 0 # noqa: E501 + + def __repr__(self): + return f'CudaMemoryStats: {self.num_gpus} GPU takes {self.total_current_mb} MiB' # noqa: E501 + + +def merge_cuda_memory_stats_list( + dict_list: list[CudaMemoryStats]) -> CudaMemoryStats: + if isinstance(dict_list, CudaMemoryStats): + logger.warning('dict_list is a CudaMemoryStatus instead of a list') + return dict_list + memory_stats_dict: CudaMemoryStats = dict_list[0] + assert isinstance(memory_stats_dict, CudaMemoryStats) + if len(dict_list) > 1: + for m in dict_list[1:]: + memory_stats_dict.update(m) + return memory_stats_dict diff --git a/xtuner/rlhf/model_backend/dist_utils.py b/xtuner/rlhf/model_backend/dist_utils.py new file mode 100644 index 000000000..30e63a229 --- /dev/null +++ b/xtuner/rlhf/model_backend/dist_utils.py @@ -0,0 +1,63 @@ +from datetime import timedelta +from typing import Any, Optional, Union + +from torch.distributed.distributed_c10d import (Backend, PrefixStore, Store, + _new_process_group_helper, + _world, default_pg_timeout, + rendezvous) + + +# Adapted from https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # noqa: E501 +def init_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = '', + pg_options: Optional[Any] = None, +): + assert (store is None) or ( + init_method is None), 'Cannot specify both init_method and store.' + + if store is not None: + assert world_size > 0, 'world_size must be positive if using store' + assert rank >= 0, 'rank must be non-negative if using store' + elif init_method is None: + init_method = 'env://' + + if backend: + backend = Backend(backend) + else: + backend = Backend('undefined') + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous( + init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + pg = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + pg_options=pg_options, + timeout=timeout, + ) + + pg = pg[0] if isinstance(pg, tuple) else pg + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py new file mode 100644 index 000000000..15fa0b669 --- /dev/null +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -0,0 +1,195 @@ +from typing import Optional, Union + +import torch +from transformers import PreTrainedTokenizer + + +def get_question_answer_mask( + input_ids: torch.Tensor, + output_ids: torch.Tensor, + tokenizer_pad_token_id: int, + generate_pad_token_id: int = None, +): + """ + Example: + input_ids = torch.tensor([[0, 1, 9]]) + output_ids = torch.tensor([[0, 1, 9, 2, 3, 4, 5]]) + tokenizer_pad_token_id = 0 # set 0 as neither question or answer + generate_pad_token_id = None + expected_qst_mask = torch.tensor([[0, 1, 1, 0, 0, 0, 0]]) + expected_ans_mask = torch.tensor([[0, 0, 0, 1, 1, 1, 1]]) + """ + # seq_mask yields zero where token == pad_token_id + seq_mask = output_ids.not_equal(tokenizer_pad_token_id).int() + if generate_pad_token_id is not None: + seq_mask *= output_ids.not_equal(generate_pad_token_id).int() + + question_len = input_ids.shape[-1] + question_mask = seq_mask.clone() + question_mask[:, question_len:] = 0 + answer_mask = seq_mask.clone() + answer_mask[:, :question_len] = 0 + return question_mask, answer_mask + + +def partition_by_micro_batch_size( + input_ids: Union[torch.Tensor, list[int]], + micro_batch_size: int, + attention_mask: torch.Tensor = None, + position_ids: torch.Tensor = None, + labels: Optional[Union[list[torch.Tensor], torch.Tensor, + dict[str, torch.Tensor]]] = None, +) -> list[dict[str, torch.Tensor]]: + max_inputs_length = get_longest_list_length(input_ids) if isinstance( + input_ids, list) else None + micro_batches: list[dict[str, torch.Tensor]] = [] + batch_size = input_ids.shape[0] if isinstance( + input_ids, torch.Tensor) else len(input_ids) + if micro_batch_size <= 0 or batch_size == micro_batch_size: + micro_batch = {} + micro_batch['input_ids'] = input_ids + micro_batch['attention_mask'] = attention_mask + micro_batch['position_ids'] = position_ids + micro_batch['labels'] = labels + micro_batch['max_inputs_length'] = max_inputs_length + micro_batches.append(micro_batch) + return micro_batches + if micro_batch_size > batch_size: + micro_batch_size = batch_size + + num_splits = int(batch_size // micro_batch_size) + ( + batch_size % micro_batch_size > 0) + if isinstance(input_ids, torch.Tensor): + input_ids_split = torch.split(input_ids, micro_batch_size, dim=0) + attention_mask_split = ( + torch.split(attention_mask, micro_batch_size, dim=0) if + attention_mask is not None else [None for _ in range(num_splits)]) + else: + input_ids_split = [ + input_ids[i:i + micro_batch_size] + for i in range(0, len(input_ids), micro_batch_size) + ] + attention_mask_split = [ + attention_mask[i:i + micro_batch_size] if attention_mask + is not None else [None for _ in range(num_splits)] for i in range( + 0, + len(attention_mask + ) if attention_mask is not None else num_splits * + micro_batch_size, micro_batch_size) + ] + position_ids_split = ( + torch.split(position_ids, micro_batch_size, dim=0) + if position_ids is not None else [None for _ in range(num_splits)]) + labels_split = ( + partition_label_by_micro_batch_size(labels, micro_batch_size, + num_splits) + if labels is not None else [None for _ in range(num_splits)]) + for i in range(num_splits): + micro_batch = {} + micro_batch['input_ids'] = input_ids_split[i] + micro_batch['attention_mask'] = attention_mask_split[i] + micro_batch['position_ids'] = position_ids_split[i] + micro_batch['labels'] = labels_split[i] + micro_batch['max_inputs_length'] = max_inputs_length + micro_batches.append(micro_batch) + return micro_batches + + +def partition_label_by_micro_batch_size( + labels: Union[list[torch.Tensor], torch.Tensor, dict[str, torch.Tensor]], + micro_batch_size: int, + num_splits: int = 1, +): + if isinstance(labels, torch.Tensor): + return torch.split(labels, micro_batch_size, dim=0) + if isinstance(labels, list): + return [ + labels[i:i + micro_batch_size] + for i in range(0, len(labels), micro_batch_size) + ] + if isinstance(labels, dict): + split = [{} for _ in range(num_splits)] + for key in labels.keys(): + if key == 'loss_factor': + for i in range(num_splits): + split[i][key] = labels[key] + else: + tensors = partition_label_by_micro_batch_size( + labels[key], micro_batch_size) + for i in range(num_splits): + split[i][key] = tensors[i] + return split + + +def partition_list_by_micro_batch_size( + input_ids: list[torch.Tensor], + micro_batch_size: list[int], + labels: list[torch.Tensor], + attention_mask: Optional[list[torch.Tensor]] = None, + position_ids: Optional[list[torch.Tensor]] = None, +) -> list[dict]: + length = len(input_ids) + batch_size = input_ids[0].shape[0] + num_splits = int(batch_size // micro_batch_size[0]) + ( + batch_size % micro_batch_size[0] > 0) + micro_batches = [[{} for i in range(length)] for _ in range(num_splits)] + if attention_mask is None: + attention_mask = [None for _ in range(length)] + if position_ids is None: + position_ids = [None for _ in range(length)] + for i in range(length): + sub_input_ids = input_ids[i] + sub_attention_mask = attention_mask[i] + sub_position_ids = position_ids[i] + sub_labels = labels[i] + sub_micro_batches = partition_by_micro_batch_size( + sub_input_ids, micro_batch_size[i], sub_attention_mask, + sub_position_ids, sub_labels) + for micro_batch_index, sub_micro_batch in enumerate(sub_micro_batches): + micro_batches[micro_batch_index][i]['input_ids'] = sub_micro_batch[ + 'input_ids'] + micro_batches[micro_batch_index][i][ + 'attention_mask'] = sub_micro_batch['attention_mask'] + micro_batches[micro_batch_index][i][ + 'position_ids'] = sub_micro_batch['position_ids'] + micro_batches[micro_batch_index][i]['labels'] = sub_micro_batch[ + 'labels'] + return micro_batches + + +def merge_loss_list(loss_list_mb: list[list[torch.Tensor]]): + micro_batch_num = len(loss_list_mb) + loss_num = len(loss_list_mb[0]) + loss_list = [i for i in range(loss_num)] + for loss_index in range(loss_num): + losses = [] + for batch_index in range(micro_batch_num): + losses.append(loss_list_mb[batch_index][loss_index]) + loss_list[loss_index] = sum(losses) / micro_batch_num + return loss_list + + +def get_answer_str( + tokenizer: PreTrainedTokenizer, + output_ids: torch.Tensor, + answer_mask: torch.Tensor, +): + answer_ids = output_ids * answer_mask + zero_mask = answer_ids.eq(0) + answer_ids = zero_mask * tokenizer.all_special_ids[0] + answer_ids + + answer_str = tokenizer.batch_decode( + answer_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + return answer_str + + +def get_longest_list_length(list_of_lists): + max_length = 0 + for int_list in list_of_lists: + current_length = len(int_list) + if current_length > max_length: + max_length = current_length + return max_length diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py new file mode 100644 index 000000000..4786d1038 --- /dev/null +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -0,0 +1,886 @@ +import glob +import os +import socket +from typing import Optional, Union + +import ray +import torch +from accelerate import Accelerator +from accelerate.utils import FullyShardedDataParallelPlugin +from loguru import logger +from ray.util.placement_group import placement_group as create_placement_group +from ray.util.placement_group import remove_placement_group +from torch.nn.modules.loss import _Loss +from torch.optim.lr_scheduler import _LRScheduler +from transformers import AutoModelForCausalLM, PreTrainedModel +from transformers import get_scheduler as transformers_get_scheduler +from transformers.dynamic_module_utils import init_hf_modules +from transformers.generation.utils import GenerateDecoderOnlyOutput + +from ..config.config_consts import (ENGINE_PLUGIN_DDP, ENGINE_PLUGIN_DEEPSPEED, + ENGINE_PLUGIN_FSDP) +from ..config.config_utils import get_dp_size, get_gpu_requirement +from ..policy_output import (PolicyOutput, concat_policy_outputs, + logprobs_from_logits) +from ..tokenizer import get_tokenizer +from ..utils import set_seed +from .dist_utils import init_process_group +from .generate_utils import (get_answer_str, get_question_answer_mask, + partition_by_micro_batch_size, + partition_list_by_micro_batch_size) +from .ray_actor_group import RayActorGroup +from .ray_actor_mixin import RayActorMixin +from .ray_utils import DEFAULT_NUM_CPUS, DEFAULT_NUM_GPUS, create_ray_actors + +DEFAULT_NEW_TOKENS = 64 +MAXIMUM_NEW_TOKENS = 1024 +""" +HfModelRunner can be individually called by other process +HfModelRunnerRayActor is called by ModelServer with .remote() +""" + + +class HfModelRunner: + """HfModelRunner is capable of training, inference, and generation.""" + + def __init__(self, model_config): + self.model_config: dict = model_config + + def initialize(self): + # 0. Environment + envs = self.model_config.get('envs', {}) + for key, value in envs.items(): + os.environ[key] = value + + # Parallel Settings + parallel: dict = self.model_config['parallel'] + assert parallel['tensor']['size'] == 1 # TODO: support TP + assert parallel['pipeline']['size'] == 1 # TODO: support PP + self.update_step = 0 + self.zero_stage = 1 + mixed_precision = self.model_config.get('mixed_precision', None) + if parallel['data'].get('mode') == ENGINE_PLUGIN_FSDP: + self.accelerator = Accelerator( + fsdp_plugin=FullyShardedDataParallelPlugin()) + self.zero_stage = 3 + elif parallel['data'].get('mode') == ENGINE_PLUGIN_DEEPSPEED: + from accelerate import DeepSpeedPlugin + + ds_config = self.model_config['deepspeed_config'] # requisite + self.accelerator = Accelerator( + deepspeed_plugin=DeepSpeedPlugin(ds_config)) + self.zero_stage = ds_config['zero_optimization']['stage'] + else: + self.accelerator = Accelerator(mixed_precision=mixed_precision) + self.zero_stage = 0 + + # 1. Model + model_path = self.model_config.get('model_path') + self.model_type = self.model_config.get('model_type', '').lower() + torch_dtype = self.model_config.get('torch_dtype', 'auto') + use_flash_attn = self.model_config.get('use_flash_attn', None) + model_class = self.model_config.get('model_class', + AutoModelForCausalLM) + self.model: PreTrainedModel = model_class.from_pretrained( + pretrained_model_name_or_path=model_path, + device_map=None if self.zero_stage == 3 else 'auto', + torch_dtype=torch_dtype, + trust_remote_code=True, + attn_implementation='flash_attention_2' + if use_flash_attn else None, + ) + + # Graident checkpointing + gradient_checkpointing = self.model_config.get( + 'gradient_checkpointing', False) + if gradient_checkpointing: + self.model.gradient_checkpointing_enable() + self.vocab_size = self.model.config.vocab_size + + # 2. Tokenizer + tokenizer_path = self.model_config.get('tokenizer_path', model_path) + tokenizer_config = self.model_config.get('tokenizer_config', {}) + self.tokenizer = get_tokenizer( + tokenizer_path, trust_remote_code=True, **tokenizer_config) + + # 3. Trainer + train_kwargs = self.model_config.get('train_kwargs') + if train_kwargs is None: # requires no training + self.model = self.accelerator.prepare( + self.model) if self.zero_stage == 3 else self.model + self.device = self.accelerator.device + logger.info( + f'[{self.model_type}] __init__() done without train_kwargs.') + return + optimizer_type = train_kwargs.get('optimizer', torch.optim.AdamW) + learning_rate = train_kwargs.get('lr', 1e-5) + self.clip_grad_norm = train_kwargs.get('clip_grad_norm', 1.0) + self.optimizer: torch.optim.Optimizer = optimizer_type( + params=self.model.parameters(), + lr=learning_rate, + ) + + lr_scheduler_type = train_kwargs.get('lr_scheduler', 'linear') + lr_scheduler_kwargs = train_kwargs.get( + 'lr_scheduler_kwargs', + { + 'num_warmup_steps': 0, + 'num_training_steps': 10000000000 + }, + ) + self.lr_scheduler: _LRScheduler = transformers_get_scheduler( + lr_scheduler_type, + optimizer=self.optimizer, + **lr_scheduler_kwargs, + ) + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( # noqa: E501 + self.model, self.optimizer, self.lr_scheduler) + + # resume optimizer, lr_scheduler + if bool(len(glob.glob(os.path.join(model_path, '*.step')))): + self._resume_load_pretrained(model_path=model_path) + + # Others + self.device = self.accelerator.device + set_seed(self.model_config.get('seed')) + if mixed_precision is not None: + self.info_rank0( + f'[{self.model_type}]: Enable mixed_precision = {mixed_precision}' # noqa: E501 + ) + if gradient_checkpointing: + self.info_rank0( + f'[{self.model_type}]: Enable gradient_checkpointing') + self.info_rank0( + f'[{self.model_type}] __init__() done with optimizer {self.optimizer.optimizer}.' # noqa: E501 + ) + + def _resume_load_pretrained(self, model_path): + _, step_pt = os.path.split( + glob.glob(os.path.join(model_path, '*.step'))[0]) + self.update_step = int(step_pt.split('.step')[0]) + logger.info(f'Resume train step {self.update_step} from {model_path}') + assert os.path.exists(os.path.join(model_path, 'saved_state')) + self.accelerator.load_state(os.path.join(model_path, 'saved_state')) + + def compute_loss( + self, + input_ids: torch.Tensor, + labels: Optional[Union[torch.Tensor, dict[str, torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + criterion: Optional[_Loss] = None, + loss_weight: Optional[float] = None, + **_ignored, + ) -> torch.Tensor: + input_ids = input_ids.to(self.device) + labels = input_ids.clone() if labels is None else labels + if attention_mask is not None: + if position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + batch = { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask.to(self.device) + if attention_mask is not None else None, + 'position_ids': + position_ids.to(self.device) if position_ids is not None else None + } + self.model.train() + + if criterion is None: + # OPT. A) Default settings + assert isinstance( + labels, torch.Tensor + ), 'Please pass in `criterion` for non-tensor labels' + batch['labels'] = labels.to(self.device) + fwd_output = self.model(**batch, use_cache=False) + loss = fwd_output.loss + elif isinstance(labels, torch.Tensor): + # OPT. B) Use preset loss functions, e.g., torch.nn.CrossEntropyLoss() # noqa: E501 + # Adopted from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1199 # noqa: E501 + logits: torch.Tensor = self.model(**batch, use_cache=False).logits + labels = labels.to(self.device) + loss = criterion(logits, labels) + elif isinstance(labels, dict): + # OPT. C) Use customized loss function, see loss/policy_loss.py + logits: torch.Tensor = self.model( + **batch, use_cache=False, return_dict=True).logits + for k, v in labels.items(): + labels[k] = v.to(self.device) + loss = criterion(logits, labels) + else: + raise ValueError(f'labels of unsupported type: {type(labels)}') + + if loss_weight is not None: + loss *= loss_weight + return loss + + def parameter_update(self, step_interval=1): + self.info_rank0(f'[{self.model_type}] self.parameter_update()') + self.update_step += 1 + if self.update_step % step_interval == 0: + self.accelerator.clip_grad_norm_(self.model.parameters(), + self.clip_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + def train( + self, + input_ids: Union[list[torch.Tensor], torch.Tensor], + labels: Optional[Union[list[torch.Tensor], torch.Tensor, + dict[str, torch.Tensor]]] = None, + attention_mask: Optional[Union[list[torch.Tensor], + torch.Tensor]] = None, + position_ids: Optional[Union[list[torch.Tensor], torch.Tensor]] = None, + criterion: Optional[Union[list[_Loss], _Loss]] = None, + loss_weights: Optional[Union[list[float], float]] = None, + step_interval: int = 1, + # None means using the entire input as one batch + micro_batch_size: Optional[Union[list[int], int]] = None, + debug=False, + **_ignored, + ): + if isinstance(input_ids, torch.Tensor): + input_ids = [input_ids] + labels = [labels] + attention_mask = [attention_mask] + position_ids = [position_ids] + criterion = [criterion] + loss_weights = [loss_weights] + micro_batch_size = [micro_batch_size] + else: + if attention_mask is None: + attention_mask = [None for _ in range(len(input_ids))] + if position_ids is None: + position_ids = [None for _ in range(len(input_ids))] + if criterion is None: + criterion = [None for _ in range(len(input_ids))] + if loss_weights is None: + loss_weights = [None for _ in range(len(input_ids))] + if micro_batch_size is None: + micro_batch_size = [None for _ in range(len(input_ids))] + + assert isinstance(input_ids, list) + + loss_list = [[] for _ in range(len(input_ids))] + for index in range(len(input_ids)): + mb_size_entry = micro_batch_size[index] + if mb_size_entry is None: + micro_batches: list[dict[str, torch.Tensor]] = [] + micro_batches.append({ + 'input_ids': input_ids[index], + 'attention_mask': attention_mask[index], + 'position_ids': position_ids[index], + 'labels': labels[index] + }) + else: + micro_batches = partition_by_micro_batch_size( + input_ids=input_ids[index], + micro_batch_size=micro_batch_size[index], + attention_mask=attention_mask[index], + position_ids=position_ids[index], + labels=labels[index], + ) + loss_entry = [] + for mb_index, micro_batch in enumerate(micro_batches): + if mb_index == 0: + self.info_rank0( + f"[{self.model_type}] will train input_ids[{mb_index}] shape[{micro_batch['input_ids'].shape}] * {len(micro_batches)} times" # noqa: E501 + ) + # compute loss and backward + loss = self.compute_loss( + input_ids=micro_batch['input_ids'], + labels=micro_batch['labels'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids'], + criterion=criterion[index], + loss_weight=loss_weights[index], + ) + self.accelerator.backward(loss) + loss_entry.append(loss) + if debug: + set_seed(1234) + loss_list[index] = sum(loss_entry) / len(loss_entry) + + self.parameter_update(step_interval) + return loss_list if len(loss_list) > 1 else loss_list[0] + + # Inference + @torch.no_grad() + def _infer( + self, + input_ids: torch.Tensor, + attention_mask=None, + output_logprobs=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + infer_kwargs: Optional[dict] = {}, + **_ignored, + ) -> PolicyOutput: + assert isinstance(input_ids, torch.Tensor) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_output = self.model( + input_ids.to(self.device), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids.to(self.device), + return_dict=True, + **infer_kwargs, + ) + + output = PolicyOutput() + if output_logits: + output['logits'] = model_output['logits'] + if output_attentions: + output['attentions'] = model_output['attentions'] + if output_hidden_states: + output['hidden_states'] = model_output['hidden_states'] + if output_logprobs: + log_probs = logprobs_from_logits( + logits=model_output['logits'][:, :-1, :], + labels=input_ids[:, 1:], + gather=True, + ) + output['logprobs'] = log_probs + output.to('cpu') + return output + + @torch.no_grad() + def infer( + self, + input_ids: torch.Tensor, + micro_batch_size: Optional[ + int] = -1, # -1: use the entire input as one batch + tokenizer=None, # Only used for reward models + attention_mask=None, + output_logprobs=False, + output_logits=True, + output_attentions=False, + output_hidden_states=False, + infer_kwargs: Optional[dict] = {}, + debug=False, + **_ignored, + ) -> PolicyOutput: + self.info_rank0( + f'[{self.model_type}] self.infer() kwargs: {infer_kwargs}') + input_ids = input_ids.to(self.device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) + # returns entire-input-as-one-batch inference results + if micro_batch_size < 0: + self.info_rank0( + f'[{self.model_type}] infer() input_ids.shape: {input_ids.shape}' # noqa: E501 + ) + return self._infer( + input_ids, + attention_mask, + output_logprobs, + output_logits, + output_attentions, + output_hidden_states, + infer_kwargs, + ) + + # Otherwise, partition the input into micro batches and run inference on each micro batch separately # noqa: E501 + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + policy_outputs = [] + for index, micro_batch in enumerate(micro_batches): + input_ids_mb = micro_batch['input_ids'] + attention_mask_mb = micro_batch['attention_mask'] + if index == 0: + self.info_rank0( + f'[{self.model_type}] will infer() input_ids_mb.shape: {input_ids_mb.shape} * {len(micro_batches)} times' # noqa: E501 + ) + policy_output_mb = self._infer( + input_ids_mb, + attention_mask_mb, + output_logprobs, + output_logits, + output_attentions, + output_hidden_states, + infer_kwargs, + ) + policy_outputs.append(policy_output_mb) + if debug: + self.set_seed(1234) + # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 + return concat_policy_outputs(policy_outputs) + + # Generate + @torch.no_grad() + def _generate( + self, + input_ids: torch.Tensor, + attention_mask=None, + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + generate_kwargs: Optional[dict] = {}, + ) -> PolicyOutput: + assert isinstance(input_ids, torch.Tensor) + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model = self.accelerator.unwrap_model(self.model) + else: + model = self.model + + max_new_tokens = ( + MAXIMUM_NEW_TOKENS + if 'eos_token_id' in generate_kwargs else DEFAULT_NEW_TOKENS) + max_new_tokens = step if step > 0 else max_new_tokens + + # TODO: stop if meeting eos_token_id + model_output: GenerateDecoderOnlyOutput = model.generate( + input_ids.to(model.device), + use_cache=True, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_logits=output_logits, # transformers >= 4.38.2 + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + attention_mask=attention_mask, + **generate_kwargs, + ) + + output_ids = model_output['sequences'] + self.info_rank0( + f'generate input_ids shape:[{input_ids.shape}], output_ids shape:[{output_ids.shape}]' # noqa: E501 + ) + output = PolicyOutput(output_ids=output_ids) + # masks + output['question_mask'], output[ + 'answer_mask'] = get_question_answer_mask( + input_ids, + output_ids, + tokenizer_pad_token_id=self.tokenizer.pad_token_id, + generate_pad_token_id=generate_kwargs.get('pad_token_id'), + ) + output['attention_mask'] = output.question_mask + output.answer_mask + output['action_mask'] = output['attention_mask'][:, + input_ids.size(1) - + 1:-1] + + if output_logits: + output['logits'] = model_output['logits'] # tuple(torch.Tensor, ) + if output_attentions: + output['attentions'] = model_output['attentions'] + if output_hidden_states: + output['hidden_states'] = model_output['hidden_states'] + if output_str: # customized post processing + output['output_str'] = self.tokenizer.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output['output_ans_str'] = get_answer_str( + tokenizer=self.tokenizer, + output_ids=output_ids, + answer_mask=output.answer_mask, + ) + + output.to('cpu') + return output + + # Generate + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + micro_batch_size: Optional[ + int] = -1, # -1: use the entire input as one batch + attention_mask=None, + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + chat_template=None, + generate_kwargs: Optional[dict] = {}, + debug=False, + **_ignored, + ) -> PolicyOutput: + self.info_rank0( + f'[{self.model_type}] self.generate() kwargs: {generate_kwargs}') + input_ids = input_ids.to(self.device) + if attention_mask is not None: + assert isinstance(attention_mask, torch.Tensor) + attention_mask = attention_mask.to(self.device) + + if micro_batch_size < 0: + return self._generate( + input_ids, + attention_mask, + step, + output_str, + output_logits, + output_attentions, + output_hidden_states, + generate_kwargs, + ) + + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + policy_outputs = [] + for micro_batch in micro_batches: + input_ids_mb = micro_batch['input_ids'] + attention_mask_mb = micro_batch['attention_mask'] + policy_output_mb = self._generate( + input_ids_mb, + attention_mask_mb, + step, + output_str, + output_logits, + output_attentions, + output_hidden_states, + generate_kwargs, + ) + policy_outputs.append(policy_output_mb) + if debug: + self.set_seed(1234) + + padding_token_map = {'output_ids': self.tokenizer.pad_token_id} + return concat_policy_outputs(policy_outputs, padding_token_map) + + def get_model(self): + parallel: dict = self.model_config['parallel'] + dp = parallel['data'].get('size') + dp_mode = parallel['data'].get('mode') + if dp > 1 and dp_mode != ENGINE_PLUGIN_DDP: + raise ('please use get_state_dict instead when using parallel') + _model = self.accelerator.unwrap_model(self.model) + return _model + + def get_state_dict(self): + state_dict = self.accelerator.get_state_dict(self.model) + if not self.accelerator.is_main_process: + return None + return state_dict + + def set_seed(self, seed=None): + set_seed(seed) + + def save(self, path): + # for resume + self.accelerator.wait_for_everyone() + self.accelerator.save_state(os.path.join(path, 'saved_state')) + + # save model, tokenizer, step + if not self.accelerator.is_main_process: + self.accelerator.get_state_dict(self.model) + return + else: + path = os.path.normpath(path) + logger.info(f'[Train step {self.update_step}] ' + f'Saving {self.model_type} to {path} ...') + # save model + unwrapped_model = self.accelerator.unwrap_model(self.model) + unwrapped_model.save_pretrained( + path, + is_main_process=True, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + # save tokenizer + if self.tokenizer is not None: + self.tokenizer.save_pretrained(path) + torch.save(self.update_step, + os.path.join(path, f'{self.update_step}.step')) + logger.info(f'{self.model_type} saved.') + + def info_rank0(self, content): + if self.accelerator.is_main_process: + logger.info(content) + + +# Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/ppo_actor.py # noqa: E501 +class HfModelRunnerRayActor(HfModelRunner, RayActorMixin): + """A ray.remote Actor Class initialized by HfModelRunnerRayActorGroup, + extending HfModelRunner with ray related method via RayActorMixin.""" + + def init_process_group(self, generator): + if self.accelerator.is_main_process: + # init process groups for vllm engine + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(('', 0)) + master_port = sock.getsockname()[1] + + world_size = generator.dp_size * generator.tp_size + 1 + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * generator.tp_size + 1, + world_size, + 'vllm', + ) for i, engine in enumerate(generator.ray_actors) + ] + self._model_update_group = init_process_group( + backend='nccl', + init_method=f'tcp://{master_address}:{master_port}', + world_size=world_size, + rank=0, + group_name='vllm', + ) + ray.get(refs) + + def broadcast_model_to_generator(self, generator): + # TODO: Support Pytorch FSDP. + if self.model_config['parallel']['data'].get( + 'mode') == ENGINE_PLUGIN_FSDP: + raise NotImplementedError('FSDP is not supported yet.') + logger.info('Broadcast BEGIN') + model = self.accelerator.unwrap_model(self.model) + for name, param in model.named_parameters(): + if self.accelerator.is_main_process: + shape = param.shape if self.zero_stage != 3 else param.ds_shape + + for engine in generator.ray_actors: + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape) + + if self.zero_stage != 3: + if self.accelerator.is_main_process: + torch.distributed.broadcast( + param.data, 0, group=self._model_update_group) + else: + from deepspeed.runtime.zero.partition_parameters import \ + GatheredParameters + + with GatheredParameters([param]): + if self.accelerator.is_main_process: + torch.distributed.broadcast( + param.data, 0, group=self._model_update_group) + + logger.info('Broadcast END') + + +class HfModelRunnerRayActorGroup(RayActorGroup): + """HfModelRunnerRayActorGroup manages a list of HfModelRunnerRayActor + create ray actors.""" + + # avoid ModuleNotFoundError: No module named 'transformers_modules' + # refer to https://github.com/vllm-project/vllm/pull/871 + init_hf_modules() + + def __init__(self, name: str, config: dict): + super().__init__(name, config) + self.released = True + num_gpus = get_gpu_requirement(config) + self.dp_size = get_dp_size(config) + self.tokenizer_pad_token_id = config.tokenizer_config['pad_token_id'] + bundles = [{ + 'CPU': DEFAULT_NUM_CPUS, + 'GPU': DEFAULT_NUM_GPUS + } for _ in range(num_gpus)] + self.placement_group = create_placement_group(bundles) + self.ray_actors: list[HfModelRunnerRayActor] = create_ray_actors( + name_prefix=name, + config=config, + placement_group=self.placement_group, + trainer_class=ray.remote( + num_cpus=DEFAULT_NUM_CPUS, + num_gpus=DEFAULT_NUM_GPUS)(HfModelRunnerRayActor), + ) + self.released = False + + master_ip = ray.get(self.ray_actors[0].get_metadata.remote()).node_ip + master_port = ray.get(self.ray_actors[0].get_free_port.remote()) + ray.get([ + actor.inject_distribute_env.remote( + master_ip=master_ip, + master_port=master_port, + rank_id=rank, + world_size=len(self.ray_actors), + ) for rank, actor in enumerate(self.ray_actors) + ]) + self.initialize_ref = [ + actor.initialize.remote() for actor in self.ray_actors + ] + + def initialize_get(self): + if self.initialize_ref is not None: + ray.get(self.initialize_ref) + else: + logger.info( + 'self.initialize_get None, maybe self.generator==self.trainer') + self.initialize_ref = None + + # Training + def train_async(self, input_ids, labels, attention_mask, position_ids, + *args, **kwargs): + if isinstance(input_ids, torch.Tensor): + micro_batch_size = input_ids.shape[0] // self.dp_size + ( + input_ids.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size( + input_ids, micro_batch_size, attention_mask, position_ids, + labels) + assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[index].train.remote( + input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids'], + labels=micro_batch['labels'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + elif isinstance(input_ids, list): + """a list of tensors whose training loss will be taken average.""" + assert isinstance(input_ids[0], torch.Tensor) + micro_batch_size = [i for i in range(len(input_ids))] + for index, input_id in enumerate(input_ids): + micro_batch_size[index] = input_id.shape[0] // self.dp_size + ( + input_id.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_list_by_micro_batch_size( + input_ids=input_ids, + micro_batch_size=micro_batch_size, + labels=labels, + attention_mask=attention_mask, + position_ids=position_ids, + ) + assert len(micro_batches) == self.dp_size + object_refs = [] + for index, micro_batch in enumerate(micro_batches): + input_ids_mb = [] + attention_mask_mb = [] + position_ids_mb = [] + labels_mb = [] + for i in range(len(micro_batch)): + input_ids_mb.append(micro_batch[i]['input_ids']) + attention_mask_mb.append(micro_batch[i]['attention_mask']) + position_ids_mb.append(micro_batch[i]['position_ids']) + labels_mb.append(micro_batch[i]['labels']) + object_ref = self.ray_actors[index].train.remote( + input_ids=input_ids_mb, + attention_mask=attention_mask_mb, + position_ids=position_ids_mb, + labels=labels_mb, + *args, + **kwargs, + ) + object_refs.append(object_ref) + return object_refs + + def train_get(self, object_refs, timeout=None): + losses = ray.get(object_refs, timeout=timeout) + if isinstance(losses[0], list): + p_loss = [sub_loss[0] for sub_loss in losses] + pt_loss = [sub_loss[1] for sub_loss in losses] + return [sum(p_loss) / len(p_loss), sum(pt_loss) / len(pt_loss)] + else: + return sum(losses) / len(losses) + + def train(self, *args, **kwargs): + object_refs = self.train_async(*args, **kwargs) + return self.train_get(object_refs) + + # Inference + def infer_async(self, input_ids, attention_mask, *args, **kwargs): + micro_batch_size = input_ids.shape[0] // self.dp_size + ( + input_ids.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[index].infer.remote( + input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + + def infer_get(self, object_refs, timeout=None): + outputs = ray.get(object_refs, timeout=timeout) + return concat_policy_outputs(outputs) + + def infer(self, *args, **kwargs): + object_refs = self.infer_async(*args, **kwargs) + return self.infer_get(object_refs) + + # Generation + def generate_async(self, input_ids, attention_mask, *args, **kwargs): + micro_batch_size = input_ids.shape[0] // self.dp_size + ( + input_ids.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[index].generate.remote( + input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + + def generate_get(self, object_refs, timeout=None): + outputs = ray.get(object_refs, timeout=timeout) + padding_token_map = { + 'output_ids': self.config.tokenizer_config.pad_token_id + } + return concat_policy_outputs(outputs, padding_token_map) + + def generate(self, *args, **kwargs): + object_refs = self.generate_async(*args, **kwargs) + return self.generate_get(object_refs) + + # Others + def get_model(self): + return self.ray_actors[0].get_model.remote() + + def get_state_dict(self): + state_dicts = [ + actor.get_state_dict.remote() for actor in self.ray_actors + ] + return state_dicts[0] + + def set_seed(self, seed=None): + ray.get([actor.set_seed.remote(seed) for actor in self.ray_actors]) + + def release_resources(self): + """release ray resources.""" + if self.released: + return + for actor in self.ray_actors: + try: + ray.kill(actor=actor, no_restart=True) + except BaseException as exp: + logger.error(f'failed to kill ray actor {actor}. {exp}') + remove_placement_group(self.placement_group) + self.released = True + + def save(self, path): + ray.get([actor.save.remote(path) for actor in self.ray_actors]) + + def init_process_group(self, generator): + refs = [ + hfm.init_process_group.remote(generator) + for i, hfm in enumerate(self.ray_actors) + ] + ray.get(refs) + + def broadcast_model_to_generator(self, generator: None): + refs = [ + hfm.broadcast_model_to_generator.remote(generator) + for i, hfm in enumerate(self.ray_actors) + ] + ray.get(refs) diff --git a/xtuner/rlhf/model_backend/net_utils.py b/xtuner/rlhf/model_backend/net_utils.py new file mode 100644 index 000000000..7fc715836 --- /dev/null +++ b/xtuner/rlhf/model_backend/net_utils.py @@ -0,0 +1,31 @@ +import socket + + +def get_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + s.connect(('10.254.254.254', 1)) + local_ip = s.getsockname()[0] + except BaseException: + local_ip = '127.0.0.1' + finally: + s.close() + return local_ip + + +def get_ip_hostname(): + hostname = socket.gethostname() + return get_ip(), hostname + + +def get_free_port() -> int: + """Get a free port for the actor to use for DDP dist_init. + + Returns: A free port that could be used. + """ + tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tcp.bind(('', 0)) + _, port = tcp.getsockname() + tcp.close() + return port diff --git a/xtuner/rlhf/model_backend/ray_actor_group.py b/xtuner/rlhf/model_backend/ray_actor_group.py new file mode 100644 index 000000000..a2da48a7e --- /dev/null +++ b/xtuner/rlhf/model_backend/ray_actor_group.py @@ -0,0 +1,19 @@ +import ray + +from .cuda_memory_stats import merge_cuda_memory_stats_list +from .ray_actor_mixin import RayActorMixin + + +class RayActorGroup: + + def __init__(self, name: str, config: dict): + self.config = config + self.name = name # name_prefix for ray_actors + self.ray_actors: list[RayActorMixin] = [] + + def get_cuda_mem_stats(self): + return merge_cuda_memory_stats_list( + ray.get([ + ray_actor.get_memory_stats_of_visible_devices.remote() + for ray_actor in self.ray_actors + ])) diff --git a/xtuner/rlhf/model_backend/ray_actor_mixin.py b/xtuner/rlhf/model_backend/ray_actor_mixin.py new file mode 100644 index 000000000..c075fc13e --- /dev/null +++ b/xtuner/rlhf/model_backend/ray_actor_mixin.py @@ -0,0 +1,92 @@ +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch + +from .cuda_memory_stats import CudaMemoryStats +from .net_utils import get_free_port, get_ip, get_ip_hostname + + +@dataclass +class RayActorMetadata: + """Metadata for Ray actor. + + This information is expected to stay the same throughout the lifetime of actor. # noqa: E501 + + Args: + node_ip (str): Node IP address that this actor is on. + hostname (str): Hostname that this actor is on. + gpu_ids (Optional[list[int]]): List of CUDA IDs available to this actor. # noqa: E501 + gpu_num (int): Number of used GPUs of this actor. + """ + + node_ip: str + hostname: str + gpu_ids: Optional[list[int]] + gpu_num: int + + def __str__(self) -> str: + info = { + 'Node_IP': self.node_ip, + 'Hostname': self.hostname, + 'GPU_IDs': self.gpu_ids, + 'GPU_Num': self.gpu_num, + } + return json.dumps(info, indent=4, sort_keys=True) + + +class RayActorMixin: + + def inject_distribute_env( + self, + master_ip: Optional[str] = None, + master_port: int = 0, + rank_id: int = 0, + world_size: int = 0, + ) -> None: + """Inject Environment Variables before training. + + Args: + master_ip (Optional[str]): The ip address of the master node. + master_port (int): The port on the master node used for dist_init. + rank_id (int): The rank id of this actor. + world_size (int): Number of Actors for DDP training. + """ + os.environ['MASTER_ADDR'] = master_ip + os.environ['MASTER_PORT'] = str(master_port) + os.environ['RANK'] = str(rank_id) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = '0' + + def get_metadata(self) -> RayActorMetadata: + node_ip, hostname = get_ip_hostname() + gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'] + gpu_num = torch.cuda.device_count() + + return RayActorMetadata( + node_ip=node_ip, + hostname=hostname, + gpu_ids=gpu_ids, + gpu_num=gpu_num, + ) + + def get_free_port(self): + return get_free_port() + + def get_memory_stats_of_visible_devices(self) -> CudaMemoryStats: + visible_gpu_ids = [] + if 'CUDA_VISIBLE_DEVICES' in os.environ: + visible_gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',') + else: + visible_gpu_ids = [ + str(index) for index in range(torch.cuda.device_count()) + ] + + cuda_memory_stats = CudaMemoryStats() + for index, gpu_id in enumerate(visible_gpu_ids): + status = torch.cuda.memory_stats(device=index) + node_ip = get_ip() + cuda_memory_stats[f'ip{node_ip}-gpu{gpu_id}'] = status + return cuda_memory_stats diff --git a/xtuner/rlhf/model_backend/ray_utils.py b/xtuner/rlhf/model_backend/ray_utils.py new file mode 100644 index 000000000..c7dc4d1f2 --- /dev/null +++ b/xtuner/rlhf/model_backend/ray_utils.py @@ -0,0 +1,36 @@ +import uuid +from typing import TypeVar + +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +DEFAULT_NUM_CPUS = 1 +DEFAULT_NUM_GPUS = 1 +T = TypeVar('T') +UUID = uuid.uuid4() # may called multiple times in different ray instances + + +# Create Ray Actors +def create_ray_actors( + name_prefix: str, + config: dict, + placement_group: PlacementGroup, + trainer_class: T, +) -> list[T]: + ray_actors = [_ for _ in range(placement_group.bundle_count)] + for index in range(placement_group.bundle_count): + ray_actors[index] = trainer_class.options( + name=f'{name_prefix}_rank_{index}', + namespace=f'{UUID}_{trainer_class.__class__.__name__}', + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=index, + ), + runtime_env=set_runtime_env(), + ).remote(config) + return ray_actors + + +def set_runtime_env(): + runtime_env = {'env_vars': {'HF_ENDPOINT': 'https://hf-mirror.com'}} + return runtime_env diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py new file mode 100644 index 000000000..70f7c76e9 --- /dev/null +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -0,0 +1,339 @@ +import os +from typing import Optional, Union + +import ray +import torch +from loguru import logger +from ray.util.placement_group import placement_group as create_placement_group +from ray.util.placement_group import remove_placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from vllm import LLM, SamplingParams +from vllm.sampling_params import _SAMPLING_EPS + +from ..config.config_utils import get_dp_size, get_tp_size +from ..policy_output import PolicyOutput, concat_policy_outputs +from .generate_utils import (get_question_answer_mask, + partition_by_micro_batch_size) +from .ray_actor_group import RayActorGroup +from .ray_actor_mixin import RayActorMixin +from .ray_utils import DEFAULT_NUM_CPUS, DEFAULT_NUM_GPUS, set_runtime_env + +VLLM_DEFAULT_DEVICE = 'cuda' + + +class VllmGenerator: + + def __init__(self, model_config) -> None: + self.model_config: dict = model_config + + # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 + def initialize(self) -> None: + model_path = self.model_config.get('model_path') + torch_dtype = self.model_config.get('torch_dtype', 'auto') + tokenizer_path = self.model_config.get('tokenizer_path', model_path) + parallel: dict = self.model_config.get('parallel') + tensor_parallel_size = 1 if parallel is None else parallel['tensor'][ + 'size'] + + import vllm + + if '0.2.7' <= vllm.__version__ <= '0.3.3' and tensor_parallel_size != 1: # noqa: E501 + # NOTE: In 0.2.7, vLLM made a major change to its architecture which move one worker into the driver process. # noqa: E501 + # Driver process will manually set CUDA_VISIBLE_DEVICES before worker init. To avoid importing torch before # noqa: E501 + # set CUDA_VISIBLE_DEVICES, we must defer monkey patch. + # For more detail, see: https://github.com/vllm-project/vllm/pull/2221 # noqa: E501 + def _set_cuda_visible_devices(device_ids: list[int]): + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( + map(str, device_ids)) + from vllm.worker import worker + + from .vllm_worker_wrap import VllmWorkerWrap + + worker.Worker = VllmWorkerWrap + + vllm.engine.llm_engine.set_cuda_visible_devices = _set_cuda_visible_devices # noqa: E501 + else: + from vllm.worker import worker + + from .vllm_worker_wrap import VllmWorkerWrap + + worker.Worker = VllmWorkerWrap + + self.llm: LLM = vllm.LLM( + model=model_path, + tokenizer=tokenizer_path, + trust_remote_code=True, + dtype=torch_dtype, + swap_space=0, + tensor_parallel_size=tensor_parallel_size, + device=VLLM_DEFAULT_DEVICE, + ) + self.tokenizer = self.llm.get_tokenizer() + tokenizer_config = self.model_config.get('tokenizer_config', {}) + for key, value in tokenizer_config.items(): + setattr(self.tokenizer, key, value) + + @staticmethod + def get_sampling_params_from_dict(generate_kwargs: dict) -> SamplingParams: + sp = SamplingParams() + for k, v in generate_kwargs.items(): + if k in sp.__dict__: + sp.__dict__[k] = v + elif k == 'num_beams' and v > 1: + sp.__dict__['use_beam_search'] = True + elif k == 'eos_token_id': + sp.__dict__['stop_token_ids'] = [v] + + sp.top_k = -1 if sp.top_k <= 1 else sp.top_k + sp._verify_args() + + if sp.use_beam_search: + sp._verify_beam_search() + else: + sp.early_stopping = False + sp._verify_non_beam_search() + if sp.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + sp.top_p = 1.0 + sp.top_k = -1 + sp.min_p = 0.0 + sp._verify_greedy_sampling() + return sp + + def generate( + self, + inputs: Union[torch.Tensor, str, list[str]], + max_inputs_length: int, + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + generate_kwargs: Optional[dict] = {}, + **_ignored, + ) -> list[tuple[list[int], str]]: + sp = VllmGenerator.get_sampling_params_from_dict(generate_kwargs) + sp.max_tokens = step if step > 0 else None + logger.info( + f'[{self.__class__.__name__}] self.generate() SamplingParams: {sp}' + ) + + if isinstance(inputs, torch.Tensor): + if len(inputs.shape) == 2: # e.g., [batch_size, seq_len] + prompt = self.tokenizer.batch_decode( + inputs, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + elif len(inputs.shape) == 1: # e.g., [seq_len] + prompt = self.tokenizer.decode( + inputs, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + else: + raise ValueError( + f'Unsupported tensor inputs of shape({inputs.shape})') + + elif isinstance(inputs, str): + prompt = inputs # str + elif isinstance(inputs, list): + if isinstance(inputs[0], list): + prompt = inputs # list[int] + else: + raise ValueError( + f'Unsupported inputs[0] with type({type(inputs[0])})') + else: + raise ValueError(f'Unsupported inputs with type({type(inputs)})') + + # Calling vllm's generate + req_outputs = self.llm.generate( + prompt_token_ids=prompt, sampling_params=sp) + + def pad_list_with_pad_token(int_list, max_length, pad_token_id): + if len(int_list) < max_length: + num_pad_token_to_add = max_length - len(int_list) + padded_list = [pad_token_id] * num_pad_token_to_add + int_list + return padded_list + else: + return int_list + + policy_outputs = [] + for _, req_output in enumerate(req_outputs): + output = PolicyOutput() + input_ids = [item for item in req_output.prompt_token_ids] + input_ids = pad_list_with_pad_token(input_ids, max_inputs_length, + self.tokenizer.pad_token_id) + output_token_ids = [ + item for item in req_output.outputs[0].token_ids + ] + output_ids = input_ids + output_token_ids # concat + output['input_ids'] = torch.Tensor(input_ids).to( + torch.long).unsqueeze(0) + output['output_ids'] = torch.tensor(output_ids).to( + torch.long).unsqueeze(0) + + output['question_mask'], output[ + 'answer_mask'] = get_question_answer_mask( + output['input_ids'], + output['output_ids'], + tokenizer_pad_token_id=self.tokenizer.pad_token_id, + generate_pad_token_id=generate_kwargs.get('pad_token_id'), + ) + output[ + 'attention_mask'] = output.question_mask + output.answer_mask # noqa: E501 + output['action_mask'] = output[ + 'attention_mask'][:, max_inputs_length - 1:-1] + if output_logits: + raise NotImplementedError('TODO: output_logits') + if output_attentions: + raise NotImplementedError('TODO: output_attentions') + if output_hidden_states: + raise NotImplementedError('TODO: output_hidden_states') + if output_str: # return list[str] + output['output_ans_str'] = [req_output.outputs[0].text] + output_str = self.tokenizer.decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output['output_str'] = [output_str] + output.to('cpu') + + policy_outputs.append(output) + + padding_token_map = {'output_ids': self.tokenizer.pad_token_id} + concated_policy_out = concat_policy_outputs(policy_outputs, + padding_token_map) + return concated_policy_out + + +class VllmGeneratorRayActor(VllmGenerator, RayActorMixin): + + # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 + def init_process_group(self, master_address, master_port, rank_offset, + world_size, group_name): + return self.llm.llm_engine._run_workers( + 'init_process_group', + master_address, + master_port, + rank_offset, + world_size, + group_name, + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + return self.llm.llm_engine._run_workers('update_weight', name, dtype, + shape, empty_cache) + + +class VllmGeneratorRayActorGroup(RayActorGroup): + + def __init__(self, name: str, config: dict): + import uuid + self.released = True + self.config = config + self.tp_size = get_tp_size(config) # tensor parallelism + self.dp_size = get_dp_size(config) # num of vllm_engines + self.tokenizer_pad_token_id = config.tokenizer_config['pad_token_id'] + self.ray_actors: list[VllmGeneratorRayActor] = [] # i.e., vllm_engines + + # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 + for dp_i in range(self.dp_size): + ray_actor_num_gpus = int(self.tp_size == 1) + scheduling_strategy = None + + if self.tp_size > 1: + bundles = [{ + 'CPU': DEFAULT_NUM_CPUS, + 'GPU': DEFAULT_NUM_GPUS + }] * self.tp_size + self.placement_group = create_placement_group(bundles) + ray.get(self.placement_group.ready()) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, + ) + + namespace = f'{uuid.uuid4()}_{VllmGeneratorRayActor.__class__.__name__}' # noqa: E501 + self.ray_actors.append( + ray.remote(VllmGeneratorRayActor).options( + name=f'{name}_rank_{dp_i}', + namespace=namespace, + num_cpus=1, + num_gpus=ray_actor_num_gpus, + scheduling_strategy=scheduling_strategy, + runtime_env=set_runtime_env(), + ).remote(config)) + + self.released = False + self.initialize_ref = [ + actor.initialize.remote() for actor in self.ray_actors + ] + + def initialize_get(self): + shared_with_trainer = self.config.get('shared_with_trainer', False) + if shared_with_trainer: + assert self.initialize_ref is None + return # assuming trainer.initialize_get() has been called + if self.initialize_ref is not None: + ray.get(self.initialize_ref) + else: + logger.warning( + 'self.initialize_ref is None when calling initialize_get()') + self.initialize_ref = None + + # Generation + def generate_async(self, input_ids, attention_mask, *args, **kwargs): + assert ( + len(input_ids) >= self.dp_size + ), f'The length of input_ids({len(input_ids)}) must not be less than dp_size({self.dp_size}).' # noqa: E501 + micro_batch_size = len(input_ids) // self.dp_size + ( + len(input_ids) % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches + ) == self.dp_size, f'{len(micro_batches)}, :{self.dp_size}' + return [ + self.ray_actors[index].generate.remote( + inputs=micro_batch['input_ids'], + max_inputs_length=micro_batch['max_inputs_length'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + + def generate_get(self, object_refs, timeout=None): + outputs = ray.get(object_refs, timeout=timeout) + padding_token_map = { + 'output_ids': self.config.tokenizer_config['pad_token_id'] + } + return concat_policy_outputs(outputs, padding_token_map) + + def generate(self, *args, **kwargs): + object_refs = self.generate_async(*args, **kwargs) + return self.generate_get(object_refs) + + # Others + def get_model(self): + return self.ray_actors[0].get_model.remote() + + def set_seed(self, seed=None): + ray.get([actor.set_seed.remote(seed) for actor in self.ray_actors]) + + def release_resources(self): + """release ray resources.""" + if self.released: + return + for actor in self.ray_actors: + try: + ray.kill(actor=actor, no_restart=True) + except BaseException as exp: + logger.error(f'failed to kill ray actor {actor}. {exp}') + remove_placement_group(self.placement_group) + self.released = True diff --git a/xtuner/rlhf/model_backend/vllm_worker_wrap.py b/xtuner/rlhf/model_backend/vllm_worker_wrap.py new file mode 100644 index 000000000..daef742e4 --- /dev/null +++ b/xtuner/rlhf/model_backend/vllm_worker_wrap.py @@ -0,0 +1,77 @@ +# Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_worker_wrap.py # noqa: E501 +import importlib + +import torch +from vllm.model_executor.weight_utils import hf_model_weights_iterator +from vllm.worker.worker import Worker + +from ..logger import init_logger +from .dist_utils import init_process_group + +logger = init_logger(__name__) + + +def _hf_model_weights_iterator_wrap(model_name_or_path, *args, **kwargs): + if isinstance(model_name_or_path, dict): + yield from model_name_or_path.items() + else: + yield from hf_model_weights_iterator(model_name_or_path, *args, + **kwargs) + + +class VllmWorkerWrap(Worker): + + def __init__(self, *args, **kwargs): + # Monkey patch hf_model_weights_iterator to allow update single weight + # NOTE: In 0.2.5, vLLM introduce lazy model loader + # https://github.com/vllm-project/vllm/pull/2044 + from vllm.model_executor.models import _MODELS, ModelRegistry + + load_model_cls = ModelRegistry.load_model_cls + + def patched_load_model_cls(model_arch: str): + module_name, _ = _MODELS[model_arch] + module = importlib.import_module( + f'vllm.model_executor.models.{module_name}') + module.hf_model_weights_iterator = _hf_model_weights_iterator_wrap + logger.info( + f'Monkey patch hf_model_weights_iterator for module {module_name}' # noqa: E501 + ) + + return load_model_cls(model_arch) + + ModelRegistry.load_model_cls = patched_load_model_cls + + super().__init__(*args, **kwargs) + + def init_process_group(self, master_address, master_port, rank_offset, + world_size, group_name): + """Init torch process group for model weights update.""" + assert torch.distributed.is_initialized( + ), 'default torch process group must be initialized' + assert group_name != '', 'group name must not be empty' + + rank = torch.distributed.get_rank() + rank_offset + self._model_update_group = init_process_group( + backend='nccl', + init_method=f'tcp://{master_address}:{master_port}', + world_size=world_size, + rank=rank, + group_name=group_name, + ) + logger.info( + f'init_process_group: master_address={master_address}, master_port={master_port}, ' # noqa: E501 + f'rank={rank}, world_size={world_size}, group_name={group_name}') + + def update_weight(self, name, dtype, shape, empty_cache=False): + """Broadcast weight to all vllm workers from source rank 0 (policy + model)""" + if torch.distributed.get_rank() == 0: + logger.debug( + f'update weight: {name}, dtype: {dtype}, shape: {shape}') + + weight = torch.empty(shape, dtype=dtype, device='cuda') + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + self.model_runner.model.load_weights(model_name_or_path={name: weight}) + + del weight diff --git a/xtuner/rlhf/model_server/__init__.py b/xtuner/rlhf/model_server/__init__.py new file mode 100644 index 000000000..d60547baa --- /dev/null +++ b/xtuner/rlhf/model_server/__init__.py @@ -0,0 +1,10 @@ +from .base_model_server import BaseModelServer +from .critic_model_server import CriticModelServer +from .policy_model_server import PolicyModelServer +from .ref_model_server import RefModelServer +from .reward_model_server import RewardModelServer + +__all__ = [ + 'BaseModelServer', 'PolicyModelServer', 'RefModelServer', + 'CriticModelServer', 'RewardModelServer' +] diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py new file mode 100644 index 000000000..63526233e --- /dev/null +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -0,0 +1,170 @@ +from typing import Optional + +import ray +import torch +from loguru import logger +from transformers import AutoModelForCausalLM + +from ..config.config_consts import ENGINE_HUGGINGFACE, ENGINE_INTERNEVO +from ..model_backend.hf_model_runner import HfModelRunnerRayActorGroup +from ..tokenizer import encode_inputs, get_tokenizer + +DEFAULT_GET_TIMEOUT = 600.0 # 10 min + + +class BaseModelServer: + # Initialize + def __init__(self, model_name: str, model_config: dict): + self.model_name = model_name + self.model_config = model_config + self.tokenizer = None + self.tokenizer_config = None + self.trainer = None + self.trainer_config = None + self.model_ref = None + self.is_initialized = False + self.show_cuda_mem_stats = self.model_config.get( + 'show_cuda_mem_stats', False) + logger.info(f'model_name={model_name}, model_config={model_config}') + + def init_tokenizer_and_config(self, model_config): + tokenizer_config = model_config.get('tokenizer_config', {}) + if 'tokenizer_path' in tokenizer_config: + tokenizer_path = tokenizer_config['tokenizer_path'] + elif 'tokenizer_path' in model_config: + tokenizer_path = model_config['tokenizer_path'] + else: + tokenizer_path = model_config['model_path'] + + self.tokenizer = get_tokenizer( + tokenizer_path, trust_remote_code=True, **tokenizer_config) + + tokenizer_config['tokenizer_path'] = tokenizer_path + tokenizer_config['pad_token_id'] = self.tokenizer.pad_token_id + self.tokenizer_config = tokenizer_config + + def init_trainer_config(self, model_config, tokenizer_config): + model_path = model_config['model_path'] + trainer_config: dict = model_config['trainer_config'] # requisite + trainer_config['tokenizer_config'] = tokenizer_config + trainer_config['tokenizer_path'] = tokenizer_config['tokenizer_path'] + trainer_config['model_path'] = model_path + trainer_config['model_type'] = model_config['model_type'] + trainer_config['model_class'] = self.get_model_class(model_path) + self.trainer_config = trainer_config + + def get_model_class(self, model_path): + # will be changed in subclasses + return AutoModelForCausalLM + + def initialize_async(self): + self.init_tokenizer_and_config(self.model_config) + self.init_trainer_config(self.model_config, self.tokenizer_config) + + trainer_type = self.trainer_config.get('trainer_type', + ENGINE_HUGGINGFACE).lower() + if trainer_type == ENGINE_HUGGINGFACE: + self.trainer = HfModelRunnerRayActorGroup( + name=f'{self.model_name}_trainer', config=self.trainer_config) + elif trainer_type == ENGINE_INTERNEVO: + raise NotImplementedError(f'{trainer_type}.') + else: + raise ValueError( + f'No trainer is registered with type: {trainer_type}') + + def initialize_get(self): + self.trainer.initialize_get() + self.is_initialized = True + logger.info(f'{self.model_name} has been initialized.') + + # Inference + def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): + if not isinstance(inputs, torch.Tensor): + input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) + else: + input_ids = inputs + return self.trainer.infer_async( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **infer_kwargs) + + def infer_get(self, object_refs, timeout: Optional[float] = None): + return self.trainer.infer_get(object_refs, timeout=timeout) + + def infer(self, inputs, *args, **infer_kwargs): + object_refs = self.infer_async(inputs, *args, **infer_kwargs) + results = self.infer_get(object_refs) + self.log_cuda_mem_stats(remark='[infer] ') + return results + + # Training + def train_async(self, + input_ids, + labels=None, + attention_mask=None, + position_ids=None, + *args, + **train_kwargs): + return self.trainer.train_async(input_ids, labels, attention_mask, + position_ids, *args, **train_kwargs) + + def train_get(self, object_refs, timeout: Optional[float] = None): + return self.trainer.train_get(object_refs, timeout=timeout) + + def train(self, + input_ids, + labels=None, + attention_mask=None, + position_ids=None, + *args, + **train_kwargs): + object_refs = self.train_async(input_ids, labels, attention_mask, + position_ids, *args, **train_kwargs) + loss = self.train_get(object_refs) + self.log_cuda_mem_stats(remark='[train] ') + return loss + + # Generation + def generate_async(self, + inputs, + attention_mask=None, + *args, + **generate_kwargs): + raise NotImplementedError + + def generate_get(self, object_refs, timeout: Optional[float] = None): + raise NotImplementedError + + def generate(self, inputs, *args, **generate_kwargs): + raise NotImplementedError + + # Model + def model_get(self): + if not self.model_ref: + self.model_ref = self.trainer.get_model() # an reference + return ray.get(self.model_ref, timeout=DEFAULT_GET_TIMEOUT) + + def state_dict_get(self): + return ray.get( + self.trainer.get_state_dict(), timeout=DEFAULT_GET_TIMEOUT) + + def save(self, path): + self.trainer.save(path) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(path) + + # Misc. + def set_seed(self, seed: int = None): + self.trainer.set_seed(seed) + + def log_cuda_mem_stats(self, remark=''): + if self.show_cuda_mem_stats: + trainer_mem = self.trainer.get_cuda_mem_stats() + logger.info( + f'{remark}{self.model_name} trainer allocated GPU memory: {trainer_mem.total_current_mb} MiB' # noqa: E501 + ) + + def clean_up(self): + self.trainer.release_resources() + logger.info(f'{self.model_name} is destroyed.') diff --git a/xtuner/rlhf/model_server/critic_model_server.py b/xtuner/rlhf/model_server/critic_model_server.py new file mode 100644 index 000000000..ee35aa829 --- /dev/null +++ b/xtuner/rlhf/model_server/critic_model_server.py @@ -0,0 +1,9 @@ +from .base_model_server import BaseModelServer +from .utils import get_critic_model + + +class CriticModelServer(BaseModelServer): + # Initialize + def get_model_class(self, model_path): + head_name = self.model_config.get('head_name', 'v_head') + return get_critic_model(model_path, head_name) diff --git a/xtuner/rlhf/model_server/policy_model_server.py b/xtuner/rlhf/model_server/policy_model_server.py new file mode 100644 index 000000000..bbe819347 --- /dev/null +++ b/xtuner/rlhf/model_server/policy_model_server.py @@ -0,0 +1,99 @@ +from typing import Optional + +import torch +from loguru import logger + +from ..config.config_consts import ENGINE_VLLM +from ..tokenizer import encode_inputs +from .base_model_server import BaseModelServer + + +class PolicyModelServer(BaseModelServer): + # Initialize + def initialize_async(self): + super().initialize_async() + + self.generator_eq_trainer = True + # use trainer for self.generate() by default + self.generator = self.trainer + if 'generator_config' not in self.model_config: + return # self.generator = self.trainer + + generator_config = self.model_config['generator_config'] # optional + if generator_config.get('shared_with_trainer', True): + return # self.generator = self.trainer + + generator_config['model_path'] = self.model_config['model_path'] + generator_config['tokenizer_config'] = self.tokenizer_config + generator_config['tokenizer_path'] = self.tokenizer_config[ + 'tokenizer_path'] + generator_type = generator_config.get('generator_type', None) + if generator_type == ENGINE_VLLM: + from ..model_backend.vllm_model_runner import \ + VllmGeneratorRayActorGroup + self.generator = VllmGeneratorRayActorGroup( + f'{self.model_name}_generator', generator_config) + # to sync model among trainer and generator + self.trainer.initialize_get() + self.trainer.init_process_group(self.generator) + else: + raise ValueError( + f"No generator is registered with type '{generator_type}'") + self.generator_eq_trainer = False + + def initialize_get(self): + self.generator.initialize_get() + self.is_initialized = True + logger.info(f'{self.model_name} has been initialized. ') + + # Generation + def generate_async(self, + inputs, + attention_mask=None, + *args, + **generate_kwargs): + if isinstance(inputs, torch.Tensor): + input_ids = inputs + elif isinstance(inputs, list): + if not self.generator_eq_trainer: + input_ids, attention_mask = encode_inputs( + inputs, + self.tokenizer, + return_tensors=None, + padding=False, + add_generation_prompt=True) + else: + input_ids, attention_mask = encode_inputs( + inputs, self.tokenizer, add_generation_prompt=True) + else: + raise NotImplementedError(f'unknown inputs: {inputs}') + + return self.generator.generate_async( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **generate_kwargs) + + def generate_get(self, object_refs, timeout: Optional[float] = None): + return self.generator.generate_get(object_refs, timeout=timeout) + + def generate(self, inputs, *args, **generate_kwargs): + object_refs = self.generate_async(inputs, *args, **generate_kwargs) + policy_output = self.generate_get(object_refs) + self.log_cuda_mem_stats(remark='[generate] ') + return policy_output + + # Sync + def sync_model(self, *args, **kwargs): + if not self.generator_eq_trainer: + self.trainer.broadcast_model_to_generator(self.generator) + + # Misc. + def log_cuda_mem_stats(self, remark=''): + if self.show_cuda_mem_stats: + trainer_mem = self.trainer.get_cuda_mem_stats() + generator_mem = self.generator.get_cuda_mem_stats() + logger.info( + f'{remark}{self.model_name} trainer allocated GPU memory: {trainer_mem.total_current_mb} MiB, ' # noqa: E501 + f'generator allocated GPU memory: {generator_mem.total_current_mb} MiB, ' # noqa: E501 + f'generator_eq_trainer: {self.generator_eq_trainer}') diff --git a/xtuner/rlhf/model_server/ref_model_server.py b/xtuner/rlhf/model_server/ref_model_server.py new file mode 100644 index 000000000..90b1dcce3 --- /dev/null +++ b/xtuner/rlhf/model_server/ref_model_server.py @@ -0,0 +1,5 @@ +from .base_model_server import BaseModelServer + + +class RefModelServer(BaseModelServer): + pass # same as BaseModelServer diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py new file mode 100644 index 000000000..84e5e42af --- /dev/null +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -0,0 +1,45 @@ +import torch +from transformers import AutoConfig + +from ..tokenizer import encode_inputs +from ..utils import expand_reward_token_id +from .base_model_server import BaseModelServer +from .utils import get_reward_model + + +class RewardModelServer(BaseModelServer): + # Initialize + def get_model_class(self, model_path): + head_name = self.model_config.get('head_name', 'v_head') + return get_reward_model(model_path, head_name) + + def init_tokenizer_and_config(self, model_config): + super().init_tokenizer_and_config(self.model_config) + + # specify `reward_token_id`` to get scalar reward of a sequence + # according to the `Rward Model` training strategy, + # which is set to `pad_token_id` by default + self.reward_token_id = self.tokenizer.pad_token_id + model_path = model_config['model_path'] + auto_config = AutoConfig.from_pretrained( + model_path, trust_remote_code=True) + if hasattr(auto_config, 'reward_token_id'): + self.reward_token_id = auto_config.reward_token_id + + # Inference + def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): + if not isinstance(inputs, torch.Tensor): + input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) + else: + input_ids = inputs + + # Reward model specific + if self.reward_token_id is not None: + input_ids, attention_mask = expand_reward_token_id( + self.reward_token_id, input_ids, attention_mask) + + return self.trainer.infer_async( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **infer_kwargs) diff --git a/xtuner/rlhf/model_server/utils.py b/xtuner/rlhf/model_server/utils.py new file mode 100644 index 000000000..8180f2278 --- /dev/null +++ b/xtuner/rlhf/model_server/utils.py @@ -0,0 +1,111 @@ +# Adopted from https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/models/model.py#L134 # noqa: E501 +from typing import Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.modeling_outputs import SequenceClassifierOutputWithPast + + +def _get_model_class(model_name_or_path: str): + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True) + config_class = type(config) + if config_class in AutoModel._model_mapping: + model_class = AutoModel._model_mapping[type(config)] + model_base_class = model_class.__base__ + return model_class, model_base_class + + if 'AutoModel' in config.auto_map: + module_file, causal_model_name = config.auto_map['AutoModel'].split( + '.') + elif 'AutoModelForCausalLM' in config.auto_map: + module_file, causal_model_name = config.auto_map[ + 'AutoModelForCausalLM'].split('.') + else: + raise Exception( + f'config of {model_name_or_path} has no AutoModel or AutoModelForCausalLM in auto_map' # noqa: E501 + ) + + model_class_name = (causal_model_name.split('For')[0] + 'Model' + ) # e.g., "InternLM2Model" + model_class = get_class_from_dynamic_module( + f'{module_file}.{model_class_name}', model_name_or_path) + model_base_class_name = (causal_model_name.split('For')[0] + + 'PreTrainedModel' + ) # e.g., "InternLM2PreTrainedModel" + model_base_class = get_class_from_dynamic_module( + f'{module_file}.{model_base_class_name}', model_name_or_path) + return model_class, model_base_class + + +def get_critic_model(model_name_or_path: str, head_name): + model_class, model_base_class = _get_model_class(model_name_or_path) + + class CriticModel(model_base_class): + supports_gradient_checkpointing = True + + def __init__(self, config: AutoConfig): + super().__init__(config) + self.model = model_class(config) + self.head_name = head_name + setattr(self, head_name, + nn.Linear(config.hidden_size, 1, bias=False)) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **_ignored, + ) -> torch.Tensor: + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = outputs[0] + logits = getattr(self, + self.head_name)(hidden_states).squeeze(-1)[:, :-1] + + return SequenceClassifierOutputWithPast(logits=logits, ) + + return CriticModel + + +def get_reward_model(model_name_or_path: str, head_name): + model_class, model_base_class = _get_model_class(model_name_or_path) + + class RewardModel(model_base_class): + supports_gradient_checkpointing = True + + def __init__(self, config: AutoConfig): + super().__init__(config) + self.model = model_class(config) + self.head_name = head_name + setattr(self, head_name, + nn.Linear(config.hidden_size, 1, bias=False)) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **_ignored, + ) -> torch.Tensor: + eos_indices = ( + attention_mask.size(1) - 1 - + attention_mask.long().fliplr().argmax(dim=1, keepdim=True)) + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = outputs[0] + values = getattr(self, self.head_name)(hidden_states).squeeze(-1) + reward_scores = values.gather(dim=1, index=eos_indices).squeeze(1) + + return SequenceClassifierOutputWithPast(logits=reward_scores, ) + + return RewardModel diff --git a/xtuner/rlhf/policy_output.py b/xtuner/rlhf/policy_output.py new file mode 100644 index 000000000..af7129410 --- /dev/null +++ b/xtuner/rlhf/policy_output.py @@ -0,0 +1,154 @@ +# Adopted from: https://github.com/huggingface/transformers/blob/HEAD/src/transformers/generation/utils.py # noqa: E501 +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.utils.generic import ModelOutput + + +@dataclass +class PolicyOutput(ModelOutput): + output_ids: Optional[torch.Tensor] = None + output_str: Optional[list[str]] = None + loss: Optional[torch.Tensor] = None + logits: Optional[torch.Tensor] = None + attentions: Optional[torch.Tensor] = None + hidden_states: Optional[torch.Tensor] = None + logits_entropy: Optional[torch.Tensor] = None + logprobs: Optional[torch.Tensor] = None + top_logprobs: Optional[torch.Tensor] = None + question_mask: Optional[torch.Tensor] = None + answer_mask: Optional[torch.Tensor] = None + + def to(self, device): + for k, v in self.items(): + if isinstance(v, torch.Tensor): + self[k] = v.to(device) + + def get_tensor_keys(self): + keys = [] + for k, v in self.items(): + if isinstance(v, torch.Tensor): + keys.append(k) + return keys + + +def union_keys_from_policy_outputs(policy_outputs: list[PolicyOutput]) -> list: + all_keys = set() + for po in policy_outputs: + all_keys = all_keys.union(set(po.keys())) + # e.g., return ["output_str", "output_ids", "loss", ...] + return list(all_keys) + + +def union_tensor_keys_from_policy_outputs( + policy_outputs: list[PolicyOutput]) -> list: + all_keys = set() + for po in policy_outputs: + all_keys = all_keys.union(set(po.get_tensor_keys())) + # e.g., return ["output_ids", "loss", ...] + return list(all_keys) + + +def concat_policy_outputs(policy_outputs: list[PolicyOutput], + padding_token_map: dict = None) -> PolicyOutput: + if isinstance(policy_outputs, PolicyOutput): + # Wrong input type + return policy_outputs + elif policy_outputs is None or len(policy_outputs) == 0: + return PolicyOutput(None) + elif len(policy_outputs) == 1: + return policy_outputs[0] + + # padding + if padding_token_map is not None: + policy_outputs = padding_policy_outputs(policy_outputs, + padding_token_map) + + concated = PolicyOutput() + all_keys = union_keys_from_policy_outputs(policy_outputs) + for key in all_keys: + for po in policy_outputs: + value = po[key] + if value is not None: + # get the first non-empty value + break + if value is None: + # skip if all values are None + continue + + if isinstance(value, torch.Tensor): + concated[key] = torch.cat( + [po[key] for po in policy_outputs if po[key] is not None], + dim=0) + elif isinstance(value, list): + # e.g., list[str] + concated[key] = [] + for po in policy_outputs: + if po[key] is not None: + concated[key].extend(po[key]) + elif isinstance(value, tuple) and isinstance(value[0], torch.Tensor): + results = [] + for i in range(len(value)): + beef = [ + po[key][i] for po in policy_outputs + if po[key][i] is not None + ] + tensor = torch.cat( + beef, dim=0) if len(beef) > 0 else torch.Tensor() + results.append(tensor) + concated[key] = tuple(results) + raise NotImplementedError( + f'{value}\n{[v.shape for v in value]}\n{results}') + else: + raise TypeError( + f'value: {value} with unsupported type: {type(value)}.') + return concated + + +def padding_policy_outputs(policy_outputs: list[PolicyOutput], + padding_token_map={}, + right_padding=True, + padding_id=0): + tensor_keys = union_tensor_keys_from_policy_outputs(policy_outputs) + for key in tensor_keys: + padding_id = padding_token_map.get(key, padding_id) + max_seq_len = find_max_seq_len(policy_outputs, key) + for policy_output in policy_outputs: + origin_tensor = policy_output[key] + padding_size = max_seq_len - origin_tensor.shape[1] + pad = (0, padding_size) if right_padding else (padding_size, 0) + padded_tensor = torch.nn.functional.pad( + origin_tensor, pad, mode='constant', value=padding_id) + policy_output[key] = padded_tensor + return policy_outputs + + +def find_max_seq_len(policy_outputs: list[PolicyOutput], key): + max_seq_len = 0 + for policy_output in policy_outputs: + if policy_output[key] is None: + continue + batch_size, seq_len = policy_output[key].shape[:2] + max_seq_len = seq_len if seq_len > max_seq_len else max_seq_len + return max_seq_len + + +def logprobs_from_logits(logits: torch.Tensor, + labels: torch.Tensor, + gather: bool = True) -> torch.Tensor: + r""" + Adapted from: https://github.com/huggingface/trl/blob/main/trl/core.py#L131 + + Example: + + ```python + >>> logits, _ = model(**input_kwargs) + >>> input_ids = input_kwargs["input_ids"] + >>> logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + ```""" + logp = torch.nn.functional.log_softmax(logits, dim=2) + if not gather: + return logp + logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) + return logpy diff --git a/xtuner/rlhf/repeaters/__init__.py b/xtuner/rlhf/repeaters/__init__.py new file mode 100644 index 000000000..14ca68a07 --- /dev/null +++ b/xtuner/rlhf/repeaters/__init__.py @@ -0,0 +1,4 @@ +from .base import RepeaterBase +from .kl_gae import KLGAERepeater + +__all__ = ['RepeaterBase', 'KLGAERepeater'] diff --git a/xtuner/rlhf/repeaters/base.py b/xtuner/rlhf/repeaters/base.py new file mode 100644 index 000000000..5a6e63054 --- /dev/null +++ b/xtuner/rlhf/repeaters/base.py @@ -0,0 +1,15 @@ +from ..policy_output import PolicyOutput + + +class RepeaterBase: + """`RepeaterBase` is the base class of different repeaters. + + `repeater` is responsible to deal with the trajectory data. + """ + + def __init__(self): + pass + + def process(self, trajectories: PolicyOutput, *args, **kwargs): + """define process, such as get GAEs.""" + raise NotImplementedError diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py new file mode 100644 index 000000000..5ab69611a --- /dev/null +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -0,0 +1,208 @@ +import torch + +from ..model_server.base_model_server import BaseModelServer +from ..policy_output import PolicyOutput +from ..timer import Timer +from .base import RepeaterBase +from .utils import RunningStates + + +class KLGAERepeater(RepeaterBase): + + def __init__( + self, + ref_model: BaseModelServer, + policy_model: BaseModelServer, + critic_model: BaseModelServer, + policy_micro_bs: int = 8, + ref_micro_bs: int = 8, + critic_micro_bs: int = 32, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min: int = -5, + clip_reward_max: int = 5, + norm_rewards=True, + norm_adv=False, + env=None, + **_ignored, + ): + # models + self.ref_model = ref_model + self.policy_model = policy_model + self.critic_model = critic_model + + self.policy_micro_bs = policy_micro_bs + self.ref_micro_bs = ref_micro_bs + self.critic_micro_bs = critic_micro_bs + self.kl_coeff = kl_coeff + self.gamma = gamma + self.gae_lambda = gae_lambda + # rewards + self.clip_reward_min = clip_reward_min + self.clip_reward_max = clip_reward_max + self.norm_rewards = norm_rewards + if self.norm_rewards: + self.running_states = RunningStates(epsilon=0) + self.norm_adv = norm_adv + + # only used for async reward model.infer_get() in _get_kl_rewards + self.env = env + + def process(self, trajectories: PolicyOutput): + critic_output_ref = self._get_values_async(trajectories) + action_mask = trajectories['action_mask'] + num_actions = action_mask.size(1) + (kl_rewards, entropy, kl_distance, policy_logprobs, + ref_logprobs) = self._get_kl_rewards(trajectories) + trajectories['kl'] = (kl_distance * action_mask).sum( + axis=-1) / action_mask.sum(axis=-1) + trajectories['entropy'] = entropy + trajectories['kl_rewards'] = kl_rewards + trajectories['policy_logprobs'] = policy_logprobs + trajectories['ref_logprobs'] = ref_logprobs + + values = self._get_values_collect(critic_output_ref) + old_values = values[:, -num_actions:] + advantages, returns = self.get_advantages_and_returns( + old_values, kl_rewards, action_mask) + if self.norm_adv: + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8) + trajectories['advantages'] = advantages + trajectories['returns'] = returns + trajectories['old_values'] = old_values + + return trajectories + + def _get_kl_rewards(self, trajectories: PolicyOutput): + with Timer('policy_model.infer_async'): + policy_output = self.policy_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.policy_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + with Timer('ref_model.infer_async'): + ref_output = self.ref_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + with Timer('policy_model.infer_get'): + policy_output = self.policy_model.infer_get(policy_output) + with Timer('ref_model.infer_get'): + ref_output = self.ref_model.infer_get(ref_output) + + # Experimental + if self.env.async_reward: + rewards = self.env.get_reward_collect( + trajectories['reward_output_ref']) + trajectories['reward_output_ref'] = None + trajectories['rewards'] = rewards + # Experimental + + clipped_rewards = torch.clamp( + rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['clipped_rewards'] = clipped_rewards + + if self.norm_rewards: + self.running_states.update(clipped_rewards) + norm_reward_score = (clipped_rewards - + self.running_states.mean) / ( + self.running_states.var.sqrt() + 1e-8) + action_mask = trajectories.action_mask + num_actions = action_mask.size(1) + + policy_logprobs = policy_output.logprobs[:, -num_actions:] + ref_logprobs = ref_output.logprobs[:, -num_actions:] + + if self.kl_coeff <= 0.0: + self.kl_coeff = 0.0 + # compute_approx_kl + log_ratio = policy_logprobs - ref_logprobs + kl = log_ratio * action_mask + kl_reward = -self.kl_coeff * kl + + eos_indices = action_mask.size( + 1) - 1 - action_mask.long().fliplr().argmax( + dim=1, keepdim=True) + last_reward = torch.zeros_like(kl).scatter_( + dim=1, + index=eos_indices, + src=norm_reward_score.unsqueeze(1).to(kl.dtype)) + + reward = last_reward + kl_reward + + entropy = -(policy_logprobs * + action_mask).sum(axis=-1) / action_mask.sum(axis=-1) + return reward, entropy, kl, policy_logprobs, ref_logprobs + + def _get_values(self, trajectories: PolicyOutput): + with Timer('critic_model.infer'): + critic_output = self.critic_model.infer( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + raw_values = critic_output.logits.squeeze(-1) + return raw_values + + def _get_values_async(self, trajectories: PolicyOutput): + with Timer('critic_model.infer_async'): + critic_output_ref = self.critic_model.infer_async( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + return critic_output_ref + + def _get_values_collect(self, critic_output_ref): + with Timer('critic_model.infer_get'): + critic_output = self.critic_model.infer_get(critic_output_ref) + raw_values = critic_output.logits.squeeze(-1) + return raw_values + + def get_advantages_and_returns( + self, + values: torch.Tensor, + rewards: torch.Tensor, + action_mask: torch.Tensor, + ): + # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 + """Function that computes advantages and returns from rewards and + values. Calculated as in the original PPO paper: + https://arxiv.org/abs/1707.06347 Note that rewards may include a KL + divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + """ + lastgaelam = 0 + advantages_reversed = [] + response_length = rewards.size(1) + + # Mask invalid responses + values = action_mask * values + rewards = action_mask * rewards + + for t in reversed(range(response_length)): + nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + # Since old_rewards and old_values are masked with action_mask, + # i.e. they have 0's at pad tokens, + # delta will be 0 if current t is at a pad token, + # so will lastgaelam + delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + returns = advantages + values + return advantages.detach(), returns diff --git a/xtuner/rlhf/repeaters/utils.py b/xtuner/rlhf/repeaters/utils.py new file mode 100644 index 000000000..e8b3e2763 --- /dev/null +++ b/xtuner/rlhf/repeaters/utils.py @@ -0,0 +1,38 @@ +import torch + + +class RunningStates: + # adopt from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py # noqa: E501 + def __init__(self, epsilon: float = 1e-4): + self.mean = torch.tensor(0, dtype=torch.float32) + self.var = torch.tensor(0, dtype=torch.float32) + self.count = epsilon + + def update(self, x: torch.Tensor): + x_var, x_mean = torch.var_mean(x.cpu(), unbiased=False) + x_count = x.shape[0] + self.update_from_moments(x_mean, x_var, x_count) + + def update_from_other(self, other: 'RunningStates'): + self.update_from_moments(other.mean, other.var, other.count) + + def update_from_moments(self, mean: torch.Tensor, var: torch.Tensor, + count: int): + delta = mean - self.mean + tot_count = self.count + count + m_a = self.var * self.count + m_b = var * count + m_2 = m_a + m_b + delta**2 * self.count * count / (self.count + count) + new_var = m_2 / (self.count + count) + + self.mean += delta * count / tot_count + self.var = new_var + self.count = tot_count + + def state_dict(self): + return dict(mean=self.mean, var=self.var, count=self.count) + + def load_state_dict(self, states): + self.mean = states['mean'] + self.var = states['var'] + self.count = states['count'] diff --git a/xtuner/rlhf/timer.py b/xtuner/rlhf/timer.py new file mode 100644 index 000000000..4574ca8c7 --- /dev/null +++ b/xtuner/rlhf/timer.py @@ -0,0 +1,27 @@ +import time + +from loguru import logger + + +class Timer: + """Timer.""" + + def __init__(self, task_name: str): + self.task_name = task_name + self.duration = 0 + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end() + + def start(self): + logger.info(f'Start {self.task_name}') + self.start = time.time() + + def end(self): + self.duration = time.time() - self.start + logger.info( + f' End {self.task_name}, duration = {self.duration:.2f} seconds') diff --git a/xtuner/rlhf/tokenizer/__init__.py b/xtuner/rlhf/tokenizer/__init__.py new file mode 100644 index 000000000..e59b36fff --- /dev/null +++ b/xtuner/rlhf/tokenizer/__init__.py @@ -0,0 +1,3 @@ +from .tokenizer_utils import encode_inputs, get_tokenizer + +__all__ = ['get_tokenizer', 'encode_inputs'] diff --git a/xtuner/rlhf/tokenizer/tokenizer_utils.py b/xtuner/rlhf/tokenizer/tokenizer_utils.py new file mode 100644 index 000000000..2a1539aa7 --- /dev/null +++ b/xtuner/rlhf/tokenizer/tokenizer_utils.py @@ -0,0 +1,72 @@ +from typing import Optional, Union + +from loguru import logger +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +PADDING_SIDE = 'left' + + +def get_tokenizer( + tokenizer_name: str, + *args, + trust_remote_code: bool = False, + tokenizer_revision: Optional[str] = None, + padding_side: Optional[str] = PADDING_SIDE, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + padding_side=padding_side, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + 'does not exist or is not currently imported.' in str(e) + or 'requires you to execute the tokenizer file' in str(e)): + err_msg = 'Failed to load the tokenizer. Try `trust_remote_code=True`.' # noqa: E501 + raise RuntimeError(err_msg) from e + else: + raise e + except AttributeError as e: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + 'Using a slow tokenizer. This might cause a significant ' + 'slowdown. Consider using a fast tokenizer instead.') + for key, value in kwargs.items(): + setattr(tokenizer, key, value) + return tokenizer + + +def encode_inputs( + inputs: Union[list[str], list[list[dict]]], + tokenizer, + return_tensors='pt', + padding=True, + add_generation_prompt: bool = False, +): + if isinstance(inputs[0], list): + inputs = [ + tokenizer.apply_chat_template( + input, + tokenize=False, + add_generation_prompt=add_generation_prompt, + return_tensors=return_tensors, + ) for input in inputs + ] + output = tokenizer( + inputs, + return_tensors=return_tensors, + padding=padding, + add_special_tokens=False) + return output.input_ids, output.attention_mask diff --git a/xtuner/rlhf/trainer/__init__.py b/xtuner/rlhf/trainer/__init__.py new file mode 100644 index 000000000..855182fb3 --- /dev/null +++ b/xtuner/rlhf/trainer/__init__.py @@ -0,0 +1,3 @@ +from .ppo import PPOTrainer + +__all__ = ['PPOTrainer'] diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py new file mode 100644 index 000000000..655c91f94 --- /dev/null +++ b/xtuner/rlhf/trainer/ppo.py @@ -0,0 +1,195 @@ +from loguru import logger + +from ..loss import CriticLoss, PPOPolicyLoss, PretrainLoss +from ..model_server.base_model_server import BaseModelServer +from ..timer import Timer + + +class PPOTrainer: + + def __init__( + self, + policy_model: BaseModelServer, + critic_model: BaseModelServer, + policy_micro_bs=2, + critic_micro_bs=2, + policy_learn_time=1, + critic_learn_time=1, + policy_minibatch=None, + critic_minibatch=None, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_criterion=PretrainLoss(label_smoothing=0), + policy_criterion=PPOPolicyLoss(cliprange=0.2), + critic_criterion=CriticLoss(cliprange_value=0.5), + **kwargs, + ): + + # policy + self.policy_model = policy_model + self.policy_learn_time = policy_learn_time + self.policy_minibatch = policy_minibatch + self.policy_micro_bs = policy_micro_bs + + self.ppo_loss_weight = ppo_loss_weight + self.pretrain_loss_weight = pretrain_loss_weight + self.pretrain_criterion = pretrain_criterion + self.policy_criterion = policy_criterion + + # critic + self.critic_model = critic_model + self.critic_learn_time = critic_learn_time + self.critic_minibatch = critic_minibatch + self.critic_micro_bs = critic_micro_bs + + self.critic_criterion = critic_criterion + + def policy_learn(self, trajectories): + if self.policy_minibatch is None: + self.policy_minibatch = len(trajectories.output_ids) + assert len(trajectories.output_ids) % self.policy_minibatch == 0 + policy_updates = len(trajectories.output_ids) // self.policy_minibatch + ppo_loss = [] + pretrain_loss = [] + + for _ in range(self.policy_learn_time): + for i in range(policy_updates): + logger.info( + '[Policy Train] start policy trains {}/{} | {}'.format( + i + 1, policy_updates, _ + 1)) + # prompt train data + begin = i * self.policy_minibatch + end = begin + self.policy_minibatch + + train_input_ids = [trajectories.output_ids[begin:end, :]] + train_attention_mask = [ + trajectories.attention_mask[begin:end, :] + ] + train_criterion = [self.policy_criterion] + loss_weights = [self.ppo_loss_weight] + micro_batch_size = [self.policy_micro_bs] + + train_lables = [ + dict( + input_ids=trajectories.output_ids[begin:end, :], + old_logprobs=trajectories.policy_logprobs[ + begin:end, :], + advantages=trajectories.advantages[begin:end, :], + mask=trajectories.action_mask[begin:end, :], + ), + ] + # pretrain data + if trajectories.pretrain_data is not None: + logger.info( + '[Policy Train] pretrain data ' + f'{trajectories.pretrain_data["input_ids"].shape}') + train_input_ids.append( + trajectories.pretrain_data['input_ids']) + train_lables.append(trajectories.pretrain_data['labels']) + # train_position_ids.append(trajectories.pretrain_data["position_ids"]) + train_attention_mask.append( + trajectories.pretrain_data['attention_mask']) + train_criterion.append(self.pretrain_criterion) + loss_weights.append(self.pretrain_loss_weight) + micro_batch_size.append(self.policy_micro_bs) + + with Timer('policy_model.train'): + p_loss = self.policy_model.train( + input_ids=train_input_ids, + labels=train_lables, + attention_mask=train_attention_mask, + # position_ids=train_position_ids, + criterion=train_criterion, + loss_weights=loss_weights, + micro_batch_size=micro_batch_size) + if isinstance(p_loss, list): + ppo_loss.append(p_loss[0].item()) + pretrain_loss.append(p_loss[1].item()) + logger.info( + f'[Policy Train] prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss[0].item()}; pretrain data: {train_input_ids[1].shape}, pretrain loss: {p_loss[1].item()}' # noqa: E501 + ) + else: + ppo_loss.append(p_loss.item()) + logger.info( + f'[Policy Train] prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' # noqa: E501 + ) + + with Timer('policy_model.sync_model'): + self.policy_model.sync_model() + return ppo_loss, pretrain_loss + + def critic_learn(self, trajectories): + if self.critic_minibatch is None: + self.critic_minibatch = len(trajectories.output_ids) + assert len(trajectories.output_ids) % self.critic_minibatch == 0 + critic_updates = len(trajectories.output_ids) // self.critic_minibatch + critic_loss = [] + + for learn_i in range(self.critic_learn_time): + for step_i in range(critic_updates): + logger.info( + '[Critic Train] start critic trains {}/{} | {}'.format( + step_i + 1, critic_updates, learn_i + 1)) + with Timer('critic_model.train'): + critic_batch_inputs, labels = self._critic_learn_prepare( + step_i, learn_i, trajectories, critic_updates) + v_loss = self.critic_model.train( + input_ids=critic_batch_inputs['input_ids'], + labels=labels, + attention_mask=critic_batch_inputs['attention_mask'], + criterion=self.critic_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info(f'[Critic train] {self.critic_minibatch} batch, ' + f'critic loss: {v_loss.item()}') + critic_loss.append(v_loss.item()) + return critic_loss + + def _critic_learn_prepare(self, step_i, learn_i, trajectories, + critic_updates): + logger.info('[Critic Train] start critic trains {}/{} | {}'.format( + step_i + 1, critic_updates, learn_i + 1)) + begin = step_i * self.critic_minibatch + end = begin + self.critic_minibatch + critic_batch_inputs = dict( + input_ids=trajectories.output_ids[begin:end, :], + old_values=trajectories.old_values[begin:end, :], + returns=trajectories.returns[begin:end, :], + action_mask=trajectories.action_mask[begin:end, :], + attention_mask=trajectories.attention_mask[begin:end, :]) + + labels = dict( + old_values=critic_batch_inputs['old_values'], + returns=critic_batch_inputs['returns'], + mask=critic_batch_inputs['action_mask'], + ) + return critic_batch_inputs, labels + + def critic_learn_async(self, trajectories): + if self.critic_minibatch is None: + self.critic_minibatch = len(trajectories.output_ids) + assert len(trajectories.output_ids) % self.critic_minibatch == 0 + critic_updates = len(trajectories.output_ids) // self.critic_minibatch + critic_loss = [] + assert critic_updates == 1 and self.policy_learn_time == 1, \ + '[WIP] `critic_learn_async` support learn async in loop' + with Timer('critic_model.train_async'): + critic_batch_inputs, labels = self._critic_learn_prepare( + 0, 0, trajectories, critic_updates) + v_loss_ref = self.critic_model.train_async( + input_ids=critic_batch_inputs['input_ids'], + labels=labels, + attention_mask=critic_batch_inputs['attention_mask'], + criterion=self.critic_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info(f'[critic train] {self.critic_minibatch} batch') + critic_loss.append(v_loss_ref) + return critic_loss + + def critic_learn_get(self, critic_loss_ref): + with Timer('critic_model.train_get'): + return [ + self.critic_model.train_get(ref).item() + for ref in critic_loss_ref + ] diff --git a/xtuner/rlhf/utils.py b/xtuner/rlhf/utils.py new file mode 100644 index 000000000..bee9ae986 --- /dev/null +++ b/xtuner/rlhf/utils.py @@ -0,0 +1,64 @@ +import os +import random +from typing import Optional + +import numpy as np +import torch + +DEFAULT_SEED_NUMBER = 1234 + + +def set_seed(seed: int = DEFAULT_SEED_NUMBER): + if seed is None or not isinstance(seed, int): + seed = DEFAULT_SEED_NUMBER + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # refer to https://pytorch.org/docs/1.13/notes/randomness.html#reproducibility # noqa: E501 + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.backends.cudnn_deterministic = True + torch.backends.cudnn_benchmark = False + torch.use_deterministic_algorithms(True, warn_only=True) + # refer to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility # noqa: E501 + os.putenv('CUBLAS_WORKSPACE_CONFIG', + os.environ.get('CUBLAS_WORKSPACE_CONFIG', ':4096:8')) + + +def expand_reward_token_id(reward_token_id: int, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + pad_token_id=0): + assert len(input_ids.shape) == 2 + new_input_ids = torch.zeros((input_ids.shape[0], input_ids.shape[1] + 1), + dtype=input_ids.dtype).to(input_ids.device) + new_attention_mask = torch.zeros_like( + new_input_ids, dtype=torch.int64).to(input_ids.device) + for i in range(input_ids.size(0)): + row = input_ids[i] + nonzero_index = (row != pad_token_id).nonzero(as_tuple=False) + if nonzero_index.numel() > 0: + nonzero_index = nonzero_index[-1] + 1 + new_input_ids[i] = torch.cat( + (input_ids[i][:nonzero_index], + torch.tensor([reward_token_id], dtype=input_ids.dtype).to( + input_ids.device), input_ids[i][nonzero_index:]), + 0).to(input_ids.device) + if attention_mask is not None: + new_attention_mask[i] = torch.cat( + (attention_mask[i][:nonzero_index], + torch.tensor([1], dtype=torch.int64).to( + input_ids.device), attention_mask[i][nonzero_index:]), + 0).to(input_ids.device) + else: + new_input_ids[i] = torch.cat( + (input_ids[i][:], + torch.tensor([reward_token_id], dtype=input_ids.dtype).to( + input_ids.device)), 0).to(input_ids.device) + if attention_mask is not None: + new_attention_mask[i] = torch.cat( + (attention_mask[i][:], torch.tensor( + [1], dtype=torch.int64).to(input_ids.device)), + 0).to(input_ids.device) + + return new_input_ids, new_attention_mask diff --git a/xtuner/tools/tokenize_ftdp_datasets.py b/xtuner/tools/tokenize_ftdp_datasets.py index 9327a91fe..769e60b4e 100644 --- a/xtuner/tools/tokenize_ftdp_datasets.py +++ b/xtuner/tools/tokenize_ftdp_datasets.py @@ -361,7 +361,7 @@ def tokenize_and_save(tokenizer, processed_dir, tokenized_dir): description=f'{os.path.basename(file_path)}...'): samples.append(sample) - train_tokens, valid_tokens, train_samples, valid_samples = write_bin_meta_bin( # noqa E501 + train_tokens, valid_tokens, train_samples, valid_samples = write_bin_meta_bin( # noqa: E501 path=tokenized_save_dir, dataset_name=dataset_name, samples=samples,