Skip to content

Commit

Permalink
refactor: Review comments incorporated
Browse files Browse the repository at this point in the history
Signed-off-by: Anurag Dixit <[email protected]>
  • Loading branch information
Anurag Dixit committed Jun 10, 2021
1 parent 0bd8d28 commit 611f6a1
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 45 deletions.
20 changes: 0 additions & 20 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,13 @@
namespace trtorch {
namespace core {

static std::unordered_map<int, runtime::CudaDevice> cuda_device_list;

void update_cuda_device_list(void) {
int num_devices = 0;
auto status = cudaGetDeviceCount(&num_devices);
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
cudaDeviceProp device_prop;
for (int i = 0; i < num_devices; i++) {
TRTORCH_CHECK(
(cudaGetDeviceProperties(&device_prop, i) == cudaSuccess),
"Unable to read CUDA Device Properies for device id: " << i);
std::string device_name(device_prop.name);
runtime::CudaDevice device = {
i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name};
cuda_device_list[i] = device;
}
}

void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
const std::string& serialized_engine,
runtime::CudaDevice& device_info,
std::string engine_id = "",
bool fallback = false) {
// Scan and Update the list of available cuda devices
update_cuda_device_list();
auto engine_ptr =
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
// Get required metadata about the engine out
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
name = slugify(mod_name) + "_engine";

cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");

// Easy way to get a unique name for each engine, maybe there is a more
// descriptive way (using something associated with the graph maybe)
id = reinterpret_cast<EngineID>(cuda_engine);
Expand Down
4 changes: 2 additions & 2 deletions core/runtime/register_trt_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ int select_cuda_device(const CudaDevice& conf_device) {
int device_id = 0;
auto dla_supported = get_dla_supported_SM();

auto cuda_device_list = DeviceList::instance().get_devices();
auto device_list = cuda_device_list.instance().get_devices();

for (auto device : cuda_device_list) {
for (auto device : device_list) {
auto compute_cap = std::to_string(device.second.major) + "." + std::to_string(device.second.minor);
// In case of DLA select the DLA supported device ID
if (conf_device.device_type == nvinfer1::DeviceType::kDLA) {
Expand Down
65 changes: 43 additions & 22 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,6 @@ CudaDevice deserialize_device(std::string device_info);

CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type);

class DeviceList {
using DeviceMap = std::unordered_map<int, CudaDevice>;
DeviceMap device_list;
DeviceList() {}

public:
static DeviceList& instance() {
static DeviceList obj;
return obj;
}

void insert(int device_id, CudaDevice cuda_device) {
device_list[device_id] = cuda_device;
}
CudaDevice find(int device_id) {
return device_list[device_id];
}
DeviceMap get_devices() {
return device_list;
}
};

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
nvinfer1::IRuntime* rt;
Expand All @@ -125,6 +103,49 @@ struct TRTEngine : torch::CustomClassHolder {

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

class DeviceList {
using DeviceMap = std::unordered_map<int, CudaDevice>;
DeviceMap device_list;

public:
// Scans and updates the list of available CUDA devices
DeviceList(void) {
int num_devices = 0;
auto status = cudaGetDeviceCount(&num_devices);
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
cudaDeviceProp device_prop;
for (int i = 0; i < num_devices; i++) {
TRTORCH_CHECK(
(cudaGetDeviceProperties(&device_prop, i) == cudaSuccess),
"Unable to read CUDA Device Properies for device id: " << i);
std::string device_name(device_prop.name);
CudaDevice device = {
i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name};
device_list[i] = device;
}
}

public:
static DeviceList& instance() {
static DeviceList obj;
return obj;
}

void insert(int device_id, CudaDevice cuda_device) {
device_list[device_id] = cuda_device;
}
CudaDevice find(int device_id) {
return device_list[device_id];
}
DeviceMap get_devices() {
return device_list;
}
};

namespace {
static DeviceList cuda_device_list;
}

} // namespace runtime
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
* in a TorchScript module
*
* @param engine: std::string - Pre-built serialized TensorRT engine
* @param info: CompileSepc::Device - Device information
* @param device: CompileSepc::Device - Device information
*
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
* module. Registers execution of the engine as the forward method of the module
Expand Down

0 comments on commit 611f6a1

Please sign in to comment.