Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add core.closed_call_p #10711

Merged
merged 1 commit into from
May 14, 2022
Merged

add core.closed_call_p #10711

merged 1 commit into from
May 14, 2022

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented May 14, 2022

This PR adds a variant of core.call_p called core.closed_call_p. The only difference is that in its 'jaxpr form' its call_jaxpr parameter is a core.ClosedJaxpr rather than a core.Jaxpr.

Some background:

  • core.call_p is the most vanilla call primitive possible: unlike, say, xla_call_p (the primitive underlying jax.jit), its impl rule isn't to compile an XLA computation, but instead it's just to interpret its jaxpr (staying in Python, using core.eval_jaxpr). Correspondingly, it doesn't need to raise the abstraction level of its arguments. It's basically a model for other "final-style" call primitives, each of which is interesting in precisely how it deviates from core.call_p (e.g. xla_call_p's impl rule stages out for compilation; remat_call_p has a special partial evaluation rule; custom_jvp_call_p has a special JVP rule; etc). Historically it was the first call primitive we introduced, just to test the system; core.call_p is not really used anywhere.

  • core.ClosedJaxpr is a data type which would be better named as PartiallyAppliedJaxpr. When we form jaxprs, they usually get paired with "constants" (e.g. trace_to_jaxpr_nounits and trace_to_jaxpr_dynamic output a list of constants), which are values that are not arguments and that we don't want to turn into literals (e.g. because we want to de-duplicate them, or even just avoid inlining them in pretty-prints). In some cases, these "constants" can be core.Tracers, like when we form the jaxprs for jax.lax.scan and the body function closes over some Tracer; when that's possible, because Tracers have to be handled with core.Primitive.bind, we typically just convert them to arguments (via pe.convert_constvars_jaxpr). But in other cases the constants that come out can't be Tracers (e.g. in the JVP rule of an initial-style primitive, when we run ad.jvp_jaxpr, we can get new constants out which can't be Tracers and must be raw array values). That's when core.ClosedJaxpr comes in handy: it lets us pair a jaxpr with some array constants so that the caller, e.g. a JVP rule for an initial-style higher-order primitive, doesn't need to deal with handling new constant values and their input binders. In other words, primitives which are parameterized by ClosedJaxprs can have simpler rules, especially jaxpr-to-jaxpr rules, since those rules don't need to worry about handling new constants/binders introduced by the rule.

On that last point, when working on #10576 we ran into a situation where

  • the current signature for "custom-policy partial eval rules" didn't allow a custom partial evaluation rule to introduce new constants (because such rules just get to output a pair of Optional[JaxprEqn]s and have no output for "new constants for the caller to handle appropriately");

  • but to perform an optimization, namely hoisting loop-invariant residual computations out of a scan body, we might need such a rule to introduce multiple equations as well as new constants.

To proceed, there were at least two options:

  1. make the signature for custom-policy partial evaluation rules even more complex (to support outputting multiple equations, new variable names being introduced, new constants, etc)

  2. just use a call primitive to handle the "multiple equations with new variables" problem, and as long as it was a call primitive with a ClosedJaxpr it would handle the constants problem too.

I chose the second approach, which led to this PR.

For simplicity, we could delete core.call_p in favor of this core.closed_call_p; after all, the former is not used at all. Going further, we might want to make all higher-order primitives (i.e. even the final-style ones, not just the initial style ones as at present) take ClosedJaxprs rather than Jaxprs; futher still, at that point we could de-duplicate Jaxpr and ClosedJaxpr so that we only have one such type. Those simplifications sound reasonable, but they're out of scope for this PR. Here I just want to land a change for enabling the new remat implementation with scan inside!

Finally, some notes on the changes here. Final-style primitives (like the new closed_call_p) have two forms, with different parameters: the 'bind form' used during tracing which takes a Python callable as a parameter representing the function to be called (really a linear_util.WrappedFun), and the 'jaxpr form' which appears in a jaxpr which itself takes a Jaxpr (or after this PR alternatively a ClosedJaxpr). Since we're introducing a primitive which is like core.call_p except that it takes a ClosedJaxpr parameter, we need to

  • update places where the bind-form primitive is converted to the jaxpr-form primitive (i.e. JaxprTrace.process_call and DynamicJaxprTrace.process_call in partial_eval.py, both of which can be handled by using the existing "call param updater" hook) to actually produce a ClosedJaxpr parameter;

  • update places where the jaxpr-form is converted to the bind-form (namely ClosedCallPrimitive.get_bind_params in core.py)

  • update rules which consume the jaxpr-form to handle the ClosedJaxpr parameter (namely the MLIR lowering rule in mlir.py, the transpose rule in ad.py, the typecheck rule in core.py, the DCE rule in partial_eval.py, and (once it exists for any calls) the forwarding rule in partial_eval.py); note that we do not need to update rules which consume the bind form (e.g. JVPTrace.process_call or BatchTrace.process_call) since the bind forms of call_p and closed_call_p are identical;

  • update core_test.py to cover the new call primitive.

Only the second-to-last bullet seems burdensome. That would be mitigated by moving to make all call primitives take ClosedJaxpr parameters, which I think was already a good idea. But again that's out of scope!

@mattjj mattjj requested a review from sharadmv May 14, 2022 18:07
Copy link
Collaborator

@sharadmv sharadmv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the detailed PR description.

@@ -986,13 +986,13 @@ def f_lowered(ctx, *args, **params):

def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args):
if isinstance(call_jaxpr, core.Jaxpr):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, this check exists because not all call primitives use closed jaxprs yet. When they do, we can delete this.

@@ -511,6 +513,12 @@ def partial_eval_wrapper_nounits(
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}

def _closed_call_param_updater(params, _, __):
jaxpr = params.get('call_jaxpr')
if jaxpr is None: return params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is jaxpr None here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question, I forget... let me see if I can exercise this.

Copy link
Collaborator Author

@mattjj mattjj May 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's because in JaxprTrace.process_call we actually call the same call_param_updater for both the bind-form and jaxpr-form parameter versions. Usually it's just used to update params like donated_invars, and it doesn't matter whether we're working with the bind-form or the jaxpr-form (e.g. for xla_call).

This required behavior was covered by the tests in core_test.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 14, 2022
@copybara-service copybara-service bot merged commit 86899ee into jax-ml:main May 14, 2022
@mattjj mattjj deleted the closed-call branch May 11, 2023 05:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants