diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index ac943c9c644..56532e3c6dd 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -412,6 +412,7 @@ def test_check_large_model(self): import onnx import torch import torch.nn as nn + from neural_compressor.model.onnx_model import ONNXModel class Net(nn.Module): @@ -422,6 +423,7 @@ def __init__(self, in_features, out_features): def forward(self, x): x = self.fc(x) return x + # model > 2GB model = Net(512, 1024 * 1024) input = torch.randn(512, requires_grad=True)