Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API NEW][ARRAY METHOD] Add __Index__() and __array_namespace__() #20689

Merged
merged 5 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/python_docs/python/api/np/arrays.ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,13 @@ Container customization: (see :ref:`Indexing <arrays.indexing>`)
ndarray.__getitem__
ndarray.__setitem__

Conversion; the operations :func:`int()` and :func:`float()`.
Conversion; the operations :func:`index()`, :func:`int()` and :func:`float()`.
They work only on arrays that have one element in them
and return the appropriate scalar.

.. autosummary::

ndarray.__index__
ndarray.__int__
ndarray.__float__

Expand Down
35 changes: 35 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from array import array as native_array
import functools
import ctypes
import sys
import datetime
import warnings
import numpy as _np
from .. import _deferred_compute as dc
Expand Down Expand Up @@ -412,6 +414,34 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
return mx_np_func(*new_args, **new_kwargs)


def __array_namespace__(self, api_version=None):
"""
Returns an object that has all the array API functions on it.

Notes
-----
This is a standard API in
https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none.

Parameters
----------
self : ndarray
The indexing key.
api_version : Optional, string
string representing the version of the array API specification to be returned, in `YYYY.MM` form.
If it is None, it should return the namespace corresponding to latest version of the array API
specification.
"""
if api_version is not None:
try:
date = datetime.datetime.strptime(api_version, '%Y.%m')
if date.year != 2021:
raise ValueError
except ValueError:
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return sys.modules[self.__module__]


def _get_np_basic_indexing(self, key):
"""
This function indexes ``self`` with a tuple of `slice` objects only.
Expand Down Expand Up @@ -1255,6 +1285,11 @@ def __bool__(self):

__nonzero__ = __bool__

def __index__(self):
if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer):
return self.item()
raise TypeError('only integer scalar arrays can be converted to a scalar index')

def __float__(self):
num_elements = self.size
if num_elements != 1:
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import os
import pytest
import operator
import numpy as _np
import mxnet as mx
from mxnet import np, npx, autograd
Expand Down Expand Up @@ -1426,3 +1427,34 @@ def test_mixed_array_types_share_memory():
def test_save_load_empty(tmp_path):
mx.npx.savez(str(tmp_path / 'params.npz'))
mx.npx.load(str(tmp_path / 'params.npz'))

@use_np
@pytest.mark.parametrize('shape', [
(),
(1,),
(1,2)
])
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'bool', 'int32'])
def test_index_operator(shape, dtype):
if len(shape) >= 1 or not _np.issubdtype(dtype, _np.integer):
x = np.ones(shape=shape, dtype=dtype)
pytest.raises(TypeError, operator.index, x)
else:
assert operator.index(np.ones(shape=shape, dtype=dtype)) == \
operator.index(_np.ones(shape=shape, dtype=dtype))


@pytest.mark.parametrize('api_version, raise_exception', [
(None, False),
('2021.10', False),
('2020.09', True),
('2021.24', True),
])
def test_array_namespace(api_version, raise_exception):
x = np.array([1, 2, 3], dtype="float64")
if raise_exception:
pytest.raises(ValueError, x.__array_namespace__, api_version)
else:
xp = x.__array_namespace__(api_version)
y = xp.array([1, 2, 3], dtype="float64")
assert same(x, y)