From c67d8f6f7321252b2d435e78c1b996e4477f663e Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Sat, 24 Apr 2021 18:44:41 -0500 Subject: [PATCH] feat: support the case when the injected node is not supported in dependency analysis Signed-off-by: Bo Wang --- core/compiler.cpp | 2 +- core/partitioning/partitioning.cpp | 87 +++++++++++++++++++++++++----- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index a362b1d639..964de26da9 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -274,7 +274,7 @@ torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) { auto new_g = std::make_shared(); AddEngineToGraph(new_mod, new_g, engine); auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g); - auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g); + auto schema = util::GenerateGraphSchema(new_method->name(), new_g); new_mod.type()->addMethod(new_method); new_method->setSchema(schema); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index f28ba11e81..024812d714 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -9,19 +9,46 @@ namespace trtorch { namespace core { namespace partitioning { -inline bool isTensorOrTensorList(torch::jit::Value* val) { - return val->type()->isSubtypeOf(torch::jit::TensorType::get()) || - val->type()->isSubtypeOf(torch::jit::ListType::ofTensors()); -} - struct usage_info { int produce_id = -1; std::vector torch_use_id; std::vector tensorrt_use_id; }; +inline bool isTensorOrTensorList(torch::jit::Value* val) { + return val->type()->isSubtypeOf(torch::jit::TensorType::get()) || + val->type()->isSubtypeOf(torch::jit::ListType::ofTensors()); +} + +bool isAllNodesSupported(const std::vector& nodes) { + for (auto node : nodes) { + if (!conversion::OpSupported(node)) { + return false; + } + } + return true; +} + +bool containNonTensorInputs(torch::jit::Node* n, const std::unordered_set& target_inputs) { + for (auto input : n->inputs()) { + if (!isTensorOrTensorList(input) && target_inputs.count(input)) { + return true; + } + } + return false; +} + +bool containNonTensorOutputs(torch::jit::Node* n) { + for (auto output : n->outputs()) { + if (!isTensorOrTensorList(output)) { + return true; + } + } + return false; +} + std::vector getDependencyNodes(std::vector& vals) { - // using bfs to get the DAG dependency nodes for input value + // use bfs to get the DAG dependency nodes for input value std::queue> q( std::deque(vals.begin(), vals.end())); std::unordered_set visited; @@ -43,7 +70,7 @@ std::vector getDependencyNodes(std::vector injectNodesForNonTensorInputs(SegmentedBlock& seg_block) { // reconstruct segmented_block if this block requires nonTensor input std::vector nontensor_inputs; for (auto input : seg_block.raw_inputs()) { @@ -51,9 +78,42 @@ SegmentedBlock injectNodesForNonTensorInputs(SegmentedBlock& seg_block) { nontensor_inputs.push_back(input); } } - std::vector new_block_nodes = getDependencyNodes(nontensor_inputs); - new_block_nodes.insert(new_block_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - return std::move(SegmentedBlock(seg_block.target(), new_block_nodes)); + std::vector dependency_nodes = getDependencyNodes(nontensor_inputs); + + std::vector new_seg_blocks; + // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only + // one new block + if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) { + dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); + new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); + } else { + // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again + std::unordered_set nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end()); + new_seg_blocks.emplace_back(SegmentedBlock::kTorch, dependency_nodes); + std::vector tensorrt_nodes, pytorch_nodes; + bool prev_non_tensor_outputs = false; + for (auto n : seg_block.raw_nodes()) { + // it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block + if (containNonTensorInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) { + if (!tensorrt_nodes.empty()) { + new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); + } + pytorch_nodes.push_back(n); + prev_non_tensor_outputs = containNonTensorOutputs(n); + } else { + if (!pytorch_nodes.empty()) { + new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + } + tensorrt_nodes.push_back(n); + } + } + if (!tensorrt_nodes.empty()) { + new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); + } else { + new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + } + } + return std::move(new_seg_blocks); } void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr g) { @@ -80,7 +140,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr< if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) { int first_torch_id = use_info.torch_use_id.front(); if (!updated_segments.count(first_torch_id)) { - auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]); + auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]).front(); segmented_blocks[first_torch_id] = new_torch_block; updated_segments.insert(first_torch_id); } @@ -88,8 +148,9 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr< // KTensorRT segments always need to inject nodes for the nonTensor inputs for (int i : use_info.tensorrt_use_id) { if (!updated_segments.count(i)) { - auto new_seg_block = injectNodesForNonTensorInputs(segmented_blocks[i]); - segmented_blocks[i] = new_seg_block; + auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[i]); + segmented_blocks.erase(segmented_blocks.begin() + i); + segmented_blocks.insert(segmented_blocks.begin() + i, to_inject_blocks.begin(), to_inject_blocks.end()); updated_segments.insert(i); } }