From 14dbbfd33a8e2a12361ce00759a12ebee140a749 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 30 Oct 2023 00:36:28 -0700 Subject: [PATCH] docs: Refactoring the docs for 2.1 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- docsrc/conf.py | 15 +- docsrc/contributors/dynamo_converters.rst | 131 ++++++++ docsrc/contributors/fx_converters.rst | 211 ------------ ...iting_converters.rst => ts_converters.rst} | 6 +- .../{user_guide => dynamo}/dynamo_export.rst | 34 +- .../{user_guide => dynamo}/torch_compile.rst | 2 +- .../getting_started_with_fx_path.rst | 0 docsrc/getting_started/installation.rst | 9 +- docsrc/index.rst | 76 +++-- docsrc/py_api/dynamo.rst | 36 +++ docsrc/py_api/torch_tensorrt.rst | 1 + .../creating_torchscript_module_in_python.rst | 0 .../getting_started_with_cpp_api.rst | 4 +- .../getting_started_with_python_api.rst | 0 .../torchscript_frontend_from_pytorch.rst} | 2 +- docsrc/user_guide/dynamic_shapes.rst | 59 ++-- docsrc/user_guide/saving_models.rst | 23 +- py/torch_tensorrt/_compile.py | 2 +- py/torch_tensorrt/dynamo/__init__.py | 12 +- .../dynamo/{compile.py => _compiler.py} | 64 +++- .../dynamo/{export.py => _exporter.py} | 37 +++ .../dynamo/{aten_tracer.py => _tracer.py} | 58 +++- py/torch_tensorrt/dynamo/backend/backends.py | 2 +- ...rter_registry.py => _ConverterRegistry.py} | 4 +- .../dynamo/conversion/_TRTInterpreter.py | 16 +- .../dynamo/conversion/__init__.py | 7 +- .../{conversion.py => _conversion.py} | 7 +- .../dynamo/conversion/aten_ops_converters.py | 299 ++++++++++-------- .../dynamo/conversion/converter_utils.py | 5 +- .../dynamo/conversion/impl/cast.py | 2 +- .../dynamo/conversion/ops_evaluators.py | 12 +- .../dynamo/conversion/prims_ops_converters.py | 9 +- py/torch_tensorrt/dynamo/lowering/__init__.py | 4 + .../partitioning/_adjacency_partitioner.py | 4 +- .../partitioning/_global_partitioner.py | 4 +- .../dynamo/tools/opset_coverage.py | 5 +- py/torch_tensorrt/dynamo/utils.py | 4 +- 37 files changed, 663 insertions(+), 503 deletions(-) create mode 100644 docsrc/contributors/dynamo_converters.rst delete mode 100644 docsrc/contributors/fx_converters.rst rename docsrc/contributors/{writing_converters.rst => ts_converters.rst} (98%) rename docsrc/{user_guide => dynamo}/dynamo_export.rst (74%) rename docsrc/{user_guide => dynamo}/torch_compile.rst (99%) rename docsrc/{user_guide => fx}/getting_started_with_fx_path.rst (100%) create mode 100644 docsrc/py_api/dynamo.rst rename docsrc/{user_guide => ts}/creating_torchscript_module_in_python.rst (100%) rename docsrc/{getting_started => ts}/getting_started_with_cpp_api.rst (98%) rename docsrc/{getting_started => ts}/getting_started_with_python_api.rst (100%) rename docsrc/{user_guide/use_from_pytorch.rst => ts/torchscript_frontend_from_pytorch.rst} (97%) rename py/torch_tensorrt/dynamo/{compile.py => _compiler.py} (61%) rename py/torch_tensorrt/dynamo/{export.py => _exporter.py} (84%) rename py/torch_tensorrt/dynamo/{aten_tracer.py => _tracer.py} (50%) rename py/torch_tensorrt/dynamo/conversion/{converter_registry.py => _ConverterRegistry.py} (99%) rename py/torch_tensorrt/dynamo/conversion/{conversion.py => _conversion.py} (95%) diff --git a/docsrc/conf.py b/docsrc/conf.py index 0fd4acc2e0..87849f8e73 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -15,12 +15,12 @@ sys.path.append(os.path.join(os.path.dirname(__name__), "../py")) -import torch import pytorch_sphinx_theme +import torch import torch_tensorrt +from docutils import nodes from docutils.parsers.rst import Directive, directives from docutils.statemachine import StringList -from docutils import nodes # -- Project information ----------------------------------------------------- @@ -175,12 +175,21 @@ # Tell sphinx what the pygments highlight language should be. highlight_language = "cpp" +autodoc_typehints_format = 'short' +python_use_unqualified_type_names = True + +autodoc_type_aliases = { + 'LegacyConverterImplSignature': 'LegacyConverterImplSignature', + 'DynamoConverterImplSignature': 'DynamoConverterImplSignature', + 'ConverterImplSignature': 'ConverterImplSignature', +} + # -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # See http://stackoverflow.com/a/41184353/3343043 from docutils import nodes -from sphinx.util.docfields import TypedField from sphinx import addnodes +from sphinx.util.docfields import TypedField def patched_make_field(self, types, domain, items, **kw): diff --git a/docsrc/contributors/dynamo_converters.rst b/docsrc/contributors/dynamo_converters.rst new file mode 100644 index 0000000000..3238d609f3 --- /dev/null +++ b/docsrc/contributors/dynamo_converters.rst @@ -0,0 +1,131 @@ +.. _dynamo_converters: + +Writing Dynamo Converters +============================= +The dynamo converter library in Torch-TensorRT is located in ``TensorRT/py/torch_tensorrt/dynamo/conversion``. + +Converter implementation +------------------------ + +Registration +^^^^^^^^^^^^^^^^ + +A converter is a function decrorated with ``torch_tensorrt.dynamo.dynamo_tensorrt_converter`` that follows the function signature: + + +.. code-block:: python + + @torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) + def leaky_relu_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionCtx, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + + + +The decorator takes a number of arguments: + + * ``key``: Node target for which the converter is implemented for (for example, torch.ops.aten.leaky_relu.default) + * ``enabled``: Whether the converter should be enabled as a converter that can be used in the converter registry + * ``capability_validator``: A lambda that can take a ``torch.fx.Node`` and determine if the converter can properly handle this Node. If the validator returns ``False``, the subgraph partitioner will make sure this Node is run in PyTorch in the compiled graph. + * ``priority``: Allows developers to override existing converters in the converter registry + +All that is required for a converter is the key. + +The function body is responsible for taking the current state of the network and adding the next subgraph to perform the op specified in the decorator with TensorRT operations. +The function is provided arguments as the native PyTorch op would be provided with the added case of numpy arrays for frozen Tensor attributes or TensorRT ITensors which are ouput Tensors of previous nodes, correspoding to edges/output Tensors of intermediate operations in the graph. +To determine the types expected as well as the return type of the converter, look at the definition of the op being converted. In the case of ``aten`` operations, this file will be the source of truth: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml +Since many converters a developer may write are a composition of lower level operators, instead of needing to implement the converter in raw TensorRT, the ``torch_tensorrt.dynamo.conversion.impl`` subpackage contains many implementations of operations that can be chained to create a TensorRT subgraph. + + * ``ctx`` : The current state of the compiler. Converters primarily will manipulate ctx.net which is the ``tensorrt.INetworkDefinition`` being constructed. Additional metadata including user provided settings is available in this struct as well. + * ``target``: Target key in the ``call_module`` or ``call_function`` above. eg: ``torch.ops.aten_.leaky_relu.default``. Note that ``torch.ops.aten._leaky_relu`` is the ``OpOverloadPacket`` while ``torch.ops.aten_.leaky_relu.default`` is ``OpOverload``. + * ``args``: The arguments being passed to a particular Node (as collected by the ``torch_tensorrt.dynamo.conversion.TRTInterpreter``). These arguments along with the kwargs are to be used to construct a specific TensorRT subgraph representing the current node in the INetworkDefinition. + * ``kwargs``: The arguments being passed to a particular Node (as collected by the ``torch_tensorrt.dynamo.conversion.TRTInterpreter``). + * ``name``: String containing the name of the target + +The function is expected to return the ``tensorrt.ITensor`` or some collection of ``tensorrt.ITensor`` for use in the ``torch_tensorrt.dynamo.conversion.TRTInterpreter`` matching the output signature of the operation being converted + +Capability Validation +^^^^^^^^^^^^^^^^^^^^^^^ + +There are some converters which have special cases to be accounted for. In those cases, one should use ``capability_validators`` to register the converter using ``@dynamo_tensorrt_converter`` +We illustrate this through ``torch.ops.aten.embedding.default``. It has parameters - ``scale_grad_by_freq`` and ``sparse`` which are not currently supported by the implementation. +In such cases we can write validator ``embedding_param_validator`` which implements that given those paramters the converter is not supported and register the converter by + + +Type Contract +^^^^^^^^^^^^^^^ + +The function is expected to follow the type contract established by the signature. This includes accepting the union of valid PyTorch types + numpy arrays for constant tensors and TensorRT ITensors. +In the case that only a subset of types is supported in the converter, you can also add the ``torch_tensorrt.dynamo.conversion.converter_utils.enforce_tensor_types``, which allows you to specify a dictionary mapping between input positions and types that those inputs can take. Where possible the decorator will convert inputs to match these types prefering the order provided. +``int`` keys in the dictionary will refer to positional arguments in ``args``. ``str`` keys will refer to keyword arguments in ``kwargs``. + + +Example: ``Convolution`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The default convolution converter both uses a capability validator and type enforcement to prevent being run in unsupported situations. +The capability validator is run during partitioning to determine if a particular convolution node can be converted to TensorRT or needs to run in PyTorch. Here the validator ensures that the convolution is no greater than 3D. +The type enforcer will autocast before the converter is called, inputs to the supported type in the converter, thereby limiting the number of cases an author must handle. + +.. code-block:: python + + @dynamo_tensorrt_converter( + torch.ops.aten.convolution.default, capability_validator=lambda conv_node: conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) + ) # type: ignore[misc] + @enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (np.ndarray, torch.Tensor, TRTTensor), + 2: (np.ndarray, torch.Tensor, TRTTensor), + } + ) # type: ignore[misc] + def aten_ops_convolution( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + +Evaluators +------------------------ + +Some operations do not produce TensorRT subgraphs as a side-effect. These are termed evaluators. + + Example: ``operator.getitem`` + + Evaluators are categorized as so since they do not make any modification to the graph. This is implemented in ``py/torch_tensorrt/dynamo/conversion/op_evaluators.py``, with the corresponding ``capbility_validator``. + The opcode is ``operator.getitem``. + + +Operator Decomposition +----------------------- + +There are some converters which can be decomposed into suboperations in PyTorch and need not have seperate converter registration. +Such converters can be implemented via a decomposition + +Example: ``addmm`` +^^^^^^^^^^^^^^^^^^^^^^^ + +The decompositions are registered via ``register_torch_trt_decomposition`` decorator +We define ``addmm_replacement`` and replace it with the torch ops, which will have their corresponding converters called. + +.. code-block:: python + + @torch_tensorrt.dynamo.lowering.register_torch_trt_decomposition(torch.ops.aten.addmm) + def addmm_replacement( + input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1 + ) -> torch.Tensor: + return torch.add( + torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha) + ) + +You can modify the decompositions run by editing ``torch_tensorrt.dynamo.lowering.torch_enabled_decompositions`` and ``torch_tensorrt.dynamo.lowering.torch_disabled_decompositions`` + + Note: ``torch_tensorrt.dynamo.lowering.torch_enabled_decompositions`` and ``torch_tensorrt.dynamo.lowering.torch_disabled_decompositions`` must be disjoint sets and that the decompositions already defined in ``torch_tensorrt.dynamo.lowering`` will take precedence over torch lowering ops. + +Much of the time, this is significantly easier than implementing a converter. So where possible, this is what should be tried first. \ No newline at end of file diff --git a/docsrc/contributors/fx_converters.rst b/docsrc/contributors/fx_converters.rst deleted file mode 100644 index 75ee1c6341..0000000000 --- a/docsrc/contributors/fx_converters.rst +++ /dev/null @@ -1,211 +0,0 @@ -.. _dynamo_conversion: - -Dynamo Converters -================== -The dynamo converter library in Torch-TensorRT is located in ``TensorRT/py/torch_tensorrt/dynamo/conversion``. - - - -Steps -================== - -Operation Set -------------------- -The converters in dynamo are produced by ``aten_trace`` and falls under ``aten_ops_converters`` ( FX earlier had ``acc_ops_converters``, ``aten_ops_converters`` or ``nn_ops_converters`` depending on the trace through which it was produced). The converters are registered using ``dynamo_tensorrt_converter`` for dynamo. The function decorated -has the arguments - ``network, target, args, kwargs, name``, which is common across all the operators schema. -These functions are mapped in the ``aten`` converter registry dictionary (at present a compilation of FX and dynamo converters, FX will be deprecated soon), with key as the function target name. - - * aten_trace is produced by ``torch_tensorrt.dynamo.trace(..)`` for the export path and ``torch_tensorrt.compile(ir=dynamo)`` for the compile path. - The export path makes use of ``aten_tracer`` whereas the alternate trace in compile is produced by the AOT Autograd library. - Both these simplify the torch operators to reduced set of Aten operations. - - -As mentioned above, if you would like to add a new converter, its implementation will be included in ``TensorRT/py/torch_tensorrt/dynamo/conversion/impl`` -Although there is a corresponding implementation of the converters included in the common implementation library present in ``TensorRT/py/torch_tensorrt/fx/impl`` for FX converters, this documentation focuses on the implementation of the ``aten_ops`` converters in dynamo. - - -Converter implementation ------------------------- -In this section, we illustrate the steps to be implemented for writing a converter. We divide them according to activation, operator, lowering pass implementation or evaluator. -Each of them is detailed with the help of an example - - * Registration - - The converter needs to be registered with the appropriate op code in the ``dynamo_tensorrt_converter``. - - * Activation type - - Example: ``leaky_relu`` - - - * aten_ops_converters: Dynamo_converters - - Define in ``py/torch_tensorrt/dynamo/conversion/aten_ops_converters``. One needs to register the opcode generated in the trace with ``dynamo_tensorrt_converter`` decorator. Op code to be used for the registration or the converter registry key in this case is ``torch.ops.aten.leaky_relu.default`` - - .. code-block:: python - - @dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) - def aten_ops_leaky_relu( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1]) - - The ``tensorrt_converter`` (used for FX registration) and ``dynamo_tensorrt_converter`` are similar decorator functions with some differences. - - #. Both register the converters in the registeries (python dictionaries) - ``CONVERTERS`` and ``DYNAMO_CONVERTERS`` respectively. These are two dictioneries which are concatenated to form the overall converter registry - #. The dictionary is keyed on the ``OpOverLoad`` which is mentioned in more detail below with examples - #. Both return the decorated converter implementation - #. The ``CONVERTERS`` directly registers the decorated ``converter_implementation`` function, while ``DYNAMO_CONVERTERS`` has additionational arguments and registers the ``ConverterSupport`` object - #. The additional arguments are: - - .. code-block:: python - def dynamo_tensorrt_converter( - key: Target, - enabled: bool = True, - capability_validator: Optional[Callable[[Node], bool]] = None, - priority: ConverterPriority = ConverterPriority.STANDARD, - ) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]: - - #. key: Node target for which the converter is implemented for (for example, torch.ops.aten.leaky_relu.Tensor) - #. enabled: Whether the converter should be enabled/cached or not - #. capability_validator: Function which evaluates whether a node is valid for conversion by the decorated converter. It defaults to None, implying the capability_validator function is always true. This means all nodes of "key" kind can be supported by this converter by default. See ``embedding`` example for more details - #. priority: Converter's level of priority relative to other converters with the same target - - #. The ``ConverterSupport`` is a compilation of ``converter_implementation`` and ``capability_validator``. - - - The function decorated by ``tensorrt_converter`` and ``dynamo_tensorrt_converter`` has the following arguments which are automatically generated by the trace functions mentioned above. - - #. network : Node in the form of ``call_module`` or ``call_function`` having the target as the key - #. target: Target key in the ``call_module`` or ``call_function`` above. eg: ``torch.ops.aten_.leaky_relu.default``. Note that ``torch.ops.aten._leaky_relu`` is the ``OpOverloadPacket`` while ``torch.ops.aten_.leaky_relu.default`` is ``OpOverload``. - #. args: The arguments passed in the ``call_module`` or ``call_function`` above - #. kwargs: The kwargs passed in the ``call_module`` or ``call_function`` above - #. name: String containing the name of the target - - As a user writing new converters, one just needs to take care that the approriate arguments are extracted from the trace generated to the implementation function in the implementation lib function ``activation.leaky_relu`` (which we will discuss below in detail). - - * Operation type - - Example: ``fmod`` - - It follows the same steps as the above converter. In this case the opcode is ``torch.ops.aten.fmod.Scalar`` or ``torch.ops.aten.fmod.Tensor``. - Hence both the opcodes are registered in ``py/torch_tensorrt/dynamo/conversion/aten_ops_converters``. - Note that ``torch.ops.aten.fmod`` is the ``OpOverLoadPacket`` while the registry is keyed on ``torch.ops.aten.fmod.Scalar`` or ``torch.ops.aten.fmod.Tensor``, which is ``OpOverLoad`` - - Example: ``embedding`` - - It follows the same steps as the above converter. In this case the opcode is ``torch.ops.aten.embedding.default``. - There are some converters which have special cases to be accounted for. In those cases, one should use ``capability_validators`` to register the converter using ``@dynamo_tensorrt_converter`` - We illustrate this through ``torch.ops.aten.embedding.default``. It has parameters - ``scale_grad_by_freq`` and ``sparse`` which are not currently supported by the implementation. - In such cases we can write validator ``embedding_param_validator`` which implements that given those paramters the converter is not supported and register the converter by - - .. code-block:: python - @dynamo_tensorrt_converter( - torch.ops.aten.embedding.default, capability_validator=embedding_param_validator - ) - - So if there is a new converter in which certain special cases are not to be supported then they can be specified in the ``capability_validator``. - - * Evaluator type - - Example: ``operator.getitem`` - - Evaluators are categorized as so since they do not make any modification to the graph. This is implemented in ``py/torch_tensorrt/dynamo/conversion/op_evaluators.py``, with the corresponding ``capbility_validator``. - The opcode is ``operator.getitem``. - - - * Implementation Library - - The dynamo converters would be located in ``py/torch_tensorrt/dynamo/conversion/impl`` - - * Activation - - Example: ``leaky_relu`` - - The implementation is to be placed in present in ``py/torch_tensorrt/dynamo/conversion/impl/activation.py``. This is where all the activation functions are defined and implemented. - - .. code-block:: python - - def leaky_relu( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_val: TRTTensor, - alpha: Optional[Any], - ): - #implementation - - The implementation function has the following arguments. - - #. network : ``network`` passed from the decorated function registration - #. target: ``target`` passed from the decorated function registration - #. source_ir: Enum attribute. ``SourceIR`` enum is defined in ``py/torch_tensorrt/dynamo/conversion/impl/converter_utils`` - #. name: ``name`` passed from the decorated function registration - #. input_val: Approriate arguments extracted from the decorated function registration from args or kwargs - #. alpha: Approriate arguments extracted from the decorated function registration from args or kwargs. If not None, it will set the alpha attribute of the created TensorRT activation layer eg: Used in leaky_relu, elu, hardtanh - #. beta: Approriate arguments extracted from the decorated function registration from args or kwargs. If not None, it will set the beta attribute of the created TensorRT activation layer eg: Used in hardtanh - #. dyn_range_fn: A optional function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range - - The implementation functions call the ``convert_activation`` function in ``py/torch_tensorrt/dynamo/conversion/impl/activation.py``. This function will add the approriate activation layer via ``network.add_activation``. - - * Operator - - The implementation is to be placed in ``py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py`` for dynamo. This is where all the elementwise functions are defined and implemented. - For a new operator, one should identify the category to which it belongs. Following are some examples - - #. Elementwise operators like ``fmod`` is present in ``py/torch_tensorrt/dynamo/conversion/impl/elementwise``. The ``py/torch_tensorrt/dynamo/conversion/impl/elementwise/base`` contains base functions for elementwise operator. - #. Unary operators like ``sqrt`` will be present in ``py/torch_tensorrt/dynamo/conversion/impl/unary``. The ``py/torch_tensorrt/dynamo/conversion/impl/unary/base`` contains base functions for unary operator. - #. Normalization operators like ``softmax``, ``layer_norm``, ``batch_norm`` will be present in ``py/torch_tensorrt/dynamo/conversion/impl/normalization``. Since there are no base operations common to all, there is no base file. But one can choose to implement a base file, if there are common functions across all normalization operations - #. Individual operators like ``slice``, ``select``, ``where``, ``embedding`` will be present in ``py/torch_tensorrt/dynamo/conversion/impl/*.py``. They will have individual operator implementation with the same API structure as above but with different individual arguments - - Please note that the above operators would have common functions to be implemented which should be placed in - ``py/torch_tensorrt/dynamo/conversion/impl/converter_utils.py`` - - - * Lowering type - - There are some converters which can be decomposed into suboperations and need not have seperate converter registration. - Such converters can be implemented via ``lowering passes`` - - Example: ``addmm`` - - The decompositions are registered via ``register_decomposition`` in ``py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py`` - We define ``addmm_replacement`` and replace it with the torch ops, which will have their corresponding converters called. - - .. code-block:: python - - @register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) - def addmm_replacement( - input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1 - ) -> torch.Tensor: - return torch.add( - torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha) - ) - - Note that there are some pre-existing dynamo decompositions in torch directory, in which case they should be used, - In that case please enable the decompositions in ``py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py`` in ``torch_enabled_decompositions``. - Similarly you can choose to disable any in ``torch_disabled_decompositions``. Please note that the ones already defined in the lowering will take precedence over torch lowering ops. - - - - -Tests ------ - -* Dynamo testing: - - Dynamo tests are present for the lowering ops in ``tests/py/dynamo/lowering/test_decompositions.py``. The above converters will soon be ported to dynamo tests - - #. Compare the results for ``fx.symbolic_trace`` and ``torch_tensorrt.dynamo.compile``. - #. Test for the ``expected_op`` and the ``unexpected_op``. - - #. ``expected_op``: Operations the operations are lowered to. eg: ``mul`` and ``add`` for ``addmm`` - #. Note that specify that ``disable_passes= True`` for cases where you would not want lowering passes (which should be the default when testing converters) - #. ``unexpected_op``: Original operation. eg: ``addmm`` for ``addmm`` - -The tests should fail if any of the above two conditions fail diff --git a/docsrc/contributors/writing_converters.rst b/docsrc/contributors/ts_converters.rst similarity index 98% rename from docsrc/contributors/writing_converters.rst rename to docsrc/contributors/ts_converters.rst index 990c4dc77d..3266354ad8 100644 --- a/docsrc/contributors/writing_converters.rst +++ b/docsrc/contributors/ts_converters.rst @@ -1,7 +1,7 @@ -.. _writing_converters: +.. _ts_converters: -Writing Converters -=================== +Writing TorchScript Converters +================================= Background ------------ diff --git a/docsrc/user_guide/dynamo_export.rst b/docsrc/dynamo/dynamo_export.rst similarity index 74% rename from docsrc/user_guide/dynamo_export.rst rename to docsrc/dynamo/dynamo_export.rst index a5d430f8f2..7cbf5cad7c 100644 --- a/docsrc/user_guide/dynamo_export.rst +++ b/docsrc/dynamo/dynamo_export.rst @@ -1,6 +1,6 @@ .. _dynamo_export: -Torch-TensorRT Dynamo Backend +Compiling ``ExportedPrograms`` with Torch-TensorRT ============================================= .. currentmodule:: torch_tensorrt.dynamo @@ -8,16 +8,13 @@ Torch-TensorRT Dynamo Backend :members: :undoc-members: :show-inheritance: - -This guide presents Torch-TensorRT dynamo backend which optimizes Pytorch models -using TensorRT in an Ahead-Of-Time fashion. -Using the Dynamo backend ----------------------------------------- -Pytorch 2.1 introduced ``torch.export`` APIs which -can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo -backend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple -usage of the dynamo backend +Using the Torch-TensorRT Frontend for ``torch.export.ExportedPrograms`` +-------------------------------------------------------- +Pytorch 2.1 introduced ``torch.export`` APIs which +can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo +frontend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple +usage of the dynamo frontend .. code-block:: python @@ -30,13 +27,13 @@ usage of the dynamo backend trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule trt_gm(*inputs) -.. note:: ``torch_tensorrt.dynamo.compile`` is the main API for users to interact with Torch-TensorRT dynamo backend. The input type of the model should be ``ExportedProgram`` (ideally the output of ``torch.export.export`` or ``torch_tensorrt.dynamo.trace`` (discussed in the section below)) and output type is a ``torch.fx.GraphModule`` object. +.. note:: ``torch_tensorrt.dynamo.compile`` is the main API for users to interact with Torch-TensorRT dynamo frontend. The input type of the model should be ``ExportedProgram`` (ideally the output of ``torch.export.export`` or ``torch_tensorrt.dynamo.trace`` (discussed in the section below)) and output type is a ``torch.fx.GraphModule`` object. Customizeable Settings ---------------------- -There are lot of options for users to customize their settings for optimizing with TensorRT. -Some of the frequently used options are as follows: +There are lot of options for users to customize their settings for optimizing with TensorRT. +Some of the frequently used options are as follows: * ``inputs`` - For static shapes, this can be a list of torch tensors or `torch_tensorrt.Input` objects. For dynamic shapes, this should be a list of ``torch_tensorrt.Input`` objects. * ``enabled_precisions`` - Set of precisions that TensorRT builder can use during optimization. @@ -46,7 +43,7 @@ Some of the frequently used options are as follows: The complete list of options can be found `here `_ -.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in +.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in our Torchscript IR. We plan to implement similar support for dynamo in our next release. Under the hood @@ -62,10 +59,10 @@ Under the hood, ``torch_tensorrt.dynamo.compile`` performs the following on the Tracing ------- -``torch_tensorrt.dynamo.trace`` can be used to trace a Pytorch graphs and produce ``ExportedProgram``. -This internally performs some decompositions of operators for downstream optimization. +``torch_tensorrt.dynamo.trace`` can be used to trace a Pytorch graphs and produce ``ExportedProgram``. +This internally performs some decompositions of operators for downstream optimization. The ``ExportedProgram`` can then be used with ``torch_tensorrt.dynamo.compile`` API. -If you have dynamic input shapes in your model, you can use this ``torch_tensorrt.dynamo.trace`` to export +If you have dynamic input shapes in your model, you can use this ``torch_tensorrt.dynamo.trace`` to export the model with dynamic shapes. Alternatively, you can use ``torch.export`` `with constraints `_ directly as well. .. code-block:: python @@ -78,5 +75,4 @@ the model with dynamic shapes. Alternatively, you can use ``torch.export`` `with max_shape=(8, 3, 224, 224), dtype=torch.float32)] model = MyModel().eval() - exp_program = torch_tensorrt.dynamo.trace(model, inputs) - \ No newline at end of file + exp_program = torch_tensorrt.dynamo.trace(model, inputs) diff --git a/docsrc/user_guide/torch_compile.rst b/docsrc/dynamo/torch_compile.rst similarity index 99% rename from docsrc/user_guide/torch_compile.rst rename to docsrc/dynamo/torch_compile.rst index a2d83cd52e..6e969092ae 100644 --- a/docsrc/user_guide/torch_compile.rst +++ b/docsrc/dynamo/torch_compile.rst @@ -1,6 +1,6 @@ .. _torch_compile: -Torch-TensorRT `torch.compile` Backend +TensorRT Backend for ``torch.compile`` ====================================================== .. currentmodule:: torch_tensorrt.dynamo diff --git a/docsrc/user_guide/getting_started_with_fx_path.rst b/docsrc/fx/getting_started_with_fx_path.rst similarity index 100% rename from docsrc/user_guide/getting_started_with_fx_path.rst rename to docsrc/fx/getting_started_with_fx_path.rst diff --git a/docsrc/getting_started/installation.rst b/docsrc/getting_started/installation.rst index 1b6089645a..9f0088c3b8 100644 --- a/docsrc/getting_started/installation.rst +++ b/docsrc/getting_started/installation.rst @@ -82,8 +82,7 @@ Dependencies for Compilation cp output/bazel /usr/local/bin/ -* You will also need to have **CUDA** installed on the system (or if running in a container, the system must have -the CUDA driver installed and the container must have CUDA) +* You will also need to have **CUDA** installed on the system (or if running in a container, the system must have the CUDA driver installed and the container must have CUDA) * Specify your CUDA version here if not the version used in the branch being built: https://github.com/pytorch/TensorRT/blob/4e5b0f6e860910eb510fa70a76ee3eb9825e7a4d/WORKSPACE#L46 @@ -96,8 +95,7 @@ the CUDA driver installed and the container must have CUDA) * https://github.com/pytorch/TensorRT/blob/4e5b0f6e860910eb510fa70a76ee3eb9825e7a4d/WORKSPACE#L53C1-L53C1 -* **cuDNN and TensorRT** are not required to be installed on the system to build Torch-TensorRT, in fact this is preferable to ensure reproducable builds. Download the tarballs -for cuDNN and TensorRT from https://developer.nvidia.com and update the paths in the WORKSPACE file here https://github.com/pytorch/TensorRT/blob/4e5b0f6e860910eb510fa70a76ee3eb9825e7a4d/WORKSPACE#L71 +* **cuDNN and TensorRT** are not required to be installed on the system to build Torch-TensorRT, in fact this is preferable to ensure reproducable builds. Download the tarballs for cuDNN and TensorRT from https://developer.nvidia.com and update the paths in the WORKSPACE file here https://github.com/pytorch/TensorRT/blob/4e5b0f6e860910eb510fa70a76ee3eb9825e7a4d/WORKSPACE#L71 For example: @@ -137,9 +135,10 @@ Once the WORKSPACE has been configured properly, all that is required to build t python -m pip install --pre . --extra-index-url https://download.pytorch.org/whl/nightly/cu121 + To build the wheel file - .. code-bloc:: sh + .. code-block:: sh python -m pip wheel --no-deps --pre . --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -w dist diff --git a/docsrc/index.rst b/docsrc/index.rst index 55890727ad..1c0c9a0d9e 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -19,8 +19,6 @@ More Information / System Architecture: Getting Started ---------------- * :ref:`installation` -* :ref:`getting_started_with_python_api` -* :ref:`getting_started_cpp` .. toctree:: :caption: Getting Started @@ -28,22 +26,60 @@ Getting Started :hidden: getting_started/installation - getting_started/getting_started_with_python_api - getting_started/getting_started_with_cpp_api getting_started/getting_started_with_windows -User Guide ------------- -* :ref:`creating_a_ts_mod` -* :ref:`getting_started_with_fx` +Dynamo Frontend +---------------- + * :ref:`torch_compile` * :ref:`dynamo_export` + +.. toctree:: + :caption: Dynamo Frontend + :maxdepth: 1 + :hidden: + + dynamo/torch_compile + dynamo/dynamo_export + +TorchScript Frontend +----------------------- +* :ref:`creating_a_ts_mod` +* :ref:`getting_started_with_python_api` +* :ref:`getting_started_cpp` +* :ref:`use_from_pytorch` + +.. toctree:: + :caption: TorchScript Frontend + :maxdepth: 1 + :hidden: + + ts/creating_torchscript_module_in_python + ts/getting_started_with_python_api + ts/getting_started_with_cpp_api + ts/use_from_pytorch + +FX Frontend +------------ + +* :ref:`getting_started_with_fx` + +.. toctree:: + :caption: FX Frontend + :maxdepth: 1 + :hidden: + + fx/getting_started_with_fx_path + + +User Guide +------------ + +* :ref:`dynamic_shapes` * :ref:`ptq` -* :ref:`runtime` * :ref:`saving_models` -* :ref:`dynamic_shapes` -* :ref:`use_from_pytorch` +* :ref:`runtime` * :ref:`using_dla` .. toctree:: @@ -51,15 +87,11 @@ User Guide :maxdepth: 1 :hidden: - user_guide/creating_torchscript_module_in_python - user_guide/getting_started_with_fx_path - user_guide/torch_compile - user_guide/dynamo_export + + user_guide/dynamic_shapes user_guide/ptq - user_guide/runtime user_guide/saving_models - user_guide/dynamic_shapes - user_guide/use_from_pytorch + user_guide/runtime user_guide/using_dla Tutorials @@ -84,6 +116,7 @@ Python API Documenation * :ref:`torch_tensorrt_py` * :ref:`torch_tensorrt_logging_py` * :ref:`torch_tensorrt_ptq_py` +* :ref:`torch_tensorrt_dynamo_py` * :ref:`torch_tensorrt_ts_py` * :ref:`torch_tensorrt_fx_py` @@ -95,6 +128,7 @@ Python API Documenation py_api/torch_tensorrt py_api/logging py_api/ptq + py_api/dynamo py_api/ts py_api/fx @@ -132,8 +166,9 @@ CLI Documentation Contributor Documentation -------------------------------- * :ref:`system_overview` -* :ref:`writing_converters` +* :ref:`_dynamo_converters` * :ref:`writing_dynamo_aten_lowering_passes` +* :ref:`_ts_converters` * :ref:`useful_links` .. toctree:: @@ -142,8 +177,9 @@ Contributor Documentation :hidden: contributors/system_overview - contributors/writing_converters + contributors/dynamo_converters contributors/writing_dynamo_aten_lowering_passes + contributors/ts_converters contributors/useful_links Indices diff --git a/docsrc/py_api/dynamo.rst b/docsrc/py_api/dynamo.rst new file mode 100644 index 0000000000..fce5372d0e --- /dev/null +++ b/docsrc/py_api/dynamo.rst @@ -0,0 +1,36 @@ +.. _torch_tensorrt_dynamo_py: + +torch_tensorrt.dynamo +=================== + +.. currentmodule:: torch_tensorrt.dynamo + +.. automodule torch_tensorrt.ts + :undoc-members: + +.. automodule:: torch_tensorrt.dynamo + :members: + :undoc-members: + :show-inheritance: + +Functions +------------ + +.. autofunction:: compile + +.. autofunction:: trace + +.. autofunction:: export + + + +Classes +-------- + +.. autoclass:: CompilationSettings + +.. autoclass:: SourceIR + +.. autoclass:: runtime.TorchTensorRTModule + +.. autoclass:: runtime.PythonTorchTensorRTModule \ No newline at end of file diff --git a/docsrc/py_api/torch_tensorrt.rst b/docsrc/py_api/torch_tensorrt.rst index 98ccde8193..22fda13ba2 100644 --- a/docsrc/py_api/torch_tensorrt.rst +++ b/docsrc/py_api/torch_tensorrt.rst @@ -62,3 +62,4 @@ Submodules ptq ts fx + dynamo diff --git a/docsrc/user_guide/creating_torchscript_module_in_python.rst b/docsrc/ts/creating_torchscript_module_in_python.rst similarity index 100% rename from docsrc/user_guide/creating_torchscript_module_in_python.rst rename to docsrc/ts/creating_torchscript_module_in_python.rst diff --git a/docsrc/getting_started/getting_started_with_cpp_api.rst b/docsrc/ts/getting_started_with_cpp_api.rst similarity index 98% rename from docsrc/getting_started/getting_started_with_cpp_api.rst rename to docsrc/ts/getting_started_with_cpp_api.rst index 7f7f60a669..70f439ea6d 100644 --- a/docsrc/getting_started/getting_started_with_cpp_api.rst +++ b/docsrc/ts/getting_started_with_cpp_api.rst @@ -8,7 +8,9 @@ If you haven't already, acquire a tarball of the library by following the instru Using Torch-TensorRT in C++ *************************** Torch-TensorRT C++ API accepts TorchScript modules (generated either from ``torch.jit.script`` or ``torch.jit.trace``) as an input and returns -a Torchscript module (optimized using TensorRT). This requires users to use Pytorch (in python) to generate torchscript modules beforehand. +a Torchscript module (optimized using TensorRT), Dynamo compilation workflows will not be supported in the C++ API however, execution of +torch.jit.trace'd compiled FX GraphModules is supported for FX and Dyanmo workflows. + Please refer to `Creating TorchScript modules in Python `_ section to generate torchscript graphs. diff --git a/docsrc/getting_started/getting_started_with_python_api.rst b/docsrc/ts/getting_started_with_python_api.rst similarity index 100% rename from docsrc/getting_started/getting_started_with_python_api.rst rename to docsrc/ts/getting_started_with_python_api.rst diff --git a/docsrc/user_guide/use_from_pytorch.rst b/docsrc/ts/torchscript_frontend_from_pytorch.rst similarity index 97% rename from docsrc/user_guide/use_from_pytorch.rst rename to docsrc/ts/torchscript_frontend_from_pytorch.rst index 41e81d181b..d0a403e93c 100644 --- a/docsrc/user_guide/use_from_pytorch.rst +++ b/docsrc/ts/torchscript_frontend_from_pytorch.rst @@ -1,6 +1,6 @@ .. _use_from_pytorch: -Using Torch-TensorRT Directly From PyTorch +Using Torch-TensorRT TorchScript Frontend Directly From PyTorch ============================================ You will now be able to directly access TensorRT from PyTorch APIs. The process to use this feature diff --git a/docsrc/user_guide/dynamic_shapes.rst b/docsrc/user_guide/dynamic_shapes.rst index 4e1bf69631..73ed9b594a 100644 --- a/docsrc/user_guide/dynamic_shapes.rst +++ b/docsrc/user_guide/dynamic_shapes.rst @@ -3,12 +3,12 @@ Dynamic shapes with Torch-TensorRT ==================================== -By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. +By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model. In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for these range of input shapes. An example usage of static and dynamic shapes is as follows. -NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same. +NOTE: The following code uses Dynamo Frontend. Incase of Torchscript Frontend, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same. .. code-block:: python @@ -19,7 +19,7 @@ NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap o # Compile with static shapes inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32) # or compile with dynamic shapes - inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224], + inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224], opt_shape=[4, 3, 224, 224], max_shape=[8, 3, 224, 224], dtype=torch.float32) @@ -32,25 +32,25 @@ There are two phases of compilation when we use ``torch_tensorrt.compile`` API w - aten_tracer.trace (which uses torch.export to trace the graph with the given inputs) -In the tracing phase, we use torch.export along with the constraints. In the case of -dynamic shaped inputs, the range can be provided to the tracing via constraints. Please +In the tracing phase, we use torch.export along with the constraints. In the case of +dynamic shaped inputs, the range can be provided to the tracing via constraints. Please refer to this `docstring `_ -for detailed information on how to set constraints. In short, we create new inputs for -torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. -Please take a look at ``aten_tracer.py`` file to understand how this works under the hood. +for detailed information on how to set constraints. In short, we create new inputs for +torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. +Please take a look at ``aten_tracer.py`` file to understand how this works under the hood. - dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT) -In the conversion to TensorRT, we use the user provided dynamic shape inputs. -We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the -intermediate output shapes which can be used in case the graph has a mix of Pytorch +In the conversion to TensorRT, we use the user provided dynamic shape inputs. +We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the +intermediate output shapes which can be used in case the graph has a mix of Pytorch and TensorRT submodules. Custom Constraints ------------------ -Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, -Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows +Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, +Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows .. code-block:: python @@ -80,18 +80,18 @@ If you have to provide any custom constraints to your model, the overall workflo torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda() torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda() - dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14], + dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14], opt_shape=[4, 14], max_shape=[8, 14], - dtype=torch.int32), - torch_tensorrt.Input(min_shape=[1, 14], + dtype=torch.int32), + torch_tensorrt.Input(min_shape=[1, 14], opt_shape=[4, 14], max_shape=[8, 14], dtype=torch.int32)] # Export the model with additional constraints constraints = [] - # The following constraints are automatically added by Torch-TensorRT in the + # The following constraints are automatically added by Torch-TensorRT in the # general case when you call torch_tensorrt.compile directly on MyModel() constraints.append(dynamic_dim(torch_input_1, 0) < 8) constraints.append(dynamic_dim(torch_input_2, 0) < 8) @@ -110,9 +110,9 @@ If you have to provide any custom constraints to your model, the overall workflo Limitations ----------- -If there are operations in the graph that use the dynamic dimension of the input, Pytorch -introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and -the compilation results in undefined behavior. We plan to add support for these operators and implement +If there are operations in the graph that use the dynamic dimension of the input, Pytorch +introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and +the compilation results in undefined behavior. We plan to add support for these operators and implement robust support for shape tensors in the next release. Here is an example of the limitation described above .. code-block:: python @@ -132,7 +132,7 @@ robust support for shape tensors in the next release. Here is an example of the model = MyModel().eval().cuda() # Compile with dynamic shapes - inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1), + inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1), opt_shape=(4, 512, 1, 1), max_shape=(8, 512, 1, 1), dtype=torch.float32) @@ -151,14 +151,14 @@ The traced graph of `MyModule()` looks as follows return (view,) -Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support +Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support which would be a part of our next release. Workaround (BERT static compilation example) ------------------------------------------ -In the case where you encounter the issues mentioned in the **Limitations** section, -you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, +In the case where you encounter the issues mentioned in the **Limitations** section, +you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, we can pad them accordingly. This is only a workaround until we address the limitations. .. code-block:: python @@ -172,11 +172,11 @@ we can pad them accordingly. This is only a workaround until we address the limi # Input sequence length is 20. input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") - + model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda() trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) model_outputs = model(input, input2) - + # If you have a sequence of length 14, pad 6 zero tokens and run inference # or recompile for sequence length of 14. input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") @@ -188,9 +188,10 @@ we can pad them accordingly. This is only a workaround until we address the limi Dynamic shapes with ir=torch_compile ------------------------------------ -``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend -configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes. -In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model. +``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend +configured to Tensorrt. In the case of ``ir=torch_compile``, when the input size changes, Dynamo will trigger a recompilation +of the TensorRT engine automatically giving dynamic shape behavior similar to native PyTorch eager however with the cost of rebuilding +TRT engine. This limitation will be addressed in future versions of Torch-TensorRT. .. code-block:: python diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 3b50e7d761..c41e006b98 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -8,10 +8,10 @@ Saving models compiled with Torch-TensorRT :members: :undoc-members: :show-inheritance: - + Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation. -Dynamo IR +Dynamo IR ------------- Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based. @@ -20,8 +20,8 @@ The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There a) Converting to Torchscript ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk. -The following code illustrates this approach. +`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk. +The following code illustrates this approach. .. code-block:: python @@ -41,29 +41,28 @@ The following code illustrates this approach. b) ExportedProgram ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant +`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant `torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk. .. code-block:: python import torch import torch_tensorrt - from torch_tensorrt.dynamo.export import transform, create_exported_program model = MyModel().eval().cuda() inputs = torch.randn((1, 3, 224, 224)).cuda() trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule # Transform and create an exported program - trt_gm = transform(trt_gm, inputs) + trt_gm = torch_tensorrt.dynamo.export(trt_gm, inputs) trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict()) torch._export.save(trt_exp_program, "trt_model.ep") - # Later, you can load it and run inference + # Later, you can load it and run inference model = torch._export.load("trt_model.ep") model(inputs) -`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. -This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). +`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. +This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341 @@ -72,7 +71,7 @@ Torchscript IR ------------- In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR. - This behavior stays the same in 2.X versions as well. + This behavior stays the same in 2.X versions as well. .. code-block:: python @@ -87,4 +86,4 @@ Torchscript IR # Later, you can load it and run inference model = torch.jit.load("trt_model.ts").cuda() model(inputs) - + diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9b9f4c00a1..8391033281 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -22,7 +22,7 @@ if DYNAMO_ENABLED: from torch._export import ExportedProgram - from torch_tensorrt.dynamo.compile import compile as dynamo_compile + from torch_tensorrt.dynamo._compiler import compile as dynamo_compile logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 63cc2af10a..c191be62fb 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -7,12 +7,8 @@ logger = logging.getLogger(__name__) if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from ._settings import * + from ._compiler import compile + from ._exporter import export + from ._settings import CompilationSettings from ._SourceIR import SourceIR - from .aten_tracer import trace - from .compile import compile - from .conversion import * - from .conversion.converter_registry import ( - DYNAMO_CONVERTERS, - dynamo_tensorrt_converter, - ) + from ._tracer import trace diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/_compiler.py similarity index 61% rename from py/torch_tensorrt/dynamo/compile.py rename to py/torch_tensorrt/dynamo/_compiler.py index 5394c1382e..d31be8a413 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -12,7 +12,7 @@ EngineCapability, ) from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import CompilationSettings, partitioning +from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._defaults import ( DEBUG, DEVICE, @@ -30,6 +30,7 @@ WORKSPACE_SIZE, ) from torch_tensorrt.dynamo.conversion import ( + CompilationSettings, convert_module, repair_long_or_double_inputs, ) @@ -47,7 +48,7 @@ def compile( exported_program: ExportedProgram, - inputs: Any, + inputs: Tuple[Any, ...], *, device: Optional[Union[Device, torch.device, str]] = DEVICE, disable_tf32: bool = False, @@ -76,6 +77,65 @@ def compile( enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, ) -> torch.fx.GraphModule: + """Compile a TorchScript module for NVIDIA GPUs using TensorRT + + Takes a existing TorchScript module and a set of settings to configure the compiler + and will convert methods to JIT Graphs which call equivalent TensorRT engines + + Converts specifically the forward method of a TorchScript Module + + Arguments: + exported_program (torch.export.ExportedProgram): Source module, running torch.export on a ``torch.nn.Module`` + inputs (Tuple[Any, ...]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum + to select device type. :: + + input=[ + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 + torch_tensorrt.Input( + min_shape=(1, 224, 224, 3), + opt_shape=(1, 512, 512, 3), + max_shape=(1, 1024, 1024, 3), + dtype=torch.int32 + format=torch.channel_last + ), # Dynamic input shape for input #2 + torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings + ] + + Keyword Arguments: + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + + device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) + + disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + sparse_weights (bool): Enable sparsity for convolution and fully connected layers. + enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels + refit (bool): Enable refitting + debug (bool): Enable debuggable engine + capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels + workspace_size (int): Maximum size of workspace given to TensorRT + dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. + dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations + dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution + truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 + calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration + require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch + min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT + torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows) + max_aux_stream (Optional[int]): Maximum streams in the engine + version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines) + optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level. + use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization + use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance + enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + **kwargs: Any, + Returns: + torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT + """ + if debug: set_log_level(logger.parent, logging.DEBUG) diff --git a/py/torch_tensorrt/dynamo/export.py b/py/torch_tensorrt/dynamo/_exporter.py similarity index 84% rename from py/torch_tensorrt/dynamo/export.py rename to py/torch_tensorrt/dynamo/_exporter.py index 9bd1dbddb3..f8d1eceaf8 100644 --- a/py/torch_tensorrt/dynamo/export.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -10,6 +10,43 @@ from torch_tensorrt.dynamo import partitioning +# TODO: @peri044: Correct this implementation +def export( + src_gm: torch.fx.GraphModule, + trt_gm: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], +) -> ExportedProgram: + """Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded. + + > Note: When ExportedProgram becomes stable, this function will get merged into ``torch_tensorrt.dynamo.compile`` + + Arguments: + src_gm (torch.fx.GraphModule): Source module, generated by torch.export (The module provided to ``torch_tensorrt.dynamo.compile``) + gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` + + Keyword Arguments: + inputs (Any): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum + to select device type. :: + + input=[ + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 + torch_tensorrt.Input( + min_shape=(1, 224, 224, 3), + opt_shape=(1, 512, 512, 3), + max_shape=(1, 1024, 1024, 3), + dtype=torch.int32 + format=torch.channel_last + ), # Dynamic input shape for input #2 + torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings + + """ + + patched_module = transform(torch.fx.GraphModule, inputs) + + return create_trt_exp_program(patched_module, src_gm.call_spec, src_gm.state_dict) + + def transform( gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py similarity index 50% rename from py/torch_tensorrt/dynamo/aten_tracer.py rename to py/torch_tensorrt/dynamo/_tracer.py index 0ef47ff2ef..5fdca08399 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -2,12 +2,15 @@ import logging import unittest.mock -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple, Union import torch from torch._export import dynamic_dim, export +from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import ( + DEBUG, + DEVICE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, default_device, ) @@ -32,18 +35,55 @@ def get_random_tensor( def trace( - model: torch.nn.Module | torch.fx.GraphModule, + mod: torch.nn.Module | torch.fx.GraphModule, inputs: Tuple[Any, ...], + device: Optional[Union[Device, torch.device, str]] = DEVICE, + debug: bool = DEBUG, + enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, -) -> torch.fx.GraphModule: +) -> torch.export.ExportedProgram: + """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT + + Exports a ``torch.export.ExportedProgram`` from either a ``torch.nn.Module`` or torch.fx.GraphModule``. Runs specific operator decompositions geared towards + compilation by Torch-TensorRT's dynamo frontend. + + Arguments: + mod (torch.nn.Module | torch.fx.GraphModule): Source module to later be compiled by Torch-TensorRT's dynamo fronted + inputs (Tuple[Any, ...]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum + to select device type. :: + + input=[ + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 + torch_tensorrt.Input( + min_shape=(1, 224, 224, 3), + opt_shape=(1, 512, 512, 3), + max_shape=(1, 1024, 1024, 3), + dtype=torch.int32 + format=torch.channel_last + ), # Dynamic input shape for input #2 + torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings + ] + Keyword Arguments: + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + + device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) + + debug (bool): Enable debuggable engine + enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + **kwargs: Any, + Returns: + torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT + """ + # Set log level at the top of compilation (torch_tensorrt.dynamo) - if "debug" in kwargs and kwargs["debug"]: + if debug: set_log_level(logger.parent, logging.DEBUG) + device = to_torch_device(device if device else default_device()) # Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT # Torch dynamo does not allow 0/1 value for dynamic dimensions # for inputs during tracing. Hence we create new inputs for export - device = to_torch_device(kwargs.get("device", default_device())) torch_inputs = get_torch_inputs(inputs, device) trace_inputs = [] constraints = [] @@ -77,12 +117,10 @@ def trace( else: trace_inputs.append(torch_inputs[idx]) - experimental_decompositions = kwargs.get( - "enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS - ) with unittest.mock.patch( - "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) + "torch._export.DECOMP_TABLE", + get_decompositions(enable_experimental_decompositions), ): - exp_program = export(model, tuple(trace_inputs), constraints=constraints) + exp_program = export(mod, tuple(trace_inputs), constraints=constraints) return exp_program diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 246b7e3cb7..1fa2806181 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -9,7 +9,7 @@ from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple from torch_tensorrt.dynamo import CompilationSettings -from torch_tensorrt.dynamo.compile import compile_module +from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( apply_lowering_passes, get_decompositions, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py similarity index 99% rename from py/torch_tensorrt/dynamo/conversion/converter_registry.py rename to py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 45445f0f89..d689de3e54 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -19,7 +19,7 @@ from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.fx.converter_registry import CONVERTERS +from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS from torch_tensorrt.fx.types import TRTNetwork, TRTTensor logger = logging.getLogger(__name__) @@ -441,7 +441,7 @@ def display_all_available_converters(self) -> str: # Initialize dynamo converter registry with the FX and Dynamo aten registries # Note the Dynamo registry is listed first, for precedence DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry( - [DYNAMO_ATEN_CONVERTERS, CONVERTERS], # type: ignore[list-item] + [DYNAMO_ATEN_CONVERTERS, FX_CONVERTERS], # type: ignore[list-item] ["Dynamo ATen Converters Registry", "FX Legacy ATen Converters Registry"], [CallingConvention.CTX, CallingConvention.LEGACY], ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 61f3f6b6f5..0f1c3b0c42 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,9 +4,6 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -15,7 +12,10 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_registry import CallingConvention +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention from torch_tensorrt.dynamo.conversion.converter_utils import ( get_node_name, get_trt_tensor, @@ -23,10 +23,10 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt from packaging import version -from .converter_registry import DYNAMO_CONVERTERS as CONVERTERS - _LOGGER: logging.Logger = logging.getLogger(__name__) TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ @@ -49,9 +49,9 @@ class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] def __init__( self, module: torch.fx.GraphModule, - input_specs: List[Input], + input_specs: Sequence[Input], logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, - output_dtypes: Optional[List[torch.dtype]] = None, + output_dtypes: Optional[Sequence[torch.dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), ): super().__init__(module) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 1261062cf4..83805d9a55 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,7 +1,6 @@ +from . import aten_ops_converters, ops_evaluators, prims_ops_converters +from ._conversion import convert_module from ._ConversionContext import ConversionContext +from ._ConverterRegistry import * # noqa: F403 from ._TRTInterpreter import * # noqa: F403 -from .aten_ops_converters import * # noqa: F403 -from .conversion import * # noqa: F403 -from .ops_evaluators import * # noqa: F403 -from .prims_ops_converters import * # noqa: F403 from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py similarity index 95% rename from py/torch_tensorrt/dynamo/conversion/conversion.py rename to py/torch_tensorrt/dynamo/conversion/_conversion.py index 59ca3e3143..1cdea63680 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,14 +3,15 @@ import io from typing import Sequence -import tensorrt as trt import torch from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import CompilationSettings -from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs +import tensorrt as trt + def convert_module( module: torch.fx.GraphModule, diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f7ebf9e9c5..06ccdf57ee 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -8,17 +8,16 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_registry import ( +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) from torch_tensorrt.dynamo.conversion.converter_utils import ( + dynamic_unsupported_with_args, enforce_tensor_types, is_only_operator_on_placeholder, ) from torch_tensorrt.fx.types import TRTTensor -from .converter_utils import dynamic_unsupported_with_args - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -55,14 +54,16 @@ def one_user_validator(node: Node) -> bool: ) -@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator +) +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_batch_norm( ctx: ConversionContext, target: Target, @@ -88,14 +89,16 @@ def aten_ops_batch_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator +) +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_layer_norm( ctx: ConversionContext, target: Target, @@ -118,12 +121,14 @@ def aten_ops_layer_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator +) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_native_group_norm( ctx: ConversionContext, target: Target, @@ -147,13 +152,13 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.group_norm) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) +@dynamo_tensorrt_converter(torch.ops.aten.group_norm) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_group_norm( ctx: ConversionContext, target: Target, @@ -175,7 +180,7 @@ def aten_ops_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.cat.default) def aten_ops_cat( ctx: ConversionContext, target: Target, @@ -212,7 +217,7 @@ def embedding_param_validator(embedding_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.embedding.default, capability_validator=embedding_param_validator -) # type: ignore[misc] +) def aten_ops_embedding( ctx: ConversionContext, target: Target, @@ -245,15 +250,19 @@ def embedding_bag_validator(node: Node) -> bool: ) -@dynamo_tensorrt_converter(torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator +) @enforce_tensor_types( { 0: (TRTTensor,), 1: (TRTTensor,), 2: (np.ndarray, torch.Tensor), } -) # type: ignore[misc] +) def aten_ops_embedding_bag( ctx: ConversionContext, target: Target, @@ -278,8 +287,8 @@ def aten_ops_embedding_bag( ) -@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) def aten_ops_fmod( ctx: ConversionContext, target: Target, @@ -290,7 +299,7 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) def aten_ops_relu( ctx: ConversionContext, target: Target, @@ -307,7 +316,7 @@ def aten_ops_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) def aten_ops_sigmoid( ctx: ConversionContext, target: Target, @@ -324,12 +333,12 @@ def aten_ops_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_index( ctx: ConversionContext, target: Target, @@ -347,7 +356,7 @@ def aten_ops_index( ) -@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) def aten_ops_tanh( ctx: ConversionContext, target: Target, @@ -364,7 +373,7 @@ def aten_ops_tanh( ) -@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) def aten_ops_leaky_relu( ctx: ConversionContext, target: Target, @@ -382,7 +391,7 @@ def aten_ops_leaky_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.elu.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.elu.default) def aten_ops_elu( ctx: ConversionContext, target: Target, @@ -401,7 +410,7 @@ def aten_ops_elu( ) -@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) def aten_ops_softplus( ctx: ConversionContext, target: Target, @@ -419,7 +428,7 @@ def aten_ops_softplus( ) -@dynamo_tensorrt_converter(torch.ops.aten.clip.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.clip.default) def aten_ops_clip( ctx: ConversionContext, target: Target, @@ -438,7 +447,7 @@ def aten_ops_clip( ) -@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) def aten_ops_hard_sigmoid( ctx: ConversionContext, target: Target, @@ -457,10 +466,10 @@ def aten_ops_hard_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.bmm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.matmul) +@dynamo_tensorrt_converter(torch.ops.aten.mm.default) +@dynamo_tensorrt_converter(torch.ops.aten.mv.default) +@dynamo_tensorrt_converter(torch.ops.aten.bmm.default) def aten_ops_matmul( ctx: ConversionContext, target: Target, @@ -478,7 +487,7 @@ def aten_ops_matmul( ) -@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) def aten_ops_rsqrt( ctx: ConversionContext, target: Target, @@ -495,7 +504,7 @@ def aten_ops_rsqrt( ) -@dynamo_tensorrt_converter(torch.ops.aten.neg.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.neg.default) def aten_ops_neg( ctx: ConversionContext, target: Target, @@ -512,8 +521,8 @@ def aten_ops_neg( ) -@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) def aten_ops_squeeze( ctx: ConversionContext, target: Target, @@ -524,7 +533,7 @@ def aten_ops_squeeze( return impl.squeeze.squeeze(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.erf.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.erf.default) def aten_ops_erf( ctx: ConversionContext, target: Target, @@ -541,7 +550,7 @@ def aten_ops_erf( ) -@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) def aten_ops_unsqueeze( ctx: ConversionContext, target: Target, @@ -554,7 +563,7 @@ def aten_ops_unsqueeze( ) -@dynamo_tensorrt_converter(torch.ops.aten._softmax.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten._softmax.default) def aten_ops_softmax( ctx: ConversionContext, target: Target, @@ -569,14 +578,14 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) -) # type: ignore[misc] +) @dynamo_tensorrt_converter( torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) -) # type: ignore[misc] +) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, capability_validator=dynamic_unsupported_with_args([1]), -) # type: ignore[misc] +) def aten_ops_split( ctx: ConversionContext, target: Target, @@ -595,7 +604,7 @@ def aten_ops_split( ) -@dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.where.self) def aten_ops_where( ctx: ConversionContext, target: Target, @@ -614,7 +623,7 @@ def aten_ops_where( ) -@dynamo_tensorrt_converter(torch.ops.aten.clamp.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.clamp.default) def aten_ops_clamp( ctx: ConversionContext, target: Target, @@ -633,7 +642,7 @@ def aten_ops_clamp( ) -@dynamo_tensorrt_converter(torch.ops.aten.select.int) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.select.int) def aten_ops_select( ctx: ConversionContext, target: Target, @@ -646,7 +655,7 @@ def aten_ops_select( ) -@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) def aten_ops_slice( ctx: ConversionContext, target: Target, @@ -667,12 +676,12 @@ def aten_ops_slice( ) -@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_chunk( ctx: ConversionContext, target: Target, @@ -691,12 +700,12 @@ def aten_ops_chunk( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_permute( ctx: ConversionContext, target: Target, @@ -761,11 +770,11 @@ def validator(to_copy_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.clone.default, capability_validator=lambda node: not is_only_operator_on_placeholder(node), -) # type: ignore[misc] +) @dynamo_tensorrt_converter( torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator(placeholder_only=False), -) # type: ignore[misc] +) def aten_ops_clone_copy_dtype( ctx: ConversionContext, target: Target, @@ -787,11 +796,11 @@ def aten_ops_clone_copy_dtype( @dynamo_tensorrt_converter( torch.ops.aten.clone.default, capability_validator=is_only_operator_on_placeholder, -) # type: ignore[misc] +) @dynamo_tensorrt_converter( torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator(placeholder_only=True), -) # type: ignore[misc] +) def aten_ops_clone_copy_placeholder( ctx: ConversionContext, target: Target, @@ -813,7 +822,7 @@ def aten_ops_clone_copy_placeholder( ) -@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.expand.default) def aten_ops_expand( ctx: ConversionContext, target: Target, @@ -843,7 +852,7 @@ def amax_param_validator(amax_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.amax.default, capability_validator=amax_param_validator -) # type: ignore[misc] +) def aten_ops_amax( ctx: ConversionContext, target: Target, @@ -862,9 +871,9 @@ def aten_ops_amax( ) -@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.prims.sum.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sum.default) +@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) +@dynamo_tensorrt_converter(torch.ops.prims.sum.default) def aten_ops_sum( ctx: ConversionContext, target: Target, @@ -896,8 +905,8 @@ def aten_ops_sum( return sum_ -@dynamo_tensorrt_converter(torch.ops.aten.prod.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.prod.default) +@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int) def aten_ops_prod( ctx: ConversionContext, target: Target, @@ -916,8 +925,10 @@ def aten_ops_prod( ) -@dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.max.default) +@dynamo_tensorrt_converter( + torch.ops.aten.max.dim, capability_validator=one_user_validator +) def aten_ops_max( ctx: ConversionContext, target: Target, @@ -937,8 +948,10 @@ def aten_ops_max( ) -@dynamo_tensorrt_converter(torch.ops.aten.min.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.min.dim, capability_validator=one_user_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.min.default) +@dynamo_tensorrt_converter( + torch.ops.aten.min.dim, capability_validator=one_user_validator +) def aten_ops_min( ctx: ConversionContext, target: Target, @@ -958,8 +971,8 @@ def aten_ops_min( ) -@dynamo_tensorrt_converter(torch.ops.aten.mean.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.mean.dim) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mean.default) +@dynamo_tensorrt_converter(torch.ops.aten.mean.dim) def aten_ops_mean( ctx: ConversionContext, target: Target, @@ -978,7 +991,7 @@ def aten_ops_mean( ) -@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.exp.default) def aten_ops_exp( ctx: ConversionContext, target: Target, @@ -995,7 +1008,7 @@ def aten_ops_exp( ) -@dynamo_tensorrt_converter(torch.ops.aten.log.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.log.default) def aten_ops_log( ctx: ConversionContext, target: Target, @@ -1012,7 +1025,7 @@ def aten_ops_log( ) -@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default) def aten_ops_sqrt( ctx: ConversionContext, target: Target, @@ -1029,7 +1042,7 @@ def aten_ops_sqrt( ) -@dynamo_tensorrt_converter(torch.ops.aten.reciprocal.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.reciprocal.default) def aten_ops_recip( ctx: ConversionContext, target: Target, @@ -1046,7 +1059,7 @@ def aten_ops_recip( ) -@dynamo_tensorrt_converter(torch.ops.aten.abs.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.abs.default) def aten_ops_abs( ctx: ConversionContext, target: Target, @@ -1063,7 +1076,7 @@ def aten_ops_abs( ) -@dynamo_tensorrt_converter(torch.ops.aten.sin.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sin.default) def aten_ops_sin( ctx: ConversionContext, target: Target, @@ -1080,7 +1093,7 @@ def aten_ops_sin( ) -@dynamo_tensorrt_converter(torch.ops.aten.cos.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.cos.default) def aten_ops_cos( ctx: ConversionContext, target: Target, @@ -1097,7 +1110,7 @@ def aten_ops_cos( ) -@dynamo_tensorrt_converter(torch.ops.aten.tan.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.tan.default) def aten_ops_tan( ctx: ConversionContext, target: Target, @@ -1114,7 +1127,7 @@ def aten_ops_tan( ) -@dynamo_tensorrt_converter(torch.ops.aten.sinh.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sinh.default) def aten_ops_sinh( ctx: ConversionContext, target: Target, @@ -1131,7 +1144,7 @@ def aten_ops_sinh( ) -@dynamo_tensorrt_converter(torch.ops.aten.cosh.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.cosh.default) def aten_ops_cosh( ctx: ConversionContext, target: Target, @@ -1148,7 +1161,7 @@ def aten_ops_cosh( ) -@dynamo_tensorrt_converter(torch.ops.aten.asin.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.asin.default) def aten_ops_asin( ctx: ConversionContext, target: Target, @@ -1165,7 +1178,7 @@ def aten_ops_asin( ) -@dynamo_tensorrt_converter(torch.ops.aten.acos.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.acos.default) def aten_ops_acos( ctx: ConversionContext, target: Target, @@ -1182,7 +1195,7 @@ def aten_ops_acos( ) -@dynamo_tensorrt_converter(torch.ops.aten.atan.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.atan.default) def aten_ops_atan( ctx: ConversionContext, target: Target, @@ -1199,7 +1212,7 @@ def aten_ops_atan( ) -@dynamo_tensorrt_converter(torch.ops.aten.asinh.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.asinh.default) def aten_ops_asinh( ctx: ConversionContext, target: Target, @@ -1216,7 +1229,7 @@ def aten_ops_asinh( ) -@dynamo_tensorrt_converter(torch.ops.aten.acosh.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.acosh.default) def aten_ops_acosh( ctx: ConversionContext, target: Target, @@ -1233,7 +1246,7 @@ def aten_ops_acosh( ) -@dynamo_tensorrt_converter(torch.ops.aten.atanh.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.atanh.default) def aten_ops_atanh( ctx: ConversionContext, target: Target, @@ -1250,7 +1263,7 @@ def aten_ops_atanh( ) -@dynamo_tensorrt_converter(torch.ops.aten.ceil.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.ceil.default) def aten_ops_ceil( ctx: ConversionContext, target: Target, @@ -1267,7 +1280,7 @@ def aten_ops_ceil( ) -@dynamo_tensorrt_converter(torch.ops.aten.floor.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.floor.default) def aten_ops_floor( ctx: ConversionContext, target: Target, @@ -1284,7 +1297,7 @@ def aten_ops_floor( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_not.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.logical_not.default) def aten_ops_logical_not( ctx: ConversionContext, target: Target, @@ -1301,7 +1314,7 @@ def aten_ops_logical_not( ) -@dynamo_tensorrt_converter(torch.ops.aten.sign.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sign.default) def aten_ops_sign( ctx: ConversionContext, target: Target, @@ -1318,7 +1331,7 @@ def aten_ops_sign( ) -@dynamo_tensorrt_converter(torch.ops.aten.round.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.round.default) def aten_ops_round( ctx: ConversionContext, target: Target, @@ -1335,7 +1348,7 @@ def aten_ops_round( ) -@dynamo_tensorrt_converter(torch.ops.aten.isinf.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.isinf.default) def aten_ops_isinf( ctx: ConversionContext, target: Target, @@ -1352,8 +1365,8 @@ def aten_ops_isinf( ) -@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) def aten_ops_add( ctx: ConversionContext, target: Target, @@ -1384,8 +1397,8 @@ def aten_ops_add( ) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) def aten_ops_mul( ctx: ConversionContext, target: Target, @@ -1403,7 +1416,7 @@ def aten_ops_mul( ) -@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) def aten_ops_maximum( ctx: ConversionContext, target: Target, @@ -1421,7 +1434,7 @@ def aten_ops_maximum( ) -@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) def aten_ops_minimum( ctx: ConversionContext, target: Target, @@ -1439,8 +1452,8 @@ def aten_ops_minimum( ) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) def aten_ops_sub( ctx: ConversionContext, target: Target, @@ -1471,11 +1484,11 @@ def aten_ops_sub( ) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) +@dynamo_tensorrt_converter(torch.ops.prims.div.default) def aten_ops_div( ctx: ConversionContext, target: Target, @@ -1518,9 +1531,9 @@ def aten_ops_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) def aten_ops_pow( ctx: ConversionContext, target: Target, @@ -1538,8 +1551,8 @@ def aten_ops_pow( ) -@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) def aten_ops_floor_div( ctx: ConversionContext, target: Target, @@ -1557,7 +1570,7 @@ def aten_ops_floor_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) def aten_ops_logical_and( ctx: ConversionContext, target: Target, @@ -1575,7 +1588,7 @@ def aten_ops_logical_and( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) def aten_ops_logical_or( ctx: ConversionContext, target: Target, @@ -1593,7 +1606,7 @@ def aten_ops_logical_or( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) def aten_ops_logical_xor( ctx: ConversionContext, target: Target, @@ -1611,8 +1624,8 @@ def aten_ops_logical_xor( ) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) def aten_ops_equal( ctx: ConversionContext, target: Target, @@ -1630,8 +1643,8 @@ def aten_ops_equal( ) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) def aten_ops_greater( ctx: ConversionContext, target: Target, @@ -1649,8 +1662,8 @@ def aten_ops_greater( ) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) def aten_ops_less( ctx: ConversionContext, target: Target, @@ -1674,14 +1687,14 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, capability_validator=conv_param_validator -) # type: ignore[misc] +) @enforce_tensor_types( { 0: (TRTTensor,), 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) # type: ignore[misc] +) def aten_ops_convolution( ctx: ConversionContext, target: Target, @@ -1722,8 +1735,8 @@ def aten_ops_convolution( ) -@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.linear) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.linear.default) +@dynamo_tensorrt_converter(torch.ops.aten.linear) def aten_ops_linear( ctx: ConversionContext, target: Target, @@ -1762,9 +1775,15 @@ def avg_pool_param_validator(pool_node: Node) -> bool: # Note: AvgPool1d uses avg_pool2d as it converts to 2D first. -@dynamo_tensorrt_converter(torch.ops.aten.avg_pool1d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.avg_pool1d.default, capability_validator=avg_pool_param_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator +) def aten_ops_avg_pool( ctx: ConversionContext, target: Target, @@ -1805,9 +1824,15 @@ def max_pool_param_validator(pool_node: Node) -> bool: # Note: MaxPool1d uses max_pool2d as it converts to 2D first. -@dynamo_tensorrt_converter(torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator +) def aten_ops_max_pool( ctx: ConversionContext, target: Target, @@ -1831,7 +1856,7 @@ def aten_ops_max_pool( @dynamo_tensorrt_converter( torch.nn.functional.scaled_dot_product_attention, -) # type: ignore[misc] +) def tensorrt_scaled_dot_product_attention( ctx: ConversionContext, target: Target, @@ -1844,13 +1869,13 @@ def tensorrt_scaled_dot_product_attention( ) -@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.view.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) +@dynamo_tensorrt_converter(torch.ops.aten.view.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_reshape( ctx: ConversionContext, target: Target, @@ -1868,8 +1893,8 @@ def aten_ops_reshape( ) -@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc] +@enforce_tensor_types({0: (TRTTensor,)}) +@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index b382c5c329..7d133ceffa 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,13 +4,12 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np -import tensorrt as trt import torch from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_registry import ( +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, DynamoConverterImplSignature, ) @@ -21,6 +20,8 @@ ) from torch_tensorrt.fx.types import TRTDataType, TRTTensor +import tensorrt as trt + _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index 790f0d6f60..bc6af1a32d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -4,7 +4,7 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 7a980327a2..3a67c47fa3 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -5,23 +5,25 @@ import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterRegistry, + dynamo_tensorrt_converter, +) from torch_tensorrt.fx.types import TRTTensor -from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter - _LOGGER: logging.Logger = logging.getLogger(__name__) def getitem_validator(getitem_node: Node) -> bool: - from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS + from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS # Getitem nodes can only be converted if their parent node also can return getitem_node.args[0] in DYNAMO_CONVERTERS # TODO: Subsequent evaluators should be registered here with their own validators -@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.detach.default) # type: ignore[misc] +@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +@dynamo_tensorrt_converter(torch.ops.aten.detach.default) def generic_evaluator( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py index a8c0dfa6fd..9548dc287a 100644 --- a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py @@ -6,10 +6,11 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) from torch_tensorrt.fx.types import TRTTensor -from .converter_registry import dynamo_tensorrt_converter - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -25,8 +26,8 @@ def broadcast_checker(broadcast_node: torch.fx.Node) -> bool: @dynamo_tensorrt_converter( torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker -) # type: ignore[misc] -def aten_ops_broadcast_in_dim( +) +def prim_ops_broadcast_in_dim( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 1fbe0cd120..7c4e9fdd2d 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,3 +1,7 @@ +from ._decomposition_groups import ( + torch_disabled_decompositions, + torch_enabled_decompositions, +) from ._decompositions import get_decompositions # noqa: F401 from ._fusers import * # noqa: F401 from ._repair_input_aliasing import repair_input_aliasing diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 3f20c7efc8..55df3cb2b3 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -17,10 +17,10 @@ MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) -from torch_tensorrt.dynamo.conversion.converter_registry import ( +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 49ee02b4cf..f6149a2271 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -10,10 +10,10 @@ MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION, ) -from torch_tensorrt.dynamo.conversion.converter_registry import ( +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/tools/opset_coverage.py b/py/torch_tensorrt/dynamo/tools/opset_coverage.py index 977d8db875..bfa57d3ed8 100644 --- a/py/torch_tensorrt/dynamo/tools/opset_coverage.py +++ b/py/torch_tensorrt/dynamo/tools/opset_coverage.py @@ -13,10 +13,7 @@ import torchgen from torch._dynamo.variables import BuiltinVariable from torch._ops import OpOverload -from torch_tensorrt.dynamo.conversion.converter_registry import ( - DYNAMO_CONVERTERS, - ConverterRegistry, -) +from torch_tensorrt.dynamo.conversion import DYNAMO_CONVERTERS, ConverterRegistry from torch_tensorrt.dynamo.lowering import get_decompositions from torchgen.gen import parse_native_yaml diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index fa84a5d0c4..26de1fcb27 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,12 +5,12 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import torch +import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._defaults import PRECISION +from torch_tensorrt.dynamo._settings import CompilationSettings -import torch_tensorrt from packaging import version logger = logging.getLogger(__name__)