Skip to content

Commit

Permalink
feat: Save target platform as part of TRTEngine Metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 21, 2024
1 parent 6a38648 commit f2e902b
Show file tree
Hide file tree
Showing 18 changed files with 423 additions and 107 deletions.
18 changes: 15 additions & 3 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cc_library(
name = "runtime",
srcs = [
"DeviceList.cpp",
"Platform.cpp",
"RTDevice.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
Expand All @@ -29,6 +30,7 @@ cc_library(
"runtime.cpp",
],
hdrs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
Expand All @@ -41,16 +43,26 @@ cc_library(
"//core/plugins:torch_tensorrt_plugins",
"//core/util:prelude",
] + select({
":windows": ["@tensorrt_win//:nvinfer", "@libtorch_win//:libtorch"],
":use_pre_cxx11_abi": ["@tensorrt//:nvinfer", "@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@tensorrt//:nvinfer", "@libtorch"],
":use_pre_cxx11_abi": [
"@libtorch_pre_cxx11_abi//:libtorch",
"@tensorrt//:nvinfer",
],
":windows": [
"@libtorch_win//:libtorch",
"@tensorrt_win//:nvinfer",
],
"//conditions:default": [
"@libtorch",
"@tensorrt//:nvinfer",
],
}),
alwayslink = True,
)

pkg_tar(
name = "include",
srcs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ set(CXX_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.cpp"
)

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.h"
)

target_sources(${lib_name}
Expand Down
102 changes: 102 additions & 0 deletions core/runtime/Platform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "core/runtime/Platform.h"
#include "core/runtime/runtime.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

namespace {
const std::unordered_map<std::string, Platform::PlatformEnum>& get_name_to_platform_map() {
static const std::unordered_map<std::string, Platform::PlatformEnum> name_to_platform_map = {
{"linux_aarch64", Platform::PlatformEnum::kLINUX_AARCH64},
{"linux_x86_64", Platform::PlatformEnum::kLINUX_X86_64},
{"windows_x86_64", Platform::PlatformEnum::kWIN_X86_64},
{"unknown", Platform::PlatformEnum::kUNKNOWN},
};
return name_to_platform_map;
}

const std::unordered_map<Platform::PlatformEnum, std::string>& _get_platform_name_map() {
static const std::unordered_map<Platform::PlatformEnum, std::string> platform_name_map = {
{Platform::PlatformEnum::kLINUX_AARCH64, "linux_aarch64"},
{Platform::PlatformEnum::kLINUX_X86_64, "linux_x86_64"},
{Platform::PlatformEnum::kWIN_X86_64, "windows_x86_64"},
{Platform::PlatformEnum::kUNKNOWN, "unknown"}};
return platform_name_map;
}
} // namespace

const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map() {
return _get_platform_name_map();
}

Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {}

Platform::Platform(Platform::PlatformEnum val) : _platform{val} {}

Platform::Platform(const std::string& platform_str) {
LOG_ERROR("Platform constructor: " << platform_str);
auto name_map = get_name_to_platform_map();
auto it = name_map.find(platform_str);
if (it != name_map.end()) {
_platform = it->second;
} else {
LOG_WARNING("Unknown platform " << platform_str);
_platform = Platform::PlatformEnum::kUNKNOWN;
}
}

std::string Platform::serialize() const {
auto name_map = get_platform_name_map();
auto it = name_map.find(_platform);
if (it != name_map.end()) {
return it->second;
} else {
LOG_WARNING("Attempted to serialized unknown platform tag");
return std::string("unknown");
}
}

Platform& Platform::operator=(const Platform& other) {
_platform = other._platform;
return (*this);
}

bool operator==(const Platform& lhs, const Platform& rhs) {
return lhs._platform == rhs._platform;
}

std::ostream& operator<<(std::ostream& os, const Platform& platform) {
os << platform.serialize();
return os;
}

Platform get_current_platform() {
#if defined(__linux__) || defined(__gnu_linux__)
#if defined(__aarch64__)
return Platform(Platform::PlatformEnum::kLINUX_AARCH64);
#elif defined(__amd64__) || defined(__x86_64__)
return Platform(Platform::PlatformEnum::kLINUX_X86_64);
#else
return Platform(Platform::PlatformEnum::kLINUX_X86_64);
#endif
#elif defined(_WIN32) || defined(_WIN64)
#if defined(_M_AMD64) || defined(_M_X64)
return Platform(Platform::PlatformEnum::kWIN_X86_64);
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
}

bool is_supported_on_current_platform(Platform target) {
// Space for more complicated platform support calculations later
return target == get_current_platform();
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
35 changes: 35 additions & 0 deletions core/runtime/Platform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <string>
#include <unordered_map>

namespace torch_tensorrt {
namespace core {
namespace runtime {

struct Platform {
typedef enum {
kLINUX_X86_64 = 0,
kLINUX_AARCH64,
kWIN_X86_64,
kUNKNOWN,
} PlatformEnum;

PlatformEnum _platform = Platform::kUNKNOWN;

Platform();
Platform(PlatformEnum val);
Platform(const std::string& platform_str);
std::string serialize() const;
Platform& operator=(const Platform& other);

friend std::ostream& operator<<(std::ostream& os, const Platform& device);
friend bool operator==(const Platform& lhs, const Platform& rhs);
};

const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map();
Platform get_current_platform();
bool is_supported_on_current_platform(Platform target);

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
17 changes: 15 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
const std::string& serialized_metadata)
: TRTEngine(
Expand All @@ -42,6 +43,7 @@ TRTEngine::TRTEngine(
cuda_device,
_in_binding_names,
_out_binding_names,
target_platform,
hardware_compatible,
serialized_metadata) {}

Expand All @@ -52,6 +54,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
RTDevice(serialized_info[DEVICE_IDX]),
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
Platform(serialized_info[TARGET_PLATFORM_IDX]),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}

Expand All @@ -61,12 +64,22 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
const std::string& serialized_metadata) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
<< get_current_platform() << ")");
this->target_platform = target_platform;

this->cudagraph_mempool_id = at::cuda::graph_pool_handle();

this->hardware_compatible = hardware_compatible;
this->serialized_metadata = serialized_metadata;
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");

this->serialized_metadata = serialized_metadata;
device_info = most_compatible_device.value();
multi_gpu_device_check();
set_rt_device(device_info);
Expand Down Expand Up @@ -196,7 +209,6 @@ TRTEngine::TRTEngine(
}

TRTEngine::~TRTEngine() {
cudagraph.reset();
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
Expand Down Expand Up @@ -276,6 +288,7 @@ std::string TRTEngine::to_str() const {
ss << " ]" << std::endl;
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
// clang-format on
return ss.str();
}
Expand Down
7 changes: 7 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,30 @@ struct TRTEngine : torch::CustomClassHolder {
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
// in compilation
Platform target_platform;

~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
const std::string& serialized_metadata = "");

TRTEngine(std::vector<std::string> serialized_info);

TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
const std::string& serialized_metadata = "");

TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand All @@ -75,6 +81,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
at::cuda::MempoolId_t cudagraph_mempool_id;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (need_cudagraphs_record) {
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
compiled_engine->cudagraph.capture_begin();
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
compiled_engine->exec_ctx->enqueueV3(recording_stream);
compiled_engine->cudagraph.capture_end();

Expand Down
26 changes: 26 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <codecvt>

#include "core/runtime/Platform.h"
#include "core/runtime/runtime.h"
#include "core/util/macros.h"

namespace torch_tensorrt {
namespace core {
Expand Down Expand Up @@ -103,11 +105,14 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);

return serialize_info;
},
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
LOG_ERROR(serialized_info[TARGET_PLATFORM_IDX]);
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
Expand Down Expand Up @@ -137,7 +142,28 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; });
m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; });
m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; });
m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
m.def("_platform_linux_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
return it->second;
});
m.def("_platform_linux_aarch64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_AARCH64);
return it->second;
});
m.def("_platform_win_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kWIN_X86_64);
return it->second;
});
m.def("_platform_unknown", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kUNKNOWN);
return it->second;
});
m.def("get_current_platform", []() -> std::string {
auto it = get_platform_name_map().find(get_current_platform()._platform);
return it->second;
});
}

} // namespace
Expand Down
4 changes: 3 additions & 1 deletion core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
#include "core/runtime/Platform.h"
#include "core/runtime/RTDevice.h"
#include "core/runtime/TRTEngine.h"
#include "core/util/prelude.h"
Expand All @@ -15,7 +16,7 @@ namespace core {
namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "5";
const std::string ABI_VERSION = "6";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;

Expand All @@ -28,6 +29,7 @@ typedef enum {
OUTPUT_BINDING_NAMES_IDX,
HW_COMPATIBLE_IDX,
SERIALIZED_METADATA_IDX,
TARGET_PLATFORM_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand Down
Loading

0 comments on commit f2e902b

Please sign in to comment.