Skip to content

Commit

Permalink
Allow scalar cunumeric ndarrays as array indices (#479)
Browse files Browse the repository at this point in the history
* Allow scalar cunumeric ndarrays as array indices

* Remove some duplicate test lines
  • Loading branch information
manopapad authored Jul 29, 2022
1 parent f66f322 commit 731901b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
from __future__ import annotations

import operator
import warnings
from collections.abc import Iterable
from functools import reduce, wraps
Expand Down Expand Up @@ -823,6 +824,12 @@ def __ge__(self, rhs):

def _convert_key(self, key, first=True):
# Convert any arrays stored in a key to a cuNumeric array
if isinstance(key, slice):
key = slice(
operator.index(key.start) if key.start is not None else None,
operator.index(key.stop) if key.stop is not None else None,
operator.index(key.step) if key.step is not None else None,
)
if (
key is np.newaxis
or key is Ellipsis
Expand Down Expand Up @@ -970,6 +977,9 @@ def __imul__(self, rhs):

return multiply(self, rhs, out=self)

def __index__(self) -> int:
return self.__array__().__index__()

def __int__(self):
"""a.__int__(/)
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/test_get_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import pytest

import cunumeric as num
Expand All @@ -24,6 +25,21 @@ def test_basic():
assert x[2] == 3


ARRAYS_4_3_2_1_0 = [
4 - num.arange(5),
4 - np.arange(5),
[4, 3, 2, 1, 0],
]


@pytest.mark.parametrize("arr", ARRAYS_4_3_2_1_0)
def test_scalar_ndarray_as_index(arr):
offsets = num.arange(5) # [0, 1, 2, 3, 4]
offset = offsets[3] # 3
# assert arr[offset] == 1 # TODO: doesn't work when arr is a num.ndarray
assert np.array_equal(arr[offset - 2 : offset], [3, 2])


if __name__ == "__main__":
import sys

Expand Down
17 changes: 17 additions & 0 deletions tests/integration/test_set_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import pytest

import cunumeric as num
Expand All @@ -28,6 +29,22 @@ def test_basic():
assert x[2] == 3


ARRAYS_4_3_2_1_0 = [
4 - num.arange(5),
4 - np.arange(5),
[4, 3, 2, 1, 0],
]


@pytest.mark.parametrize("arr", ARRAYS_4_3_2_1_0)
def test_scalar_ndarray_as_index(arr):
offsets = num.arange(5) # [0, 1, 2, 3, 4]
offset = offsets[3] # 3
# arr[offset] = -1 # TODO: doesn't work when arr is a num.ndarray
arr[offset - 2 : offset] = [-1, -1]
assert np.array_equal(arr, [4, -1, -1, 1, 0])


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 731901b

Please sign in to comment.