Skip to content

Commit

Permalink
Unifying the FX and TS Runtimes (#1404)
Browse files Browse the repository at this point in the history
* feat: Adding profiling support to the runtime

Signed-off-by: Naren Dasan <[email protected]>

* refactor: A new TRTModule implementation using the internal runtime which should give TS for free

Signed-off-by: Naren Dasan <[email protected]>

* feat: let Input generate random tensors following the spec

Signed-off-by: Naren Dasan <[email protected]>

* feat!(//core/runtime): Allow the Runtime to use binding names to align I/O

BREAKING CHANGE: This commit contains an ABI version upgrade meaning
that existing compiled modules will not work with this runtime.
Recompilation with a newer version of Torch-TensorRT will fix this.

This also ammends the C++ to allow users to explicitly set binding names
in the order they will be passed in and are expected to be returned.
This change is backwards compatible with the current API.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* fix(//core/runtime): Resolving some issues with the runtime ABI

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//core/runtime): Adding a TRT layer profiler

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//py): Exposed the new runtime in Python

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//py/torch_tensorrt/fx): Compliant TRTModule implementation based
on shared Torch-TensorRT runtime

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* refactor: CUDADevice -> RTDevice for better distinction from compile
time device

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//examples): Demo that you can compile using FX then deploy in
TS!!!

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* refactor(//py/torch_tensorrt): Updates to existing APIs for use in fx

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//core/runtime): Encode TRT engine in base64 instead of raw bytes

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//py/torch_tensorrt/fx): Adding the option to use the experimental
runtime

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* fix(//core/runtime): Fixing a bug where if an exception is thrown in
downstream constructor, it would cause a segfault

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* feat(//py/torch_tensorrt/TRTModule): Allow state_dict extraction

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* chore: Addressing merge conflicts

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* chore: lint

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* chore: remove print statements

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* fix: Fix cmake build

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* refactor: Add a suffix to the TRTModuleNext class while it's
experimental

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* docs: Update docs and examples

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

* refactor: Reorder the API since everything but the engine is optional

Also new destructor to order cleanup

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan authored Nov 21, 2022
1 parent e7bb8c2 commit 0e7f4fe
Show file tree
Hide file tree
Showing 50 changed files with 1,704 additions and 281 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
cmake_minimum_required(VERSION 3.17)
project(Torch-TensorRT LANGUAGES CXX)

# use c++17
set(CMAKE_CXX_STANDARD 17)
# use c++14 like PyTorch
set(CMAKE_CXX_STANDARD 14)

# Build the libraries with -fPIC
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down
34 changes: 26 additions & 8 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
const std::string& serialized_engine,
runtime::CudaDevice& device_info,
runtime::RTDevice& device_info,
const std::vector<std::string>& input_binding_names,
const std::vector<std::string>& output_binding_names,
std::string engine_id = "",
bool fallback = false) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(
mod._ivalue()->name() + "_engine_" + engine_id, serialized_engine, device_info);
mod._ivalue()->name() + "_engine_" + engine_id,
serialized_engine,
device_info,
input_binding_names,
output_binding_names);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;
Expand Down Expand Up @@ -162,8 +168,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
auto cuda_device = runtime::RTDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(
new_mod,
temp_g,
engine,
cuda_device,
std::vector<std::string>(),
std::vector<std::string>(),
trt_engine_id.str(),
true);

seg_block.update_graph(temp_g);
}
Expand Down Expand Up @@ -279,7 +293,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");

auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
auto cuda_device = runtime::RTDevice(device_spec.gpu_id, device_spec.device_type);

for (const torch::jit::Method& method : mod.get_methods()) {
if (method.name().compare("forward") == 0) {
Expand Down Expand Up @@ -327,7 +341,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
"Not all operations in graph are supported by the compiler");
// TODO find the right
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
AddEngineToGraph(new_mod, new_g, engine, cuda_device, std::vector<std::string>(), std::vector<std::string>());
}
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
Expand All @@ -338,12 +352,16 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
return new_mod;
}

torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device) {
torch::jit::script::Module EmbedEngineInNewModule(
const std::string& engine,
runtime::RTDevice cuda_device,
const std::vector<std::string>& input_binding_names,
const std::vector<std::string>& output_binding_names) {
std::ostringstream engine_id;
engine_id << reinterpret_cast<const int*>(&engine);
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
AddEngineToGraph(new_mod, new_g, engine, cuda_device, input_binding_names, output_binding_names);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
Expand Down
6 changes: 5 additions & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device);
torch::jit::script::Module EmbedEngineInNewModule(
const std::string& engine,
runtime::RTDevice cuda_device,
const std::vector<std::string>& input_binding_names,
const std::vector<std::string>& output_binding_names);

void set_device(const int gpu_id);

Expand Down
5 changes: 3 additions & 2 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,13 @@ auto expand_registrations TORCHTRT_UNUSED =

// Collapse repeated dimension back into desired dimension
std::vector<int64_t> collapse_shape_vec;
for (int k = 0; k < repeat_shape_dims.nbDims; k++) {
for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
if (k == dim) {
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k];
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k + 1];
// Set dim size to -1 if repeat is being done on dynamic dim
collapse_dim = std::max(collapse_dim, (int64_t)-1);
collapse_shape_vec.push_back(collapse_dim);
k++;
} else {
collapse_shape_vec.push_back(repeat_shape_dims.d[k]);
}
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ auto select_registrations TORCHTRT_UNUSED =

std::vector<nvinfer1::ITensor*> tensors;
std::vector<int32_t> adv_idx_indices;
for (auto i = 0; i < ts.size(); i++) {
for (size_t i = 0; i < ts.size(); i++) {
auto t = ts[i];
if (t.isTensor()) {
auto torch_tensor = t.toTensor().to(torch::kInt32);
Expand Down
16 changes: 14 additions & 2 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@ config_setting(
cc_library(
name = "runtime",
srcs = [
"CudaDevice.cpp",
"DeviceList.cpp",
"RTDevice.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
"execute_engine.cpp",
"register_jit_hooks.cpp",
"runtime.cpp",
],
hdrs = [
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"runtime.h",
],
linkopts = [
"-lstdc++fs",
],
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
Expand All @@ -36,6 +43,11 @@ cc_library(

pkg_tar(
name = "include",
srcs = ["runtime.h"],
srcs = [
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"runtime.h",
],
package_dir = "core/runtime/",
)
9 changes: 7 additions & 2 deletions core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ set(lib_name "core_runtime")
add_library(${lib_name} OBJECT)

set(CXX_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/CudaDevice.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/DeviceList.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp"
)

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
)

Expand All @@ -29,6 +33,7 @@ target_link_libraries(${lib_name}
TensorRT::nvinfer
torch
core_util
stdc++fs
)

# Install
Expand Down
6 changes: 3 additions & 3 deletions core/runtime/DeviceList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ DeviceList::DeviceList() {
}

for (int i = 0; i < num_devices; i++) {
device_list[i] = CudaDevice(i, nvinfer1::DeviceType::kGPU);
device_list[i] = RTDevice(i, nvinfer1::DeviceType::kGPU);
}

// REVIEW: DO WE CARE ABOUT DLA?

LOG_DEBUG("Runtime:\n Available CUDA Devices: \n" << this->dump_list());
}

void DeviceList::insert(int device_id, CudaDevice cuda_device) {
void DeviceList::insert(int device_id, RTDevice cuda_device) {
device_list[device_id] = cuda_device;
}

CudaDevice DeviceList::find(int device_id) {
RTDevice DeviceList::find(int device_id) {
return device_list[device_id];
}

Expand Down
16 changes: 8 additions & 8 deletions core/runtime/CudaDevice.cpp → core/runtime/RTDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ const std::string DEVICE_INFO_DELIM = "%";

typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex;

CudaDevice::CudaDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}
RTDevice::RTDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}

CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
CudaDevice cuda_device;
RTDevice::RTDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
RTDevice cuda_device;
cudaDeviceProp device_prop;

// Device ID
Expand All @@ -41,7 +41,7 @@ CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
// NOTE: Serialization Format for Device Info:
// id%major%minor%(enum)device_type%device_name

CudaDevice::CudaDevice(std::string device_info) {
RTDevice::RTDevice(std::string device_info) {
LOG_DEBUG("Deserializing Device Info: " << device_info);

std::vector<std::string> tokens;
Expand All @@ -66,7 +66,7 @@ CudaDevice::CudaDevice(std::string device_info) {
LOG_DEBUG("Deserialized Device Info: " << *this);
}

CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
RTDevice& RTDevice::operator=(const RTDevice& other) {
id = other.id;
major = other.major;
minor = other.minor;
Expand All @@ -75,7 +75,7 @@ CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
return (*this);
}

std::string CudaDevice::serialize() {
std::string RTDevice::serialize() {
std::vector<std::string> content;
content.resize(DEVICE_NAME_IDX + 1);

Expand All @@ -98,13 +98,13 @@ std::string CudaDevice::serialize() {
return serialized_device_info;
}

std::string CudaDevice::getSMCapability() const {
std::string RTDevice::getSMCapability() const {
std::stringstream ss;
ss << major << "." << minor;
return ss.str();
}

std::ostream& operator<<(std::ostream& os, const CudaDevice& device) {
std::ostream& operator<<(std::ostream& os, const RTDevice& device) {
os << "Device(ID: " << device.id << ", Name: " << device.device_name << ", SM Capability: " << device.major << '.'
<< device.minor << ", Type: " << device.device_type << ')';
return os;
Expand Down
29 changes: 29 additions & 0 deletions core/runtime/RTDevice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once
#include <string>
#include "NvInfer.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

struct RTDevice {
int64_t id; // CUDA device id
int64_t major; // CUDA compute major version
int64_t minor; // CUDA compute minor version
nvinfer1::DeviceType device_type;
std::string device_name;

RTDevice();
RTDevice(int64_t gpu_id, nvinfer1::DeviceType device_type);
RTDevice(std::string serialized_device_info);
~RTDevice() = default;
RTDevice(const RTDevice& other) = default;
RTDevice& operator=(const RTDevice& other);
std::string serialize();
std::string getSMCapability() const;
friend std::ostream& operator<<(std::ostream& os, const RTDevice& device);
};

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
Loading

0 comments on commit 0e7f4fe

Please sign in to comment.