Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pipeline Parallelization of Different Stages in RLHF #877

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ data
*.pkl.json
*.log.json
work_dirs/
rlhf_trainlog*/

# Pytorch
*.pth
Expand All @@ -122,3 +123,6 @@ work_dirs/
# srun
*.out
batchscript-*

# custom
logs/
95 changes: 95 additions & 0 deletions README_pipeline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
## pipeline优化
### 优化原理

RLHF的每次迭代过程可以分为三个阶段:Generation、Forward和Train。在Generation阶段,由vLLM推理生成回复;在Forward阶段,actor、critic、reference和reward四个模型进行推理;在Train阶段,actor和critic模型进行训练。

在每个阶段运行时,其它阶段的GPU会处于空闲等待状态,导致资源浪费。

为了解决这个问题,可以借助流水线并行的思想进行优化。将batch数据分为多个小的micro-batch,每个阶段处理完一个micro-batch后,立即将数据传递到下一个阶段进行处理,而不是等待整个batch处理完成。这样可以减少各阶段GPU的空闲等待时间,提高资源利用率。

### 运行步骤

1)vLLM添加接口
- 获取vLLM安装路径
```shell
export vllm=$(pip show numpy | grep Location | awk '{print $2"/vllm"}')
```

- 编辑$vllm/entrypoints/llm.py,在`class LLM`中添加下面两个接口
```python
def generate_to_queue(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
queue = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts and put result to queue.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()

# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
return self._run_engine_to_queue(use_tqdm, queue)


def _run_engine_to_queue(self, use_tqdm: bool, queue) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
queue.put(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
```

### 参数配置
参考配置文件 examples/rlhf/internlm2_20b_pipe_32gpu.py
```python
...
PIPE_MICRO_BATCH_NUM = 4 # 调整micro-batch的数量
...
```

