-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.while_loop calls body_fun only once #1708
Comments
You cannot update array values in-place with JAX. Instead use |
I'll also note that JAX's control flow operations like There's more info about this in the docs: |
To put a fine point on it, to express something like this functionally you might write from jax import lax
def body_fun(val):
return val+1
cond_fun = lambda j: j<10
u=0
u = lax.while_loop(cond_fun, body_fun, u)
print('xla', u) That is, you can't rely on side-effects: you must make data flow explicit. WDYT? |
I see, thanks for the clarification! My original motivation was the following: I have a model approximated by a neural network. For training purposes, I need to evaluate this model on a large dataset (my neural net encodes a generative model and I need a huge number of samples to beat down Monte-Carlo fluctuations when evaluating the cost function). I put it on the GPU and I quickly started running into memory errors because the GPU has limited memory and the data set is too large. So I decided to split up the task in mini-batches and do the evaluation as follows:
I was hoping to get some speedup for this from XLA; but I guess there's no free lunch :/ I think the speed bottleneck here comes from transferring data to the GPU because I see that the calculation is comparably fast on a CPU without using mini-batches. |
You can probably do still do this with XLA, you'll just have to be a little more creative with how you accumulate results. I would suggest looking into |
I'm going to close this because I think the specific question was answered, but @mgbukov please open another issue if you have follow-up questions. |
In the example below, I increment the value of
u[0]
every timebody_fun
is called inside thewhile_loop
;while_loop
should break after10
iterations but it seems to do so after the first iteration.Any suggestions about how I can get the value of
u[0]
to update are welcome!The text was updated successfully, but these errors were encountered: