Skip to content

Commit

Permalink
feat(//core/partitioning): Improved logging and code org for the
Browse files Browse the repository at this point in the history
segmentation step of partitioning

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Oct 19, 2021
1 parent 17e0e8a commit 8927e77
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 43 deletions.
20 changes: 20 additions & 0 deletions core/partitioning/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ namespace trtorch {
namespace core {
namespace partitioning {

SegmentedBlock::SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
: id_(id), target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
for (auto& node : nodes) {
nodes_.push_back(node);
appendNode(node);
}
}

SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
for (auto& node : nodes) {
Expand Down Expand Up @@ -62,6 +70,18 @@ torch::jit::Node* SegmentedBlock::cloneNode(torch::jit::Node* node) {
return new_node;
}

std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b) {
os << "Segment Block @" << b.id_ << ":" << std::endl;
os << " Target: " << b.target_ << std::endl;
os << " Graph: " << *b.g_ << std::endl;
return os;
}

std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t) {
os << SegmentedBlock::target_to_str(t) << std::endl;
return os;
}

} // namespace partitioning
} // namespace core
} // namespace trtorch
17 changes: 17 additions & 0 deletions core/partitioning/SegmentedBlock.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <vector>
#include <ostream>

#include "NvInfer.h"
#include "core/ir/ir.h"
Expand All @@ -18,10 +19,21 @@ struct SegmentedBlock {
kTensorRT,
};

static std::string target_to_str(SegmentedBlockTarget t) {
if (t == SegmentedBlockTarget::kTorch) {
return "Torch";
} else {
return "TensorRT";
}
}

using BlockID = uint64_t;

SegmentedBlock() = default;
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);

torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
torch::jit::Node* cloneNode(torch::jit::Node* node);
Expand Down Expand Up @@ -74,7 +86,10 @@ struct SegmentedBlock {
return target_;
}

friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);

private:
BlockID id_;
SegmentedBlockTarget target_;
std::vector<ir::Input> in_shape_;
std::vector<torch::jit::Value*> inputs_;
Expand All @@ -84,6 +99,8 @@ struct SegmentedBlock {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
};

std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);

} // namespace partitioning
} // namespace core
} // namespace trtorch
118 changes: 84 additions & 34 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,81 +275,120 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
return compile_to_trt;
}

std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set<std::string>& torch_ops) {
// If the op is not supported by the conversion phase it should run in PyTorch
if (!conversion::OpSupported(n)) {
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
return false;
}

// If the user specifies the op to run in Torch it should run in PyTorch
if (torch_ops.find(n->kind().toQualString()) != torch_ops.end()) {
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
return false;
}

// If the user specifies the module containing this op to run in torch it should run in PyTorch
const auto to_compile_sym = c10::Symbol::attr("to_compile");
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
return false;
}

LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n));
return true;
}

void finalize_block(PartitionedGraph& g, SegmentedBlock::SegmentedBlockTarget kind, std::vector<torch::jit::Node*>& nodes) {
SegmentedBlock::BlockID b_id= g.size();
LOG_DEBUG("Finalizing in progress " << SegmentedBlock::target_to_str(kind) << " block");
g.emplace_back(b_id, kind, nodes);
nodes.clear();
LOG_DEBUG(g.back());
}

PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
auto min_block_size = partition_info.min_block_size;
std::unordered_set<std::string> forced_fallback_operators(
std::unordered_set<std::string> forced_fallback_ops(
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());

auto nodes = block->nodes();
std::vector<SegmentedBlock> segmented_blocks;
PartitionedGraph segmented_blocks;

// segment the nodes
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
for (const auto n : nodes) {
// Skip constant nodes as they are resources for both kinds of modules
if (n->kind() == torch::jit::prim::Constant) {
continue;
}

std::string node_string(n->kind().toQualString());
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) &&
(!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
tensorrt_nodes.push_back(n);
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
if (should_run_in_trt(n, forced_fallback_ops)) {
in_prog_trt_blk_nodes.push_back(n);

// If there is an active PyTorch block and we have passed the threshold for a valid TRT
// block then segment and reset the active PyTorch block
if (in_prog_trt_blk_nodes.size() >= min_block_size && !in_prog_pyt_blk_nodes.empty()) {
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}
} else {
if (tensorrt_nodes.size() >= min_block_size) {
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
// If there is an active TRT block that is valid segment and reset the active TRT block
// otherwise add it to the active PyTorch block and reset
if (in_prog_trt_blk_nodes.size() >= min_block_size) {
finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
} else {
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
LOG_DEBUG("In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
}
tensorrt_nodes.clear();
in_prog_trt_blk_nodes.clear();
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
// we shouldn't inject node for this block in dependency analysis process
if (n->kind() == torch::jit::prim::If) {
if (!pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
LOG_DEBUG("Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
if (!in_prog_pyt_blk_nodes.empty()) {
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
auto cond_node = std::vector<torch::jit::Node*>{n};
finalize_block(segmented_blocks, SegmentedBlock::kTorch, cond_node);
continue;
} else if (n->kind() == torch::jit::prim::Loop) {
if (!pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
if (!in_prog_pyt_blk_nodes.empty()) {
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}
if (checkLoopEvaluatable(n)) {
tensorrt_nodes.push_back(n);
in_prog_trt_blk_nodes.push_back(n);
} else {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
auto loop_node = std::vector<torch::jit::Node*>{n};
finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node);
}
continue;
}
pytorch_nodes.push_back(n);
in_prog_pyt_blk_nodes.push_back(n);
}
}

// if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
// min_block_size
if (!pytorch_nodes.empty()) {
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
} else {
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
if (in_prog_trt_blk_nodes.size() >= min_block_size) {
finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
}

if (!in_prog_pyt_blk_nodes.empty()) {
in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}

return std::move(segmented_blocks);
}

std::vector<SegmentedBlock> Partition(
PartitionedGraph Partition(
torch::jit::Block* block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
const PartitionInfo& partition_info) {
LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
PartitionedGraph segmented_blocks = segment_graph(block, partition_info);

// resolve nonTensor inputs/outputs
resolveNonTensorInputs(segmented_blocks);
Expand All @@ -358,11 +397,22 @@ std::vector<SegmentedBlock> Partition(
registerSegmentsOutputs(segmented_blocks, block);

// run shape analysis on each segmented block
runShapeAnalysis(segmented_blocks, input_ivalues_map);
runShapeAnalysis(segmented_blocks, input_ivalues_map, at::kFloat);

LOG_INFO(segmented_blocks);

return segmented_blocks;
}

std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g) {
os << "Partitioned Graph: [";
for (auto b : g) {
os << b;
}
os << "]";
return os;
}

} // namespace partitioning
} // namespace core
} // namespace trtorch
5 changes: 4 additions & 1 deletion core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <vector>
#include <iostream>

#include "core/ir/ir.h"
#include "core/partitioning/PartitionInfo.h"
Expand All @@ -17,11 +18,13 @@ typedef std::vector<SegmentedBlock> PartitionedGraph;

PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info);

std::vector<SegmentedBlock> Partition(
PartitionedGraph Partition(
torch::jit::Block* block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
const PartitionInfo& partition_info);

std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g);

} // namespace partitioning
} // namespace core
} // namespace trtorch
3 changes: 3 additions & 0 deletions core/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ cc_library(
hdrs = [
"jit_util.h",
],
srcs = [
"jit_util.cpp"
],
deps = select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
69 changes: 69 additions & 0 deletions core/util/jit_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "core/util/jit_util.h"

namespace trtorch {
namespace core {
namespace util {

c10::optional<at::ScalarType> getBlockFirstCalcDType(const std::shared_ptr<torch::jit::Block>& b) {
auto ns = b->nodes();

c10::optional<at::ScalarType> dtype = {};

// For each node check the inputs to find a prim:Constant, which will provide a static tensor.
// Use that tensor to determine operating dtype for the first calculation in the block
for (auto n : ns) {
if (n->kind() == torch::jit::prim::Constant) {
// Not really helpful to evaluate typing for constants
continue;
}

auto ins = n->inputs();
auto outs = n->outputs();

bool outputs_tensor = false;
for (auto o : outs) {
if (o->type() == c10::TensorType::get()) {
outputs_tensor = true;
}
}

if (outputs_tensor) {
// If all input tensors are block inputs then this node will not give us useful type info so move to the next one
std::unordered_set<torch::jit::Value*> node_input_set = {ins.begin(), ins.end()};

bool all_n_ins_are_b_ins = true;
for (auto b_in : b->inputs()) {
if (node_input_set.find(b_in) == node_input_set.end()) {
all_n_ins_are_b_ins = false;
}
}

if (all_n_ins_are_b_ins) {
continue;
}


// If node outputs a Tensor it might be a result of tensor calcuation so check to see
// if any inputs to the calculation can give us hints
c10::optional<torch::jit::Node*> const_tensor_n = {};

// Backtrace to constants which will immediately give us the Tensor type if possible
for (auto in : ins) {
if (in->type() == c10::TensorType::get()) {
if (in->node()->kind() == torch::jit::prim::Constant) {
auto const_ival = in->node()->get(c10::Symbol::attr("value"));
dtype = {const_ival.value().toTensor().scalar_type()};
goto exit_first_calc_dtype;
}
}
}
}
}

exit_first_calc_dtype:
return dtype;
}

} // namespace util
} // namespace core
} // namespace trtorch
2 changes: 2 additions & 0 deletions core/util/jit_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
return source_code;
}

c10::optional<at::ScalarType> getBlockFirstCalcDType(const std::shared_ptr<torch::jit::Block>& b);

} // namespace util
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/util/logging/TRTorchLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ namespace {

TRTorchLogger& get_global_logger() {
#ifndef NDEBUG
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true);
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true);
#else
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false);
#endif
Expand Down
2 changes: 1 addition & 1 deletion tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cc_test(
)

test_suite(
name = "partitioning_test",
name = "partitioning_tests",
tests = [
":test_segmentation",
":test_shape_analysis",
Expand Down
Loading

0 comments on commit 8927e77

Please sign in to comment.