Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix unfreeze strategies with onecyclelr and reduced lr #1329

Merged
merged 5 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324))

- Fixed a bug where the `freeze_unfreeze` and `unfreeze_milestones` finetuning strategies could not be used in tandem with a `onecyclelr` LR scheduler ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

- Fixed a bug where the backbone learning rate would be divided by 10 when unfrozen if using the `freeze_unfreeze` or `unfreeze_milestones` strategies ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

## [0.7.4] - 2022-04-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ For even more customization, create your own finetuning callback. Learn more abo

# When ``current_epoch`` is 5, backbone will start to be trained.
if current_epoch == self._unfreeze_epoch:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
pl_module.backbone,
optimizer,
)
Expand Down
19 changes: 16 additions & 3 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def freeze_before_training(self, pl_module: Union[Module, Iterable[Union[Module,
modules = [modules]
self.freeze(modules=modules, train_bn=self.train_bn)

def unfreeze_and_extend_param_group(
self,
modules: Union[Module, Iterable[Union[Module, Iterable]]],
optimizer: Optimizer,
train_bn: bool = True,
) -> None:
self.make_trainable(modules)

params = self.filter_params(modules, train_bn=train_bn, requires_grad=True)
params = self.filter_on_optimizer(optimizer, params)
if params:
optimizer.param_groups[0]["params"].extend(params)

def _freeze_unfreeze_function(
self,
pl_module: Union[Module, Iterable[Union[Module, Iterable]]],
Expand All @@ -117,7 +130,7 @@ def _freeze_unfreeze_function(

modules = self._get_modules_to_freeze(pl_module=pl_module)
if modules is not None:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand All @@ -140,15 +153,15 @@ def _unfreeze_milestones_function(
# unfreeze num_layers last layers

backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[-num_layers:]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
)
elif epoch == unfreeze_milestones[1]:
# unfreeze remaining layers
backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[:-num_layers]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand Down
37 changes: 31 additions & 6 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,45 @@ def test_finetuning_with_none_return_type(strategy):

@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize(
("strategy", "checker_class", "checker_class_data"),
("strategy", "lr_scheduler", "checker_class", "checker_class_data"),
[
("no_freeze", None, {}),
("freeze", FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
("no_freeze", None, None, {}),
("freeze", None, FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), None, FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
(
("unfreeze_milestones", ((1, 3), 1)),
None,
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
(
"no_freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
None,
{},
),
(
"freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeStrategyChecking,
{},
),
(
("freeze_unfreeze", 2),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeUnfreezeStrategyChecking,
{"check_epoch": 2},
),
(
("unfreeze_milestones", ((1, 3), 1)),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
],
)
def test_finetuning(tmpdir, strategy, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss)
def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss, lr_scheduler=lr_scheduler, optimizer="sgd", learning_rate=0.1)
callbacks = [] if checker_class is None else checker_class(dirpath=tmpdir, **checker_class_data)
trainer = flash.Trainer(max_epochs=5, limit_train_batches=10, callbacks=callbacks)
ds = DummyDataset()
Expand Down