diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index 4b016db05880..cb7a12e5203e 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -24,7 +24,7 @@ def broadcast_arrays(*arrays: Array) -> list[Array]: return jax.numpy.broadcast_arrays(*arrays) -def broadcast_to(x: Array, /, shape: tuple[int]) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: """Broadcasts an array to a specified shape.""" return jax.numpy.broadcast_to(x, shape=shape)