diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 587bfe5b6e..53d3221409 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -396,6 +396,20 @@ auto aten_registrations TRTORCH_UNUSED = EvalOptions().validSchemas({ "aten::numel(Tensor self) -> int", })}) + .evaluator({c10::Symbol::fromQualString("aten::t"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto tensor_var = args.at(n->input(0)); + if (tensor_var.IValue()->isTensor()) { + auto tensor = tensor_var.unwrapToTensor(); + return tensor.t(); + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor"); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::t(Tensor self) -> Tensor", + })}) .evaluator({c10::Symbol::fromQualString("aten::dim"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index aec9ebb8e1..b25af1b8de 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -36,7 +36,7 @@ void LowerGraph(std::shared_ptr& g) { torch::jit::LowerAllTuples(g); passes::RemoveContiguous(g); passes::RemoveDropout(g); - passes::FuseFlattenLinear(g); + passes::LinearToAddMM(g); passes::Conv2DToConvolution(g); passes::Conv3DToConvolution(g); passes::FuseAddMMBranches(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index f213a2539a..b786ddd615 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -17,7 +17,7 @@ cc_library( "conv3d_to_convolution.cpp", "exception_elimination.cpp", "fuse_addmm_branches.cpp", - "fuse_flatten_linear.cpp", + "linear_to_addmm.cpp", "remove_bn_dim_check.cpp", "remove_contiguous.cpp", "remove_dropout.cpp", diff --git a/core/lowering/passes/fuse_flatten_linear.cpp b/core/lowering/passes/linear_to_addmm.cpp similarity index 63% rename from core/lowering/passes/fuse_flatten_linear.cpp rename to core/lowering/passes/linear_to_addmm.cpp index 936a4930da..0cc8ef584c 100644 --- a/core/lowering/passes/fuse_flatten_linear.cpp +++ b/core/lowering/passes/linear_to_addmm.cpp @@ -7,29 +7,31 @@ namespace core { namespace lowering { namespace passes { -void FuseFlattenLinear(std::shared_ptr& graph) { +void LinearToAddMM(std::shared_ptr& graph) { // TensorRT implicitly adds a flatten layer infront of FC layers if necessary std::string flatten_linear_pattern = R"IR( - graph(%input, %6, %7, %weight, %bias): - %flat = aten::flatten(%input, %6, %7) - %res = aten::linear(%flat, %weight, %bias) + graph(%input, %weight, %bias): + %res = aten::linear(%input, %weight, %bias) return (%res))IR"; std::string flatten_linear_bias_none_pattern = R"IR( - graph(%input, %6, %7, %weight): - %flat = aten::flatten(%input, %6, %7) + graph(%input, %weight): %bias: Tensor? = prim::Constant() - %res = aten::linear(%flat, %weight, %bias) - return (%res))IR"; - std::string fused_linear = R"IR( - graph(%input, %6, %7, %weight, %bias): %res = aten::linear(%input, %weight, %bias) return (%res))IR"; + std::string fused_linear = R"IR( + graph(%input, %weight_t, %bias): + %1: int = prim::Constant[value=1]() + %weight = aten::t(%weight_t) + %mm: Tensor = aten::matmul(%input, %weight) + %b_f: Tensor = trt::const(%bias) + %out: Tensor = aten::add_(%b_f, %mm, %1) + return (%out))IR"; std::string fused_linear_bias_none = R"IR( - graph(%input, %6, %7, %weight): - %bias: Tensor? = prim::Constant() - %res = aten::linear(%input, %weight, %bias) - return (%res))IR"; + graph(%input, %weight_t): + %weight = aten::t(%weight_t) + %mm: Tensor = aten::matmul(%input, %weight) + return (%mm))IR"; torch::jit::SubgraphRewriter flatten_linear_to_linear; flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear); @@ -38,7 +40,7 @@ void FuseFlattenLinear(std::shared_ptr& graph) { torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear; flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none); flatten_linear_bias_none_to_linear.runOnGraph(graph); - LOG_GRAPH("Post flatten linear: " << *graph); + LOG_GRAPH("Post linear to addmm: " << *graph); } } // namespace passes diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 770982f67f..df204df918 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -10,7 +10,7 @@ namespace passes { void Conv2DToConvolution(std::shared_ptr& graph); void Conv3DToConvolution(std::shared_ptr& graph); void FuseAddMMBranches(std::shared_ptr graph); -void FuseFlattenLinear(std::shared_ptr& graph); +void LinearToAddMM(std::shared_ptr& graph); void EliminateExceptionOrPassPattern(std::shared_ptr graph); void RemoveBNDimCheck(std::shared_ptr graph); void RemoveContiguous(std::shared_ptr& graph); diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 7742a07e06..9e778bb5cc 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -7,6 +7,10 @@ config_setting( } ) +lowering_test( + name = "test_linear_to_addmm", +) + lowering_test( name = "test_remove_contiguous_pass", ) @@ -30,6 +34,7 @@ lowering_test( test_suite( name = "lowering_tests", tests = [ + ":test_linear_to_addmm", ":test_remove_contiguous_pass", ":test_remove_to", ":test_remove_detach_pass", diff --git a/tests/core/lowering/test_linear_to_addmm.cpp b/tests/core/lowering/test_linear_to_addmm.cpp new file mode 100644 index 0000000000..3f5acae812 --- /dev/null +++ b/tests/core/lowering/test_linear_to_addmm.cpp @@ -0,0 +1,34 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, LinearToAddMM) { + std::string source_graph = R"IR( + graph(%input, %6, %7, %weight, %bias): + %flat = aten::flatten(%input, %6, %7) + %res = aten::linear(%flat, %weight, %bias) + return (%res))IR"; + std::string target_graph = R"IR( + graph(%input, %6, %7, %weight_t, %bias): + %1: int = prim::Constant[value=1]() + %flat = aten::flatten(%input, %6, %7) + %weight = aten::t(%weight_t) + %mm: Tensor = aten::matmul(%flat, %weight) + %b_f: Tensor = trt::const(%bias) + %out: Tensor = aten::add_(%b_f, %mm, %1) + return (%out))IR"; + + trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + trtorch::core::lowering::passes::LinearToAddMM(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 8b31c2c516..e2a0516e0a 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -46,15 +46,15 @@ "path": "both" }, "resnet18": { - "model": torch.hub.load('pytorch/vision:v0.8.2', 'resnet18', pretrained=True), + "model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True), "path": "both" }, "resnet50": { - "model": torch.hub.load('pytorch/vision:v0.8.2', 'resnet50', pretrained=True), + "model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True), "path": "both" }, "fcn_resnet101": { - "model": torch.hub.load('pytorch/vision:v0.8.2', 'fcn_resnet101', pretrained=True), + "model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True), "path": "script" }, "ssd": { diff --git a/tests/py/test_api.py b/tests/py/test_api.py index fe0613413b..a21385f6e1 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -79,7 +79,6 @@ def test_is_colored_output_on(self): def test_suite(): suite = unittest.TestSuite() suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True))) - suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet50(pretrained=True))) suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True))) suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport)) suite.addTest(unittest.makeSuite(TestLoggingAPIs))