Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (//core/conversion) : Add converter for torch.repeat_interleave ( #1313

Merged
merged 7 commits into from
Aug 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,116 @@ auto expand_registrations TORCHTRT_UNUSED =
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);

LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto repeats = args[1].unwrapToScalar().to<int>();

auto input_shape = self->getDimensions();

int dim;
if (args[2].IValue()->isNone()) {
dim = 0;

// Flatten self tensor
int size;
if (ctx->input_is_dynamic) {
// Set size to -1 if input is dynamic
size = -1;
} else {
size = 1;
for (int i = 0; i < input_shape.nbDims; i++) {
size *= input_shape.d[i];
}
}
auto flatten = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(flatten, "Unable to create shuffle layer from node: " << *n);
flatten->setReshapeDimensions(util::toDims(std::vector<int64_t>({size})));
self = flatten->getOutput(0);
input_shape = self->getDimensions();
} else {
dim = args[2].unwrapToScalar().to<int>();
}

if (ctx->input_is_dynamic) {
int dynamic_dims = 0;
for (int idx = 0; idx < input_shape.nbDims; idx++) {
if (input_shape.d[idx] == -1) {
dynamic_dims++;
}
}

if (dynamic_dims > 1) {
TORCHTRT_THROW_ERROR(
"Repeat_interleave is currently not supported when target shape contains more than one dynamic dimension");
}
}

// Insert singleton dimension after desired repeat dimension
std::vector<int64_t> repeat_shape_vec;
for (int j = 0; j < input_shape.nbDims; j++) {
repeat_shape_vec.push_back(input_shape.d[j]);
if (j == dim) {
repeat_shape_vec.push_back(1);
}
}
auto expand = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(expand, "Unable to create shuffle layer from node: " << *n);
auto repeat_shape_dims = util::toDims(repeat_shape_vec);
expand->setReshapeDimensions(repeat_shape_dims);

// Expand on newly created singleton dimension
repeat_shape_dims.d[dim + 1] = repeats;
std::vector<int64_t> start_vec(repeat_shape_dims.nbDims, 0);
auto start_dims = util::toDims(start_vec);

std::vector<int64_t> strides_vec(repeat_shape_dims.nbDims, 1);
strides_vec[dim + 1] = 0;
auto strides_dims = util::toDims(strides_vec);

auto slice = ctx->net->addSlice(*expand->getOutput(0), start_dims, repeat_shape_dims, strides_dims);

if (ctx->input_is_dynamic) {
auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32));

auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0);
std::vector<int64_t> repeat_const_vec(repeat_shape_dims.nbDims, 1);
repeat_const_vec[dim + 1] = repeats;
auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32));
auto repeat_shape_tensor =
ctx->net
->addElementWise(*expand_output_shape, *repeat_const, nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0);

auto strides_tensor = tensor_to_const(ctx, torch::tensor(strides_vec, torch::kInt32));
slice->setInput(1, *start_tensor);
slice->setInput(2, *repeat_shape_tensor);
slice->setInput(3, *strides_tensor);
}

// Collapse repeated dimension back into desired dimension
std::vector<int64_t> collapse_shape_vec;
for (int k = 0; k < repeat_shape_dims.nbDims; k++) {
if (k == dim) {
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k];
// Set dim size to -1 if repeat is being done on dynamic dim
collapse_dim = std::max(collapse_dim, (int64_t)-1);
collapse_shape_vec.push_back(collapse_dim);
} else {
collapse_shape_vec.push_back(repeat_shape_dims.d[k]);
}
}
auto collapse = ctx->net->addShuffle(*slice->getOutput(0));
TORCHTRT_CHECK(collapse, "Unable to create shuffle layer from node: " << *n);
collapse->setReshapeDimensions(util::toDims(collapse_shape_vec));

collapse->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}});

Expand Down
224 changes: 224 additions & 0 deletions tests/core/conversion/converters/test_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,227 @@ TEST(Converters, ATenRepeatExtraDimsConvertsCorrectlyWithDynamicInput) {

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}