From 12f545c1fc8266bbec06391a1220d55742b57b57 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 28 Apr 2023 13:00:59 -0700 Subject: [PATCH] refactor: Reorging to reduce code duplication and seperating TRT implementation, example changes with ReLU Signed-off-by: Naren Dasan --- .../fx/converters/acc_ops_converters.py | 76 +++++++++++---- py/torch_tensorrt/fx/converters/activation.py | 39 -------- .../fx/converters/aten_ops_converters.py | 13 ++- .../fx/converters/converter_utils.py | 77 +++++++--------- .../fx/converters/impl/__init__.py | 0 .../fx/converters/impl/activation.py | 92 +++++++++++++++++++ .../fx/converters/nn_ops_converters.py | 24 +++++ 7 files changed, 217 insertions(+), 104 deletions(-) create mode 100644 py/torch_tensorrt/fx/converters/impl/__init__.py create mode 100644 py/torch_tensorrt/fx/converters/impl/activation.py create mode 100644 py/torch_tensorrt/fx/converters/nn_ops_converters.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 51b5d899eb..eb62d16379 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,6 +26,7 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous +from torch_tensorrt.fx.converters.impl import activation _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -1004,9 +1005,14 @@ def acc_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.RELU - return add_activation_layer(network, input_val, operation_type, target, name) + + return activation.relu( + network, + target, + SourceIR.ACC, + name, + kwargs["input"], + ) @tensorrt_converter(acc_ops.leaky_relu) @@ -1020,8 +1026,14 @@ def acc_ops_leaky_relu( input_val = kwargs["input"] negative_slope = kwargs["negative_slope"] operation_type = trt.ActivationType.LEAKY_RELU - return add_activation_layer( - network, input_val, operation_type, target, name, negative_slope + return activation.convert_activation( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + alpha=negative_slope, ) @@ -1036,7 +1048,9 @@ def acc_ops_elu( input_val = kwargs["input"] alpha = kwargs["alpha"] operation_type = trt.ActivationType.ELU - return add_activation_layer(network, input_val, operation_type, target, name, alpha) + return activation.convert_activation( + network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha + ) @tensorrt_converter(acc_ops.selu) @@ -1049,7 +1063,14 @@ def acc_ops_selu( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.ActivationType.SELU - return add_activation_layer(network, input_val, operation_type, target, name) + return activation.convert_activation( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.softsign) @@ -1062,7 +1083,14 @@ def acc_ops_softsign( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.ActivationType.SOFTSIGN - return add_activation_layer(network, input_val, operation_type, target, name) + return activation.convert_activation( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sin) @@ -1140,7 +1168,14 @@ def acc_ops_tanh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.ActivationType.TANH - return add_activation_layer(network, input_val, operation_type, target, name) + return activation.convert_activation( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.asin) @@ -3137,12 +3172,13 @@ def acc_ops_hard_sigmoid( "of the TensorRT region!" ) - return add_activation_layer( + return activation.convert_activation( network, - input_val, - trt.ActivationType.HARD_SIGMOID, target, + SourceIR.ACC, name, + trt.ActivationType.HARD_SIGMOID, + input_val, alpha=1 / 6, beta=0.5, ) @@ -3164,8 +3200,13 @@ def acc_ops_sigmoid( "of the TensorRT region!" ) - return add_activation_layer( - network, input_val, trt.ActivationType.SIGMOID, target, name + return activation.convert_activation( + network, + target, + SourceIR.ACC, + name, + trt.ActivationType.SIGMOID, + input_val, ) @@ -3557,12 +3598,13 @@ def acc_ops_hardtanh( "of the TensorRT region!" ) - return add_activation_layer( + return activation.convert_activation( network, - input_val, - trt.ActivationType.CLIP, target, + SourceIR.ACC, name, + trt.ActivationType.CLIP, + input_val, alpha=kwargs["min_val"], beta=kwargs["max_val"], ) diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index a7ab25152c..4efddab5cd 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -9,45 +9,6 @@ from .converter_utils import mark_as_int8_layer -def common_activation( - network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name -): - layer = network.add_activation(input=input_val, type=activation_type) - layer.name = layer_name - - if input_val.dynamic_range: - dyn_range = activation_dyn_range_fn(input_val.dynamic_range) - mark_as_int8_layer(layer, dyn_range) - - return layer.get_output(0) - - -@tensorrt_converter(torch.nn.functional.relu) -@tensorrt_converter(torch.nn.modules.activation.ReLU) -def relu(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - input_val = kwargs["input"] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"ReLU received input {input_val} that is not part " - "of the TensorRT region!" - ) - - def activation_dyn_range_fn(dyn_range): - return max(0, dyn_range[0]), max(0, dyn_range[1]) - - return common_activation( - network, - submod, - input_val, - trt.ActivationType.RELU, - activation_dyn_range_fn, - layer_name, - ) - - @tensorrt_converter(torch.nn.modules.activation.Sigmoid) def sigmoid(network, submod, args, kwargs, layer_name): # args/kwargs should have already been normalized to kwargs diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index c86f2bd228..803db8b68c 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -22,6 +22,7 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +from torch_tensorrt.fx.converters.impl import activation _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -290,10 +291,14 @@ def aten_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } - return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name) + + return activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) @tensorrt_converter(torch.ops.aten.sub.Tensor) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 17a0cef456..d13be41d05 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -2,6 +2,7 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from enum import Enum, auto import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt @@ -22,6 +23,26 @@ from ..utils import torch_dtype_from_trt +class SourceIR(Enum): + NN = auto() + ACC = auto() + ATEN = auto() + PRIM = auto() + UNKNOWN = auto() + + def __str__(self): + if self == SourceIR.NN: + return "nn" + elif self == SourceIR.ACC: + return "acc" + elif self == SourceIR.ATEN: + return "aten" + elif self == SourceIR.PRIM: + return "prim" + else: + return "unknown_ir" + + def get_trt_plugin( plugin_name: str, field_collection: List[TRTPluginFieldCollection], @@ -77,7 +98,9 @@ def get_positive_dim(dim: int, dim_size: int) -> int: return dim -def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None: +def set_layer_name( + layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None +) -> None: """ Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]" @@ -86,8 +109,16 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None: target (Target): A fx node.target. For call_function node, it's the function that the node represents. name (str): Consists of fx node.name with optional suffix. + source_ir: (Optional[SourceIR]): The IR producing the op. """ - target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" + + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN + + target_name = ( + f"{source_ir}_ops.{target}" + if isinstance(target, str) + else f"{source_ir}_ops.{target.__name__}" + ) layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]" @@ -560,48 +591,6 @@ def add_unary_layer( return layer.get_output(0) -def add_activation_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.ActivationType, - target: Target, - name: str, - alpha: Optional[Any] = None, - beta: Optional[Any] = None, -) -> TRTTensor: - """ - Add a TensorRT Activation layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the activation op. - Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT activation - operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - alpha (Optional[Any]): If not None, we will use it to set the alpha - attribute of the created TensorRT activation layer. - beta (Optional[Any]): If not None, we will use it to set the beta - attribute of the created TensorRT activation layer. - - Returns: - The output of TensorRT Activation layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_activation(input_val, operation_type) - if alpha is not None: - layer.alpha = alpha - if beta is not None: - layer.beta = beta - set_layer_name(layer, target, name) - return layer.get_output(0) - - def add_reduce_layer( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/fx/converters/impl/__init__.py b/py/torch_tensorrt/fx/converters/impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/converters/impl/activation.py b/py/torch_tensorrt/fx/converters/impl/activation.py new file mode 100644 index 0000000000..ef7057b79f --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/activation.py @@ -0,0 +1,92 @@ +import numpy as np +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Argument, Target + + +from torch_tensorrt.fx.converters.converter_utils import mark_as_int8_layer +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.converters.converter_utils import SourceIR + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + + +def convert_activation( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + operation_type: trt.ActivationType, + input_val: TRTTensor, + alpha: Optional[Any] = None, + beta: Optional[Any] = None, + dyn_range_fn: Optional[Callable[[float, float], Any]] = None, +) -> TRTTensor: + """ + Add a TensorRT Activation layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): Target of fx node. + source_ir (Optional[SourceIR]): Type of IR calling the converter + operation_type (trt.ElementWiseOperation): Type of the TensorRT activation operation. + name (str): The name we want to assign to the created TensorRT layer. + input_val (TRTTensor): Input to the activation op. + Must be a TensorRT tensor. + alpha (Optional[Any]): If not None, we will use it to set the alpha + attribute of the created TensorRT activation layer. + beta (Optional[Any]): If not None, we will use it to set the beta + attribute of the created TensorRT activation layer. + dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range + + + Returns: + The output of TensorRT Activation layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_activation(input_val, operation_type) + if alpha is not None: + layer.alpha = alpha + if beta is not None: + layer.beta = beta + set_layer_name(layer, target, name, source_ir) + + if input_val.dynamic_range is not None: + dyn_range = dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) + return layer.get_output(0) + + +def relu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.RELU + + def relu_dyn_range_fn(dyn_range): + return max(0, dyn_range[0]), max(0, dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=relu_dyn_range_fn, + ) diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py new file mode 100644 index 0000000000..551a4f368e --- /dev/null +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -0,0 +1,24 @@ +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch + +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.converter_utils import SourceIR + + +@tensorrt_converter(torch.nn.functional.relu) +@tensorrt_converter(torch.nn.modules.activation.ReLU) +def relu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + + return activation.relu( + network=network, + target="torch.nn.functional.relu", + source_ir=SourceIR.NN, + name=layer_name, + input_val=kwargs["input"], + )