-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feedback on jax.experimental.host_callback #3127
Comments
Nice! For the first version I wanted to keep it simple, but your version is certainly a good candidate for id_print.
BTW, the above could have been written also:
... essentially |
I knew that repeated
to have
I think that this is what you meant by zipping the params and transforms. |
I do not understand this comment. What do you mean by "change the order in which JAX executes its arguments"? Note that in the last print statement you wrote |
Yes, exactly! |
Oops, fixed. I edited my original post showing exactly the output that I would have expected. Does that make more sense? |
I see now what you are pointing out. The tapping order is guaranteed to respect data dependency, but otherwise it is subject to how JAX lays out the Jaxpr. In this case it seems that JAX produces the a reordered Jaxpr, still respecting the data dependency. For example, the tap of There is thinking to add a "program-order" dependency to the id_tap primitive, but even that is not going to add a dependency between |
#3132 addresses issue 3. It seems that we have discussed the issues, but there is no action item. Should we close this issue? |
I agree, we can close this for now. There are still a couple of follow-ups, but those can be discussed separately:
|
First of all -- this is an awesome new feature! I am so excited about finally being able to plumb values at runtime out of the guts of complex JAX transformations.
I played around with a few variations on the API for
id_print
:id_print
that accepts*args
likeprint
and allows for inlinestr
arguments (dropped from the return value), along the line of my suggestion fromAn implementation of id_print with CustomCall #2739 (comment) (see
_simple_print_callback
below)_IndentingPrinter
below).Generally I liked both of these changes. See below for my implementation, some examples and my feedback.
Implementation
Example usage
Here are three ways to compute the same derivative:
Outputs:
My feedback / questions
jit(jacfwd(power3))
andjacfwd(jit(power3))
execute variables in a different order.id_tap
was quite easy to adapt to my own printing function. Kudos for designing an easy to use API!id_tap
is thatparams
aren't associated withtransforms
, so it isn't obvious how they pair up. Perhaps not coincidentally, repeated versions of the same transformation don't seem to work (see host_callback.id_print fails with nested vmap #3126 forvmap
). It seems like transformations and their parameters really should be zipped together.None
withjacfwd
come from? For a simpler example, considerjax.jacfwd(host_callback.id_print)(1.0)
which outputsbatch_dims: (0, None) transforms: ('jvp', 'batch')
. I'm not quite sure what this batch dimension could correspond to me, given thatid_print
only has one input.id_print
on the outputs of a function can apparently change the order in which JAX executes its arguments underjit
. I don't know if this is a bug or just me being surprised:jacfwd(jit(power3))
example above without printingz
, I would have expected this to look like:The text was updated successfully, but these errors were encountered: