diff --git a/core/compiler.cpp b/core/compiler.cpp index 8a0621c38e..8384433299 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -27,24 +27,6 @@ namespace trtorch { namespace core { -static std::unordered_map cuda_device_list; - -void update_cuda_device_list(void) { - int num_devices = 0; - auto status = cudaGetDeviceCount(&num_devices); - TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status); - cudaDeviceProp device_prop; - for (int i = 0; i < num_devices; i++) { - TRTORCH_CHECK( - (cudaGetDeviceProperties(&device_prop, i) == cudaSuccess), - "Unable to read CUDA Device Properies for device id: " << i); - std::string device_name(device_prop.name); - runtime::CudaDevice device = { - i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name}; - cuda_device_list[i] = device; - } -} - void AddEngineToGraph( torch::jit::script::Module mod, std::shared_ptr& g, @@ -52,8 +34,6 @@ void AddEngineToGraph( runtime::CudaDevice& device_info, std::string engine_id = "", bool fallback = false) { - // Scan and Update the list of available cuda devices - update_cuda_device_list(); auto engine_ptr = c10::make_intrusive(mod._ivalue()->name() + engine_id, serialized_engine, device_info); // Get required metadata about the engine out diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 077dbac6ba..7bc39a0749 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -50,6 +50,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe name = slugify(mod_name) + "_engine"; cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()); + TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine"); + // Easy way to get a unique name for each engine, maybe there is a more // descriptive way (using something associated with the graph maybe) id = reinterpret_cast(cuda_engine); diff --git a/core/runtime/register_trt_op.cpp b/core/runtime/register_trt_op.cpp index 2b37760a7f..566caff773 100644 --- a/core/runtime/register_trt_op.cpp +++ b/core/runtime/register_trt_op.cpp @@ -49,9 +49,9 @@ int select_cuda_device(const CudaDevice& conf_device) { int device_id = 0; auto dla_supported = get_dla_supported_SM(); - auto cuda_device_list = DeviceList::instance().get_devices(); + auto device_list = cuda_device_list.instance().get_devices(); - for (auto device : cuda_device_list) { + for (auto device : device_list) { auto compute_cap = std::to_string(device.second.major) + "." + std::to_string(device.second.minor); // In case of DLA select the DLA supported device ID if (conf_device.device_type == nvinfer1::DeviceType::kDLA) { diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 8a9c210671..b40f7e3b1f 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -78,28 +78,6 @@ CudaDevice deserialize_device(std::string device_info); CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type); -class DeviceList { - using DeviceMap = std::unordered_map; - DeviceMap device_list; - DeviceList() {} - - public: - static DeviceList& instance() { - static DeviceList obj; - return obj; - } - - void insert(int device_id, CudaDevice cuda_device) { - device_list[device_id] = cuda_device; - } - CudaDevice find(int device_id) { - return device_list[device_id]; - } - DeviceMap get_devices() { - return device_list; - } -}; - struct TRTEngine : torch::CustomClassHolder { // Each engine needs it's own runtime object nvinfer1::IRuntime* rt; @@ -125,6 +103,49 @@ struct TRTEngine : torch::CustomClassHolder { std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine); +class DeviceList { + using DeviceMap = std::unordered_map; + DeviceMap device_list; + + public: + // Scans and updates the list of available CUDA devices + DeviceList(void) { + int num_devices = 0; + auto status = cudaGetDeviceCount(&num_devices); + TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status); + cudaDeviceProp device_prop; + for (int i = 0; i < num_devices; i++) { + TRTORCH_CHECK( + (cudaGetDeviceProperties(&device_prop, i) == cudaSuccess), + "Unable to read CUDA Device Properies for device id: " << i); + std::string device_name(device_prop.name); + CudaDevice device = { + i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name}; + device_list[i] = device; + } + } + + public: + static DeviceList& instance() { + static DeviceList obj; + return obj; + } + + void insert(int device_id, CudaDevice cuda_device) { + device_list[device_id] = cuda_device; + } + CudaDevice find(int device_id) { + return device_list[device_id]; + } + DeviceMap get_devices() { + return device_list; + } +}; + +namespace { +static DeviceList cuda_device_list; +} + } // namespace runtime } // namespace core } // namespace trtorch diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index e42aaf4447..7b5d03418a 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -517,7 +517,7 @@ TRTORCH_API std::string ConvertGraphToTRTEngine( * in a TorchScript module * * @param engine: std::string - Pre-built serialized TensorRT engine - * @param info: CompileSepc::Device - Device information + * @param device: CompileSepc::Device - Device information * * Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript * module. Registers execution of the engine as the forward method of the module