Skip to content

Commit

Permalink
feat(//core/lowering): New freeze model pass and new exception
Browse files Browse the repository at this point in the history
elimination pass

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 1, 2020
1 parent 90c44b9 commit 4acc3fd
Show file tree
Hide file tree
Showing 17 changed files with 188 additions and 46 deletions.
7 changes: 7 additions & 0 deletions core/conversion/converters/impl/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ namespace impl {
namespace {

auto linear_registrations = RegisterNodeConversionPatterns()
// .pattern({
// "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
// [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> {
// auto in = args[0].ITensor();

// }
// })
.pattern({
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
3 changes: 2 additions & 1 deletion core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ cc_library(
],
deps = [
"@libtorch//:libtorch",
"//core/lowering/irfusers"
"//core/lowering/passes",
"//core/util:prelude"
]
)

Expand Down
35 changes: 22 additions & 13 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/fuse_linear.h"
#include "torch/csrc/jit/passes/freeze_module.h"
#include "torch/csrc/jit/passes/lower_graph.h"
#include "torch/csrc/jit/passes/quantization.h"
#include "torch/csrc/jit/passes/guard_elimination.h"

#include "core/util/prelude.h"
#include "core/lowering/lowering.h"
#include "core/lowering/irfusers/irfusers.h"
#include "core/lowering/passes/passes.h"

namespace trtorch {
namespace core {
Expand All @@ -17,30 +20,36 @@ void LowerBlock(torch::jit::Block* b) {
}

void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
torch::jit::EliminateRedundantGuards(g);
passes::EliminateExceptionOrPassPattern(g);
torch::jit::FuseLinear(g);
irfusers::RemoveDropout(g);
irfusers::FuseFlattenLinear(g);
irfusers::ExpandLogSoftmax(g);
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::ExpandLogSoftmax(g);
//passes::RemoveDimExeception(g);
//irfusers::UnpackBatchNorm(g);
//torch::jit::EliminateDeadCode(g);
torch::jit::EliminateDeadCode(g);
LOG_GRAPH(*g);
}

void LowerModule(const torch::jit::script::Module& mod) {
torch::jit::FoldConvBatchNorm2d(mod);
torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
auto mod_ = torch::jit::freeze_module(mod);
return mod_;
}

std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
std::string method_name) {
LowerModule(mod);
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;
auto lowered_mod = LowerModule(mod);
auto g = lowered_mod.get_method(method_name).graph();
LOG_GRAPH(*g);

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(g);
//=[torch::jit::FoldConvBatchNorm2d(lowered_mod);
LOG_GRAPH("LibTorch Lowering");
auto graph_and_parameters = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
// Is this necessary?
lowering::LowerBlock(g->block());
return graph_and_parameters;
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace lowering {

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
void LowerModule(const torch::jit::script::Module& mod);
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
std::string method_name);

Expand Down
8 changes: 5 additions & 3 deletions core/lowering/irfusers/BUILD → core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package(default_visibility = ["//visibility:public"])

cc_library(
name = "irfusers",
name = "passes",
hdrs = [
"irfusers.h",
"passes.h",
],
srcs = [
"fuse_flatten_linear.cpp",
"expand_log_softmax.cpp",
"remove_dropout.cpp",
"unpack_batch_norm.cpp"
"unpack_batch_norm.cpp",
"exception_elimination.cpp"
],
deps = [
"//core/util:prelude",
"@libtorch//:libtorch",
]
)
Expand Down
86 changes: 86 additions & 0 deletions core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

#include "core/util/prelude.h"

#include <vector>

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {
namespace {
using namespace torch::jit;
struct ExceptionOrPassPatternElimination {
ExceptionOrPassPatternElimination(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}

void run() {
LOG_GRAPH("Pre exeception or pass elimination: " << *graph_);
findExceptionOrPassNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);
}

private:
bool isExceptionOrPassNode(Node* n) {
/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// = prim::RaiseException(%45)
/// -> ()
/// block1():
/// -> ()
if (n->blocks().size() != 2) {
return false;
}
auto arm1 = n->blocks()[0];
auto arm2 = n->blocks()[1];
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are used by other nodes
return false;
}

auto arm1_start = arm1->nodes().begin();

if ((*arm1_start)->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
}

if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
}

return true;
}

void findExceptionOrPassNodes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)");
it.destroyCurrent();
}
}
}

std::shared_ptr<Graph> graph_;
};
} // namespace

