Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix out of bounds error in iterator (#5402)
Browse files Browse the repository at this point in the history
Fixes 5401

>>> import mxnet as mx
>>> x = mx.nd.array([1, 2, 3])
>>> for a in x:
...     print a
...
<NDArray 1 @cpu(0)>
<NDArray 1 @cpu(0)>
<NDArray 1 @cpu(0)>
[09:44:11] mxnet/dmlc-core/include/dmlc/logging.h:300:
[09:44:11] mxnet/include/mxnet/ndarray.h:276: Check failed: shape_[0] > idx (3 vs. 3) index out of range

Stack trace returned 6 entries:
[bt] (0) 0   libmxnet.so                         0x000000010bb7f48e _ZN4dmlc15LogMessageFatalD2Ev + 46
[bt] (1) 1   libmxnet.so                         0x000000010bb6dbc5 _ZN4dmlc15LogMessageFatalD1Ev + 21
[bt] (2) 2   libmxnet.so                         0x000000010bb71a64 _ZNK5mxnet7NDArray2AtEj + 644
[bt] (3) 3   libmxnet.so                         0x000000010bb716b5 MXNDArrayAt + 101
  • Loading branch information
dleen authored and piiswrong committed Mar 15, 2017
1 parent e638aba commit 8b28b8c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ def __getitem__(self, key):
"""
# multi-dimensional slicing is not supported yet
if isinstance(key, int):
if key > self.shape[0] - 1:
raise IndexError(
'index {} is out of bounds for axis 0 with size {}'.format(
key, self.shape[0]))
return self._at(key)
if isinstance(key, py_slice):
if key.step is not None:
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,16 @@ def test_take():
result = mx.nd.take(data_real_mx, idx_real_mx)
assert_almost_equal(result.asnumpy(), data_real[idx_real])


def test_iter():
x = mx.nd.array([1, 2, 3])
y = []
for a in x:
y.append(a)

assert np.all(np.array(y) == x.asnumpy())


if __name__ == '__main__':
test_broadcast_binary()
test_ndarray_setitem()
Expand All @@ -603,3 +613,4 @@ def test_take():
test_order()
test_ndarray_equal()
test_take()
test_iter()

0 comments on commit 8b28b8c

Please sign in to comment.