-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error raised in comparing static argument for a vmapped function after running the function without vmap before #20466
Comments
Hi, thanks for the report! This is working as intended: when you pass arrays through a JAX transformation like Your @jax.jit
def inner_func(x, static_arg: SomeStaticType):
return jnp.sum(x*x + 3*x)
@jax.jit
def outer_func(x, static_arg: SomeStaticType):
return jax.vmap(inner_func, in_axes=(1,None))(x, static_arg) With this change your code runs without any error. (Side-note: I'd also remove the |
Thanks for the reply. I don't understand the logic behind
|
Yes. Static arguments must be hashable, and JAX arrays are not hashable, so they cannot be used as static arguments. The reasoning behind this is that cached JIT compilations are indexed by a hash table with keys built from the hashes of static arguments along with the static attributes of any array arguments. If we have non-hashable static arguments, then we have no way of knowing whether a particular function call should be re-compiled.
I suspect that in any situation in which you think a static array is the right solution, a dynamic array is probably the better solution.
Sort of. Previous to omnistaging it was possible to mix staged and non-staged JAX computations; now all jax computations are staged. But the requirement that static data be hashable has been true even pre-omnistaging. In very early JAX versions it was possible to mark jax arrays as static while implicitly using their object ID as their hash, but this frequently led to surprising recompilations, so we made it an error. |
Thanks for the further reply! For OP's original question: is it true that a "static argument of pytree class" can't be used with Please see my other responses inline below.
Thanks for telling us this. We currently heavily depend on this "workaround", i.e., wrapping
Some times, we need to store some parameters like the parameters of a linear model or filter. These parameters are vectors of floats and won't be changed in training (they are hyper-parameters). So it is more convenient to store in a static |
Note that "static" and "constant" are not identical concepts in JAX. It's fine to have a constant array that is dynamic, and you shouldn't see any performance penalties in this case. If you have constant arrays, you should not go out of your way to mark them as static, especially if it requires workarounds like wrapping them in a hashable class. |
Thanks and understood. We feel having such parameters as "static" is better because,
|
Well, I disagree on "conceptually clear" (as I mentioned, "static" is a different concept than "constant"). And doing conditionals on such values is somewhat fraught, because element access is a traced operation under JIT. But if you want to treat arrays as static values, I'd recommend one of two approaches:
import jax
from functools import partial
import dataclasses
@dataclasses.dataclass(frozen=True)
class HashableArrayWrapper:
val: jax.Array
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return isinstance(other, HashableArrayWrapper) and id(self.val) == id(other.val)
@partial(jax.jit, static_argnums=0)
def f(x):
val = x.val
return val ** 2
x = jax.numpy.arange(5)
f(HashableArrayWrapper(x))
# Array([ 0, 1, 4, 9, 16], dtype=int32) The disadvantage here is that any time you change |
Description
While running a single jitted instance of our jax function, we found that a subsequent run of the same function, but vmapped, leads to a jax.errors.TracerBoolConversionError. Specifically this occurs when the our static_arg is checked for similarity as the class has been traced during the prior single run leading to the boolean conversion error.
The MWE code is enclosed below:
The text was updated successfully, but these errors were encountered: