Skip to content

Commit

Permalink
enhance check of large model
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Sep 20, 2023
2 parents 6b72726 + d3e204e commit 33dd319
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
8 changes: 6 additions & 2 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def _pre_optimize(self, model, level=1):
return self.pre_optimized_model
from neural_compressor import options
from neural_compressor.adaptor.ox_utils.util import remove_init_from_model_input, split_shared_bias

remove_init_from_model_input(model)
sess_options = ort.SessionOptions()
optimization_levels = {
Expand Down Expand Up @@ -768,11 +769,14 @@ def _pre_optimize(self, model, level=1):
sess_options.optimized_model_filepath = os.path.join(self.work_space, "Optimized_model.onnx")
if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
from onnxruntime_extensions import get_library_path

sess_options.register_custom_ops_library(get_library_path())
if not model.is_large_model:
sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"])
sess = ort.InferenceSession(
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
)
elif model.model_path is not None: # pragma: no cover
model.model = onnx.ModelProto() # clean memory for large model
model.model = onnx.ModelProto() # clean memory for large model
sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
else: # pragma: no cover
logger.warning("Please use model path instead of onnx model object to quantize")
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
"DmlExecutionProvider": "onnxrt_dml_ep",
}

MAXIMUM_PROTOBUF = 2147483648


def dtype_to_name(dtype_mapping, dtype):
"""Map data type and its string representation."""
Expand Down
15 changes: 12 additions & 3 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
from pathlib import Path

from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF
from neural_compressor.model.base_model import BaseModel
from neural_compressor.utils.utility import LazyImport

Expand Down Expand Up @@ -59,13 +60,21 @@ def __init__(self, model, **kwargs):
self._graph_info = {}
self._get_graph_info()
self._q_config = None

def check_large_model(self):
"""Check model > 2GB."""
init_size = 0
for init in self._model.graph.initializer:
init_size += sys.getsizeof(init.SerializeToString())
if init_size > 2147483648:
# if size of the initializer > 2GB, return True
try:
init_bytes = init.SerializeToString()
init_size += sys.getsizeof(init.SerializeToString())
except Exception as e:
if "exceeds maximum protobuf size of 2GB" in str(e):
return True
else:
raise e
if init_size > MAXIMUM_PROTOBUF:
return True
return False

Expand Down
6 changes: 2 additions & 4 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,11 @@ def __init__(self, in_features, out_features):
def forward(self, x):
x = self.fc(x)
return x

# model > 2GB
model = Net(512, 1024 * 1024)
input = torch.randn(512, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13)
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model)
self.assertTrue(model.check_large_model())
Expand All @@ -436,12 +435,11 @@ def forward(self, x):
model = Net(10, 10 * 10)
input = torch.randn(10, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13)
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model)
self.assertFalse(model.check_large_model())



if __name__ == "__main__":
unittest.main()

0 comments on commit 33dd319

Please sign in to comment.