Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error raised in comparing static argument for a vmapped function after running the function without vmap before #20466

Closed
JasonMH17 opened this issue Mar 27, 2024 · 7 comments
Assignees
Labels
question Questions for the JAX team

Comments

@JasonMH17
Copy link

JasonMH17 commented Mar 27, 2024

Description

While running a single jitted instance of our jax function, we found that a subsequent run of the same function, but vmapped, leads to a jax.errors.TracerBoolConversionError. Specifically this occurs when the our static_arg is checked for similarity as the class has been traced during the prior single run leading to the boolean conversion error.
The MWE code is enclosed below:

import functools

import jax
import jax.numpy as jnp

@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass
class SomeStaticType:
  params: jnp.ndarray = dataclasses.field(
      default_factory=lambda: jnp.zeros((10,))
  )

  def tree_flatten(self):
    children = (self.params,)
    return (children, ('params',))

  @classmethod
  def tree_unflatten(cls, _, children):
    return cls(*children)

  def __hash__(self):
    return 1  # Make sure the `__eq__` is called.

  def __eq__(self, other):
    return (self.params == other.params).all()  # !!! Raises error because `(self.params == other.params)` becomes a tracer.


@functools.partial(jax.jit, static_argnames=("static_arg",))
def inner_func(x, static_arg: SomeStaticType):
    return jnp.sum(x*x + 3*x)


@functools.partial(jax.jit, static_argnames=("static_arg",))  # `jit`.
def outer_func(x, static_arg: SomeStaticType):
    return jax.vmap(inner_func, in_axes=(1,None))(x, static_arg)  # `vmap`.



## Main
static_arg = SomeStaticType()
print(static_arg)
x = jnp.zeros((3,))

### First call `inner_func` once to get the cache filled.
inner_func(x, static_arg)

