Skip to content

Commit

Permalink
feat(//core/runtime)!: Better and more portable names for engines
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This bumps the TRTorch ABI version to 3 due to
a new field for engine name included in the serialized form of
TRTEngine. This lets deserialized engines have the same name they
serialized with

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 7, 2021
1 parent c54ed13 commit 6eb3bb2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void AddEngineToGraph(
std::string engine_id = "",
bool fallback = false) {
auto engine_ptr =
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + "_engine_" + engine_id, serialized_engine, device_info);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;
Expand Down
21 changes: 12 additions & 9 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace trtorch {
namespace core {
namespace runtime {

typedef enum { ABI_TARGET_IDX = 0, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;

std::string slugify(std::string s) {
std::replace(s.begin(), s.end(), '.', '_');
Expand All @@ -37,8 +37,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
TRTORCH_CHECK(
serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
"Program to be deserialized targets a different TRTorch ABI Version ("
<< serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI (" << ABI_VERSION << ")");
std::string _name = "deserialized_trt";
<< serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI Version (" << ABI_VERSION << ")");
std::string _name = serialized_info[NAME_IDX];
std::string engine_info = serialized_info[ENGINE_IDX];

CudaDevice cuda_device = deserialize_device(serialized_info[DEVICE_IDX]);
Expand All @@ -55,7 +55,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe

rt = nvinfer1::createInferRuntime(logger);

name = slugify(mod_name) + "_engine";
name = slugify(mod_name);

cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");
Expand All @@ -70,8 +70,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
uint64_t outputs = 0;

for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
std::string name = cuda_engine->getBindingName(x);
std::string idx_s = name.substr(name.find("_") + 1);
std::string bind_name = cuda_engine->getBindingName(x);
std::string idx_s = bind_name.substr(bind_name.find("_") + 1);
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));

if (cuda_engine->bindingIsInput(x)) {
Expand Down Expand Up @@ -124,9 +124,12 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialize_info;
serialize_info.push_back(ABI_VERSION);
serialize_info.push_back(serialize_device(self->device_info));
serialize_info.push_back(trt_engine);
serialize_info.resize(ENGINE_IDX + 1);

serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
serialize_info[NAME_IDX] = self->name;
serialize_info[DEVICE_IDX] = serialize_device(self->device_info);
serialize_info[ENGINE_IDX] = trt_engine;
return serialize_info;
},
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace core {
namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "2";
const std::string ABI_VERSION = "3";

struct CudaDevice {
int64_t id; // CUDA device id
Expand Down

0 comments on commit 6eb3bb2

Please sign in to comment.