Skip to content

Commit

Permalink
feat: support the case when the injected node is not supported in dep…
Browse files Browse the repository at this point in the history
…endency analysis

Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Apr 24, 2021
1 parent de3ba23 commit c67d8f6
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
2 changes: 1 addition & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);

Expand Down
87 changes: 74 additions & 13 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,46 @@ namespace trtorch {
namespace core {
namespace partitioning {

inline bool isTensorOrTensorList(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
}

struct usage_info {
int produce_id = -1;
std::vector<int> torch_use_id;
std::vector<int> tensorrt_use_id;
};

inline bool isTensorOrTensorList(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
}

bool isAllNodesSupported(const std::vector<torch::jit::Node*>& nodes) {
for (auto node : nodes) {
if (!conversion::OpSupported(node)) {
return false;
}
}
return true;
}

bool containNonTensorInputs(torch::jit::Node* n, const std::unordered_set<torch::jit::Value*>& target_inputs) {
for (auto input : n->inputs()) {
if (!isTensorOrTensorList(input) && target_inputs.count(input)) {
return true;
}
}
return false;
}

bool containNonTensorOutputs(torch::jit::Node* n) {
for (auto output : n->outputs()) {
if (!isTensorOrTensorList(output)) {
return true;
}
}
return false;
}

std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
// using bfs to get the DAG dependency nodes for input value
// use bfs to get the DAG dependency nodes for input value
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
std::unordered_set<torch::jit::Node*> visited;
Expand All @@ -43,17 +70,50 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
return stk;
}

SegmentedBlock injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> nontensor_inputs;
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
nontensor_inputs.push_back(input);
}
}
std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes(nontensor_inputs);
new_block_nodes.insert(new_block_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
return std::move(SegmentedBlock(seg_block.target(), new_block_nodes));
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);

std::vector<SegmentedBlock> new_seg_blocks;
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
// one new block
if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) {
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
} else {
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, dependency_nodes);
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
bool prev_non_tensor_outputs = false;
for (auto n : seg_block.raw_nodes()) {
// it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
if (containNonTensorInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
if (!tensorrt_nodes.empty()) {
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
}
pytorch_nodes.push_back(n);
prev_non_tensor_outputs = containNonTensorOutputs(n);
} else {
if (!pytorch_nodes.empty()) {
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
}
tensorrt_nodes.push_back(n);
}
}
if (!tensorrt_nodes.empty()) {
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
} else {
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
}
}
return std::move(new_seg_blocks);
}

void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
Expand All @@ -80,16 +140,17 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
int first_torch_id = use_info.torch_use_id.front();
if (!updated_segments.count(first_torch_id)) {
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]).front();
segmented_blocks[first_torch_id] = new_torch_block;
updated_segments.insert(first_torch_id);
}
} else {
// KTensorRT segments always need to inject nodes for the nonTensor inputs
for (int i : use_info.tensorrt_use_id) {
if (!updated_segments.count(i)) {
auto new_seg_block = injectNodesForNonTensorInputs(segmented_blocks[i]);
segmented_blocks[i] = new_seg_block;
auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[i]);
segmented_blocks.erase(segmented_blocks.begin() + i);
segmented_blocks.insert(segmented_blocks.begin() + i, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(i);
}
}
Expand Down

0 comments on commit c67d8f6

Please sign in to comment.