diff --git a/core/compiler.cpp b/core/compiler.cpp index 02a617a823..31fc2ce587 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -36,7 +36,7 @@ void AddEngineToGraph( std::string engine_id = "", bool fallback = false) { auto engine_ptr = - c10::make_intrusive(mod._ivalue()->name() + engine_id, serialized_engine, device_info); + c10::make_intrusive(mod._ivalue()->name() + "_engine_" + engine_id, serialized_engine, device_info); // Get required metadata about the engine out auto num_io = engine_ptr->num_io; auto name = engine_ptr->name; diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 6b480507e2..e70ad65420 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -11,7 +11,7 @@ namespace trtorch { namespace core { namespace runtime { -typedef enum { ABI_TARGET_IDX = 0, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex; +typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex; std::string slugify(std::string s) { std::replace(s.begin(), s.end(), '.', '_'); @@ -37,8 +37,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) TRTORCH_CHECK( serialized_info[ABI_TARGET_IDX] == ABI_VERSION, "Program to be deserialized targets a different TRTorch ABI Version (" - << serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI (" << ABI_VERSION << ")"); - std::string _name = "deserialized_trt"; + << serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI Version (" << ABI_VERSION << ")"); + std::string _name = serialized_info[NAME_IDX]; std::string engine_info = serialized_info[ENGINE_IDX]; CudaDevice cuda_device = deserialize_device(serialized_info[DEVICE_IDX]); @@ -55,7 +55,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe rt = nvinfer1::createInferRuntime(logger); - name = slugify(mod_name) + "_engine"; + name = slugify(mod_name); cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()); TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine"); @@ -70,8 +70,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe uint64_t outputs = 0; for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) { - std::string name = cuda_engine->getBindingName(x); - std::string idx_s = name.substr(name.find("_") + 1); + std::string bind_name = cuda_engine->getBindingName(x); + std::string idx_s = bind_name.substr(bind_name.find("_") + 1); uint64_t idx = static_cast(std::stoi(idx_s)); if (cuda_engine->bindingIsInput(x)) { @@ -124,9 +124,12 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size()); std::vector serialize_info; - serialize_info.push_back(ABI_VERSION); - serialize_info.push_back(serialize_device(self->device_info)); - serialize_info.push_back(trt_engine); + serialize_info.resize(ENGINE_IDX + 1); + + serialize_info[ABI_TARGET_IDX] = ABI_VERSION; + serialize_info[NAME_IDX] = self->name; + serialize_info[DEVICE_IDX] = serialize_device(self->device_info); + serialize_info[ENGINE_IDX] = trt_engine; return serialize_info; }, [](std::vector seralized_info) -> c10::intrusive_ptr { diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 33ce64198b..73796bed0b 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -11,7 +11,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "2"; +const std::string ABI_VERSION = "3"; struct CudaDevice { int64_t id; // CUDA device id