From cbcd63aa15039b37edbb74c9763aa2df1b5a1cfb Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 10 Nov 2021 23:57:38 -0500 Subject: [PATCH] Update README.md Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- py/README.md | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/py/README.md b/py/README.md index 77c0461a67..3f6ee91ca1 100644 --- a/py/README.md +++ b/py/README.md @@ -7,26 +7,24 @@ Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via ## Example Usage ``` python -import torch -import torchvision import torch_tensorrt -# Get a model -model = torchvision.models.alexnet(pretrained=True).eval().cuda() +... -# Create some example data -data = torch.randn((1, 3, 224, 224)).to("cuda") +trt_ts_module = torch_tensorrt.compile(torch_script_module, + inputs = [example_tensor, # Provide example tensor for input shape or... + torch_tensorrt.Input( # Specify input object with shape and dtype + min_shape=[1, 3, 224, 224], + opt_shape=[1, 3, 512, 512], + max_shape=[1, 3, 1024, 1024], + # For static size shape=[1, 3, 224, 224] + dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool) + ], + enabled_precisions = {torch.half}, # Run with FP16) -# Trace the module with example data -traced_model = torch.jit.trace(model, [data]) +result = trt_ts_module(input_data) # run inference +torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedded Torchscript -# Compile module -compiled_trt_model = torch_tensorrt.compile(traced_model, { - "inputs": [torch_tensorrt.Input(data.shape)], - "enabled_precisions": {torch.float, torch.half}, # Run with FP16 -}) - -results = compiled_trt_model(data.half()) ``` ## Installation