Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.full_like & co: support device parameter #19504

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,37 +2226,40 @@ def copy(a: ArrayLike, order: str | None = None) -> Array:
@util.implements(np.zeros_like)
def zeros_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
util.check_arraylike("zeros_like", a)
dtypes.check_user_dtype_supported(dtype, "zeros_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 0, dtype, shape)
return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device))


@util.implements(np.ones_like)
def ones_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
util.check_arraylike("ones_like", a)
dtypes.check_user_dtype_supported(dtype, "ones_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 1, dtype, shape)
return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device))


@util.implements(np.empty_like, lax_description="""\
Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty_like(prototype: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing
util.check_arraylike("empty_like", prototype)
dtypes.check_user_dtype_supported(dtype, "empty_like")
return zeros_like(prototype, dtype=dtype, shape=shape)
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)


def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array:
Expand Down Expand Up @@ -2286,7 +2289,8 @@ def full(shape: Any, fill_value: ArrayLike,
@util.implements(np.full_like)
def full_like(a: ArrayLike | DuckTypedArray,
fill_value: ArrayLike, dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
util.check_arraylike("full_like", 0, fill_value)
else:
Expand All @@ -2295,11 +2299,11 @@ def full_like(a: ArrayLike | DuckTypedArray,
if shape is not None:
shape = canonicalize_shape(shape)
if ndim(fill_value) == 0:
return lax.full_like(a, fill_value, dtype, shape)
return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device))
else:
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type]
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)


@util.implements(np.zeros)
Expand Down
12 changes: 8 additions & 4 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def empty_like(prototype: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ...
euler_gamma: float
def exp(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -366,7 +367,8 @@ def full(shape: Any, fill_value: ArrayLike,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def full_like(a: Union[ArrayLike, DuckTypedArray],
fill_value: ArrayLike, dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: ...
generic = _np.generic
def geomspace(
Expand Down Expand Up @@ -609,7 +611,8 @@ def ones(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def ones_like(a: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ...
def packbits(
a: ArrayLike, axis: Optional[int] = ..., bitorder: str = ...
Expand Down Expand Up @@ -885,6 +888,7 @@ def zeros(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def zeros_like(a: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...

def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ...
44 changes: 41 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,44 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
out = func(**kwds, shape=shape, dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
dtype=default_dtypes,
)
def testFullLikeWithDevice(self, func, shape, dtype):
device = jax.devices()[-1]
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
kwds = {'fill_value': 1} if func is jnp.full_like else {}

with self.subTest('device from keyword'):
out = func(x, **kwds, device=device)
self.assertEqual(out.devices(), {device})

with self.subTest('device from input array'):
out2 = func(out, **kwds)
self.assertEqual(out2.devices(), out.devices())

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
dtype=default_dtypes,
)
def testFullLikeWithSharding(self, func, shape, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
kwds = {'fill_value': 1} if func is jnp.full_like else {}

with self.subTest('device from keyword'):
out = func(x, **kwds, device=sharding)
self.assertEqual(out.sharding, sharding)

with self.subTest('device from input array'):
out2 = func(out, **kwds)
self.assertEqual(out2.devices(), out.devices())

def testDuckTypedLike(self):
x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32"))
self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype))
Expand Down Expand Up @@ -5674,7 +5712,7 @@ def testWrappedSignaturesMatch(self):
'identity': ['like'],
'isin': ['kind'],
'full': ['order', 'like'],
'full_like': ['device', 'subok', 'order'],
'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'],
'histogram2d': ['normed'],
Expand All @@ -5685,7 +5723,7 @@ def testWrappedSignaturesMatch(self):
'nanstd': ['correction', 'mean'],
'nanvar': ['correction', 'mean'],
'ones': ['order', 'like'],
'ones_like': ['device', 'subok', 'order'],
'ones_like': ['subok', 'order'],
'partition': ['kind', 'order'],
'percentile': ['weights'],
'quantile': ['weights'],
Expand All @@ -5695,7 +5733,7 @@ def testWrappedSignaturesMatch(self):
'tri': ['like'],
'var': ['correction', 'mean'],
'vstack': ['casting'],
'zeros_like': ['device', 'subok', 'order']
'zeros_like': ['subok', 'order']
}

extra_params = {
Expand Down
Loading