Releases: jax-ml/jax
Releases · jax-ml/jax
Jaxlib release v0.4.30
jaxlib-v0.4.30 jaxlib version 0.4.30
Jax release v0.4.30
jax-v0.4.30 jax version 0.4.30
Jaxlib release v0.4.29
-
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403). - Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(openxla/xla#13301). - Fixes a compiler crash on GPU (#21396).
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
-
Deprecations
jax.tree.map(f, None, non-None)
now emits aDeprecationWarning
, and will
raise an error in a future version of jax.None
is only a tree-prefix of
itself. To preserve the current behavior, you can askjax.tree.map
to
treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.
JAX v0.4.29
-
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.pip install jax[cuda12]
). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.export
API. It is not possible anymore to use
from jax.experimental.export import export
, and instead you should use
from jax.experimental import export
.
The removed functionality has been deprecated since 0.4.24.
- We anticipate that this will be the last release of JAX and jaxlib
-
Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please use
jax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed as
jax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.
The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core
:non_negative_dim
,DimSize
,Shape
- from {mod}
jax.lax
:tie_in
- from {mod}
jax.nn
:normalize
- from {mod}
jax.interpreters.xla
:backend_specific_translations
,
translations
,register_translation
,xla_destructure
,
TranslationRule
,TranslationContext
,XlaOp
.
- from {mod}
- The
tol
argument of {func}jax.numpy.linalg.matrix_rank
is being
deprecated and will soon be removed. Usertol
instead. - The
rcond
argument of {func}jax.numpy.linalg.pinv
is being
deprecated and will soon be removed. Usertol
instead. - The deprecated
jax.config
submodule has been removed. To configure JAX
useimport jax
and then reference the config object viajax.config
. - {mod}
jax.random
APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}jax.vmap
in such cases.
-
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jax
to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in theExported
objects.
- Added {func}
jaxlib v0.4.28
-
Bug fixes
- Fixes a memory corruption bug in the type name of Array and JIT Python
objects in Python 3.10 or earlier. - Fixed a warning
'+ptx84' is not a recognized feature for this target
under CUDA 12.4. - Fixed a slow compilation problem on CPU.
- Fixes a memory corruption bug in the type name of Array and JIT Python
-
Changes
- The Windows build is now built with Clang instead of MSVC.
JAX v0.4.28
-
Bug fixes
- Reverted a change to
make_jaxpr
that was breaking Equinox (#21116).
- Reverted a change to
-
Deprecations & removals
- The
kind
argument tojax.numpy.sort
andjax.numpy.argsort
is now removed. Usestable=True
orstable=False
instead. - Removed
get_compute_capability
from thejax.experimental.pallas.gpu
module. Use thecompute_capability
attribute of a GPU device, returned
byjax.devices
orjax.local_devices
, instead.
- The
-
Changes
- The minimum jaxlib version of this release is 0.4.27.
Jaxlib release v0.4.27
jaxlib-v0.4.27 jaxlib version 0.4.27
Jax release v0.4.27
jax-v0.4.27 jax version 0.4.27
Jaxlib release v0.4.26
jaxlib-v0.4.26 jaxlib version 0.4.26
Jax release v0.4.26
jax-v0.4.26 jax version 0.4.26