Skip to content

Commit

Permalink
clean messy code
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 5, 2021
1 parent bbd3835 commit 1ca13d8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 98 deletions.
118 changes: 23 additions & 95 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ c10::FunctionSchema GenerateGraphSchema(
void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
std::string& serialized_engine, int eng_id = 0) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(eng_id), serialized_engine);
std::string& serialized_engine, int engine_id = 0) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(engine_id), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;
Expand Down Expand Up @@ -147,16 +147,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
auto convert_cfg = std::move(cfg.convert_info);
auto g = graph_and_parameters.first;

printf("*********************************************\n");
// auto segmented_blocks = partitioning::segment_graph(g);
// for (auto &seg :segmented_blocks) {
// LOG_INFO(*seg.g_ << "(Seg Graph)\n");
// printf("segmented nodes No.: %d\n", seg.nodes.size());
// for (auto &val : seg.inputs) {
// printf("input: %s ", val->debugNameBase().c_str());
// } printf("\n");
// }
printf("|||||||||||||||||||||||||||||||||||||||||||||\n");
auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);

Expand Down Expand Up @@ -190,35 +180,20 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::



void AddTorchSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg) {
void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> output_input_map;
for (size_t i = 0; i < g->outputs().size(); ++i) {
auto prev_output = g->outputs()[i];
auto next_input = seg.inputs()[i];
output_input_map[next_input] = prev_output;
size_t input_idx = 0;
if (seg.target == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = g->insertInput(0, "self_1");
self->setType(seg.inputs()[0]->type());
}
output_input_map[seg.inputs()[input_idx++]] = g->inputs()[0];
}

torch::jit::Node *node;
for (const auto n : seg.nodes()) {
node = partitioning::cloneNode(n, g, output_input_map);
}
for (size_t i = 0; i < g->outputs().size(); ++i) {
g->eraseOutput(i);
}
for (auto &value : node->outputs()) {
g->registerOutput(value);
}
LOG_INFO(*g << "(addTorch)\n");
return;
}

void AddTensorRTSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> output_input_map;
output_input_map[seg.inputs()[0]] = g->inputs()[0];

for (size_t i = 0; i < g->outputs().size(); ++i) {
auto prev_output = g->outputs()[i];
auto next_input = seg.inputs()[i + 1];
auto next_input = seg.inputs()[input_idx++];
output_input_map[next_input] = prev_output;
}

Expand All @@ -232,7 +207,7 @@ void AddTensorRTSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioni
for (auto &value : node->outputs()) {
g->registerOutput(value);
}
LOG_INFO(*g << "(addTensorRT)\n");
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
return;
}

Expand All @@ -249,41 +224,6 @@ void AddTensorRTSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioni
// printf("dim(%d) : %d\n", i, optional_vec[i].value());
// }
//}
//
//void InferShapeForGraph(std::shared_ptr<torch::jit::Graph> &g) {
// printf("g input size: %d\n", g->inputs().size());
// torch::jit::PropagateInputShapes(g);
// for (size_t i = 0; i < g->inputs().size(); ++i) {
// printf("type: %s\n", g->inputs()[i]->type()->str().c_str());
// auto tensor_type = g->inputs()[i]->type()->cast<torch::jit::TensorType>();
// std::vector<int64_t> sizes{3, 3, 16, 16};
// g->inputs()[i]->setType(tensor_type->withSizes(sizes));
// print_type_dim((g->inputs()[i]->type()));
// }
//
// for (auto &i : g->inputs()) {
// auto tensor_type = i->type()->cast<torch::jit::TensorType>();
// auto shape = c10::SymbolicShape(input_shape);
// tensor_type->setType(tensor_type->withSymbolicShapes(shape));
// std::vector<int64_t> input_shape{3, 3, 16, 16};
// c10::VaryingShape<int64_t> input_size(input_shape);
// std::vector<int64_t> input_stride{768, 256, 16, 1};
// c10::VaryingShape<int64_t> tensor_stride(input_stride);
// i->setType(c10::TensorType::create(at::kInt, at::kCPU, input_size, tensor_stride, c10::nullopt));
// }
// for (auto &i : g->inputs()) {
// print_type_dim(i->type());
// }
//
// torch::jit::PropagateInputShapes(g);
//
//
// for (auto &i : g->outputs()) {
// print_type_dim((i->type()));
// }
//
// LOG_INFO(*g << "(after infershape)\n");
//}


torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
Expand All @@ -304,38 +244,26 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
auto convert_cfg = std::move(cfg.convert_info);
LOG_INFO(*g << "(CompileGraph)\n");

// InferShapeForGraph(g);

// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);
bool first = true;

int trt_engine_id = 0;
for (auto &seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g_ << "(MiniGraph)\n");
// LOG_INFO(*seg_block.g_ << "SegmentedBlockGraph");
if (seg_block.target == partitioning::SegmentedBlock::kTensorRT) {
if (first) {
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
AddEngineToGraph(new_mod, new_g, engine);
LOG_INFO(*new_g << "(new global graph)\n");
first = false;
} else {
std::vector<int64_t> temp_range = util::toVec(seg_block.in_shape_[0]);
convert_cfg.input_ranges[0] = conversion::InputRange(temp_range);
printf("before\n");
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto tmp_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, tmp_g, engine, 1);
LOG_INFO(*tmp_g << "(second engine graph)\n");
auto tmp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, tmp_g);
AddTensorRTSegmentToGraph(new_g, tmp_seg_block);
LOG_INFO(*new_g << "(new global graph)\n");
}
std::vector<int64_t> input_range = util::toVec(seg_block.in_shape_[0]);
convert_cfg.input_ranges[0] = conversion::InputRange(input_range);
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);
// printf("type: %s\n", temp_g->inputs()[0]->type()->str().c_str());
auto temp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, temp_g);
AddSegmentedBlockToGraph(new_g, temp_seg_block);
} else {
printf("Torch Segment\n");
AddTorchSegmentToGraph(new_g, seg_block);
AddSegmentedBlockToGraph(new_g, seg_block);
}
}
printf("after seg parts\n");

auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
Expand Down
3 changes: 0 additions & 3 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ std::vector<nvinfer1::Dims> registerSegmentInOutShape(SegmentedBlock &seg_block,
}
auto jit_results_tensor = jit_results_ivalues.toTensor();
auto output_sizes = jit_results_tensor.sizes();
for (auto &i : output_sizes) {
printf("%d\n", i);
}

std::vector<nvinfer1::Dims> output_shape;
output_shape.push_back(util::toDims(output_sizes));
Expand Down

0 comments on commit 1ca13d8

Please sign in to comment.