Skip to content

Commit

Permalink
Merge pull request #9950 from froystig:wrapped
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 435524235
  • Loading branch information
jax authors committed Mar 18, 2022
2 parents 4b81311 + e04ea89 commit 1f95273
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 96 deletions.
20 changes: 10 additions & 10 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import jax
from jax import core
from jax import linear_util as lu
from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
Expand All @@ -51,7 +52,6 @@
from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
Expand All @@ -74,7 +74,12 @@
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies
from jax.core import ShapedArray, raise_to_shaped
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
Expand All @@ -83,11 +88,6 @@
from jax.interpreters import masking
from jax.interpreters import invertible_ad as iad
from jax.interpreters.invertible_ad import custom_ivjp
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.ad_checkpoint import checkpoint_policies

from jax._src.config import (
flags, config, bool_env,
Expand Down Expand Up @@ -236,7 +236,7 @@ def jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> Any:
) -> stages.Wrapped:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
Expand Down Expand Up @@ -347,7 +347,7 @@ def _python_jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> F:
) -> stages.Wrapped:
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
Expand Down Expand Up @@ -399,7 +399,7 @@ def _cpp_jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> Any:
) -> stages.Wrapped:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
Expand Down Expand Up @@ -1897,7 +1897,7 @@ def _python_pmap(
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
Expand Down
166 changes: 91 additions & 75 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Any, Optional, Tuple, Union
from typing_extensions import Protocol

from jax import core
from jax.interpreters import pxla
Expand All @@ -32,75 +33,6 @@
zip, unsafe_zip = util.safe_zip, zip


class Lowered:
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool

def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.
Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")

# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()


class Compiled:
"""Compiled representation of a function specialized to types/values.
Expand All @@ -114,7 +46,6 @@ class Compiled:
"_no_kwargs"
]


# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
Expand Down Expand Up @@ -157,14 +88,14 @@ def __call__(self, *args, **kwargs):
if self._no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
'function was compiled by a transformation that does not support '
"function was compiled by a transformation that does not support "
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_flatten((args, kwargs))
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f'function compiled for {self.in_tree}, called with {in_tree}')
f"function compiled for {self.in_tree}, called with {in_tree}")
try:
out_flat = self._executable.call(*args_flat)
except TypeError:
Expand All @@ -179,9 +110,94 @@ def __call__(self, *args, **kwargs):
for arg in args_flat:
if isinstance(arg, core.Tracer):
raise TypeError(
'Cannot apply JAX transformations to a function lowered and '
'compiled for a particular signature. Detected argument of '
f'Tracer type {type(arg)}.')
"Cannot apply JAX transformations to a function lowered and "
"compiled for a particular signature. Detected argument of "
f"Tracer type {type(arg)}.")
else:
raise
return tree_unflatten(self.out_tree, out_flat)


class Lowered:
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool

def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.
Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compile(self) -> Compiled:
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")

# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()


class Wrapped(Protocol):
def __call__(self, *args, **kwargs):
"""Executes the wrapped function, lowering and compiling as needed."""

def lower(self, *args, **kwargs) -> Lowered:
"""Lower this function for the given arguments.
A lowered function is staged out of Python and translated to a
compiler's input language, possibly in a backend-dependent
manner. It is ready for compilation but not yet compiled.
Returns:
A ``Lowered`` instance representing the lowering.
"""
20 changes: 10 additions & 10 deletions jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def xmap(fun: Callable,
axis_sizes: Dict[AxisName, int] = {},
axis_resources: Dict[AxisName, ResourceSet] = {},
donate_argnums: Union[int, Sequence[int]] = (),
backend: Optional[str] = None):
backend: Optional[str] = None) -> stages.Wrapped:
"""Assign a positional signature to a program that uses named array axes.
.. warning::
Expand Down Expand Up @@ -644,17 +644,20 @@ def verify_outputs(out_flat, out_tree, params):
f"but the output has rank {out.ndim} (and shape {out.shape})")
return tree_unflatten(out_tree(), out_flat)

def decorate_serial(f):
for loop_params in reversed(anon_serial_loops):
f = serial_loop(*loop_params)(f)
return f

@wraps(fun)
@decorate_serial
def fun_mapped(*args):
tree_map(_check_arg, args)
fun_flat, args_flat, params, _, out_tree = infer_params(*args)
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
return verify_outputs(out_flat, out_tree, params)

def decorate_serial(f):
for loop_params in reversed(anon_serial_loops):
f = serial_loop(*loop_params)(f)
return f

@decorate_serial
def lower(*args):
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
Expand All @@ -671,10 +674,7 @@ def lower(*args):
computation, in_tree, in_avals, out_tree(), donate_argnums,
no_kwargs=True)

fun_mapped = wraps(fun)(
traceback_util.api_boundary(decorate_serial(fun_mapped)))
fun_mapped.lower = decorate_serial(lower)

fun_mapped.lower = lower
return fun_mapped

def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def pjit(fun: Callable,
in_axis_resources,
out_axis_resources,
static_argnums: Union[int, Sequence[int]] = (),
donate_argnums: Union[int, Sequence[int]] = ()):
donate_argnums: Union[int, Sequence[int]] = ()) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
The returned function has semantics equivalent to those of ``fun``, but is
Expand Down
1 change: 1 addition & 0 deletions jax/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from jax._src.stages import (
Compiled as Compiled,
Lowered as Lowered,
Wrapped as Wrapped,
)

0 comments on commit 1f95273

Please sign in to comment.