From 6d1087d6488e16b5d23f8de04f07bb394749b6da Mon Sep 17 00:00:00 2001 From: yuwenz Date: Tue, 19 Sep 2023 08:53:17 +0000 Subject: [PATCH 1/3] enhance check for large model Signed-off-by: yuwenz --- neural_compressor/adaptor/ox_utils/util.py | 1 + neural_compressor/model/onnx_model.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index 0f750f429f1..b9b00a3dafe 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -85,6 +85,7 @@ "DmlExecutionProvider": "onnxrt_dml_ep", } +MAXIMUM_PROTOBUF = 2147483648 def dtype_to_name(dtype_mapping, dtype): """Map data type and its string representation.""" diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index c61fc388dad..b4166705655 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -23,6 +23,7 @@ from neural_compressor.model.base_model import BaseModel from neural_compressor.utils.utility import LazyImport +from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF onnx = LazyImport("onnx") ort = LazyImport("onnxruntime") @@ -65,7 +66,7 @@ def check_large_model(self): init_size = 0 for init in self._model.graph.initializer: init_size += sys.getsizeof(init.SerializeToString()) - if init_size > 2147483648: + if init_size > MAXIMUM_PROTOBUF: return True return False From aa04c3afa8708ea6730198625cf48982e70a5f5f Mon Sep 17 00:00:00 2001 From: yuwenz Date: Tue, 19 Sep 2023 10:53:23 +0000 Subject: [PATCH 2/3] update ut of check large model Signed-off-by: yuwenz --- test/model/test_onnx_model.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 97a784f7315..4b5624291db 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -203,6 +203,7 @@ def setUp(self): def tearDownClass(self): shutil.rmtree("./gptj", ignore_errors=True) shutil.rmtree("./hf_test", ignore_errors=True) + os.remove("model.onnx") def test_hf_model(self): from optimum.onnxruntime import ORTModelForCausalLM @@ -407,6 +408,38 @@ def test_remove_unused_nodes(self): self.model.remove_unused_nodes() self.assertEqual(len(self.model.nodes()), 6) + def test_check_large_model(self): + import onnx + import torch + import torch.nn as nn + from neural_compressor.model.onnx_model import ONNXModel + + class Net(nn.Module): + def __init__(self, in_features, out_features): + super(Net, self).__init__() + self.fc = nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.fc(x) + return x + + # model > 2GB + model = Net(512, 1024 * 1024) + input = torch.randn(512, requires_grad=True) + with torch.no_grad(): + torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + model = onnx.load("model.onnx") + model = ONNXModel(model) + self.assertTrue(model.check_large_model()) + + # model < 2GB + model = Net(10, 10 * 10) + input = torch.randn(10, requires_grad=True) + with torch.no_grad(): + torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + model = onnx.load("model.onnx") + model = ONNXModel(model) + self.assertFalse(model.check_large_model()) if __name__ == "__main__": unittest.main() From d3e204e99d4f5b572f8437c2cb34bc6d5d2692f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 10:56:31 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/adaptor/onnxrt.py | 8 ++++++-- neural_compressor/adaptor/ox_utils/util.py | 1 + neural_compressor/model/onnx_model.py | 4 ++-- test/model/test_onnx_model.py | 8 +++++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 5ca63e51240..cf7ca7475fd 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -741,6 +741,7 @@ def _pre_optimize(self, model, level=1): return self.pre_optimized_model from neural_compressor import options from neural_compressor.adaptor.ox_utils.util import remove_init_from_model_input, split_shared_bias + remove_init_from_model_input(model) sess_options = ort.SessionOptions() optimization_levels = { @@ -768,11 +769,14 @@ def _pre_optimize(self, model, level=1): sess_options.optimized_model_filepath = os.path.join(self.work_space, "Optimized_model.onnx") if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover from onnxruntime_extensions import get_library_path + sess_options.register_custom_ops_library(get_library_path()) if not model.is_large_model: - sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]) + sess = ort.InferenceSession( + model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"] + ) elif model.model_path is not None: # pragma: no cover - model.model = onnx.ModelProto() # clean memory for large model + model.model = onnx.ModelProto() # clean memory for large model sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"]) else: # pragma: no cover logger.warning("Please use model path instead of onnx model object to quantize") diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index b9b00a3dafe..508b8ef4d10 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -87,6 +87,7 @@ MAXIMUM_PROTOBUF = 2147483648 + def dtype_to_name(dtype_mapping, dtype): """Map data type and its string representation.""" return list(dtype_mapping.keys())[list(dtype_mapping.values()).index(dtype)] diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index b4166705655..a2169008c1c 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -21,9 +21,9 @@ import sys from pathlib import Path +from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF from neural_compressor.model.base_model import BaseModel from neural_compressor.utils.utility import LazyImport -from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF onnx = LazyImport("onnx") ort = LazyImport("onnxruntime") @@ -60,7 +60,7 @@ def __init__(self, model, **kwargs): self._graph_info = {} self._get_graph_info() self._q_config = None - + def check_large_model(self): """Check model > 2GB.""" init_size = 0 diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 4b5624291db..56532e3c6dd 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -412,6 +412,7 @@ def test_check_large_model(self): import onnx import torch import torch.nn as nn + from neural_compressor.model.onnx_model import ONNXModel class Net(nn.Module): @@ -422,12 +423,12 @@ def __init__(self, in_features, out_features): def forward(self, x): x = self.fc(x) return x - + # model > 2GB model = Net(512, 1024 * 1024) input = torch.randn(512, requires_grad=True) with torch.no_grad(): - torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") model = ONNXModel(model) self.assertTrue(model.check_large_model()) @@ -436,10 +437,11 @@ def forward(self, x): model = Net(10, 10 * 10) input = torch.randn(10, requires_grad=True) with torch.no_grad(): - torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13) + torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") model = ONNXModel(model) self.assertFalse(model.check_large_model()) + if __name__ == "__main__": unittest.main()