Skip to content

Commit

Permalink
feat(//core/lowering): Fuse aten::addmm branches into a single
Browse files Browse the repository at this point in the history
aten::addm op that can be expanded by a later pass

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 31, 2020
1 parent db20098 commit 68f0317
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 2 deletions.
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::Conv2DToConvolution(g);
passes::FuseAddMMBranches(g);
passes::UnpackAddMM(g);
//passes::UnpackBatchNorm(g);
passes::UnpackLogSoftmax(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
srcs = [
"conv2d_to_convolution.cpp",
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"fuse_flatten_linear.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct ExceptionOrPassPatternElimination {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)");
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
it.destroyCurrent();
}
}
Expand Down
100 changes: 100 additions & 0 deletions core/lowering/passes/fuse_addmm_branches.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

#include "core/util/prelude.h"

#include <vector>

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {
namespace {
using namespace torch::jit;
struct AddMMBranchFusion {
AddMMBranchFusion(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}

void run() {
findAddMMVariantsNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_);
}

private:
bool isAddMMVariantsNode(Node* n) {
/// Check if this Node hosts a pattern like so:
/// %ret : Tensor = prim::If(%622)
/// block0():
/// %ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
/// -> (%ret.1)
/// block1():
/// %output.1 : Tensor = aten::matmul(%x9.1, %3677)
/// %output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
/// -> (%output0.1)

if (n->blocks().size() != 2) {
return false;
}
auto arm1 = n->blocks()[0];
auto arm2 = n->blocks()[1];

auto arm1_start = arm1->nodes().begin();
if ((*arm1_start)->kind().toQualString() != std::string("aten::addmm")
&& (*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the aten::addmm op and the return
return false;
}

auto arm2_start = arm2->nodes().begin();
if ((*arm2_start)->kind().toQualString() != std::string("aten::matmul")
&& (*(++arm2_start))->kind().toQualString() != std::string("aten::add_")
&& (*(++arm2_start))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
}

return true;
}

void findAddMMVariantsNodes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isAddMMVariantsNode(n)) {
LOG_GRAPH("Found that node " << *n << " is an AddMM variants node (FuseAddMMBranches)" << std::endl);
auto arm1 = n->blocks()[0];
auto arm1_start = arm1->nodes().begin();

auto input_values = (*arm1_start)->inputs();

auto new_addmm_node = b->owningGraph()->create(c10::Symbol::fromQualString("aten::addmm"), input_values, 1);
n->replaceAllUsesWith(new_addmm_node);

auto old_insert_point = b->owningGraph()->insertPoint();
b->owningGraph()->setInsertPoint(n);
b->owningGraph()->insertNode(new_addmm_node);
b->owningGraph()->setInsertPoint(old_insert_point);

it.destroyCurrent();
}
}
}

std::shared_ptr<Graph> graph_;
};
} // namespace

void FuseAddMMBranches(std::shared_ptr<Graph> graph) {
AddMMBranchFusion ammbf(std::move(graph));
ammbf.run();
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
3 changes: 2 additions & 1 deletion core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ namespace lowering {
namespace passes {

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

} // namespace irfusers
} // namespace lowering
Expand Down

0 comments on commit 68f0317

Please sign in to comment.