Avoid OOM on TPU #1690
-
Hi, I've been able to solve an OOM on a TPU v3-8 with an ugly hack that I don't understand. Problem you have encountered:When running my training script on a TPU v3-8, I get What you expected to happen:Due to my quick hack (see below), it should run with no problem. Logs, error messages, etc:
How do I solve it?
Note:
Steps to reproduce:
|
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 6 replies
-
So I think the issue is due to the following:
Another solution I had considered is move all the weights to CPU and move them back to TPU again. Is there a cleaner way to handle TPU memory allocation? |
Beta Was this translation helpful? Give feedback.
-
This is quite odd for sure. Fragmentation and being close to the limit in terms of memory could off course result in errors that appear almost randomly. One thing you could try is to initialize the model on CPU |
Beta Was this translation helpful? Give feedback.
-
I'm having a similar issue, even when I run it on CPU. It's baffling. I'm not sure how it's a parameter issue if, like @borisdayma said, the memory allocation seems to scale with batch size. |
Beta Was this translation helpful? Give feedback.
-
One issue that I faced in using |
Beta Was this translation helpful? Give feedback.
-
I'm having a similar issue, also around checkpointing. The difference is that I'm starting from randomly initialized weights and then checkpointing with orbax. The initial training works fine and the checkpointing works fine, but when the training loop resumes after the first checkpoint, the error appears. I'm guessing that checkpointing moves the model off of the TPU to the CPU, so my hunch is that when it tries to get it back onto the TPU, there isn't enough space. I'm wondering if it'd be possible somehow to "flush" the TPU after checkpointing? I tried the File "train_tcn.py", line 492, in <module>
state, loss = jit_train_step(state, input, target, config.loss_fn)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_w
ith_filtered_traceback
return fun(*args, **kwargs)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/mike/jaxfun/.venv/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1349, in __call__
results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to reserve 2.61G at the bottom of memory. That was not possible. There are 1.70G free, 0B reserved, and 1.70G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well). |
Beta Was this translation helpful? Give feedback.
This is quite odd for sure. Fragmentation and being close to the limit in terms of memory could off course result in errors that appear almost randomly. One thing you could try is to initialize the model on CPU
jax.jit(model.init, backend="cpu")
The params are moved to TPU automatically during training or during replication of the state (egjax_utils.replicate
)