Skip to content

Commit

Permalink
update UT of check large model
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Sep 19, 2023
1 parent 3becf8a commit 6b72726
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def setUp(self):
def tearDownClass(self):
shutil.rmtree("./gptj", ignore_errors=True)
shutil.rmtree("./hf_test", ignore_errors=True)
os.remove("model.onnx")

def test_hf_model(self):
from optimum.onnxruntime import ORTModelForCausalLM
Expand Down Expand Up @@ -407,6 +408,40 @@ def test_remove_unused_nodes(self):
self.model.remove_unused_nodes()
self.assertEqual(len(self.model.nodes()), 6)

def test_check_large_model(self):
import onnx
import torch
import torch.nn as nn
from neural_compressor.model.onnx_model import ONNXModel

class Net(nn.Module):
def __init__(self, in_features, out_features):
super(Net, self).__init__()
self.fc = nn.Linear(in_features, out_features)

def forward(self, x):
x = self.fc(x)
return x

# model > 2GB
model = Net(512, 1024 * 1024)
input = torch.randn(512, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model)
self.assertTrue(model.check_large_model())

# model < 2GB
model = Net(10, 10 * 10)
input = torch.randn(10, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input, ), 'model.onnx', do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model)
self.assertFalse(model.check_large_model())



if __name__ == "__main__":
unittest.main()

0 comments on commit 6b72726

Please sign in to comment.