diff --git a/jax/_src/api.py b/jax/_src/api.py index dd1b9547b881..575dbe6fbe64 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -33,6 +33,7 @@ import types from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional, Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable) +from typing_extensions import ParamSpec from warnings import warn import numpy as np @@ -110,6 +111,7 @@ F = TypeVar("F", bound=Callable) T = TypeVar("T") U = TypeVar("U") +P = ParamSpec("P") map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -228,7 +230,7 @@ def _infer_argnums_and_argnames( def jit( - fun: Callable, + fun: Callable[P, T], *, static_argnums: Union[int, Iterable[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, @@ -236,7 +238,7 @@ def jit( backend: Optional[str] = None, donate_argnums: Union[int, Iterable[int]] = (), inline: bool = False, -) -> stages.Wrapped: +) -> stages.Wrapped[P, T]: """Sets up ``fun`` for just-in-time compilation with XLA. Args: @@ -340,14 +342,14 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums, def _python_jit( - fun: Callable, + fun: Callable[P, T], static_argnums: Union[int, Iterable[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, device: Optional[xc.Device] = None, backend: Optional[str] = None, donate_argnums: Union[int, Iterable[int]] = (), inline: bool = False, -) -> stages.Wrapped: +) -> stages.Wrapped[P, T]: # The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit. _check_callable(fun) static_argnums, static_argnames = _infer_argnums_and_argnames( @@ -392,14 +394,14 @@ class _FastpathData(NamedTuple): _cpp_jit_cache = jax_jit.CompiledFunctionCache() def _cpp_jit( - fun: Callable, + fun: Callable[P, T], static_argnums: Union[int, Iterable[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, device: Optional[xc.Device] = None, backend: Optional[str] = None, donate_argnums: Union[int, Iterable[int]] = (), inline: bool = False, -) -> stages.Wrapped: +) -> stages.Wrapped[P, T]: # An implementation of `jit` that tries to do as much as possible in C++. # The goal of this function is to speed up the time it takes to process the # arguments, find the correct C++ executable, start the transfer of arguments @@ -1890,7 +1892,7 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums, def _python_pmap( - fun: Callable, + fun: Callable[P, T], axis_name: Optional[AxisName] = None, *, in_axes=0, @@ -1901,7 +1903,7 @@ def _python_pmap( axis_size: Optional[int] = None, donate_argnums: Union[int, Iterable[int]] = (), global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None, -) -> stages.Wrapped: +) -> stages.Wrapped[P, T]: """The Python only implementation.""" axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 47995e70d1e1..53b794e67fe1 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -13,8 +13,8 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Optional, Sequence, Tuple -from typing_extensions import Protocol +from typing import Any, Optional, Sequence, Tuple, TypeVar +from typing_extensions import Protocol, ParamSpec from jax import core from jax import tree_util @@ -32,6 +32,9 @@ zip, unsafe_zip = util.safe_zip, zip +T_co = TypeVar("T_co", covariant=True) +P_contra = ParamSpec("P_contra", contravariant=True) + @dataclass class ArgInfo: aval: core.ShapedArray @@ -239,12 +242,12 @@ def _xla_computation(self): return self._lowering.hlo() -class Wrapped(Protocol): - def __call__(self, *args, **kwargs): +class Wrapped(Protocol[P_contra, T_co]): + def __call__(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> T_co: """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError - def lower(self, *args, **kwargs) -> Lowered: + def lower(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> Lowered: """Lower this function for the given arguments. A lowered function is staged out of Python and translated to a