Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic protocol for compilation wrappers #9999

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import types
from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional,
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable)
from typing_extensions import ParamSpec
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -110,6 +111,7 @@
F = TypeVar("F", bound=Callable)
T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -228,15 +230,15 @@ def _infer_argnums_and_argnames(


def jit(
fun: Callable,
fun: Callable[P, T],
*,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> stages.Wrapped:
) -> stages.Wrapped[P, T]:
"""Sets up ``fun`` for just-in-time compilation with XLA.

Args:
Expand Down Expand Up @@ -340,14 +342,14 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,


def _python_jit(
fun: Callable,
fun: Callable[P, T],
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> stages.Wrapped:
) -> stages.Wrapped[P, T]:
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
Expand Down Expand Up @@ -392,14 +394,14 @@ class _FastpathData(NamedTuple):
_cpp_jit_cache = jax_jit.CompiledFunctionCache()

def _cpp_jit(
fun: Callable,
fun: Callable[P, T],
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> stages.Wrapped:
) -> stages.Wrapped[P, T]:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
Expand Down Expand Up @@ -1890,7 +1892,7 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,


def _python_pmap(
fun: Callable,
fun: Callable[P, T],
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
Expand All @@ -1901,7 +1903,7 @@ def _python_pmap(
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped:
) -> stages.Wrapped[P, T]:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Optional, Sequence, Tuple
from typing_extensions import Protocol
from typing import Any, Optional, Sequence, Tuple, TypeVar
from typing_extensions import Protocol, ParamSpec

from jax import core
from jax import tree_util
Expand All @@ -32,6 +32,9 @@
zip, unsafe_zip = util.safe_zip, zip


T_co = TypeVar("T_co", covariant=True)
P_contra = ParamSpec("P_contra", contravariant=True)

@dataclass
class ArgInfo:
aval: core.ShapedArray
Expand Down Expand Up @@ -239,12 +242,12 @@ def _xla_computation(self):
return self._lowering.hlo()


class Wrapped(Protocol):
def __call__(self, *args, **kwargs):
class Wrapped(Protocol[P_contra, T_co]):
def __call__(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> T_co:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

def lower(self, *args, **kwargs) -> Lowered:
def lower(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> Lowered:
"""Lower this function for the given arguments.

A lowered function is staged out of Python and translated to a
Expand Down