Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent to TensorArray for while_loop #3106

Closed
Joshuaalbert opened this issue May 15, 2020 · 7 comments
Closed

Equivalent to TensorArray for while_loop #3106

Joshuaalbert opened this issue May 15, 2020 · 7 comments

Comments

@Joshuaalbert
Copy link
Contributor

A large class of algorithms require maintaining an arbitrary length list of arrays over loops.
What is the flow for collecting a list of tensors (all of the same shape and dtype) in a while_loop?
This should be the same notion as TensorArray in tensorflow.
One obvious way is to use python while loops, but what I'm looking for is while_loop support.
Is there a way to do it without unrolling the graph?
Also, if it can only be done with python while loops what are the constraints and performance effects? Is it still jittable? Is it feasible for really long loops?

@hawkinsp
Copy link
Collaborator

Most likely, you're looking for jax.lax.scan:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan

e.g., if you were writing, say, an RNN, that's what you would want to use.

Does that help?

@shoyer
Copy link
Collaborator

shoyer commented May 15, 2020

TensorArray lets your grow tensors in each loop iteration, but that isn't possible to in XLA (i.e., under jit).

Instead, you should preallocate the full sized arrays you need, and update them using primitives such as while_loop and scan. This may seem extremely inefficient, but is actually only a 2x cost in memory (the difference between the upper triangular half of a matrix and the full matrix).

@Joshuaalbert
Copy link
Contributor Author

Thanks for the pointer to lax.scan. Solves the problem for most problems except those with a dynamic length. Thanks for the info about XLA too.

@Joshuaalbert
Copy link
Contributor Author

@shoyer re-opening this thread to ask another related question. Is there a way to have both a dynamic halting condition and a dynamically allocated result. Basically, I need scan-like behaviour in that it collects results, but the termination condition is dynamic as in a while_loop. I was wondering how you would implement scan under the hood, and noted that in TF it's with a while_loop. In JAX are you essentially just preallocating a result array tree and then using dynamic_update_slice?

@Joshuaalbert Joshuaalbert reopened this Jul 14, 2020
@Joshuaalbert
Copy link
Contributor Author

Looks like the _scan_impl in JAX also uses a while_loop using dynamic_update_index_in_dim. Makes sense. So I'm guessing the answer to my problem is to do a similar thing.

@Joshuaalbert
Copy link
Contributor Author

Clear

@albertz
Copy link

albertz commented May 8, 2023

Extending on the question, I wonder: Without sth like TensorArray, how can that be efficient with backprop? Wouldn't it keep a copy of ys of each iteration for backprop? StackOverflow question with more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants