From fad4a10c842a23af2fcc02518e604ebda94f1dcf Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Tue, 7 Apr 2020 15:55:10 -0700 Subject: [PATCH] feat(//lowering): centralize lowering and try to use PyTorch Conv2DBN folding before using the converter Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/compiler.cpp | 52 +++++++++++++------------------------- core/lowering/lowering.cpp | 28 ++++++++++++++++++-- core/lowering/lowering.h | 5 +++- 3 files changed, 48 insertions(+), 37 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 33e2f04bff..459efd1356 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -24,24 +24,24 @@ namespace trtorch { namespace core { -c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr& g) { +c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr& g) { std::vector args; for (auto in : g->inputs()) { args.push_back(c10::Argument(in->debugName(), in->type())); } - + std::vector returns; for (auto out : g->outputs()) { returns.push_back(c10::Argument(out->debugName(), out->type())); } - + return c10::FunctionSchema(method_name, method_name, args, returns); } void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr& g, std::string& serialized_engine) { - execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine); + execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine); auto schema = execution::GetEngineFunctionSchema(uid); auto num_io = execution::GetEngineIO(uid); @@ -53,58 +53,42 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptrsetType(c10::TensorType::get()); graph_inputs.push_back(in_val); } - + auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef(graph_inputs), num_io.second); g->block()->appendNode(engine_node); for (auto o : engine_node->outputs()) { g->registerOutput(o); } - + return; } bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) { - auto g = mod.get_method(method_name).graph(); - // Go through PyTorch Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); - - g = graph_and_parameters.first; - - // Go through TRTorch Lowering to reformat graph to be conversion friendly - // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) - lowering::LowerGraph(g); - + // Go through Lowering to simplify graph and extract weight parameters + auto graph_and_parameters = lowering::Lower(mod, method_name); + + auto g = graph_and_parameters.first; auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n"); - - // Is this necessary? - lowering::LowerBlock(g->block()); - + return conversion::VerifyConverterSupportForBlock(g->block()); } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, conversion::ExtraInfo cfg) { - auto g = mod.get_method(method_name).graph(); - // Go through PyTorch Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); - - g = graph_and_parameters.first; - - // Go through TRTorch Lowering to reformat graph to be conversion friendly - // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) - lowering::LowerGraph(g); - + // Go through Lowering to simplify graph and extract weight parameters + auto graph_and_parameters = lowering::Lower(mod, method_name); + + auto g = graph_and_parameters.first; auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); + LOG_INFO(*g << "(CompileGraph)\n"); - - // Is this necessary? - lowering::LowerBlock(g->block()); + auto engine = ConvertBlockToEngine(g->block(), cfg, named_params); return std::move(engine); } @@ -128,7 +112,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, return new_mod; } - + } // namespace core } // namespace trtorch diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index f7f1254e00..d18a9d612d 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -1,5 +1,7 @@ -#include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/fuse_linear.h" +#include "torch/csrc/jit/passes/lower_graph.h" +#include "torch/csrc/jit/passes/quantization.h" #include "core/lowering/lowering.h" #include "core/lowering/irfusers/irfusers.h" @@ -22,7 +24,29 @@ void LowerGraph(std::shared_ptr& g) { //irfusers::UnpackBatchNorm(g); //torch::jit::EliminateDeadCode(g); } - + +void LowerModule(const torch::jit::script::Module& mod) { + torch::jit::FoldConvBatchNorm2d(mod); +} + +std::pair, std::vector> Lower(const torch::jit::script::Module& mod, + std::string method_name) { + LowerModule(mod); + auto g = mod.get_method(method_name).graph(); + // Go through PyTorch Lowering to simplify graph and extract weight parameters + auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); + + g = graph_and_parameters.first; + + // Go through TRTorch Lowering to reformat graph to be conversion friendly + // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) + lowering::LowerGraph(g); + // Is this necessary? + lowering::LowerBlock(g->block()); + return graph_and_parameters; +} + + } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index ed34ee30a0..95547c778c 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -5,9 +5,12 @@ namespace trtorch { namespace core { namespace lowering { - + void LowerBlock(torch::jit::Block* b); void LowerGraph(std::shared_ptr& g); +void LowerModule(const torch::jit::script::Module& mod); +std::pair, std::vector> Lower(const torch::jit::script::Module& mod, + std::string method_name); } // namespace lowering } // namespace core