Skip to content

Commit

Permalink
feat: support aten::conv1d and aten::conv_transpose1d
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo committed Oct 19, 2021
1 parent 4d95b04 commit c8dc6e9
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 23 deletions.
85 changes: 62 additions & 23 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ namespace converters {
namespace impl {
namespace {

bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
bool add_conv_deconv(
ConversionCtx* ctx,
const torch::jit::Node* n,
args& args,
nvinfer1::Dims& stride,
nvinfer1::Dims& padding,
nvinfer1::Dims& dilation,
bool transposed,
nvinfer1::Dims& out_padding,
int64_t groups) {
// Input to conv/deconv
auto in = args[0].ITensor();

// Conv /deconv parameters
auto stride = util::toDims(args[3].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
auto dilation = util::toDims(args[5].unwrapToIntList());
bool transposed = args[6].unwrapToBool();
auto out_padding = util::toDims(args[7].unwrapToIntList());
int64_t groups = args[8].unwrapToInt();

// Reshape the parameters to 2D if needed
if (stride.nbDims == 1) {
stride = util::unsqueezeDims(stride, 1, 1);
Expand Down Expand Up @@ -174,28 +175,66 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
return true;
}

auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::_convolution(Tensor input, Tensor weight,
auto conv_registrations TRTORCH_UNUSED =
RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::_convolution(Tensor input, Tensor weight,
Tensor? bias, int[] stride, int[] padding,
int[] dilation, bool transposed,
int[] output_padding, int groups, bool benchmark,
bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return add_conv_deconv(ctx, n, args);
}})
.pattern({
R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight,
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Conv /deconv parameters
auto stride = util::toDims(args[3].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
auto dilation = util::toDims(args[5].unwrapToIntList());
bool transposed = args[6].unwrapToBool();
auto out_padding = util::toDims(args[7].unwrapToIntList());
int64_t groups = args[8].unwrapToInt();
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
}})
.pattern({
R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight,
Tensor? bias, int[] stride, int[] padding,
int[] dilation, bool transposed,
int[] output_padding, int groups, bool benchmark,
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// This pattern is only matched for traced JIT models which do not
// have allow_tf32 bool in the function signature. The TRT conversion
// code is exactly same as the above call.
return add_conv_deconv(ctx, n, args);
}});
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// This pattern is only matched for traced JIT models which do not
// have allow_tf32 bool in the function signature. The TRT conversion
// code is exactly same as the above call.
auto stride = util::toDims(args[3].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
auto dilation = util::toDims(args[5].unwrapToIntList());
bool transposed = args[6].unwrapToBool();
auto out_padding = util::toDims(args[7].unwrapToIntList());
int64_t groups = args[8].unwrapToInt();
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
}})
.pattern(
{R"SIG(aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor)SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Conv /deconv parameters
auto stride = util::toDims(args[3].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
auto dilation = util::toDims(args[5].unwrapToIntList());
bool transposed = false;
nvinfer1::Dims out_padding{1, {0}};
int64_t groups = args[6].unwrapToInt();
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
}})
.pattern(
{R"SIG(aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor)SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Conv /deconv parameters
auto stride = util::toDims(args[3].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
auto out_padding = util::toDims(args[5].unwrapToIntList());
bool transposed = true;
int64_t groups = args[6].unwrapToInt();
auto dilation = util::toDims(args[7].unwrapToIntList());
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
}});
} // namespace
} // namespace impl
} // namespace converters
Expand Down
86 changes: 86 additions & 0 deletions tests/core/conversion/converters/test_conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
// int[] output_padding, int groups, bool benchmark,
// bool deterministic, bool cudnn_enabled) -> (Tensor)

// aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) ->
// Tensor

// aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding,
// int groups, int[] dilation) -> Tensor

void conv_test_helper(std::string graph_ir) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph_ir, g.get());
Expand Down Expand Up @@ -116,6 +122,86 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConv1dConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%3)
%9 : int[] = prim::ListConstruct(%4)
%10 : int[] = prim::ListConstruct(%5)
%12 : Tensor = aten::conv1d(%0, %1, %2, %8, %9, %10, %3)
return (%12))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
auto b = at::randint(1, 10, {4}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvTranspose1dConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%8 : int[] = prim::ListConstruct(%3)
%9 : int[] = prim::ListConstruct(%4)
%10 : int[] = prim::ListConstruct(%5)
%11 : int[] = prim::ListConstruct(%6)
%12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %8, %9, %11, %3, %10)
return (%12))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {8, 4, 3}, {at::kCUDA});
auto b = at::randint(1, 10, {4}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down

0 comments on commit c8dc6e9

Please sign in to comment.