Skip to content

Commit

Permalink
add support for chatglm2/3
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyuT committed Jun 28, 2024
1 parent bf788c3 commit d1af961
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def chatglm2_model_forward(
else:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
seq_length, batch_size, _ = inputs_embeds.shape
input_ids = torch.empty((batch_size, seq_length), device=inputs_embeds.device)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (
Expand Down
76 changes: 58 additions & 18 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,41 @@ def prepare_batch(self, cur_batch):
cur_batch.partial_prefilling = 0

return cur_batch



def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
if model_type in ["baichuan", "chatglm"]:
result = []
for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
if sub_tuple1 is None:
sub_result = [sub_tuple2]
elif sub_tuple2 is None:
sub_result = [sub_tuple1]
else:
sub_result = []
for t1, t2 in zip(sub_tuple1, sub_tuple2):
if t1 is None:
sub_result.append(t2)
elif t2 is None:
sub_result.append(t1)
else:
if model_type == "chatglm":
sub_result.append(torch.cat((t1, t2), dim=1))
else:
sub_result.append(torch.cat((t1, t2), dim=0))
result.append(tuple(sub_result))
return tuple(result)
else:
num_layers = self.model.layer_end - self.model.layer_start
for layer_idx in range(num_layers):
kv_cache_1.key_cache[layer_idx] = \
torch.cat([kv_cache_1.key_cache[layer_idx],
kv_cache_2.key_cache[layer_idx]], dim=0)
kv_cache_1.value_cache[layer_idx] = \
torch.cat([kv_cache_1.value_cache[layer_idx],
kv_cache_2.value_cache[layer_idx]], dim=0)

return kv_cache_1
@torch.no_grad()
def model_step(self, input, cur_batch):
if cur_batch is None or cur_batch.stopped or input is None:
Expand Down Expand Up @@ -462,26 +496,31 @@ def model_step(self, input, cur_batch):
if tmp_past_key_values is None:
tmp_past_key_values = output.past_key_values
else:
if self.model.config.model_type in ["baichuan", "chatglm"]:
tmp_past_key_values = torch.cat((tmp_past_key_values, output.past_key_values), dim=0)
else:
num_layers = self.model.layer_end - self.model.layer_start
for layer_idx in range(num_layers):
tmp_past_key_values.key_cache[layer_idx] = \
torch.cat([tmp_past_key_values.key_cache[layer_idx],
output.past_key_values.key_cache[layer_idx]], dim=0)
tmp_past_key_values.value_cache[layer_idx] = \
torch.cat([tmp_past_key_values.value_cache[layer_idx],
output.past_key_values.value_cache[layer_idx]], dim=0)
tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type, tmp_past_key_values, output.past_key_values)
torch.xpu.empty_cache()

if cur_batch.prefilled_index == cur_batch.batch_size:
if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0])
tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \
tuple(None for _ in range(layer_start)) + \
(tmp_past_key_values)[layer_start:]

# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (output.past_key_values)[layer_start:]
# _past_key_values = past_key_values_placeholder

self.past_key_values_dict[cur_id] = tmp_past_key_values

if self.pp_config.is_tail:
_pre_output = self.partial_output_dict.get(cur_id, None)
tmp_output = output.logits.to(self.dtype)
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
if _pre_output is None:
_pre_output = output.logits.to(self.dtype)
_pre_output = tmp_output
else:
_pre_output = torch.cat((_pre_output, output.logits.to(self.dtype)), dim=0)
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
self.partial_output_dict[cur_id] = _pre_output
else:
if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
Expand All @@ -498,11 +537,12 @@ def model_step(self, input, cur_batch):
return output[0].to(self.dtype), cur_batch
else:
if cur_batch.partial_prefilling > 0 and cur_batch.prefilled_index == cur_batch.batch_size:
_output = self.partial_output_dict.get(cur_id, None)
_output = self.partial_output_dict.pop(cur_id, None)
cur_batch.partial_prefilling = 0
return _output, cur_batch
else:
return output.logits, cur_batch
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
return _output, cur_batch

def is_initialized(self):
return True
Expand Down Expand Up @@ -681,8 +721,8 @@ async def process_step(self, tokenizer, result_dict):
dist.recv(cur_input, src=self.pre_rank)

output, cur_batch = self.model_step(cur_input, cur_batch)
if output is not None and self.rank == self.world_size - 1:
output = torch.argmax(output[:, -1:, :], dim=-1)
# if output is not None and self.rank == self.world_size - 1:
# output = torch.argmax(output[:, -1:, :], dim=-1)

if output is not None:
# dist.send(output, dst=self.next_rank)
Expand Down

0 comments on commit d1af961

Please sign in to comment.