Skip to content

Commit

Permalink
fix(aten::linear): Fixes new issues in 1.8 that cause script based
Browse files Browse the repository at this point in the history
models to fail while trace models work. Seems to be down to the fact
that the two create different graphs and script was having issues with
aten::linear

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 12, 2021
1 parent 71c4dcb commit c5057f8
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 22 deletions.
14 changes: 14 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,20 @@ auto aten_registrations TRTORCH_UNUSED =
EvalOptions().validSchemas({
"aten::numel(Tensor self) -> int",
})})
.evaluator({c10::Symbol::fromQualString("aten::t"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto tensor_var = args.at(n->input(0));
if (tensor_var.IValue()->isTensor()) {
auto tensor = tensor_var.unwrapToTensor();
return tensor.t();
} else {
TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
return {};
}
},
EvalOptions().validSchemas({
"aten::t(Tensor self) -> Tensor",
})})
.evaluator({c10::Symbol::fromQualString("aten::dim"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto tensor_var = args.at(n->input(0));
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
torch::jit::LowerAllTuples(g);
passes::RemoveContiguous(g);
passes::RemoveDropout(g);
passes::FuseFlattenLinear(g);
passes::LinearToAddMM(g);
passes::Conv2DToConvolution(g);
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ cc_library(
"conv3d_to_convolution.cpp",
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"fuse_flatten_linear.cpp",
"linear_to_addmm.cpp",
"remove_bn_dim_check.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,31 @@ namespace core {
namespace lowering {
namespace passes {

void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %6, %7, %weight, %bias):
%flat = aten::flatten(%input, %6, %7)
%res = aten::linear(%flat, %weight, %bias)
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
std::string flatten_linear_bias_none_pattern = R"IR(
graph(%input, %6, %7, %weight):
%flat = aten::flatten(%input, %6, %7)
graph(%input, %weight):
%bias: Tensor? = prim::Constant()
%res = aten::linear(%flat, %weight, %bias)
return (%res))IR";
std::string fused_linear = R"IR(
graph(%input, %6, %7, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

std::string fused_linear = R"IR(
graph(%input, %weight_t, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add_(%b_f, %mm, %1)
return (%out))IR";
std::string fused_linear_bias_none = R"IR(
graph(%input, %6, %7, %weight):
%bias: Tensor? = prim::Constant()
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
graph(%input, %weight_t):
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
return (%mm))IR";

torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
Expand All @@ -38,7 +40,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none);
flatten_linear_bias_none_to_linear.runOnGraph(graph);
LOG_GRAPH("Post flatten linear: " << *graph);
LOG_GRAPH("Post linear to addmm: " << *graph);
}

} // namespace passes
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace passes {
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ config_setting(
}
)

lowering_test(
name = "test_linear_to_addmm",
)

lowering_test(
name = "test_remove_contiguous_pass",
)
Expand All @@ -30,6 +34,7 @@ lowering_test(
test_suite(
name = "lowering_tests",
tests = [
":test_linear_to_addmm",
":test_remove_contiguous_pass",
":test_remove_to",
":test_remove_detach_pass",
Expand Down
34 changes: 34 additions & 0 deletions tests/core/lowering/test_linear_to_addmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"

TEST(LoweringPasses, LinearToAddMM) {
std::string source_graph = R"IR(
graph(%input, %6, %7, %weight, %bias):
%flat = aten::flatten(%input, %6, %7)
%res = aten::linear(%flat, %weight, %bias)
return (%res))IR";
std::string target_graph = R"IR(
graph(%input, %6, %7, %weight_t, %bias):
%1: int = prim::Constant[value=1]()
%flat = aten::flatten(%input, %6, %7)
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%flat, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add_(%b_f, %mm, %1)
return (%out))IR";

trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
trtorch::core::lowering::passes::LinearToAddMM(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}
6 changes: 3 additions & 3 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@
"path": "both"
},
"resnet18": {
"model": torch.hub.load('pytorch/vision:v0.8.2', 'resnet18', pretrained=True),
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True),
"path": "both"
},
"resnet50": {
"model": torch.hub.load('pytorch/vision:v0.8.2', 'resnet50', pretrained=True),
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True),
"path": "both"
},
"fcn_resnet101": {
"model": torch.hub.load('pytorch/vision:v0.8.2', 'fcn_resnet101', pretrained=True),
"model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True),
"path": "script"
},
"ssd": {
Expand Down
1 change: 0 additions & 1 deletion tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_is_colored_output_on(self):
def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet50(pretrained=True)))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
Expand Down

0 comments on commit c5057f8

Please sign in to comment.