Skip to content

Commit

Permalink
fix: Fix a core partitioning algo bug where non-tensor input segments…
Browse files Browse the repository at this point in the history
… are not updated correctly

Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Oct 4, 2021
1 parent 1aa492f commit cc10876
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 38 deletions.
15 changes: 7 additions & 8 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
pytorch_nodes.push_back(n);
prev_non_tensor_outputs = containNonTensorOutputs(n);
} else {
// If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
// If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
if (!pytorch_nodes.empty()) {
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
Expand All @@ -132,6 +132,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
}
}

return std::move(new_seg_blocks);
}

Expand Down Expand Up @@ -159,6 +160,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
}
}

// For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block that has/produces it.
for (auto& use : usage_counts) {
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
if (segmented_blocks[i].contain_raw_value(use.first)) {
Expand All @@ -167,6 +169,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
}
}


std::unordered_set<int> updated_segments;
for (auto& use : usage_counts) {
auto use_info = use.second;
Expand All @@ -178,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// TRTorch doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
segmented_blocks.insert(
segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(first_torch_id);
}
}
Expand Down Expand Up @@ -314,6 +316,7 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
continue;
} else if (n->kind() == torch::jit::prim::Loop) {

if (!pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
Expand Down Expand Up @@ -347,19 +350,15 @@ std::vector<SegmentedBlock> Partition(
const PartitionInfo& partition_info) {
LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
LOG_DEBUG("Partitioning graph into PyTorch and TensorRT segmented blocks");
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);

// resolve nonTensor inputs/outputs
LOG_DEBUG("Resolving non-tensor type inputs/outputs (eg: int/float types)");
resolveNonTensorInputs(segmented_blocks);

// register input/output torch::jit::Value for segmented graphs
LOG_DEBUG("Registering input/outputs for segmented blocks");
registerSegmentsOutputs(segmented_blocks, block);

// run shape analysis on each segmented block
LOG_DEBUG("Running shape analysis for all the segmented blocks");
runShapeAnalysis(segmented_blocks, input_ivalues_map);

return segmented_blocks;
Expand Down
4 changes: 1 addition & 3 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void getSegmentsOutputByRunning(
for (auto& input : seg_block.raw_inputs()) {
TRTORCH_CHECK(
ivalues_maps.count(input),
"Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n");
"Could not find torch::jit::Value* " << input->debugName() << " produced from " << util::node_info(input->node()) << " in lowering graph for mini graph input.\n");
if (input->node()->kind() == torch::jit::prim::Param) {
jit_inputs_ivalues.push_back(ivalues_maps[input]);
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
Expand Down Expand Up @@ -108,10 +108,8 @@ void runShapeAnalysis(
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
// register every segment's input shape, and it's running output IValues
for (auto& seg_block : segmented_blocks) {
LOG_DEBUG("Segmented graph: " << *seg_block.g());
torch::jit::ConstantPooling(seg_block.g());
getSegmentsOutputByRunning(seg_block, ivalues_maps);
LOG_DEBUG("=================");
}
return;
}
Expand Down
54 changes: 27 additions & 27 deletions tests/core/partitioning/test_loop_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,30 @@ TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}

// TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
// torch::jit::script::Module mod;
// try {
// mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
// } catch (const c10::Error& e) {
// std::cerr << "error loading the model\n";
// return;
// }
//
// const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
// std::vector<torch::jit::IValue> jit_inputs_ivalues;
// std::vector<torch::jit::IValue> trt_inputs_ivalues;
// for (auto in_shape : input_shapes) {
// auto in = at::randint(5, in_shape, {at::kCUDA});
// jit_inputs_ivalues.push_back(in.clone());
// trt_inputs_ivalues.push_back(in.clone());
// }
//
// std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
// trtorch::core::CompileSpec cfg(input_ranges);
// cfg.partition_info.enabled = true;
//
// auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
// auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
// auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
// }
TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
trtorch::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
}

0 comments on commit cc10876

Please sign in to comment.