Skip to content

Commit

Permalink
worked
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Aug 14, 2024
1 parent 464138a commit 80421cd
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,9 @@ def pipeline_parallel_generate(self,
if self.device.type == 'xpu':
torch.xpu.synchronize()
self.rest_cost_mean = np.mean(self.next_token_time)
return output_ids, _past_key_values

return output_ids, _past_key_values, model_kwargs

import types
def run_decode(model, rank, world_size, layer_start, layer_end,
max_seq_len, transpose_value_cache,
input_queue, result_queue):
Expand All @@ -829,6 +829,14 @@ def run_decode(model, rank, world_size, layer_start, layer_end,
# trust_remote_code=True, attn_implementation="eager",
# load_in_low_bit="sym_int4", pipeline_parallel_stages=world_size)


from ipex_llm.transformers.npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, 2,
torch.float16, device="cpu")

# add pipeline_parallel_generate to pretrained model dynamically
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
model)
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
optimize_llm(model)

Expand Down Expand Up @@ -882,6 +890,12 @@ def run_decode(model, rank, world_size, layer_start, layer_end,
transpose_value=transpose_value_cache
)

if rank == 0:
print(model)
dist.barrier()
if rank == 1:
print(model)

model.model.multi_decoder = multi_decoder

result_queue.put("loading success")
Expand All @@ -893,12 +907,12 @@ def run_decode(model, rank, world_size, layer_start, layer_end,
result = input_queue.get()
if result == "stop":
break
input_ids, past_key_value, n_predict = result
output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict, past_key_values=past_key_value)
input_ids, model_kwargs, n_predict = result
model_kwargs.pop("max_new_tokens", None)
output = model.generate(input_ids, **model_kwargs, num_beams=1, do_sample=False, max_new_tokens=n_predict)
result_queue.put(output)



def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_queue):


Expand Down Expand Up @@ -970,20 +984,20 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q
with torch.inference_mode():


output, past_key_value = generate(model, input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict)
output, past_key_value, model_kwargs = generate(model, input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict)

result_queue.put((output, past_key_value))
result_queue.put((output, past_key_value, model_kwargs))


import torch.multiprocessing as mp
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for npu model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
parser.add_argument('--repo-id-or-model-path', type=str, default="D:\llm-models\Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=1,
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
parser.add_argument('--max-seq-len', type=int, default=1024)
parser.add_argument('--transpose-value-cache', action="store_true", default=False)
Expand Down Expand Up @@ -1030,20 +1044,21 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q

output = prefill_result_queue.get()

print(Fore.RED + f"prefill process output: {output}")
print(Fore.GREEN + f"prefill process output: {output}")
print(Style.RESET_ALL)

prefill_input_queue.put((input_ids, args.n_predict, args.max_seq_len, args.transpose_value_cache))
prefill_output, past_key_value = prefill_result_queue.get()
prefill_input_queue.put((input_ids, 1, args.max_seq_len, args.transpose_value_cache))
prefill_output, past_key_value, model_kwargs = prefill_result_queue.get()

print("finish prefill")
print("output tokens", prefill_output)
print("past_key_value", past_key_value)
print("model_kwargs", model_kwargs)

decode_input_ids = prefill_output[:, -1:]

decode_input_queue_0.put((decode_input_ids, past_key_value, args.n_predict))
decode_input_queue_1.put((decode_input_ids, past_key_value, args.n_predict))
decode_input_queue_0.put((decode_input_ids, model_kwargs, args.n_predict))
decode_input_queue_1.put((decode_input_ids, model_kwargs, args.n_predict))

output0 = decode_result_queue_0.get()
output1 = decode_result_queue_1.get()
Expand Down
1 change: 0 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def wrapper(model: torch.nn.Module, qtype, device, *args, **kwargs):
"""
for name, layer in model.named_children():
print(f"converting layers {name}")
new_layer = func(layer, qtype, device, *args, **kwargs)
if new_layer:
model.add_module(name, new_layer)
Expand Down
3 changes: 3 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def llama_fused_model_forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# print("attention_mask", attention_mask)
# print("cache_position", cache_position)
# print('past_seen_tokens', past_seen_tokens)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
cache_position, past_seen_tokens)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,18 @@ def pipeline_parallel_generate(self,

if _input_ids is None:
_input_ids = input_ids

model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs)


model_inputs = self.prepare_inputs_for_generation(_input_ids, **model_kwargs)
if local_rank == 1:
from colorama import Fore, Back, Style
# print(Fore.GREEN + f"model_inputs: ", model_inputs, Style.RESET_ALL)
# print(Fore.GREEN + f"_input_ids: ", _input_ids, Style.RESET_ALL)

tic = time.time()
if local_rank == 0:
outputs = self(**model_inputs)
else:
# _input_ids = model_inputs.pop("input_ids")
_inputs_shape = _input_ids.shape + (self.config.hidden_size,)
if step == 0 and self.config.model_type == "chatglm" \
and hasattr(self.config, "vision_config"):
Expand Down

0 comments on commit 80421cd

Please sign in to comment.