From 81ac67f38164b7626d733d081a87ff49b235b9d0 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Thu, 2 Nov 2023 07:54:16 -0700 Subject: [PATCH] Fix typing annotations for `@jax.named_call` PiperOrigin-RevId: 578852649 --- jax/_src/api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 6c5e5a5becbd..4dcdff7e4d6a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2835,10 +2835,10 @@ def eval_shape(fun, *args, **kwargs): def named_call( - fun: Callable[..., Any], + fun: F, *, name: str | None = None, - ) -> Callable[..., Any]: +) -> F: """Adds a user specified name to a function when staging out JAX computations. When staging out computations for just-in-time compilation to XLA (or other @@ -2867,6 +2867,7 @@ def named_call( return source_info_util.extend_name_stack(name)(fun) + @contextmanager def named_scope( name: str,