Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying the FX and TS Frontends #1404

Merged
merged 23 commits into from
Nov 21, 2022
Merged

Unifying the FX and TS Frontends #1404

merged 23 commits into from
Nov 21, 2022

Conversation

narendasan
Copy link
Collaborator

Description

This PR will merged the backends for FX and TS to use a common set of resources including sharing the Torch-TRT runtime which will enable users to post FX compilation, TS and serialize compiled modules and add C++ execution to FX.

Fixes #1372

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: api [C++] Issues re: C++ API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: runtime component: tests Issues re: Tests labels Oct 13, 2022
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp b/tmp/changes.txt
index e6a6c01..b88ebc8 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp
+++ b/tmp/changes.txt
@@ -376,7 +376,7 @@ auto expand_registrations TORCHTRT_UNUSED =
               std::vector<int64_t> collapse_shape_vec;
               for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
                 if (k == dim) {
-                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k+1];
+                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k + 1];
                   // Set dim size to -1 if repeat is being done on dynamic dim
                   collapse_dim = std::max(collapse_dim, (int64_t)-1);
                   collapse_shape_vec.push_back(collapse_dim);
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txt
index 9f7dd16..a761f28 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp
+++ b/tmp/changes.txt
@@ -89,8 +89,8 @@ void TRTEngine::set_paths() {
  execution_profile_path = profile_path + "/" + name + "_execution_profile.trace";
  device_profile_path = profile_path + "/" + name + "_device_config_profile.trace";
  input_profile_path = profile_path + "/" + name + "_input_profile.trace";
-  output_profile_path =  profile_path + "/" + name + "_output_profile.trace";
-  enqueue_profile_path =  profile_path + "/" + name + "_enqueue_profile.trace";
+  output_profile_path = profile_path + "/" + name + "_output_profile.trace";
+  enqueue_profile_path = profile_path + "/" + name + "_enqueue_profile.trace";
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp b/tmp/changes.txt
index 727c36f..9896dbe 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp
+++ b/tmp/changes.txt
@@ -62,7 +62,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

  std::unique_ptr<torch::autograd::profiler::RecordProfile> execution_profiler_guard;
  if (compiled_engine->debug) {
-    execution_profiler_guard.reset(new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
+    execution_profiler_guard.reset(
+        new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
  }

  {
@@ -113,11 +114,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
      compiled_engine->exec_ctx->setBindingDimensions(i, dims);
      gpu_handles.push_back(contig_inputs.back().data_ptr());
    }
-  TORCHTRT_CHECK(
-      compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
+    TORCHTRT_CHECK(
+        compiled_engine->exec_ctx->allInputDimensionsSpecified(),
+        "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
  }

-
  std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
  {
    std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h b/tmp/changes.txt
index 6959780..b46ba32 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h
+++ b/tmp/changes.txt
@@ -28,6 +28,6 @@ void set_cuda_device(CUDADevice& cuda_device);
// Gets the current active GPU (DLA will not show up through this)
CUDADevice get_current_device();

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h b/tmp/changes.txt
index 420e373..37afb99 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h
+++ b/tmp/changes.txt
@@ -5,10 +5,10 @@
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
-#include "core/util/prelude.h"
-#include "torch/custom_class.h"
#include "core/runtime/CUDADevice.h"
#include "core/runtime/TRTEngine.h"
+#include "core/util/prelude.h"
+#include "torch/custom_class.h"

namespace torch_tensorrt {
namespace core {
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h b/tmp/changes.txt
index dd31fdd..19e5d9b 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h
+++ b/tmp/changes.txt
@@ -50,6 +50,6 @@ struct TRTEngine : torch::CustomClassHolder {
  // c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
};

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp b/tmp/changes.txt
index 727c36f..9896dbe 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp
+++ b/tmp/changes.txt
@@ -62,7 +62,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

  std::unique_ptr<torch::autograd::profiler::RecordProfile> execution_profiler_guard;
  if (compiled_engine->debug) {
-    execution_profiler_guard.reset(new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
+    execution_profiler_guard.reset(
+        new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
  }

  {
@@ -113,11 +114,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
      compiled_engine->exec_ctx->setBindingDimensions(i, dims);
      gpu_handles.push_back(contig_inputs.back().data_ptr());
    }
-  TORCHTRT_CHECK(
-      compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
+    TORCHTRT_CHECK(
+        compiled_engine->exec_ctx->allInputDimensionsSpecified(),
+        "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
  }

-
  std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
  {
    std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txt
index 9f7dd16..a761f28 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp
+++ b/tmp/changes.txt
@@ -89,8 +89,8 @@ void TRTEngine::set_paths() {
  execution_profile_path = profile_path + "/" + name + "_execution_profile.trace";
  device_profile_path = profile_path + "/" + name + "_device_config_profile.trace";
  input_profile_path = profile_path + "/" + name + "_input_profile.trace";
-  output_profile_path =  profile_path + "/" + name + "_output_profile.trace";
-  enqueue_profile_path =  profile_path + "/" + name + "_enqueue_profile.trace";
+  output_profile_path = profile_path + "/" + name + "_output_profile.trace";
+  enqueue_profile_path = profile_path + "/" + name + "_enqueue_profile.trace";
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp b/tmp/changes.txt
index e6a6c01..b88ebc8 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp
+++ b/tmp/changes.txt
@@ -376,7 +376,7 @@ auto expand_registrations TORCHTRT_UNUSED =
               std::vector<int64_t> collapse_shape_vec;
               for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
                 if (k == dim) {
-                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k+1];
+                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k + 1];
                   // Set dim size to -1 if repeat is being done on dynamic dim
                   collapse_dim = std::max(collapse_dim, (int64_t)-1);
                   collapse_shape_vec.push_back(collapse_dim);
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h b/tmp/changes.txt
index f76880b..6adb5a7 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h
+++ b/tmp/changes.txt
@@ -50,6 +50,6 @@ struct TRTEngine : torch::CustomClassHolder {
  // c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
};

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h b/tmp/changes.txt
index 6959780..b46ba32 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h
+++ b/tmp/changes.txt
@@ -28,6 +28,6 @@ void set_cuda_device(CUDADevice& cuda_device);
// Gets the current active GPU (DLA will not show up through this)
CUDADevice get_current_device();

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h b/tmp/changes.txt
index 420e373..37afb99 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h
+++ b/tmp/changes.txt
@@ -5,10 +5,10 @@
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
-#include "core/util/prelude.h"
-#include "torch/custom_class.h"
#include "core/runtime/CUDADevice.h"
#include "core/runtime/TRTEngine.h"
+#include "core/util/prelude.h"
+#include "torch/custom_class.h"

namespace torch_tensorrt {
namespace core {
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

…hich should give TS for free

Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp b/tmp/changes.txt
index 727c36f..9896dbe 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp
+++ b/tmp/changes.txt
@@ -62,7 +62,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

  std::unique_ptr<torch::autograd::profiler::RecordProfile> execution_profiler_guard;
  if (compiled_engine->debug) {
-    execution_profiler_guard.reset(new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
+    execution_profiler_guard.reset(
+        new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
  }

  {
@@ -113,11 +114,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
      compiled_engine->exec_ctx->setBindingDimensions(i, dims);
      gpu_handles.push_back(contig_inputs.back().data_ptr());
    }
-  TORCHTRT_CHECK(
-      compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
+    TORCHTRT_CHECK(
+        compiled_engine->exec_ctx->allInputDimensionsSpecified(),
+        "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
  }

-
  std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
  {
    std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txt
index 9f7dd16..a761f28 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp
+++ b/tmp/changes.txt
@@ -89,8 +89,8 @@ void TRTEngine::set_paths() {
  execution_profile_path = profile_path + "/" + name + "_execution_profile.trace";
  device_profile_path = profile_path + "/" + name + "_device_config_profile.trace";
  input_profile_path = profile_path + "/" + name + "_input_profile.trace";
-  output_profile_path =  profile_path + "/" + name + "_output_profile.trace";
-  enqueue_profile_path =  profile_path + "/" + name + "_enqueue_profile.trace";
+  output_profile_path = profile_path + "/" + name + "_output_profile.trace";
+  enqueue_profile_path = profile_path + "/" + name + "_enqueue_profile.trace";
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp b/tmp/changes.txt
index e6a6c01..b88ebc8 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp
+++ b/tmp/changes.txt
@@ -376,7 +376,7 @@ auto expand_registrations TORCHTRT_UNUSED =
               std::vector<int64_t> collapse_shape_vec;
               for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
                 if (k == dim) {
-                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k+1];
+                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k + 1];
                   // Set dim size to -1 if repeat is being done on dynamic dim
                   collapse_dim = std::max(collapse_dim, (int64_t)-1);
                   collapse_shape_vec.push_back(collapse_dim);
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h b/tmp/changes.txt
index f76880b..6adb5a7 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h
+++ b/tmp/changes.txt
@@ -50,6 +50,6 @@ struct TRTEngine : torch::CustomClassHolder {
  // c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
};

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h b/tmp/changes.txt
index 6959780..b46ba32 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h
+++ b/tmp/changes.txt
@@ -28,6 +28,6 @@ void set_cuda_device(CUDADevice& cuda_device);
// Gets the current active GPU (DLA will not show up through this)
CUDADevice get_current_device();

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h b/tmp/changes.txt
index 420e373..37afb99 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h
+++ b/tmp/changes.txt
@@ -5,10 +5,10 @@
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
-#include "core/util/prelude.h"
-#include "torch/custom_class.h"
#include "core/runtime/CUDADevice.h"
#include "core/runtime/TRTEngine.h"
+#include "core/util/prelude.h"
+#include "torch/custom_class.h"

namespace torch_tensorrt {
namespace core {
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- py/torch_tensorrt/fx/trt_module_next.py	2022-10-13 23:32:49.728214 +0000
+++ py/torch_tensorrt/fx/trt_module_next.py	2022-10-13 23:33:15.516818 +0000
@@ -4,42 +4,45 @@
import torch

from torch.classes.tensorrt import Engine
from torch.ops.tensorrt import execute_engine

-from torch_tensorrt import (_C, Device)
+from torch_tensorrt import _C, Device
+

class TRTModule(torch.nn.module):
    def __init__(
        self,
        engine_name: str,
        device_info: Device,
        serialized_engine: bytearray,
    ):
        super(TRTModule, self).__init__()
-        self.engine = Engine([
-            _C.rt.ABI_VERSION,
-            engine_name,
-            device_info._to_internal_cuda_device_str(),
-            serialized_engine
-        ])
+        self.engine = Engine(
+            [
+                _C.rt.ABI_VERSION,
+                engine_name,
+                device_info._to_internal_cuda_device_str(),
+                serialized_engine,
+            ]
+        )

    def forward(self, *inputs):
        try:
            assert all([i.issubclass(torch.Tensor) for i in inputs])
        except:
            raise RuntimeError("TRTModule expects a flattened list of tensors as input")
        outputs = execute_engine(list(inputs), self.engine)
        return tuple(outputs)

    def enable_profiling(self, profiler: None):
-        #TODO: CHANGE THIS SO IT MAKE MORE SENSE
+        # TODO: CHANGE THIS SO IT MAKE MORE SENSE
        self.engine.debug = True

    def disable_profiling(self):
-        #TODO: HERE TOO
+        # TODO: HERE TOO
        self.engine.debug = False

    def get_layer_info(self) -> str:
        raise RuntimeError("Engine Inspector needs to be implemented")
-        #assert TRT VERSION > 8.2
-        return self.engine.get_engine_information(_C.LayerInformationFormat.JSON)
\ No newline at end of file
+        # assert TRT VERSION > 8.2
+        return self.engine.get_engine_information(_C.LayerInformationFormat.JSON)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp b/tmp/changes.txt
index 727c36f..9896dbe 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/execute_engine.cpp
+++ b/tmp/changes.txt
@@ -62,7 +62,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

  std::unique_ptr<torch::autograd::profiler::RecordProfile> execution_profiler_guard;
  if (compiled_engine->debug) {
-    execution_profiler_guard.reset(new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
+    execution_profiler_guard.reset(
+        new torch::autograd::profiler::RecordProfile(compiled_engine->execution_profile_path));
  }

  {
@@ -113,11 +114,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
      compiled_engine->exec_ctx->setBindingDimensions(i, dims);
      gpu_handles.push_back(contig_inputs.back().data_ptr());
    }
-  TORCHTRT_CHECK(
-      compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
+    TORCHTRT_CHECK(
+        compiled_engine->exec_ctx->allInputDimensionsSpecified(),
+        "Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
  }

-
  std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
  {
    std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txt
index 9f7dd16..a761f28 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp
+++ b/tmp/changes.txt
@@ -89,8 +89,8 @@ void TRTEngine::set_paths() {
  execution_profile_path = profile_path + "/" + name + "_execution_profile.trace";
  device_profile_path = profile_path + "/" + name + "_device_config_profile.trace";
  input_profile_path = profile_path + "/" + name + "_input_profile.trace";
-  output_profile_path =  profile_path + "/" + name + "_output_profile.trace";
-  enqueue_profile_path =  profile_path + "/" + name + "_enqueue_profile.trace";
+  output_profile_path = profile_path + "/" + name + "_output_profile.trace";
+  enqueue_profile_path = profile_path + "/" + name + "_enqueue_profile.trace";
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp b/tmp/changes.txt
index e6a6c01..b88ebc8 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/expand.cpp
+++ b/tmp/changes.txt
@@ -376,7 +376,7 @@ auto expand_registrations TORCHTRT_UNUSED =
               std::vector<int64_t> collapse_shape_vec;
               for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
                 if (k == dim) {
-                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k+1];
+                   int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k + 1];
                   // Set dim size to -1 if repeat is being done on dynamic dim
                   collapse_dim = std::max(collapse_dim, (int64_t)-1);
                   collapse_shape_vec.push_back(collapse_dim);
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h b/tmp/changes.txt
index f76880b..6adb5a7 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.h
+++ b/tmp/changes.txt
@@ -50,6 +50,6 @@ struct TRTEngine : torch::CustomClassHolder {
  // c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
};

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h b/tmp/changes.txt
index 6959780..b46ba32 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/CUDADevice.h
+++ b/tmp/changes.txt
@@ -28,6 +28,6 @@ void set_cuda_device(CUDADevice& cuda_device);
// Gets the current active GPU (DLA will not show up through this)
CUDADevice get_current_device();

-} // namespace torch_tensorrt
-} // namespace core
} // namespace runtime
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h b/tmp/changes.txt
index 420e373..37afb99 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/runtime.h
+++ b/tmp/changes.txt
@@ -5,10 +5,10 @@
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
-#include "core/util/prelude.h"
-#include "torch/custom_class.h"
#include "core/runtime/CUDADevice.h"
#include "core/runtime/TRTEngine.h"
+#include "core/util/prelude.h"
+#include "torch/custom_class.h"

namespace torch_tensorrt {
namespace core {
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- py/torch_tensorrt/_Input.py	2022-10-13 23:49:25.765299 +0000
+++ py/torch_tensorrt/_Input.py	2022-10-13 23:49:48.468560 +0000
@@ -281,16 +281,30 @@
        return cls(shape=t.shape, dtype=t.dtype, format=frmt)

    def example_tensor(self, optimization_profile_field: str = None):
        if optimization_profile_field is not None:
            try:
-                assert any([optimization_profile_field == field_name for field_name in ["min_shape", "opt_shape", "max_shape"]])
+                assert any(
+                    [
+                        optimization_profile_field == field_name
+                        for field_name in ["min_shape", "opt_shape", "max_shape"]
+                    ]
+                )
            except:
-                raise ValueError("Invalid field name, expected one of min_shape, opt_shape, max_shape")
-
-        if optimization_profile_field is not None and self.shape_mode == Input._ShapeMode.STATIC:
-            raise ValueError("Specified a optimization profile field but the input is static")
+                raise ValueError(
+                    "Invalid field name, expected one of min_shape, opt_shape, max_shape"
+                )
+
+        if (
+            optimization_profile_field is not None
+            and self.shape_mode == Input._ShapeMode.STATIC
+        ):
+            raise ValueError(
+                "Specified a optimization profile field but the input is static"
+            )

        if self.shape_mode == Input._ShapeMode.STATIC:
            return torch.randn(self.shape).to(dtype=self.dtype)
        else:
-            return torch.randn(self.shape[optimization_profile_field]).to(dtype=self.dtype)
+            return torch.randn(self.shape[optimization_profile_field]).to(
+                dtype=self.dtype
+            )
--- py/torch_tensorrt/fx/trt_module_next.py	2022-10-13 23:49:25.773299 +0000
+++ py/torch_tensorrt/fx/trt_module_next.py	2022-10-13 23:49:50.800654 +0000
@@ -4,42 +4,45 @@
import torch

from torch.classes.tensorrt import Engine
from torch.ops.tensorrt import execute_engine

-from torch_tensorrt import (_C, Device)
+from torch_tensorrt import _C, Device
+

class TRTModule(torch.nn.module):
    def __init__(
        self,
        engine_name: str,
        device_info: Device,
        serialized_engine: bytearray,
    ):
        super(TRTModule, self).__init__()
-        self.engine = Engine([
-            _C.rt.ABI_VERSION,
-            engine_name,
-            device_info._to_internal_cuda_device_str(),
-            serialized_engine
-        ])
+        self.engine = Engine(
+            [
+                _C.rt.ABI_VERSION,
+                engine_name,
+                device_info._to_internal_cuda_device_str(),
+                serialized_engine,
+            ]
+        )

    def forward(self, *inputs):
        try:
            assert all([i.issubclass(torch.Tensor) for i in inputs])
        except:
            raise RuntimeError("TRTModule expects a flattened list of tensors as input")
        outputs = execute_engine(list(inputs), self.engine)
        return tuple(outputs)

    def enable_profiling(self, profiler: None):
-        #TODO: CHANGE THIS SO IT MAKE MORE SENSE
+        # TODO: CHANGE THIS SO IT MAKE MORE SENSE
        self.engine.debug = True

    def disable_profiling(self):
-        #TODO: HERE TOO
+        # TODO: HERE TOO
        self.engine.debug = False

    def get_layer_info(self) -> str:
        raise RuntimeError("Engine Inspector needs to be implemented")
-        #assert TRT VERSION > 8.2
-        return self.engine.get_engine_information(_C.LayerInformationFormat.JSON)
\ No newline at end of file
+        # assert TRT VERSION > 8.2
+        return self.engine.get_engine_information(_C.LayerInformationFormat.JSON)


from torch_tensorrt import (_C, Device)

class TRTModule(torch.nn.module):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this some staging module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this would be an alternate to the existing TRTModule so it is opt-in for users for the next release for evaluation.

@narendasan narendasan self-assigned this Oct 25, 2022
@narendasan narendasan added the release: v1.3 Tagged to be included in v1.3 label Oct 25, 2022
@narendasan
Copy link
Collaborator Author

@frank-wei can you take a look at this PR? Hopefully it doesnt introduce any breaking changes

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@narendasan
Copy link
Collaborator Author

Also @peri044 can you review on the TS side?

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Nov 18, 2022
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Contributor

@frank-wei frank-wei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, the PR looks good to me. It provides a way for users who wants to deploy in C++. But user can also choose to use fx TRTModule to wrap the engine and jit trace later.

from torch_tensorrt._Device import Device


class TRTModule(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we consider a different name or adding a suffix to differentiate the TRTModule in fx folder? That one is still used internally.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about TRTModuleNext?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me.

@narendasan narendasan requested a review from gs-olive November 21, 2022 16:23
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

experimental

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few minor comments, otherwise looks good and is great to use! Adding this review as comments since these are primarily usability-based suggestions.

py/torch_tensorrt/_TRTModuleNext.py Show resolved Hide resolved
py/torch_tensorrt/_TRTModuleNext.py Outdated Show resolved Hide resolved
py/torch_tensorrt/_TRTModuleNext.py Show resolved Hide resolved
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Also new destructor to order cleanup

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@narendasan narendasan merged commit 0e7f4fe into master Nov 21, 2022
@narendasan narendasan deleted the shared_core branch November 21, 2022 23:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [C++] Issues re: C++ API component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: fx component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation release: v1.3 Tagged to be included in v1.3
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants