From 35ba7dbee91d90b06b1b052267cb0988b7b220d8 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:04:14 -0800 Subject: [PATCH] Fix for https://github.com/keras-team/keras/issues/20425 The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called. Added an exception to catch this situation in the future. Added a unit test to test `model.fit()` with all the combinations of data adapters. --- .../data_adapters/py_dataset_adapter.py | 16 +- keras/src/trainers/epoch_iterator.py | 1 + keras/src/trainers/trainer_test.py | 170 ++++++++++++++++++ 3 files changed, 181 insertions(+), 6 deletions(-) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 6a117a6c50b..f4f12e5d511 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -193,6 +193,7 @@ def __init__( self.enqueuer = None self.shuffle = shuffle self._output_signature = None + self._within_epoch = False workers = self.py_dataset.workers use_multiprocessing = self.py_dataset.use_multiprocessing @@ -314,6 +315,12 @@ def get_torch_dataloader(self): return data_adapter_utils.get_torch_dataloader(self._get_iterator()) def on_epoch_begin(self): + if self._within_epoch: + raise ValueError( + "`on_epoch_begin` was called twice without `on_epoch_end` " + "having been called." + ) + self._within_epoch = True if self.enqueuer: self.enqueuer.start() self.py_dataset.on_epoch_begin() @@ -322,6 +329,7 @@ def on_epoch_end(self): if self.enqueuer: self.enqueuer.stop() self.py_dataset.on_epoch_end() + self._within_epoch = False @property def num_batches(self): @@ -460,7 +468,7 @@ def start(self): return self.running = True self.run_thread = threading.Thread(target=self._run) - self.run_thread.name = f"Worker_{self.uid}" # TODO remove + self.run_thread.name = f"Worker_{self.uid}" self.run_thread.daemon = True self.run_thread.start() @@ -644,11 +652,7 @@ def get(self): if inputs is not None: yield inputs except queue.Empty: - warnings.warn( - "Generator ran out of batches before reaching `num_batches`" - ) - self.stop() - return + pass except Exception as e: self.stop(drain_queue_and_join=True) raise e diff --git a/keras/src/trainers/epoch_iterator.py b/keras/src/trainers/epoch_iterator.py index 6f83215b68a..9f2e670be1f 100644 --- a/keras/src/trainers/epoch_iterator.py +++ b/keras/src/trainers/epoch_iterator.py @@ -91,6 +91,7 @@ def reset(self): self._num_batches = self.data_adapter.num_batches self._steps_seen = 0 self._epoch_iterator = None + self.data_adapter.on_epoch_end() def _enumerate_iterator(self): self.data_adapter.on_epoch_begin() diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 19e51417443..f087c459318 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -18,6 +18,7 @@ from keras.src.callbacks.callback import Callback from keras.src.optimizers.rmsprop import RMSprop from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters import py_dataset_adapter if backend.backend() == "jax": from keras.src.backend.jax.trainer import JAXTrainer as Trainer @@ -141,6 +142,76 @@ def call(self, x, training=False): return x * 0 +class TestPyDataset(py_dataset_adapter.PyDataset): + def __init__(self, infinite=False, **kwargs): + super().__init__(**kwargs) + self.infinite = infinite + + @property + def num_batches(self): + return None if self.infinite else 20 + + def __getitem__(self, idx): + return ops.ones((5, 4)), ops.zeros((5, 3)) + + +def create_dataset(dataset_type, dataset_kwargs): + if dataset_type == "np_array": + return np.ones((100, 4)), np.zeros((100, 3)) + elif dataset_type == "native_array": + return ops.ones((100, 4)), ops.zeros((100, 3)) + elif dataset_type == "py_dataset": + return TestPyDataset(**dataset_kwargs), None + elif dataset_type == "tf_dataset": + import tensorflow as tf + + dataset = tf.data.Dataset.from_tensor_slices( + (tf.ones((100, 4)), tf.zeros((100, 3))) + ).batch(5) + if dataset_kwargs.get("infinite", False): + dataset = dataset.repeat() + return dataset, None + elif dataset_type == "torch_dataloader": + import torch + + class TestIterableDataset(torch.utils.data.IterableDataset): + def __iter__(self): + for _ in range(20): + yield torch.ones((5, 4)), torch.zeros((5, 3)) + + class TestIterableDatasetWithLen(TestIterableDataset): + def __len__(self): + return 20 + + if dataset_kwargs.get("iterable", False): + if dataset_kwargs.get("has_len", False): + dataset = TestIterableDatasetWithLen() + else: + dataset = TestIterableDataset() + return torch.utils.data.DataLoader(dataset), None + else: + dataset = torch.utils.data.TensorDataset( + torch.ones((100, 4)), torch.zeros((100, 3)) + ) + return torch.utils.data.DataLoader(dataset, batch_size=5), None + elif dataset_type == "generator": + + def generate_finite(): + for _ in range(20): + yield ops.ones((5, 4)), ops.zeros((5, 3)) + + def generate_infinite(): + while True: + yield ops.ones((5, 4)), ops.zeros((5, 3)) + + if dataset_kwargs.get("infinite", False): + return generate_infinite(), None + else: + return generate_finite(), None + else: + raise ValueError(f"Invalid dataset type {dataset_type}") + + def sparse_generator(generator_type): if generator_type == "scipy": import scipy @@ -397,6 +468,105 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch): atol=1.0, # TODO: results vary across backends ) + @parameterized.named_parameters( + [ + { + "testcase_name": "np_array", + "dataset_type": "np_array", + "fit_kwargs": {"batch_size": 5}, + }, + { + "testcase_name": "native_array", + "dataset_type": "native_array", + "fit_kwargs": {"batch_size": 5}, + }, + { + "testcase_name": "py_dataset", + "dataset_type": "py_dataset", + }, + { + "testcase_name": "py_dataset_infinite", + "dataset_type": "py_dataset", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "py_dataset_multithreading", + "dataset_type": "py_dataset", + "dataset_kwargs": {"workers": 2}, + }, + { + "testcase_name": "py_dataset_multithreading_infinite", + "dataset_type": "py_dataset", + "dataset_kwargs": {"infinite": True, "workers": 2}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "py_dataset_multiprocessing", + "dataset_type": "py_dataset", + "dataset_kwargs": {"workers": 2, "use_multiprocessing": True}, + }, + { + "testcase_name": "py_dataset_multiprocessing_infinite", + "dataset_type": "py_dataset", + "dataset_kwargs": { + "infinite": True, + "workers": 2, + "use_multiprocessing": True, + }, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "tf_dataset", + "dataset_type": "tf_dataset", + }, + { + "testcase_name": "tf_dataset_infinite", + "dataset_type": "tf_dataset", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "torch_dataloader_tensor", + "dataset_type": "torch_dataloader", + }, + { + "testcase_name": "torch_dataloader_iterable", + "dataset_type": "torch_dataloader", + "dataset_kwargs": {"iterable": True, "has_len": False}, + }, + { + "testcase_name": "torch_dataloader_iterable_with_len", + "dataset_type": "torch_dataloader", + "dataset_kwargs": {"iterable": True, "has_len": True}, + }, + { + "testcase_name": "generator", + "dataset_type": "generator", + }, + { + "testcase_name": "generator_infinite", + "dataset_type": "generator", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + ] + ) + @pytest.mark.requires_trainable_backend + def test_fit_with_data_adapter( + self, dataset_type, dataset_kwargs={}, fit_kwargs={} + ): + model = ExampleModel(units=3) + optimizer = optimizers.Adagrad() + model.compile( + optimizer=optimizer, + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + jit_compile=True, + ) + x, y = create_dataset(dataset_type, dataset_kwargs) + model.fit(x, y, epochs=3, **fit_kwargs) + @parameterized.named_parameters( [ ("eager", True, False, False),