From e76b87eab0d1d3ed6677ae9c96583af9acd7b514 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 22 Jan 2024 11:55:25 -0800 Subject: [PATCH] lax.full_like: add sharding argument --- jax/_src/lax/lax.py | 10 +++++++--- tests/lax_test.py | 8 ++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 650fb5631b2c..5ed216d2db71 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1361,7 +1361,7 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array: def full_like(x: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, - shape: Shape | None = None) -> Array: + shape: Shape | None = None, sharding: Sharding | None = None) -> Array: """Create a full array like np.full based on the example array `x`. Args: @@ -1369,6 +1369,9 @@ def full_like(x: ArrayLike | DuckTypedArray, fill_value: a scalar value to fill the entries of the output array. dtype: optional, a dtype parameter for the output ndarray. shape: optional, a shape parameter for the output ndarray. + sharding: an optional sharding specification for the resulting array. + If not specified, the output will have the same sharding as the input, + so long as ``shape`` is also not specified. Returns: An ndarray with the same shape as `x` with its entries set equal to @@ -1379,10 +1382,11 @@ def full_like(x: ArrayLike | DuckTypedArray, dtype = dtype or _dtype(x) if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] - val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type)) + val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), + sharding=sharding) # TODO(yashkatariya): Use shard_like in tracing mode too i.e. remove the # ArrayImpl check. - if shape is None and isinstance(x, array.ArrayImpl): + if shape is None and sharding is None and isinstance(x, array.ArrayImpl): if xla_extension_version < 227: sharding = x.sharding # type: ignore[union-attr] if (not dispatch.is_single_device_sharding(sharding) and diff --git a/tests/lax_test.py b/tests/lax_test.py index 2531a9f961e0..f3b91b720ba1 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2734,6 +2734,14 @@ def test_lax_full_sharding(self): x = lax.full((len(devices),), 1.0, sharding=sharding) self.assertEqual(x.sharding, sharding) + def test_lax_full_like_sharding(self): + devices = jax.devices() + mesh = Mesh(devices, axis_names=("i")) + sharding = NamedSharding(mesh, P('i')) + x = lax.iota("float32", len(devices)) + y = lax.full_like(x, 1, sharding=sharding) + self.assertEqual(y.sharding, sharding) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected):