Skip to content

Commit

Permalink
[feat] Add dependency awareness to torch-trt partitioning (#40)
Browse files Browse the repository at this point in the history
Adds a heuristic to torch-trt partitioning's segmentation to avoid materializing segments until we hit a dependency of that segment. This can significantly reduce the number of segments/engines in cases where the linear traversal of torchscipt nodes would otherwise produce alternating torch and TRT segments which are not dependent on each-other

Fixes # (issue)

Please delete options that are not relevant and/or add your own.

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update

- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
  • Loading branch information
mfeliz-cruise committed Oct 6, 2022
1 parent e608bc7 commit 3a33b6e
Show file tree
Hide file tree
Showing 5 changed files with 466 additions and 329 deletions.
136 changes: 118 additions & 18 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,34 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
}
}

std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
std::set<torch::jit::Node*> dependent_nodes;
for (auto val : n->outputs()) {
for (auto use : val->uses()) {
dependent_nodes.insert(use.user);
}
}
if (const auto* schema = n->maybeSchema()) {
for (size_t i = 0; i < n->inputs().size(); ++i) {
const at::AliasInfo* formal = schema->arguments()[i].alias_info();
if (formal && formal->isWrite()) {
for (auto use : n->inputs()[i]->uses()) {
torch::jit::Node* use_node = use.user;
if (use_node->isAfter(n)) {
dependent_nodes.insert(use_node);
}
}
}
}
}
return dependent_nodes;
}

// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
auto nodes = block->nodes();
std::vector<torch::jit::Node*> cur_trt_nodes;
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
std::vector<torch::jit::Node*> min_block_fallback_nodes;
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) {
Expand All @@ -124,11 +148,16 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
// check if current node fallback or not
if (!ctx->shouldNodeRunInTorch(n)) {
cur_trt_nodes.push_back(n);
auto dependent_nodes = getDependentNodes(n);
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
} else {
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
if (cur_trt_nodes_uses.count(n)) {
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
}
cur_trt_nodes.clear();
cur_trt_nodes_uses.clear();
}
cur_trt_nodes.clear();
}
}
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
Expand Down Expand Up @@ -355,6 +384,59 @@ void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
setMinBlockFallbackNodes(ctx, block);
}

void merge_adjacent_segments_list_in_new_partition(
PartitionedGraph& original_partition,
PartitionedGraph& new_partition,
SegmentedBlock::SegmentedBlockTarget& segment_kind,
std::vector<size_t>& same_type_segment_idx) {
TORCHTRT_CHECK(!same_type_segment_idx.empty(), "Unable to merge empty segment list");
if (same_type_segment_idx.size() == 1) {
new_partition.push_back(original_partition[same_type_segment_idx[0]]);
} else {
auto first_idx = same_type_segment_idx[0];
for (size_t i = 1; i < same_type_segment_idx.size(); ++i) {
TORCHTRT_CHECK(
same_type_segment_idx[i] == (first_idx + i),
"Unable to merge non-sequential segments: " << same_type_segment_idx);
}
LOG_DEBUG(
"Merging adjacent " << SegmentedBlock::target_to_str(segment_kind) << " segments: " << same_type_segment_idx);
std::vector<torch::jit::Node*> nodes;
for (auto segment_to_merge : same_type_segment_idx) {
const auto& merge_nodes = original_partition[segment_to_merge].raw_nodes();
nodes.insert(nodes.end(), merge_nodes.begin(), merge_nodes.end());
}
new_partition.emplace_back(segment_kind, nodes);
}
}

PartitionedGraph merge_adjacent_segments_of_same_type(PartitionedGraph& original_partition) {
PartitionedGraph new_partition;
SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch;
std::vector<size_t> same_type_segment_idx;
for (size_t i = 0UL; i < original_partition.size(); ++i) {
auto& segment = original_partition[i];
if (same_type_segment_idx.empty()) {
segment_kind = segment.target();
} else if (segment_kind != segment.target() || segment.do_not_merge()) {
merge_adjacent_segments_list_in_new_partition(
original_partition, new_partition, segment_kind, same_type_segment_idx);
same_type_segment_idx.clear();
segment_kind = segment.target();
}
if (segment.do_not_merge()) {
new_partition.push_back(segment);
} else {
same_type_segment_idx.push_back(i);
}
}
if (!same_type_segment_idx.empty()) {
merge_adjacent_segments_list_in_new_partition(
original_partition, new_partition, segment_kind, same_type_segment_idx);
}
return new_partition;
}

void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
// Find all the fallback nodes and build execution decision LUT for all nodes
setNodeExecutorLUT(ctx, block);
Expand All @@ -365,58 +447,75 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
PartitionedGraph segmented_blocks;

std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
for (const auto n : nodes) {
// Skip constant nodes as they are resources for both kinds of modules
if (n->kind() == torch::jit::prim::Constant) {
continue;
}
auto dependent_nodes = getDependentNodes(n);
// the outputs of trt subgraph shouldn't be collections
if (ctx->shouldNodeRunInTensorRT(n)) {
in_prog_trt_blk_nodes.push_back(n);
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());

// If there is an active PyTorch block and we have passed the threshold for a valid TRT
// block then segment and reset the active PyTorch block
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size && !in_prog_pyt_blk_nodes.empty()) {
// If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize
// those dependencies in the graph
if (cur_pyt_nodes_uses.count(n)) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
cur_pyt_nodes_uses.clear();
}
} else {
// If there is an active TRT block that is valid segment and reset the active TRT block
// otherwise add it to the active PyTorch block and reset
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
} else {
LOG_DEBUG(
"In progress TRT block does not meet minimum block size requirements ("
<< in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
<< "), therefore folding into in progress PyTorch block");
in_prog_pyt_blk_nodes.insert(
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
// The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the
// graph or add them to the active PyTorch block
if (cur_trt_nodes_uses.count(n)) {
// If there is an active TRT block that is valid segment and reset the active TRT block
// otherwise add it to the active PyTorch block and reset
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
} else {
LOG_DEBUG(
"In progress TRT block does not meet minimum block size requirements ("
<< in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
<< "), therefore folding into in progress PyTorch block");
in_prog_pyt_blk_nodes.insert(
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
cur_pyt_nodes_uses.insert(cur_trt_nodes_uses.begin(), cur_trt_nodes_uses.end());
}
in_prog_trt_blk_nodes.clear();
cur_trt_nodes_uses.clear();
}
in_prog_trt_blk_nodes.clear();
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
// we shouldn't inject node for this block in dependency analysis process
if (n->kind() == torch::jit::prim::If) {
LOG_DEBUG(
"Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
if (!in_prog_pyt_blk_nodes.empty()) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
cur_pyt_nodes_uses.clear();
}
auto cond_node = std::vector<torch::jit::Node*>{n};
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
segmented_blocks.back().do_not_merge(true);
continue;
} else if (n->kind() == torch::jit::prim::Loop) {
if (!in_prog_pyt_blk_nodes.empty()) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
cur_pyt_nodes_uses.clear();
}
if (checkLoopEvaluatable(n)) {
in_prog_trt_blk_nodes.push_back(n);
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
} else {
auto loop_node = std::vector<torch::jit::Node*>{n};
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
segmented_blocks.back().do_not_merge(true);
}
continue;
}
in_prog_pyt_blk_nodes.push_back(n);
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
}
}

Expand All @@ -432,6 +531,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}

segmented_blocks = merge_adjacent_segments_of_same_type(segmented_blocks);
ctx->partitioned_blocks.insert({block, segmented_blocks});
return;
}
Expand Down
9 changes: 9 additions & 0 deletions core/partitioning/segmentedblock/SegmentedBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ struct SegmentedBlock {
return target_;
}

bool do_not_merge(void) const {
return do_not_merge_;
}

void do_not_merge(bool x) {
do_not_merge_ = x;
}

friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);

private:
Expand All @@ -106,6 +114,7 @@ struct SegmentedBlock {
std::vector<torch::jit::Node*> nodes_;
std::shared_ptr<torch::jit::Graph> g_;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
bool do_not_merge_ = false;
};

std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);
Expand Down
2 changes: 1 addition & 1 deletion tests/core/partitioning/test_conditionals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {

auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);

ASSERT_TRUE(conditional_engines_count == 2);
ASSERT_TRUE(conditional_engines_count == 1);
}

TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
Expand Down
2 changes: 1 addition & 1 deletion tests/core/partitioning/test_resolve_nontensor_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
}));
}
}
ASSERT_TRUE(trt_block_cnt == 2 && torch_block_cnt == 2);
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1);
}

TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
Expand Down
Loading

0 comments on commit 3a33b6e

Please sign in to comment.