Skip to content

Commit

Permalink
Update orbax_upgrade_guide.rst for async checkpointing usage examples
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushaladiti-2802 committed Jul 3, 2024
1 parent 0c19b6b commit 4b4d2fe
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,33 @@ Then, you can call ``orbax.checkpoint.AsyncCheckpointer.wait_until_finished()``

For more details, read the `checkpoint guide <https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#asynchronized-checkpointing>`_.

You can also use Orbax AsyncCheckpointer with Flax APIs through async manager. Async manager internally calls wait_until_finished(). This solution is not actively maintained and the recommedation is to use Orbax async checkpointing.

For example:

.. codediff::
:title: flax.checkpoints(async), orbax.checkpoint(async)
:skip_test: flax.checkpoints
:sync:

PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'
flax.config.update('flax_use_orbax_checkpointing', True)
async_manager = checkpoints.AsyncManager()

checkpoints.save_checkpoint(PURE_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager)
checkpoints.restore_checkpoint(PURE_CKPT_DIR, target=TARGET_PYTREE)
---

PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure'

import orbax.checkpoint as ocp
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(PURE_CKPT_DIR, args=ocp.args.StandardSave(pytree))
# ... Continue with your work...
# ... Until a time when you want to wait until the save completes:
ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed.
ckptr.restore(PURE_CKPT_DIR, args=ocp.args.StandardRestore(target))


Saving/loading a single JAX or NumPy Array
******************************************
Expand Down Expand Up @@ -210,4 +237,4 @@ For example:
Final words
***********

This guide provides an overview of how to migrate from the "legacy" Flax checkpointing API to the Orbax API. Orbax provides more functionalities and the Orbax team is actively developing new features. Stay tuned and follow the `official Orbax GitHub repository <https://github.com/google/orbax>`__ for more!
This guide provides an overview of how to migrate from the "legacy" Flax checkpointing API to the Orbax API. Orbax provides more functionalities and the Orbax team is actively developing new features. Stay tuned and follow the `official Orbax GitHub repository <https://github.com/google/orbax>`__ for more!

0 comments on commit 4b4d2fe

Please sign in to comment.