From 26d928786b5be309090a2c6d4f90afa32f191381 Mon Sep 17 00:00:00 2001 From: Mike Guo Date: Mon, 15 Jan 2024 08:02:37 +0800 Subject: [PATCH] add phi2 model conversion (#864) ## Describe your changes add phi2 model example to convert the microsoft/phi2 model. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link --- examples/phi2/phi2_optimize.json | 118 +++++++++++++++ examples/phi2/readme.md | 10 ++ examples/phi2/requirements.txt | 1 + examples/phi2/user_script.py | 239 +++++++++++++++++++++++++++++++ 4 files changed, 368 insertions(+) create mode 100644 examples/phi2/phi2_optimize.json create mode 100644 examples/phi2/readme.md create mode 100644 examples/phi2/requirements.txt create mode 100644 examples/phi2/user_script.py diff --git a/examples/phi2/phi2_optimize.json b/examples/phi2/phi2_optimize.json new file mode 100644 index 000000000..014d5dac4 --- /dev/null +++ b/examples/phi2/phi2_optimize.json @@ -0,0 +1,118 @@ +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_script": "user_script.py", + "dummy_inputs_func": "dummy_inputs", + "io_config": "get_io_config", + "hf_config": { + "model_name": "microsoft/phi-2", + "task": "text-generation", + "from_pretrained_args": { + "trust_remote_code": true + } + } + } + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": [ + "cpu" + ] + } + } + }, + "evaluators": { + "common_evaluator": { + "metrics": [ + { + "name": "latency", + "type": "latency", + "sub_types": [ + { + "name": "avg", + "priority": 1 + } + ], + "user_config": { + "user_script": "user_script.py", + "dataloader_func": "create_dataloader", + "batch_size": 2, + "inference_settings": { + "onnx": { + "session_options": { + "enable_profiling": false + } + } + } + } + } + ] + } + }, + "passes": { + "convert_optimum": { + "type": "OptimumConversion", + "config": { + "target_opset": 17, + "extra_args": { + "legacy": false, + "no_post_process": true + } + } + }, + "convert": { + "type": "OnnxConversion", + "config": { + "target_opset": 17, + "save_as_external_data": true, + "all_tensors_to_one_file": true + } + }, + "optimize": { + "type": "OrtTransformersOptimization", + "config": { + "model_type": "gpt2", + "use_gpu": false, + "keep_io_types": true, + "num_heads": 32, + "hidden_size": 2560, + "optimization_options": { + "use_multi_head_attention": false + }, + "save_as_external_data": true, + "all_tensors_to_one_file": true + } + }, + "perf_tuning": { + "type": "OrtPerfTuning", + "config": { + "user_script": "user_script.py", + "dataloader_func": "create_dataloader", + "batch_size": 2, + "enable_profiling": false + } + } + }, + "pass_flows": [ + [ + "convert", + "optimize", + "perf_tuning" + ] + ], + "engine": { + "search_strategy": false, + "evaluate_input_model": true, + "evaluator": "common_evaluator", + "target": "local_system", + "cache_dir": "cache", + "output_name": "phi2", + "output_dir": "phi2", + "clean_cache": false, + "log_severity_level": 0, + "log_to_file": false + } +} diff --git a/examples/phi2/readme.md b/examples/phi2/readme.md new file mode 100644 index 000000000..f8b89c742 --- /dev/null +++ b/examples/phi2/readme.md @@ -0,0 +1,10 @@ +## Prerequisites +* einops + +## Usage +```bash +python -m olive.workflows.run --config phi2_optimize.json +``` + +## Limitations +when https://github.com/huggingface/optimum/issues/1642 is fixed, we need turn on the post_process by changing `no_post_process` to False in the config file. diff --git a/examples/phi2/requirements.txt b/examples/phi2/requirements.txt new file mode 100644 index 000000000..d27fa26c6 --- /dev/null +++ b/examples/phi2/requirements.txt @@ -0,0 +1 @@ +einops diff --git a/examples/phi2/user_script.py b/examples/phi2/user_script.py new file mode 100644 index 000000000..13d6da123 --- /dev/null +++ b/examples/phi2/user_script.py @@ -0,0 +1,239 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from itertools import chain +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch +from transformers import AutoConfig + +from olive.constants import Framework + +if TYPE_CHECKING: + from transformers import PhiConfig + +model_id = "microsoft/phi-2" +config: "PhiConfig" = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + + +def dummy_inputs(model): + """Get dummy inputs for merged decoder model with past_key_values.""" + batch_size, sequence_length, past_sequence_length = 2, 8, config.num_hidden_layers + max_sequence_length = 512 + + return get_merged_sample_with_past_kv_inputs( + config, + torch.device("cpu"), + batch_size, + sequence_length, + past_sequence_length, + max_sequence_length, + use_fp16=False, + use_gqa=False, + engine="pt", + return_dict=True, + world_size=1, + ) + + +def get_io_config(model): + input_names = [ + "input_ids", + "attention_mask", + *list( + chain.from_iterable( + (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(config.num_hidden_layers) + ) + ), + ] + output_names = [ + "logits", + *list(chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))), + ] + dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) + return { + "input_names": input_names, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + } + + +def create_dataloader(data_dir, batch_size, *args, **kwargs): + sequence_length, past_sequence_length = 8, config.num_hidden_layers + max_sequence_length = 512 + model_framework = kwargs.get("model_framework", Framework.PYTORCH) + engine = "ort" if model_framework == Framework.ONNX else "pt" + + return RandomDataLoader(batch_size, sequence_length, past_sequence_length, max_sequence_length, engine=engine) + + +class RandomDataLoader: + def __init__( + self, + batch_size: int, + seq_len: int, + past_seq_len: int, + max_seq_len: int, + engine: str = "pt", + use_fp16: bool = False, + use_gqa: bool = False, + ): + self.model_id = model_id + self.batch_size = batch_size + self.seq_len = seq_len + self.past_seq_len = past_seq_len + self.max_seq_len = max_seq_len + self.engine = engine + if use_gqa and (engine != "ort" or not use_fp16): + raise ValueError("GQA is only supported for ONNX model with FP16") + self.use_fp16 = use_fp16 + self.use_gqa = use_gqa + + def __getitem__(self, idx): + inputs = get_merged_sample_with_past_kv_inputs( + config, + device=torch.device("cpu"), + batch_size=self.batch_size, + seq_len=self.seq_len, + past_seq_len=self.past_seq_len, + max_seq_len=self.max_seq_len, + use_fp16=self.use_fp16, + use_gqa=self.use_gqa, + engine=self.engine, + return_dict=True, + ) + return (inputs, None) + + +def get_position_ids(attention_mask: torch.Tensor, past_seq_len: int): + """Get position_ids from attention_mask.""" + # this is generic but in practice we only expect to see two scenarios for (past_seq_len, seq_len) + # prompt generation: (0, seq_len) -> position_ids = (batch_size, seq_len) + # token generation: (past_seq_len, 1) -> position_ids = (batch_size, 1) + # Note: The merged model only works in these two scenarios + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + return position_ids[:, past_seq_len:] + + +def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 1: "max_sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) + # = (batch_size, num_heads, total_sequence_length, head_size) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 1: "max_sequence_length"} + else: + raise Exception("Unknown input or output name found") + return dynamic_axes + + +# Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# past_kv: (batch_size, max_sequence_length, 2, num_heads, head_size) +def get_merged_sample_with_past_kv_inputs( + config: AutoConfig, + device: torch.device, + batch_size: int, + seq_len: int, + past_seq_len: int, + max_seq_len: int, + use_fp16: bool = False, + use_gqa: bool = False, + engine: str = "pt", + return_dict: bool = False, + world_size: int = 1, +): + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) + # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation + position_ids = get_position_ids(attention_mask, past_seq_len) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + + # ruff: noqa: C417 + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) + + if not return_dict: + # For export + assert isinstance(past_kv, list) + return (input_ids, past_kv, attention_mask, position_ids) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_gqa: + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + + return inputs + + +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): + """Flatten past_key_values to a dict of past_key and past_value. For ONNX model only.""" + past_kv = {} + # Convert list of past_kv to dict of past_key and past_value + for i, (past_k, past_v) in enumerate(past_key_values): + past_kv[f"past_key_values.{i}.key"] = past_k + past_kv[f"past_key_values.{i}.value"] = past_v + return past_kv + + +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs + + +def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): + num_heads = config.num_attention_heads // world_size + head_size = config.hidden_size // config.num_attention_heads + torch_dtype = torch.float16 if use_fp16 else torch.float32 + return [ + ( + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + ) + for _ in range(config.num_hidden_layers) + ]