Skip to content

Commit

Permalink
resolved dependency problems in edge cases
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 8, 2021
1 parent 1ca13d8 commit 0d28164
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 48 deletions.
48 changes: 32 additions & 16 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,27 +180,34 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::



void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> output_input_map;
void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg,
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new_g) {
//old_to_new_g contains: original_graph value => new graph value, mini_graph value -> new graph value, new graph value -> mini_graph value
size_t input_idx = 0;
if (seg.target == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = g->insertInput(0, "self_1");
self->setType(seg.inputs()[0]->type());
}
output_input_map[seg.inputs()[input_idx++]] = g->inputs()[0];
old_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
}

for (size_t i = 0; i < g->outputs().size(); ++i) {
auto prev_output = g->outputs()[i];
auto next_input = seg.inputs()[input_idx++];
output_input_map[next_input] = prev_output;
for (auto &raw_input : seg.raw_inputs()) {
if (old_to_new_g.count(raw_input)) {
old_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
}
}

torch::jit::Node *node;
for (const auto n : seg.nodes()) {
node = partitioning::cloneNode(n, g, output_input_map);
node = partitioning::cloneNode(n, g, old_to_new_g);
}

// original graph value => new global graph value
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
}

for (size_t i = 0; i < g->outputs().size(); ++i) {
g->eraseOutput(i);
}
Expand Down Expand Up @@ -248,20 +255,29 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);

for (auto &seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "SegmentedBlockGraph");
}

int trt_engine_id = 0;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
for (auto &seg_block : segmented_blocks) {
// LOG_INFO(*seg_block.g_ << "SegmentedBlockGraph");
if (seg_block.target == partitioning::SegmentedBlock::kTensorRT) {
std::vector<int64_t> input_range = util::toVec(seg_block.in_shape_[0]);
convert_cfg.input_ranges[0] = conversion::InputRange(input_range);
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
std::vector<conversion::InputRange> input_ranges;
for (auto &shape : seg_block.in_shape()) {
input_ranges.push_back(conversion::InputRange(util::toVec(shape)));
}
convert_cfg.input_ranges = input_ranges;
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);
// printf("type: %s\n", temp_g->inputs()[0]->type()->str().c_str());
auto temp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, temp_g);
AddSegmentedBlockToGraph(new_g, temp_seg_block);
// auto temp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, temp_g);
// AddSegmentedBlockToGraph(new_g, temp_seg_block);
seg_block.update_graph(temp_g);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
AddSegmentedBlockToGraph(new_g, seg_block);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
}
}

Expand Down
107 changes: 85 additions & 22 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ torch::jit::Value* getOrAddInputForValue(torch::jit::Value* old_value, std::shar
}
auto new_value = graph->block()->addInput();
old_to_new[old_value] = new_value;
// mapping from new graph input Values to original graph values
old_to_new[new_value] = old_value;
new_value->copyMetadata(old_value);
return new_value;
} else {
Expand Down Expand Up @@ -56,37 +58,53 @@ c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr<t
return c10::FunctionSchema(method_name, method_name, args, returns);
}

std::vector<nvinfer1::Dims> registerSegmentInOutShape(SegmentedBlock &seg_block, std::vector<nvinfer1::Dims> &input_shape) {
auto g = seg_block.g_->copy();
void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, nvinfer1::Dims> &input_shape_map) {
// create a module to run the graph
auto g = seg_block.g();
auto copy_g = g->copy();
torch::jit::script::Module cur_mod(c10::QualifiedName("module"));

auto self = g->insertInput(0, "self_1");
auto self = copy_g->insertInput(0, "self_1");
self->setType(cur_mod.type());

auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
auto schema = getFunctionSchema(cur_method->name(), g);
auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), copy_g);
auto schema = getFunctionSchema(cur_method->name(), copy_g);
cur_mod.type()->addMethod(cur_method);
cur_method->setSchema(schema);

std::vector<int64_t> shape;
shape.insert(shape.begin(), std::begin(input_shape[0].d), std::begin(input_shape[0].d) + input_shape[0].nbDims);
auto in = at::randint(5, shape, {at::kCUDA});
std::vector<torch::jit::IValue> jit_inputs_ivalues;
jit_inputs_ivalues.push_back(in.clone());

// set inputs ivalues
for (auto &input : seg_block.raw_inputs()) {
std::vector<int64_t> shape;
nvinfer1::Dims cur_shape = input_shape_map[input];
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
auto in = at::randint(5, shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
}

std::vector<at::Tensor> jit_results;
torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues);
if (!jit_results_ivalues.isTensor()) {
std::cerr << "Mini graph output is NOT a Tensor!\n";
if (jit_results_ivalues.isTensor()) {
jit_results.push_back(jit_results_ivalues.toTensor());
} else {
auto results = jit_results_ivalues.toTuple()->elements();
for (auto r : results) {
jit_results.push_back(r.toTensor());
}
}
auto jit_results_tensor = jit_results_ivalues.toTensor();
auto output_sizes = jit_results_tensor.sizes();

std::vector<nvinfer1::Dims> output_shape;
output_shape.push_back(util::toDims(output_sizes));
seg_block.register_inshape(input_shape);
seg_block.register_outshape(output_shape);
size_t idx = 0;
for (auto &output : seg_block.raw_outputs()) {
input_shape_map[output] = util::toDims(jit_results[idx++].sizes());
}

std::vector<nvinfer1::Dims> input_shape;
for (auto &i : seg_block.raw_inputs()) {
input_shape.push_back(input_shape_map[i]);
}

return output_shape;
seg_block.register_inshape(input_shape);
}

std::vector<nvinfer1::Dims> extractNvinfer1Dims(std::vector<conversion::InputRange>& input_ranges) {
Expand All @@ -97,6 +115,35 @@ std::vector<nvinfer1::Dims> extractNvinfer1Dims(std::vector<conversion::InputRan
return res;
}

void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
std::set<torch::jit::Value*> input_values;
for (auto &seg_block : segmented_blocks) {
seg_block.registerInputs();
for (auto &input : seg_block.raw_inputs()) {
input_values.insert(input);
}
}

// for (auto &graph_input : g->inputs()) {
// input_values.erase(graph_input);
// }

for (auto &graph_output : g->outputs()) {
input_values.insert(graph_output);
}

for (auto &mini_graph_input : input_values) {
for (auto &seg_block : segmented_blocks) {
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input)
== seg_block.raw_inputs().end() && seg_block.contain_raw_input(mini_graph_input)) {
seg_block.registerOutput(mini_graph_input);
}
}
}

return;
}

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges) {
std::vector<SegmentedBlock> segmented_blocks;

Expand All @@ -105,9 +152,10 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
// segment the nodes
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) continue;

auto block_target = conversion::OpSupported(n) ? SegmentedBlock::kTensorRT : SegmentedBlock::kTorch;

if (segmented_blocks.empty() || block_target != segmented_blocks.back().target) {
if (segmented_blocks.empty() || block_target != segmented_blocks.back().target()) {
SegmentedBlock cur_block(block_target);
cur_block.appendNode(n);
segmented_blocks.push_back(cur_block);
Expand All @@ -116,10 +164,25 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
}
}

std::vector<nvinfer1::Dims> cur_input = extractNvinfer1Dims(input_ranges);
printf("before register input\n");
registerSegmentsInputsOutputs(segmented_blocks, g);

std::vector<nvinfer1::Dims> graph_inputs_shape = extractNvinfer1Dims(input_ranges);
std::unordered_map<torch::jit::Value*, nvinfer1::Dims> input_shape_map;

for (size_t i = 0; i < g->inputs().size(); ++i) {
input_shape_map[g->inputs()[i]] = graph_inputs_shape[i];
}

for (auto &seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "In partitioning\n");
}

printf("before register shapes\n");

for (auto &seg_block : segmented_blocks) {
seg_block.registerOutput();
cur_input = registerSegmentInOutShape(seg_block, cur_input);
printf("h\n");
registerSegmentInOutShape(seg_block, input_shape_map);
}

return segmented_blocks;
Expand Down
62 changes: 52 additions & 10 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,29 @@ struct SegmentedBlock {
kTensorRT,
};

SegmentedBlock(SegmentedBlockTarget blk_target) : target(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}

SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target(blk_target), g_(g) {}
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}

enum SegmentedBlockTarget target() {
return target_;
}

void appendNode(torch::jit::Node* n) {
last_node = cloneNode(n, g_, old_to_new_);
cloneNode(n, g_, old_to_new_);
}

void registerOutput() {
for (auto &value : last_node->outputs()) {
g_->registerOutput(value);
void registerInputs() {
for (auto &value : g_->inputs()) {
inputs_.push_back(old_to_new_[value]);
}
}

void registerOutput(torch::jit::Value* raw_input) {
outputs_.push_back(raw_input);
g_->registerOutput(old_to_new_[raw_input]);
}

torch::jit::Block* block() {
return g_->block();
}
Expand All @@ -42,6 +51,22 @@ struct SegmentedBlock {
return g_->inputs();
}

c10::ArrayRef<torch::jit::Value*> outputs() {
return g_->outputs();
}

const std::vector<torch::jit::Value*> &raw_inputs() const {
return inputs_;
}

const std::vector<torch::jit::Value*> &raw_outputs() const {
return outputs_;
}

bool contain_raw_input(torch::jit::Value* input) {
return old_to_new_.count(input);
}

torch::jit::graph_node_list nodes() {
return g_->nodes();
}
Expand All @@ -54,14 +79,31 @@ struct SegmentedBlock {
out_shape_ = out_shape;
}

SegmentedBlockTarget target;
const std::vector<nvinfer1::Dims>& in_shape() const {
return in_shape_;
}

const std::vector<nvinfer1::Dims>& out_shape() const {
return out_shape_;
}

const std::shared_ptr<torch::jit::Graph>& g() const {
return g_;
}

void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
g_ = new_g;
}

private:
SegmentedBlockTarget target_;
std::vector<nvinfer1::Dims> in_shape_;
std::vector<nvinfer1::Dims> out_shape_;
// std::vector<torch::jit::Value*> inputs_;
// std::vector<torch::jit::Value*> outputs_;
std::vector<torch::jit::Value*> inputs_;
std::vector<torch::jit::Value*> outputs_;
std::shared_ptr<torch::jit::Graph> g_;
std::string trt_engine;
torch::jit::Node* last_node;
//last node on original global graph
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;

};
Expand Down

0 comments on commit 0d28164

Please sign in to comment.