diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 3599ab9939..0bb6ad01ff 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -22,6 +22,8 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s if (unbind) { axis = args[1].unwrapToInt(); + auto maxDim = static_cast(in->getDimensions().nbDims); + axis = axis < 0 ? axis + maxDim : axis; numOutputs = in->getDimensions().d[axis]; sizes.insert(sizes.end(), numOutputs, 1); } else { diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 4632744790..4d8795f378 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -181,7 +181,12 @@ auto aten_registrations TORCHTRT_UNUSED = .evaluator({c10::Symbol::fromQualString("aten::slice"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::List list = args.at(n->input(0)).IValue()->to>(); - int64_t start = args.at(n->input(1)).unwrapToInt(); + + int64_t start = 0; + auto startIVal = args.at(n->input(1)).IValue(); + if(!startIVal->isNone()){ + start = args.at(n->input(1)).unwrapToInt(); + } int64_t end = args.at(n->input(2)).unwrapToInt(); int64_t step = args.at(n->input(3)).unwrapToInt(); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index c4bb727d11..03b6bda36c 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -365,6 +365,38 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenSliceListConvertsCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %1 : NoneType = prim::Constant() + %2 : int = prim::Constant[value=2]() + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=3]() + %list : Tensor[] = aten::unbind(%x, %4) + %slice : Tensor[] = aten::slice(%list, %1, %2, %3) + %out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice) + return (%out.1, %out.2))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); + + auto jit_in_x = at::clone(in_x); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x}); + + auto trt_in_x = at::clone(in_x); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): @@ -796,3 +828,30 @@ TEST(Converters, ATenUnbindConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } } + +TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=-1]() + %3 : Tensor[] = aten::unbind(%x.1, %2) + %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) + return (%o1.1, %o2.1))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {5, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +}