From 4993e72152b36c3904ddaec2bdfea640af6333f6 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 24 Jan 2024 12:04:06 -0500 Subject: [PATCH] Fix type annotation for array_api.broadcast_to Signed-off-by: nstarman --- jax/experimental/array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index 4b016db05880..04a590579189 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -24,9 +24,9 @@ def broadcast_arrays(*arrays: Array) -> list[Array]: return jax.numpy.broadcast_arrays(*arrays) -def broadcast_to(x: Array, /, shape: tuple[int]) -> Array: """Broadcasts an array to a specified shape.""" return jax.numpy.broadcast_to(x, shape=shape) +def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array: