Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support >2G model export | torchlib(feat) #1003

Merged
merged 7 commits into from
Aug 11, 2023
Merged

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Aug 10, 2023

Support >2G model export by caching the model to disk when necessary.

Tested locally with test_save_initializer_to_files_for_large_model

Fixes #493

cc @wschin

onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
if not cache_model_to_disk:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the follow up PR we will remove the checks altogether from a discussion with Aaron: we should not check it here.

Returns:
The estimated size of the tensor in bytes.
"""
return tensor.numel() * tensor.element_size()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks GPT

@codecov
Copy link

codecov bot commented Aug 10, 2023

Codecov Report

Merging #1003 (db58031) into main (b7d2939) will decrease coverage by 0.03%.
The diff coverage is 52.17%.

@@            Coverage Diff             @@
##             main    #1003      +/-   ##
==========================================
- Coverage   77.23%   77.20%   -0.03%     
==========================================
  Files         112      112              
  Lines       14009    14022      +13     
  Branches     1447     1450       +3     
==========================================
+ Hits        10820    10826       +6     
- Misses       2828     2833       +5     
- Partials      361      363       +2     
Files Changed Coverage Δ
...ipt/function_libs/torch_lib/graph_building_test.py 78.99% <33.33%> (+1.12%) ⬆️
...nxscript/function_libs/torch_lib/graph_building.py 81.74% <55.00%> (-1.79%) ⬇️

@github-actions
Copy link

github-actions bot commented Aug 10, 2023

Test Results

         18 files  ±         0         18 suites  ±0   1h 7m 49s ⏱️ + 2m 49s
  10 249 tests ±         0    7 476 ✔️ ±       0      2 772 💤  -          1  0 ±0  1 🔥 +1 
153 023 runs  +15 236  33 536 ✔️ +3 309  119 486 💤 +11 926  0 ±0  1 🔥 +1 

For more details on these errors, see this check.

Results for commit db58031. ± Comparison against base commit b7d2939.

♻️ This comment has been updated with latest results.

@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label Aug 10, 2023
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)
onnx.load_external_data_for_model(onnx_model, temp_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the trick is the 2GB limitation only applies when serializing, but not on in memory ModelProto.

Makes me think _export_onnx should return ModelProto instead of the serialized string. But that is not supported by pybind. We can probably create a c++ pybind api that returns ModelProto and initializers in separate serialized strings, then in python deserialize and combine them together. Only python api is exposed and used here which gives us ModelProto directly.

The benefit is we skip checking size, nor write to disk. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we return a list of initializers?

Taking this further, we don’t even need to pass in initializers to _export_onnx. We can serialize them ourselves outside with onnx. This way we don’t need to change the PyTorch c++ implementation.

Copy link
Collaborator Author

@justinchuby justinchuby Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested with

def _add_initializers(model_proto: onnx.ModelProto, initializers: Mapping[str, torch.Tensor]):
    tensor_protos = []
    for name, tensor in initializers.items():
        print(name, "0")
        tensor_numpy = tensor.detach().numpy()
        print(name)
        tensor_proto = onnx.helper.make_tensor(
            name=name,
            data_type=onnx.helper.np_dtype_to_tensor_dtype(tensor_numpy.dtype),
            dims=tensor_numpy.shape,
            vals=tensor_numpy,
        )
        print(name, "done1")
        tensor_protos.append(tensor_proto)
    model_proto.graph.initializer.extend(tensor_protos)

But onnx.helper.make_tensor is very slow.

I think returning a list of TensorProtos can work. But for now it seems to me allowing the compatibility may be nice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*compatibility with torch2.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we will move away from torchscript eventually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try make_tensor(...torch_tensor.data_ptr().to_bytes(torch_tensor.element_size() * torch_tensor.numel(), byteorder=sys.byteorder), raw=True) and if that doesn't work or is still slow then we go back to initial solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be

import ctypes
torch_tensor = torch.tensor([2, 3])
raw_data = bytes(ctypes.c_ubyte*torch_tensor.element_size()*torch_tensor.numel()).from_address(torch_tensor.data_ptr())
tensor_proto = make_tensor(..., vals=raw_data, raw=True)

I was misled by someone doing data_ptr().to_bytes(), but what that really does was converting the pointer integer itself to bytes... lol

I'm not sure if it is worth it, looks hacky, but this should resemble what _export_onnx is doing on c++ side. If that is still slow then there is nothing more we can do.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onnx.checker.check_model(onnx_model, full_check=True)
if not cache_model_to_disk:
# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For shape inference, can we still load shape and element type from model file (not initializer files) and then run infer_shape?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but we also don’t need to because PyTorch supplies all the shape info.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A drawback due to onnx/onnx#5487, we don't have much inner node shape info left now that modules are functions.

_estimate_tensor_size(tensor) for tensor in self.initializers.values()
)

# Treat models > 1GB as large models so that we have ample room
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, maybe increase to 1.8 GB? I never see a model > 100MB without initializers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

onnx_model = onnx.load_from_string(proto)
cache_model_to_disk = include_initializers and large_model

if cache_model_to_disk:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether or not storing initializers should be controlled by a user flag. Assume that I export a 1GB model on remote machine. I want to visualize it locally. I really don't want to download its initializers with home internet. If this flag can be turned on, I will be able to just download the structure of model and debug faster.

Copy link
Collaborator Author

@justinchuby justinchuby Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think once the user get the model proto, they can do whatever they want (aka remove all the data)? A user has full control when they get the dynamo export output as an object.

Further yet include_initializers is already an argument

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @wschin 's goal and @justinchuby 's explanation. A thing to consider for ExportOutput.save or ExportOutputSerializer.

@BowenBao
Copy link
Contributor

Would be nice to mention the perf impact / comparison too.

@justinchuby
Copy link
Collaborator Author

Would be nice to mention the perf impact / comparison too.

Done

@justinchuby justinchuby merged commit d9b64c5 into main Aug 11, 2023
29 of 31 checks passed
@justinchuby justinchuby deleted the justinchu/big-models branch August 11, 2023 05:22
@titaiwangms titaiwangms self-requested a review August 21, 2023 16:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TorchScript Graph: Support model > 2GB
3 participants