Skip to content

Commit

Permalink
Fix for #20425
Browse files Browse the repository at this point in the history
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
  • Loading branch information
hertschuh committed Nov 5, 2024
1 parent 272bb90 commit 35ba7db
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 6 deletions.
16 changes: 10 additions & 6 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
self.enqueuer = None
self.shuffle = shuffle
self._output_signature = None
self._within_epoch = False

workers = self.py_dataset.workers
use_multiprocessing = self.py_dataset.use_multiprocessing
Expand Down Expand Up @@ -314,6 +315,12 @@ def get_torch_dataloader(self):
return data_adapter_utils.get_torch_dataloader(self._get_iterator())

def on_epoch_begin(self):
if self._within_epoch:
raise ValueError(
"`on_epoch_begin` was called twice without `on_epoch_end` "
"having been called."
)
self._within_epoch = True
if self.enqueuer:
self.enqueuer.start()
self.py_dataset.on_epoch_begin()
Expand All @@ -322,6 +329,7 @@ def on_epoch_end(self):
if self.enqueuer:
self.enqueuer.stop()
self.py_dataset.on_epoch_end()
self._within_epoch = False

@property
def num_batches(self):
Expand Down Expand Up @@ -460,7 +468,7 @@ def start(self):
return
self.running = True
self.run_thread = threading.Thread(target=self._run)
self.run_thread.name = f"Worker_{self.uid}" # TODO remove
self.run_thread.name = f"Worker_{self.uid}"
self.run_thread.daemon = True
self.run_thread.start()

Expand Down Expand Up @@ -644,11 +652,7 @@ def get(self):
if inputs is not None:
yield inputs
except queue.Empty:
warnings.warn(
"Generator ran out of batches before reaching `num_batches`"
)
self.stop()
return
pass
except Exception as e:
self.stop(drain_queue_and_join=True)
raise e
Expand Down
1 change: 1 addition & 0 deletions keras/src/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def reset(self):
self._num_batches = self.data_adapter.num_batches
self._steps_seen = 0
self._epoch_iterator = None
self.data_adapter.on_epoch_end()

def _enumerate_iterator(self):
self.data_adapter.on_epoch_begin()
Expand Down
170 changes: 170 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.callbacks.callback import Callback
from keras.src.optimizers.rmsprop import RMSprop
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import py_dataset_adapter

if backend.backend() == "jax":
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
Expand Down Expand Up @@ -141,6 +142,76 @@ def call(self, x, training=False):
return x * 0


class TestPyDataset(py_dataset_adapter.PyDataset):
def __init__(self, infinite=False, **kwargs):
super().__init__(**kwargs)
self.infinite = infinite

@property
def num_batches(self):
return None if self.infinite else 20

def __getitem__(self, idx):
return ops.ones((5, 4)), ops.zeros((5, 3))


def create_dataset(dataset_type, dataset_kwargs):
if dataset_type == "np_array":
return np.ones((100, 4)), np.zeros((100, 3))
elif dataset_type == "native_array":
return ops.ones((100, 4)), ops.zeros((100, 3))
elif dataset_type == "py_dataset":
return TestPyDataset(**dataset_kwargs), None
elif dataset_type == "tf_dataset":
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(
(tf.ones((100, 4)), tf.zeros((100, 3)))
).batch(5)
if dataset_kwargs.get("infinite", False):
dataset = dataset.repeat()
return dataset, None
elif dataset_type == "torch_dataloader":
import torch

class TestIterableDataset(torch.utils.data.IterableDataset):
def __iter__(self):
for _ in range(20):
yield torch.ones((5, 4)), torch.zeros((5, 3))

class TestIterableDatasetWithLen(TestIterableDataset):
def __len__(self):
return 20

if dataset_kwargs.get("iterable", False):
if dataset_kwargs.get("has_len", False):
dataset = TestIterableDatasetWithLen()
else:
dataset = TestIterableDataset()
return torch.utils.data.DataLoader(dataset), None
else:
dataset = torch.utils.data.TensorDataset(
torch.ones((100, 4)), torch.zeros((100, 3))
)
return torch.utils.data.DataLoader(dataset, batch_size=5), None
elif dataset_type == "generator":

def generate_finite():
for _ in range(20):
yield ops.ones((5, 4)), ops.zeros((5, 3))

def generate_infinite():
while True:
yield ops.ones((5, 4)), ops.zeros((5, 3))

if dataset_kwargs.get("infinite", False):
return generate_infinite(), None
else:
return generate_finite(), None
else:
raise ValueError(f"Invalid dataset type {dataset_type}")


def sparse_generator(generator_type):
if generator_type == "scipy":
import scipy
Expand Down Expand Up @@ -397,6 +468,105 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):
atol=1.0, # TODO: results vary across backends
)

@parameterized.named_parameters(
[
{
"testcase_name": "np_array",
"dataset_type": "np_array",
"fit_kwargs": {"batch_size": 5},
},
{
"testcase_name": "native_array",
"dataset_type": "native_array",
"fit_kwargs": {"batch_size": 5},
},
{
"testcase_name": "py_dataset",
"dataset_type": "py_dataset",
},
{
"testcase_name": "py_dataset_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "py_dataset_multithreading",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2},
},
{
"testcase_name": "py_dataset_multithreading_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {"infinite": True, "workers": 2},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "py_dataset_multiprocessing",
"dataset_type": "py_dataset",
"dataset_kwargs": {"workers": 2, "use_multiprocessing": True},
},
{
"testcase_name": "py_dataset_multiprocessing_infinite",
"dataset_type": "py_dataset",
"dataset_kwargs": {
"infinite": True,
"workers": 2,
"use_multiprocessing": True,
},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "tf_dataset",
"dataset_type": "tf_dataset",
},
{
"testcase_name": "tf_dataset_infinite",
"dataset_type": "tf_dataset",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
{
"testcase_name": "torch_dataloader_tensor",
"dataset_type": "torch_dataloader",
},
{
"testcase_name": "torch_dataloader_iterable",
"dataset_type": "torch_dataloader",
"dataset_kwargs": {"iterable": True, "has_len": False},
},
{
"testcase_name": "torch_dataloader_iterable_with_len",
"dataset_type": "torch_dataloader",
"dataset_kwargs": {"iterable": True, "has_len": True},
},
{
"testcase_name": "generator",
"dataset_type": "generator",
},
{
"testcase_name": "generator_infinite",
"dataset_type": "generator",
"dataset_kwargs": {"infinite": True},
"fit_kwargs": {"steps_per_epoch": 20},
},
]
)
@pytest.mark.requires_trainable_backend
def test_fit_with_data_adapter(
self, dataset_type, dataset_kwargs={}, fit_kwargs={}
):
model = ExampleModel(units=3)
optimizer = optimizers.Adagrad()
model.compile(
optimizer=optimizer,
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
jit_compile=True,
)
x, y = create_dataset(dataset_type, dataset_kwargs)
model.fit(x, y, epochs=3, **fit_kwargs)

@parameterized.named_parameters(
[
("eager", True, False, False),
Expand Down

0 comments on commit 35ba7db

Please sign in to comment.