Skip to content

JAX v0.4.37

Compare
Choose a tag to compare
@hawkinsp hawkinsp released this 10 Dec 01:17
· 380 commits to main since this release

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug where jit would error if an argument was named f (#25329).
    • Fix a bug that will throw index out of range error in
      jax.lax.while_loop if the user registers pytree node class with
      different aux data for the flatten and flatten_with_path.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.