diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 7037e9a512..1feae52b5d 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -51,6 +51,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) case nvinfer1::DataType::kINT8: TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8"); cfg->setFlag(nvinfer1::BuilderFlag::kINT8); + cfg->setFlag(nvinfer1::BuilderFlag::kFP16); input_type = nvinfer1::DataType::kFLOAT; TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator"); cfg->setInt8Calibrator(settings.calibrator);