Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory consumption in ONNXRT adaptor #1266

Merged
merged 8 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,7 @@ def _detect_domain(self, model):
# 2. according to input
# typically, NLP models have multiple inputs,
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
if not model.is_large_model:
sess = ort.InferenceSession(model.model.SerializeToString(), providers=["CPUExecutionProvider"])
elif model.model_path is not None: # pragma: no cover
sess = ort.InferenceSession(model.model_path, providers=["CPUExecutionProvider"])
else: # pragma: no cover
assert False, "Please use model path instead of onnx model object to quantize."
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
input_shape_lens = [len(inp.type.tensor_type.shape.dim) for inp in model.model.graph.input]
if len(input_shape_lens) > 1 and all(shape_len == 2 for shape_len in input_shape_lens):
is_nlp = True

Expand Down Expand Up @@ -778,11 +772,15 @@ def _pre_optimize(self, model, level=1):

sess_options.register_custom_ops_library(get_library_path())
if not model.is_large_model:
ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"])
sess = ort.InferenceSession(
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
)
elif model.model_path is not None: # pragma: no cover
ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
model.model = onnx.ModelProto() # clean memory for large model
sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
else: # pragma: no cover
logger.warning("Please use model path instead of onnx model object to quantize")
del sess

tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False)

Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
"DmlExecutionProvider": "onnxrt_dml_ep",
}

MAXIMUM_PROTOBUF = 2147483648


def dtype_to_name(dtype_mapping, dtype):
"""Map data type and its string representation."""
Expand Down
35 changes: 25 additions & 10 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import logging
import os
import sys
from pathlib import Path

from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF
from neural_compressor.model.base_model import BaseModel
from neural_compressor.utils.utility import LazyImport

Expand All @@ -41,16 +43,9 @@ def __init__(self, model, **kwargs):
"""
self._model = model if not isinstance(model, str) else onnx.load(model)
self._model_path = None if not isinstance(model, str) else model
self._is_large_model = False
try:
ort.InferenceSession(self._model.SerializeToString(), providers=["CPUExecutionProvider"])
except Exception as e: # pragma: no cover
if self._model_path is not None:
ort.InferenceSession(self._model_path, providers=["CPUExecutionProvider"])
self._is_large_model = True
else:
logger.warning("Please use model path instead of onnx model object to quantize")

self._is_large_model = self.check_large_model()
if self._is_large_model and self._model_path is None:
logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")
self._config = None
if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
from transformers import PretrainedConfig
Expand All @@ -66,6 +61,26 @@ def __init__(self, model, **kwargs):
self._get_graph_info()
self._q_config = None

def check_large_model(self):
"""Check model > 2GB."""
init_size = 0
for init in self._model.graph.initializer:
# if initializer has external data location, return True
if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
return True
# if raise error of initializer size > 2GB, return True
try:
init_bytes = init.SerializeToString()
mengniwang95 marked this conversation as resolved.
Show resolved Hide resolved
init_size += sys.getsizeof(init_bytes)
except Exception as e:
if "exceeds maximum protobuf size of 2GB" in str(e):
return True
else: # pragma: no cover
raise e
if init_size > MAXIMUM_PROTOBUF:
return True
return False

@property
def is_large_model(self):
"""Check the onnx model is over 2GB."""
Expand Down
48 changes: 48 additions & 0 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def setUp(self):
def tearDownClass(self):
shutil.rmtree("./gptj", ignore_errors=True)
shutil.rmtree("./hf_test", ignore_errors=True)
os.remove("model.onnx")

def test_hf_model(self):
from optimum.onnxruntime import ORTModelForCausalLM
Expand Down Expand Up @@ -407,6 +408,53 @@ def test_remove_unused_nodes(self):
self.model.remove_unused_nodes()
self.assertEqual(len(self.model.nodes()), 6)

def test_check_large_model(self):
import onnx
import torch
import torch.nn as nn

from neural_compressor.model.onnx_model import ONNXModel

class Net(nn.Module):
def __init__(self, in_features, out_features):
super(Net, self).__init__()
self.fc = nn.Linear(in_features, out_features)

def forward(self, x):
x = self.fc(x)
return x

# model > 2GB
model = Net(512, 1024 * 1024)
input = torch.randn(512, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model) # pass ModelProto
self.assertTrue(model.check_large_model())

model = ONNXModel("model.onnx") # pass string
self.assertTrue(model.check_large_model())

model = onnx.load("model.onnx", load_external_data=False) # not load init
model = ONNXModel(model)
self.assertTrue(model.check_large_model())

# model < 2GB
model = Net(10, 10 * 10)
input = torch.randn(10, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model) # pass ModelProto
self.assertFalse(model.check_large_model())

model = ONNXModel("model.onnx") # pass string
self.assertFalse(model.check_large_model())

model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init
self.assertFalse(model.check_large_model())


if __name__ == "__main__":
unittest.main()