Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.full: add sharding argument #19445

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import jax
from jax import tree_util
from jax.sharding import Sharding
from jax.tree_util import tree_map

from jax._src import ad_util
Expand Down Expand Up @@ -1204,14 +1205,16 @@ def tie_in(x: Any, y: T) -> T:
"""Deprecated. Ignores ``x`` and returns ``y``."""
return y

def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) -> Array:
def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
sharding: Sharding | None = None) -> Array:
"""Returns an array of `shape` filled with `fill_value`.

Args:
shape: sequence of integers, describing the shape of the output array.
fill_value: the value to fill the new array with.
dtype: the type of the output array, or `None`. If not `None`, `fill_value`
will be cast to `dtype`.
sharding: an optional sharding specification for the resulting array.
"""
shape = canonicalize_shape(shape)
if np.shape(fill_value):
Expand All @@ -1222,7 +1225,11 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) ->
weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
fill_value = _convert_element_type(fill_value, dtype, weak_type)
return broadcast(fill_value, shape)
out = broadcast(fill_value, shape)
if sharding is not None:
return array.make_array_from_callback(shape, sharding, lambda idx: out[idx])
return out


def zeros_like_shaped_array(aval: ShapedArray) -> Array:
assert isinstance(aval, ShapedArray)
Expand Down
8 changes: 8 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from jax.interpreters import batching
from jax.interpreters import xla
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src import array
from jax._src import config
from jax._src import dtypes
Expand Down Expand Up @@ -2726,6 +2727,13 @@ def _step(carry, arg):

a, b = jax.lax.scan(_step, 0, jnp.arange(4, dtype=jnp.complex64))

def test_lax_full_sharding(self):
devices = jax.devices()
mesh = Mesh(devices, axis_names=("i"))
sharding = NamedSharding(mesh, P('i', None))
x = lax.full((len(devices),), 1.0, sharding=sharding)
self.assertEqual(x.sharding, sharding)


class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
Expand Down
Loading