-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equivalent to TensorArray for while_loop #3106
Comments
Most likely, you're looking for e.g., if you were writing, say, an RNN, that's what you would want to use. Does that help? |
TensorArray lets your grow tensors in each loop iteration, but that isn't possible to in XLA (i.e., under jit). Instead, you should preallocate the full sized arrays you need, and update them using primitives such as while_loop and scan. This may seem extremely inefficient, but is actually only a 2x cost in memory (the difference between the upper triangular half of a matrix and the full matrix). |
Thanks for the pointer to |
@shoyer re-opening this thread to ask another related question. Is there a way to have both a dynamic halting condition and a dynamically allocated result. Basically, I need scan-like behaviour in that it collects results, but the termination condition is dynamic as in a while_loop. I was wondering how you would implement scan under the hood, and noted that in TF it's with a while_loop. In JAX are you essentially just preallocating a result array tree and then using |
Looks like the _scan_impl in JAX also uses a |
Clear |
Extending on the question, I wonder: Without sth like TensorArray, how can that be efficient with backprop? Wouldn't it keep a copy of |
A large class of algorithms require maintaining an arbitrary length list of arrays over loops.
What is the flow for collecting a list of tensors (all of the same shape and dtype) in a while_loop?
This should be the same notion as TensorArray in tensorflow.
One obvious way is to use python while loops, but what I'm looking for is while_loop support.
Is there a way to do it without unrolling the graph?
Also, if it can only be done with python while loops what are the constraints and performance effects? Is it still jittable? Is it feasible for really long loops?
The text was updated successfully, but these errors were encountered: