From dfa9ae8764fd1465cc6dad5e4b8d513ee2ffdf21 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 19 May 2021 15:46:46 -0700 Subject: [PATCH] fix(//core/conversion/conversionctx): Guard final engine building Fixes an issue where if final network validation fails, a segfault occurs. Now an exception is thrown which can be handled by the user Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversionctx/ConversionCtx.cpp | 12 +++++++----- core/util/Exception.cpp | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 67435252b7..163962ee19 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -23,16 +23,15 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { << "\n Max Workspace Size: " << s.workspace_size; if (s.max_batch_size != 0) { - os << "\n Max Batch Size: " << s.max_batch_size; + os << "\n Max Batch Size: " << s.max_batch_size; } else { - os << "\n Max Batch Size: Not set"; + os << "\n Max Batch Size: Not set"; } os << "\n Device Type: " << s.device.device_type \ << "\n GPU ID: " << s.device.gpu_id; - if (s.device.device_type == nvinfer1::DeviceType::kDLA) - { - os << "\n DLACore: " << s.device.dla_core; + if (s.device.device_type == nvinfer1::DeviceType::kDLA) { + os << "\n DLACore: " << s.device.dla_core; } os << "\n Engine Capability: " << s.capability \ << "\n Calibrator Created: " << (s.calibrator != nullptr); @@ -146,6 +145,9 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val std::string ConversionCtx::SerializeEngine() { auto engine = builder->buildEngineWithConfig(*net, *cfg); + if (!engine) { + TRTORCH_THROW_ERROR("Building TensorRT engine failed"); + } auto serialized_engine = engine->serialize(); engine->destroy(); auto engine_str = std::string((const char*)serialized_engine->data(), serialized_engine->size()); diff --git a/core/util/Exception.cpp b/core/util/Exception.cpp index 45ab9f8ba5..902dfbae25 100644 --- a/core/util/Exception.cpp +++ b/core/util/Exception.cpp @@ -11,7 +11,7 @@ Error::Error(const std::string& new_msg, const void* caller) : msg_stack_{new_ms } Error::Error(const char* file, const uint32_t line, const std::string& msg, const void* caller) - : Error(str("[enforce fail at ", file, ":", line, "] ", msg, "\n"), caller) {} + : Error(str("[Error thrown at ", file, ":", line, "] ", msg, "\n"), caller) {} std::string Error::msg() const { return std::accumulate(msg_stack_.begin(), msg_stack_.end(), std::string(""));