Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying the FX and TS Frontends #1404

Merged
merged 23 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e84c8aa
feat: Adding profiling support to the runtime
narendasan Oct 13, 2022
65a7c13
refactor: A new TRTModule implementation using the internal runtime w…
narendasan Oct 13, 2022
808f3e2
feat: let Input generate random tensors following the spec
narendasan Oct 13, 2022
0f003b8
feat!(//core/runtime): Allow the Runtime to use binding names to alig…
narendasan Nov 3, 2022
df3ac77
fix(//core/runtime): Resolving some issues with the runtime ABI
narendasan Nov 4, 2022
e804455
feat(//core/runtime): Adding a TRT layer profiler
narendasan Nov 4, 2022
bbaf152
feat(//py): Exposed the new runtime in Python
narendasan Nov 4, 2022
71872df
feat(//py/torch_tensorrt/fx): Compliant TRTModule implementation based
narendasan Nov 11, 2022
10afcb2
refactor: CUDADevice -> RTDevice for better distinction from compile
narendasan Nov 11, 2022
d14c7c4
feat(//examples): Demo that you can compile using FX then deploy in
narendasan Nov 11, 2022
7df9032
refactor(//py/torch_tensorrt): Updates to existing APIs for use in fx
narendasan Nov 11, 2022
e3b01f3
feat(//core/runtime): Encode TRT engine in base64 instead of raw bytes
narendasan Nov 18, 2022
ea270e3
feat(//py/torch_tensorrt/fx): Adding the option to use the experimental
narendasan Nov 18, 2022
c7e757b
fix(//core/runtime): Fixing a bug where if an exception is thrown in
narendasan Nov 18, 2022
2e299db
feat(//py/torch_tensorrt/TRTModule): Allow state_dict extraction
narendasan Nov 18, 2022
e9ae43e
Merge remote-tracking branch 'origin/master' into shared_core
narendasan Nov 18, 2022
418bc6f
chore: Addressing merge conflicts
narendasan Nov 18, 2022
131759a
chore: lint
narendasan Nov 18, 2022
4c544a3
chore: remove print statements
narendasan Nov 18, 2022
bdc48d4
fix: Fix cmake build
narendasan Nov 18, 2022
58681f9
refactor: Add a suffix to the TRTModuleNext class while it's
narendasan Nov 21, 2022
71082d3
docs: Update docs and examples
narendasan Nov 21, 2022
e782cc9
refactor: Reorder the API since everything but the engine is optional
narendasan Nov 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
const std::string& serialized_engine,
runtime::CudaDevice& device_info,
runtime::CUDADevice& device_info,
std::string engine_id = "",
bool fallback = false) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(
Expand Down Expand Up @@ -166,7 +166,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
auto cuda_device = runtime::CUDADevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

seg_block.update_graph(temp_g);
Expand Down Expand Up @@ -283,7 +283,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");

auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
auto cuda_device = runtime::CUDADevice(device_spec.gpu_id, device_spec.device_type);

for (const torch::jit::Method& method : mod.get_methods()) {
if (method.name().compare("forward") == 0) {
Expand Down Expand Up @@ -342,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
return new_mod;
}

torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device) {
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CUDADevice cuda_device) {
std::ostringstream engine_id;
engine_id << reinterpret_cast<const int*>(&engine);
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device);
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CUDADevice cuda_device);

void set_device(const int gpu_id);

Expand Down
5 changes: 3 additions & 2 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,13 @@ auto expand_registrations TORCHTRT_UNUSED =

// Collapse repeated dimension back into desired dimension
std::vector<int64_t> collapse_shape_vec;
for (int k = 0; k < repeat_shape_dims.nbDims; k++) {
for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
if (k == dim) {
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k];
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k+1];
// Set dim size to -1 if repeat is being done on dynamic dim
collapse_dim = std::max(collapse_dim, (int64_t)-1);
collapse_shape_vec.push_back(collapse_dim);
k++;
} else {
collapse_shape_vec.push_back(repeat_shape_dims.d[k]);
}
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ auto select_registrations TORCHTRT_UNUSED =

std::vector<nvinfer1::ITensor*> tensors;
std::vector<int32_t> adv_idx_indices;
for (auto i = 0; i < ts.size(); i++) {
for (size_t i = 0; i < ts.size(); i++) {
auto t = ts[i];
if (t.isTensor()) {
auto torch_tensor = t.toTensor().to(torch::kInt32);
Expand Down
10 changes: 8 additions & 2 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ config_setting(
cc_library(
name = "runtime",
srcs = [
"CudaDevice.cpp",
"CUDADevice.cpp",
"DeviceList.cpp",
"TRTEngine.cpp",
"execute_engine.cpp",
Expand All @@ -22,6 +22,8 @@ cc_library(
],
hdrs = [
"runtime.h",
"CUDADevice.h",
"TRTEngine.h"
],
deps = [
"@tensorrt//:nvinfer",
Expand All @@ -36,6 +38,10 @@ cc_library(

pkg_tar(
name = "include",
srcs = ["runtime.h"],
srcs = [
"runtime.h",
"CUDADevice.h",
"TRTEngine.h"
],
package_dir = "core/runtime/",
)
4 changes: 3 additions & 1 deletion core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ set(lib_name "core_runtime")
add_library(${lib_name} OBJECT)

set(CXX_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/CudaDevice.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/CUDADevice.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/DeviceList.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.cpp"
Expand All @@ -12,6 +12,8 @@ set(CXX_SRCS

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
"${CMAKE_CURRENT_SOURCE_DIR}/CUDADevice.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
)

target_sources(${lib_name}
Expand Down
16 changes: 8 additions & 8 deletions core/runtime/CudaDevice.cpp → core/runtime/CUDADevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ const std::string DEVICE_INFO_DELIM = "%";

typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex;

CudaDevice::CudaDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}
CUDADevice::CUDADevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}

CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
CudaDevice cuda_device;
CUDADevice::CUDADevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
CUDADevice cuda_device;
cudaDeviceProp device_prop;

// Device ID
Expand All @@ -41,7 +41,7 @@ CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
// NOTE: Serialization Format for Device Info:
// id%major%minor%(enum)device_type%device_name

CudaDevice::CudaDevice(std::string device_info) {
CUDADevice::CUDADevice(std::string device_info) {
LOG_DEBUG("Deserializing Device Info: " << device_info);

std::vector<std::string> tokens;
Expand All @@ -66,7 +66,7 @@ CudaDevice::CudaDevice(std::string device_info) {
LOG_DEBUG("Deserialized Device Info: " << *this);
}

CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
CUDADevice& CUDADevice::operator=(const CUDADevice& other) {
id = other.id;
major = other.major;
minor = other.minor;
Expand All @@ -75,7 +75,7 @@ CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
return (*this);
}

std::string CudaDevice::serialize() {
std::string CUDADevice::serialize() {
std::vector<std::string> content;
content.resize(DEVICE_NAME_IDX + 1);

Expand All @@ -98,13 +98,13 @@ std::string CudaDevice::serialize() {
return serialized_device_info;
}

std::string CudaDevice::getSMCapability() const {
std::string CUDADevice::getSMCapability() const {
std::stringstream ss;
ss << major << "." << minor;
return ss.str();
}

std::ostream& operator<<(std::ostream& os, const CudaDevice& device) {
std::ostream& operator<<(std::ostream& os, const CUDADevice& device) {
os << "Device(ID: " << device.id << ", Name: " << device.device_name << ", SM Capability: " << device.major << '.'
<< device.minor << ", Type: " << device.device_type << ')';
return os;
Expand Down
33 changes: 33 additions & 0 deletions core/runtime/CUDADevice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once
#include <string>
#include "NvInfer.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

struct CUDADevice {
int64_t id; // CUDA device id
int64_t major; // CUDA compute major version
int64_t minor; // CUDA compute minor version
nvinfer1::DeviceType device_type;
std::string device_name;

CUDADevice();
CUDADevice(int64_t gpu_id, nvinfer1::DeviceType device_type);
CUDADevice(std::string serialized_device_info);
~CUDADevice() = default;
CUDADevice(const CUDADevice& other) = default;
CUDADevice& operator=(const CUDADevice& other);
std::string serialize();
std::string getSMCapability() const;
friend std::ostream& operator<<(std::ostream& os, const CUDADevice& device);
};

void set_cuda_device(CUDADevice& cuda_device);
// Gets the current active GPU (DLA will not show up through this)
CUDADevice get_current_device();

} // namespace torch_tensorrt
} // namespace core
} // namespace runtime
6 changes: 3 additions & 3 deletions core/runtime/DeviceList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ DeviceList::DeviceList() {
}

for (int i = 0; i < num_devices; i++) {
device_list[i] = CudaDevice(i, nvinfer1::DeviceType::kGPU);
device_list[i] = CUDADevice(i, nvinfer1::DeviceType::kGPU);
}

// REVIEW: DO WE CARE ABOUT DLA?

LOG_DEBUG("Runtime:\n Available CUDA Devices: \n" << this->dump_list());
}

void DeviceList::insert(int device_id, CudaDevice cuda_device) {
void DeviceList::insert(int device_id, CUDADevice cuda_device) {
device_list[device_id] = cuda_device;
}

CudaDevice DeviceList::find(int device_id) {
CUDADevice DeviceList::find(int device_id) {
return device_list[device_id];
}

Expand Down
14 changes: 11 additions & 3 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ std::string slugify(std::string s) {
return s;
}

TRTEngine::TRTEngine(std::string serialized_engine, CudaDevice cuda_device) {
TRTEngine::TRTEngine(std::string serialized_engine, CUDADevice cuda_device) {
std::string _name = "deserialized_trt";
new (this) TRTEngine(_name, serialized_engine, cuda_device);
}
Expand All @@ -33,11 +33,11 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) {
std::string _name = serialized_info[NAME_IDX];
std::string engine_info = serialized_info[ENGINE_IDX];

CudaDevice cuda_device(serialized_info[DEVICE_IDX]);
CUDADevice cuda_device(serialized_info[DEVICE_IDX]);
new (this) TRTEngine(_name, engine_info, cuda_device);
}

TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device) {
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CUDADevice cuda_device) {
auto most_compatible_device = get_most_compatible_device(cuda_device);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
device_info = most_compatible_device.value();
Expand Down Expand Up @@ -85,6 +85,14 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
LOG_DEBUG(*this);
}

void TRTEngine::set_paths() {
execution_profile_path = profile_path + "/" + name + "_execution_profile.trace";
device_profile_path = profile_path + "/" + name + "_device_config_profile.trace";
input_profile_path = profile_path + "/" + name + "_input_profile.trace";
output_profile_path = profile_path + "/" + name + "_output_profile.trace";
enqueue_profile_path = profile_path + "/" + name + "_enqueue_profile.trace";
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
rt = other.rt;
cuda_engine = other.cuda_engine;
Expand Down
55 changes: 55 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once
#include <map>
#include <memory>
#include <mutex>
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
#include "core/util/prelude.h"
#include "torch/custom_class.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
std::pair<uint64_t, uint64_t> num_io;
std::string name;
std::mutex mu;
CUDADevice device_info;

std::string execution_profile_path;
std::string device_profile_path;
std::string input_profile_path;
std::string output_profile_path;
std::string enqueue_profile_path;
std::string profile_path = "/tmp";

std::unordered_map<uint64_t, uint64_t> in_binding_map;
std::unordered_map<uint64_t, uint64_t> out_binding_map;

#ifndef NDEBUG
bool debug = true;
#else
bool debug = false;
#endif

~TRTEngine() = default;
TRTEngine(std::string serialized_engine, CUDADevice cuda_device);
TRTEngine(std::vector<std::string> serialized_info);
TRTEngine(std::string mod_name, std::string serialized_engine, CUDADevice cuda_device);
TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
void set_paths();
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
};

} // namespace torch_tensorrt
} // namespace core
} // namespace runtime
Loading