Skip to content

Commit

Permalink
Support >2G model export | torchlib(feat) (#1003)
Browse files Browse the repository at this point in the history
Support >2G model export by caching the model to disk when necessary.

Tested locally with `test_save_initializer_to_files_for_large_model`

Fixes #493
  • Loading branch information
justinchuby authored Aug 11, 2023
1 parent b7d2939 commit d9b64c5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 27 deletions.
81 changes: 64 additions & 17 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
import tempfile
import typing
import warnings
from typing import Any, Dict, Final, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -63,6 +64,9 @@
None,
]

# Be sure to leave ample room for the rest of the proto fields.
_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB

# TODO(justinchuby): Build a context manager to handle source information.


Expand Down Expand Up @@ -342,6 +346,18 @@ def _create_op_call_in_torch_graph(
return node_ouputs


def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
"""Estimate the size of a tensor in bytes.
Args:
tensor: The tensor to estimate the size of.
Returns:
The estimated size of the tensor in bytes.
"""
return tensor.numel() * tensor.element_size()


class TorchScriptGraph:
_LOCAL_FUNCTION_DOMAIN_NAME: Final[str] = "torch_export"
"""The domain name for local functions."""
Expand Down Expand Up @@ -683,31 +699,58 @@ def to_model_proto(
# TODO(BowenBao): All local function domain versions are hardcoded as 1.
unique_custom_domains[function_proto.domain] = 1

(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
initializers_size = sum(
_tensor_rawdata_size(tensor) for tensor in self.initializers.values()
)

large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers if include_initializers else {},
onnx_opset_version=opset_version,
# TODO(justinchuby): Figure out how to get the dynamic axes from the inputs
dynamic_axes={},
defer_weight_export=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
strip_doc_string=False,
keep_initializers_as_inputs=False,
custom_opsets={},
add_node_names=True,
# TODO(#493): Passing in this instead of reading from env.
# User must put the exported model file in the same folder to launch ORT.
onnx_file_path=os.path.join(
os.getenv("EXTERNAL_ONNX_INITIALIZER_FOLDER", ""), "dummy_model_path.onnx"
),
node_attr_to_name={},
)

onnx_model = onnx.load_from_string(proto)
# We decided to cache the model to disk when the model is large.
# Alternatively, we could build the ONNX `TensorProto`s in memory
# and append them to the model proto.
# We did not do it because it is harder to get right (vs. PyTorch's battle-tested
# implementation) and creating the `TensorProto`s naively (by converting to numpy)
# is slow.
cache_model_to_disk = include_initializers and large_model

if cache_model_to_disk:
with tempfile.TemporaryDirectory() as temp_dir:
onnx_file_path = os.path.join(temp_dir, "exported_model.onnx")
export_kwargs["onnx_file_path"] = onnx_file_path
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)
onnx.load_external_data_for_model(onnx_model, temp_dir)
else:
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)

onnx_model.functions.extend(function_proto_dict.values())

# `_export_onnx` only exports opset_imports that is visible to it. It does not
Expand All @@ -725,10 +768,14 @@ def to_model_proto(
)

try:
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
if not cache_model_to_disk:
# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
# we cannot serialize the model.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
logging.debug(
Expand Down
14 changes: 4 additions & 10 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from __future__ import annotations

import os
import tempfile
import unittest

import torch

import onnxscript
import onnxscript.testing
from onnxscript import FLOAT, evaluator
from onnxscript import opset17 as op
from onnxscript import opset18 as op
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib import graph_building, ops

Expand Down Expand Up @@ -165,14 +164,9 @@ def forward(self, x):
model = MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

with tempfile.TemporaryDirectory() as temp_dir:
os.environ["EXTERNAL_ONNX_INITIALIZER_FOLDER"] = temp_dir
torch.onnx.dynamo_export(
model,
x,
)
# 3 initializers are saved to files as external data.
self.assertEqual(len(os.listdir(temp_dir)), 3)
model_proto = torch.onnx.dynamo_export(model, x).model_proto
# Assert model is larger than 2GB (~=3GB)
self.assertGreater(model_proto.ByteSize(), 2**31)


if __name__ == "__main__":
Expand Down

0 comments on commit d9b64c5

Please sign in to comment.