Skip to content

Commit

Permalink
Changes done internally at Facebook (#1178)
Browse files Browse the repository at this point in the history
6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.prod
c822345d6d673e1653c2208435e34ab400bada3d Jason Park <[email protected]> Add support for generic torch ops to be used in training.

Co-authored-by: wwei6 <[email protected]>
  • Loading branch information
Wei and wwei6 authored Jul 13, 2022
1 parent 2fd564e commit 5ad9826
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 42 deletions.
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
tensorrt_converter,
)
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
from .input_tensor_spec import InputTensorSpec # noqa
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .trt_module import TRTModule # noqa
76 changes: 71 additions & 5 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,66 @@
from typing import Iterable, List, NamedTuple, Sequence, Tuple
from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple

import torch

from .types import Shape, ShapeRange
from .utils import get_dynamic_dims


def generate_input_specs(
inputs, lower_setting, additional_inputs=None, fixed_shape=False
):
# AIT lower setting doesn't have explicit_batch_dimension field and
# we just return None.
if not hasattr(lower_setting, "explicit_batch_dimension"):
return None

if not lower_setting.explicit_batch_dimension or fixed_shape:
return InputTensorSpec.from_tensors(inputs)

# If we don't have additional inputs, we assume the first dimension
# is the dynamic batch dimension. Otherwise, we use the additional
# inputs to determine the batch dimension.
if additional_inputs is None:
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
0,
lower_setting.max_batch_size,
lower_setting.max_batch_size,
),
lower_setting.opt_profile_replica,
)
else:
batch_dims = []

for i, j in zip(inputs, additional_inputs):
found_batch_dim = False

for idx, values in enumerate(zip(i.shape, j.shape)):
if values[0] != values[1]:
assert (
found_batch_dim is False
), f"We've already found a batch dim, {i.shape}, {j.shape}."
batch_dims.append(idx)
found_batch_dim = True

if not found_batch_dim:
raise RuntimeError(
f"Failed to find batch dimension because shapes are the same, {i.shape}"
)

return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
0,
lower_setting.max_batch_size,
lower_setting.max_batch_size,
),
lower_setting.opt_profile_replica,
batch_dims,
)


