Skip to content

Commit

Permalink
Merge pull request #706 from NVIDIA/update-py-readme
Browse files Browse the repository at this point in the history
Update py/README.md
  • Loading branch information
narendasan authored Nov 11, 2021
2 parents 6a4daef + cbcd63a commit 9eae269
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,24 @@ Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via
## Example Usage

``` python
import torch
import torchvision
import torch_tensorrt

# Get a model
model = torchvision.models.alexnet(pretrained=True).eval().cuda()
...

# Create some example data
data = torch.randn((1, 3, 224, 224)).to("cuda")
trt_ts_module = torch_tensorrt.compile(torch_script_module,
inputs = [example_tensor, # Provide example tensor for input shape or...
torch_tensorrt.Input( # Specify input object with shape and dtype
min_shape=[1, 3, 224, 224],
opt_shape=[1, 3, 512, 512],
max_shape=[1, 3, 1024, 1024],
# For static size shape=[1, 3, 224, 224]
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
],
enabled_precisions = {torch.half}, # Run with FP16)

# Trace the module with example data
traced_model = torch.jit.trace(model, [data])
result = trt_ts_module(input_data) # run inference
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedded Torchscript

# Compile module
compiled_trt_model = torch_tensorrt.compile(traced_model, {
"inputs": [torch_tensorrt.Input(data.shape)],
"enabled_precisions": {torch.float, torch.half}, # Run with FP16
})

results = compiled_trt_model(data.half())
```

## Installation
Expand Down

0 comments on commit 9eae269

Please sign in to comment.