Skip to content

Commit

Permalink
Fixed folding constants and style check
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Nov 15, 2024
1 parent 06bc7ab commit a76ddcb
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions nemo/export/tensorrt_lazy_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
lock_sm = threading.Lock()


# Map of TRT dtype -> Torch dtype
def trt_to_torch_dtype_dict():
"""
Map of TRT dtype -> Torch dtype
"""
return {
trt.int32: torch.int32,
trt.float32: torch.float32,
Expand Down Expand Up @@ -245,11 +247,16 @@ def infer(self, stream, use_cuda_graph=False):


def make_tensor(d):
"""
Creates a new tensor from d, returns d if d is already a tensor
"""
return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda()


def unroll_input(input_names, input_example):
# Simulate list/tuple unrolling during ONNX export
"""
Simulates list/tuple unrolling during ONNX export
"""
unrolled_input = {}
for name in input_names:
val = input_example[name]
Expand Down Expand Up @@ -353,15 +360,15 @@ def __init__(
input_names: Optional list of input names. If None, will be read from the function signature.
output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list
of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
of their dimensions, like [[], [5], [-1]] for Tensor, list of 5 items and dynamic list.
export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
input_profiles: Optional list of profiles for TRT builder and ONNX export.
Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}.
dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be
dynamic_batchsize: A sequence with three elements to define the input batch size range for the model to be
converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH].
[note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine.
use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls!
[note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used.
use_cuda_graph: Use CUDA Graph for inference. Note: inputs have to be the same GPU memory between calls!
timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes).
fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile).
"""
Expand Down Expand Up @@ -614,10 +621,10 @@ def add_profile(id, val):
**export_args,
)
if polygraphy_imported:
from polygraphy.backend.onnx.loader import fold_constants, onnx_from_path

fold_constants(onnx_from_path(onnx_path), size_threshold=16 * 1000 * 1000)
from polygraphy.backend.onnx.loader import fold_constants, onnx_from_path, save_onnx

onnx_model = fold_constants(onnx_from_path(onnx_path), size_threshold=16 * 1000 * 1000)
save_onnx(onnx_model, onnx_path)
self.logger.info("Export to ONNX successful.")
engine_bytes = self._onnx_to_trt(onnx_path)
if engine_bytes:
Expand Down

0 comments on commit a76ddcb

Please sign in to comment.