Skip to content

Commit

Permalink
feat(//core/conversion/conversionctx): Make op precision available at
Browse files Browse the repository at this point in the history
conversion time through ctx

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 4, 2020
1 parent cd6b1b9 commit 78a1c61
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ void AddInputs(ConversionCtx* ctx,
TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)");

ctx->cfg->addOptimizationProfile(profile);
// TODO: Enable in TRT 7.1
// if (ctx->op_precision == nvinfer1::DataType::kINT8) {
// ctx->cfg->setCalibrationProfile(profile);
// }
}

void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
input_type = nvinfer1::DataType::kFLOAT;
break;
}
op_precision = settings.op_precision;

if (settings.refit) {
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct ConversionCtx {
nvinfer1::INetworkDefinition* net;
nvinfer1::IBuilderConfig* cfg;
nvinfer1::DataType input_type;
nvinfer1::DataType op_precision;
BuilderSettings settings;
util::logging::TRTorchLogger logger;
// Pointers to data that needs to remain alive until conversion is done
Expand Down

0 comments on commit 78a1c61

Please sign in to comment.