-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add core.closed_call_p #10711
add core.closed_call_p #10711
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the detailed PR description.
@@ -986,13 +986,13 @@ def f_lowered(ctx, *args, **params): | |||
|
|||
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, | |||
avals_out, tokens_in, *args): | |||
if isinstance(call_jaxpr, core.Jaxpr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm, this check exists because not all call primitives use closed jaxprs yet. When they do, we can delete this.
@@ -511,6 +513,12 @@ def partial_eval_wrapper_nounits( | |||
call_partial_eval_rules: Dict[Primitive, Callable] = {} | |||
call_param_updaters: Dict[Primitive, Callable] = {} | |||
|
|||
def _closed_call_param_updater(params, _, __): | |||
jaxpr = params.get('call_jaxpr') | |||
if jaxpr is None: return params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is jaxpr None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good question, I forget... let me see if I can exercise this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's because in JaxprTrace.process_call
we actually call the same call_param_updater for both the bind-form and jaxpr-form parameter versions. Usually it's just used to update params like donated_invars
, and it doesn't matter whether we're working with the bind-form or the jaxpr-form (e.g. for xla_call).
This required behavior was covered by the tests in core_test.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks!
This PR adds a variant of core.call_p called core.closed_call_p. The only difference is that in its 'jaxpr form' its call_jaxpr parameter is a core.ClosedJaxpr rather than a core.Jaxpr.
Some background:
core.call_p
is the most vanilla call primitive possible: unlike, say,xla_call_p
(the primitive underlyingjax.jit
), its impl rule isn't to compile an XLA computation, but instead it's just to interpret its jaxpr (staying in Python, usingcore.eval_jaxpr
). Correspondingly, it doesn't need to raise the abstraction level of its arguments. It's basically a model for other "final-style" call primitives, each of which is interesting in precisely how it deviates fromcore.call_p
(e.g.xla_call_p
's impl rule stages out for compilation;remat_call_p
has a special partial evaluation rule;custom_jvp_call_p
has a special JVP rule; etc). Historically it was the first call primitive we introduced, just to test the system;core.call_p
is not really used anywhere.core.ClosedJaxpr
is a data type which would be better named asPartiallyAppliedJaxpr
. When we form jaxprs, they usually get paired with "constants" (e.g.trace_to_jaxpr_nounits
andtrace_to_jaxpr_dynamic
output a list of constants), which are values that are not arguments and that we don't want to turn into literals (e.g. because we want to de-duplicate them, or even just avoid inlining them in pretty-prints). In some cases, these "constants" can becore.Tracer
s, like when we form the jaxprs forjax.lax.scan
and the body function closes over someTracer
; when that's possible, becauseTracer
s have to be handled withcore.Primitive.bind
, we typically just convert them to arguments (viape.convert_constvars_jaxpr
). But in other cases the constants that come out can't beTracer
s (e.g. in the JVP rule of an initial-style primitive, when we runad.jvp_jaxpr
, we can get new constants out which can't beTracer
s and must be raw array values). That's whencore.ClosedJaxpr
comes in handy: it lets us pair ajaxpr
with some array constants so that the caller, e.g. a JVP rule for an initial-style higher-order primitive, doesn't need to deal with handling new constant values and their input binders. In other words, primitives which are parameterized byClosedJaxpr
s can have simpler rules, especially jaxpr-to-jaxpr rules, since those rules don't need to worry about handling new constants/binders introduced by the rule.On that last point, when working on #10576 we ran into a situation where
the current signature for "custom-policy partial eval rules" didn't allow a custom partial evaluation rule to introduce new constants (because such rules just get to output a pair of
Optional[JaxprEqn]
s and have no output for "new constants for the caller to handle appropriately");but to perform an optimization, namely hoisting loop-invariant residual computations out of a
scan
body, we might need such a rule to introduce multiple equations as well as new constants.To proceed, there were at least two options:
make the signature for custom-policy partial evaluation rules even more complex (to support outputting multiple equations, new variable names being introduced, new constants, etc)
just use a call primitive to handle the "multiple equations with new variables" problem, and as long as it was a call primitive with a
ClosedJaxpr
it would handle the constants problem too.I chose the second approach, which led to this PR.
For simplicity, we could delete
core.call_p
in favor of thiscore.closed_call_p
; after all, the former is not used at all. Going further, we might want to make all higher-order primitives (i.e. even the final-style ones, not just the initial style ones as at present) takeClosedJaxpr
s rather thanJaxpr
s; futher still, at that point we could de-duplicateJaxpr
andClosedJaxpr
so that we only have one such type. Those simplifications sound reasonable, but they're out of scope for this PR. Here I just want to land a change for enabling the newremat
implementation withscan
inside!Finally, some notes on the changes here. Final-style primitives (like the new
closed_call_p
) have two forms, with different parameters: the 'bind form' used during tracing which takes a Python callable as a parameter representing the function to be called (really alinear_util.WrappedFun
), and the 'jaxpr form' which appears in a jaxpr which itself takes aJaxpr
(or after this PR alternatively aClosedJaxpr
). Since we're introducing a primitive which is likecore.call_p
except that it takes aClosedJaxpr
parameter, we need toupdate places where the bind-form primitive is converted to the jaxpr-form primitive (i.e.
JaxprTrace.process_call
andDynamicJaxprTrace.process_call
in partial_eval.py, both of which can be handled by using the existing "call param updater" hook) to actually produce aClosedJaxpr
parameter;update places where the jaxpr-form is converted to the bind-form (namely
ClosedCallPrimitive.get_bind_params
in core.py)update rules which consume the jaxpr-form to handle the
ClosedJaxpr
parameter (namely the MLIR lowering rule in mlir.py, the transpose rule in ad.py, the typecheck rule in core.py, the DCE rule in partial_eval.py, and (once it exists for any calls) the forwarding rule in partial_eval.py); note that we do not need to update rules which consume the bind form (e.g.JVPTrace.process_call
orBatchTrace.process_call
) since the bind forms ofcall_p
andclosed_call_p
are identical;update core_test.py to cover the new call primitive.
Only the second-to-last bullet seems burdensome. That would be mitigated by moving to make all call primitives take
ClosedJaxpr
parameters, which I think was already a good idea. But again that's out of scope!