Skip to content

Commit

Permalink
fix(//core/conversion): Check for calibrator before setting int8 mode
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 24, 2020
1 parent 8d22bdd commit 3afd209
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
input_type = nvinfer1::DataType::kFLOAT;
// If the calibrator is nullptr then TRT will use default quantization
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
Expand Down
16 changes: 8 additions & 8 deletions core/util/logging/TRTorchLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace trt = nvinfer1;

namespace util {
namespace logging {

TRTorchLogger::TRTorchLogger(std::string prefix, Severity severity, bool color)
: prefix_(prefix), reportable_severity_(severity), color_(color) {}

Expand All @@ -32,7 +32,7 @@ void TRTorchLogger::log(Severity severity, const char* msg) {
if (severity > reportable_severity_) {
return;
}

if (color_) {
switch (severity) {
case Severity::kINTERNAL_ERROR: std::cerr << TERM_RED; break;
Expand All @@ -41,9 +41,9 @@ void TRTorchLogger::log(Severity severity, const char* msg) {
case Severity::kINFO: std::cerr << TERM_GREEN; break;
case Severity::kVERBOSE: std::cerr << TERM_MAGENTA; break;
default: break;
}
}
}

switch (severity) {
case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
case Severity::kERROR: std::cerr << "ERROR: "; break;
Expand All @@ -52,11 +52,11 @@ void TRTorchLogger::log(Severity severity, const char* msg) {
case Severity::kVERBOSE: std::cerr << "DEBUG: "; break;
default: std::cerr << "UNKNOWN: "; break;
}

if (color_) {
std::cerr << TERM_NORMAL;
}

std::cerr << prefix_ << msg << std::endl;
}

Expand Down Expand Up @@ -92,7 +92,7 @@ bool TRTorchLogger::get_is_colored_output_on() {
return color_;
}


namespace {

TRTorchLogger& get_global_logger() {
Expand All @@ -104,7 +104,7 @@ TRTorchLogger& get_global_logger() {
static TRTorchLogger global_logger("[TRTorch] - ",
LogLevel::kERROR,
false);
#endif
#endif
return global_logger;
}

Expand Down

0 comments on commit 3afd209

Please sign in to comment.