Skip to content

Commit

Permalink
Merge pull request #384 from inocsin/fix_transpose
Browse files Browse the repository at this point in the history
feat: support aten::transpose with negative dim
  • Loading branch information
narendasan authored Mar 2, 2021
2 parents b1b5f19 + 4a1d2f3 commit 36a5d97
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
2 changes: 2 additions & 0 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ static auto shuffle_registrations TRTORCH_UNUSED =
for (size_t i = 0; i < ndims; i++) {
new_order.push_back(i);
}
dim0 = dim0 < 0 ? (dim0 + ndims) : dim0;
dim1 = dim1 < 0 ? (dim1 + ndims) : dim1;
auto tmp = dim0;
new_order[dim0] = new_order[dim1];
new_order[dim1] = tmp;
Expand Down
26 changes: 26 additions & 0 deletions tests/core/conversion/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,29 @@ TEST(Converters, ATenTransposeConvertsCorrectly) {

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenTransposeNegativeConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=-1]()
%3 : int = prim::Constant[value=-3]()
%4 : Tensor = aten::transpose(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {2, 3, 4, 5, 6}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});

std::cout << "Running JIT" << std::endl;
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

std::cout << "Running TRT" << std::endl;
in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

0 comments on commit 36a5d97

Please sign in to comment.