Skip to content

Commit

Permalink
feat: Using shared_ptrs to manage TRT resources in runtime
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 7, 2021
1 parent 2cd9fad commit e336630
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
17 changes: 3 additions & 14 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,14 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
device_info = cuda_device;
set_cuda_device(device_info);

rt = nvinfer1::createInferRuntime(util::logging::get_logger());
rt = std::shared_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(util::logging::get_logger()));

name = slugify(mod_name) + "_engine";

cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
cuda_engine = std::shared_ptr<nvinfer1::ICudaEngine>(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");

// Easy way to get a unique name for each engine, maybe there is a more
// descriptive way (using something associated with the graph maybe)
id = reinterpret_cast<EngineID>(cuda_engine);

exec_ctx = cuda_engine->createExecutionContext();
exec_ctx = std::shared_ptr<nvinfer1::IExecutionContext>(cuda_engine->createExecutionContext());

uint64_t inputs = 0;
uint64_t outputs = 0;
Expand All @@ -74,7 +70,6 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
id = other.id;
rt = other.rt;
cuda_engine = other.cuda_engine;
device_info = other.device_info;
Expand All @@ -83,12 +78,6 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
return (*this);
}

TRTEngine::~TRTEngine() {
delete exec_ctx;
delete cuda_engine;
delete rt;
}

// TODO: Implement a call method
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
// auto input_vec = inputs.vec();
Expand Down
10 changes: 5 additions & 5 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <map>
#include <memory>
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
Expand Down Expand Up @@ -37,18 +38,17 @@ CudaDevice deserialize_device(std::string device_info);

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
nvinfer1::IRuntime* rt;
nvinfer1::ICudaEngine* cuda_engine;
nvinfer1::IExecutionContext* exec_ctx;
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
std::pair<uint64_t, uint64_t> num_io;
EngineID id;
std::string name;
CudaDevice device_info;

std::unordered_map<uint64_t, uint64_t> in_binding_map;
std::unordered_map<uint64_t, uint64_t> out_binding_map;

~TRTEngine();
~TRTEngine() = default;
TRTEngine(std::string serialized_engine, CudaDevice cuda_device);
TRTEngine(std::vector<std::string> serialized_info);
TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device);
Expand Down

0 comments on commit e336630

Please sign in to comment.