Skip to content

Commit

Permalink
feat(aten::slice): Patching slice for new optional params
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 30, 2021
1 parent 254eab2 commit a11287f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
17 changes: 14 additions & 3 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,30 @@ auto select_registrations TRTORCH_UNUSED =
return true;
}})
.pattern(
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)",
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto axis = args[1].unwrapToInt();
auto maxDim = static_cast<int64_t>(in->getDimensions().d[axis]);
auto startIdx = 0;
auto startIdxIVal = args[2].IValue();
if (!startIdxIVal->isNone()) {
startIdx = startIdxIVal->toInt();
}
// Handle case when given tensor index is negative
auto startIdx = args[2].unwrapToInt();
auto start = (startIdx < 0) ? (maxDim + startIdx) : startIdx;
// Bound the end index to input tensor dimensions at specified axis
auto endIdx = std::min(args[3].unwrapToInt(), maxDim);
auto endIdx = maxDim;
auto endIdxIVal = args[3].IValue();
if (!endIdxIVal->isNone()) {
endIdx = std::min(endIdxIVal->toInt(), maxDim);
}
auto end = (endIdx < 0) ? (maxDim + endIdx) : endIdx;
auto step = args[4].unwrapToInt();

LOG_DEBUG("Start idx: " << start);
LOG_DEBUG("End idx: " << end);

// indices to be accessed need to be an at::Tensor
at::Tensor indices = torch::arange(start, end, step).to(torch::kI32);
auto weights = Weights(ctx, indices);
Expand Down
4 changes: 2 additions & 2 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
TEST(Converters, ATenSliceConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=9223372036854775807]()
%2 : None = prim::Constant()
%3 : int = prim::Constant[value=2]()
%4 : int = prim::Constant[value=4]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : Tensor = aten::select(%x.1, %6, %6)
%8 : Tensor = aten::select(%7, %6, %5)
%9 : Tensor = aten::slice(%8, %6, %5, %4, %3)
%10 : Tensor = aten::slice(%9, %5, %6, %2, %5)
%10 : Tensor = aten::slice(%9, %5, %2, %2, %5)
return (%10))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand Down

0 comments on commit a11287f

Please sign in to comment.