From 4d2eb11f42e257ab21cbb6d640fb8a54edf858cf Mon Sep 17 00:00:00 2001 From: Kade Heckel Date: Thu, 8 Feb 2024 12:16:08 +0000 Subject: [PATCH] should really debug things locally more often --- spyx/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spyx/data.py b/spyx/data.py index 78ddfa8..3fbf715 100644 --- a/spyx/data.py +++ b/spyx/data.py @@ -27,7 +27,7 @@ def shuffler(dataset, batch_size): """ x, y = dataset cutoff = y.shape[0] % batch_size - data_shape = (-1, batch_size) + obs.shape[1:] + data_shape = (-1, batch_size) + x.shape[1:] def _shuffle(dataset, shuffle_rng): """