diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 6d4f4904c7..9d579b28d1 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -41,18 +41,14 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe device_info = cuda_device; set_cuda_device(device_info); - rt = nvinfer1::createInferRuntime(util::logging::get_logger()); + rt = std::shared_ptr(nvinfer1::createInferRuntime(util::logging::get_logger())); name = slugify(mod_name) + "_engine"; - cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()); + cuda_engine = std::shared_ptr(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine"); - // Easy way to get a unique name for each engine, maybe there is a more - // descriptive way (using something associated with the graph maybe) - id = reinterpret_cast(cuda_engine); - - exec_ctx = cuda_engine->createExecutionContext(); + exec_ctx = std::shared_ptr(cuda_engine->createExecutionContext()); uint64_t inputs = 0; uint64_t outputs = 0; @@ -74,7 +70,6 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe } TRTEngine& TRTEngine::operator=(const TRTEngine& other) { - id = other.id; rt = other.rt; cuda_engine = other.cuda_engine; device_info = other.device_info; @@ -83,12 +78,6 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) { return (*this); } -TRTEngine::~TRTEngine() { - delete exec_ctx; - delete cuda_engine; - delete rt; -} - // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { // auto input_vec = inputs.vec(); diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 7677e5b349..83e724bf8e 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include "ATen/core/function_schema.h" #include "NvInfer.h" @@ -37,18 +38,17 @@ CudaDevice deserialize_device(std::string device_info); struct TRTEngine : torch::CustomClassHolder { // Each engine needs it's own runtime object - nvinfer1::IRuntime* rt; - nvinfer1::ICudaEngine* cuda_engine; - nvinfer1::IExecutionContext* exec_ctx; + std::shared_ptr rt; + std::shared_ptr cuda_engine; + std::shared_ptr exec_ctx; std::pair num_io; - EngineID id; std::string name; CudaDevice device_info; std::unordered_map in_binding_map; std::unordered_map out_binding_map; - ~TRTEngine(); + ~TRTEngine() = default; TRTEngine(std::string serialized_engine, CudaDevice cuda_device); TRTEngine(std::vector serialized_info); TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device);