Skip to content

Commit

Permalink
feat(//lowering): centralize lowering and try to use PyTorch Conv2DBN…
Browse files Browse the repository at this point in the history
… folding

before using the converter

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 7, 2020
1 parent 4b58d3b commit fad4a10
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 37 deletions.
52 changes: 18 additions & 34 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@
namespace trtorch {
namespace core {

c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {

std::vector<c10::Argument> args;
for (auto in : g->inputs()) {
args.push_back(c10::Argument(in->debugName(), in->type()));
}

std::vector<c10::Argument> returns;
for (auto out : g->outputs()) {
returns.push_back(c10::Argument(out->debugName(), out->type()));
}

return c10::FunctionSchema(method_name, method_name, args, returns);
}


void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
auto schema = execution::GetEngineFunctionSchema(uid);
auto num_io = execution::GetEngineIO(uid);

Expand All @@ -53,58 +53,42 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
in_val->setType(c10::TensorType::get());
graph_inputs.push_back(in_val);
}

auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second);
g->block()->appendNode(engine_node);

for (auto o : engine_node->outputs()) {
g->registerOutput(o);
}

return;
}

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
std::string method_name) {
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);

// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);

auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");

// Is this necessary?
lowering::LowerBlock(g->block());


return conversion::VerifyConverterSupportForBlock(g->block());
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name,
conversion::ExtraInfo cfg) {
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);

// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);

auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);

LOG_INFO(*g << "(CompileGraph)\n");

// Is this necessary?
lowering::LowerBlock(g->block());

auto engine = ConvertBlockToEngine(g->block(), cfg, named_params);
return std::move(engine);
}
Expand All @@ -128,7 +112,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,

return new_mod;
}

} // namespace core
} // namespace trtorch

28 changes: 26 additions & 2 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/lower_graph.h"
#include "torch/csrc/jit/passes/quantization.h"

#include "core/lowering/lowering.h"
#include "core/lowering/irfusers/irfusers.h"
Expand All @@ -22,7 +24,29 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
//irfusers::UnpackBatchNorm(g);
//torch::jit::EliminateDeadCode(g);
}


void LowerModule(const torch::jit::script::Module& mod) {
torch::jit::FoldConvBatchNorm2d(mod);
}

std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
std::string method_name) {
LowerModule(mod);
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);
// Is this necessary?
lowering::LowerBlock(g->block());
return graph_and_parameters;
}


} // namespace lowering
} // namespace core
} // namespace trtorch
5 changes: 4 additions & 1 deletion core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
namespace trtorch {
namespace core {
namespace lowering {

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
void LowerModule(const torch::jit::script::Module& mod);
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
std::string method_name);

} // namespace lowering
} // namespace core
Expand Down

0 comments on commit fad4a10

Please sign in to comment.