Skip to content

Commit

Permalink
Add lower to specialize making it a true Stage.
Browse files Browse the repository at this point in the history
So now users can do:

```
specialized = jax.jit(f).specialize(*args)
print(specialized.jaxpr, specialized.out_info)

lowered = specialized.lower()

compiled = lowered.compile()
```
PiperOrigin-RevId: 640737396
  • Loading branch information
yashk2810 authored and jax authors committed Jun 6, 2024
1 parent d117305 commit 55d0f5e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
24 changes: 16 additions & 8 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class PjitInfo(NamedTuple):


def _python_pjit_helper(jit_info, *args, **kwargs):
(args_flat, _, params, _, out_tree, _, arg_names,
(args_flat, params, _, out_tree, _, arg_names,
attrs_tracked) = _infer_params(jit_info, args, kwargs)

for arg in args_flat:
Expand Down Expand Up @@ -480,7 +480,7 @@ def lower(*args, **kwargs):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())

(args_flat, flat_global_in_avals, params, in_tree, out_tree,
(args_flat, params, in_tree, out_tree,
donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs)
try:
lowering = _resolve_and_lower(
Expand All @@ -496,13 +496,14 @@ def lower(*args, **kwargs):
raise ValueError(msg) from None

donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
jaxpr = params["jaxpr"]
return stages.Lowered.from_flat_info(
lowering, in_tree, flat_global_in_avals, donate_argnums,
out_tree, fun_name=params["name"], jaxpr=params["jaxpr"])
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree,
fun_name=params["name"], jaxpr=jaxpr)

@api_boundary
def eval_shape(*args, **kwargs):
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
_, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
Expand All @@ -511,12 +512,19 @@ def eval_shape(*args, **kwargs):

@api_boundary
def specialize(*args, **kwargs) -> stages.Specialized:
_, _, params, in_tree, out_tree, donated_invars, _, _ = _infer_params(
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())

args_flat, params, in_tree, out_tree, donated_invars, _, _ = _infer_params(
jit_info, args, kwargs)

donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
jaxpr = params['jaxpr']
args_info = stages.make_args_info(in_tree, jaxpr.in_avals, donate_argnums)
return stages.Specialized(jaxpr, args_info, out_tree)
lower_callable = partial(_resolve_and_lower, args_flat, **params,
lowering_parameters=lowering_parameters)
return stages.Specialized(jaxpr, args_info, params["name"], out_tree,
lower_callable)

wrapped = _cpp_pjit(jit_info)
wrapped.lower = lower
Expand Down Expand Up @@ -667,7 +675,7 @@ def _infer_params(jit_info, args, kwargs):
keep_unused=keep_unused,
inline=inline,
)
return (consts + args_flat, in_type, params, in_tree, out_tree(),
return (consts + args_flat, params, in_tree, out_tree(),
donated_invars, dbg.arg_names if dbg else None, attrs_tracked)

def _extract_implicit_args(
Expand Down
11 changes: 9 additions & 2 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,25 @@ class CompiledCallParams(NamedTuple):


class Specialized(Stage):
__slots__ = ["jaxpr", "args_info", "_out_tree"]
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable"]

def __init__(self, jaxpr: core.ClosedJaxpr, args_info, out_tree):
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable):
self.jaxpr = jaxpr
self.args_info = args_info
self.fun_name = fun_name
self._out_tree = out_tree
self._lower_callable = lower_callable

@property
def out_info(self):
return self._out_tree.unflatten(
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])

def lower(self):
lowering = self._lower_callable()
return Lowered(lowering, self.args_info, self._out_tree)


class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.
Expand Down
19 changes: 19 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4191,6 +4191,25 @@ def f(x):
self.assertLen(specialized.in_avals[0], 1)
self.assertLen(specialized.in_avals[1], 0) # empty kwarg

def test_jit_specialize_lower_and_compile(self):
def f(x):
return x * 2

lowered = jax.jit(f).specialize(jnp.arange(8)).lower()
self.assertEqual(lowered.args_info[0][0].shape, (8,))

compiled = lowered.compile()
out = compiled(jnp.arange(8))
self.assertArraysEqual(out, np.arange(8) * 2)

# fast-forward
lowered2 = jax.jit(f).lower(jnp.arange(8))
self.assertEqual(lowered2.args_info[0][0].shape, (8,))

compiled2 = lowered2.compile()
out2 = compiled2(jnp.arange(8))
self.assertArraysEqual(out2, np.arange(8) * 2)


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit 55d0f5e

Please sign in to comment.