Skip to content

Commit

Permalink
refactored the new graph output registration
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 9, 2021
1 parent 0d28164 commit 55e0510
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 36 deletions.
35 changes: 7 additions & 28 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,38 +200,18 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin

torch::jit::Node *node;
for (const auto n : seg.nodes()) {
node = partitioning::cloneNode(n, g, old_to_new_g);
partitioning::cloneNode(n, g, old_to_new_g);
}

// original graph value => new global graph value
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
}

for (size_t i = 0; i < g->outputs().size(); ++i) {
g->eraseOutput(i);
}
for (auto &value : node->outputs()) {
g->registerOutput(value);
}
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
return;
}

//void print_type_dim(c10::TypePtr type) {
// printf("type: %s\n", type->str().c_str());
// auto tensor_type = type->cast<torch::jit::TensorType>();
// auto optional_vec = tensor_type->sizes().sizes().value();
// if (!tensor_type->isComplete()) {
// printf("Not complete type\n");
// return;
// }
// printf("dimension: %d\n", optional_vec.size());
// for (int i = 0; i < optional_vec.size(); ++i) {
// printf("dim(%d) : %d\n", i, optional_vec[i].value());
// }
//}


torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
// TODO: Should be doing a functional transform but need PR #31978
Expand All @@ -255,10 +235,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);

for (auto &seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "SegmentedBlockGraph");
}

int trt_engine_id = 0;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
for (auto &seg_block : segmented_blocks) {
Expand All @@ -271,16 +247,19 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);
// printf("type: %s\n", temp_g->inputs()[0]->type()->str().c_str());
// auto temp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, temp_g);
// AddSegmentedBlockToGraph(new_g, temp_seg_block);
seg_block.update_graph(temp_g);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
}
}

for (auto &output : g->outputs()) {
new_g->registerOutput(old_to_new_g[output]);
}

LOG_INFO(*new_g << "(After CompileGraph)\n");

auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
Expand Down
8 changes: 0 additions & 8 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
}
}

printf("before register input\n");
registerSegmentsInputsOutputs(segmented_blocks, g);

std::vector<nvinfer1::Dims> graph_inputs_shape = extractNvinfer1Dims(input_ranges);
Expand All @@ -175,13 +174,6 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
}

for (auto &seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "In partitioning\n");
}

printf("before register shapes\n");

for (auto &seg_block : segmented_blocks) {
printf("h\n");
registerSegmentInOutShape(seg_block, input_shape_map);
}

Expand Down

0 comments on commit 55e0510

Please sign in to comment.