diff --git a/gpjax/typing.py b/gpjax/typing.py index 66b162fad..fb82f1699 100644 --- a/gpjax/typing.py +++ b/gpjax/typing.py @@ -17,17 +17,18 @@ Callable, Union, ) -from jax.random import KeyArray as JAXKeyArray from jaxtyping import ( Array as JAXArray, Bool, Float, Int, + Key, UInt32, ) from numpy import ndarray as NumpyArray OldKeyArray = UInt32[JAXArray, "2"] +JAXKeyArray = Key[JAXArray, ""] KeyArray = Union[ OldKeyArray, JAXKeyArray ] # for compatibility regardless of enable_custom_prng setting