Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training hangs at the end of the first epoch when using a PyDataset and workers > 1. #20425

Closed
HGS-mbayer opened this issue Oct 29, 2024 · 3 comments
Assignees

Comments

@HGS-mbayer
Copy link

Training using a PyDataset and workers > 1 will hang at the end of the first epoch with Keras 3.6. This issue does not seem to occur with Keras 3.5.

  • Backend is Torch with GPU support (2.5.0+cu124)
  • Windows 11
  • Python 3.10.11

Example Code

Here is a slightly modified version of https://keras.io/examples/vision/mnist_convnet/ to reproduce the issue.

import math

import keras
import numpy as np
from keras import layers

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

batch_size = 512
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])


class Data(keras.utils.PyDataset):
    def __init__(self, x, y, batch_size: int = 2, **kwargs):
        super().__init__(**kwargs)
        self._x = x
        self._y = y
        self._batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self._x) / self._batch_size)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError

        indices = range(len(self._x))[
            index * self._batch_size : (index + 1) * self._batch_size
        ]

        return self._x[indices, ...], self._y[indices, ...]


training_data = Data(x_train, y_train, batch_size=batch_size, workers=8)
validation_data = Data(x_test, y_test, batch_size=batch_size, workers=8)

# This will hang at the end of the first epoch with Keras 3.6.
model.fit(training_data, epochs=epochs, validation_data=validation_data)

Traceback

Here is the traceback I receive when interrupting the process.

Epoch 1/15
117/118 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - accuracy: 0.6114 - loss: 1.2523Traceback (most recent call last):
  File "example.py", line 75, in <module>
    model.fit(training_data, epochs=epochs, validation_data=validation_data)
  File "...\env\lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler
    return fn(*args, **kwargs)
  File "...\env\lib\site-packages\keras\src\backend\torch\trainer.py", line 252, in fit
    for step, data in epoch_iterator.enumerate_epoch():
  File "...\env\lib\site-packages\keras\src\trainers\epoch_iterator.py", line 110, in enumerate_epoch
    for step, data in enumerate(self._get_iterator()):
  File "...\env\lib\site-packages\torch\utils\data\dataloader.py", line 701, in __next__
    data = self._next_data()
  File "...\env\lib\site-packages\torch\utils\data\dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "...\env\lib\site-packages\torch\utils\data\_utils\fetch.py", line 42, in fetch
    data = next(self.dataset_iter)
  File "...\env\lib\site-packages\keras\src\trainers\data_adapters\data_adapter_utils.py", line 222, in __iter__
    for batch in self.iterable:
  File "...\env\lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py", line 257, in _finite_enqueuer_generator
    for i, batch in enumerate(self.enqueuer.get()):
  File "...\env\lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py", line 637, in get
    value = self.future_queue.get(block=True, timeout=5)
  File "...\AppData\Local\Programs\Python\Python310\lib\queue.py", line 180, in get
    self.not_empty.wait(remaining)
  File "...\AppData\Local\Programs\Python\Python310\lib\threading.py", line 324, in wait
    gotit = waiter.acquire(True, timeout)
KeyboardInterrupt
@fchollet
Copy link
Member

fchollet commented Nov 4, 2024

Thanks for the report.

This issue appears to have been introduced in fd8bbe2

@hertschuh can you take a look? I started debugging it, and here's my reading: the following code

except queue.Empty:
    pass

is reached and leads to an infinite loop. That's because we never get to the exit condition:

if i >= num_batches - 1:
    self.enqueuer.stop()
    return

which is because def num_batches returns a (correct) number that is larger than the actual number of batches drawable for the first epoch

@fchollet
Copy link
Member

fchollet commented Nov 4, 2024

I added a workaround at HEAD to continue training when the issue occur. It's not a definitive solution but it should help.

hertschuh added a commit to hertschuh/keras that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
hertschuh added a commit to hertschuh/keras that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
hertschuh added a commit to hertschuh/keras that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
hertschuh added a commit to hertschuh/keras that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
hertschuh added a commit to hertschuh/keras that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
hertschuh added a commit to hertschuh/keras that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
fchollet pushed a commit that referenced this issue Nov 5, 2024
The issue was caused by the fact that the iterator was not fully consumed and `on_epoch_end` was not called.

Added an exception to catch this situation in the future.

Added a unit test to test `model.fit()` with all the combinations of data adapters.
@fchollet
Copy link
Member

fchollet commented Nov 5, 2024

This should now be fixed at HEAD.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants