Skip to content

Commit

Permalink
update check for large model
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Sep 21, 2023
1 parent c050880 commit c5c9229
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
5 changes: 4 additions & 1 deletion neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ def check_large_model(self):
"""Check model > 2GB."""
init_size = 0
for init in self._model.graph.initializer:
# if initializer has external data location, return True
if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
return True
# if raise error of initializer size > 2GB, return True
try:
init_bytes = init.SerializeToString()
init_size += sys.getsizeof(init_bytes)
except Exception as e:
if "exceeds maximum protobuf size of 2GB" in str(e):
return True
else:
else: # pragma: no cover
raise e
if init_size > MAXIMUM_PROTOBUF:
return True
Expand Down
19 changes: 16 additions & 3 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def test_check_large_model(self):
import onnx
import torch
import torch.nn as nn

from neural_compressor.model.onnx_model import ONNXModel

class Net(nn.Module):
Expand All @@ -423,13 +422,19 @@ def __init__(self, in_features, out_features):
def forward(self, x):
x = self.fc(x)
return x

# model > 2GB
model = Net(512, 1024 * 1024)
input = torch.randn(512, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model) # pass ModelProto
self.assertTrue(model.check_large_model())

model = ONNXModel("model.onnx") # pass string
self.assertTrue(model.check_large_model())

model = onnx.load("model.onnx", load_external_data=False) # not load init
model = ONNXModel(model)
self.assertTrue(model.check_large_model())

Expand All @@ -439,9 +444,17 @@ def forward(self, x):
with torch.no_grad():
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model)
model = ONNXModel(model) # pass ModelProto
self.assertFalse(model.check_large_model())

model = ONNXModel("model.onnx") # pass string
self.assertFalse(model.check_large_model())

model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init
self.assertFalse(model.check_large_model())




if __name__ == "__main__":
unittest.main()

0 comments on commit c5c9229

Please sign in to comment.