From 93d945916a48decfc071fac4d6a5edae242f606f Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 17 Oct 2023 22:06:14 +0800 Subject: [PATCH 1/9] large model exporter --- .../transformers/large_model_exporter.py | 236 ++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 onnxruntime/python/tools/transformers/large_model_exporter.py diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py new file mode 100644 index 0000000000000..e41e96c92fa3f --- /dev/null +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -0,0 +1,236 @@ +import argparse +import os +import shutil +from pathlib import Path + +import onnx +import torch +import transformers + + +class Exporter: + """ + A class for exporting large transformer models to ONNX format. + """ + + def __init__(self): + self.model = None + + def disable_huggingface_init(self): + # do not init model twice as it slow initialization + import torch + import torch.nn.init + torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.normal_ = lambda x, *args, **kwargs: x + torch.nn.init.constant_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x + + def get_Model_Size(self): + param_size = 0 + param_sum = 0 + for param in self.model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in self.model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + return all_size + + def set_model(self, hf_model, tokenizer=None): + self.onnx_name = Path(hf_model+"/").name + import re + self.onnx_name = re.sub(r'[^0-9a-zA-Z]', self.onnx_name, '_')+'.onnx' + self.disable_huggingface_init() + + self.model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model, torch_dtype=torch.float16, trust_remote_code=True) + if tokenizer is None: + tokenizer = hf_model + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) + + self.sample_inputs = list(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) + + def run(self, onnx_path): + self.export_onnx(self.model, onnx_path, self.sample_inputs) + + def pipeline_to_multiple_gpu(self, model, gpulist: list, sample_inputs): + def input_gpu_device_hook(mod, inputs, kwargs): + modifyed_inputs = [] + first_dev = None + for layer_input in inputs: + if type(layer_input) is not torch.Tensor: + modifyed_inputs.append(layer_input) + elif hasattr(mod, 'weight'): + modifyed_inputs.append(layer_input.to(mod.weight.device)) + elif hasattr(mod, 'parameters'): + device = next(mod.parameters(), layer_input).device + modifyed_inputs.append(layer_input.to(device)) + elif hasattr(next(mod.children(), None), 'weight'): + modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) + elif first_dev is not None and layer_input.device != first_dev: + modifyed_inputs.append(layer_input.to(first_dev)) + else: + modifyed_inputs.append(layer_input) + if first_dev is None: + first_dev = modifyed_inputs[0].device + for key, value in kwargs.items(): + if type(value) is torch.Tensor: + kwargs[key] = value.to(first_dev) + + return (tuple(modifyed_inputs), kwargs) + + def move_layer_to_device_rurc(mod, dev): + mod.to(dev) + for layer in mod.named_children(): + move_layer_to_device_rurc(layer[1], dev) + + model = model.half() + all_hooks = [] + # model.register_module_forward_pre_hook(input_gpu_device_hook) + all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + pre_fix = list(model.named_children())[0][0] + for top_name, top_module in model.named_children(): + for name, module in top_module.named_children(): + all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + if type(module) in [torch.nn.ModuleList]: + import math + num_layers_on_each_gpu = math.floor(len(module)/len(gpulist)) + for idx, attn_layer in enumerate(module): + all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + + to_dev = gpulist[min(idx//num_layers_on_each_gpu, len(gpulist))] + attn_layer.to(to_dev) + move_layer_to_device_rurc(attn_layer, to_dev) + print(f"move {pre_fix}.{name}.{idx} to {to_dev}") + else: + module.to(gpulist[0]) + print(f"move {pre_fix}.{name} to {gpulist[0]}") + if len(list(top_module.named_children())) == 0: + top_module.to(gpulist[0]) + print(f"move {top_name} to {gpulist[0]}") + + # for hook in all_hooks: + # hook.remove() + with torch.no_grad(): + out = model(sample_inputs[0], attention_mask=sample_inputs[1]) + # print(out) + return model + + def retrieve_onnx_inputs(self, sample_inputs): + model = self.model + user_inputs = [] + + def hook_for_inputs(mod, inputs, kwargs): + user_inputs.append((inputs, kwargs)) + return user_inputs[0] + hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) + import inspect + forward_params = inspect.signature(model.forward).parameters + input_keys = list(forward_params.keys()) + default_values = [forward_params.get(key).default for key in input_keys] + model(sample_inputs[0], attention_mask=sample_inputs[1]) + hook_handle.remove() + user_inputs = user_inputs[0] + onnx_inputs = default_values + for idx, val in enumerate(user_inputs[0]): + onnx_inputs[idx] = user_inputs[0][idx] + for key, value in user_inputs[1].items(): + idx = input_keys.index(key) + onnx_inputs[idx] = value + for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): + if type(value) is torch.Tensor: + value.to(model.device) + if 'use_cache' in key: + onnx_inputs[idx] = False + + return input_keys, tuple(onnx_inputs) + + @torch.no_grad() + def export_onnx(self, model, onnx_path, sample_inputs: tuple): + total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory/1024/1024 + + print("Model_Size", self.get_Model_Size()) + print("total_mem_per_cpu=", total_mem_per_cpu) + if self.get_Model_Size() > total_mem_per_cpu*0.45: + if torch.cuda.device_count() > 1: + print("multi-gpu") + device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + model = self.pipeline_to_multiple_gpu(model, device_collection, sample_inputs) + else: + print("cpu") + model = model.cpu().float() + else: + print("single GPU") + model = model.cuda().half() + + sample_inputs_ = [] + for ints in sample_inputs: + if type(ints) is torch.Tensor: + sample_inputs_.append(ints.to(model.device)) + else: + sample_inputs_.append(ints) + sample_inputs = sample_inputs_ + + input_keys, onnx_inputs = self.retrieve_onnx_inputs(sample_inputs) + + onnx_path = Path(onnx_path).absolute() + if onnx_path.suffix != '.onnx': + onnx_path = onnx_path/self.onnx_name + + onnx_filepath_export_multi_files_tmp = onnx_path.parent/'tmp/tmp.onnx' + onnx_filepath_export_multi_files_tmp.parent.exists() and shutil.rmtree(onnx_filepath_export_multi_files_tmp.parent) + os.makedirs(onnx_filepath_export_multi_files_tmp.parent) + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = {"input_ids": {0: 'batch_size', 1: "seq_len"}, + "attention_mask": {0: 'batch_size', 1: "seq_len"}} + torch.onnx.export(model=model, args=onnx_inputs, f=str(onnx_filepath_export_multi_files_tmp), + verbose=False, opset_version=16, + input_names=onnx_inp_names, output_names=onnx_out_names, dynamic_axes=onnx_dynamic_axes) + + onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) + + onnx_path.exists() and onnx_path.unlink() + (onnx_path.parent/f'{self.onnx_name}_ext.data').exists() and (onnx_path.parent / + f'{self.onnx_name}_ext.data').unlink() + onnx.save_model(onnx_model, str(onnx_path), save_as_external_data=True, all_tensors_to_one_file=True, + location=f"{self.onnx_name}_ext.data", size_threshold=1024, convert_attribute=False) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + required=True, + type=str, + default=["meta-llama/Llama-2-70b-hf"], + help="Pre-trained models in huggingface model hub" + ) + parser.add_argument( + "-s", + "--saved_path", + required=False, + type=str, + default="./onnx_models/", + help="where the onnx model will be saved" + ) + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_arguments() + expoter = Exporter() + expoter.set_model(args.model) + expoter.run(args.saved_path) From 6fc808d31323a6147f9a4b4a45da0ca0421952bb Mon Sep 17 00:00:00 2001 From: JiCheng Date: Wed, 18 Oct 2023 03:27:04 +0000 Subject: [PATCH 2/9] fix --- .../transformers/large_model_exporter.py | 430 ++++++++++-------- 1 file changed, 238 insertions(+), 192 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index e41e96c92fa3f..4e4b5d6d67f05 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -1,3 +1,11 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Export LLM to onnx +""" import argparse import os import shutil @@ -6,206 +14,245 @@ import onnx import torch import transformers +from torch import nn + + +def disable_huggingface_init(): + # do not init model twice as it slow initialization + import torch + import torch.nn.init + + torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.normal_ = lambda x, *args, **kwargs: x + torch.nn.init.constant_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x + + +def get_model_size(model: nn.Module): + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + return all_size + + +def model_prepare(hf_model: str, tokenizer=None): + """ + prepare torch model, name, and inputs + """ + onnx_model_name = Path(hf_model + "/").name + import re + onnx_model_name = re.sub(r"[^0-9a-zA-Z]", onnx_model_name, "_") + ".onnx" + disable_huggingface_init() + + model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore + hf_model, torch_dtype=torch.float16, trust_remote_code=True + ) + if tokenizer is None: + tokenizer = hf_model + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore -class Exporter: + sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) + return onnx_model_name, model, sample_inputs + + +def auto_pipeline(model: nn.Module, gpulist: list, sample_inputs: tuple): """ - A class for exporting large transformer models to ONNX format. + make a model can be executed across multile-gpu. + it's a pipeline method """ - def __init__(self): - self.model = None - - def disable_huggingface_init(self): - # do not init model twice as it slow initialization - import torch - import torch.nn.init - torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x - torch.nn.init.uniform_ = lambda x, *args, **kwargs: x - torch.nn.init.normal_ = lambda x, *args, **kwargs: x - torch.nn.init.constant_ = lambda x, *args, **kwargs: x - torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x - torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x - torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x - torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x - - def get_Model_Size(self): - param_size = 0 - param_sum = 0 - for param in self.model.parameters(): - param_size += param.nelement() * param.element_size() - param_sum += param.nelement() - buffer_size = 0 - buffer_sum = 0 - for buffer in self.model.buffers(): - buffer_size += buffer.nelement() * buffer.element_size() - buffer_sum += buffer.nelement() - all_size = (param_size + buffer_size) / 1024 / 1024 - return all_size - - def set_model(self, hf_model, tokenizer=None): - self.onnx_name = Path(hf_model+"/").name - import re - self.onnx_name = re.sub(r'[^0-9a-zA-Z]', self.onnx_name, '_')+'.onnx' - self.disable_huggingface_init() - - self.model = transformers.AutoModelForCausalLM.from_pretrained( - hf_model, torch_dtype=torch.float16, trust_remote_code=True) - if tokenizer is None: - tokenizer = hf_model - tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) - - self.sample_inputs = list(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) - - def run(self, onnx_path): - self.export_onnx(self.model, onnx_path, self.sample_inputs) - - def pipeline_to_multiple_gpu(self, model, gpulist: list, sample_inputs): - def input_gpu_device_hook(mod, inputs, kwargs): - modifyed_inputs = [] - first_dev = None - for layer_input in inputs: - if type(layer_input) is not torch.Tensor: - modifyed_inputs.append(layer_input) - elif hasattr(mod, 'weight'): - modifyed_inputs.append(layer_input.to(mod.weight.device)) - elif hasattr(mod, 'parameters'): - device = next(mod.parameters(), layer_input).device - modifyed_inputs.append(layer_input.to(device)) - elif hasattr(next(mod.children(), None), 'weight'): - modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) - elif first_dev is not None and layer_input.device != first_dev: - modifyed_inputs.append(layer_input.to(first_dev)) - else: - modifyed_inputs.append(layer_input) - if first_dev is None: - first_dev = modifyed_inputs[0].device - for key, value in kwargs.items(): - if type(value) is torch.Tensor: - kwargs[key] = value.to(first_dev) - - return (tuple(modifyed_inputs), kwargs) - - def move_layer_to_device_rurc(mod, dev): - mod.to(dev) - for layer in mod.named_children(): - move_layer_to_device_rurc(layer[1], dev) - - model = model.half() - all_hooks = [] - # model.register_module_forward_pre_hook(input_gpu_device_hook) - all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) - pre_fix = list(model.named_children())[0][0] - for top_name, top_module in model.named_children(): - for name, module in top_module.named_children(): - all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) - if type(module) in [torch.nn.ModuleList]: - import math - num_layers_on_each_gpu = math.floor(len(module)/len(gpulist)) - for idx, attn_layer in enumerate(module): - all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) - - to_dev = gpulist[min(idx//num_layers_on_each_gpu, len(gpulist))] - attn_layer.to(to_dev) - move_layer_to_device_rurc(attn_layer, to_dev) - print(f"move {pre_fix}.{name}.{idx} to {to_dev}") - else: - module.to(gpulist[0]) - print(f"move {pre_fix}.{name} to {gpulist[0]}") - if len(list(top_module.named_children())) == 0: - top_module.to(gpulist[0]) - print(f"move {top_name} to {gpulist[0]}") - - # for hook in all_hooks: - # hook.remove() - with torch.no_grad(): - out = model(sample_inputs[0], attention_mask=sample_inputs[1]) - # print(out) - return model - - def retrieve_onnx_inputs(self, sample_inputs): - model = self.model - user_inputs = [] - - def hook_for_inputs(mod, inputs, kwargs): - user_inputs.append((inputs, kwargs)) - return user_inputs[0] - hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) - import inspect - forward_params = inspect.signature(model.forward).parameters - input_keys = list(forward_params.keys()) - default_values = [forward_params.get(key).default for key in input_keys] - model(sample_inputs[0], attention_mask=sample_inputs[1]) - hook_handle.remove() - user_inputs = user_inputs[0] - onnx_inputs = default_values - for idx, val in enumerate(user_inputs[0]): - onnx_inputs[idx] = user_inputs[0][idx] - for key, value in user_inputs[1].items(): - idx = input_keys.index(key) - onnx_inputs[idx] = value - for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): - if type(value) is torch.Tensor: - value.to(model.device) - if 'use_cache' in key: - onnx_inputs[idx] = False - - return input_keys, tuple(onnx_inputs) - - @torch.no_grad() - def export_onnx(self, model, onnx_path, sample_inputs: tuple): - total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory/1024/1024 - - print("Model_Size", self.get_Model_Size()) - print("total_mem_per_cpu=", total_mem_per_cpu) - if self.get_Model_Size() > total_mem_per_cpu*0.45: - if torch.cuda.device_count() > 1: - print("multi-gpu") - device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] - model = self.pipeline_to_multiple_gpu(model, device_collection, sample_inputs) + def input_gpu_device_hook(mod, inputs, kwargs): + modifyed_inputs = [] + first_dev = None + for layer_input in inputs: + if type(layer_input) is not torch.Tensor: + modifyed_inputs.append(layer_input) + elif hasattr(mod, "weight"): + modifyed_inputs.append(layer_input.to(mod.weight.device)) + elif hasattr(mod, "parameters"): + device = next(mod.parameters(), layer_input).device + modifyed_inputs.append(layer_input.to(device)) + elif hasattr(next(mod.children(), None), "weight"): + modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) + elif first_dev is not None and layer_input.device != first_dev: + modifyed_inputs.append(layer_input.to(first_dev)) else: - print("cpu") - model = model.cpu().float() - else: - print("single GPU") - model = model.cuda().half() - - sample_inputs_ = [] - for ints in sample_inputs: - if type(ints) is torch.Tensor: - sample_inputs_.append(ints.to(model.device)) + modifyed_inputs.append(layer_input) + if first_dev is None: + first_dev = modifyed_inputs[0].device + for key, value in kwargs.items(): + if type(value) is torch.Tensor: + kwargs[key] = value.to(first_dev) + + return (tuple(modifyed_inputs), kwargs) + + def move_layer_to_device_rurc(mod, dev): + mod.to(dev) + for layer in mod.named_children(): + move_layer_to_device_rurc(layer[1], dev) + + model = model.half() + all_hooks = [] + # model.register_module_forward_pre_hook(input_gpu_device_hook) + all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + pre_fix = next(iter(model.named_children()))[0] + for top_name, top_module in model.named_children(): + for name, module in top_module.named_children(): + all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + if type(module) in [torch.nn.ModuleList]: + import math + + num_layers_on_each_gpu = math.floor(len(module) / len(gpulist)) + for idx, attn_layer in enumerate(module): + all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + + to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))] + attn_layer.to(to_dev) + move_layer_to_device_rurc(attn_layer, to_dev) + print(f"move {pre_fix}.{name}.{idx} to {to_dev}") else: - sample_inputs_.append(ints) - sample_inputs = sample_inputs_ - - input_keys, onnx_inputs = self.retrieve_onnx_inputs(sample_inputs) - - onnx_path = Path(onnx_path).absolute() - if onnx_path.suffix != '.onnx': - onnx_path = onnx_path/self.onnx_name + module.to(gpulist[0]) + print(f"move {pre_fix}.{name} to {gpulist[0]}") + if len(list(top_module.named_children())) == 0: + top_module.to(gpulist[0]) + print(f"move {top_name} to {gpulist[0]}") - onnx_filepath_export_multi_files_tmp = onnx_path.parent/'tmp/tmp.onnx' - onnx_filepath_export_multi_files_tmp.parent.exists() and shutil.rmtree(onnx_filepath_export_multi_files_tmp.parent) - os.makedirs(onnx_filepath_export_multi_files_tmp.parent) + with torch.no_grad(): + model(sample_inputs[0], attention_mask=sample_inputs[1]) + return model - onnx_inp_names = ("input_ids", "attention_mask") - onnx_out_names = ("logits",) - onnx_dynamic_axes = {"input_ids": {0: 'batch_size', 1: "seq_len"}, - "attention_mask": {0: 'batch_size', 1: "seq_len"}} - torch.onnx.export(model=model, args=onnx_inputs, f=str(onnx_filepath_export_multi_files_tmp), - verbose=False, opset_version=16, - input_names=onnx_inp_names, output_names=onnx_out_names, dynamic_axes=onnx_dynamic_axes) - onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) +def retrieve_onnx_inputs(model, sample_inputs): + """ + auto retrieve onnx inputs from torch model as we can't enumlate all possibilities + for all models + """ + user_inputs = [] + + def hook_for_inputs(_, inputs, kwargs): + user_inputs.append((inputs, kwargs)) + return user_inputs[0] + + hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) + import inspect + + forward_params = inspect.signature(model.forward).parameters + input_keys = list(forward_params.keys()) + default_values = [forward_params.get(key).default for key in input_keys] + model(sample_inputs[0], attention_mask=sample_inputs[1]) + hook_handle.remove() + user_inputs = user_inputs[0] + onnx_inputs = default_values + for idx, _val in enumerate(user_inputs[0]): + onnx_inputs[idx] = user_inputs[0][idx] + for key, value in user_inputs[1].items(): + idx = input_keys.index(key) + onnx_inputs[idx] = value + for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): + if type(value) is torch.Tensor: + value.to(model.device) + # Didn't touch past_key_value now, please change it if you want + if "use_cache" in key: + onnx_inputs[idx] = False + + return input_keys, tuple(onnx_inputs) + + +@torch.no_grad() +def export_onnx(hf_model: str, onnx_path_str: str): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + onnx_model_name, model, sample_inputs_tp = model_prepare(hf_model) + total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 + + print("Model_Size", get_model_size(model)) + print("total_mem_per_cpu=", total_mem_per_cpu) + if get_model_size(model) > total_mem_per_cpu * 0.45: + if torch.cuda.device_count() > 1: + print("multi-gpu") + device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + model = auto_pipeline(model, device_collection, sample_inputs_tp) + else: + print("cpu") + model = model.cpu().float() + else: + print("single GPU") + model = model.cuda().half() + + sample_inputs = [] + for ints in sample_inputs_tp: + if type(ints) is torch.Tensor: + sample_inputs.append(ints.to(model.device)) + else: + sample_inputs.append(ints) + + input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs) + + onnx_path: Path = Path(onnx_path_str).absolute() + if onnx_path.suffix != ".onnx": + onnx_path = onnx_path / onnx_model_name + + onnx_filepath_export_multi_files_tmp = onnx_path.parent / "tmp/tmp.onnx" + onnx_filepath_export_multi_files_tmp.parent.exists() and shutil.rmtree(onnx_filepath_export_multi_files_tmp.parent) + os.makedirs(onnx_filepath_export_multi_files_tmp.parent) + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + torch.onnx.export( + model=model, + args=onnx_inputs, + f=str(onnx_filepath_export_multi_files_tmp), + verbose=False, + opset_version=16, + input_names=onnx_inp_names, + output_names=onnx_out_names, + dynamic_axes=onnx_dynamic_axes, + ) - onnx_path.exists() and onnx_path.unlink() - (onnx_path.parent/f'{self.onnx_name}_ext.data').exists() and (onnx_path.parent / - f'{self.onnx_name}_ext.data').unlink() - onnx.save_model(onnx_model, str(onnx_path), save_as_external_data=True, all_tensors_to_one_file=True, - location=f"{self.onnx_name}_ext.data", size_threshold=1024, convert_attribute=False) + onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) + + onnx_path.exists() and onnx_path.unlink() + (onnx_path.parent / f"{onnx_model_name}_ext.data").exists() and ( + onnx_path.parent / f"{onnx_model_name}_ext.data" + ).unlink() + onnx.save_model( + onnx_model, + str(onnx_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{onnx_model_name}_ext.data", + size_threshold=1024, + convert_attribute=False, + ) def parse_arguments(): + """ + args parse + model: + onnx_path: + """ parser = argparse.ArgumentParser() parser.add_argument( @@ -214,7 +261,7 @@ def parse_arguments(): required=True, type=str, default=["meta-llama/Llama-2-70b-hf"], - help="Pre-trained models in huggingface model hub" + help="Pre-trained models in huggingface model hub", ) parser.add_argument( "-s", @@ -222,15 +269,14 @@ def parse_arguments(): required=False, type=str, default="./onnx_models/", - help="where the onnx model will be saved" + help="where the onnx model will be saved", ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() - expoter = Exporter() - expoter.set_model(args.model) - expoter.run(args.saved_path) + + export_onnx(args.model, args.saved_path) From 05de1027a3d241ced40a28ff7671f8895c81b5b2 Mon Sep 17 00:00:00 2001 From: JiCheng Date: Wed, 18 Oct 2023 12:35:36 +0800 Subject: [PATCH 3/9] Apply suggestions from code review Co-authored-by: Justin Chu --- .../python/tools/transformers/large_model_exporter.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 4e4b5d6d67f05..ce267ecd05418 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -19,8 +19,6 @@ def disable_huggingface_init(): # do not init model twice as it slow initialization - import torch - import torch.nn.init torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x torch.nn.init.uniform_ = lambda x, *args, **kwargs: x @@ -47,12 +45,11 @@ def get_model_size(model: nn.Module): return all_size -def model_prepare(hf_model: str, tokenizer=None): +def initialize_model_and_sample_inputs(hf_model: str, tokenizer=None): """ prepare torch model, name, and inputs """ onnx_model_name = Path(hf_model + "/").name - import re onnx_model_name = re.sub(r"[^0-9a-zA-Z]", onnx_model_name, "_") + ".onnx" disable_huggingface_init() @@ -69,10 +66,7 @@ def model_prepare(hf_model: str, tokenizer=None): def auto_pipeline(model: nn.Module, gpulist: list, sample_inputs: tuple): - """ - make a model can be executed across multile-gpu. - it's a pipeline method - """ + """Make the model executable across multiple GPUs.""" def input_gpu_device_hook(mod, inputs, kwargs): modifyed_inputs = [] @@ -147,7 +141,6 @@ def hook_for_inputs(_, inputs, kwargs): return user_inputs[0] hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) - import inspect forward_params = inspect.signature(model.forward).parameters input_keys = list(forward_params.keys()) From 26d6374725788c826e359722ac0d4d719578bb8b Mon Sep 17 00:00:00 2001 From: JiCheng Date: Wed, 18 Oct 2023 12:38:07 +0800 Subject: [PATCH 4/9] Update onnxruntime/python/tools/transformers/large_model_exporter.py Co-authored-by: Justin Chu --- onnxruntime/python/tools/transformers/large_model_exporter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index ce267ecd05418..879d166f6dacc 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -226,9 +226,7 @@ def export_onnx(hf_model: str, onnx_path_str: str): onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) onnx_path.exists() and onnx_path.unlink() - (onnx_path.parent / f"{onnx_model_name}_ext.data").exists() and ( - onnx_path.parent / f"{onnx_model_name}_ext.data" - ).unlink() + (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) onnx.save_model( onnx_model, str(onnx_path), From 99047e9c3c7583877f223911856a1bcc07753ec7 Mon Sep 17 00:00:00 2001 From: JiCheng Date: Wed, 18 Oct 2023 04:47:46 +0000 Subject: [PATCH 5/9] fix --- .../transformers/large_model_exporter.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 879d166f6dacc..d659f2b8ebe32 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -7,7 +7,10 @@ Export LLM to onnx """ import argparse +import inspect +import math import os +import re import shutil from pathlib import Path @@ -18,7 +21,7 @@ def disable_huggingface_init(): - # do not init model twice as it slow initialization + """do not init model twice as it slow initialization""" torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x torch.nn.init.uniform_ = lambda x, *args, **kwargs: x @@ -30,7 +33,8 @@ def disable_huggingface_init(): torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x -def get_model_size(model: nn.Module): +def get_model_parameter_size(model: nn.Module): + """to calculate how much memory this model needs""" param_size = 0 param_sum = 0 for param in model.parameters(): @@ -47,7 +51,8 @@ def get_model_size(model: nn.Module): def initialize_model_and_sample_inputs(hf_model: str, tokenizer=None): """ - prepare torch model, name, and inputs + get the pretrained torch model from hugginface, + and onnx_model_name, sample model-inputs """ onnx_model_name = Path(hf_model + "/").name @@ -65,7 +70,7 @@ def initialize_model_and_sample_inputs(hf_model: str, tokenizer=None): return onnx_model_name, model, sample_inputs -def auto_pipeline(model: nn.Module, gpulist: list, sample_inputs: tuple): +def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple): """Make the model executable across multiple GPUs.""" def input_gpu_device_hook(mod, inputs, kwargs): @@ -107,8 +112,6 @@ def move_layer_to_device_rurc(mod, dev): for name, module in top_module.named_children(): all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) if type(module) in [torch.nn.ModuleList]: - import math - num_layers_on_each_gpu = math.floor(len(module) / len(gpulist)) for idx, attn_layer in enumerate(module): all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) @@ -172,16 +175,16 @@ def export_onnx(hf_model: str, onnx_path_str: str): onnx_path: where the onnx model saved to sample_inputs_tp: inputs for torch model """ - onnx_model_name, model, sample_inputs_tp = model_prepare(hf_model) + onnx_model_name, model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model) total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 - print("Model_Size", get_model_size(model)) + print("Model_Size", get_model_parameter_size(model)) print("total_mem_per_cpu=", total_mem_per_cpu) - if get_model_size(model) > total_mem_per_cpu * 0.45: + if get_model_parameter_size(model) > total_mem_per_cpu * 0.45: if torch.cuda.device_count() > 1: print("multi-gpu") device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] - model = auto_pipeline(model, device_collection, sample_inputs_tp) + model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp) else: print("cpu") model = model.cpu().float() @@ -190,19 +193,20 @@ def export_onnx(hf_model: str, onnx_path_str: str): model = model.cuda().half() sample_inputs = [] - for ints in sample_inputs_tp: - if type(ints) is torch.Tensor: - sample_inputs.append(ints.to(model.device)) + for sample_int in sample_inputs_tp: + if isinstance(sample_int, torch.Tensor): + sample_inputs.append(sample_int.to(model.device)) else: - sample_inputs.append(ints) + sample_inputs.append(sample_int) + # input_keys would be usesful if the model has some special inputs input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs) onnx_path: Path = Path(onnx_path_str).absolute() if onnx_path.suffix != ".onnx": onnx_path = onnx_path / onnx_model_name - onnx_filepath_export_multi_files_tmp = onnx_path.parent / "tmp/tmp.onnx" + onnx_filepath_export_multi_files_tmp: Path = onnx_path.parent / "tmp/tmp.onnx" onnx_filepath_export_multi_files_tmp.parent.exists() and shutil.rmtree(onnx_filepath_export_multi_files_tmp.parent) os.makedirs(onnx_filepath_export_multi_files_tmp.parent) @@ -225,7 +229,7 @@ def export_onnx(hf_model: str, onnx_path_str: str): onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) - onnx_path.exists() and onnx_path.unlink() + onnx_path.unlink(missing_ok=True) (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) onnx.save_model( onnx_model, @@ -240,7 +244,7 @@ def export_onnx(hf_model: str, onnx_path_str: str): def parse_arguments(): """ - args parse + arguments parsing. model: onnx_path: """ From 23ed3936c41c5232e1235eed67a811f714110b9b Mon Sep 17 00:00:00 2001 From: JiCheng Date: Thu, 19 Oct 2023 03:41:11 +0000 Subject: [PATCH 6/9] fix --- .../transformers/large_model_exporter.py | 145 +++++++++++------- 1 file changed, 88 insertions(+), 57 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index d659f2b8ebe32..63e52ba5cbfbf 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -11,7 +11,7 @@ import math import os import re -import shutil +import tempfile from pathlib import Path import onnx @@ -167,37 +167,58 @@ def hook_for_inputs(_, inputs, kwargs): return input_keys, tuple(onnx_inputs) -@torch.no_grad() -def export_onnx(hf_model: str, onnx_path_str: str): +def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: """ - do export - model: torch model - onnx_path: where the onnx model saved to - sample_inputs_tp: inputs for torch model + According to the model size, we will upload it to + CPU if has no GPU or enough GPU memory, + Single GPU if has only one GPU in local or model size is enough to fit one GPU + Multiple GPU if there is more than one gpu in local and model is too large """ - onnx_model_name, model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model) total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 - print("Model_Size", get_model_parameter_size(model)) - print("total_mem_per_cpu=", total_mem_per_cpu) + print(f"Model_Size = {get_model_parameter_size(model)/1024} GB") + print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB") if get_model_parameter_size(model) > total_mem_per_cpu * 0.45: - if torch.cuda.device_count() > 1: - print("multi-gpu") - device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + if len(device_collection) > 1: + print( + f"{len(device_collection)} GPUs are used to export onnx, \ + Please set CUDA_VISIBLE_DEVICES to use specific GPU group" + ) model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp) else: - print("cpu") + print("!!!! convert model to float and export onnx using CPU") model = model.cpu().float() else: - print("single GPU") + print("Export model on a single GPU") model = model.cuda().half() + return model - sample_inputs = [] - for sample_int in sample_inputs_tp: + +def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple: + """move inputs to device""" + sample_inputs_ = [] + for sample_int in sample_inputs: if isinstance(sample_int, torch.Tensor): - sample_inputs.append(sample_int.to(model.device)) + sample_inputs_.append(sample_int.to(device)) else: - sample_inputs.append(sample_int) + sample_inputs_.append(sample_int) + return tuple(sample_inputs_) + + +@torch.no_grad() +def export_onnx(hf_model: str, onnx_path_str: str, opset: int = 17): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + onnx_model_name, model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model) + + model = move_to_approprate_device(model, sample_inputs_tp) + + sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) # input_keys would be usesful if the model has some special inputs input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs) @@ -206,40 +227,42 @@ def export_onnx(hf_model: str, onnx_path_str: str): if onnx_path.suffix != ".onnx": onnx_path = onnx_path / onnx_model_name - onnx_filepath_export_multi_files_tmp: Path = onnx_path.parent / "tmp/tmp.onnx" - onnx_filepath_export_multi_files_tmp.parent.exists() and shutil.rmtree(onnx_filepath_export_multi_files_tmp.parent) - os.makedirs(onnx_filepath_export_multi_files_tmp.parent) - - onnx_inp_names = ("input_ids", "attention_mask") - onnx_out_names = ("logits",) - onnx_dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "attention_mask": {0: "batch_size", 1: "seq_len"}, - } - torch.onnx.export( - model=model, - args=onnx_inputs, - f=str(onnx_filepath_export_multi_files_tmp), - verbose=False, - opset_version=16, - input_names=onnx_inp_names, - output_names=onnx_out_names, - dynamic_axes=onnx_dynamic_axes, - ) - - onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) - - onnx_path.unlink(missing_ok=True) - (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) - onnx.save_model( - onnx_model, - str(onnx_path), - save_as_external_data=True, - all_tensors_to_one_file=True, - location=f"{onnx_model_name}_ext.data", - size_threshold=1024, - convert_attribute=False, - ) + # two step to export onnx + # 1. export onnx with lots of pieces of weights + # 2. save all weights to external data + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + torch.onnx.export( + model=model, + args=onnx_inputs, + f=tmp_onnx, + verbose=False, + opset_version=opset, + input_names=onnx_inp_names, + output_names=onnx_out_names, + dynamic_axes=onnx_dynamic_axes, + ) + + onnx_path.unlink(missing_ok=True) + (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) + + onnx_model = onnx.load(str(tmp_onnx)) + onnx.save_model( + onnx_model, + str(onnx_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{onnx_model_name}_ext.data", + size_threshold=1024, + convert_attribute=False, + ) def parse_arguments(): @@ -266,12 +289,20 @@ def parse_arguments(): default="./onnx_models/", help="where the onnx model will be saved", ) - - args = parser.parse_args() - return args + parser.add_argument( + "--opset", + required=False, + type=int, + default=17, + help=( + "the opset to save onnx model, \ + try to increase it if this opset doens't have new features you want" + ), + ) + return parser.parse_args() if __name__ == "__main__": args = parse_arguments() - export_onnx(args.model, args.saved_path) + export_onnx(args.model, args.saved_path, args.opset) From 0c0c6e3c5275144c1d4f3ffa3b3ad725cb0ed287 Mon Sep 17 00:00:00 2001 From: JiCheng Date: Thu, 19 Oct 2023 03:58:32 +0000 Subject: [PATCH 7/9] fix --- .../transformers/large_model_exporter.py | 43 +++++++++++-------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 63e52ba5cbfbf..019e189783815 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -10,9 +10,9 @@ import inspect import math import os -import re import tempfile from pathlib import Path +from typing import Optional import onnx import torch @@ -49,25 +49,23 @@ def get_model_parameter_size(model: nn.Module): return all_size -def initialize_model_and_sample_inputs(hf_model: str, tokenizer=None): +def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None): """ get the pretrained torch model from hugginface, - and onnx_model_name, sample model-inputs + and sample model-inputs """ - onnx_model_name = Path(hf_model + "/").name - onnx_model_name = re.sub(r"[^0-9a-zA-Z]", onnx_model_name, "_") + ".onnx" disable_huggingface_init() model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore - hf_model, torch_dtype=torch.float16, trust_remote_code=True + hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True ) if tokenizer is None: tokenizer = hf_model tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) - return onnx_model_name, model, sample_inputs + return model, sample_inputs def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple): @@ -105,7 +103,6 @@ def move_layer_to_device_rurc(mod, dev): model = model.half() all_hooks = [] - # model.register_module_forward_pre_hook(input_gpu_device_hook) all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) pre_fix = next(iter(model.named_children()))[0] for top_name, top_module in model.named_children(): @@ -132,7 +129,7 @@ def move_layer_to_device_rurc(mod, dev): return model -def retrieve_onnx_inputs(model, sample_inputs): +def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool): """ auto retrieve onnx inputs from torch model as we can't enumlate all possibilities for all models @@ -162,7 +159,7 @@ def hook_for_inputs(_, inputs, kwargs): value.to(model.device) # Didn't touch past_key_value now, please change it if you want if "use_cache" in key: - onnx_inputs[idx] = False + onnx_inputs[idx] = with_past return input_keys, tuple(onnx_inputs) @@ -207,14 +204,14 @@ def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple: @torch.no_grad() -def export_onnx(hf_model: str, onnx_path_str: str, opset: int = 17): +def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): """ do export model: torch model onnx_path: where the onnx model saved to sample_inputs_tp: inputs for torch model """ - onnx_model_name, model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model) + model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) model = move_to_approprate_device(model, sample_inputs_tp) @@ -223,6 +220,7 @@ def export_onnx(hf_model: str, onnx_path_str: str, opset: int = 17): # input_keys would be usesful if the model has some special inputs input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs) + onnx_model_name = "model.onnx" onnx_path: Path = Path(onnx_path_str).absolute() if onnx_path.suffix != ".onnx": onnx_path = onnx_path / onnx_model_name @@ -266,11 +264,7 @@ def export_onnx(hf_model: str, onnx_path_str: str, opset: int = 17): def parse_arguments(): - """ - arguments parsing. - model: - onnx_path: - """ + """arguments parsing.""" parser = argparse.ArgumentParser() parser.add_argument( @@ -289,6 +283,19 @@ def parse_arguments(): default="./onnx_models/", help="where the onnx model will be saved", ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=None, + help=("cache directy of huggingface, by setting this to avoid useless downloading if you have one"), + ) + parser.add_argument( + "--with_past", + action="store_true", + default=False, + help=("The tool will export onnx without past-key-value by default"), + ) parser.add_argument( "--opset", required=False, @@ -305,4 +312,4 @@ def parse_arguments(): if __name__ == "__main__": args = parse_arguments() - export_onnx(args.model, args.saved_path, args.opset) + export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset) From d7f9a408f39c4318c37d3dfcd32759e8439cd38a Mon Sep 17 00:00:00 2001 From: JiCheng Date: Thu, 19 Oct 2023 05:33:50 +0000 Subject: [PATCH 8/9] fix --- .../transformers/large_model_exporter.py | 135 +++++++++++++----- 1 file changed, 102 insertions(+), 33 deletions(-) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 019e189783815..ceeb5d218e334 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -145,7 +145,7 @@ def hook_for_inputs(_, inputs, kwargs): forward_params = inspect.signature(model.forward).parameters input_keys = list(forward_params.keys()) default_values = [forward_params.get(key).default for key in input_keys] - model(sample_inputs[0], attention_mask=sample_inputs[1]) + out = model(sample_inputs[0], attention_mask=sample_inputs[1]) hook_handle.remove() user_inputs = user_inputs[0] onnx_inputs = default_values @@ -161,7 +161,7 @@ def hook_for_inputs(_, inputs, kwargs): if "use_cache" in key: onnx_inputs[idx] = with_past - return input_keys, tuple(onnx_inputs) + return input_keys, onnx_inputs, out.past_key_values def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: @@ -192,7 +192,7 @@ def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.M return model -def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple: +def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple: """move inputs to device""" sample_inputs_ = [] for sample_int in sample_inputs: @@ -203,43 +203,76 @@ def adapt_inputs_to_device(sample_inputs: tuple, device: torch.Device) -> tuple: return tuple(sample_inputs_) -@torch.no_grad() -def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): - """ - do export - model: torch model - onnx_path: where the onnx model saved to - sample_inputs_tp: inputs for torch model - """ - model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) - - model = move_to_approprate_device(model, sample_inputs_tp) - - sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) - - # input_keys would be usesful if the model has some special inputs - input_keys, onnx_inputs = retrieve_onnx_inputs(model, sample_inputs) - - onnx_model_name = "model.onnx" - onnx_path: Path = Path(onnx_path_str).absolute() - if onnx_path.suffix != ".onnx": - onnx_path = onnx_path / onnx_model_name - +def fetch_onnx_inputs_outputs_name( + model: nn.Module, + onnx_inputs: list, + torch_input_names: tuple, + past_key_values: tuple, + with_past: bool, + input_with_past: bool, +): + """fetch onnx inputs and outputs name""" + num_of_past_key = 0 + # try get num_of_past_key and shape of past_key_value + if past_key_values is not None: + num_of_past_key = len(past_key_values) + seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1) + assert seq_index.numel() == 1 + kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"} + + if not num_of_past_key: + num_of_past_key = model.config.num_hidden_layers + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + if input_with_past: + for i in range(num_of_past_key): + onnx_inp_names += (f"present_key.{i}",) + onnx_inp_names += (f"present_values.{i}",) + + onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis + + if with_past or input_with_past: + for i in range(num_of_past_key): + onnx_out_names += (f"past_key.{i}",) + onnx_out_names += (f"past_values.{i}",) + onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis + + for idx, name in enumerate(torch_input_names): + if input_with_past: + if name == "past_key_values": + onnx_inputs[idx] = past_key_values + elif name == "attention_mask": + attn_mask = onnx_inputs[idx] + onnx_inputs[idx] = torch.cat( + (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 + ) + elif name == "input_ids": + input_ids = onnx_inputs[idx] + onnx_inputs[idx] = input_ids[:, -1:] + + return onnx_inp_names, onnx_out_names, onnx_dynamic_axes + + +def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int): + """do export with torch.onnx.export""" + onnx_model_name = onnx_path.name + onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple # two step to export onnx # 1. export onnx with lots of pieces of weights # 2. save all weights to external data with tempfile.TemporaryDirectory() as tmpdirname: tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") - onnx_inp_names = ("input_ids", "attention_mask") - onnx_out_names = ("logits",) - onnx_dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "attention_mask": {0: "batch_size", 1: "seq_len"}, - } torch.onnx.export( model=model, - args=onnx_inputs, + args=tuple(onnx_inputs), f=tmp_onnx, verbose=False, opset_version=opset, @@ -255,7 +288,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit onnx.save_model( onnx_model, str(onnx_path), - save_as_external_data=True, + save_as_external_data=(len(os.listdir(tmpdirname)) > 1), all_tensors_to_one_file=True, location=f"{onnx_model_name}_ext.data", size_threshold=1024, @@ -263,6 +296,42 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit ) +@torch.no_grad() +def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) + + model = move_to_approprate_device(model, sample_inputs_tp) + + sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) + + # input_keys would be usesful if the model has some special inputs + input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past) + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False) + + onnx_model_name = "model.onnx" + onnx_path: Path = Path(onnx_path_str).absolute() + if onnx_path.suffix != ".onnx": + onnx_path = onnx_path / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + if not with_past: + return + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True) + + onnx_model_name = "model_with_past.onnx" + onnx_path = onnx_path.parent / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + + def parse_arguments(): """arguments parsing.""" parser = argparse.ArgumentParser() From 02e887b0a93d5743ded438b8610ef4a89e0a306a Mon Sep 17 00:00:00 2001 From: JiCheng Date: Thu, 19 Oct 2023 06:03:21 +0000 Subject: [PATCH 9/9] fix --- onnxruntime/python/tools/transformers/large_model_exporter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index ceeb5d218e334..3b344d6dc9342 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -213,6 +213,7 @@ def fetch_onnx_inputs_outputs_name( ): """fetch onnx inputs and outputs name""" num_of_past_key = 0 + kv_cache_axis = {0: "batch_size"} # try get num_of_past_key and shape of past_key_value if past_key_values is not None: num_of_past_key = len(past_key_values)