-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconvert.py
211 lines (173 loc) · 9.38 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import coremltools as ct
import numpy as np
from datetime import datetime
from models.gpt2 import GPT as GPT2
from models.pythia import GPT as Pythia
from src.utils.psnr import compute_psnr
from src.utils.trace_warnings import silence_known_trace_warnings
import argparse
import gc
import sys
import platform
"""
Convert a modified nanoGPT or Huggingface pythia to CoreML.
"""
all_names = GPT2.model_names() + Pythia.model_names()
parser = argparse.ArgumentParser(description='Convert a model to CoreML.')
parser.add_argument('--model_name', choices=all_names, default="gpt2", type=str)
parser.add_argument('--low_memory', help="use less memory at the cost of slower conversion. useful for large models.", action="store_true")
parser.add_argument('--float16_mode', choices=['auto', 'force'], default="auto", type=str, help="whether the converted model uses float16 or float32 inputs. 'auto' chooses the fastest that the current device's OS supports")
args = parser.parse_args()
# float16 inference is only supported on macOS13/iOS16 and higher.
supports_float16 = int(platform.mac_ver()[0].split('.')[0]) >= 13
use_float16 = supports_float16 or args.float16_mode == "force"
if not supports_float16:
print("float16 inputs and outputs are only supported on macOS13/iOS16 and higher.")
print("Converting with float32 inputs and outputs instead, so you can run it on this device.")
print("If you plan to deploy to a newer device, pass --float16_mode force to use float16 instead.")
if args.float16_mode == "force":
print("Forcing conversion to use float16 inputs and outputs.")
file_suffix = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
model_name = args.model_name
model_filename = model_name.split("/")[-1] + "_" + file_suffix
retrace = True
if retrace:
print(f"Loading model {model_name}...")
model_class = GPT2 if model_filename.startswith("gpt2") else Pythia
torch_model = model_class.from_pretrained(model_name).eval()
sample_inputs = torch_model.sample_inputs()
output_types = torch_model.output_types()
if not use_float16:
sample_inputs = {k: v.to(torch.float32) for k,v in sample_inputs.items()}
output_types = {k: torch.float32 if v == torch.float16 else v for k,v in output_types.items()}
torch_sample_inputs = list(sample_inputs.values())
is_multi_output = len(torch_model.output_types()) > 1
print(f"Tracing the model with {len(torch_sample_inputs)} inputs...")
with silence_known_trace_warnings(model_name):
traced_model = torch.jit.trace(torch_model, torch_sample_inputs)
else:
print("Loading from saved file.")
traced_model = torch.jit.load(f"{model_filename}.pt")
# print(traced_model)
print("Trace finished.")
print("Beginning conversion...")
def op_selector(op):
"""
Return true to use float16 for the op. Must be f16 to run on Neural Engine.
You can find op_type by looking in Netron and/or print out the op type/name here
(usually the names contain a variable name).
"""
# LayerNorm is where we lose most of our precision. From experiments
# in optimizing for ANE, it's most likely the computing the first mean,
# but using the non-ANE-optimized architecture we have to float32 the whole layer norm.
# TODO: This may no longer be necessary on iOS17+, try it.
return op.op_type not in ["layer_norm"]
compute_precision=ct.precision.FLOAT16
if model_name in ["gpt2"]:
print("Using float32 computation for layer_norm otherwise the precision lost is too large.")
print("Larger models can use all float16.") #... and run purely on the neural engine.
compute_precision=ct.transform.FP16ComputePrecision(op_selector)
if args.low_memory:
del token_predictor
gc.collect()
coreml_sample_inputs = {k: v.numpy() for k,v in sample_inputs.items()}
mlmodel = ct.convert(
traced_model,
inputs=[
ct.TensorType(name=k, shape=v.shape, dtype=v.dtype, default_value=np.zeros(v.shape, dtype=v.dtype) if k == "kv_cache" else None)
for k,v in coreml_sample_inputs.items()
],
outputs=[
# No better way to convert dtypes?
ct.TensorType(name=k, dtype=torch.tensor(0, dtype=v).numpy().dtype)
for k,v in output_types.items()
],
compute_precision=compute_precision,
minimum_deployment_target=ct.target.iOS16, # To allow float16 inputs + outputs.
convert_to="mlprogram",
)
print("Conversion finished.")
if args.low_memory:
del traced_token_predictor
gc.collect()
print("Saving...")
mlmodel.save(f"{model_filename}.mlpackage")
del mlmodel
gc.collect()
print("Adding metadata...")
mlmodel = ct.models.MLModel(f"{model_filename}.mlpackage", skip_model_load=True)
# TODO: Clean up.
pretty_name = {
"gpt2": "gpt2 (124M)",
"gpt2-medium": "gpt2-medium (350M)",
"gpt2-large": "gpt2-large (774M)",
"gpt2-xl": "gpt2-xl (1558M)",
}.get(model_name, model_name)
model_family = [x for x in ["gpt2", "pythia"] if x in model_name][0]
eos_token_id = {"gpt2": 50256, "pythia": 0}[model_family]
based_on = {"gpt2": "nanoGPT", "pythia": "the HuggingFace implementation"}[model_family]
vocab_size = {"gpt2": 50257, "pythia": 50304}[model_family] if model_name != "pythia-6.9b" else 50432
mlmodel.short_description = f"{pretty_name} for text generation. Based on {based_on}. Optimized for Apple Neural Engine."
input_keys = list(sample_inputs.keys())
input_pad_side = {"gpt2": "left", "pythia": "right"}[model_family]
has_output_mask = "output_mask" in input_keys
logits_element_description = "element of input_ids specified by output_mask" if has_output_mask else "next element after input_ids"
input_output_descriptions = {
# Common
"input_ids": f"Input tokens. e.g. from the huggingface {model_family} tokenizer. Pad to the full length with {eos_token_id} (eos) on the {input_pad_side}.",
"logits": f"Predictions for the {logits_element_description} in the shape (1, 1, {vocab_size}). ",
# KV Cache
"full_sequence_length": "The length of the full input tokens. This length excludes padding and includes tokens that have moved outside of input_ids' sliding window.",
"kv_cache": "Intermediary outputs from the prior prediction. For the first prediction, pass nothing to use the default array of all zeros. For subsequent predictions, pass the appropriate *_kv_cache output from the previous prediction.",
"prompt_kv_cache": "Intermediary outputs for the next prediction. Pass as the kv_cache input to the next prediction when evaluating the initial prompt.",
"generation_kv_cache": "Intermediary outputs for the next prediction. Pass as the kv_cache input to the next prediction after evaluating the initial prompt.",
# No KV Cache
"output_mask": "A single element array with the index of your sequence to predict. If your non-padded input length was N, pass [N-1].",
}
for k in input_keys:
mlmodel.input_description[k] = input_output_descriptions[k]
for k in output_types.keys():
mlmodel.output_description[k] = input_output_descriptions[k]
mlmodel.user_defined_metadata["Converted By"] = "http://twitter.com/flat"
mlmodel.user_defined_metadata["URL"] = "https://github.com/smpanaro/more-ane-transformers"
if not args.low_memory:
print("Saving...")
# Workaround to save metadata: https://github.com/apple/coremltools/issues/1680
to_save = ct.models.MLModel(mlmodel._spec,
weights_dir=mlmodel._weights_dir,
is_temp_package=True)
to_save.save(f"{model_filename}.mlpackage")
if args.low_memory:
print("Skipping model comparison due to low memory mode.")
print("Conversion complete.")
sys.exit(0)
# Always compare in float32 so we don't overflow.
with torch.no_grad():
og_out = torch_model(*torch_sample_inputs)
og_out = og_out[0] if isinstance(og_out, tuple) else og_out
og_out = og_out.to(torch.float32)
tr_out = traced_model(*torch_sample_inputs)
tr_out = tr_out[0] if isinstance(tr_out, tuple) else tr_out
tr_out = tr_out.to(torch.float32)
# Hanging here? It's very likely your intputs are the wrong shape and/or types.
print("predicting with mlmodel")#, input_ids.shape, input_ids.dtype)
cm_out = mlmodel.predict(coreml_sample_inputs)
cm_out = torch.from_numpy(cm_out["logits"]).to(torch.float32)
assert og_out.shape == cm_out.shape, f"{og_out.shape} != {cm_out.shape}"
assert og_out.dtype == cm_out.dtype, f"{og_out.dtype} != {cm_out.dtype}"
trace_psnr = compute_psnr(og_out, tr_out)
if trace_psnr < 200:
print(f"tracing PSNR too low ({trace_psnr}), CoreML model will likely be unusable")
print("\nfinished. these should be >60, ideally much higher (inf is perfect). lower and the model may not be usable")
print("coreml-traced psnr:", compute_psnr(tr_out.numpy(), cm_out.numpy()))
print("coreml-original psnr:", compute_psnr(og_out.numpy(), cm_out.numpy()))
if model_name in ["gpt2-xl"]:
print("\n👋 This model is big. It will run on CPU and GPU as-is, but to run on the Neural Engine there are a few extra steps.")
print("You can also download a version that runs on the Neural Engine from the releases tab on GitHub.")
print("If you want to build it yourself follow these steps:")
print("1. Install coremltools >= 6.3")
print(f"2. Run: python -m src.experiments.chunk_model --mlpackage-path {model_filename}.mlpackage -o .")
print(f"3. Run: python -m src.experiments.make_pipeline {model_filename}_chunk1.mlpackage")
print("Use the output *-pipeline.mlpackage with generate.py as usual.")