From bef099b9889d265332b90865e7a0dd06364a8b1d Mon Sep 17 00:00:00 2001 From: yuwenz Date: Tue, 19 Sep 2023 08:22:44 +0000 Subject: [PATCH 1/8] reduce memory consumption Signed-off-by: yuwenz --- neural_compressor/adaptor/onnxrt.py | 16 +++++----------- neural_compressor/model/onnx_model.py | 23 +++++++++++++---------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 8ff0673955a..5ca63e51240 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -712,13 +712,7 @@ def _detect_domain(self, model): # 2. according to input # typically, NLP models have multiple inputs, # and the dimension of each input is usually 2 (batch_size, max_seq_len) - if not model.is_large_model: - sess = ort.InferenceSession(model.model.SerializeToString(), providers=["CPUExecutionProvider"]) - elif model.model_path is not None: # pragma: no cover - sess = ort.InferenceSession(model.model_path, providers=["CPUExecutionProvider"]) - else: # pragma: no cover - assert False, "Please use model path instead of onnx model object to quantize." - input_shape_lens = [len(input.shape) for input in sess.get_inputs()] + input_shape_lens = [len(inp.type.tensor_type.shape.dim) for inp in model.model.graph.input] if len(input_shape_lens) > 1 and all(shape_len == 2 for shape_len in input_shape_lens): is_nlp = True @@ -747,7 +741,6 @@ def _pre_optimize(self, model, level=1): return self.pre_optimized_model from neural_compressor import options from neural_compressor.adaptor.ox_utils.util import remove_init_from_model_input, split_shared_bias - remove_init_from_model_input(model) sess_options = ort.SessionOptions() optimization_levels = { @@ -775,14 +768,15 @@ def _pre_optimize(self, model, level=1): sess_options.optimized_model_filepath = os.path.join(self.work_space, "Optimized_model.onnx") if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover from onnxruntime_extensions import get_library_path - sess_options.register_custom_ops_library(get_library_path()) if not model.is_large_model: - ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]) + sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]) elif model.model_path is not None: # pragma: no cover - ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"]) + model.model = onnx.ModelProto() # clean memory for large model + sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"]) else: # pragma: no cover logger.warning("Please use model path instead of onnx model object to quantize") + del sess tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False) diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 3d85a7a9423..c61fc388dad 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -18,6 +18,7 @@ import logging import os +import sys from pathlib import Path from neural_compressor.model.base_model import BaseModel @@ -41,16 +42,9 @@ def __init__(self, model, **kwargs): """ self._model = model if not isinstance(model, str) else onnx.load(model) self._model_path = None if not isinstance(model, str) else model - self._is_large_model = False - try: - ort.InferenceSession(self._model.SerializeToString(), providers=["CPUExecutionProvider"]) - except Exception as e: # pragma: no cover - if self._model_path is not None: - ort.InferenceSession(self._model_path, providers=["CPUExecutionProvider"]) - self._is_large_model = True - else: - logger.warning("Please use model path instead of onnx model object to quantize") - + self._is_large_model = self.check_large_model() + if self._is_large_model and self._model_path is None: + logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize") self._config = None if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()): from transformers import PretrainedConfig @@ -65,6 +59,15 @@ def __init__(self, model, **kwargs): self._graph_info = {} self._get_graph_info() self._q_config = None + + def check_large_model(self): + """Check model > 2GB.""" + init_size = 0 + for init in self._model.graph.initializer: + init_size += sys.getsizeof(init.SerializeToString()) + if init_size > 2147483648: + return True + return False @property def is_large_model(self): From 213bff23d9ce5460326c1498b0f965c054fa8315 Mon Sep 17 00:00:00 2001 From: yuwenz Date: Tue, 19 Sep 2023 08:53:17 +0000 Subject: [PATCH 2/8] enhance check for large model Signed-off-by: yuwenz --- neural_compressor/adaptor/ox_utils/util.py | 1 + neural_compressor/model/onnx_model.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index 0f750f429f1..b9b00a3dafe 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -85,6 +85,7 @@ "DmlExecutionProvider": "onnxrt_dml_ep", } +MAXIMUM_PROTOBUF = 2147483648 def dtype_to_name(dtype_mapping, dtype): """Map data type and its string representation.""" diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index c61fc388dad..b4166705655 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -23,6 +23,7 @@ from neural_compressor.model.base_model import BaseModel from neural_compressor.utils.utility import LazyImport +from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF onnx = LazyImport("onnx") ort = LazyImport("onnxruntime") @@ -65,7 +66,7 @@ def check_large_model(self): init_size = 0 for init in self._model.graph.initializer: init_size += sys.getsizeof(init.SerializeToString()) - if init_size > 2147483648: + if init_size > MAXIMUM_PROTOBUF: return True return False From 036d4ad5344ea7cdc3cdab0a391757b26e277e3e Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 19 Sep 2023 18:49:45 +0800 Subject: [PATCH 3/8] update UT of check large model Signed-off-by: yuwenzho --- test/model/test_onnx_model.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 97a784f7315..0b03b48e72a 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -203,6 +203,7 @@ def setUp(self): def tearDownClass(self): shutil.rmtree("./gptj", ignore_errors=True) shutil.rmtree("./hf_test", ignore_errors=True) + os.remove("model.onnx") def test_hf_model(self): from optimum.onnxruntime import ORTModelForCausalLM @@ -407,6 +408,40 @@ def test_remove_unused_nodes(self): self.model.remove_unused_nodes() self.assertEqual(len(self.model.nodes()), 6) + def test_check_large_model(self): + import onnx + import torch + import torch.nn as nn + from neural_compressor.model.onnx_model import ONNXModel + + class Net(nn.Module): + def __init__(self, in_features, out_features): + super(Net, self).__init__() + self.fc = nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.fc(x) + return x + + # model > 2GB + model = Net(512, 1024 * 1024) + input = torch.randn(512, requires_grad=True) + with torch.no_grad(): + torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + model = onnx.load("model.onnx") + model = ONNXModel(model) + self.assertTrue(model.check_large_model()) + + # model < 2GB + model = Net(10, 10 * 10) + input = torch.randn(10, requires_grad=True) + with torch.no_grad(): + torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + model = onnx.load("model.onnx") + model = ONNXModel(model) + self.assertFalse(model.check_large_model()) + + if __name__ == "__main__": unittest.main() From 05bfe0206dcb2d54fcc9d1504ec9317aa14bac78 Mon Sep 17 00:00:00 2001 From: yuwenz Date: Tue, 19 Sep 2023 10:53:23 +0000 Subject: [PATCH 4/8] update ut of check large model Signed-off-by: yuwenz --- test/model/test_onnx_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 0b03b48e72a..4b5624291db 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -441,7 +441,5 @@ def forward(self, x): model = ONNXModel(model) self.assertFalse(model.check_large_model()) - - if __name__ == "__main__": unittest.main() From 8a9a8aa5594a90d6ba29f1828d7f7a0b357f57a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 10:56:31 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/adaptor/onnxrt.py | 8 ++++++-- neural_compressor/adaptor/ox_utils/util.py | 1 + neural_compressor/model/onnx_model.py | 4 ++-- test/model/test_onnx_model.py | 8 +++++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 5ca63e51240..cf7ca7475fd 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -741,6 +741,7 @@ def _pre_optimize(self, model, level=1): return self.pre_optimized_model from neural_compressor import options from neural_compressor.adaptor.ox_utils.util import remove_init_from_model_input, split_shared_bias + remove_init_from_model_input(model) sess_options = ort.SessionOptions() optimization_levels = { @@ -768,11 +769,14 @@ def _pre_optimize(self, model, level=1): sess_options.optimized_model_filepath = os.path.join(self.work_space, "Optimized_model.onnx") if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover from onnxruntime_extensions import get_library_path + sess_options.register_custom_ops_library(get_library_path()) if not model.is_large_model: - sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]) + sess = ort.InferenceSession( + model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"] + ) elif model.model_path is not None: # pragma: no cover - model.model = onnx.ModelProto() # clean memory for large model + model.model = onnx.ModelProto() # clean memory for large model sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"]) else: # pragma: no cover logger.warning("Please use model path instead of onnx model object to quantize") diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index b9b00a3dafe..508b8ef4d10 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -87,6 +87,7 @@ MAXIMUM_PROTOBUF = 2147483648 + def dtype_to_name(dtype_mapping, dtype): """Map data type and its string representation.""" return list(dtype_mapping.keys())[list(dtype_mapping.values()).index(dtype)] diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index b4166705655..a2169008c1c 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -21,9 +21,9 @@ import sys from pathlib import Path +from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF from neural_compressor.model.base_model import BaseModel from neural_compressor.utils.utility import LazyImport -from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF onnx = LazyImport("onnx") ort = LazyImport("onnxruntime") @@ -60,7 +60,7 @@ def __init__(self, model, **kwargs): self._graph_info = {} self._get_graph_info() self._q_config = None - + def check_large_model(self): """Check model > 2GB.""" init_size = 0 diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 4b5624291db..56532e3c6dd 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -412,6 +412,7 @@ def test_check_large_model(self): import onnx import torch import torch.nn as nn + from neural_compressor.model.onnx_model import ONNXModel class Net(nn.Module): @@ -422,12 +423,12 @@ def __init__(self, in_features, out_features): def forward(self, x): x = self.fc(x) return x - + # model > 2GB model = Net(512, 1024 * 1024) input = torch.randn(512, requires_grad=True) with torch.no_grad(): - torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") model = ONNXModel(model) self.assertTrue(model.check_large_model()) @@ -436,10 +437,11 @@ def forward(self, x): model = Net(10, 10 * 10) input = torch.randn(10, requires_grad=True) with torch.no_grad(): - torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") model = ONNXModel(model) self.assertFalse(model.check_large_model()) + if __name__ == "__main__": unittest.main() From c050880e5b5fd051a3430da42a77efa570835bd7 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 20 Sep 2023 15:05:58 +0800 Subject: [PATCH 6/8] fix typo Signed-off-by: yuwenzho --- neural_compressor/model/onnx_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index a2169008c1c..ad031a29c96 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -65,7 +65,15 @@ def check_large_model(self): """Check model > 2GB.""" init_size = 0 for init in self._model.graph.initializer: - init_size += sys.getsizeof(init.SerializeToString()) + # if raise error of initializer size > 2GB, return True + try: + init_bytes = init.SerializeToString() + init_size += sys.getsizeof(init_bytes) + except Exception as e: + if "exceeds maximum protobuf size of 2GB" in str(e): + return True + else: + raise e if init_size > MAXIMUM_PROTOBUF: return True return False From c5c9229132881b0562a5a653d7a79a387064dbef Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 21 Sep 2023 11:36:19 +0800 Subject: [PATCH 7/8] update check for large model Signed-off-by: yuwenzho --- neural_compressor/model/onnx_model.py | 5 ++++- test/model/test_onnx_model.py | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index ad031a29c96..dc9976e20cc 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -65,6 +65,9 @@ def check_large_model(self): """Check model > 2GB.""" init_size = 0 for init in self._model.graph.initializer: + # if initializer has external data location, return True + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + return True # if raise error of initializer size > 2GB, return True try: init_bytes = init.SerializeToString() @@ -72,7 +75,7 @@ def check_large_model(self): except Exception as e: if "exceeds maximum protobuf size of 2GB" in str(e): return True - else: + else: # pragma: no cover raise e if init_size > MAXIMUM_PROTOBUF: return True diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 56532e3c6dd..5de077cb5f8 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -412,7 +412,6 @@ def test_check_large_model(self): import onnx import torch import torch.nn as nn - from neural_compressor.model.onnx_model import ONNXModel class Net(nn.Module): @@ -423,13 +422,19 @@ def __init__(self, in_features, out_features): def forward(self, x): x = self.fc(x) return x - # model > 2GB model = Net(512, 1024 * 1024) input = torch.randn(512, requires_grad=True) with torch.no_grad(): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") + model = ONNXModel(model) # pass ModelProto + self.assertTrue(model.check_large_model()) + + model = ONNXModel("model.onnx") # pass string + self.assertTrue(model.check_large_model()) + + model = onnx.load("model.onnx", load_external_data=False) # not load init model = ONNXModel(model) self.assertTrue(model.check_large_model()) @@ -439,9 +444,17 @@ def forward(self, x): with torch.no_grad(): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") - model = ONNXModel(model) + model = ONNXModel(model) # pass ModelProto + self.assertFalse(model.check_large_model()) + + model = ONNXModel("model.onnx") # pass string self.assertFalse(model.check_large_model()) + model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init + self.assertFalse(model.check_large_model()) + + + if __name__ == "__main__": unittest.main() From ed9ff63703fb4a69f201ad9fc5a56bf838c2789e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Sep 2023 03:39:35 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/model/onnx_model.py | 2 +- test/model/test_onnx_model.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index dc9976e20cc..6dc4a74b707 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -75,7 +75,7 @@ def check_large_model(self): except Exception as e: if "exceeds maximum protobuf size of 2GB" in str(e): return True - else: # pragma: no cover + else: # pragma: no cover raise e if init_size > MAXIMUM_PROTOBUF: return True diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 5de077cb5f8..4c891781c35 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -412,6 +412,7 @@ def test_check_large_model(self): import onnx import torch import torch.nn as nn + from neural_compressor.model.onnx_model import ONNXModel class Net(nn.Module): @@ -422,19 +423,20 @@ def __init__(self, in_features, out_features): def forward(self, x): x = self.fc(x) return x + # model > 2GB model = Net(512, 1024 * 1024) input = torch.randn(512, requires_grad=True) with torch.no_grad(): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") - model = ONNXModel(model) # pass ModelProto + model = ONNXModel(model) # pass ModelProto self.assertTrue(model.check_large_model()) - model = ONNXModel("model.onnx") # pass string + model = ONNXModel("model.onnx") # pass string self.assertTrue(model.check_large_model()) - model = onnx.load("model.onnx", load_external_data=False) # not load init + model = onnx.load("model.onnx", load_external_data=False) # not load init model = ONNXModel(model) self.assertTrue(model.check_large_model()) @@ -444,17 +446,15 @@ def forward(self, x): with torch.no_grad(): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") - model = ONNXModel(model) # pass ModelProto + model = ONNXModel(model) # pass ModelProto self.assertFalse(model.check_large_model()) - model = ONNXModel("model.onnx") # pass string + model = ONNXModel("model.onnx") # pass string self.assertFalse(model.check_large_model()) - model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init + model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init self.assertFalse(model.check_large_model()) - - if __name__ == "__main__": unittest.main()