Skip to content

Torch-TensorRT v1.0.0

Compare
Choose a tag to compare
@narendasan narendasan released this 09 Nov 08:26

New Name!, Support for PyTorch 1.10, CUDA 11.3, New Packaging and Distribution Options, Stabilized APIs, Stabilized Partial Compilation, Adjusted Default Behavior, Usability Improvements, New Converters, Bug Fixes

This is the first stable release of Torch-TensorRT targeting PyTorch 1.10, CUDA 11.3 (on x86_64, CUDA 10.2 on aarch64), cuDNN 8.2 and TensorRT 8.0 with backwards compatible source for TensorRT 7.1. On aarch64 TRTorch targets Jetpack 4.6 primarily with backwards compatible source for Jetpack 4.5. This version also removes deprecated APIs such as InputRange and op_precision

New Name

TRTorch is now Torch-TensorRT! TRTorch started out as a small experimental project compiling TorchScript to TensorRT almost two years ago and now as we are hitting v1.0.0 with APIs and major features stabilizing we felt that the name of the project should reflect the ecosystem of tools it is joining with this release, namely TF-TRT (https://blog.tensorflow.org/2021/01/leveraging-tensorflow-tensorrt-integration.html) and MXNet-TensorRT(https://mxnet.apache.org/versions/1.8.0/api/python/docs/tutorials/performance/backend/tensorrt/tensorrt). Since we were already significantly changing APIs with this release to reflect what we learned over the last two years of using TRTorch, we felt this is was the right time to change the name as well.

The overall process to port forward from TRTorch is as follows:

  • Python

    • The library has been renamed from trtorch to torch_tensorrt
    • Components that used to all live under the trtorch namespace have now been separated. IR agnostic components: torch_tensorrt.Input, torch_tensorrt.Device, torch_tensorrt.ptq, torch_tensorrt.logging will continue to live under the top level namespace. IR specific components like torch_tensorrt.ts.compile, torch_tensorrt.ts.convert_method_to_trt_engine, torch_tensorrt.ts.TensorRTCompileSpec will live in a TorchScript specific namespace. This gives us space to explore the other IRs that might be relevant to the project in the future. In the place of the old top level compile and convert_method_to_engine are new ones which will call the IR specific versions based on what is provided to them. This also means that you can now provide a raw torch.nn.Module to torch_tensorrt.compile and Torch-TensorRT will handle the TorchScripting step for you. For the most part the sole change that will be needed to change over namespaces is to exchange trtorch to torch_tensorrt
  • C++

    • Similar to Python the namespaces in C++ have changed from trtorch to torch_tensorrt and components specific to the IR like compile, convert_method_to_trt_engine and CompileSpec are in a torchscript namespace, while agnostic components are at the top level. Namespace aliases for torch_tensorrt -> torchtrt and torchscript -> ts are included. Again the port forward process for namespaces should be a find and replace. Finally the libraries libtrtorch.so, libtrtorchrt.so and libtrtorch_plugins.so have been renamed to libtorchtrt.so, libtorchtrt_runtime.so and libtorchtrt_plugins.so respectively.
  • CLI:

    • trtorch has been renamed to torchtrtc

New Distribution Options and Packaging

Starting with nvcr.io/nvidia/pytorch:21.11, Torch-TensorRT will be distributed as part of the container (https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). The version of Torch-TensorRT in container will be the state of the master at the time of building. Torch-TensorRT will be validated to run correctly with the version of PyTorch, CUDA, cuDNN and TensorRT in the container. This will serve as the easiest way to have a full validated PyTorch end to end training to inference stack and serves as a great starting point for building DL applications.

Also as part of Torch-TensorRT we are now starting to distribute the full C++ package within the wheel files for the Python packages. By installing the wheel you now get the Python API, the C++ libraries + headers and the CLI binary. This is going to be the easiest way to install Torch-TensorRT on your stack. After installing with pip

pip3 install torch-tensorrt -f https://github.com/NVIDIA/Torch-TensorRT/releases

You can add the following to your PATH to set up the CLI

PATH=$PATH:<PATH TO TORCHTRT PYTHON PACKAGE>/bin

Stabilized APIs

Python

Many of the APIs have change slighly in this release to be more self consistent and more usable. These changes begin with the Python API for which compile, convert_method_to_trt_engine and TensorRTCompileSpec now instead of dictionaries use kwargs. As features many features came out of beta and experimental stability the necessity to have multiple levels of nesting in settings has decreased, therefore kwargs make much more sense. You can simply port forward to the new APIs by unwrapping your existing compile_spec dict in the arguments to compile or similar functions.

Example:
compile_settings = {
    "inputs": [torch_tensorrt.Input(
        min_shape=[1, 3, 224, 224],
        opt_shape=[1, 3, 512, 512],
        max_shape=[1, 3, 1024, 1024],
        # For static size shape=[1, 3, 224, 224]
        dtype=torch.half, # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
    )],
    "enabled_precisions": {torch.half}, # Run with FP16
}

trt_ts_module = torch_tensorrt.compile(torch_script_module, **compile_settings)

This release also introduces support for providing tensors as examples to Torch-TensorRT. In place of a torch_tensorrt.Input in the list of inputs you can pass a Tensor. This can only be used to set a static input size. There are also some things to be aware of which will be discussed later in the release notes.

Now that Torch-TensorRT separates components specific to particular IRs to their own namespaces, there is now a replacement for the old compile and convert_method_to_trt_engine functions on the top level. These functions take any PyTorch generated format including torch.nn.Modules and decides the best way to compile it down to TensorRT. In v1.0.0 this means to go through TorchScript and return a Torch.jit.ScriptModule. You can specify the IR to try using the ir arg for these functions.

Due to partial compilation becoming stable in v1.0.0, there are now four new fields which replace the old torch_fallback struct.

  • old:
complie_spec = {
  "torch_fallback": {
      "enabled": True, # Turn on or turn off falling back to PyTorch if operations are not supported in TensorRT
      "force_fallback_ops": [
          "aten::max_pool2d" # List of specific ops to require running in PyTorch
      ],
      "force_fallback_modules": [
          "mypymod.mytorchmod" # List of specific torch modules to require running in PyTorch
      ],
      "min_block_size": 3 # Minimum number of ops an engine must incapsulate to be run in TensorRT
  }
}
  • new:
torch_tensorrt.compile(...,
    require_full_compilation=False, 
    min_block_size=3, 
    torch_executed_ops=[ "aten::max_pool2d" ], 
    torch_executed_modules=["mypymod.mytorchmod"])

C++

The changes for the C++ API other than the reorganization and renaming of the namespaces, mostly serve to make Torch-TensorRT consistent between Python and C++ namely by renaming trtorch::CompileGraph to torch_tensorrt::ts::compile and trtorch::ConvertGraphToTRTEngine to torch_tensorrt::ts::convert_method_to_trt_engine. Beyond that similar to Python, the partial compilation struct TorchFallback has been removed and replaced by four fields in torch_tensorrt::ts::CompileSpec

  • old:
  /**
   * @brief A struct to hold fallback info
   */
  struct TRTORCH_API TorchFallback {
    /// enable the automatic fallback feature
    bool enabled = false;

    /// minimum consecutive operation number that needs to be satisfied to convert to TensorRT
    uint64_t min_block_size = 1;

    /// A list of names of operations that will explicitly run in PyTorch
    std::vector<std::string> forced_fallback_ops;

    /// A list of names of modules that will explicitly run in PyTorch
    std::vector<std::string> forced_fallback_modules;

    /**
     * @brief Construct a default Torch Fallback object, fallback will be off
     */
    TorchFallback() = default;

    /**
     * @brief Construct from a bool
     */
    TorchFallback(bool enabled) : enabled(enabled) {}

    /**
     * @brief Constructor for setting min_block_size
     */
    TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
  };
  • new:
  /**
   * Require the full module be compiled to TensorRT instead of potentially running unsupported operations in PyTorch
   */
  bool require_full_compilation = false;

  /**
   * Minimum number of contiguous supported operators to compile a subgraph to TensorRT
   */
  uint64_t min_block_size = 3;

  /**
   * List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but
   * ``require_full_compilation`` is True
   */
  std::vector<std::string> torch_executed_ops;

  /**
   * List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but
   * ``require_full_compilation`` is True
   */
  std::vector<std::string> torch_executed_modules;

CLI

Similarly these partial compilation fields have been renamed in torchtrtc:

    --require-full-compilation        Require that the model should be fully
                                      compiled to TensorRT or throw an error
    --teo=[torch-executed-ops...],
    --torch-executed-ops=[torch-executed-ops...]
                                      (Repeatable) Operator in the graph that
                                      should always be run in PyTorch for
                                      execution (partial compilation must be
                                      enabled)
    --tem=[torch-executed-mods...],
    --torch-executed-mods=[torch-executed-mods...]
                                      (Repeatable) Module that should always
                                      be run in Pytorch for execution (partial
                                      compilation must be enabled)
    --mbs=consecutive_ops,
    --min-block-size=consecutive_ops
                                      Minimum number of contiguous TensorRT
                                      supported ops to compile a subgraph to
                                      TensorRT

Going forward breaking changes to the API the sort of magnitude seen in this release will be accompanied by a major version bump.

Stabilized Partial Compilation

Partial compilation should be considered stable for static input shape and is now enabled by default. In the case of dynamic shape, set require_full_compilation to True.

Adjusted Defaults

Input Types

Default behavior of Torch-TensorRT has shifted slightly. The most important of these changes is the changes to inferred input type. In prior versions the expected input type for a Tensor barring it being set explicitly was based on the op_precision. With that field being removed in this release and being replaced with enabled_precisions introduced in v0.4.0 this sort of behavior no longer makes sense. Therefore now Torch-TensorRT follows these rules to determine Input type for a Tensor.

  1. If no dtype is specified for an Input, Torch-TensorRT will determine the input type by inspecting the uses of this Input. It will trace the lifetime of this tensor to the first tensor operation using weights stored in the provided module. The type of the weights is the inferred type of the Input using the rule that PyTorch requires like types for Tensor operations. The goal with this behavior is to maintain the concept that Torch-TensorRT modules should feel no different than normal PyTorch modules. Therefore you can expect

    Weight Type of Model Expected Input Type For Tensor
    FP32 FP32
    FP16 FP16
    Quantization Workflows FP32
    Unknown / Ambiguous FP32 w/ Warning
  2. Users can override this behavior to set the Input type to whatever they wish using the dtype field of torch_tensorrt.Input. Torch-TensorRT will always respect the user setting but may throw a warning stating that the model provided expects a different input type. This is mainly to notify you that just dropping the compiled module in place of the raw torch.nn.Module might throw errors and casting before inference might be necessary.

    • With Torch-TensorRT v1.0.0 you can provide example torch Tensors to set the input shape. However, this not only sets the input shape but also the input dtype and tensor format as well. So if you provide a half precision 1x3x32x32 contiguous tensor the expected input would be Input(shape=(1, 3, 32, 32), dtype=dtype.half, format=TensorFormat.contiguous). This is subject to the behavior in 2.

Workspace Size

Now by default the workspace size is set to 1GB for all GPUs Pascal based and newer (SM capability 6 or above). Maxwell and older cards including Jetson Nano have a workspace of 256MB by default. This value is user settable.

Dependencies

- Bazel 4.2.1
- LibTorch 1.10.0
- CUDA 11.3 (on x86_64, by default, newer CUDA 11 supported with compatible PyTorch Build), 10.2 (on aarch64)
- cuDNN 8.2.4.15
- TensorRT 8.0.3.4

1.0.0 (2021-11-09)

Bug Fixes

  • aten::gelu call was wrong in test (40bc4e3)
  • Fix a core partitioning algo bug where non-tensor input segments are not updated correctly (cc10876)
  • Fix modules_as_engines test case to use trt_mod instead of pyt_mod (282e98a)
  • Fix plugin registration macro (8afab22)
  • Fix python API tests for mobilenet v2 (e5a38ff)
  • Partial compilation translation to internal settings was incorrect (648bad3)
  • //py: Don't crash harshly on import when CUDA is not available (07e16fd)
  • Renable backtrace and make it less repetitive (1435845)
  • //core/lowering: Fixes module level fallback recursion (f94ae8f)
  • //core/partitioing: Fixing support for paritally compiling (748ecf3)
  • //docker: Update docker container build script to use release path (9982855)
  • //py: Add new dirs to remove during clean (d2cc1e9)
  • //py: Fix some api import issues (840ca89)
  • //py: Fix trtorch.Device alternate contructor options (fa08311)
  • //py: Fix trtorch.Device alternate contructor options (ac26841)
  • Update notebooks with new library name Torch-TensorRT (8274fd9)
  • aten::conv1d: Update namespace, fix typo in dest IR for conv1d (d53f136)
  • eval: Rollback 1.11a0 change + namespace issues (ba743f5)
  • Use scripting instead of tracing for module fallback tests (32e8b53)
  • Workspace defaults for other apis and centralize cuda api use (930321e)

Features

  • Add functionality for tests to use precompiled libraries (b5c324a)

  • Add QAT patch which modifies scale factor dtype to INT32 (4a10673)

  • Add TF32 override flag in bazelrc for CI-Testing (7a0c9a5)

  • Add VGG QAT sample notebook which demonstrates end-end workflow for QAT models (8bf6dd6)

  • Augment python package to include bin, lib, include directories (ddc0685)

  • handle scalar type of size [] in shape_analysis (fca53ce)

  • support aten::and.bool evaluator (6d73e43)

  • support aten::conv1d and aten::conv_transpose1d (c8dc6e9)

  • support aten::eq.str evaluator (5643972)

  • support setting input types of subgraph in fallback, handle Tensor type in evaluated_value_map branch in MarkOutputs (4778b2b)

  • support truncate_long_and_double in fallback subgraph input type (0bc3c05)

  • Update documentation with new library name Torch-TensorRT (e5f96d9)

  • Updating the pre_built to prebuilt (51412c7)

  • //:libtrtorch: Ship a WORKSPACE file and BUILD file with the (7ac6f1c)

  • //core/partitioning: Improved logging and code org for the (8927e77)

  • //cpp: Adding example tensors as a way to set input spec (70a7bb3)

  • //py: Add the git revision to non release builds (4a0a918)

  • //py: Allow example tensors from torch to set shape (01d525d)

  • feat!: Changing the default behavior for selecting the input type (a234335)

  • refactor!: Removing deprecated InputRange, op_precision and input_shapes (621bc67)

  • feat(//py)!: Porting forward the API to use kwargs (17e0e8a)

  • refactor(//py)!: Kwargs updates and support for shifting internal apis (2a0d1c8)

  • refactor!(//cpp): Inlining partial compilation settings since the (19ecc64)

  • refactor! : Update default workspace size based on platforms. (391a4c0)

  • feat!: Turning on partial compilation by default (52e2f05)

  • refactor!: API level rename (483ef59)

  • refactor!: Changing the C++ api to be snake case (f34e230)

  • refactor! : Update Pytorch version to 1.10 (cc7d0b7)

  • refactor!: Updating bazel version for py build container (06533fe)

BREAKING CHANGES

  • This removes the InputRange Class and op_precision and
    input shape fields which were deprecated in TRTorch v0.4.0

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This change updates the bazel version
    to build Torch-TensorRT to 4.2.1.

This was done since the only version of bazel available
in our build container for python apis is 4.2.1

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This changes the API for compile settings
    from a dictionary of settings to a set of kwargs for the various
    compilation functions. This will break existing code. However
    there is simple guidance to port forward your code:

Given a dict of valid TRTorch CompileSpec settings

spec = {
	"inputs": ...
	...
}

You can use this same dict with the new APIs by changing your code from:

trtorch.compile(mod, spec)

to:

trtorch.compile(mod, **spec)

which will unpack the dictionary as arguments to the function

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This commit changes the APIs from a dictionary of
    arguements to a set of kwargs. You can port forward using
trtorch.compile(mod, **spec)

Also in preparation for partial compilation to be enabled by default
settings related to torch fallback have been moved to the top level

instead of

"torch_fallback": {
  "enabled": True,
  "min_block_size" " 3,
  "forced_fallback_ops" : ["aten::add"],
  "forced_fallback_mods" : ["MySubModule"]
}

now there are new settings

require_full_compilation=False,
min_block_size=3,
torch_executed_ops=["aten::add"],
torch_executed_modules=["MySubModule"]

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This commit changes the API for automatic fallback
    to inline settings regarding partial compilation in preparation
    for it to be turned on by default

Now in the compile spec instead of a torch_fallback field with its
associated struct, there are four new fields in the compile spec

bool require_full_compilation = true;
uint64_t min_block_size = 3;
std::vector<std::string> torch_executed_ops = {};
std::vector<std::string> torch_executed_modules = {};

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This commit sets the default workspace size to 1GB for GPU platforms and 256MB for Jetson Nano/TX1 platforms whose compute capability is < 6.

Signed-off-by: Dheeraj Peri [email protected]

Signed-off-by: Dheeraj Peri [email protected]

Signed-off-by: Dheeraj Peri [email protected]

Signed-off-by: Dheeraj Peri [email protected]

Signed-off-by: Dheeraj Peri [email protected]

  • This commit turns on partial compilation
    by default. Unsupported modules will attempt to be
    run partially in PyTorch and partially in TensorRT

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This commit renames the namespaces of all
    TRTorch/Torch-TensorRT APIs. Now torchscript specific functions
    are segregated into their own torch_tensorrt::torchscript /
    torch_tensorrt.ts namespaces. Generic utils will remain in the
    torch_tensorrt namespace. Guidance on how to port forward will follow in
    the next commits
  • This changes the C++ API ::ts
    APIs to be snake case and for CompileModules to
    become just compile

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This commit updates the pytorch version to 1.10. To use python API of torch_tensorrt, please upgrade your local pytorch to 1.10 to avoid ABI incompatibility errors. WORKSPACE and requirements files are updated accordingly

Signed-off-by: Dheeraj Peri [email protected]

Signed-off-by: Dheeraj Peri [email protected]

  • This commit changes the default behavior of
    the compiler where if the user does not specify an input data
    type explicity instead of using the enabled precision, now
    the compiler will inspect the model provided to infer the
    data type for the input that will not cause an error if
    the model was run in torch. In practice this means
  • If the weights are in FP32 for the first tensor calculation
    then default input type is FP32
  • If the weights are in FP16 for the first tensor calculation
    then default input type is FP16
  • etc.

If the data type cannot be determined the compiler will
default to FP32.

This calculation is done per input tensor so if one input
is inferred to use FP32 and another INT32 then the expected
types will be the same (FP32, INT32)

As was the same before if the user defines the data type
explicitly or provides an example tensor the data type
specified there will be respected

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

Operators Supported

Operators Currently Supported Through Converters

  • aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor)
  • aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
  • aten::abs(Tensor self) -> (Tensor)
  • aten::acos(Tensor self) -> (Tensor)
  • aten::acosh(Tensor self) -> (Tensor)
  • aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)
  • aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)
  • aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
  • aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::asin(Tensor self) -> (Tensor)
  • aten::asinh(Tensor self) -> (Tensor)
  • aten::atan(Tensor self) -> (Tensor)
  • aten::atanh(Tensor self) -> (Tensor)
  • aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[0], bool ceil_mode=False, bool count_include_pad=True) -> (Tensor)
  • aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::bmm(Tensor self, Tensor mat2) -> (Tensor)
  • aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::ceil(Tensor self) -> (Tensor)
  • aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)
  • aten::clamp_max(Tensor self, Scalar max) -> (Tensor)
  • aten::clamp_min(Tensor self, Scalar min) -> (Tensor)
  • aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
  • aten::cos(Tensor self) -> (Tensor)
  • aten::cosh(Tensor self) -> (Tensor)
  • aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)
  • aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::div_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))
  • aten::div_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
  • aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)
  • aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::erf(Tensor self) -> (Tensor)
  • aten::exp(Tensor self) -> (Tensor)
  • aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))
  • aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))
  • aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)
  • aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)
  • aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
  • aten::floor(Tensor self) -> (Tensor)
  • aten::floor_divide(Tensor self, Tensor other) -> (Tensor)
  • aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::gelu(Tensor self) -> (Tensor)
  • aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor)
  • aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)
  • aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))
  • aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)
  • aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> (Tensor(a!))
  • aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor)
  • aten::log(Tensor self) -> (Tensor)
  • aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
  • aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)
  • aten::matmul(Tensor self, Tensor other) -> (Tensor)
  • aten::max(Tensor self) -> (Tensor)
  • aten::max.other(Tensor self, Tensor other) -> (Tensor)
  • aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], int[3] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::min(Tensor self) -> (Tensor)
  • aten::min.other(Tensor self, Tensor other) -> (Tensor)
  • aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::narrow(Tensor(a) self, int dim, int start, int length) -> (Tensor(a))
  • aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> (Tensor(a))
  • aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::neg(Tensor self) -> (Tensor)
  • aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)
  • aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))
  • aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)
  • aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)
  • aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)
  • aten::prelu(Tensor self, Tensor weight) -> (Tensor)
  • aten::prod(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::reciprocal(Tensor self) -> (Tensor)
  • aten::relu(Tensor input) -> (Tensor)
  • aten::relu_(Tensor(a!) self) -> (Tensor(a!))
  • aten::repeat(Tensor self, int[] repeats) -> (Tensor)
  • aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)
  • aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)
  • aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)
  • aten::reshape(Tensor self, int[] shape) -> (Tensor)
  • aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))
  • aten::sigmoid(Tensor input) -> (Tensor)
  • aten::sigmoid_(Tensor(a!) self) -> (Tensor(a!))
  • aten::sin(Tensor self) -> (Tensor)
  • aten::sinh(Tensor self) -> (Tensor)
  • aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
  • aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)
  • aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])
  • aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::sqrt(Tensor self) -> (Tensor)
  • aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::t(Tensor self) -> (Tensor)
  • aten::tan(Tensor self) -> (Tensor)
  • aten::tanh(Tensor input) -> (Tensor)
  • aten::tanh_(Tensor(a!) self) -> (Tensor(a!))
  • aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
  • aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
  • aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(a|b))
  • aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
  • aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))
  • aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)
  • aten::upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)
  • aten::upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::view(Tensor(a) self, int[] size) -> (Tensor(a))
  • trt::const(Tensor self) -> (Tensor)

