JAX v0.5.0
As of this release, JAX now uses effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.
-
Breaking changes
-
Enable
jax_threefry_partitionable
by default (see
the update note). -
This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussion, see #22936.Two key factors motivated this decision:
- The Mac x86 build (only) has a number of test failures and crashes. We
would prefer to ship no release than a broken release. - Mac x86 hardware is end-of-life and cannot be easily obtained for
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again. - The Mac x86 build (only) has a number of test failures and crashes. We
-
-
Changes:
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025. - The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025. jax.numpy.einsum
now defaults tooptimize='auto'
rather than
optimize='optimal'
. This avoids exponentially-scaling trace-time in
the case of many arguments (#25214
).jax.numpy.linalg.solve
no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
usesolve(a, b[..., None]).squeeze(-1)
.
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
-
New Features
jax.numpy.fft.fftn
,jax.numpy.fft.rfftn
,
jax.numpy.fft.ifftn
, andjax.numpy.fft.irfftn
now support
transforms in more than 3 dimensions, which was previously the limit. See
#25606
for more details.- Support added for user defined state in the FFI via the new
jax.ffi.register_ffi_type_id
function. - The AOT lowering
.as_text()
method now supports thedebug_info
option
to include debugging information, e.g., source location, in the output.
-
Deprecations
- From
jax.interpreters.xla
,abstractify
andpytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name
injax.core
. jax.scipy.special.lpmn
andjax.scipy.special.lpmn_values
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.- The
jax.extend.ffi
submodule was moved tojax.ffi
, and the
previous import path is deprecated.
- From
-
Deletions
jax_enable_memories
flag has been deleted and the behavior of that flag
is on by default.- From
jax.lib.xla_client
, the previously-deprecatedDevice
and
XlaRuntimeError
symbols have been removed; instead usejax.Device
andjax.errors.JaxRuntimeError
respectively. - The
jax.experimental.array_api
module has been removed after being
deprecated in JAX v0.4.32. Since that release,jax.numpy
supports
the array API directly.