Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce a protocol for compilation wrappers #9950

Merged
merged 1 commit into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import jax
from jax import core
from jax import linear_util as lu
from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
Expand All @@ -51,7 +52,6 @@
from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
Expand All @@ -74,7 +74,12 @@
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies
from jax.core import ShapedArray, raise_to_shaped
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
Expand All @@ -83,11 +88,6 @@
from jax.interpreters import masking
from jax.interpreters import invertible_ad as iad
from jax.interpreters.invertible_ad import custom_ivjp
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.ad_checkpoint import checkpoint_policies

from jax._src.config import (
flags, config, bool_env,
Expand Down Expand Up @@ -236,7 +236,7 @@ def jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> Any:
) -> stages.Wrapped:
"""Sets up ``fun`` for just-in-time compilation with XLA.

Args:
Expand Down Expand Up @@ -347,7 +347,7 @@ def _python_jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> F:
) -> stages.Wrapped:
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
Expand Down Expand Up @@ -399,7 +399,7 @@ def _cpp_jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> Any:
) -> stages.Wrapped:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
Expand Down Expand Up @@ -1897,7 +1897,7 @@ def _python_pmap(
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
Expand Down
166 changes: 91 additions & 75 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Any, Optional, Tuple, Union
from typing_extensions import Protocol

from jax import core
from jax.interpreters import pxla
Expand All @@ -32,75 +33,6 @@
zip, unsafe_zip = util.safe_zip, zip


class Lowered:
"""Lowering of a function specialized to argument types and values.

A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool

def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.

Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")

# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()


class Compiled:
"""Compiled representation of a function specialized to types/values.

Expand All @@ -114,7 +46,6 @@ class Compiled:
"_no_kwargs"
]


# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
Expand Down Expand Up @@ -157,14 +88,14 @@ def __call__(self, *args, **kwargs):
if self._no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
'function was compiled by a transformation that does not support '
"function was compiled by a transformation that does not support "
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_flatten((args, kwargs))
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f'function compiled for {self.in_tree}, called with {in_tree}')
f"function compiled for {self.in_tree}, called with {in_tree}")
try:
out_flat = self._executable.call(*args_flat)
except TypeError:
Expand All @@ -179,9 +110,94 @@ def __call__(self, *args, **kwargs):
for arg in args_flat:
if isinstance(arg, core.Tracer):
raise TypeError(
'Cannot apply JAX transformations to a function lowered and '
'compiled for a particular signature. Detected argument of '
f'Tracer type {type(arg)}.')
"Cannot apply JAX transformations to a function lowered and "
"compiled for a particular signature. Detected argument of "
f"Tracer type {type(arg)}.")
else:
raise
return tree_unflatten(self.out_tree, out_flat)


class Lowered:
"""Lowering of a function specialized to argument types and values.

A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool

def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.

Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compile(self) -> Compiled:
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")

# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()


class Wrapped(Protocol):
froystig marked this conversation as resolved.
Show resolved Hide resolved
def __call__(self, *args, **kwargs):
"""Executes the wrapped function, lowering and compiling as needed."""

def lower(self, *args, **kwargs) -> Lowered:
"""Lower this function for the given arguments.

A lowered function is staged out of Python and translated to a
compiler's input language, possibly in a backend-dependent
manner. It is ready for compilation but not yet compiled.

Returns:
A ``Lowered`` instance representing the lowering.
"""
20 changes: 10 additions & 10 deletions jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def xmap(fun: Callable,
axis_sizes: Dict[AxisName, int] = {},
axis_resources: Dict[AxisName, ResourceSet] = {},
donate_argnums: Union[int, Sequence[int]] = (),
backend: Optional[str] = None):
backend: Optional[str] = None) -> stages.Wrapped:
"""Assign a positional signature to a program that uses named array axes.

.. warning::
Expand Down Expand Up @@ -644,17 +644,20 @@ def verify_outputs(out_flat, out_tree, params):
f"but the output has rank {out.ndim} (and shape {out.shape})")
return tree_unflatten(out_tree(), out_flat)

def decorate_serial(f):
for loop_params in reversed(anon_serial_loops):
f = serial_loop(*loop_params)(f)
return f

@wraps(fun)
@decorate_serial
def fun_mapped(*args):
tree_map(_check_arg, args)
fun_flat, args_flat, params, _, out_tree = infer_params(*args)
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
return verify_outputs(out_flat, out_tree, params)

def decorate_serial(f):
for loop_params in reversed(anon_serial_loops):
f = serial_loop(*loop_params)(f)
return f

@decorate_serial
def lower(*args):
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
Expand All @@ -671,10 +674,7 @@ def lower(*args):
computation, in_tree, in_avals, out_tree(), donate_argnums,
no_kwargs=True)

fun_mapped = wraps(fun)(
traceback_util.api_boundary(decorate_serial(fun_mapped)))
fun_mapped.lower = decorate_serial(lower)

fun_mapped.lower = lower
return fun_mapped

def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def pjit(fun: Callable,
in_axis_resources,
out_axis_resources,
static_argnums: Union[int, Sequence[int]] = (),
donate_argnums: Union[int, Sequence[int]] = ()):
donate_argnums: Union[int, Sequence[int]] = ()) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.

The returned function has semantics equivalent to those of ``fun``, but is
Expand Down
1 change: 1 addition & 0 deletions jax/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from jax._src.stages import (
Compiled as Compiled,
Lowered as Lowered,
Wrapped as Wrapped,
)