Skip to content

Commit

Permalink
Merge pull request #19466 from jakevdp:full-like-sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600620827
  • Loading branch information
jax authors committed Jan 23, 2024
2 parents f21022b + e76b87e commit 70aec84
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
10 changes: 7 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,14 +1361,17 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:

def full_like(x: ArrayLike | DuckTypedArray,
fill_value: ArrayLike, dtype: DTypeLike | None = None,
shape: Shape | None = None) -> Array:
shape: Shape | None = None, sharding: Sharding | None = None) -> Array:
"""Create a full array like np.full based on the example array `x`.
Args:
x: example array-like, used for shape and dtype information.
fill_value: a scalar value to fill the entries of the output array.
dtype: optional, a dtype parameter for the output ndarray.
shape: optional, a shape parameter for the output ndarray.
sharding: an optional sharding specification for the resulting array.
If not specified, the output will have the same sharding as the input,
so long as ``shape`` is also not specified.
Returns:
An ndarray with the same shape as `x` with its entries set equal to
Expand All @@ -1379,10 +1382,11 @@ def full_like(x: ArrayLike | DuckTypedArray,
dtype = dtype or _dtype(x)
if dtypes.issubdtype(dtype, dtypes.extended):
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type))
val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type),
sharding=sharding)
# TODO(yashkatariya): Use shard_like in tracing mode too i.e. remove the
# ArrayImpl check.
if shape is None and isinstance(x, array.ArrayImpl):
if shape is None and sharding is None and isinstance(x, array.ArrayImpl):
if xla_extension_version < 227:
sharding = x.sharding # type: ignore[union-attr]
if (not dispatch.is_single_device_sharding(sharding) and
Expand Down
8 changes: 8 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,14 @@ def test_lax_full_sharding(self):
x = lax.full((len(devices),), 1.0, sharding=sharding)
self.assertEqual(x.sharding, sharding)

def test_lax_full_like_sharding(self):
devices = jax.devices()
mesh = Mesh(devices, axis_names=("i"))
sharding = NamedSharding(mesh, P('i'))
x = lax.iota("float32", len(devices))
y = lax.full_like(x, 1, sharding=sharding)
self.assertEqual(y.sharding, sharding)


class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
Expand Down

0 comments on commit 70aec84

Please sign in to comment.