Skip to content

Commit

Permalink
fix: refactor the resegmentation for TensorRT segments in ResolveNonT…
Browse files Browse the repository at this point in the history
…ensorInput

Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed May 4, 2022
1 parent 10b55d4 commit 3cc2dfb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
66 changes: 51 additions & 15 deletions core/partitioning/partitioning.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
q.pop();
auto node = cur_val->node();
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
visited.insert(node);
stk.push_back(node);
for (auto input : node->inputs()) {
if (!isTensorOrTensorList(input)) {
Expand All @@ -89,14 +90,14 @@ std::vector<torch::jit::Node*> getOutputNodes(
std::unordered_set<torch::jit::Node*> visited;
q.push(value);

// top-down order traveling
// top-down order traversing
while (!q.empty()) {
auto cur_val = q.front();
q.pop();
for (auto use : cur_val->uses()) {
auto node = use.user;
// use node must be in seg_block_nodes
if (seg_block_nodes.count(node) != 0 && !visited.count(node)) {
if (seg_block_nodes.count(node) && !visited.count(node)) {
stk.push_back(node);
visited.insert(node);
// travel its' all outputs
Expand All @@ -109,10 +110,41 @@ std::vector<torch::jit::Node*> getOutputNodes(
}
}

// top-down order and we don't need reverse it
// top-down order and we don't need to reverse it
return stk;
}

void getDirtyNodes(
std::unordered_set<torch::jit::Node*>& dirty_nodes,
const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
std::queue<torch::jit::Node*> q;
for (auto& node : dirty_nodes) {
q.push(node);
}
dirty_nodes.clear();

while (!q.empty()) {
auto cur_node = q.front();
q.pop();
if (!dirty_nodes.count(cur_node) && seg_block_nodes.count(cur_node)) {
dirty_nodes.insert(cur_node);
for (auto input : cur_node->inputs()) {
if (!isTensorOrTensorList(input)) {
q.push(input->node());
}
}
for (auto output : cur_node->outputs()) {
if (!isTensorOrTensorList(output)) {
for (auto use : output->uses()) {
auto node = use.user;
q.push(node);
}
}
}
}
}
}

std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock> segmentBlocksWithTensorListInputs(
SegmentedBlock& seg_block,
const std::unordered_map<torch::jit::Value*, SegmentedBlock>& tensorlist_inputs) {
Expand Down Expand Up @@ -163,25 +195,29 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
} else {
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end());
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;

bool prev_non_tensor_outputs = false;
// take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use
// dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor
// values consumed by dirty nodes
std::unordered_set<torch::jit::Node*> dirty_nodes;
const std::unordered_set<torch::jit::Node*> seg_block_nodes(
seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());

for (auto n : seg_block.raw_nodes()) {
if (containTargetInputs(n, nontensor_inputs_set)) {
dirty_nodes.insert(n);
}
}
getDirtyNodes(dirty_nodes, seg_block_nodes);
for (auto n : seg_block.raw_nodes()) {
// Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
// In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
// SegmentedBlock.
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
// If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
// TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
if (dirty_nodes.count(n)) {
if (!tensorrt_nodes.empty()) {
new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes);
tensorrt_nodes.clear();
}
pytorch_nodes.push_back(n);
prev_non_tensor_outputs = containNonTensorOutputs(n);
} else {
// If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
if (!pytorch_nodes.empty()) {
new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
Expand All @@ -190,7 +226,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
}
}

// Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly.
// Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly.
if (!tensorrt_nodes.empty()) {
new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes);
} else {
Expand Down
2 changes: 2 additions & 0 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void getSegmentsOutputByRunning(
jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar());
} else if (input->type()->kind() == torch::jit::TypeKind::DictType) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict());
} else if (input->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toDevice());
} else {
TORCHTRT_THROW_ERROR(
"Expected to find type " << input->type()->str() << " for value " << input->debugName()
Expand Down

0 comments on commit 3cc2dfb

Please sign in to comment.