diff --git a/CHANGELOG.md b/CHANGELOG.md index ff6558576a..281370aca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `torchvision` as a requirement to `datatype_audio.txt` as it's used for Audio Classification ([#1425](https://github.com/Lightning-AI/lightning-flash/pull/1425)). +- Added fine tuning strategies for DeepSpeed (with parameter loading and storing omitted) ([#1377](https://github.com/Lightning-AI/lightning-flash/pull/1377)) + +- Added `torchvision` as a requirement to `datatype_audio.txt` as it's used for Audio Classification ([#1425](https://github.com/Lightning-AI/lightning-flash/pull/1425)) - Added `figsize` and `limit_nb_samples` for showing batch images ([#1381](https://github.com/Lightning-AI/lightning-flash/pull/1381)) diff --git a/docs/source/general/finetuning.rst b/docs/source/general/finetuning.rst index 42cb873e29..f71ab35bdc 100644 --- a/docs/source/general/finetuning.rst +++ b/docs/source/general/finetuning.rst @@ -241,3 +241,13 @@ For even more customization, create your own finetuning callback. Learn more abo :hide: ... + +Working with DeepSpeed +====================== + +If you are using DeepSpeed, you can use the following strategies. The usage of the following strategies is the same as listed above, but finetuning with DeepSpeed doesn't yet support the loading and storing of its parameters. + +* ``freeze_deepspeed`` +* ``no_freeze_deepspeed`` +* ``freeze_unfreeze_deepspeed`` +* ``unfreeze_milestones_deepspeed`` diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index f6b39c928e..356b99f04f 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -13,7 +13,7 @@ # limitations under the License. import os from functools import partial -from typing import Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import BaseFinetuning @@ -215,3 +215,56 @@ def __init__( train_bn: bool = True, ): super().__init__(FinetuningStrategies.UNFREEZE_MILESTONES, strategy_metadata, train_bn) + + +class FlashDeepSpeedFinetuning(FlashBaseFinetuning): + """FlashDeepSpeedFinetuning can be used to create a custom Flash Finetuning Callback which works with + DeepSpeed. + + DeepSpeed cannot store and load its parameters when working with Lightning. So FlashDeepSpeedFinetuning overrides + `_store` to not store its parameters. + """ + + def _store( + self, + pl_module: LightningModule, + opt_idx: int, + num_param_groups: int, + current_param_groups: List[Dict[str, Any]], + ) -> None: + pass + + +class NoFreezeDeepSpeed(FlashDeepSpeedFinetuning): + def __init__(self, train_bn: bool = True): + super().__init__(FinetuningStrategies.NO_FREEZE, train_bn) + + +class FreezeDeepSpeed(FlashDeepSpeedFinetuning): + def __init__(self, train_bn: bool = True): + super().__init__(FinetuningStrategies.FREEZE, train_bn) + + +class FreezeUnfreezeDeepSpeed(FlashDeepSpeedFinetuning): + def __init__( + self, + strategy_metadata: int, + train_bn: bool = True, + ): + super().__init__(FinetuningStrategies.FREEZE_UNFREEZE, strategy_metadata, train_bn) + + +class UnfreezeMilestonesDeepSpeed(FlashDeepSpeedFinetuning): + def __init__( + self, + strategy_metadata: Tuple[Tuple[int, int], int], + train_bn: bool = True, + ): + super().__init__(FinetuningStrategies.UNFREEZE_MILESTONES, strategy_metadata, train_bn) + + +for strategy in FinetuningStrategies: + _FINETUNING_STRATEGIES_REGISTRY( + name=f"{strategy.value}_deepspeed", + fn=partial(FlashDeepSpeedFinetuning, strategy_key=strategy), + ) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index d9a852cd0c..bc22549655 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -118,6 +118,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _BAAL_AVAILABLE = _module_available("baal") _TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer") _SENTENCE_TRANSFORMERS_AVAILABLE = _module_available("sentence_transformers") +_DEEPSPEED_AVAILABLE = _module_available("deepspeed") if _PIL_AVAILABLE: diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 67c63647b0..a37bc6e4fe 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -26,8 +26,9 @@ from torch.utils.data import DataLoader import flash +from flash.core.finetuning import _FINETUNING_STRATEGIES_REGISTRY from flash.core.model import Task -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _CORE_TESTING, _DEEPSPEED_AVAILABLE from tests.helpers.boring_model import BoringModel @@ -138,17 +139,38 @@ def on_train_epoch_start(self, trainer, pl_module): @pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") @pytest.mark.parametrize( - "strategy", + "strategy, plugins", [ - "no_freeze", - "freeze", - ("freeze_unfreeze", 1), - ("unfreeze_milestones", ((5, 10), 5)), + ("no_freeze", None), + ("freeze", None), + (("freeze_unfreeze", 1), None), + (("unfreeze_milestones", ((5, 10), 5)), None), + pytest.param( + "no_freeze_deepspeed", + "deepspeed", + marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"), + ), + pytest.param( + "freeze_deepspeed", + "deepspeed", + marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"), + ), + pytest.param( + ("freeze_unfreeze_deepspeed", 1), + "deepspeed", + marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"), + ), + pytest.param( + ("unfreeze_milestones_deepspeed", ((5, 10), 5)), + "deepspeed", + marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"), + ), ], ) -def test_finetuning_with_none_return_type(strategy): +def test_finetuning_with_none_return_type(strategy, plugins): + gpus = 0 if plugins is None else 1 task = TestTaskWithoutFinetuning(loss_fn=F.nll_loss) - trainer = flash.Trainer(max_epochs=1, limit_train_batches=10) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=10, gpus=gpus, plugins=plugins) ds = DummyDataset() trainer.finetune(task, train_dataloader=DataLoader(ds), strategy=strategy) @@ -223,3 +245,22 @@ def test_finetuning_errors_and_exceptions(strategy): ds = DummyDataset() with pytest.raises(MisconfigurationException): trainer.finetune(task, train_dataloader=DataLoader(ds), strategy=strategy) + + +@pytest.mark.parametrize( + "strategy_key, strategy_metadata", + [ + ("no_freeze", None), + ("freeze", None), + ("freeze_unfreeze", 2), + ("unfreeze_milestones", ((5, 10), 15)), + ], +) +def test_deepspeed_finetuning_strategy_key(strategy_key, strategy_metadata): + deepspeed_strategy_key = f"{strategy_key}_deepspeed" + + strategy = _FINETUNING_STRATEGIES_REGISTRY.get(key=strategy_key)(strategy_metadata=strategy_metadata).strategy + deepspeed_strategy = _FINETUNING_STRATEGIES_REGISTRY.get(key=deepspeed_strategy_key)( + strategy_metadata=strategy_metadata + ).strategy + assert strategy == deepspeed_strategy