diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index dcf3a73744d26..ba3f3cfca46cd 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018)) +- Added `work.delete` method to delete the work ([#16103](https://github.com/Lightning-AI/lightning/pull/16103)) + - Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095)) diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index 12203d4db3ff8..b34f42f2fef9a 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -644,6 +644,19 @@ def stop(self): app = _LightningAppRef().get_current() self._backend.stop_work(app, self) + def delete(self): + """Delete LightingWork component and shuts down hardware provisioned via L.CloudCompute. + + Locally, the work.delete() behaves as work.stop(). + """ + if not self._backend: + raise Exception( + "Can't delete the work, it looks like it isn't attached to a LightningFlow. " + "Make sure to assign the Work to a flow instance." + ) + app = _LightningAppRef().get_current() + self._backend.delete_work(app, self) + def _check_run_is_implemented(self) -> None: if not is_overridden("run", instance=self, parent=LightningWork): raise TypeError( diff --git a/src/lightning_app/runners/backends/mp_process.py b/src/lightning_app/runners/backends/mp_process.py index dc0681390046e..36f3cb8097604 100644 --- a/src/lightning_app/runners/backends/mp_process.py +++ b/src/lightning_app/runners/backends/mp_process.py @@ -88,6 +88,9 @@ def stop_work(self, app, work: "lightning_app.LightningWork") -> None: work_manager: MultiProcessWorkManager = app.processes[work.name] work_manager.kill() + def delete_work(self, app, work: "lightning_app.LightningWork") -> None: + self.stop_work(app, work) + class CloudMultiProcessingBackend(MultiProcessingBackend): def __init__(self, *args, **kwargs): @@ -108,3 +111,6 @@ def stop_work(self, app, work: "lightning_app.LightningWork") -> None: disable_port(work._port) self.ports = [port for port in self.ports if port != work._port] return super().stop_work(app, work) + + def delete_work(self, app, work: "lightning_app.LightningWork") -> None: + self.stop_work(app, work) diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py index 01f55b1f90d1b..ea3288e6b761b 100644 --- a/tests/tests_app/core/test_lightning_work.py +++ b/tests/tests_app/core/test_lightning_work.py @@ -1,6 +1,6 @@ from queue import Empty from re import escape -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest @@ -386,6 +386,18 @@ def test_lightning_app_work_start(cache_calls, parallel): MultiProcessRuntime(app, start_server=False).dispatch() +def test_lightning_work_delete(): + work = WorkCounter() + + with pytest.raises(Exception, match="Can't delete the work"): + work.delete() + + mock = MagicMock() + work._backend = mock + work.delete() + assert work == mock.delete_work._mock_call_args_list[0].args[1] + + class WorkDisplay(LightningWork): def __init__(self): super().__init__()