### Then call `outer_func` whch leads to errors.
x_extended = jnp.zeros((3,2))
outer_func(x_extended, static_arg)``` 

### System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.10.7 (tags/v3.10.7:6cc6b13, Sep  5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
@JasonMH17 JasonMH17 added the bug Something isn't working label Mar 27, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 1, 2024

Hi, thanks for the report! This is working as intended: when you pass arrays through a JAX transformation like jit or vmap, they become traced, and you can no longer perform logic that requires concrete values.

Your SomeStaticType class is not actually static because it contains JAX arrays. It correctly flattens the parameter array in children, which specifies the dynamic attributes of a pytree. Because SomeStatictype is not static, it's not correct to mark corresponding arguments as static, and you can fix the problem by avoiding doing so:

@jax.jit
def inner_func(x, static_arg: SomeStaticType):
    return jnp.sum(x*x + 3*x)

@jax.jit
def outer_func(x, static_arg: SomeStaticType):
    return jax.vmap(inner_func, in_axes=(1,None))(x, static_arg)

With this change your code runs without any error.

(Side-note: I'd also remove the __hash__ and __eq__ definitions from your type, because your structure is not actually hashable given that it contains non-hashable contents!)

@jakevdp jakevdp self-assigned this Apr 1, 2024
@jakevdp jakevdp added question Questions for the JAX team and removed bug Something isn't working labels Apr 1, 2024
@yuhonglin
Copy link

Your SomeStaticType class is not actually static because it contains JAX arrays. It correctly flattens the parameter array in children, which specifies the dynamic attributes of a pytree. Because SomeStatictype is not static, it's not correct to mark corresponding arguments as static

Thanks for the reply. I don't understand the logic behind JAX array can't be static?

  1. Does it mean that even in jax.jit, labeling an JAX array as static is considered inappropriate? (Then why doesn't jax.jit throw an error in such cases?)
  2. What if we really need some static arrays as static parameters? Should we always use tuples and converge them into arrays every time we use them?
  3. Is it a side effect of omnistaging?

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 1, 2024

  1. Does it mean that even in jax.jit, labeling an JAX array as static is considered inappropriate? (Then why doesn't jax.jit throw an error in such cases?)

Yes. Static arguments must be hashable, and JAX arrays are not hashable, so they cannot be used as static arguments. The reasoning behind this is that cached JIT compilations are indexed by a hash table with keys built from the hashes of static arguments along with the static attributes of any array arguments. If we have non-hashable static arguments, then we have no way of knowing whether a particular function call should be re-compiled.

jax.jit does throw an error if you try to pass a JAX array directly as a static argument, but as you found you can work around this by wrapping your arrays in some class that defines a __hash__ method. Typically such a workaround will lead to a different error further down in the stack.

  1. What if we really need some static arrays as static parameters? Should we always use tuples and converge them into arrays every time we use them?

I suspect that in any situation in which you think a static array is the right solution, a dynamic array is probably the better solution.

  1. Is it a side effect of omnistaging?

Sort of. Previous to omnistaging it was possible to mix staged and non-staged JAX computations; now all jax computations are staged. But the requirement that static data be hashable has been true even pre-omnistaging. In very early JAX versions it was possible to mark jax arrays as static while implicitly using their object ID as their hash, but this frequently led to surprising recompilations, so we made it an error.

@yuhonglin
Copy link

Thanks for the further reply! For OP's original question: is it true that a "static argument of pytree class" can't be used with vmap because there is no way in vmap to designate which parameter is static? Although this may fall into the category as you said "lead to a different error further down in the stack", IIUC, this really feels like a limitation of vmap and at least a better error message might be needed...

Please see my other responses inline below.

Typically such a workaround will lead to a different error further down in the stack.

Thanks for telling us this. We currently heavily depend on this "workaround", i.e., wrapping jnp.ndarray into hashable classes. Hope it won't cause lots of issues (I suppose as long as this usage it is allowed, it should be OK...)

I suspect that in any situation in which you think a static array is the right solution, a dynamic array is probably the better solution.

Some times, we need to store some parameters like the parameters of a linear model or filter. These parameters are vectors of floats and won't be changed in training (they are hyper-parameters). So it is more convenient to store in a static jnp.ndarray.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 1, 2024

These parameters are vectors of floats and won't be changed in training

Note that "static" and "constant" are not identical concepts in JAX. It's fine to have a constant array that is dynamic, and you shouldn't see any performance penalties in this case. If you have constant arrays, you should not go out of your way to mark them as static, especially if it requires workarounds like wrapping them in a hashable class.

@yuhonglin
Copy link

These parameters are vectors of floats and won't be changed in training

Note that "static" and "constant" are not identical concepts in JAX. It's fine to have a constant array that is dynamic, and you shouldn't see any performance penalties in this case. If you have constant arrays, you should not go out of your way to mark them as static, especially if it requires workarounds like wrapping them in a hashable class.

Thanks and understood. We feel having such parameters as "static" is better because,

  • Compared with plain dynamic input argument, it is more conceptually clear. And it is easier when we want to do conditionals etc. based on it.
  • Compared with some global constants, it is safer because JAX can automatically recompile the function every time the value changes.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 2, 2024

Well, I disagree on "conceptually clear" (as I mentioned, "static" is a different concept than "constant"). And doing conditionals on such values is somewhat fraught, because element access is a traced operation under JIT.

But if you want to treat arrays as static values, I'd recommend one of two approaches:

  1. split the values into tuples of ints or floats, and then mark these as static. The advantage here is that such tuples are actually hashable, and individual elements are statically hashable (so e.g. it's trivial to use them for trace-time control flow). The disadvantage is that converting to and from an array will be somewhat expensive, particularly as the size of the parameter array grows.
  2. A second option is to use a wrapper class that looks something like this:
import jax
from functools import partial
import dataclasses

@dataclasses.dataclass(frozen=True)
class HashableArrayWrapper:
  val: jax.Array
  def __hash__(self):
    return id(self.val)
  def __eq__(self, other):
    return isinstance(other, HashableArrayWrapper) and id(self.val) == id(other.val)

@partial(jax.jit, static_argnums=0)
def f(x):
  val = x.val
  return val ** 2

x = jax.numpy.arange(5)
f(HashableArrayWrapper(x))
# Array([ 0,  1,  4,  9, 16], dtype=int32)

The disadvantage here is that any time you change val, it will needlessly trigger a re-compilation. You also will not be able to use elements of val statically for control flow, because array indexing is a traced operation. I honestly cannot think of any advantages to this approach, and I would not recommend this as a solution over just using your parameter array directly as a dynamic variable. But if it's important to you that your parameters are stored in an array, and that the array be treated as static by JIT, then this is probably the best way to do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants