From 78a1c614130030e0ca8bf477dfd8c35bccdebe25 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sun, 3 May 2020 20:42:25 -0700 Subject: [PATCH] feat(//core/conversion/conversionctx): Make op precision available at conversion time through ctx Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion.cpp | 4 ++++ core/conversion/conversionctx/ConversionCtx.cpp | 1 + core/conversion/conversionctx/ConversionCtx.h | 1 + 3 files changed, 6 insertions(+) diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 248ad52b98..c602e48b89 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -160,6 +160,10 @@ void AddInputs(ConversionCtx* ctx, TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)"); ctx->cfg->addOptimizationProfile(profile); + // TODO: Enable in TRT 7.1 + // if (ctx->op_precision == nvinfer1::DataType::kINT8) { + // ctx->cfg->setCalibrationProfile(profile); + // } } void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 4afeb8238a..acde8024c3 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -60,6 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) input_type = nvinfer1::DataType::kFLOAT; break; } + op_precision = settings.op_precision; if (settings.refit) { cfg->setFlag(nvinfer1::BuilderFlag::kREFIT); diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index 81bf99ca7c..1d2581fdc9 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -47,6 +47,7 @@ struct ConversionCtx { nvinfer1::INetworkDefinition* net; nvinfer1::IBuilderConfig* cfg; nvinfer1::DataType input_type; + nvinfer1::DataType op_precision; BuilderSettings settings; util::logging::TRTorchLogger logger; // Pointers to data that needs to remain alive until conversion is done