Skip to content

Commit

Permalink
fix(aten::conv1d): Update namespace, fix typo in dest IR for conv1d
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Nov 3, 2021
1 parent 540e135 commit d53f136
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 85 deletions.
4 changes: 1 addition & 3 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ config_setting(
cc_library(
name = "passes",
srcs = [
"conv1d_to_convolution.cpp",
"conv2d_to_convolution.cpp",
"conv3d_to_convolution.cpp",
"convNd_to_convolution.cpp",
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"linear_to_addmm.cpp",
Expand Down
33 changes: 0 additions & 33 deletions core/lowering/passes/conv2d_to_convolution.cpp

This file was deleted.

33 changes: 0 additions & 33 deletions core/lowering/passes/conv3d_to_convolution.cpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "core/util/prelude.h"

namespace trtorch {
namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
Expand All @@ -12,6 +12,7 @@ void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";

std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
Expand Down Expand Up @@ -43,7 +44,45 @@ void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
}

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv2d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
map_conv2d_to_convolution.runOnGraph(graph);
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
}

void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv3d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
map_conv3d_to_convolution.runOnGraph(graph);
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
} // namespace torch_tensorrt
28 changes: 14 additions & 14 deletions tests/core/lowering/test_conv1d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ TEST(LoweringPasses, Conv1dCorrectly) {
%12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
return (%12))IR";

trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
trtorch::core::lowering::passes::Conv1DToConvolution(sg);
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
Expand All @@ -50,13 +50,13 @@ TEST(LoweringPasses, Conv1dCorrectly) {
auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b});
auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in});
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});

params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b});
auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in});
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
}

TEST(LoweringPasses, ConvTransposed1dCorrectly) {
Expand Down Expand Up @@ -92,10 +92,10 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
%12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7)
return (%12))IR";

trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
trtorch::core::lowering::passes::ConvTransposed1DToConvolution(sg);
torch_tensorrt::core::lowering::passes::ConvTransposed1DToConvolution(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
Expand All @@ -107,11 +107,11 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b});
auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in});
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});

params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b});
auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in});
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
}

0 comments on commit d53f136

Please sign in to comment.