void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
ExceptionOrPassPatternElimination eppe(std::move(graph));
eppe.run();
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
namespace trtorch {
namespace core {
namespace lowering {
namespace irfusers {
namespace passes {

void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
// Its easier for TensorRT if we seperate softmax and log
// There might need to be a reshape inserted see:
// https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593
// Should the reshapes be added here or in the converter?

// TODO: In the future this should be removed for a deicated log_softmax converter (more efficent)
// But its easier to stand up a working system if the number of op converters is lower
std::string logsoftmax_pattern = R"IR(
Expand All @@ -33,19 +33,19 @@ void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
%dtype : int? = prim::Constant()
%softmax = aten::softmax(%input, %dim, %dtype)
%log_softmax = aten::log(%softmax)
return (%log_softmax))IR";
return (%log_softmax))IR";

torch::jit::SubgraphRewriter logsoftmax_to_softmax_log;
logsoftmax_to_softmax_log.RegisterRewritePattern(logsoftmax_pattern, softmax_log_pattern);
logsoftmax_to_softmax_log.runOnGraph(graph);

torch::jit::SubgraphRewriter logsoftmax_none_to_softmax_log_none;
logsoftmax_none_to_softmax_log_none.RegisterRewritePattern(
logsoftmax_none_pattern, softmax_log_none_pattern);
logsoftmax_none_to_softmax_log_none.runOnGraph(graph);
}

} // namespace irfusers
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace trtorch {
namespace core {
namespace lowering {
namespace irfusers {
namespace passes {

void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
Expand Down Expand Up @@ -33,13 +33,47 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
flatten_linear_to_linear.runOnGraph(graph);


torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
flatten_linear_bias_none_pattern, fused_linear_bias_none);
flatten_linear_bias_none_to_linear.runOnGraph(graph);
}

void FuseFlattenAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %6, %7, %weight, %bias):
%flat = aten::flatten(%input, %6, %7)
%res = aten::linear(%flat, %weight, %bias)
return (%res))IR";
std::string flatten_linear_bias_none_pattern = R"IR(
graph(%input, %6, %7, %weight):
%flat = aten::flatten(%input, %6, %7)
%bias: Tensor? = prim::Constant()
%res = aten::linear(%flat, %weight, %bias)
return (%res))IR";
std::string fused_linear = R"IR(
graph(%input, %6, %7, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

std::string fused_linear_bias_none = R"IR(
graph(%input, %6, %7, %weight):
%bias: Tensor? = prim::Constant()
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
flatten_linear_to_linear.runOnGraph(graph);

torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
flatten_linear_bias_none_pattern, fused_linear_bias_none);
flatten_linear_bias_none_to_linear.runOnGraph(graph);
}
} // namespace irfusers
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
namespace trtorch {
namespace core {
namespace lowering {
namespace irfusers {
namespace passes {

void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

} // namespace irfusers
} // namespace lowering
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace trtorch {
namespace core {
namespace lowering {
namespace irfusers {
namespace passes {

void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
std::string dropout_pattern = R"IR(
Expand All @@ -14,15 +14,15 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
std::string no_dropout_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter remove_dropout;
remove_dropout.RegisterRewritePattern(
dropout_pattern, no_dropout_pattern);
remove_dropout.runOnGraph(graph);
}

} // namespace irfusers
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ RegisterOperators trt_const_op_reg({
namespace trtorch {
namespace core {
namespace lowering {
namespace irfusers {
namespace passes {

// // May be abusing aten::_tensor_to_list(Tensor self) -> int[]
// // Treating it as an emit_constant by the converters
Expand Down Expand Up @@ -60,7 +60,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph) {
unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern);
unpack_batch_norm.runOnGraph(graph);
}
} // Namespace Irfusers
} // Namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/util/logging/TRTorchLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ namespace {
TRTorchLogger& get_global_logger() {
#ifndef NDEBUG
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ",
LogLevel::kDEBUG,
LogLevel::kGRAPH,
true);
#else
static TRTorchLogger global_logger("[TRTorch] - ",
Expand Down
6 changes: 3 additions & 3 deletions core/util/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
l.log(sev, ss.str()); \
} while (0)

#define GRAPH_DUMP_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s)
#define LOG_GRAPH_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s)
#define LOG_DEBUG_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kDEBUG, s)
#define LOG_INFO_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINFO, s)
#define LOG_WARNING_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kWARNING, s)
#define LOG_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kERROR, s)
#define LOG_INTERNAL_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINTERNAL_ERROR, s)

#define GRAPH_DUMP_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s)
#define LOG_GRAPH_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s)
#define LOG_DEBUG_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kDEBUG, s)
#define LOG_INFO_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINFO, s)
#define LOG_WARNING_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kWARNING, s)
#define LOG_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kERROR, s)
#define LOG_INTERNAL_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINTERNAL_ERROR, s)

#define GRAPH_DUMP(...) GET_MACRO(__VA_ARGS__, GRAPH_DUMP_OWN, GRAPH_DUMP_GLOBAL)(__VA_ARGS__)
#define LOG_GRAPH(...) GET_MACRO(__VA_ARGS__, LOG_GRAPH_OWN, LOG_GRAPH_GLOBAL)(__VA_ARGS__)
#define LOG_DEBUG(...) GET_MACRO(__VA_ARGS__, LOG_DEBUG_OWN, LOG_DEBUG_GLOBAL)(__VA_ARGS__)
#define LOG_INFO(...) GET_MACRO(__VA_ARGS__, LOG_INFO_OWN, LOG_INFO_GLOBAL)(__VA_ARGS__)
#define LOG_WARNING(...) GET_MACRO(__VA_ARGS__, LOG_WARNING_OWN, LOG_WARNING_GLOBAL)(__VA_ARGS__)
Expand Down
Loading

0 comments on commit 4acc3fd

Please sign in to comment.