Skip to content

Commit

Permalink
Added dynamic shape support for kwargs and dynamo.trace
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Jul 18, 2024
1 parent 08f2cbb commit 164b934
Show file tree
Hide file tree
Showing 3 changed files with 708 additions and 249 deletions.
67 changes: 51 additions & 16 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import Any, Optional, Tuple
from inspect import signature
from typing import Any, Optional, Tuple, Union

import torch
from torch.export import Dim, export
Expand Down Expand Up @@ -76,14 +77,58 @@ def trace(
device = to_torch_device(kwargs.get("device", default_device()))
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
dynamic_shapes = []
for input in arg_inputs: # type: ignore
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
# breakpoint()
exp_program = export(
mod,
tuple(torch_arg_inputs),
kwargs=torch_kwarg_inputs,
dynamic_shapes=dynamic_shapes,
)

return exp_program


def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]:
if isinstance(inputs, dict):
dynamic_shapes_kwarg = {}
for k, v in inputs.items():
dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v)
return dynamic_shapes_kwarg

elif isinstance(inputs, Input):
return get_dynamic_shapes(inputs)

elif isinstance(inputs, (list, tuple)):
dynamic_shapes = []
for input in inputs:
dynamic_shapes.append(get_dynamic_shapes(input))
return dynamic_shapes

raise TypeError(f"Unknown type {type(inputs)}.")


def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]:
# dynamic_shape is a dict and cannot work without keys. Here we use position argument name
# in forward function as the name
args = list(signature(mod.forward).parameters.keys())
dynamic_shapes = {}
for input, input_name in zip(inputs, args[: len(inputs)]):
dynamic_shapes[input_name] = get_dynamic_shapes(input)
return dynamic_shapes


def get_dynamic_shapes(input: Input) -> dict[Any, Any]:
if not isinstance(input, Input):
raise TypeError(f"Expected type torch_trt.Input, but got {type(input)}")
else:
dynamic_dims = {}
if input.shape_mode == Input._ShapeMode.DYNAMIC:
min_shape = input.shape["min_shape"]
opt_shape = input.shape["opt_shape"]
max_shape = input.shape["max_shape"]
assert len(min_shape) == len(opt_shape) == len(max_shape)
dynamic_dims = {}
for dim in range(len(min_shape)):
if min_shape[dim] == opt_shape[dim] == max_shape[dim]:
continue
Expand All @@ -93,14 +138,4 @@ def trace(
min=min_shape[dim],
max=max_shape[dim],
)

dynamic_shapes.append(dynamic_dims)

exp_program = export(
mod,
tuple(torch_arg_inputs),
kwargs=torch_kwarg_inputs,
dynamic_shapes=tuple(dynamic_shapes),
)

return exp_program
return dynamic_dims
Loading

0 comments on commit 164b934

Please sign in to comment.