Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add finetuning strategies for DeepSpeed #1377

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `torchvision` as a requirement to `datatype_audio.txt` as it's used for Audio Classification ([#1425](https://github.com/Lightning-AI/lightning-flash/pull/1425)).
- Added fine tuning strategies for DeepSpeed (with parameter loading and storing omitted) ([#1377](https://github.com/Lightning-AI/lightning-flash/pull/1377))

- Added `torchvision` as a requirement to `datatype_audio.txt` as it's used for Audio Classification ([#1425](https://github.com/Lightning-AI/lightning-flash/pull/1425))

- Added `figsize` and `limit_nb_samples` for showing batch images ([#1381](https://github.com/Lightning-AI/lightning-flash/pull/1381))

Expand Down
10 changes: 10 additions & 0 deletions docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,13 @@ For even more customization, create your own finetuning callback. Learn more abo
:hide:

...

Working with DeepSpeed
======================

If you are using DeepSpeed, you can use the following strategies. The usage of the following strategies is the same as listed above, but finetuning with DeepSpeed doesn't yet support the loading and storing of its parameters.

* ``freeze_deepspeed``
* ``no_freeze_deepspeed``
* ``freeze_unfreeze_deepspeed``
* ``unfreeze_milestones_deepspeed``
55 changes: 54 additions & 1 deletion flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from functools import partial
from typing import Iterable, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import BaseFinetuning
Expand Down Expand Up @@ -215,3 +215,56 @@ def __init__(
train_bn: bool = True,
):
super().__init__(FinetuningStrategies.UNFREEZE_MILESTONES, strategy_metadata, train_bn)


class FlashDeepSpeedFinetuning(FlashBaseFinetuning):
"""FlashDeepSpeedFinetuning can be used to create a custom Flash Finetuning Callback which works with
DeepSpeed.

DeepSpeed cannot store and load its parameters when working with Lightning. So FlashDeepSpeedFinetuning overrides
`_store` to not store its parameters.
"""

def _store(
self,
pl_module: LightningModule,
opt_idx: int,
num_param_groups: int,
current_param_groups: List[Dict[str, Any]],
) -> None:
pass


class NoFreezeDeepSpeed(FlashDeepSpeedFinetuning):
def __init__(self, train_bn: bool = True):
super().__init__(FinetuningStrategies.NO_FREEZE, train_bn)


class FreezeDeepSpeed(FlashDeepSpeedFinetuning):
def __init__(self, train_bn: bool = True):
super().__init__(FinetuningStrategies.FREEZE, train_bn)


class FreezeUnfreezeDeepSpeed(FlashDeepSpeedFinetuning):
def __init__(
self,
strategy_metadata: int,
train_bn: bool = True,
):
super().__init__(FinetuningStrategies.FREEZE_UNFREEZE, strategy_metadata, train_bn)


class UnfreezeMilestonesDeepSpeed(FlashDeepSpeedFinetuning):
def __init__(
self,
strategy_metadata: Tuple[Tuple[int, int], int],
train_bn: bool = True,
):
super().__init__(FinetuningStrategies.UNFREEZE_MILESTONES, strategy_metadata, train_bn)


for strategy in FinetuningStrategies:
_FINETUNING_STRATEGIES_REGISTRY(
name=f"{strategy.value}_deepspeed",
fn=partial(FlashDeepSpeedFinetuning, strategy_key=strategy),
)
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_BAAL_AVAILABLE = _module_available("baal")
_TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer")
_SENTENCE_TRANSFORMERS_AVAILABLE = _module_available("sentence_transformers")
_DEEPSPEED_AVAILABLE = _module_available("deepspeed")


if _PIL_AVAILABLE:
Expand Down
57 changes: 49 additions & 8 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from torch.utils.data import DataLoader

import flash
from flash.core.finetuning import _FINETUNING_STRATEGIES_REGISTRY
from flash.core.model import Task
from flash.core.utilities.imports import _CORE_TESTING
from flash.core.utilities.imports import _CORE_TESTING, _DEEPSPEED_AVAILABLE
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -138,17 +139,38 @@ def on_train_epoch_start(self, trainer, pl_module):

@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize(
"strategy",
"strategy, plugins",
[
"no_freeze",
"freeze",
("freeze_unfreeze", 1),
("unfreeze_milestones", ((5, 10), 5)),
("no_freeze", None),
("freeze", None),
(("freeze_unfreeze", 1), None),
(("unfreeze_milestones", ((5, 10), 5)), None),
pytest.param(
"no_freeze_deepspeed",
"deepspeed",
marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"),
),
pytest.param(
"freeze_deepspeed",
"deepspeed",
marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"),
),
pytest.param(
("freeze_unfreeze_deepspeed", 1),
"deepspeed",
marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"),
),
pytest.param(
("unfreeze_milestones_deepspeed", ((5, 10), 5)),
"deepspeed",
marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not installed"),
),
],
)
def test_finetuning_with_none_return_type(strategy):
def test_finetuning_with_none_return_type(strategy, plugins):
gpus = 0 if plugins is None else 1
task = TestTaskWithoutFinetuning(loss_fn=F.nll_loss)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=10)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=10, gpus=gpus, plugins=plugins)
ds = DummyDataset()
trainer.finetune(task, train_dataloader=DataLoader(ds), strategy=strategy)

Expand Down Expand Up @@ -223,3 +245,22 @@ def test_finetuning_errors_and_exceptions(strategy):
ds = DummyDataset()
with pytest.raises(MisconfigurationException):
trainer.finetune(task, train_dataloader=DataLoader(ds), strategy=strategy)


@pytest.mark.parametrize(
"strategy_key, strategy_metadata",
[
("no_freeze", None),
("freeze", None),
("freeze_unfreeze", 2),
("unfreeze_milestones", ((5, 10), 15)),
],
)
def test_deepspeed_finetuning_strategy_key(strategy_key, strategy_metadata):
deepspeed_strategy_key = f"{strategy_key}_deepspeed"

strategy = _FINETUNING_STRATEGIES_REGISTRY.get(key=strategy_key)(strategy_metadata=strategy_metadata).strategy
deepspeed_strategy = _FINETUNING_STRATEGIES_REGISTRY.get(key=deepspeed_strategy_key)(
strategy_metadata=strategy_metadata
).strategy
assert strategy == deepspeed_strategy
Comment on lines +250 to +266
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice tests! Thanks for adding them.