Skip to content

Commit

Permalink
fix(aten::_convolution): out channels was passed in incorrectly for
Browse files Browse the repository at this point in the history
deconv

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 22, 2020
1 parent ce6cf75 commit ee727f8
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 11 deletions.
14 changes: 6 additions & 8 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,25 @@ auto conv_registrations = RegisterNodeConversionPatterns()
auto in = args[0].ITensor();

auto w = Weights(ctx, args[1].unwrapToTensor());
auto stride = util::toDimsHW(args[3].unwrapToIntList());
auto stride = util::toDims(args[3].unwrapToIntList());
LOG_DEBUG("stride: " << stride);
auto padding = util::toDimsHW(args[4].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
LOG_DEBUG("padding: " << padding);
auto dilation = util::toDimsHW(args[5].unwrapToIntList());
auto dilation = util::toDims(args[5].unwrapToIntList());
LOG_DEBUG("dilation: " << dilation);
bool transposed = args[6].unwrapToBool();
auto out_padding = util::toDimsHW(args[7].unwrapToIntList());
auto out_padding = util::toDims(args[7].unwrapToIntList());
LOG_DEBUG("out_padding: " << out_padding);
int64_t groups = args[8].unwrapToInt();

nvinfer1::ILayer* new_layer;
if (transposed) {
//TODO: Check deconv correctness
LOG_WARNING(ctx->logger, "Deconvolution converter has not be tested");
nvinfer1::IDeconvolutionLayer* deconv;
if (args[2].IValue()->isTensor()) {
Weights b(ctx, args[2].IValue()->toTensor());
deconv = ctx->net->addDeconvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, b.data);
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, b.data);
} else {
deconv = ctx->net->addDeconvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, {});
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, {});
}

TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
Expand Down
4 changes: 2 additions & 2 deletions tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ converter_test(
)

converter_test(
name = "test_conv"
name = "test_conv_deconv"
)

converter_test(
Expand Down Expand Up @@ -56,7 +56,7 @@ test_suite(
tests = [
":test_activation",
":test_batch_norm",
":test_conv",
":test_conv_deconv",
":test_element_wise",
":test_linear",
":test_matrix_multiply",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,173 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvTransposeConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(8, 3, 3, 3),
%2 : Float(8)):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%3, %3)
%9 : int[] = prim::ListConstruct(%4, %4)
%10 : int[] = prim::ListConstruct(%5, %5)
%11 : int[] = prim::ListConstruct(%6, %6)
%12 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %3, %7, %7, %7)
return (%12))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 3, {1, 8, 5, 5}, {at::kCUDA});
auto w = at::randint(1, 3, {8, 3, 3, 3}, {at::kCUDA});
auto b = at::randint(1, 3, {3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 1, 3, 3)):
%2 : None = prim::Constant()
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%3, %3)
%9 : int[] = prim::ListConstruct(%4, %4)
%10 : int[] = prim::ListConstruct(%5, %5)
%11 : int[] = prim::ListConstruct(%6, %6)
%12 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %3, %7, %7, %7)
return (%12))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 2, {1, 4, 3, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {4, 1, 2, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}


TEST(Converters, ATenConvTransposeWithStrideConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, 3),
%2 : Float(4)):
%3 : int = prim::Constant[value=3]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%3, %3)
%9 : int[] = prim::ListConstruct(%4, %4)
%10 : int[] = prim::ListConstruct(%5, %5)
%11 : int[] = prim::ListConstruct(%6, %6)
%12 : int = prim::Constant[value=1]()
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
return (%13))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 4, 9, 9}, {at::kCUDA});
auto w = at::randint(1, 10, {4, 3, 3, 3}, {at::kCUDA});
auto b = at::randint(1, 10, {3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 4, 4),
%2 : Float(4)):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=2]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%3, %3)
%9 : int[] = prim::ListConstruct(%4, %4)
%10 : int[] = prim::ListConstruct(%5, %5)
%11 : int[] = prim::ListConstruct(%6, %6)
%12 : int = prim::Constant[value=1]()
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
return (%13))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA});
auto b = at::randint(1, 10, {3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
// const auto graph = R"IR(
// graph(%0 : Tensor,
Expand Down
4 changes: 3 additions & 1 deletion tests/util/run_graph_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
std::vector<at::Tensor> inputs) {
LOG_DEBUG("Running TRT version");
auto in = toInputRanges(inputs);
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), in, named_params);
auto info = core::conversion::ConversionInfo(in);
info.engine_settings.workspace_size = 1 << 20;
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
return RunEngine(eng, inputs);
}

Expand Down

0 comments on commit ee727f8

Please sign in to comment.