diff --git a/_static/img/cat_resized.jpg b/_static/img/cat_resized.jpg new file mode 100644 index 0000000000..c7746e6530 Binary files /dev/null and b/_static/img/cat_resized.jpg differ diff --git a/advanced_source/super_resolution_with_onnxruntime.py b/advanced_source/super_resolution_with_onnxruntime.py index 3f4dd43969..ecb0ba4fe4 100644 --- a/advanced_source/super_resolution_with_onnxruntime.py +++ b/advanced_source/super_resolution_with_onnxruntime.py @@ -107,7 +107,7 @@ def _initialize_weights(self): # Load pretrained model weights model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth' -batch_size = 1 # just a random number +batch_size = 64 # just a random number # Initialize model with the pretrained weights map_location = lambda storage, loc: storage @@ -218,6 +218,32 @@ def to_numpy(tensor): # ONNX exporter, so please contact us in that case. # +###################################################################### +# Timing Comparison Between Models +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +###################################################################### +# Since ONNX models optimize for inference speed, running the same +# data on an ONNX model instead of a native pytorch model should result in an +# improvement of up to 2x. Improvement is more pronounced with higher batch sizes. + + +import time + +x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) + +start = time.time() +torch_out = torch_model(x) +end = time.time() +print(f"Inference of Pytorch model used {end - start} seconds") + +ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} +start = time.time() +ort_outs = ort_session.run(None, ort_inputs) +end = time.time() +print(f"Inference of ONNX model used {end - start} seconds") + ###################################################################### # Running the model on an image using ONNX Runtime @@ -301,10 +327,20 @@ def to_numpy(tensor): # Save the image, we will compare this with the output image from mobile device final_img.save("./_static/img/cat_superres_with_ort.jpg") +# Save resized original image (without super-resolution) +img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img) +img.save("cat_resized.jpg") ###################################################################### +# Here is the comparison between the two images: +# +# .. figure:: /_static/img/cat_resized.jpg +# +# Low-resolution image +# # .. figure:: /_static/img/cat_superres_with_ort.jpg -# :alt: output\_cat +# +# Image after super-resolution # # # ONNX Runtime being a cross platform engine, you can run it across @@ -313,7 +349,7 @@ def to_numpy(tensor): # ONNX Runtime can also be deployed to the cloud for model inferencing # using Azure Machine Learning Services. More information `here `__. # -# More information about ONNX Runtime's performance `here `__. +# More information about ONNX Runtime's performance `here `__. # # # For more information about ONNX Runtime `here `__.