Skip to content

Commit

Permalink
fix bug of indexing with ellipsis
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Nov 15, 2021
1 parent 70cb0a5 commit aad598a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
9 changes: 8 additions & 1 deletion paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,20 @@ static void ParseIndexingSlice(
// specified_dims is the number of dimensions which indexed by Interger,
// Slices.
int specified_dims = 0;
int ell_count = 0;
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) {
specified_dims++;
} else if (slice_item == Py_Ellipsis) {
ell_count++;
}
}

PADDLE_ENFORCE_LE(ell_count, 1,
platform::errors::InvalidArgument(
"An index can only have a single ellipsis ('...')"));

for (int i = 0, dim = 0; i < size; ++i) {
PyObject *slice_item = PyTuple_GetItem(index, i);

Expand Down Expand Up @@ -639,7 +646,7 @@ static void ParseIndexingSlice(
}

// valid_index is the number of dimensions exclude None index
const int valid_indexs = size - none_axes->size();
const int valid_indexs = size - none_axes->size() - ell_count;
PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.",
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,11 @@ def assert_getitem_ellipsis_index(var_tensor, var_np):
assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
assert_getitem_ellipsis_index(var_int, np_int_value)

# test 1 dim tensor
var_one_dim = paddle.to_tensor([1, 2, 3, 4])
self.assertTrue(
np.array_equal(var_one_dim[..., 0].numpy(), np.array([1])))

def _test_none_index(self):
shape = (8, 64, 5, 256)
np_value = np.random.random(shape).astype('float32')
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,22 @@ def _test_slice_index_ellipsis(self, place):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
y = paddle.assign([1, 2, 3, 4])
out1 = x[0:, ..., 1:]
out2 = x[0:, ...]
out3 = x[..., 1:]
out4 = x[...]
out5 = x[[1, 0], [0, 0]]
out6 = x[([1, 0], [0, 0])]
out7 = y[..., 0]

exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6])
result = exe.run(prog,
fetch_list=[out1, out2, out3, out4, out5, out6, out7])

expected = [
data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...],
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])]
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])], np.array([1])
]

self.assertTrue((result[0] == expected[0]).all())
Expand All @@ -247,6 +250,7 @@ def _test_slice_index_ellipsis(self, place):
self.assertTrue((result[3] == expected[3]).all())
self.assertTrue((result[4] == expected[4]).all())
self.assertTrue((result[5] == expected[5]).all())
self.assertTrue((result[6] == expected[6]).all())

with self.assertRaises(IndexError):
res = x[[1.2, 0]]
Expand Down

1 comment on commit aad598a

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.