diff --git a/core/compiler.cpp b/core/compiler.cpp index f8d2861e8e..2f94ba8ead 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -60,8 +60,12 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptrcreate(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef(engine_inputs), num_io.second); g->block()->appendNode(engine_node); - for (auto o : engine_node->outputs()) { - g->registerOutput(o); + if (engine_node->outputs().size() > 1) { + auto return_tuple_node = g->createTuple(engine_node->outputs()); + g->block()->appendNode(return_tuple_node); + g->registerOutput(return_tuple_node->outputs()[0]); + } else { + g->registerOutput(engine_node->outputs()[0]); } LOG_DEBUG(*g << "(AddEngineToGraph)\n");