Skip to content

Commit

Permalink
[FX] Changes done internally at Facebook (#1603)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Wei <[email protected]>
  • Loading branch information
frank-wei and Wei Wei authored Jan 22, 2023
1 parent df65620 commit 2e21ce6
Show file tree
Hide file tree
Showing 31 changed files with 1,277 additions and 207 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ commands:
parameters:
torch-build:
type: string
default: "2.0.0.dev20230103+cu117"
default: "2.0.0.dev20230120+cu117"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu117"
Expand Down Expand Up @@ -992,7 +992,7 @@ parameters:
# Nightly platform config
torch-build:
type: string
default: "2.0.0.dev20230103+cu117"
default: "2.0.0.dev20230120+cu117"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu117"
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ FX2TRT is merged as FX module in Torch-TensorRT

* Method 1. Follow the instrucions for Torch-TensorRT
* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path
```
`
$ conda create --name python_env python=3.8
$ conda activate python_env
# Recommend to install PyTorch 1.12 and later
Expand All @@ -18,4 +18,4 @@ FX2TRT is merged as FX module in Torch-TensorRT
$ pyton -c "import torch_tensorrt.fx"
# Test an example by
$ python py/torch_tensorrt/fx/example/lower_example.py
```
`
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2802,7 +2802,7 @@ def acc_ops_linear(

if isinstance(kwargs["weight"], torch.Tensor):
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
if target is not acc_ops.linear:
if target not in (acc_ops.linear, torch.ops.aten.linear):
weight_op = trt.MatrixOperation.TRANSPOSE
else:
weight_op = trt.MatrixOperation.NONE
Expand Down
53 changes: 39 additions & 14 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,27 +187,20 @@ def aten_ops_fmod(
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.mm.default)
@tensorrt_converter(torch.ops.aten.addmm.default)
@tensorrt_converter(torch.ops.aten.linear)
def aten_ops_linear(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if target == torch.ops.aten.addmm.default:
kwargs_new = {
"bias": args[0],
"input": args[1],
"weight": args[2],
}
elif target == torch.ops.aten.mm.default:
kwargs_new = {
"bias": None,
"input": args[0],
"weight": args[1],
}
kwargs_new = {
"input": args[0],
"weight": args[1],
"bias": args[2],
}

return acc_ops_converters.acc_ops_linear(network, target, None, kwargs_new, name)


Expand Down Expand Up @@ -320,3 +313,35 @@ def aten_ops_reshape(
"acc_out_ty": acc_utils.build_raw_tensor_meta(shape=args[1]),
}
return acc_ops_converters.acc_ops_reshape(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.cat.default)
def aten_ops_cat(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"tensors": args[0],
"dim": args[1],
}
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.expand.default)
def aten_ops_expand(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"sizes": args[1],
}
return acc_ops_converters.acc_ops_expand_tensor(
network, target, None, kwargs_new, name
)
7 changes: 6 additions & 1 deletion py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,12 @@ def call_method(self, target, args, kwargs):

def output(self, target, args, kwargs):
assert len(args) == 1
outputs = args[0] if isinstance(args[0], tuple) else (args[0],)
if isinstance(args[0], tuple):
outputs = args[0]
elif isinstance(args[0], list):
outputs = tuple(args[0])
else:
outputs = (args[0],)

if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
Expand Down
5 changes: 2 additions & 3 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import dataclasses as dc
import logging
import dataclasses as dc
import logging
from typing import Any, Callable, Optional, Sequence

# @manual=//deeplearning/trt/python:py_tensorrt
Expand Down Expand Up @@ -180,8 +178,9 @@ def lower_pass(
interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
if lower_setting.use_experimental_rt:
import io
from torch_tensorrt._TRTModuleNext import TRTModuleNext

from torch_tensorrt._Device import Device
from torch_tensorrt._TRTModuleNext import TRTModuleNext

with io.BytesIO() as engine_bytes:
engine_bytes.write(interp_res.engine.serialize())
Expand Down
16 changes: 10 additions & 6 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class LowerSetting(LowerSettingBasic):
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
preset_lowerer (str): when specified, use a preset logic to build the
instance of Lowerer.
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
only used by explicit batch dim with dynamic shape mode.
only used by explicit batch dim with dynamic shape mode. In general, we use 2 GPU setting with
2 stream on each. Set total number to 8 as a safe default value.
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.
tactic_sources: tactic sources for TensorRT kernel selection. Default to None,
meaning all possible tactic sources.
Expand All @@ -81,17 +81,21 @@ class LowerSetting(LowerSettingBasic):
explicit_precision: bool = False
max_workspace_size: int = 1 << 30
strict_type_constraints: bool = False
customized_fuse_pass: PassManager = PassManager.build_from_passlist([])
lower_basic_fuse_pass: PassManager = PassManager.build_from_passlist(
[fuse_permute_matmul, fuse_permute_linear]
customized_fuse_pass: PassManager = dc.field(
default_factory=lambda: PassManager.build_from_passlist([])
)
lower_basic_fuse_pass: PassManager = dc.field(
default_factory=lambda: PassManager.build_from_passlist(
[fuse_permute_matmul, fuse_permute_linear]
)
)
verbose_log: bool = False
algo_selector = None
timing_cache_prefix: str = ""
save_timing_cache: bool = False
cuda_graph_batch_size: int = -1
preset_lowerer: str = ""
opt_profile_replica: int = 1
opt_profile_replica: int = 8
dynamic_batch: bool = True
tactic_sources: Optional[int] = None
correctness_atol: float = 0.1
Expand Down
22 changes: 22 additions & 0 deletions py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,25 @@ def _get_shape(node: fx.Node) -> Optional[torch.Size]:
# shape info not available
return None
return node.meta["tensor_meta"].shape


@log_before_after
@validate_inference(atol=1e-3, rtol=1e-2)
def fix_clamp_numerical_limits_to_fp16(
mod: torch.fx.GraphModule, input: Input
) -> torch.fx.GraphModule:
MIN_FP16 = -65504.0
MAX_FP16 = 65504.0
for node in mod.graph.nodes:
if node.op == "call_function" and "clamp" in str(node.target):
input_kwargs = node.kwargs
if input_kwargs["min"] < MIN_FP16 and input_kwargs["max"] > MAX_FP16:
new_kwargs = {
"input": input_kwargs["input"],
"min": MIN_FP16,
"max": MAX_FP16,
}
node.kwargs = new_kwargs

mod.recompile()
return mod
Loading

0 comments on commit 2e21ce6

Please sign in to comment.