Skip to content

Commit

Permalink
fix(aten::cat): support neg dim for cat
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 28, 2021
1 parent 114969b commit d8ca182
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
8 changes: 5 additions & 3 deletions core/conversion/converters/impl/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
for (auto t : ts) {
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
auto t_weights = Weights(ctx, torch_tensor);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
tensors.push_back(const_layer->getOutput(0));
tensors.push_back(tensor_to_const(ctx, torch_tensor));
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
}
}

if (dim < 0) {
dim = tensors[0]->getDimensions().nbDims + dim;
}

auto cat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size());
cat_layer->setAxis(static_cast<int>(dim));
auto cat_out = ctx->AssociateValueAndTensor(n->outputs()[0], cat_layer->getOutput(0));
Expand Down
47 changes: 47 additions & 0 deletions tests/core/conversion/converters/test_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,52 @@ TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}
TEST(Converters, ATenCatPureTensorNegDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=-1]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5, 5}, {at::kCUDA});
auto in2 = at::randint(1, 10, {5, 5}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenCatDiffTensorNegDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(5)):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=-1]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5, 5}, {at::kCUDA});
auto in2 = at::randint(1, 10, {5, 5}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});

params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit d8ca182

Please sign in to comment.