Skip to content

Commit

Permalink
implemented fallback and run successfully
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 5, 2021
1 parent 123f026 commit bbd3835
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 16 deletions.
1 change: 1 addition & 0 deletions core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_library(
"//core/conversion",
"//core/runtime",
"//core/lowering",
"//core/partitioning",
"//core/util/logging",
"@tensorrt//:nvinfer"
] + select({
Expand Down
116 changes: 104 additions & 12 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ c10::FunctionSchema GenerateGraphSchema(
void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
std::string& serialized_engine) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
std::string& serialized_engine, int eng_id = 0) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(eng_id), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;
Expand Down Expand Up @@ -130,7 +130,6 @@ void AddEngineToGraph(
}



bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);
Expand All @@ -149,14 +148,14 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
auto g = graph_and_parameters.first;

printf("*********************************************\n");
auto segmented_blocks = partitioning::segment_graph(g);
for (auto &seg :segmented_blocks) {
LOG_INFO(*seg.g_ << "(Seg Graph)\n");
// auto segmented_blocks = partitioning::segment_graph(g);
// for (auto &seg :segmented_blocks) {
// LOG_INFO(*seg.g_ << "(Seg Graph)\n");
// printf("segmented nodes No.: %d\n", seg.nodes.size());
// for (auto &val : seg.inputs) {
// printf("input: %s ", val->debugNameBase().c_str());
// } printf("\n");
}
// }
printf("|||||||||||||||||||||||||||||||||||||||||||||\n");
auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
Expand Down Expand Up @@ -189,6 +188,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
// return new_mod;
//}



void AddTorchSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> output_input_map;
for (size_t i = 0; i < g->outputs().size(); ++i) {
Expand All @@ -211,6 +212,79 @@ void AddTorchSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning:
return;
}

void AddTensorRTSegmentToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> output_input_map;
output_input_map[seg.inputs()[0]] = g->inputs()[0];

for (size_t i = 0; i < g->outputs().size(); ++i) {
auto prev_output = g->outputs()[i];
auto next_input = seg.inputs()[i + 1];
output_input_map[next_input] = prev_output;
}

torch::jit::Node *node;
for (const auto n : seg.nodes()) {
node = partitioning::cloneNode(n, g, output_input_map);
}
for (size_t i = 0; i < g->outputs().size(); ++i) {
g->eraseOutput(i);
}
for (auto &value : node->outputs()) {
g->registerOutput(value);
}
LOG_INFO(*g << "(addTensorRT)\n");
return;
}

//void print_type_dim(c10::TypePtr type) {
// printf("type: %s\n", type->str().c_str());
// auto tensor_type = type->cast<torch::jit::TensorType>();
// auto optional_vec = tensor_type->sizes().sizes().value();
// if (!tensor_type->isComplete()) {
// printf("Not complete type\n");
// return;
// }
// printf("dimension: %d\n", optional_vec.size());
// for (int i = 0; i < optional_vec.size(); ++i) {
// printf("dim(%d) : %d\n", i, optional_vec[i].value());
// }
//}
//
//void InferShapeForGraph(std::shared_ptr<torch::jit::Graph> &g) {
// printf("g input size: %d\n", g->inputs().size());
// torch::jit::PropagateInputShapes(g);
// for (size_t i = 0; i < g->inputs().size(); ++i) {
// printf("type: %s\n", g->inputs()[i]->type()->str().c_str());
// auto tensor_type = g->inputs()[i]->type()->cast<torch::jit::TensorType>();
// std::vector<int64_t> sizes{3, 3, 16, 16};
// g->inputs()[i]->setType(tensor_type->withSizes(sizes));
// print_type_dim((g->inputs()[i]->type()));
// }
//
// for (auto &i : g->inputs()) {
// auto tensor_type = i->type()->cast<torch::jit::TensorType>();
// auto shape = c10::SymbolicShape(input_shape);
// tensor_type->setType(tensor_type->withSymbolicShapes(shape));
// std::vector<int64_t> input_shape{3, 3, 16, 16};
// c10::VaryingShape<int64_t> input_size(input_shape);
// std::vector<int64_t> input_stride{768, 256, 16, 1};
// c10::VaryingShape<int64_t> tensor_stride(input_stride);
// i->setType(c10::TensorType::create(at::kInt, at::kCPU, input_size, tensor_stride, c10::nullopt));
// }
// for (auto &i : g->inputs()) {
// print_type_dim(i->type());
// }
//
// torch::jit::PropagateInputShapes(g);
//
//
// for (auto &i : g->outputs()) {
// print_type_dim((i->type()));
// }
//
// LOG_INFO(*g << "(after infershape)\n");
//}


torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
// TODO: Should be doing a functional transform but need PR #31978
Expand All @@ -230,20 +304,38 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
auto convert_cfg = std::move(cfg.convert_info);
LOG_INFO(*g << "(CompileGraph)\n");

// InferShapeForGraph(g);

// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::segment_graph(g);
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);
bool first = true;

for (auto &seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g_ << "(MiniGraph)\n");

