Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add copy argument to Array.__array__ #20077

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,10 @@ def is_fully_addressable(self) -> bool:
"""
return self.sharding.is_fully_addressable

def __array__(self, dtype=None, context=None):
return np.asarray(self._value, dtype=dtype)
def __array__(self, dtype=None, context=None, copy=None):
# copy argument is supported by np.asarray starting in numpy 2.0
kwds = {} if copy is None else {'copy': copy}
return np.asarray(self._value, dtype=dtype, **kwds)

def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ class Array(abc.ABC):
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self) -> np.ndarray: ...
def __array__(self, dtype: Optional[np.dtype] = ...,
copy: Optional[bool] = ...) -> np.ndarray: ...
def __dlpack__(self) -> Any: ...

# JAX extensions
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3465,7 +3465,7 @@ def testArrayOutputsArrays(self):
assert type(jnp.array(np.array([]))) is array.ArrayImpl

class NDArrayLike:
def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
return np.array([], dtype=dtype)
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl

Expand All @@ -3478,7 +3478,7 @@ def __array__(self, dtype=None):
def testArrayMethod(self):
class arraylike:
dtype = np.dtype('float32')
def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
return np.array(3., dtype=dtype)
a = arraylike()
ans = jnp.array(a)
Expand Down