Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.while_loop calls body_fun only once #1708

Closed
mgbukov opened this issue Nov 17, 2019 · 6 comments
Closed

lax.while_loop calls body_fun only once #1708

mgbukov opened this issue Nov 17, 2019 · 6 comments
Labels
question Questions for the JAX team

Comments

@mgbukov
Copy link

mgbukov commented Nov 17, 2019

In the example below, I increment the value of u[0] every time body_fun is called inside the while_loop; while_loop should break after 10 iterations but it seems to do so after the first iteration.

Any suggestions about how I can get the value of u[0] to update are welcome!

from jax import lax

def body_fun(val):
    u[0]+=1
    return val+1

cond_fun = lambda j: j<10


### python
u=[0]
def while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val

while_loop(cond_fun, body_fun, 0)
print('python', u)

### xla
u=[0]
lax.while_loop(cond_fun, body_fun, 0)
print('xla', u)
@shoyer
Copy link
Collaborator

shoyer commented Nov 17, 2019

You cannot update array values in-place with JAX. Instead use index_update or index_add:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates

@shoyer
Copy link
Collaborator

shoyer commented Nov 17, 2019

I'll also note that JAX's control flow operations like lax.while_loop use tracing, so the body the of the loop is indeed only evaluated once. This is intentional -- if every loop iteration was evaluated in Python, then the loop could not be compiled by XLA.

There's more info about this in the docs:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow

@mattjj
Copy link
Collaborator

mattjj commented Nov 17, 2019

To put a fine point on it, to express something like this functionally you might write

from jax import lax

def body_fun(val):
    return val+1

cond_fun = lambda j: j<10

u=0
u = lax.while_loop(cond_fun, body_fun, u)
print('xla', u)

That is, you can't rely on side-effects: you must make data flow explicit.

WDYT?

@mattjj mattjj added the question Questions for the JAX team label Nov 17, 2019
@mgbukov
Copy link
Author

mgbukov commented Nov 18, 2019

I see, thanks for the clarification!

My original motivation was the following: I have a model approximated by a neural network. For training purposes, I need to evaluate this model on a large dataset (my neural net encodes a generative model and I need a huge number of samples to beat down Monte-Carlo fluctuations when evaluating the cost function). I put it on the GPU and I quickly started running into memory errors because the GPU has limited memory and the data set is too large. So I decided to split up the task in mini-batches and do the evaluation as follows:

for j in range(num_batches):
        batch, batch_idx = next(batches)
        log_psi[batch_idx], phase_psi[batch_idx] = evaluate_NN(params, batch)

I was hoping to get some speedup for this from XLA; but I guess there's no free lunch :/

I think the speed bottleneck here comes from transferring data to the GPU because I see that the calculation is comparably fast on a CPU without using mini-batches.

@shoyer
Copy link
Collaborator

shoyer commented Nov 18, 2019

You can probably do still do this with XLA, you'll just have to be a little more creative with how you accumulate results. I would suggest looking into lax.scan.

@mattjj
Copy link
Collaborator

mattjj commented Nov 18, 2019

I'm going to close this because I think the specific question was answered, but @mgbukov please open another issue if you have follow-up questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants