Skip to content

Commit

Permalink
fix code style, update example
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Oct 30, 2024
1 parent 1d30f2d commit d944a55
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) |

## 0. Requirements
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down Expand Up @@ -47,6 +48,9 @@ python llama3.py
:: to run Baichuan2-7B-Chat
python baichuan2.py
:: to run MiniCPM-1B-sft-bf16
python minicpm.py
```

Arguments info:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils import logging
import os

logger = logging.get_logger(__name__)

Expand All @@ -35,28 +36,49 @@
help="The huggingface repo id for the MiniCPM model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path

model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache)
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
pipeline=True,
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

print("-" * 80)
print("done")
with torch.inference_mode():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def forward(
seq_len = hidden_states.shape[1]

backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask.to(torch.int64), position_ids.to(torch.int64))
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
Expand Down

0 comments on commit d944a55

Please sign in to comment.