Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using JaxprTracer as index of list in lax.fori_loop #2962

Open
tvieijra opened this issue May 5, 2020 · 2 comments
Open

using JaxprTracer as index of list in lax.fori_loop #2962

tvieijra opened this issue May 5, 2020 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@tvieijra
Copy link

tvieijra commented May 5, 2020

I am working on an application where I have a list of numpy tensors with different shapes. I want to loop over this list to multiply these tensors together to a single scalar, possibly also transforming the tensor in some way at every step in the loop. I can compile this code with jit and it works fine, but the compilation time becomes exceedingly long for large lists. I want to speed up the compilation time using lax.fori_loop, but indexing the list seems not possible.

Here is a piece of toy code:

@jit
def trace(list_of_tensors):
    L = len(list_of_tensors)
    edge = np.ones((1))

    for i in range(L):
        edge = np.einsum('i,ij->j', edge, list_of_tensors[i])

    return edge

for i in range(5):
    key = random.PRNGKey(onp.random.randint(1,1000000))
    list_of_tensors = [random.normal(key, (1,2)), random.normal(key, (2,4)), random.normal(key, (4,8)), random.normal(key, (8,1))]

    print(trace(list_of_tensors))

I want to change the explicit for loop in trace with a lax.fori_loop like this

@jit
def trace(list_of_tensors):
    L = len(list_of_tensors)
    edge = np.ones((1))

    list_of_tensors, edge = lax.fori_loop(0, L, loop_body, (list_of_tensors, edge))

    return edge

def loop_body(i, args):
    L = len(args[0])
    edge = np.einsum('i,ij->j', args[1], args[0][i])

    return (args[0], edge)

for i in range(5):
    key = random.PRNGKey(onp.random.randint(1,1000000))
    list_of_tensors = [random.normal(key, (1,2)), random.normal(key, (2,4)), random.normal(key, (4,8)), random.normal(key, (8,1))]

    print(trace(list_of_tensors)[1])

When running this code, I get the error TypeError: list indices must be integers or slices, not JaxprTracer because I want to use the loop index as an index of a list. I tried casting the loop index to a jax.ops.index or an integer but none of these work. Is there another way to use the loop index as an index of the list?

@skye
Copy link
Member

skye commented May 5, 2020

To slightly sidestep your question, check out lax.scan. I think you should be able to express your original for-loop more easily using scan than fori_loop.

To more directly answer your question, there's no way to use the index from a fori_loop to index a list. Basically, the body of a fori_loop needs to be expressible as a single traced jax computation, and since jax tracers don't "see" the indexing of a regular Python list, it can't trace the list access properly. It would work if you were indexing a jax array instead of a Python list.

@mattjj
Copy link
Collaborator

mattjj commented May 7, 2020

Interestingly, this isn't just a tracing issue: in XLA HLO there's no way to use a dynamic value (like a loop iteration count) to index into a tuple (i.e. the only product type in HLO), in part because tuples can have elements with different shapes and the shape of every intermediate must be decidable in the type system. In other words, there's no dynamic version of GetTupleElement.

Unfortunately that means even if we could trace this computation effectively in Python (which I'm sure we could work out how to do) we don't have a way to lower it to a single compiled computation.

The only option now seems to be to pad and mask things into one shape so that you can stack things into an array. We're working on a transformation jax.mask to handle that automatically for you, and this is a good example use case that wasn't on our radar, but it's still a prototype and not ready for use at the moment.

I'll self-assign this because I'm working on jax.mask (with several others too), but I likely won't be able to report progress for several weeks or more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants