Skip to content

Commit

Permalink
Update aqt for newer JAX versions
Browse files Browse the repository at this point in the history
jax.random.KeyArray is deprecated in jax v0.4.16, and will be removed in a later release. See [JEP 9263](https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html) for details. I also changed from using private to public utilities to generate random bits.

PiperOrigin-RevId: 568880357
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Sep 27, 2023
1 parent dc9d7cc commit db2bdaa
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
2 changes: 1 addition & 1 deletion aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

@flax.struct.dataclass
class Context:
key: Optional[jax.random.KeyArray]
key: Optional[jax.Array]
train_step: Optional[int]


Expand Down
2 changes: 1 addition & 1 deletion aqt/jax/v2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

FreshScaleFn = Callable[[jnp.ndarray], jnp.ndarray]
ClipAndRoundFn = Callable[[jnp.ndarray, Context], jnp.ndarray]
NoiseFn = Callable[[tuple[int, ...], jax.random.KeyArray], jnp.ndarray]
NoiseFn = Callable[[tuple[int, ...], jax.Array], jnp.ndarray]


@dataclasses.dataclass
Expand Down
12 changes: 4 additions & 8 deletions aqt/jax/v2/stochastic_rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@


def random_centered_uniform(
shape: tuple[int, ...], key: jax.random.KeyArray) -> jnp.ndarray:
shape: tuple[int, ...], key: jax.Array) -> jnp.ndarray:
"""Generates uniform number in [-0.5, 0.5]."""
nbits = 16
dtype = jnp.dtype('uint16')
nbits = jnp.iinfo(dtype).bits

# Generate random bits.
from jax._src import prng # pylint: disable=g-import-not-at-top
assert not jax.config.jax_enable_custom_prng
key = prng.random_wrap(key, impl=jax.random.default_prng_impl())
bits = prng.random_bits(key, bit_width=nbits, shape=shape)
assert bits.shape == shape, (bits.shape, bits.shape)
assert bits.dtype == {8: jnp.uint8, 16: jnp.uint16}[nbits], bits.dtype
bits = jax.random.bits(key, shape, dtype)

# Align bits with the mantissa of f32.
nmant = jnp.finfo(jnp.float32).nmant
Expand Down

0 comments on commit db2bdaa

Please sign in to comment.