Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support baichuan2 for level0 pipeline #12289

Merged
merged 10 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
|------------|----------------------------------------------------------------|
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |

## 0. Requirements
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down Expand Up @@ -43,6 +44,9 @@ python llama2.py

:: to run Meta-Llama-3-8B-Instruct
python llama3.py

:: to run Baichuan2-7B-Chat
python baichuan2.py
```

Arguments info:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import torch
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils import logging

logger = logging.get_logger(__name__)

def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Predict Tokens using `generate()` API for npu model"
)
parser.add_argument(
"--repo-id-or-model-path",
type=str,
default="baichuan-inc/Baichuan2-7B-Chat",
help="The huggingface repo id for the Baichuan2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--quantization_group_size", type=int, default=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove this argument for now as we have not support GW for baichuan2 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove this argument for now as we have not support GW for baichuan2 ?

Sure, have removed.

parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path

model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
quantization_group_size=args.quantization_group_size,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

DEFAULT_SYSTEM_PROMPT = """\
"""

print("-" * 80)
print("done")
with torch.inference_mode():
print("finish to load")
for i in range(5):
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
)
end = time.time()
print(f"Inference time: {end-st} s")
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str)

print("-" * 80)
print("done")
print("success shut down")
66 changes: 36 additions & 30 deletions python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,32 +112,14 @@ def __init__(

# Self Attention
if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1))
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64)
else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len))
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
dtype=np.int64)

position_ids = self.create_input_op((self.batch_size, self.seq_len))
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
# self.num_key_value_heads = num_key_value_heads
past_keys = []
past_values = []
if mode == "decode":
for i in range(num_layers):
past_key = self.create_cache_op(
(self.batch_size, self.num_heads, self.max_seq_len, self.head_dim)
)
if transpose_value:
past_value = self.create_cache_op(
(self.batch_size, self.num_heads, self.head_dim, self.max_seq_len)
)
else:
past_value = self.create_cache_op(
(self.batch_size, self.num_heads, self.max_seq_len, self.head_dim)
)
past_keys.append(past_key)
past_values.append(past_value)
else:
past_keys = [None] * num_layers
past_values = [None] * num_layers

if input_layernorm_weights is None:
input_layernorm_weights = []
Expand All @@ -163,6 +145,27 @@ def __init__(
input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights]
post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]

past_keys = []
past_values = []
if mode == "decode":
for i in range(num_layers):
past_key = self.create_cache_op(
(self.batch_size, self.num_heads, self.max_seq_len, self.head_dim)
)
if transpose_value:
past_value = self.create_cache_op(
(self.batch_size, self.num_heads, self.head_dim, self.max_seq_len)
)
else:
past_value = self.create_cache_op(
(self.batch_size, self.num_heads, self.max_seq_len, self.head_dim)
)
past_keys.append(past_key)
past_values.append(past_value)
else:
past_keys = [None] * num_layers
past_values = [None] * num_layers

hidden_states = input

curr_key_values = []
Expand Down Expand Up @@ -251,6 +254,7 @@ def attention(self,

attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(self.head_dim))
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
Expand Down Expand Up @@ -395,8 +399,8 @@ def forward(

inputs = (
hidden_states.to(torch.float16),
attention_mask,
position_ids.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64),
)

for i in range(self.intra_stages):
Expand Down Expand Up @@ -502,7 +506,9 @@ def forward(
seq_len = hidden_states.shape[1]

backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
Expand Down Expand Up @@ -625,9 +631,9 @@ def run_decode(

pad_mask = (0, pad_len)
padded_causal_mask = F.pad(
attention_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
attention_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
)
padded_causal_mask[:, :, :, -1] = 0.0
padded_causal_mask[:, :, :, -1] = 0
dist.recv(hidden_states, src=rank - 1)
layer_outputs = multi_decoder(
hidden_states,
Expand Down Expand Up @@ -869,9 +875,9 @@ def forward(
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad(
attention_mask.to(torch.float16),
attention_mask.to(torch.int64),
(0, pad_len, 0, pad_len),
value=torch.finfo(torch.float16).min,
value=torch.iinfo(torch.int64).min,
)

args = (hidden_states, position_ids, attention_mask, past_key_value)
Expand Down
64 changes: 42 additions & 22 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,41 @@ def convert_llama(
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)


def convert_baichuan(
model: torch.nn.Module,
max_output_len=1024,
max_prompt_len=1024,
decoder=False,
inter_pp=None,
intra_pp=None,
transpose_value_cache=True,
):
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
if decoder:
decode_runner = DecodeRunner(
model,
max_seq_len=max_output_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
else:
decode_runner = None
prefill_runner = PrefillRunner(
model,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
baichuan_model_forward = gen_baichuan_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model, module.BaichuanModel, baichuan_model_forward)


def optimize_llm(
model: torch.nn.Module,
max_context_len=1024,
Expand Down Expand Up @@ -297,28 +332,13 @@ def optimize_llm(
intra_pp = 2
if inter_pp is None:
inter_pp = 2
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
decode_runner = DecodeRunner(
model,
max_seq_len=max_context_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
prefill_runner = PrefillRunner(
model,
max_output_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
baichuan_model_forward = gen_baichuan_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model, module.BaichuanModel, baichuan_model_forward)

convert_baichuan(model,
max_output_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
decoder=True,
transpose_value_cache=transpose_value_cache)
if isinstance(model.lm_head, SlicedLMHead):
model.lm_head.get_fused_lm_head()

Expand Down
Loading
Loading