You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working on an application where I have a list of numpy tensors with different shapes. I want to loop over this list to multiply these tensors together to a single scalar, possibly also transforming the tensor in some way at every step in the loop. I can compile this code with jit and it works fine, but the compilation time becomes exceedingly long for large lists. I want to speed up the compilation time using lax.fori_loop, but indexing the list seems not possible.
Here is a piece of toy code:
@jit
def trace(list_of_tensors):
L = len(list_of_tensors)
edge = np.ones((1))
for i in range(L):
edge = np.einsum('i,ij->j', edge, list_of_tensors[i])
return edge
for i in range(5):
key = random.PRNGKey(onp.random.randint(1,1000000))
list_of_tensors = [random.normal(key, (1,2)), random.normal(key, (2,4)), random.normal(key, (4,8)), random.normal(key, (8,1))]
print(trace(list_of_tensors))
I want to change the explicit for loop in trace with a lax.fori_loop like this
When running this code, I get the error TypeError: list indices must be integers or slices, not JaxprTracer because I want to use the loop index as an index of a list. I tried casting the loop index to a jax.ops.index or an integer but none of these work. Is there another way to use the loop index as an index of the list?
The text was updated successfully, but these errors were encountered:
To slightly sidestep your question, check out lax.scan. I think you should be able to express your original for-loop more easily using scan than fori_loop.
To more directly answer your question, there's no way to use the index from a fori_loop to index a list. Basically, the body of a fori_loop needs to be expressible as a single traced jax computation, and since jax tracers don't "see" the indexing of a regular Python list, it can't trace the list access properly. It would work if you were indexing a jax array instead of a Python list.
Interestingly, this isn't just a tracing issue: in XLA HLO there's no way to use a dynamic value (like a loop iteration count) to index into a tuple (i.e. the only product type in HLO), in part because tuples can have elements with different shapes and the shape of every intermediate must be decidable in the type system. In other words, there's no dynamic version of GetTupleElement.
Unfortunately that means even if we could trace this computation effectively in Python (which I'm sure we could work out how to do) we don't have a way to lower it to a single compiled computation.
The only option now seems to be to pad and mask things into one shape so that you can stack things into an array. We're working on a transformation jax.mask to handle that automatically for you, and this is a good example use case that wasn't on our radar, but it's still a prototype and not ready for use at the moment.
I'll self-assign this because I'm working on jax.mask (with several others too), but I likely won't be able to report progress for several weeks or more.
I am working on an application where I have a list of numpy tensors with different shapes. I want to loop over this list to multiply these tensors together to a single scalar, possibly also transforming the tensor in some way at every step in the loop. I can compile this code with
jit
and it works fine, but the compilation time becomes exceedingly long for large lists. I want to speed up the compilation time usinglax.fori_loop
, but indexing the list seems not possible.Here is a piece of toy code:
I want to change the explicit for loop in
trace
with alax.fori_loop
like thisWhen running this code, I get the error
TypeError: list indices must be integers or slices, not JaxprTracer
because I want to use the loop index as an index of a list. I tried casting the loop index to ajax.ops.index
or an integer but none of these work. Is there another way to use the loop index as an index of the list?The text was updated successfully, but these errors were encountered: