diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index 0d60e9c..ec33f6d 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -324,7 +324,7 @@ def convert_seed_and_sample_shape( else: # key is of type PRNGKey rng = seed - return rng, sample_shape + return rng, sample_shape # type: ignore[bad-return-type] def to_batch_shape_index(