This repository contains common functionality for writing ML training loops in JAX. The goal is to make trainings loops short and readable (but moving common tasks to small libraries) without removing the flexibility required for research.
To get started, check out this Notebook.
This started as a fork of CLU. See CHANGELOG.md for more details on changes since the fork.