Skip to content

Commit

Permalink
add transformer optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
guotuofeng committed Jan 12, 2024
1 parent 7c4fed9 commit 844ca29
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
20 changes: 18 additions & 2 deletions examples/phi2/phi2_optimize.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@
"convert_optimum": {
"type": "OptimumConversion",
"config": {
"target_opset": 14,
"target_opset": 17,
"extra_args": {
"legacy": false,
"no_post_process": false
"no_post_process": true
}
}
},
Expand All @@ -71,6 +71,21 @@
"all_tensors_to_one_file": true
}
},
"optimize": {
"type": "OrtTransformersOptimization",
"config": {
"model_type": "gpt2",
"use_gpu": false,
"keep_io_types": true,
"num_heads": 32,
"hidden_size": 2560,
"optimization_options": {
"use_multi_head_attention": false
},
"save_as_external_data": true,
"all_tensors_to_one_file": true
}
},
"perf_tuning": {
"type": "OrtPerfTuning",
"config": {
Expand All @@ -84,6 +99,7 @@
"pass_flows": [
[
"convert",
"optimize",
"perf_tuning"
]
],
Expand Down
2 changes: 2 additions & 0 deletions examples/phi2/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ python -m olive.workflows.run --config phi2_optimize.json

## Limitations
From the time being, the official [phi2](https://huggingface.co/microsoft/phi-2) model could not be exported to ONNX by using the official model code. Therefore, we need patch the forward method to do preprocessing and postprocessing for the past_key_values arguments. When the official model could be exported to ONNX, we will remove this patch.

when https://github.com/huggingface/optimum/issues/1642 is fixed, we need turn on the post_process by changing `no_post_process` to False in the config file.
7 changes: 4 additions & 3 deletions examples/phi2/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ def get_merged_sample_with_past_kv_inputs(
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
# position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
position_ids = get_position_ids(attention_mask, past_seq_len)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)

# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
# position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)

# ruff: noqa: C417
past_kv = (
Expand All @@ -183,11 +183,12 @@ def get_merged_sample_with_past_kv_inputs(
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, past_kv, attention_mask)
return (input_ids, past_kv, attention_mask, position_ids)

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
Expand Down

0 comments on commit 844ca29

Please sign in to comment.