Skip to content

Commit

Permalink
Better graceful shutdown for KeyboardInterrupt (#19976)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jun 16, 2024
1 parent b16e998 commit c1af4d0
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 22 deletions.
34 changes: 34 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added

-

-

### Changed

-

-

### Deprecated

-

-

### Removed

-

-

### Fixed

-

-



## [2.3.0] - 2024-06-13

### Added
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import logging
import os
import signal
import time
from contextlib import nullcontext
from datetime import timedelta
Expand Down Expand Up @@ -306,8 +307,11 @@ def _init_dist_connection(


def _destroy_dist_connection() -> None:
# Don't allow Ctrl+C to interrupt this handler
signal.signal(signal.SIGINT, signal.SIG_IGN)
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
signal.signal(signal.SIGINT, signal.SIG_DFL)


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
Expand Down
35 changes: 35 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added

-

-

### Changed

- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))

-

### Deprecated

-

-

### Removed

-

-

### Fixed

-

-



## [2.3.0] - 2024-06-13

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive() and proc.pid is not None:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
with suppress(ProcessLookupError):
os.kill(proc.pid, signum)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)

Expand Down
20 changes: 14 additions & 6 deletions src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Type, Union

Expand All @@ -20,10 +21,12 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
from lightning.pytorch.trainer.states import TrainerStatus
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

log = logging.getLogger(__name__)

Expand All @@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None

# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
_interrupt(trainer, exception)
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
signal.signal(signal.SIGINT, signal.SIG_IGN)
_interrupt(trainer, exception)
trainer._teardown()
launcher = trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
launcher.kill(_get_sigkill_signal())
exit(1)

except BaseException as exception:
_interrupt(trainer, exception)
trainer._teardown()
Expand Down
11 changes: 5 additions & 6 deletions src/lightning/pytorch/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
import signal
import sys
import threading
from subprocess import call
from types import FrameType
Expand Down Expand Up @@ -54,7 +53,7 @@ def register_signal_handlers(self) -> None:
sigterm_handlers.append(self._sigterm_handler_fn)

# Windows seems to have signal incompatibilities
if not self._is_on_windows():
if not _IS_WINDOWS:
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
assert sigusr is not None
if sigusr_handlers and not self._has_already_handler(sigusr):
Expand Down Expand Up @@ -155,10 +154,6 @@ def _valid_signals() -> Set[signal.Signals]:
}
return set(signal.Signals)

@staticmethod
def _is_on_windows() -> bool:
return sys.platform == "win32"

@staticmethod
def _has_already_handler(signum: _SIGNUM) -> bool:
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
Expand All @@ -172,3 +167,7 @@ def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state["_original_handlers"] = {}
return state


def _get_sigkill_signal() -> _SIGNUM:
return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def on_train_start(self) -> None:

with mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
) as mock_progress_stop, pytest.raises(SystemExit):
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,
Expand Down
9 changes: 7 additions & 2 deletions tests/tests_pytorch/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial

import pytest
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, LambdaCallback
from lightning.pytorch.demos.boring_classes import BoringModel
Expand All @@ -23,10 +24,13 @@
def test_lambda_call(tmp_path):
seed_everything(42)

class CustomException(Exception):
pass

class CustomModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 1:
raise KeyboardInterrupt
raise CustomException("Custom exception to trigger `on_exception` hooks")

checker = set()

Expand Down Expand Up @@ -59,7 +63,8 @@ def call(hook, *_, **__):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.raises(CustomException):
trainer.fit(model, ckpt_path=ckpt_path)
trainer.test(model)
trainer.predict(model)

Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):

trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)

trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
34 changes: 30 additions & 4 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
import torch.nn as nn
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.seed import seed_everything
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator
Expand All @@ -45,7 +46,7 @@
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
Expand Down Expand Up @@ -1007,7 +1008,8 @@ def on_exception(self, trainer, pl_module, exception):
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
with pytest.raises(MisconfigurationException):
Expand All @@ -1016,6 +1018,30 @@ def on_exception(self, trainer, pl_module, exception):
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)


def test_keyboard_interrupt(tmp_path):
class InterruptCallback(Callback):
def __init__(self):
super().__init__()

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
raise KeyboardInterrupt

model = BoringModel()
trainer = Trainer(
callbacks=[InterruptCallback()],
barebones=True,
default_root_dir=tmp_path,
)

trainer.strategy._launcher = Mock(spec=_SubprocessScriptLauncher)
trainer.strategy._launcher.launch = lambda function, *args, trainer, **kwargs: function(*args, **kwargs)

with pytest.raises(SystemExit) as exc_info:
trainer.fit(model)
assert exc_info.value.args[0] == 1
trainer.strategy._launcher.kill.assert_called_once_with(15 if _IS_WINDOWS else 9)


@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))])
@RunIf(sklearn=True)
def test_gradient_clipping_by_norm(tmp_path, precision):
Expand Down Expand Up @@ -2042,7 +2068,7 @@ def on_fit_start(self):

trainer = Trainer(default_root_dir=tmp_path)
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
Exception
Exception, SystemExit
):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)
Expand All @@ -2061,7 +2087,7 @@ def on_fit_start(self):
datamodule.on_exception = Mock()
trainer = Trainer(default_root_dir=tmp_path)

with suppress(Exception):
with suppress(Exception, SystemExit):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)

Expand Down

0 comments on commit c1af4d0

Please sign in to comment.