Skip to content

Commit

Permalink
Merge pull request #1088 from mfeliz-cruise/michael.feliz/fix_slice_a…
Browse files Browse the repository at this point in the history
…nd_unbind

Fix errors in unbind and list slice
  • Loading branch information
narendasan authored Aug 4, 2022
2 parents 07238c8 + d73738c commit 8ca7a22
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
2 changes: 2 additions & 0 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s

if (unbind) {
axis = args[1].unwrapToInt();
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
axis = axis < 0 ? axis + maxDim : axis;
numOutputs = in->getDimensions().d[axis];
sizes.insert(sizes.end(), numOutputs, 1);
} else {
Expand Down
7 changes: 6 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ auto aten_registrations TORCHTRT_UNUSED =
.evaluator({c10::Symbol::fromQualString("aten::slice"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
int64_t start = args.at(n->input(1)).unwrapToInt();

int64_t start = 0;
auto startIVal = args.at(n->input(1)).IValue();
if(!startIVal->isNone()){
start = args.at(n->input(1)).unwrapToInt();
}
int64_t end = args.at(n->input(2)).unwrapToInt();
int64_t step = args.at(n->input(3)).unwrapToInt();

Expand Down
59 changes: 59 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,38 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenSliceListConvertsCorrectly) {
const auto graph = R"IR(
graph(%x : Tensor):
%1 : NoneType = prim::Constant()
%2 : int = prim::Constant[value=2]()
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=3]()
%list : Tensor[] = aten::unbind(%x, %4)
%slice : Tensor[] = aten::slice(%list, %1, %2, %3)
%out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice)
return (%out.1, %out.2))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA});

auto jit_in_x = at::clone(in_x);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x});

auto trt_in_x = at::clone(in_x);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down Expand Up @@ -796,3 +828,30 @@ TEST(Converters, ATenUnbindConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=-1]()
%3 : Tensor[] = aten::unbind(%x.1, %2)
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
return (%o1.1, %o2.1))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {5, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

0 comments on commit 8ca7a22

Please sign in to comment.