if (seg_block.target == partitioning::SegmentedBlock::kTensorRT) {
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
AddEngineToGraph(new_mod, new_g, engine);
LOG_INFO(*new_g << "(new global graph)\n");
if (first) {
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
AddEngineToGraph(new_mod, new_g, engine);
LOG_INFO(*new_g << "(new global graph)\n");
first = false;
} else {
std::vector<int64_t> temp_range = util::toVec(seg_block.in_shape_[0]);
convert_cfg.input_ranges[0] = conversion::InputRange(temp_range);
printf("before\n");
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto tmp_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, tmp_g, engine, 1);
LOG_INFO(*tmp_g << "(second engine graph)\n");
auto tmp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, tmp_g);
AddTensorRTSegmentToGraph(new_g, tmp_seg_block);
LOG_INFO(*new_g << "(new global graph)\n");
}
} else {
printf("Torch Segment\n");
AddTorchSegmentToGraph(new_g, seg_block);
}
}
printf("after seg parts\n");

auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
Expand Down
65 changes: 64 additions & 1 deletion core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "partitioning.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/api/module.h"
#include "core/util/prelude.h"


namespace trtorch {
namespace core {
Expand Down Expand Up @@ -39,12 +42,70 @@ torch::jit::Node* cloneNode(torch::jit::Node* node, std::shared_ptr<torch::jit::
return new_node;
}

c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
std::vector<c10::Argument> args;
for (auto in : g->inputs()) {
args.push_back(c10::Argument(in->debugName(), in->type()));
}

std::vector<c10::Argument> returns;
for (auto out : g->outputs()) {
returns.push_back(c10::Argument(out->debugName(), out->type()));
}

return c10::FunctionSchema(method_name, method_name, args, returns);
}

std::vector<nvinfer1::Dims> registerSegmentInOutShape(SegmentedBlock &seg_block, std::vector<nvinfer1::Dims> &input_shape) {
auto g = seg_block.g_->copy();
torch::jit::script::Module cur_mod(c10::QualifiedName("module"));

auto self = g->insertInput(0, "self_1");
self->setType(cur_mod.type());

auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), g);
auto schema = getFunctionSchema(cur_method->name(), g);
cur_mod.type()->addMethod(cur_method);
cur_method->setSchema(schema);

std::vector<int64_t> shape;
shape.insert(shape.begin(), std::begin(input_shape[0].d), std::begin(input_shape[0].d) + input_shape[0].nbDims);
auto in = at::randint(5, shape, {at::kCUDA});
std::vector<torch::jit::IValue> jit_inputs_ivalues;
jit_inputs_ivalues.push_back(in.clone());

torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues);
if (!jit_results_ivalues.isTensor()) {
std::cerr << "Mini graph output is NOT a Tensor!\n";
}
auto jit_results_tensor = jit_results_ivalues.toTensor();
auto output_sizes = jit_results_tensor.sizes();
for (auto &i : output_sizes) {
printf("%d\n", i);
}

std::vector<nvinfer1::Dims> output_shape;
output_shape.push_back(util::toDims(output_sizes));
seg_block.register_inshape(input_shape);
seg_block.register_outshape(output_shape);

return output_shape;
}

std::vector<nvinfer1::Dims> extractNvinfer1Dims(std::vector<conversion::InputRange>& input_ranges) {
std::vector<nvinfer1::Dims> res;
for (auto &input_range : input_ranges) {
res.push_back(input_range.input_shape);
}
return res;
}

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g) {
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges) {
std::vector<SegmentedBlock> segmented_blocks;

auto nodes = g->block()->nodes();

// segment the nodes
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) continue;
auto block_target = conversion::OpSupported(n) ? SegmentedBlock::kTensorRT : SegmentedBlock::kTorch;
Expand All @@ -58,8 +119,10 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g)
}
}

std::vector<nvinfer1::Dims> cur_input = extractNvinfer1Dims(input_ranges);
for (auto &seg_block : segmented_blocks) {
seg_block.registerOutput();
cur_input = registerSegmentInOutShape(seg_block, cur_input);
}

return segmented_blocks;
Expand Down
17 changes: 14 additions & 3 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/conversion/conversion.h"
#include "torch/csrc/jit/ir/ir.h"


namespace trtorch {
namespace core {
namespace partitioning {
Expand All @@ -21,6 +22,8 @@ struct SegmentedBlock {

SegmentedBlock(SegmentedBlockTarget blk_target) : target(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}

SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target(blk_target), g_(g) {}

void appendNode(torch::jit::Node* n) {
last_node = cloneNode(n, g_, old_to_new_);
}
Expand All @@ -43,9 +46,17 @@ struct SegmentedBlock {
return g_->nodes();
}

void register_inshape(std::vector<nvinfer1::Dims>& in_shape) {
in_shape_ = in_shape;
}

void register_outshape(std::vector<nvinfer1::Dims>& out_shape) {
out_shape_ = out_shape;
}

SegmentedBlockTarget target;
nvinfer1::Dims in_shape;
nvinfer1::Dims out_shape;
std::vector<nvinfer1::Dims> in_shape_;
std::vector<nvinfer1::Dims> out_shape_;
// std::vector<torch::jit::Value*> inputs_;
// std::vector<torch::jit::Value*> outputs_;
std::shared_ptr<torch::jit::Graph> g_;
Expand All @@ -55,7 +66,7 @@ struct SegmentedBlock {

};

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g);
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges);

}
}
Expand Down

0 comments on commit bbd3835

Please sign in to comment.