### 精度影响
启用norm_rewards时,精度无法严格对齐。原因在于norm_rewards对奖励进行了归一化处理。在优化前,归一化操作是在整个batch上进行的;而优化后,归一化操作是在每个micro-batch上分别进行。
2 changes: 2 additions & 0 deletions examples/rlhf/demo_datas/pretrain_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[{"role": "user", "content": ""}, {"role": "assistant", "content": "I am an artificial intelligence (AI) assistant named InternLM. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology."}]
[{"role": "user", "content": "Give three tips for staying healthy."}, {"role": "assistant", "content": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."}, {"role": "user", "content": "How to study English?"}, {"role": "assistant","content": "1. Set clear goals. 2. Create a study plan. 3. Build vocabulary. 4. Practice speaking."}]
3 changes: 3 additions & 0 deletions examples/rlhf/demo_datas/prompt_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[{"role": "user", "content": "How to study English?"}]
[{"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": "Give three tips for staying healthy."}]
[{"role": "user", "content": "Give three tips for staying healthy."}, {"role": "assistant", "content": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."}, {"role": "user", "content": "How to study English?"}]
277 changes: 277 additions & 0 deletions examples/rlhf/internlm2_20b_pipe_32gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
#######################################################################
# Settings #
#######################################################################
RESUME_STEP = -1
MAX_PROMPT_LEN = 1536
MAX_ANSWER_LEN = 512
MAX_PRETRAIN_LEN = 4096

PROMPT_BATCH_SIZE = 128
PRETRAIN_BATCH_SIZE = 128

PIPE_MICRO_BATCH_NUM = 4
assert PROMPT_BATCH_SIZE % PIPE_MICRO_BATCH_NUM == 0
PIPE_MICRO_BATCH_SIZE = PROMPT_BATCH_SIZE // PIPE_MICRO_BATCH_NUM #32

GENERATE_MICRO_BATCH_SIZE = 8
INFER_MICRO_BATCH_SIZE = 2
TRAIN_MICRO_BATCH_SIZE = 1

ZERO_STAGE = 3
POLICY_DP_SIZE = 8
CRITIC_DP_SIZE = 8
REF_DP_SIZE = 4
REWARD_DP_SIZE = 4
VLLM_TP_SIZE=8
POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE
) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE
CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501

# checkout generate config
assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0
assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0
# checkout infer config
assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0
assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0
# checkout learn config
assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE *
POLICY_DP_SIZE) == 0
assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0

MODEL_DTYPE = 'auto'

tokenizer_config = dict(
pad_token_id=0,
eos_token_id=92542,
padding_side='left',
)

rollout_config = dict(
policy_micro_bs=GENERATE_MICRO_BATCH_SIZE,
max_new_tokens=MAX_ANSWER_LEN,
write_to_file=True,
resume_step=RESUME_STEP,
generate_kwargs={
'do_sample': True,
'temperature': 1.0,
'top_k': 0,
'top_p': 0.9,
'min_new_tokens': 1,
'num_beams': 1,
'early_stopping': True,
'eos_token_id': 92542,
'pad_token_id': 0,
},
async_reward=True,
)

repeater_config = dict(
ref_micro_bs=INFER_MICRO_BATCH_SIZE,
reward_micro_bs=GENERATE_MICRO_BATCH_SIZE,
kl_coeff=0.01,
gamma=1.0,
gae_lambda=0.99,
clip_reward_min=-5,
clip_reward_max=5,
norm_rewards=True,
)

train_config = dict(
pipe_micro_bs=PIPE_MICRO_BATCH_SIZE,
policy_train_micro_bs=TRAIN_MICRO_BATCH_SIZE,
critic_train_micro_bs=TRAIN_MICRO_BATCH_SIZE,
policy_infer_micro_bs=INFER_MICRO_BATCH_SIZE,
critic_infer_micro_bs=INFER_MICRO_BATCH_SIZE,
ppo_loss_weight=1.0,
pretrain_loss_weight=0.5,
# critic_warmup_step=40,
critic_warmup_step=0, ## Debug-Only
save_interval=200,
max_train_step=800,
resume_step=RESUME_STEP,
)

model_configs = dict(
policy=dict(
model_path=None,
model_type='policy',
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
use_flash_attn=True,
gradient_checkpointing=True,
train_kwargs=dict(
micro_bsz=1,
lr=5e-7,
total_steps=1e9,
lr_decay_rate=1,
),
parallel=dict(
data=dict(size=POLICY_DP_SIZE, mode='deepspeed'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
deepspeed_config={
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"stage3_gather_16bit_weights_on_model_save": True
},
'bf16': {
'enabled': True
},
'gradient_clipping': 1.0,
'prescale_gradients': False,
'wall_clock_breakdown': False,
'data_types': {
'grad_accum_dtype': 'fp32'
},
'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE,
'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP,
'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE,
},
),
generator_config=dict(
shared_with_trainer=False,
generator_type='vllm',
parallel=dict(
data=dict(size=1, mode='ddp'),
tensor=dict(size=VLLM_TP_SIZE, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
),
),
critic=dict(
model_path=None,
model_type="critic",
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
use_flash_attn=True,
gradient_checkpointing=True,
train_kwargs=dict(
micro_bsz=1,
lr=9e-6,
total_steps=1e9,
lr_decay_rate=1,
loss_type="per_seq",
),
parallel=dict(
data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
deepspeed_config={
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"stage3_gather_16bit_weights_on_model_save": True
},
'bf16': {
'enabled': True
},
'gradient_clipping': 1.0,
'prescale_gradients': False,
'wall_clock_breakdown': False,
'data_types': {
'grad_accum_dtype': 'fp32'
},
'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE,
'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP,
'train_batch_size': PROMPT_BATCH_SIZE,
},
),
),
reference=dict(
model_path=None,
model_type="reference",
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
use_flash_attn=True,
parallel=dict(
data=dict(size=REF_DP_SIZE, mode="deepspeed"),
tensor=dict(size=1, mode="1d"),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
deepspeed_config={
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"stage3_gather_16bit_weights_on_model_save": True
},
"bf16": {
"enabled": True
},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"data_types": {
"grad_accum_dtype": "fp32"
},
"train_micro_batch_size_per_gpu": 2
},
),
),
reward=dict(
model_path=None,
model_type="reward",
trainer_config=dict(
torch_dtype=MODEL_DTYPE,
trainer_type='huggingface',
use_flash_attn=True,
parallel=dict(
data=dict(size=REWARD_DP_SIZE, mode="deepspeed"),
tensor=dict(size=1, mode='1d'),
pipeline=dict(size=1, interleaved_overlap=False),
sequence=False,
),
deepspeed_config={
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"stage3_gather_16bit_weights_on_model_save": True
},
"bf16": {
"enabled": True
},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"data_types": {
"grad_accum_dtype": "fp32"
},
"train_micro_batch_size_per_gpu": 2
},
),
),
)

prompt_dataset_config = dict(
samples_each_epoch=PROMPT_BATCH_SIZE,
max_len=MAX_PROMPT_LEN,
message_type='prompt',
random_seed=1024,
sample_strategy='in_batch', # 'in_data'
message_datasets=[
'./examples/rlhf/demo_datas/prompt_data.json::0.01[SYS_PROMPT]:summarization', # noqa: E501
'[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default',
'[HF]HuggingFaceH4/summarize_from_feedback::0.5',
])

pretrain_dataset_config = dict(
samples_each_epoch=PRETRAIN_BATCH_SIZE,
max_len=MAX_PRETRAIN_LEN,
message_type='pretrain',
random_seed=1024,
sample_strategy='in_batch', # 'in_data'
message_datasets=[
'./examples/rlhf/demo_datas/pretrain_data.json::0.01',
'[HF]Anthropic/hh-rlhf/helpful-base::0.5',
'[HF]HuggingFaceH4/summarize_from_feedback::0.5',
],
)
Loading