-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficiency of scan in backprop; TensorArray? #15906
Comments
This touches on lots of interesting points! Part 1: what JAX does today.First of all, note that JAX doesn't support dynamic shapes. It's not actually possible to implement "naive variant 1" under JAX's model of computation. Next, as Matt has touched on. Let's unpack that. Here's a simple program: import jax
import jax.lax as lax
import jax.numpy as jnp
def f(x):
def body(carry, _):
return carry + 1, carry + 1
_, out = lax.scan(body, x, xs=None, length=5)
return jnp.sum(out)
forward_jaxpr = jax.make_jaxpr(f)(1.)
print("forward computation", forward_jaxpr) Now, here's the result:
This shows how JAX interprets our Python program. We see that JAX has recorded the Now let's look at the backward pass: backward_jaxpr = jax.make_jaxpr(jax.grad(f))(1.)
print("forward+backward computation", backward_jaxpr) this gives:
We've differentiated our scan. This produces two scans: one for the forward iteration, one for the backward iteration, and the important bit is that we still haven't specified how we're implementing our It's only once you've done all the autodiff that this actually gets lowered to a while-with-DUS. We can inspect that (using some private internals) like so: import equinox.internal as eqxi
eqxi.primitive_finalisations[lax.scan_p] = jax._src.lax.control_flow.loops._scan_impl
forward_lowered_jaxpr = eqxi.finalise_jaxpr(forward_jaxpr)
print("forward lowered computation", forward_lowered_jaxpr)
print("")
backward_lowered_jaxpr = eqxi.finalise_jaxpr(backward_jaxpr)
print("forward+backward lowered computation", backward_lowered_jaxpr) this produces:
In these jaxprs, we can now see the So, JAX avoids inefficient memory consumption by doing autodiff at a higher level representation (the Part 2: what if you tried to do this yourself?As you've noticed, you can't do this with a You could do this by composing (The most common example of this footgun is when vmap'ing a while, with a body function that uses a scatter. Even without autodiff this will actually incur a needless copy, blech.) Part 3: let's do even better!We've seen that JAX is smart enough to handle extensive outputs of scans, but that:
The good news is that JAX actually has enough tools that we can fix both of these issues ourselves, as a user! (It's unfortunate that these footguns exist in the first place of course -- it'd be nice if JAX fixed these issues itself at some point.) Take a look at This support reverse mode autodiff by using a fixed number of checkpoints, and then rematerialising intermediate computations as required. (If you know of the "treeverse" algorithm, then this is an extension to the unbounded-number-of-steps case.) This fixes issue 1. It also supports a There's also a convenience Phew, that was rather a lot! Does that all make sense / do you have any follow-up questions? |
Thanks a lot for the very detailed write-up!
I assume this is for efficiency reasons, maybe also simplicity reasons on XLA side? Are there any plans to loosen this restriction on XLA side, and then also on JAX side? So, if I need this, what are my options? One obvious use case: Implement beam search for some encoder-decoder like model, or just a language model. You don't know in advance how long it is going to be. For this, you need dynamically-sized arrays. Or what are the possible workarounds? Setting some arbitrary upper limit for the target length? But if you set e.g. I wonder a bit about this. Isn't this a quite common case? Despite the dynamic size issue: Wouldn't sth like |
I understand there's ongoing work (in both JAX and XLA) to implement "bounded dynamism"/"dynamic shapes", for which the size of the array is dynamic but its size has some known upper limit. AFAICT that's not landing any time soon though. For working around this: indeed, typically we just set some upper limit. (This is usually not too big a deal, our available memory is already setting an upper limit anyway.) If you use something like a Another possibility is to do the dynamic part of the iteration in Python, rather than within the confines of JIT. This often makes sense if you're doing it at the top level (above the level at which you're doing autodiff etc.) I don't think this is that common as a use case. In many respects I think this is quite useful, as it ensures that we write our program in ways we can efficiently transform/lower to all backends/etc. For example I know the Julia folks have had some recent issues in which CUDA.jl only supports static memory allocation, but many of their programs were written assuming dynamic memory allocation is possible. This made it hard to compose these two parts of their ecosystem. I think |
In the past, I implemented beam search in TensorFlow in a way that it was differentiable, and we used it for max expected BLEU or min expected WER training. So I don't really see any reasonable way now to do this in JAX. Or maybe currently the only ugly way I see is: Do the search with an outer loop in pure Python, to get the search space lattice, and to know the real max seq length, or maybe even directly the best N sequences (or a graph/lattice representing best sequences). Then a second pass using the given sequences, so now the seq length is known. This sounds already complicated, but when thinking more about it, you realize that it is even more problematic: You don't want to recompute the encoder twice, but you still want that backprop goes through the encoder for the second pass. How would you even do that? All of this is not that uncommon for tasks operating on sequences, like speech recognition or machine translation. In any case, any preallocation based on some given max upper seq length would be problematic. In practice, you would want that it just uses as much memory as possible and then you get an OOM if it needs more. Also, for the case of short sequences, it should still be fast. Even if the num calc steps are small, allocating such a big tensor in each iteration would make it slow. In our case, also our batch size is dynamic, because we have sequences of very different lengths. So we have some very small batch sizes with very long sequences, but also very big batch sizes with very short sequences. |
Btw, I tried some implementation using It actually looks linear? This is my code. Maybe it is wrong? import jax
import psutil
import matplotlib.pyplot as plt
import os
batch_dim = 10
feature_dim = 5
def test_scan(time_dim: int):
def func(xs):
def body(state, _):
i, ys = state
x = xs[i]
x_ = x * (i ** 0.5)
ys = jax.lax.dynamic_update_index_in_dim(ys, x_, i, axis=0)
return (i + 1, ys), None
i = 0
ys = jax.numpy.zeros((time_dim, batch_dim, feature_dim))
(_, ys), _ = jax.lax.scan(body, init=(i, ys), xs=None, length=time_dim)
y = jax.numpy.sum(ys)
return y
rnd_key = jax.random.PRNGKey(42)
xs = jax.random.uniform(rnd_key, (time_dim, batch_dim, feature_dim))
grad_xs = jax.grad(func)(xs)
return grad_xs
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
jax.config.update('jax_platform_name', 'cpu')
xs = list(range(5, 10_000, 100))
ys = []
for n in xs:
y = test_scan(n)
y.block_until_ready()
mem = psutil.Process().memory_info().rss
print(f"n={n}, mem={mem}")
ys.append(mem)
fig, ax = plt.subplots()
ax.plot(xs, ys)
plt.show() Edit Ah yes I think this is wrong. I'm not really measuring the mem consumption during the backprop. I need to measure at the end of the forward pass, just before backprop starts. Not sure how I can hook that point. |
If you need to differentiate through a variably-sized computation -- which is what I get from what you're saying -- then Here's an exampe of the O(N^2) behaviour: #8192 |
Actually, no; you still get O(T) memory scaling even if you don't use the extensive inputs and outputs. Here's an example: import jax
import jax.numpy as jnp
import jax.ad_checkpoint
xs = 1 + jnp.arange(3.)
def f(z):
xs_ = z * xs
ys = jnp.zeros_like(xs)
def body(carry, _):
i, c, xs, ys = carry
new_c = jnp.sin(jnp.sin(xs[i]) * c)
new_ys = jax.lax.dynamic_update_index_in_dim(ys, new_c, i, 0)
new_carry = (i + 1, new_c, xs, new_ys)
return new_carry, None
(_, y, _, ys), _ = jax.lax.scan(body, (0, z, xs_, ys), None, length=3)
return jnp.sum(y * ys)
def f2(z):
xs_ = z * xs
def body(c, x):
new_c = jnp.sin(jnp.sin(x) * c)
return new_c, new_c
y, ys = jax.lax.scan(body, z, xs_)
return jnp.sum(y * ys)
print('=== first-order AD ===')
print('carry-only loop')
jax.ad_checkpoint.print_saved_residuals(f, 3.)
print()
print('extensive inputs/outputs loop')
jax.ad_checkpoint.print_saved_residuals(f2, 3.)
print()
Notice how nothing is of size T^2, i.e. nothing is shaped like JAX only saves residuals when primal values interact with tangents in a first-order primitive application, i.e. it only saves them when they're really needed. That means we don't need to save anything to do with indexing into the big thing to get the small thing. We only save residuals when we hit e.g. multiplies, and those only involve small things. So There's much more that can be said about loops, but I think that covers this set of questions from the OP:
I think the other questions may have been covered by @patrick-kidger already, though I didn't actually check! I just wanted to talk about |
In simple cases like this, yes. In practice there are programs which do spuriously get O(n^2) scaling, due to the particular way their computation is expressed. For example adding Similarly instead constructing (Honourable mention: nested loops as you mention sometimes mean that transpose-of-DS really does create an array of zeros -- this introduces a spurious O(n^2) runtime.) I can see that we're talking about slightly different things. You're thinking about the reimplementation of specifically extensive outputs. Whilst I didn't originally discuss specifics I'm mostly trying to emphasise that real-world uses of in-place-updates are silently dangerous! |
Thanks again for the answers. I think the examples from @patrick-kidger are clear to me, that they introduce O(n^2) memory consumption. What @mattjj writes though was unexpected to me. In this line: new_ys = jax.lax.dynamic_update_index_in_dim(ys, new_c, i, 0) You only get O(n) if it can be sure that it does not need to store the old ys = jax.lax.dynamic_update_index_in_dim(ys, c1, 0, 0)
ys = jax.lax.dynamic_update_index_in_dim(ys, c2, 0, 0) So, backprop on the final grad_c2 = grad_ys[0]
grad_ys = jax.lax.dynamic_update_index_in_dim(grad_ys, 0, 0, 0)
grad_c1 = grad_ys[0] # now zero Actually
In TensorFlow, the gradient of I was just checking that by extending my script, and yes, this seems to be the case, the runtime seems to be O(n^2): |
Certainly you can write programs where you need to save a full copy of the carry for each step, but that wasn't the question. My point is simply that if you wrote a loop which only slices into the carry you don't get the quadratic memory scaling.
The old
That's what we did in the original Autograd too. But in JAX we ultimately lower to HLO, and there aren't runtime-sparse data types in HLO. Just to summarize, my points are:
On the latter point, preserving efficiency while not relying on compiler optimizations (like JAX does) or sparse runtime data structures (like Autograd/TF do) was one of the points of the Dex paper. |
Like Dex, the secret (i.e. undocumented) next-gen loop construct in JAX uses in-place effectful indexed writes/addupdate operations to ensure efficient transposition (at the jaxpr cost model level, i.e. at the moment we'd still ultimately lower to HLO) without requiring the indexing be built into the loop itself like I can't find a public version of the slides @dougalm and others have made to explain the design options, but here are a few inline: |
Thanks for the summary. That clarifies some of the main points from my original post. I think most other points were also addressed already here. Now I wonder though, why is this not efficient in the case of TF? I don't really remember anymore. I think I have seen quadratic memory consumption for similar code. Or maybe just quadratic runtime? Or maybe just very slow constant overhead? As we discussed, it actually has I also still wonder about how to efficiently implement beam search decoding in a differentiable way with unknown sequence lengths. Or rather, it looks like this is just not possible currently in JAX? This would need dynamic shapes. |
I was curious, and tried this PyTorch code: def test_scan(time_dim: int):
xs = torch.rand((time_dim, batch_dim, feature_dim), requires_grad=True)
ys = torch.zeros((time_dim, batch_dim, feature_dim))
for i in range(time_dim):
x = xs[i]
x_ = x * (i ** 0.5)
ys[i] = x_
y = ys.sum()
y.backward()
return xs.grad It should be equivalent to my JAX code above. It seems I get linear runtime and memory consumption for this. |
Can you just write it the same way you would in PyTorch, without using any control flow primitive? JAX can differentiate through Python loops. (The performance may not be great but then it's just a performance optimization problem.) |
It would be a dynamic ending condition. Sth like So, that would work fine with JAX? I thought JAX does not support dynamic shapes? |
TF provides the
TensorArray
to make automatic iteration and stacking efficient inscan
orwhile_loop
.The naive variant with gathering and concatenating or dynamic updates would be inefficient with backprop, because backprop would keep copies of the full array in each iteration. E.g. assume you are collecting
ys
of shape [T,B,F], and iterating over t in [0,...,T-1]. Now two possible naive variants:You don't know T in advance. You allocate the initial
ys
tensor of shape [0,B,F], and each iteration, you concatenate a new vector [B,F], extended as [1,B,F] to it, so in each step t, the currentys
is of shape [t,B,F].You know T in advance. You can allocate the initial
ys
tensor of shape [T,B,F]. In each iteration, you updateys[t]
(e.g.tensor_scatter_nd_update
).I have seen the concat variant being used for self-attention implementations.
I was checking JAX while_loop and scan and it seems it does not have
TensorArray
but instead usesdynamic_index_in_dim
/dynamic_slice
anddynamic_update_index_in_dim
/dynamic_update_slice
(liketensor_scatter_nd_update
), which is like variant 2.Without considering backprop, variant 2 can be efficient if it would update inplace, actually more so than
TensorArray
. But if it does not update inplace for some reason, you get O(T^2) runtime. Variant 1 would also likely lead to O(T^2) runtime, unless it can be very clever and having preallocated a tensor which is bigger. Then it might get away with O(T log T) runtime, similar to C++std::vector
. But I very much doubt that.When considering backprop, it is much worse, unless there are some clever optimizations happening. But for the standard case, it would need to store a copy of the full
ys
tensor in every iteration, for the use of backprop. So it means you get T times a copy ofys
. This means O(T^2) memory consumption.With TF
TensorArray
, this is not the case, as each tensorys[t]
is treated separately. It is efficient and only has O(T) runtime and memory consumption, even with backprop.So, my question is: Is JAX
scan
really inefficient like I described, esp for the case of backprop? Or if not, why not? How does it avoid the O(T^2) memory consumption in backprop? Is this some automatic optimization? How does it work?I already got some preliminary answer by @mattjj here: As I understood, it is efficient because
scan
has specific code for autodiff, which is implemented in terms of otherscan
s.If you implement
scan
naively usingwhile_loop
and slicing and dynamic slice updates, it would have the problems I described though, right?I actually tried to implement that, but I get:
So I cannot easily implement it?
If you need a dynamic ending condition, so there are no
xs
andlength
is unknown, how would you do that? That would need such a customscan
implementation?Somewhat related issue: #3106 on TensorArray equivalent. but it doesn't really answer my question here on efficiency
Original question also asked here on StackOverflow.
The text was updated successfully, but these errors were encountered: