Skip to content

Commit

Permalink
jnp.full_like & co: support device parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 26, 2024
1 parent 1ae054b commit 9549c74
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 16 deletions.
22 changes: 13 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,37 +2226,40 @@ def copy(a: ArrayLike, order: str | None = None) -> Array:
@util.implements(np.zeros_like)
def zeros_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
util.check_arraylike("zeros_like", a)
dtypes.check_user_dtype_supported(dtype, "zeros_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 0, dtype, shape)
return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device))


@util.implements(np.ones_like)
def ones_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing
util.check_arraylike("ones_like", a)
dtypes.check_user_dtype_supported(dtype, "ones_like")
if shape is not None:
shape = canonicalize_shape(shape)
return lax.full_like(a, 1, dtype, shape)
return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device))


@util.implements(np.empty_like, lax_description="""\
Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty_like(prototype: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing
util.check_arraylike("empty_like", prototype)
dtypes.check_user_dtype_supported(dtype, "empty_like")
return zeros_like(prototype, dtype=dtype, shape=shape)
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)


def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array:
Expand Down Expand Up @@ -2286,7 +2289,8 @@ def full(shape: Any, fill_value: ArrayLike,
@util.implements(np.full_like)
def full_like(a: ArrayLike | DuckTypedArray,
fill_value: ArrayLike, dtype: DTypeLike | None = None,
shape: Any = None) -> Array:
shape: Any = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
util.check_arraylike("full_like", 0, fill_value)
else:
Expand All @@ -2295,11 +2299,11 @@ def full_like(a: ArrayLike | DuckTypedArray,
if shape is not None:
shape = canonicalize_shape(shape)
if ndim(fill_value) == 0:
return lax.full_like(a, fill_value, dtype, shape)
return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device))
else:
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type]
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)


@util.implements(np.zeros)
Expand Down
12 changes: 8 additions & 4 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def empty_like(prototype: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ...
euler_gamma: float
def exp(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -366,7 +367,8 @@ def full(shape: Any, fill_value: ArrayLike,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def full_like(a: Union[ArrayLike, DuckTypedArray],
fill_value: ArrayLike, dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: ...
generic = _np.generic
def geomspace(
Expand Down Expand Up @@ -609,7 +611,8 @@ def ones(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def ones_like(a: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ...
def packbits(
a: ArrayLike, axis: Optional[int] = ..., bitorder: str = ...
Expand Down Expand Up @@ -885,6 +888,7 @@ def zeros(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def zeros_like(a: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
shape: Any = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...

def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ...
44 changes: 41 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,44 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
out = func(**kwds, shape=shape, dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
dtype=default_dtypes,
)
def testFullLikeWithDevice(self, func, shape, dtype):
device = jax.devices()[-1]
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
kwds = {'fill_value': 1} if func is jnp.full_like else {}

with self.subTest('device from keyword'):
out = func(x, **kwds, device=device)
self.assertEqual(out.devices(), {device})

with self.subTest('device from input array'):
out2 = func(out, **kwds)
self.assertEqual(out2.devices(), out.devices())

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
dtype=default_dtypes,
)
def testFullLikeWithSharding(self, func, shape, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
kwds = {'fill_value': 1} if func is jnp.full_like else {}

with self.subTest('device from keyword'):
out = func(x, **kwds, device=sharding)
self.assertEqual(out.sharding, sharding)

with self.subTest('device from input array'):
out2 = func(out, **kwds)
self.assertEqual(out2.devices(), out.devices())

def testDuckTypedLike(self):
x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32"))
self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype))
Expand Down Expand Up @@ -5674,7 +5712,7 @@ def testWrappedSignaturesMatch(self):
'identity': ['like'],
'isin': ['kind'],
'full': ['order', 'like'],
'full_like': ['device', 'subok', 'order'],
'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'],
'histogram2d': ['normed'],
Expand All @@ -5685,7 +5723,7 @@ def testWrappedSignaturesMatch(self):
'nanstd': ['correction', 'mean'],
'nanvar': ['correction', 'mean'],
'ones': ['order', 'like'],
'ones_like': ['device', 'subok', 'order'],
'ones_like': ['subok', 'order'],
'partition': ['kind', 'order'],
'percentile': ['weights'],
'quantile': ['weights'],
Expand All @@ -5695,7 +5733,7 @@ def testWrappedSignaturesMatch(self):
'tri': ['like'],
'var': ['correction', 'mean'],
'vstack': ['casting'],
'zeros_like': ['device', 'subok', 'order']
'zeros_like': ['subok', 'order']
}

extra_params = {
Expand Down

0 comments on commit 9549c74

Please sign in to comment.