Skip to content

Commit

Permalink
feat: insert nodes by dependencies for nonTensor inputs/outputs
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 30, 2021
1 parent 54e407e commit 4e32eff
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 46 deletions.
174 changes: 134 additions & 40 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
#include "partitioning.h"
#include <queue>
#include "core/conversion/evaluators/eval_util.h"
#include "core/lowering/passes/passes.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/ir/constants.h"
#include "torch/csrc/jit/passes/constant_pooling.h"

namespace trtorch {
namespace core {
namespace partitioning {

inline bool isTensorOrTensorList(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
}

struct usage_info {
int produce_id = -1;
std::vector<int> torch_use_id;
std::vector<int> tensorrt_use_id;
};

torch::jit::Value* getOrAddInputForValue(
torch::jit::Value* old_value,
std::shared_ptr<torch::jit::Graph>& graph,
Expand Down Expand Up @@ -39,6 +51,7 @@ torch::jit::Node* cloneNode(
auto* block = graph->block();
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); };

// create node for current graph by using the metadata in node and input Values in env
auto new_node = block->appendNode(graph->createClone(node, env));
for (size_t i = 0; i < node->outputs().size(); ++i) {
auto oo = node->outputs()[i];
Expand Down Expand Up @@ -68,7 +81,6 @@ void registerSegmentInOutIValues(
// create a module to run the graph
auto g = seg_block.g();
auto copy_g = g->copy();
// LOG_INFO(*copy_g << "(copy graph)\n");

// create tuple for multiple outputs
if (seg_block.raw_outputs().size() > 1) {
Expand Down Expand Up @@ -110,7 +122,10 @@ void registerSegmentInOutIValues(

// run segments to get outputs for later segments input shape, and other arguments such as Int
std::vector<torch::jit::IValue> jit_results;
printf("before forward\n");
torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues);
printf("after forward\n");

if (jit_results_ivalues.isTuple()) {
auto results = jit_results_ivalues.toTuple()->elements();
for (auto r : results) {
Expand Down Expand Up @@ -149,13 +164,10 @@ std::vector<torch::jit::IValue> generateRandomInputs(std::vector<conversion::Inp
return random_inputs;
}

void registerSegmentsInputsOutputs(
std::vector<SegmentedBlock>& segmented_blocks,
std::shared_ptr<torch::jit::Graph> g) {
void registerSegmentsOutputs(std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
std::set<torch::jit::Value*> input_values;
for (auto& seg_block : segmented_blocks) {
seg_block.registerInputs();
for (auto& input : seg_block.raw_inputs()) {
input_values.insert(input);
}
Expand All @@ -165,51 +177,124 @@ void registerSegmentsInputsOutputs(
input_values.insert(graph_output);
}

// should be careful here because some in-place operations don't return any values
// should be careful here because some in-place operations don't return any values, there is no output for this kind
// of segment identify the output for each mini-graph by checking if any value in this graph is used later we
// shouldn't register nonTensor output for TensorRT segments
for (auto& seg_block : segmented_blocks) {
for (auto& mini_graph_input : input_values) {
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
seg_block.raw_inputs().end() &&
seg_block.contain_raw_input(mini_graph_input)) {
seg_block.contain_raw_value(mini_graph_input)) {
if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT)
continue;
seg_block.registerOutput(mini_graph_input);
}
}
// if no output, then register the last node's output as current graph's output
if (seg_block.raw_outputs().empty()) {
seg_block.registerOutput(seg_block.raw_inputs()[0]);
// for Torch segments, register input as output
if (seg_block.target() == SegmentedBlock::kTorch) {
seg_block.registerOutput(seg_block.raw_inputs()[0]);
} else {
// for TensorRT segments, register last nonInput Tensor outputs
for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) {
for (auto node_output : seg_block.raw_nodes()[i]->outputs()) {
if (isTensorOrTensorList(node_output))
seg_block.registerOutput(node_output);
}
if (!seg_block.raw_outputs().empty())
break;
}
}
}
}
// erase segments which still have no output
segmented_blocks.erase(
std::remove_if(
segmented_blocks.begin(),
segmented_blocks.end(),
[](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }),
segmented_blocks.end());

return;
}

void eraseNonTensorInputsOutputs(
SegmentedBlock& seg_block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
if (seg_block.target() == SegmentedBlock::kTorch)
return;
auto mini_graph = seg_block.g();

for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) {
// erase this input and prepend a prim::Constant if it's not Tensor
if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
!seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]);
seg_block.inputs()[i]->replaceAllUsesWith(new_val);
seg_block.eraseInput(i);
std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
// using bfs to get the DAG dependency nodes for input value
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
std::unordered_set<torch::jit::Node*> visited;
std::vector<torch::jit::Node*> stk;
while (!q.empty()) {
auto cur_val = q.front();
q.pop();
auto node = cur_val->node();
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
stk.push_back(node);
for (auto input : node->inputs()) {
if (!isTensorOrTensorList(input)) {
q.push(input);
}
}
}
}
std::reverse(stk.begin(), stk.end());
return stk;
}

for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) {
if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
!seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
seg_block.eraseOutput(i);
SegmentedBlock injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> nontensor_inputs;
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
nontensor_inputs.push_back(input);
}
}
std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes(nontensor_inputs);
new_block_nodes.insert(new_block_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
return SegmentedBlock(seg_block.target(), new_block_nodes);
}

// not sure to delete this block or just fallback to pytorch
if (seg_block.raw_outputs().empty()) {
seg_block.update_target(SegmentedBlock::kTorch);
void resolveNonTensorInputs(std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
// for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments
std::unordered_map<torch::jit::Value*, usage_info> usage_counts;
for (int i = segmented_blocks.size() - 1; i >= 0; --i) {
for (auto input : segmented_blocks[i].raw_inputs()) {
if (!isTensorOrTensorList(input)) {
segmented_blocks[i].target() == SegmentedBlock::kTorch ? usage_counts[input].torch_use_id.push_back(i)
: usage_counts[input].tensorrt_use_id.push_back(i);
}
}
for (auto& use : usage_counts) {
if (segmented_blocks[i].contain_raw_value(use.first)) {
use.second.produce_id = i;
}
}
}
std::unordered_set<int> updated_segments;
for (auto& use : usage_counts) {
auto use_info = use.second;
// if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
// kTorch segments
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
int first_torch_id = use_info.torch_use_id.front();
if (!updated_segments.count(first_torch_id)) {
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
segmented_blocks[first_torch_id] = new_torch_block;
updated_segments.insert(first_torch_id);
}
} else {
// KTensorRT segments always need to inject nodes for the nonTensor inputs
for (int i : use_info.tensorrt_use_id) {
if (!updated_segments.count(i)) {
auto new_seg_block = injectNodesForNonTensorInputs(segmented_blocks[i]);
segmented_blocks[i] = new_seg_block;
updated_segments.insert(i);
}
}
}
}
return;
}

void construct_segments(
Expand All @@ -231,20 +316,18 @@ void construct_segments(
}
}

std::vector<SegmentedBlock> segment_graph(
void segment_graph(
std::shared_ptr<torch::jit::Graph> g,
std::vector<conversion::InputRange>& input_ranges,
const conversion::TorchFallback& fallback_info) {
const conversion::TorchFallback& fallback_info,
std::vector<SegmentedBlock>& segmented_blocks) {
auto min_block_size = fallback_info.min_block_size;
std::unordered_set<std::string> forced_fallback_operators(
fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end());
std::vector<SegmentedBlock> segmented_blocks;

auto nodes = g->block()->nodes();

// segment the nodes
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;

for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant)
continue;
Expand All @@ -261,22 +344,33 @@ std::vector<SegmentedBlock> segment_graph(
if (!pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
}
}

std::vector<SegmentedBlock> Partition(
std::shared_ptr<torch::jit::Graph> g,
std::vector<conversion::InputRange>& input_ranges,
const conversion::TorchFallback& fallback_info) {
// segment lowering global graph into blocks
std::vector<SegmentedBlock> segmented_blocks;
segment_graph(g, fallback_info, segmented_blocks);

// register input/output torch::jit::Value for segmetned graphs
registerSegmentsInputsOutputs(segmented_blocks, g);
// resolve nonTensor inputs/outputs
resolveNonTensorInputs(segmented_blocks, g);

// register input/output torch::jit::Value for segmented graphs
registerSegmentsOutputs(segmented_blocks, g);

// store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;

std::vector<torch::jit::IValue> random_inputs = generateRandomInputs(input_ranges);
for (size_t i = 0; i < g->inputs().size(); ++i) {
ivalues_maps[g->inputs()[i]] = random_inputs[i];
}

// register every segment's input shape, and it's running output Ivalues
// register every segment's input shape, and it's running output IValues
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
registerSegmentInOutIValues(seg_block, ivalues_maps);
eraseNonTensorInputsOutputs(seg_block, ivalues_maps);
}

return segmented_blocks;
Expand Down
20 changes: 14 additions & 6 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ struct SegmentedBlock {

SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}

SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes)
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
for (auto& node : nodes) {
nodes_.push_back(node);
appendNode(node);
}
registerInputs();
}

SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
Expand All @@ -53,9 +55,9 @@ struct SegmentedBlock {
}
}

void registerOutput(torch::jit::Value* raw_input) {
outputs_.push_back(raw_input);
g_->registerOutput(old_to_new_[raw_input]);
void registerOutput(torch::jit::Value* raw_output) {
outputs_.push_back(raw_output);
g_->registerOutput(old_to_new_[raw_output]);
}

torch::jit::Block* block() {
Expand Down Expand Up @@ -88,7 +90,11 @@ struct SegmentedBlock {
return outputs_;
}

bool contain_raw_input(torch::jit::Value* input) {
const std::vector<torch::jit::Node*>& raw_nodes() const {
return nodes_;
}

bool contain_raw_value(torch::jit::Value* input) {
return old_to_new_.count(input);
}

Expand Down Expand Up @@ -121,15 +127,17 @@ struct SegmentedBlock {
std::vector<nvinfer1::Dims> in_shape_;
std::vector<torch::jit::Value*> inputs_;
std::vector<torch::jit::Value*> outputs_;
std::vector<torch::jit::Node*> nodes_;
std::shared_ptr<torch::jit::Graph> g_;
std::string trt_engine;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
};

std::vector<SegmentedBlock> segment_graph(
std::vector<SegmentedBlock> Partition(
std::shared_ptr<torch::jit::Graph> g,
std::vector<conversion::InputRange>& input_ranges,
const conversion::TorchFallback& fallback_info);

} // namespace partitioning
} // namespace core
} // namespace trtorch

0 comments on commit 4e32eff

Please sign in to comment.