From 13988638596e8ea4b35e29625ee400fa49612a3f Mon Sep 17 00:00:00 2001 From: JZZ-NOTE Date: Thu, 2 Mar 2023 06:49:32 +0000 Subject: [PATCH] add function to disable trt op by output name --- .../analysis/ir_passes/tensorrt_subgraph_pass.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 2883cb4dfe157..e9832c1d94a21 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -135,6 +135,16 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( << " is diabled by config in TensorRT"; return false; } + for (const auto &out_var : node->Op()->OutputNames()) { + for (const auto &var_name : node->Op()->Output(out_var)) { + if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), var_name) != + trt_disabled_ops.end()) { + VLOG(3) << node->Op()->Type().c_str() + << " is diabled by config in TensorRT"; + return false; + } + } + } bool is_ok = tensorrt::OpTeller::Global().Tell( node, no_calib_int8, with_dynamic_shape); if (!is_ok)