Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feedback on jax.experimental.host_callback #3127

Closed
shoyer opened this issue May 17, 2020 · 8 comments
Closed

Feedback on jax.experimental.host_callback #3127

shoyer opened this issue May 17, 2020 · 8 comments
Assignees

Comments

@shoyer
Copy link
Collaborator

shoyer commented May 17, 2020

First of all -- this is an awesome new feature! I am so excited about finally being able to plumb values at runtime out of the guts of complex JAX transformations.

I played around with a few variations on the API for id_print:

  1. I experimented with an API for id_print that accepts *args like print and allows for inline str arguments (dropped from the return value), along the line of my suggestion from
    An implementation of id_print with CustomCall #2739 (comment) (see _simple_print_callback below)
  2. I experimented with automatically adding indentation to show the current transformation stack (see _IndentingPrinter below).

Generally I liked both of these changes. See below for my implementation, some examples and my feedback.

Implementation

import io
import sys
from functools import partial
import textwrap

from jax import jacfwd, grad, jit
from jax.experimental import host_callback

def _common_element_count(xs, ys):
  i = 0
  while i < min(len(xs), len(ys)) and xs[i] == ys[i]:
    i += 1
  return i

class _IndentingPrinter:
  def __init__(self, transforms=()):
    self._current_transforms = list(transforms)

  def print(self, *value, sep=' ', end='\n', file=sys.stdout, flush=False):
    buffer = io.StringIO()
    print(*value, sep=sep, end='', file=buffer)
    prefix = '| ' * len(self._current_transforms)
    print(textwrap.indent(buffer.getvalue(), prefix),
          file=file, flush=flush, end=end)

  def callback(self, arg, *, strings, sep, end, file, flush, transforms=(),
               **kwargs):
    # update current transforms, if needed
    i = _common_element_count(transforms, self._current_transforms)
    del self._current_transforms[i:]
    for transform in transforms[i:]:
      if transform == 'batch':
        self.print(f"batch [batch_dims={kwargs['batch_dims']}]")
      elif transform == 'mask':
        self.print(f"mask [logical_shapes={kwargs['logical_shapes']}]")
      else:
        self.print(transform)
      self._current_transforms.append(transform)

    # print current value
    args = list(arg)
    for i, v in strings:
      args.insert(i, v)
    self.print(*args, sep=sep, end=end, file=file, flush=flush)

def _id_print(print_callback, *args,
              result=None, sep=' ', end='\n', file=sys.stdout, flush=False):
  strings = [(i, v) for i, v in enumerate(args) if isinstance(v, str)]
  others = tuple(a for a in args if not isinstance(a, str))
  f = partial(print_callback, strings=strings, sep=sep, end=end,
              file=file, flush=flush)
  out = host_callback.id_tap(f, others, result=result)
  if result is None and len(others) == 1:
    out, = out
  return out

def _simple_print_callback(arg, *, strings, sep, end, file, flush, **kwargs):
  args = list(arg)
  for i, v in strings:
    args.insert(i, v)
  if kwargs:
    print(', '.join(f'{k}={v}' for k, v in kwargs.items()), end=': ')
  print(*args, sep=sep, end=end, file=file, flush=flush)

id_print_inline = partial(_id_print, _simple_print_callback)
id_print_indented = partial(_id_print, _IndentingPrinter().callback)

Example usage

Here are three ways to compute the same derivative:

def power3(x):
   x = id_print_indented('x=', x, sep='')
   y = x ** 2
   y = id_print_indented('y=', y, sep='')
   z = y * x
   return z

print('\njit(jacfwd(power3)):')
f = jit(jacfwd(power3))
with host_callback.outfeed_receiver():
  f(2.0)

print('\njacfwd(jit(power3)):')
f = jacfwd(jit(power3))
with host_callback.outfeed_receiver():
  f(2.0)

print('\ngrad(power3):')
f = grad(power3)
with host_callback.outfeed_receiver():
  f(2.0)

Outputs:

jit(jacfwd(power3)):
x=2.0
jvp
| batch [batch_dims=(0, None)]
| | x=[1.]
y=4.0
jvp
| batch [batch_dims=(0, None)]
| | y=[4.]

jacfwd(jit(power3)):
x=2.0
y=4.0
jvp
| batch [batch_dims=(0, None)]
| | x=[1.]
| | y=[4.]

grad(power3):
x=2.0
y=4.0
jvp
| transpose
| | y=2.0
| | x=12.0

My feedback / questions

  1. Wow, it is great to finally see what JAX is doing! E.g., in the example above, I was surprised (but in retrospect it makes complete sense) that jit(jacfwd(power3)) and jacfwd(jit(power3)) execute variables in a different order.
  2. id_tap was quite easy to adapt to my own printing function. Kudos for designing an easy to use API!
  3. The one awkward part about id_tap is that params aren't associated with transforms, so it isn't obvious how they pair up. Perhaps not coincidentally, repeated versions of the same transformation don't seem to work (see host_callback.id_print fails with nested vmap #3126 for vmap). It seems like transformations and their parameters really should be zipped together.
  4. Where does the second batch dimension None with jacfwd come from? For a simpler example, consider jax.jacfwd(host_callback.id_print)(1.0) which outputs batch_dims: (0, None) transforms: ('jvp', 'batch'). I'm not quite sure what this batch dimension could correspond to me, given that id_print only has one input.
  5. Something else that surprised me is how adding id_print on the outputs of a function can apparently change the order in which JAX executes its arguments under jit. I don't know if this is a bug or just me being surprised:
    def power3(x):
      x = id_print_indented('x=', x, sep='')
      y = x ** 2
      y = id_print_indented('y=', y, sep='')
      z = y * x
      z = id_print_indented('z=', z, sep='')
      return z
    
    print('\njacfwd(jit(power3)):')
    f = jacfwd(jit(power3))
    with host_callback.outfeed_receiver():
      f(2.0)
    outputs:
    jacfwd(jit(power3)):
    x=2.0
    jvp
    | batch [batch_dims=(0, None)]
    | | x=[1.]
    y=4.0
    jvp
    | batch [batch_dims=(0, None)]
    | | y=[4.]
    z=8.0
    jvp
    | batch [batch_dims=(0, None)]
    | | z=[12.]
    
    Based on the jacfwd(jit(power3)) example above without printing z, I would have expected this to look like:
    jacfwd(jit(power3)):
    x=2.0
    y=4.0
    z=8.0
    jvp
    | batch [batch_dims=(0, None)]
    | | x=[1.]
    | | y=[4.]
    | | z=[12.]
    
@gnecula
Copy link
Collaborator

gnecula commented May 18, 2020

I experimented with an API for id_print that accepts *args like print and allows for inline str arguments (dropped from the return value), along the line of my suggestion from
#2739 (comment) (see _simple_print_callback below)

Nice! For the first version I wanted to keep it simple, but your version is certainly a good candidate for id_print.

  f = partial(print_callback, strings=strings, sep=sep, end=end,
              file=file, flush=flush)
  out = host_callback.id_tap(f, others, result=result)

BTW, the above could have been written also:

out = host_callback_id_tap(print_callback, others, result=result, 
           strings=strings, sep=sep, end=end,  file=file, flush=flush)

... essentially id_tap will keep the kwargs and pass them along to the tap function.

@gnecula
Copy link
Collaborator

gnecula commented May 18, 2020

The one awkward part about id_tap is that params aren't associated with transforms, so it isn't obvious how they pair up. Perhaps not coincidentally, repeated versions of the same transformation don't seem to work (see #3126 for vmap). It seems like transformations and their parameters really should be zipped together.

I knew that repeated vmap and mask transforms don't work. I was planning to change transforms so that an entry is a tuple of name and parameters for the transformation. Thus, instead of

dict(transforms=('jvp', 'batch'), batch_dims=(0, 0))

to have

dict(transforms=(('jvp', dict()),
                 ('batch', dict(batch_dims=(0, 0))
                )

I think that this is what you meant by zipping the params and transforms.

@gnecula
Copy link
Collaborator

gnecula commented May 18, 2020

  1. Something else that surprised me is how adding id_print on the outputs of a function can apparently change the order in which JAX executes its arguments under jit. I don't know if this is a bug or just me being surprised:

I do not understand this comment. What do you mean by "change the order in which JAX executes its arguments"?

Note that in the last print statement you wrote z = id_print_indented('z=', y, sep='') but you are actually printing and returning y.

@shoyer
Copy link
Collaborator Author

shoyer commented May 18, 2020

I think that this is what you meant by zipping the params and transforms.

Yes, exactly!

@shoyer
Copy link
Collaborator Author

shoyer commented May 18, 2020

  1. Something else that surprised me is how adding id_print on the outputs of a function can apparently change the order in which JAX executes its arguments under jit. I don't know if this is a bug or just me being surprised:

I do not understand this comment. What do you mean by "change the order in which JAX executes its arguments"?

Note that in the last print statement you wrote z = id_print_indented('z=', y, sep='') but you are actually printing and returning y.

Oops, fixed.

I edited my original post showing exactly the output that I would have expected. Does that make more sense?

@gnecula
Copy link
Collaborator

gnecula commented May 19, 2020

I edited my original post showing exactly the output that I would have expected. Does that make more sense?

I see now what you are pointing out. The tapping order is guaranteed to respect data dependency, but otherwise it is subject to how JAX lays out the Jaxpr. In this case it seems that JAX produces the a reordered Jaxpr, still respecting the data dependency. For example, the tap of y and jvp(x) have no data dependency, so they can be reordered.

There is thinking to add a "program-order" dependency to the id_tap primitive, but even that is not going to add a dependency between y and jvp(x). So, I am not sure that we can do anything to force an order.

@gnecula
Copy link
Collaborator

gnecula commented May 24, 2020

#3132 addresses issue 3. It seems that we have discussed the issues, but there is no action item. Should we close this issue?

@shoyer
Copy link
Collaborator Author

shoyer commented May 24, 2020

I agree, we can close this for now. There are still a couple of follow-ups, but those can be discussed separately:

  • I opened Adding host_callbacks.id_tap reorders JVP evaluation #3198 for discussing (5) in depth.
  • I'm still not sure about (4), but in the scheme of things it's not a huge deal. I guess the extra batch dimension somehow corresponds to the tangent? It's still little surprising given that I'm pretty sure vmap is never called with multiple in_axes even in the implementation of jacfwd. I'll see if I can work up a simpler example...

@gnecula gnecula closed this as completed May 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants