diff --git a/jax/random.py b/jax/random.py index 729de54f9d03..923e41e281ad 100644 --- a/jax/random.py +++ b/jax/random.py @@ -199,12 +199,16 @@ "PRNGKeyArray": ( "jax.random.PRNGKeyArray is deprecated. Use jax.Array for annotations, and " "jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of " - "typed prng keys.", _PRNGKeyArray + "typed prng keys (i.e. keys created with jax.random.key).\n" + "For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html", + _PRNGKeyArray ), "KeyArray": ( "jax.random.KeyArray is deprecated. Use jax.Array for annotations, and " "jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of " - "typed prng keys.", _PRNGKeyArray + "typed prng keys (i.e. keys created with jax.random.key).\n" + "For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html", + _PRNGKeyArray ), # Added September 21, 2023 "threefry2x32_key": (