Skip to content

Commit

Permalink
Exporting a Model from PyTorch to ONNX tutorial (#2935)
Browse files Browse the repository at this point in the history
* added timing comparison
---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
diningeachox and svekars authored Jun 18, 2024
1 parent ef9750a commit be898cb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
Binary file added _static/img/cat_resized.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
42 changes: 39 additions & 3 deletions advanced_source/super_resolution_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _initialize_weights(self):

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # just a random number
batch_size = 64 # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
Expand Down Expand Up @@ -218,6 +218,32 @@ def to_numpy(tensor):
# ONNX exporter, so please contact us in that case.
#

######################################################################
# Timing Comparison Between Models
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#

######################################################################
# Since ONNX models optimize for inference speed, running the same
# data on an ONNX model instead of a native pytorch model should result in an
# improvement of up to 2x. Improvement is more pronounced with higher batch sizes.


import time

x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

start = time.time()
torch_out = torch_model(x)
end = time.time()
print(f"Inference of Pytorch model used {end - start} seconds")

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
start = time.time()
ort_outs = ort_session.run(None, ort_inputs)
end = time.time()
print(f"Inference of ONNX model used {end - start} seconds")


######################################################################
# Running the model on an image using ONNX Runtime
Expand Down Expand Up @@ -301,10 +327,20 @@ def to_numpy(tensor):
# Save the image, we will compare this with the output image from mobile device
final_img.save("./_static/img/cat_superres_with_ort.jpg")

# Save resized original image (without super-resolution)
img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
img.save("cat_resized.jpg")

######################################################################
# Here is the comparison between the two images:
#
# .. figure:: /_static/img/cat_resized.jpg
#
# Low-resolution image
#
# .. figure:: /_static/img/cat_superres_with_ort.jpg
# :alt: output\_cat
#
# Image after super-resolution
#
#
# ONNX Runtime being a cross platform engine, you can run it across
Expand All @@ -313,7 +349,7 @@ def to_numpy(tensor):
# ONNX Runtime can also be deployed to the cloud for model inferencing
# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.
#
# More information about ONNX Runtime's performance `here <https://github.com/microsoft/onnxruntime#high-performance>`__.
# More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/performance>`__.
#
#
# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.
Expand Down

0 comments on commit be898cb

Please sign in to comment.