From e04ea89b06e82ec58fda1ba92f64bb19165a88c5 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 15 Mar 2022 13:51:51 -0700 Subject: [PATCH] introduce a protocol for compilation wrappers --- jax/_src/api.py | 20 ++--- jax/_src/stages.py | 166 +++++++++++++++++++++------------------ jax/experimental/maps.py | 20 ++--- jax/experimental/pjit.py | 2 +- jax/stages.py | 1 + 5 files changed, 113 insertions(+), 96 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 03b28950aeef..694345e11963 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -41,6 +41,7 @@ import jax from jax import core from jax import linear_util as lu +from jax import stages from jax.core import eval_jaxpr from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, @@ -51,7 +52,6 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import source_info_util -from jax._src import stages from jax._src import traceback_util from jax._src.api_util import ( flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, @@ -74,7 +74,12 @@ local_devices, process_index, process_count, host_id, host_ids, host_count, default_backend) +from jax.ad_checkpoint import checkpoint_policies from jax.core import ShapedArray, raise_to_shaped +from jax.custom_batching import custom_vmap +from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp, + custom_vjp, linear_call) +from jax.custom_transpose import custom_transpose from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax.interpreters import pxla @@ -83,11 +88,6 @@ from jax.interpreters import masking from jax.interpreters import invertible_ad as iad from jax.interpreters.invertible_ad import custom_ivjp -from jax.custom_batching import custom_vmap -from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp, - custom_vjp, linear_call) -from jax.custom_transpose import custom_transpose -from jax.ad_checkpoint import checkpoint_policies from jax._src.config import ( flags, config, bool_env, @@ -236,7 +236,7 @@ def jit( backend: Optional[str] = None, donate_argnums: Union[int, Iterable[int]] = (), inline: bool = False, -) -> Any: +) -> stages.Wrapped: """Sets up ``fun`` for just-in-time compilation with XLA. Args: @@ -347,7 +347,7 @@ def _python_jit( backend: Optional[str] = None, donate_argnums: Union[int, Iterable[int]] = (), inline: bool = False, -) -> F: +) -> stages.Wrapped: # The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit. _check_callable(fun) static_argnums, static_argnames = _infer_argnums_and_argnames( @@ -399,7 +399,7 @@ def _cpp_jit( backend: Optional[str] = None, donate_argnums: Union[int, Iterable[int]] = (), inline: bool = False, -) -> Any: +) -> stages.Wrapped: # An implementation of `jit` that tries to do as much as possible in C++. # The goal of this function is to speed up the time it takes to process the # arguments, find the correct C++ executable, start the transfer of arguments @@ -1897,7 +1897,7 @@ def _python_pmap( axis_size: Optional[int] = None, donate_argnums: Union[int, Iterable[int]] = (), global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None, -) -> Any: +) -> stages.Wrapped: """The Python only implementation.""" axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, diff --git a/jax/_src/stages.py b/jax/_src/stages.py index eafff4aba74d..7f955f4a812c 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Optional, Tuple, Union +from typing_extensions import Protocol from jax import core from jax.interpreters import pxla @@ -32,75 +33,6 @@ zip, unsafe_zip = util.safe_zip, zip -class Lowered: - """Lowering of a function specialized to argument types and values. - - A lowering is a computation ready for compilation. This class - carries a lowering together with the remaining information needed to - later compile and execute it. It also provides a common API for - querying properties of lowered computations across JAX's various - lowering paths (``jit``, ``pmap``, etc.). - """ - __slots__ = [ - "in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering", - "_no_kwargs" - ] - - # The PyTreeDef of the (positional arguments, keyword arguments). - # - # To get the individual PyTreeDef for the positional an keyword arguments, - # use `in_tree.children() which will return you a sequence of 2 PyTreeDef. - in_tree: PyTreeDef - # The nested input tree of `ShapedArray` abstract values of (args, kwargs). - in_avals: Any - out_tree: PyTreeDef - donate_argnums: Tuple[int] - _lowering: Union[dispatch.XlaComputation, - pxla.MeshComputation, - pxla.PmapComputation] - _no_kwargs: bool - - def __init__(self, - lowering, - in_tree: PyTreeDef, - in_avals, - out_tree: PyTreeDef, - donate_argnums: Tuple[int], - no_kwargs: bool = False): - """Initializer. - - Args: - in_tree: The `PyTreeDef` of (args, kwargs). - out_tree: The `PyTreeDef` of the outputs. - no_kwargs: If `True` the transformation, and the `Compiled` returned from - this object will not support keyword arguments (an error will be raised - if some are provided). - """ - self._lowering = lowering - self.in_tree = in_tree - self.in_avals = in_avals - self.out_tree = out_tree - self.donate_argnums = donate_argnums - self._no_kwargs = no_kwargs - - def compile(self) -> 'Compiled': - return Compiled( - self._lowering.compile(), self.in_tree, self.in_avals, - self.out_tree, self.donate_argnums, self._no_kwargs) - - def compiler_ir(self, dialect: Optional[str] = None): - if dialect is None or dialect == "mhlo": - return self._lowering.mhlo() - elif dialect == "hlo": - return self._lowering.hlo() - else: - raise ValueError(f"Unknown dialect {dialect}") - - # TODO(frostig): remove this in favor of `compiler_ir` - def _xla_computation(self): - return self._lowering.hlo() - - class Compiled: """Compiled representation of a function specialized to types/values. @@ -114,7 +46,6 @@ class Compiled: "_no_kwargs" ] - # The PyTreeDef of the (positional arguments, keyword arguments). in_tree: PyTreeDef # The nested input tree of `ShapedArray` abstract values of (args, kwargs). @@ -157,14 +88,14 @@ def __call__(self, *args, **kwargs): if self._no_kwargs and kwargs: kws = ', '.join(kwargs.keys()) raise NotImplementedError( - 'function was compiled by a transformation that does not support ' + "function was compiled by a transformation that does not support " f"keyword arguments, but called with keyword arguments: {kws}") args_flat, in_tree = tree_flatten((args, kwargs)) if in_tree != self.in_tree: # TODO(frostig): provide more info about the source function # and transformation raise TypeError( - f'function compiled for {self.in_tree}, called with {in_tree}') + f"function compiled for {self.in_tree}, called with {in_tree}") try: out_flat = self._executable.call(*args_flat) except TypeError: @@ -179,9 +110,94 @@ def __call__(self, *args, **kwargs): for arg in args_flat: if isinstance(arg, core.Tracer): raise TypeError( - 'Cannot apply JAX transformations to a function lowered and ' - 'compiled for a particular signature. Detected argument of ' - f'Tracer type {type(arg)}.') + "Cannot apply JAX transformations to a function lowered and " + "compiled for a particular signature. Detected argument of " + f"Tracer type {type(arg)}.") else: raise return tree_unflatten(self.out_tree, out_flat) + + +class Lowered: + """Lowering of a function specialized to argument types and values. + + A lowering is a computation ready for compilation. This class + carries a lowering together with the remaining information needed to + later compile and execute it. It also provides a common API for + querying properties of lowered computations across JAX's various + lowering paths (``jit``, ``pmap``, etc.). + """ + __slots__ = [ + "in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering", + "_no_kwargs" + ] + + # The PyTreeDef of the (positional arguments, keyword arguments). + # + # To get the individual PyTreeDef for the positional an keyword arguments, + # use `in_tree.children() which will return you a sequence of 2 PyTreeDef. + in_tree: PyTreeDef + # The nested input tree of `ShapedArray` abstract values of (args, kwargs). + in_avals: Any + out_tree: PyTreeDef + donate_argnums: Tuple[int] + _lowering: Union[dispatch.XlaComputation, + pxla.MeshComputation, + pxla.PmapComputation] + _no_kwargs: bool + + def __init__(self, + lowering, + in_tree: PyTreeDef, + in_avals, + out_tree: PyTreeDef, + donate_argnums: Tuple[int], + no_kwargs: bool = False): + """Initializer. + + Args: + in_tree: The `PyTreeDef` of (args, kwargs). + out_tree: The `PyTreeDef` of the outputs. + no_kwargs: If `True` the transformation, and the `Compiled` returned from + this object will not support keyword arguments (an error will be raised + if some are provided). + """ + self._lowering = lowering + self.in_tree = in_tree + self.in_avals = in_avals + self.out_tree = out_tree + self.donate_argnums = donate_argnums + self._no_kwargs = no_kwargs + + def compile(self) -> Compiled: + return Compiled( + self._lowering.compile(), self.in_tree, self.in_avals, + self.out_tree, self.donate_argnums, self._no_kwargs) + + def compiler_ir(self, dialect: Optional[str] = None): + if dialect is None or dialect == "mhlo": + return self._lowering.mhlo() + elif dialect == "hlo": + return self._lowering.hlo() + else: + raise ValueError(f"Unknown dialect {dialect}") + + # TODO(frostig): remove this in favor of `compiler_ir` + def _xla_computation(self): + return self._lowering.hlo() + + +class Wrapped(Protocol): + def __call__(self, *args, **kwargs): + """Executes the wrapped function, lowering and compiling as needed.""" + + def lower(self, *args, **kwargs) -> Lowered: + """Lower this function for the given arguments. + + A lowered function is staged out of Python and translated to a + compiler's input language, possibly in a backend-dependent + manner. It is ready for compilation but not yet compiled. + + Returns: + A ``Lowered`` instance representing the lowering. + """ diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 0e7ccd9348cf..098a4c498177 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -328,7 +328,7 @@ def xmap(fun: Callable, axis_sizes: Dict[AxisName, int] = {}, axis_resources: Dict[AxisName, ResourceSet] = {}, donate_argnums: Union[int, Sequence[int]] = (), - backend: Optional[str] = None): + backend: Optional[str] = None) -> stages.Wrapped: """Assign a positional signature to a program that uses named array axes. .. warning:: @@ -644,17 +644,20 @@ def verify_outputs(out_flat, out_tree, params): f"but the output has rank {out.ndim} (and shape {out.shape})") return tree_unflatten(out_tree(), out_flat) + def decorate_serial(f): + for loop_params in reversed(anon_serial_loops): + f = serial_loop(*loop_params)(f) + return f + + @wraps(fun) + @decorate_serial def fun_mapped(*args): tree_map(_check_arg, args) fun_flat, args_flat, params, _, out_tree = infer_params(*args) out_flat = xmap_p.bind(fun_flat, *args_flat, **params) return verify_outputs(out_flat, out_tree, params) - def decorate_serial(f): - for loop_params in reversed(anon_serial_loops): - f = serial_loop(*loop_params)(f) - return f - + @decorate_serial def lower(*args): fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args) avals_flat = [shaped_abstractify(arg) for arg in args_flat] @@ -671,10 +674,7 @@ def lower(*args): computation, in_tree, in_avals, out_tree(), donate_argnums, no_kwargs=True) - fun_mapped = wraps(fun)( - traceback_util.api_boundary(decorate_serial(fun_mapped))) - fun_mapped.lower = decorate_serial(lower) - + fun_mapped.lower = lower return fun_mapped def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars, diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index ab56d7851727..d65aed262e37 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -65,7 +65,7 @@ def pjit(fun: Callable, in_axis_resources, out_axis_resources, static_argnums: Union[int, Sequence[int]] = (), - donate_argnums: Union[int, Sequence[int]] = ()): + donate_argnums: Union[int, Sequence[int]] = ()) -> stages.Wrapped: """Makes ``fun`` compiled and automatically partitioned across multiple devices. The returned function has semantics equivalent to those of ``fun``, but is diff --git a/jax/stages.py b/jax/stages.py index 91ffa9565e1e..d93cc6c84e31 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -16,4 +16,5 @@ from jax._src.stages import ( Compiled as Compiled, Lowered as Lowered, + Wrapped as Wrapped, )