From d8ca18203bf628cfd5ed174c3c53d46ce59052e5 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 28 Jul 2021 09:19:58 -0700 Subject: [PATCH] fix(aten::cat): support neg dim for cat Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/converters/impl/concat.cpp | 8 ++-- .../conversion/converters/test_concat.cpp | 47 +++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/core/conversion/converters/impl/concat.cpp b/core/conversion/converters/impl/concat.cpp index 74a0018188..ce2032c569 100644 --- a/core/conversion/converters/impl/concat.cpp +++ b/core/conversion/converters/impl/concat.cpp @@ -20,15 +20,17 @@ auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() for (auto t : ts) { if (t.isTensor()) { auto torch_tensor = t.toTensor(); - auto t_weights = Weights(ctx, torch_tensor); - auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); - tensors.push_back(const_layer->getOutput(0)); + tensors.push_back(tensor_to_const(ctx, torch_tensor)); } else { auto cont = t.toCustomClass(); tensors.push_back(cont->tensor()); } } + if (dim < 0) { + dim = tensors[0]->getDimensions().nbDims + dim; + } + auto cat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size()); cat_layer->setAxis(static_cast(dim)); auto cat_out = ctx->AssociateValueAndTensor(n->outputs()[0], cat_layer->getOutput(0)); diff --git a/tests/core/conversion/converters/test_concat.cpp b/tests/core/conversion/converters/test_concat.cpp index 0b27ffede3..7e940df3c0 100644 --- a/tests/core/conversion/converters/test_concat.cpp +++ b/tests/core/conversion/converters/test_concat.cpp @@ -49,5 +49,52 @@ TEST(Converters, ATenCatDiffTensorConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {in2}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} +TEST(Converters, ATenCatPureTensorNegDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=-1]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 5}, {at::kCUDA}); + auto in2 = at::randint(1, 10, {5, 5}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenCatDiffTensorNegDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(5)): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=-1]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 5}, {at::kCUDA}); + auto in2 = at::randint(1, 10, {5, 5}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {in2}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } \ No newline at end of file