Operators Currently Supported Through Evaluators

  • aten::Bool.float(float b) -> (bool)
  • aten::Bool.int(int a) -> (bool)
  • aten::Float.Scalar(Scalar a) -> float
  • aten::Float.bool(bool a) -> float
  • aten::Float.int(int a) -> float
  • aten::Int.Scalar(Scalar a) -> int
  • aten::Int.bool(bool a) -> int
  • aten::Int.float(float a) -> int
  • aten::Int.int(int a) -> int
  • aten::and(int a, int b) -> (bool)
  • aten::and.bool(bool a, bool b) -> (bool)
  • aten::getitem.t(t list, int idx) -> (t(*))
  • aten::is(t1 self, t2 obj) -> bool
  • aten::isnot(t1 self, t2 obj) -> bool
  • aten::not(bool self) -> bool
  • aten::or(int a, int b) -> (bool)
  • aten::__round_to_zero_floordiv(int a, int b) -> (int)
  • aten::xor(int a, int b) -> (bool)
  • aten::add.float(float a, float b) -> (float)
  • aten::add.int(int a, int b) -> (int)
  • aten::add_.t(t self, t[] b) -> (t[])
  • aten::append.t(t self, t(c -> *) el) -> (t)
  • aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
    Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::clone(Tensor self, *, int? memory_format=None) -> (Tensor)
  • aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!))
  • aten::dim(Tensor self) -> int
  • aten::div.float(float a, float b) -> (float)
  • aten::div.int(int a, int b) -> (float)
  • aten::eq.bool(bool a, bool b) -> (bool)
  • aten::eq.float(float a, float b) -> (bool)
  • aten::eq.float_int(float a, int b) -> (bool)
  • aten::eq.int(int a, int b) -> (bool)
  • aten::eq.int_float(int a, float b) -> (bool)
  • aten::eq.str(str a, str b) -> (bool)
  • aten::floor.float(float a) -> (int)
  • aten::floor.int(int a) -> (int)
  • aten::floordiv.float(float a, float b) -> (int)
  • aten::floordiv.int(int a, int b) -> (int)
  • aten::ge.bool(bool a, bool b) -> (bool)
  • aten::ge.float(float a, float b) -> (bool)
  • aten::ge.float_int(float a, int b) -> (bool)
  • aten::ge.int(int a, int b) -> (bool)
  • aten::ge.int_float(int a, float b) -> (bool)
  • aten::gt.bool(bool a, bool b) -> (bool)
  • aten::gt.float(float a, float b) -> (bool)
  • aten::gt.float_int(float a, int b) -> (bool)
  • aten::gt.int(int a, int b) -> (bool)
  • aten::gt.int_float(int a, float b) -> (bool)
  • aten::is_floating_point(Tensor self) -> (bool)
  • aten::le.bool(bool a, bool b) -> (bool)
  • aten::le.float(float a, float b) -> (bool)
  • aten::le.float_int(float a, int b) -> (bool)
  • aten::le.int(int a, int b) -> (bool)
  • aten::le.int_float(int a, float b) -> (bool)
  • aten::len.t(t[] a) -> (int)
  • aten::lt.bool(bool a, bool b) -> (bool)
  • aten::lt.float(float a, float b) -> (bool)
  • aten::lt.float_int(float a, int b) -> (bool)
  • aten::lt.int(int a, int b) -> (bool)
  • aten::lt.int_float(int a, float b) -> (bool)
  • aten::mul.float(float a, float b) -> (float)
  • aten::mul.int(int a, int b) -> (int)
  • aten::ne.bool(bool a, bool b) -> (bool)
  • aten::ne.float(float a, float b) -> (bool)
  • aten::ne.float_int(float a, int b) -> (bool)
  • aten::ne.int(int a, int b) -> (bool)
  • aten::ne.int_float(int a, float b) -> (bool)
  • aten::neg.int(int a) -> (int)
  • aten::numel(Tensor self) -> int
  • aten::size(Tensor self) -> (int[])
  • aten::size.int(Tensor self, int dim) -> (int)
  • aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])
  • aten::sqrt.float(float a) -> (float)
  • aten::sqrt.int(int a) -> (float)
  • aten::sub.float(float a, float b) -> (float)
  • aten::sub.int(int a, int b) -> (int)
  • aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)
  • prim::dtype(Tensor a) -> (int)
  • prim::max.bool(bool a, bool b) -> (bool)
  • prim::max.float(float a, float b) -> (bool)
  • prim::max.float_int(float a, int b) -> (bool)
  • prim::max.int(int a, int b) -> (bool)
  • prim::max.int_float(int a, float b) -> (bool)
  • prim::max.self_int(int[] self) -> (int)
  • prim::min.bool(bool a, bool b) -> (bool)
  • prim::min.float(float a, float b) -> (bool)
  • prim::min.float_int(float a, int b) -> (bool)
  • prim::min.int(int a, int b) -> (bool)
  • prim::min.int_float(int a, float b) -> (bool)
  • prim::min.self_int(int[] self) -> (int)
  • prim::shape(Tensor a) -> (int[])