diff --git a/core/compiler.cpp b/core/compiler.cpp index 91cb0f8281..f578758a50 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -184,7 +184,7 @@ void AddSegmentedBlockToGraph(std::shared_ptr& g, partitionin old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]]; } - LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n"); +// LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n"); return; } @@ -208,7 +208,10 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo // segment the graph and convert segmented TensorRT block - auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges); + auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges, convert_cfg.engine_settings.torch_fallback); + if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) { + return mod; + } int trt_engine_id = 0; std::unordered_map old_to_new_g; @@ -233,7 +236,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo new_g->registerOutput(old_to_new_g[output]); } - LOG_INFO(*new_g << "(After CompileGraph)\n"); + LOG_INFO(*new_g << "(StitchSegmentedGraph)\n"); auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g); diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 038ea3874c..97c7606bdb 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -39,7 +39,13 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { os << "\n Torch Fallback: " << s.torch_fallback.enabled; if (s.torch_fallback.enabled) { - os << "\n Fallback min block size: " << s.torch_fallback.min_block_size; + os << "\n Fallback Min Block Size: " << s.torch_fallback.min_block_size; + if (!s.torch_fallback.forced_fallback_operators.empty()) { + os << "\n Forced Fallback Operators:"; + for (auto it = s.torch_fallback.forced_fallback_operators.begin(); it != s.torch_fallback.forced_fallback_operators.end(); ++it) { + os << " " << *it; + } + } } return os; } diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index f9248d436a..23f61d9f80 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -62,6 +62,14 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_mapcopy(); + if (seg_block.raw_outputs().size() > 1) { + auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs())); + for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) { + copy_g->eraseOutput(idx); + } + copy_g->registerOutput(new_output_node->outputs()[0]); + } + torch::jit::script::Module cur_mod(c10::QualifiedName("module")); auto self = copy_g->insertInput(0, "self_1"); @@ -140,25 +148,45 @@ void registerSegmentsInputsOutputs(std::vector &segmented_blocks return; } -std::vector segment_graph(std::shared_ptr g, std::vector& input_ranges) { +void merge_nodes(std::vector &pytorch_nodes, std::vector &tensorrt_nodes, + std::vector &segmented_blocks, size_t min_block_size) { + if (!tensorrt_nodes.empty()) { + if (tensorrt_nodes.size() < min_block_size) { + pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); + } else { + if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); + pytorch_nodes.clear(); + } + tensorrt_nodes.clear(); + } +} + +std::vector segment_graph(std::shared_ptr g, + std::vector& input_ranges, + const conversion::TorchFallback &fallback_info) { + auto min_block_size = fallback_info.min_block_size; + std::unordered_set forced_fallback_operators(fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end()); std::vector segmented_blocks; auto nodes = g->block()->nodes(); // segment the nodes + std::vector tensorrt_nodes, pytorch_nodes; + for (const auto n : nodes) { if (n->kind() == torch::jit::prim::Constant) continue; + std::string node_string(n->kind().toQualString()); - auto block_target = conversion::OpSupported(n) ? SegmentedBlock::kTensorRT : SegmentedBlock::kTorch; - - if (segmented_blocks.empty() || block_target != segmented_blocks.back().target()) { - SegmentedBlock cur_block(block_target); - cur_block.appendNode(n); - segmented_blocks.push_back(cur_block); + if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) { + tensorrt_nodes.push_back(n); } else { - segmented_blocks.back().appendNode(n); + merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size); + pytorch_nodes.push_back(n); } } + merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size); + if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); registerSegmentsInputsOutputs(segmented_blocks, g); diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 31183de241..35e298ebcd 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -20,8 +20,17 @@ struct SegmentedBlock { kTensorRT, }; + SegmentedBlock() = default; + SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared()) {} + SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector &nodes) : + target_(blk_target), g_(std::make_shared()) { + for (auto &node : nodes) { + appendNode(node); + } + } + SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr g) : target_(blk_target), g_(g) {} enum SegmentedBlockTarget target() { @@ -40,6 +49,7 @@ struct SegmentedBlock { void registerOutput(torch::jit::Value* raw_input) { outputs_.push_back(raw_input); + g_->registerOutput(old_to_new_[raw_input]); } @@ -108,8 +118,9 @@ struct SegmentedBlock { }; -std::vector segment_graph(std::shared_ptr g, std::vector& input_ranges); - +std::vector segment_graph(std::shared_ptr g, + std::vector& input_ranges, + const conversion::TorchFallback &fallback_info); } } } \ No newline at end of file