class InputTensorSpec(NamedTuple):
"""
This class contains the information of a input tensor.
Expand Down Expand Up @@ -70,6 +125,7 @@ def from_tensors_with_dynamic_batch_size(
tensors: Sequence[torch.Tensor],
batch_size_range: Tuple[int, int, int],
opt_profile_replica: int = 1,
batch_dims: Optional[List[int]] = None,
) -> List["InputTensorSpec"]:
"""
Produce a list of InputTenosrSpec named tuples which would contain
Expand All @@ -83,20 +139,30 @@ def from_tensors_with_dynamic_batch_size(
the smallest batch size allowed. The second integer indiceates
the batch size that we'll optimize for. The third integer indicates
the largest batch size allowed.
opt_profile_replica (int): If dynamic shape is enabled, each execution
context requires a different optimization profile. This arg determines
how many optimization profile replicas we want to produce.
batch_dims (Optional[List[int]]): The batch dim might not be the leading dim
and allow user to specify the batch dims using this arg. Default we treat
dim 0 as the batch dim.
Returns:
A list of InputTensorSpec named tuples with dynamic ranges.
"""
if batch_dims is None:
batch_dims = [0] * len(tensors)

input_specs = []
batch_size = tensors[0].size(0)
batch_size = tensors[0].size(batch_dims[0])

for i, tensor in enumerate(tensors):
batch_dim = batch_dims[i]
assert batch_size == tensor.size(
0
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape = list(tensor.shape)
shape[0] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
input_specs.append(
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
)
Expand Down
38 changes: 9 additions & 29 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses as dc
import logging
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
Expand All @@ -10,15 +10,9 @@
from torch.fx.passes.splitter_base import SplitResult

from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .input_tensor_spec import InputTensorSpec
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import (
chain_passes,
decorate_method,
PassFunc,
validate_inference,
)
from .passes.pass_utils import decorate_method, PassFunc, validate_inference
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting

Expand Down Expand Up @@ -91,25 +85,8 @@ def create(cls, lower_setting):
return LowerTrtInterpreter(lower_setting, timing_cache_manager)

def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
input_specs_val = (
self.lower_setting.input_specs
if self.lower_setting.input_specs
else (
InputTensorSpec.from_tensors_with_dynamic_batch_size(
input,
(
0,
self.lower_setting.max_batch_size,
self.lower_setting.max_batch_size,
),
self.lower_setting.opt_profile_replica,
)
if self.lower_setting.explicit_batch_dimension
and self.lower_setting.dynamic_batch
else InputTensorSpec.from_tensors(input)
)
)
logger.info(f"{split_name=} {input_specs_val=}")
assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
logger.info(f"{split_name=} {self.lower_setting.input_specs=}")

# Prepare algorithm selector and timing_cache for TRTInterpreter
algo_selector = None
Expand All @@ -125,7 +102,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:

interpreter = TRTInterpreter(
mod,
input_specs=input_specs_val,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
logger_level=trt.Logger.VERBOSE
Expand Down Expand Up @@ -242,6 +219,7 @@ def __call__(
self,
module: nn.Module,
inputs: Input,
additional_inputs: Optional[Input] = None,
) -> nn.Module:
module.eval()

Expand All @@ -254,7 +232,9 @@ def __call__(
x.half() if x is not None and x.dtype == torch.float32 else x
for x in inputs
)
pm = self.lower_pass_manager_builder.build_lower_pipeline(inputs)
pm = self.lower_pass_manager_builder.build_lower_pipeline(
inputs, additional_inputs
)

lower_result = pm(module)

Expand Down
31 changes: 28 additions & 3 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from functools import partial, wraps
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

import torch
from torch import nn
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.splitter_base import SplitResult
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult

from ..input_tensor_spec import generate_input_specs

from ..lower_setting import LowerSetting
from ..observer import Observer
Expand Down Expand Up @@ -120,13 +122,33 @@ def _split_pass(self) -> PassManager:

def _lower_pass(self) -> PassManager:
def lower_func(split_result: SplitResult) -> nn.Module:
if (
hasattr(self.lower_setting, "explicit_batch_dimension")
and self.lower_setting.explicit_batch_dimension
and self._additional_input
):
additional_submodule_inputs = generate_inputs_for_submodules(
split_result.split_module,
self._additional_input,
list(split_result.submodule_inputs.keys()),
)
else:
additional_submodule_inputs = None

for submod_name, submod_inputs in split_result.submodule_inputs.items():
submod = getattr(split_result.split_module, submod_name)

LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
Expand All @@ -139,8 +161,11 @@ def lower_func(split_result: SplitResult) -> nn.Module:

return PassManager.build_from_passlist([lower_func])

def build_lower_pipeline(self, input: Input) -> PassManager:
def build_lower_pipeline(
self, input: Input, additional_input: Optional[Input] = None
) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []

passes.append(self._const_fold_pass())
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ def _validate_inference(pass_: PassFunc) -> PassFunc:

@wraps(pass_)
def pass_with_validation(
module: fx.GraphModule, input: Input
module: fx.GraphModule,
input: Input,
*args,
**kwargs,
) -> fx.GraphModule:
res0 = module(*input)
processed_module = pass_(module, input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
Expand Down
22 changes: 21 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec

# NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples

Expand Down Expand Up @@ -93,6 +93,26 @@ def forward(self, x):
test_implicit_batch_dim=False,
)

def test_prod_all_dims_with_dynamic_shape(
self,
op=torch.prod,
):
class Prod(torch.nn.Module):
def forward(self, x):
return op(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]

self.run_test_with_dynamic_shape(
Prod(), input_specs, expected_ops={acc_ops.prod}
)


if __name__ == "__main__":
run_tests()
43 changes: 42 additions & 1 deletion py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import generate_input_specs, InputTensorSpec, LowerSetting


class TestTRTModule(TestCase):
Expand Down Expand Up @@ -47,6 +47,47 @@ def test_from_tensors_with_dynamic_batch_size(self):
self.assertEqual(batch_size, shape[0])
self.assertSequenceEqual(tensor.shape[1:], shape[1:])

def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
batch_size_range = [2, 3, 4]
specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
tensors, batch_size_range, batch_dims=[0, 1]
)
for i, spec_and_tensor in enumerate(zip(specs, tensors)):
spec, tensor = spec_and_tensor
self._validate_spec(spec, tensor, dynamic_dims=[i])

for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
self.assertEqual(batch_size, shape[i])
tensor_shape = list(tensor.shape)
tensor_shape[i] = batch_size
self.assertSequenceEqual(tensor_shape, shape)

def test_generate_input_specs(self):
lower_setting = LowerSetting(
explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
)

# Implicit batch dim.
inputs = [torch.randn(1, 2, 3)]
specs = generate_input_specs(inputs, lower_setting)
for spec, tensor in zip(specs, inputs):
self._validate_spec(spec, tensor)

# Explicit batch dim without additional inputs.
lower_setting.explicit_batch_dimension = True
specs = generate_input_specs(inputs, lower_setting)
for spec, tensor in zip(specs, inputs):
self._validate_spec(spec, tensor, dynamic_dims=[0])
self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica)

# Explicit batch dim with additional inputs.
additional_inputs = [torch.randn(1, 1, 3)]
specs = generate_input_specs(inputs, lower_setting, additional_inputs)
for spec, tensor in zip(specs, inputs):
self._validate_spec(spec, tensor, dynamic_dims=[1])
self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica)


if __name__ == "__main__":
run_tests()

0 comments on commit 5ad9826

Please sign in to comment.