Skip to content

Commit

Permalink
fix(//core/lowering): use lower_info as parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Aug 6, 2021
1 parent 74bbd10 commit 370aeb9
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
9 changes: 5 additions & 4 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ void AddEngineToGraph(
}

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name, false);
// Go through Lowering to simplify graph
CompileSpec cfg({});
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);

auto g = graph_and_parameters.first;
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
Expand All @@ -129,7 +130,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info.unfreeze_module);
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);

auto convert_cfg = std::move(cfg.convert_info);
auto g = graph_and_parameters.first;
Expand Down Expand Up @@ -187,7 +188,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
// Compile only forward methods. forward method contains the entire graph.
if (method.name().compare("forward") == 0) {
auto new_g = std::make_shared<torch::jit::Graph>();
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info.unfreeze_module);
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);

auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
Expand Down
8 changes: 4 additions & 4 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
const torch::jit::script::Module& mod,
std::string method_name,
bool unfreeze_module = false) {
auto lowered_mod = unfreeze_module ? mod : LowerModule(mod);
LowerInfo lower_info) {
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod);
auto g = lowered_mod.get_method(method_name).graph();
LOG_GRAPH(*g);

Expand All @@ -75,15 +75,15 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
// unfreeze_module is used to not perform constant folding on weights in the network.
// In quantization aware trained (QAT) models, weights are passed through quantize and
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
if (!unfreeze_module) {
if (!lower_info.unfreeze_module) {
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(g, false);
}

LOG_GRAPH("LibTorch Lowering");
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());

if (unfreeze_module) {
if (lower_info.unfreeze_module) {
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(graph_and_ivalues.first, true);
}
Expand Down
4 changes: 2 additions & 2 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ struct LowerInfo {
};

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse /*=false*/);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse=false);
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
const torch::jit::script::Module& mod,
std::string method_name,
bool unfreeze_module /*=false*/);
LowerInfo lower_info);

} // namespace lowering
} // namespace core
Expand Down
2 changes: 1 addition & 1 deletion tests/core/conversion/converters/test_activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 4e-6));
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 5e-6));
}

TEST(Converters, ATenTanhConvertsCorrectly) {
Expand Down

0 comments on commit 370aeb9

Please sign in to comment.