diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index ad031a29c96..dc9976e20cc 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -65,6 +65,9 @@ def check_large_model(self): """Check model > 2GB.""" init_size = 0 for init in self._model.graph.initializer: + # if initializer has external data location, return True + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + return True # if raise error of initializer size > 2GB, return True try: init_bytes = init.SerializeToString() @@ -72,7 +75,7 @@ def check_large_model(self): except Exception as e: if "exceeds maximum protobuf size of 2GB" in str(e): return True - else: + else: # pragma: no cover raise e if init_size > MAXIMUM_PROTOBUF: return True diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index 56532e3c6dd..5de077cb5f8 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -412,7 +412,6 @@ def test_check_large_model(self): import onnx import torch import torch.nn as nn - from neural_compressor.model.onnx_model import ONNXModel class Net(nn.Module): @@ -423,13 +422,19 @@ def __init__(self, in_features, out_features): def forward(self, x): x = self.fc(x) return x - # model > 2GB model = Net(512, 1024 * 1024) input = torch.randn(512, requires_grad=True) with torch.no_grad(): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") + model = ONNXModel(model) # pass ModelProto + self.assertTrue(model.check_large_model()) + + model = ONNXModel("model.onnx") # pass string + self.assertTrue(model.check_large_model()) + + model = onnx.load("model.onnx", load_external_data=False) # not load init model = ONNXModel(model) self.assertTrue(model.check_large_model()) @@ -439,9 +444,17 @@ def forward(self, x): with torch.no_grad(): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") - model = ONNXModel(model) + model = ONNXModel(model) # pass ModelProto + self.assertFalse(model.check_large_model()) + + model = ONNXModel("model.onnx") # pass string self.assertFalse(model.check_large_model()) + model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init + self.assertFalse(model.check_large_model()) + + + if __name__ == "__main__": unittest.main()