Skip to content

Commit

Permalink
Fix typing annotations for @jax.named_call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578852649
  • Loading branch information
Conchylicultor authored and jax authors committed Nov 2, 2023
1 parent 1c66ac5 commit 81ac67f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2835,10 +2835,10 @@ def eval_shape(fun, *args, **kwargs):


def named_call(
fun: Callable[..., Any],
fun: F,
*,
name: str | None = None,
) -> Callable[..., Any]:
) -> F:
"""Adds a user specified name to a function when staging out JAX computations.
When staging out computations for just-in-time compilation to XLA (or other
Expand Down Expand Up @@ -2867,6 +2867,7 @@ def named_call(

return source_info_util.extend_name_stack(name)(fun)


@contextmanager
def named_scope(
name: str,
Expand Down

0 comments on commit 81ac67f

Please sign in to comment.