diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index 4b016db05880..04a590579189 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -24,9 +24,9 @@ def broadcast_arrays(*arrays: Array) -> list[Array]: return jax.numpy.broadcast_arrays(*arrays) -def broadcast_to(x: Array, /, shape: tuple[int]) -> Array: """Broadcasts an array to a specified shape.""" return jax.numpy.broadcast_to(x, shape=shape) +def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array: