Skip to content

Commit

Permalink
add phi2 model conversion (microsoft#864)
Browse files Browse the repository at this point in the history
## Describe your changes
add phi2 model example to convert the microsoft/phi2 model. 

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
guotuofeng authored Jan 15, 2024
1 parent bf493a4 commit 26d9287
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 0 deletions.
118 changes: 118 additions & 0 deletions examples/phi2/phi2_optimize.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
{
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "user_script.py",
"dummy_inputs_func": "dummy_inputs",
"io_config": "get_io_config",
"hf_config": {
"model_name": "microsoft/phi-2",
"task": "text-generation",
"from_pretrained_args": {
"trust_remote_code": true
}
}
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"accelerators": [
"cpu"
]
}
}
},
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "latency",
"type": "latency",
"sub_types": [
{
"name": "avg",
"priority": 1
}
],
"user_config": {
"user_script": "user_script.py",
"dataloader_func": "create_dataloader",
"batch_size": 2,
"inference_settings": {
"onnx": {
"session_options": {
"enable_profiling": false
}
}
}
}
}
]
}
},
"passes": {
"convert_optimum": {
"type": "OptimumConversion",
"config": {
"target_opset": 17,
"extra_args": {
"legacy": false,
"no_post_process": true
}
}
},
"convert": {
"type": "OnnxConversion",
"config": {
"target_opset": 17,
"save_as_external_data": true,
"all_tensors_to_one_file": true
}
},
"optimize": {
"type": "OrtTransformersOptimization",
"config": {
"model_type": "gpt2",
"use_gpu": false,
"keep_io_types": true,
"num_heads": 32,
"hidden_size": 2560,
"optimization_options": {
"use_multi_head_attention": false
},
"save_as_external_data": true,
"all_tensors_to_one_file": true
}
},
"perf_tuning": {
"type": "OrtPerfTuning",
"config": {
"user_script": "user_script.py",
"dataloader_func": "create_dataloader",
"batch_size": 2,
"enable_profiling": false
}
}
},
"pass_flows": [
[
"convert",
"optimize",
"perf_tuning"
]
],
"engine": {
"search_strategy": false,
"evaluate_input_model": true,
"evaluator": "common_evaluator",
"target": "local_system",
"cache_dir": "cache",
"output_name": "phi2",
"output_dir": "phi2",
"clean_cache": false,
"log_severity_level": 0,
"log_to_file": false
}
}
10 changes: 10 additions & 0 deletions examples/phi2/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
## Prerequisites
* einops

## Usage
```bash
python -m olive.workflows.run --config phi2_optimize.json
```

## Limitations
when https://github.com/huggingface/optimum/issues/1642 is fixed, we need turn on the post_process by changing `no_post_process` to False in the config file.
1 change: 1 addition & 0 deletions examples/phi2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
einops
239 changes: 239 additions & 0 deletions examples/phi2/user_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from itertools import chain
from typing import TYPE_CHECKING, List, Tuple

import numpy as np
import torch
from transformers import AutoConfig

from olive.constants import Framework

if TYPE_CHECKING:
from transformers import PhiConfig

model_id = "microsoft/phi-2"
config: "PhiConfig" = AutoConfig.from_pretrained(model_id, trust_remote_code=True)


def dummy_inputs(model):
"""Get dummy inputs for merged decoder model with past_key_values."""
batch_size, sequence_length, past_sequence_length = 2, 8, config.num_hidden_layers
max_sequence_length = 512

return get_merged_sample_with_past_kv_inputs(
config,
torch.device("cpu"),
batch_size,
sequence_length,
past_sequence_length,
max_sequence_length,
use_fp16=False,
use_gqa=False,
engine="pt",
return_dict=True,
world_size=1,
)


def get_io_config(model):
input_names = [
"input_ids",
"attention_mask",
*list(
chain.from_iterable(
(f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(config.num_hidden_layers)
)
),
]
output_names = [
"logits",
*list(chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(config.num_hidden_layers))),
]
dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
return {
"input_names": input_names,
"output_names": output_names,
"dynamic_axes": dynamic_axes,
}


def create_dataloader(data_dir, batch_size, *args, **kwargs):
sequence_length, past_sequence_length = 8, config.num_hidden_layers
max_sequence_length = 512
model_framework = kwargs.get("model_framework", Framework.PYTORCH)
engine = "ort" if model_framework == Framework.ONNX else "pt"

return RandomDataLoader(batch_size, sequence_length, past_sequence_length, max_sequence_length, engine=engine)


class RandomDataLoader:
def __init__(
self,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
engine: str = "pt",
use_fp16: bool = False,
use_gqa: bool = False,
):
self.model_id = model_id
self.batch_size = batch_size
self.seq_len = seq_len
self.past_seq_len = past_seq_len
self.max_seq_len = max_seq_len
self.engine = engine
if use_gqa and (engine != "ort" or not use_fp16):
raise ValueError("GQA is only supported for ONNX model with FP16")
self.use_fp16 = use_fp16
self.use_gqa = use_gqa

def __getitem__(self, idx):
inputs = get_merged_sample_with_past_kv_inputs(
config,
device=torch.device("cpu"),
batch_size=self.batch_size,
seq_len=self.seq_len,
past_seq_len=self.past_seq_len,
max_seq_len=self.max_seq_len,
use_fp16=self.use_fp16,
use_gqa=self.use_gqa,
engine=self.engine,
return_dict=True,
)
return (inputs, None)


def get_position_ids(attention_mask: torch.Tensor, past_seq_len: int):
"""Get position_ids from attention_mask."""
# this is generic but in practice we only expect to see two scenarios for (past_seq_len, seq_len)
# prompt generation: (0, seq_len) -> position_ids = (batch_size, seq_len)
# token generation: (past_seq_len, 1) -> position_ids = (batch_size, 1)
# Note: The merged model only works in these two scenarios
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids[:, past_seq_len:]


def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]):
dynamic_axes = {}
for name in input_names + output_names:
if name in {"input_ids", "position_ids"}:
# shape is (batch_size, sequence_length)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif name == "attention_mask":
# shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length)
# for prompt generation, past_sequence_length = 0
# for token generation, sequence_length = 1
dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"}
elif "past" in name:
# shape is (batch_size, num_heads, past_sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 1: "max_sequence_length"}
elif name == "logits":
# shape is (batch_size, sequence_length, vocab_size)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif "present" in name:
# shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# = (batch_size, num_heads, total_sequence_length, head_size)
# for prompt generation, past_sequence_length = 0
# for token generation, sequence_length = 1
dynamic_axes[name] = {0: "batch_size", 1: "max_sequence_length"}
else:
raise Exception("Unknown input or output name found")
return dynamic_axes


# Inputs for all passes with past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, past_sequence_length + sequence_length)
# past_kv: (batch_size, max_sequence_length, 2, num_heads, head_size)
def get_merged_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
use_gqa: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, past_seq_len)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)

# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)

# ruff: noqa: C417
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)

if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, past_kv, attention_mask, position_ids)

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)

if use_gqa:
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)

else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv

return inputs


def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]):
"""Flatten past_key_values to a dict of past_key and past_value. For ONNX model only."""
past_kv = {}
# Convert list of past_kv to dict of past_key and past_value
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k
past_kv[f"past_key_values.{i}.value"] = past_v
return past_kv


def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
for k, v in ort_inputs.items():
# Allocate new buffers with max_sequence_length for GQA
if "cache" in k or "past_key_values" in k:
# Copy v (BxSxPxH) into new_v (BxSxMxH)
batch_size, num_heads, _, head_size = v.shape
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
ort_inputs[k] = new_v
return ort_inputs


def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads = config.num_attention_heads // world_size
head_size = config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
return [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]

0 comments on commit 26d9287

Please sign in to comment.