Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix unfreeze strategies with onecyclelr and reduced lr #1329

Merged
merged 5 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324))

- Fixed a bug where the `freeze_unfreeze` and `unfreeze_milestones` finetuning strategies could not be used in tandem with a `onecyclelr` LR scheduler ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

- Fixed a bug where the backbone learning rate would be divided by 10 when unfrozen if using the `freeze_unfreeze` or `unfreeze_milestones` strategies ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

## [0.7.4] - 2022-04-27

### Fixed
Expand Down
19 changes: 16 additions & 3 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def freeze_before_training(self, pl_module: Union[Module, Iterable[Union[Module,
modules = [modules]
self.freeze(modules=modules, train_bn=self.train_bn)

def unfreeze_and_extend_param_group(
self,
modules: Union[Module, Iterable[Union[Module, Iterable]]],
optimizer: Optimizer,
train_bn: bool = True,
) -> None:
self.make_trainable(modules)

params = self.filter_params(modules, train_bn=train_bn, requires_grad=True)
params = self.filter_on_optimizer(optimizer, params)
if params:
optimizer.param_groups[0]["params"].extend(params)

def _freeze_unfreeze_function(
self,
pl_module: Union[Module, Iterable[Union[Module, Iterable]]],
Expand All @@ -117,7 +130,7 @@ def _freeze_unfreeze_function(

modules = self._get_modules_to_freeze(pl_module=pl_module)
if modules is not None:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand All @@ -140,15 +153,15 @@ def _unfreeze_milestones_function(
# unfreeze num_layers last layers

backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[-num_layers:]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
)
elif epoch == unfreeze_milestones[1]:
# unfreeze remaining layers
backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[:-num_layers]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand Down
37 changes: 31 additions & 6 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,45 @@ def test_finetuning_with_none_return_type(strategy):


@pytest.mark.parametrize(
("strategy", "checker_class", "checker_class_data"),
("strategy", "lr_scheduler", "checker_class", "checker_class_data"),
[
("no_freeze", None, {}),
("freeze", FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
("no_freeze", None, None, {}),
("freeze", None, FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), None, FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
(
("unfreeze_milestones", ((1, 3), 1)),
None,
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
(
"no_freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
None,
{},
),
(
"freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeStrategyChecking,
{},
),
(
("freeze_unfreeze", 2),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeUnfreezeStrategyChecking,
{"check_epoch": 2},
),
(
("unfreeze_milestones", ((1, 3), 1)),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
],
)
def test_finetuning(tmpdir, strategy, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss)
def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss, lr_scheduler=lr_scheduler, optimizer="sgd", learning_rate=0.1)
callbacks = [] if checker_class is None else checker_class(dirpath=tmpdir, **checker_class_data)
trainer = flash.Trainer(max_epochs=5, limit_train_batches=10, callbacks=callbacks)
ds = DummyDataset()
Expand Down