diff --git a/docs/python_docs/python/api/np/arrays.ndarray.rst b/docs/python_docs/python/api/np/arrays.ndarray.rst index e77d20b8a138..522a667d69b1 100644 --- a/docs/python_docs/python/api/np/arrays.ndarray.rst +++ b/docs/python_docs/python/api/np/arrays.ndarray.rst @@ -512,12 +512,13 @@ Container customization: (see :ref:`Indexing `) ndarray.__getitem__ ndarray.__setitem__ -Conversion; the operations :func:`int()` and :func:`float()`. +Conversion; the operations :func:`index()`, :func:`int()` and :func:`float()`. They work only on arrays that have one element in them and return the appropriate scalar. .. autosummary:: + ndarray.__index__ ndarray.__int__ ndarray.__float__ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a58f1faf5587..b3dbe04fbbbb 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -31,6 +31,8 @@ from array import array as native_array import functools import ctypes +import sys +import datetime import warnings import numpy as _np from .. import _deferred_compute as dc @@ -412,6 +414,34 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- return mx_np_func(*new_args, **new_kwargs) + def __array_namespace__(self, api_version=None): + """ + Returns an object that has all the array API functions on it. + + Notes + ----- + This is a standard API in + https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none. + + Parameters + ---------- + self : ndarray + The indexing key. + api_version : Optional, string + string representing the version of the array API specification to be returned, in `YYYY.MM` form. + If it is None, it should return the namespace corresponding to latest version of the array API + specification. + """ + if api_version is not None: + try: + date = datetime.datetime.strptime(api_version, '%Y.%m') + if date.year != 2021: + raise ValueError + except ValueError: + raise ValueError(f"Unrecognized array API version: {api_version!r}") + return sys.modules[self.__module__] + + def _get_np_basic_indexing(self, key): """ This function indexes ``self`` with a tuple of `slice` objects only. @@ -1255,6 +1285,11 @@ def __bool__(self): __nonzero__ = __bool__ + def __index__(self): + if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer): + return self.item() + raise TypeError('only integer scalar arrays can be converted to a scalar index') + def __float__(self): num_elements = self.size if num_elements != 1: diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 559b8a575f5d..2da60aa1fc8e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -21,6 +21,7 @@ import itertools import os import pytest +import operator import numpy as _np import mxnet as mx from mxnet import np, npx, autograd @@ -1426,3 +1427,34 @@ def test_mixed_array_types_share_memory(): def test_save_load_empty(tmp_path): mx.npx.savez(str(tmp_path / 'params.npz')) mx.npx.load(str(tmp_path / 'params.npz')) + +@use_np +@pytest.mark.parametrize('shape', [ + (), + (1,), + (1,2) +]) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'bool', 'int32']) +def test_index_operator(shape, dtype): + if len(shape) >= 1 or not _np.issubdtype(dtype, _np.integer): + x = np.ones(shape=shape, dtype=dtype) + pytest.raises(TypeError, operator.index, x) + else: + assert operator.index(np.ones(shape=shape, dtype=dtype)) == \ + operator.index(_np.ones(shape=shape, dtype=dtype)) + + +@pytest.mark.parametrize('api_version, raise_exception', [ + (None, False), + ('2021.10', False), + ('2020.09', True), + ('2021.24', True), +]) +def test_array_namespace(api_version, raise_exception): + x = np.array([1, 2, 3], dtype="float64") + if raise_exception: + pytest.raises(ValueError, x.__array_namespace__, api_version) + else: + xp = x.__array_namespace__(api_version) + y = xp.array([1, 2, 3], dtype="float64") + assert same(x, y)