From 3a832d33f4ab4b81c0115367fc3d221dc66e3439 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 13 Sep 2023 12:49:29 -0700 Subject: [PATCH] chex: alias PRNGKey to jax.Array Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see https://github.com/google/jax/pull/17297) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict. PiperOrigin-RevId: 565133147 --- distrax/_src/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index 0d60e9c..ec33f6d 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -324,7 +324,7 @@ def convert_seed_and_sample_shape( else: # key is of type PRNGKey rng = seed - return rng, sample_shape + return rng, sample_shape # type: ignore[bad-return-type] def to_batch_shape_index(