diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 1feae52b5d..2993ee593e 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { << "\n Operating Precision: " << s.op_precision \ << "\n Make Refittable Engine: " << s.refit \ << "\n Debuggable Engine: " << s.debug \ - << "\n Strict Type: " << s.strict_types \ + << "\n Strict Types: " << s.strict_types \ << "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \ << "\n Min Timing Iterations: " << s.num_min_timing_iters \ << "\n Avg Timing Iterations: " << s.num_avg_timing_iters \ @@ -51,7 +51,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) case nvinfer1::DataType::kINT8: TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8"); cfg->setFlag(nvinfer1::BuilderFlag::kINT8); - cfg->setFlag(nvinfer1::BuilderFlag::kFP16); + if (!settings.strict_types) { + cfg->setFlag(nvinfer1::BuilderFlag::kFP16); + } input_type = nvinfer1::DataType::kFLOAT; TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator"); cfg->setInt8Calibrator(settings.calibrator);