Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into tabular_classification/from_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali authored May 9, 2022
2 parents e5699aa + 1ed7938 commit def986a
Show file tree
Hide file tree
Showing 16 changed files with 231 additions and 248 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'core' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras_baal' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'tabular' ] }
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed support for all the versions (including the latest and older) of `baal`. ([#1315](https://github.com/PyTorchLightning/lightning-flash/pull/1315))

- Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324))

- Fixed a bug where the `freeze_unfreeze` and `unfreeze_milestones` finetuning strategies could not be used in tandem with a `onecyclelr` LR scheduler ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

- Fixed a bug where the backbone learning rate would be divided by 10 when unfrozen if using the `freeze_unfreeze` or `unfreeze_milestones` strategies ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

## [0.7.4] - 2022-04-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ For even more customization, create your own finetuning callback. Learn more abo

# When ``current_epoch`` is 5, backbone will start to be trained.
if current_epoch == self._unfreeze_epoch:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
pl_module.backbone,
optimizer,
)
Expand Down
19 changes: 16 additions & 3 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def freeze_before_training(self, pl_module: Union[Module, Iterable[Union[Module,
modules = [modules]
self.freeze(modules=modules, train_bn=self.train_bn)

def unfreeze_and_extend_param_group(
self,
modules: Union[Module, Iterable[Union[Module, Iterable]]],
optimizer: Optimizer,
train_bn: bool = True,
) -> None:
self.make_trainable(modules)

params = self.filter_params(modules, train_bn=train_bn, requires_grad=True)
params = self.filter_on_optimizer(optimizer, params)
if params:
optimizer.param_groups[0]["params"].extend(params)

def _freeze_unfreeze_function(
self,
pl_module: Union[Module, Iterable[Union[Module, Iterable]]],
Expand All @@ -117,7 +130,7 @@ def _freeze_unfreeze_function(

modules = self._get_modules_to_freeze(pl_module=pl_module)
if modules is not None:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand All @@ -140,15 +153,15 @@ def _unfreeze_milestones_function(
# unfreeze num_layers last layers

backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[-num_layers:]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
)
elif epoch == unfreeze_milestones[1]:
# unfreeze remaining layers
backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[:-num_layers]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand Down
11 changes: 7 additions & 4 deletions flash/core/integrations/pytorch_tabular/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@


class PytorchTabularAdapter(Adapter):
def __init__(self, backbone):
def __init__(self, task_type, backbone):
super().__init__()

self.task_type = task_type
self.backbone = backbone

@classmethod
Expand Down Expand Up @@ -52,21 +53,23 @@ def from_task(
"output_dim": output_dim,
}
adapter = cls(
task_type,
task.backbones.get(backbone)(
task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs
)
),
)

return adapter

@staticmethod
def convert_batch(batch):
def convert_batch(self, batch):
new_batch = {
"continuous": batch[DataKeys.INPUT][1],
"categorical": batch[DataKeys.INPUT][0],
}
if DataKeys.TARGET in batch:
new_batch["target"] = batch[DataKeys.TARGET].reshape(-1, 1)
if self.task_type == "regression":
new_batch["target"] = new_batch["target"].float()
return new_batch

def training_step(self, batch, batch_idx) -> Any:
Expand Down
20 changes: 15 additions & 5 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import os
import types
from importlib.util import find_spec
from typing import List, Tuple, Union
from typing import Callable, List, Tuple, Union

import pkg_resources
from pkg_resources import DistributionNotFound

try:
Expand Down Expand Up @@ -48,21 +49,29 @@ def _module_available(module_path: str) -> bool:
return True


def _compare_version(package: str, op, version) -> bool:
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
"""Compare package version with some requirements.
>>> _compare_version("torch", operator.ge, "0.1")
True
>>> _compare_version("does_not_exist", operator.ge, "0.0")
False
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound, ValueError):
except (ImportError, DistributionNotFound):
return False
try:
pkg_version = Version(pkg.__version__)
if hasattr(pkg, "__version__"):
pkg_version = Version(pkg.__version__)
else:
# try pkg_resources to infer version
pkg_version = Version(pkg_resources.get_distribution(package).version)
except TypeError:
# this is mock by sphinx, so it shall return True to generate all summaries
# this is mocked by Sphinx, so it should return True to generate all summaries
return True
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))


Expand Down Expand Up @@ -128,6 +137,7 @@ class Image:
_PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0")
_TM_GREATER_EQUAL_0_7_0 = _compare_version("torchmetrics", operator.ge, "0.7.0")
_BAAL_GREATER_EQUAL_1_5_2 = _compare_version("baal", operator.ge, "1.5.2")

_TEXT_AVAILABLE = all(
[
Expand Down
14 changes: 12 additions & 2 deletions flash/image/classification/integrations/baal/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

import flash
from flash.core.utilities.imports import _BAAL_AVAILABLE
from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2

if _BAAL_AVAILABLE:
from baal.bayesian.dropout import _patch_dropout_layers
# _patch_dropout_layers function was replaced with replace_layers_in_module helper
# function in v1.5.2 (https://github.com/ElementAI/baal/pull/194 for more details)
if _BAAL_GREATER_EQUAL_1_5_2:
from baal.bayesian.common import replace_layers_in_module
from baal.bayesian.consistent_dropout import _consistent_dropout_mapping_fn

def _patch_dropout_layers(module: torch.nn.Module):
return replace_layers_in_module(module, _consistent_dropout_mapping_fn)

else:
from baal.bayesian.consistent_dropout import _patch_dropout_layers


class InferenceMCDropoutTask(flash.Task):
Expand Down
7 changes: 7 additions & 0 deletions flash_examples/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
# limitations under the License.
from flash import Trainer
from flash.core.data.utils import download_data
from flash.core.utilities.imports import example_requires
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

example_requires("text")

import nltk # noqa: E402

nltk.download("punkt")

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")

Expand Down
2 changes: 0 additions & 2 deletions requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ icedata
effdet
albumentations
learn2learn
structlog==21.1.0 # remove when baal resolved its dependency.
baal
fastface
fairscale

Expand Down
2 changes: 2 additions & 0 deletions requirements/datatype_image_extras_baal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# This is a separate file, as baal integration is affected by vissl installation (conflicts)
baal>=1.3.2
37 changes: 31 additions & 6 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,45 @@ def test_finetuning_with_none_return_type(strategy):

@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize(
("strategy", "checker_class", "checker_class_data"),
("strategy", "lr_scheduler", "checker_class", "checker_class_data"),
[
("no_freeze", None, {}),
("freeze", FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
("no_freeze", None, None, {}),
("freeze", None, FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), None, FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
(
("unfreeze_milestones", ((1, 3), 1)),
None,
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
(
"no_freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
None,
{},
),
(
"freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeStrategyChecking,
{},
),
(
("freeze_unfreeze", 2),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeUnfreezeStrategyChecking,
{"check_epoch": 2},
),
(
("unfreeze_milestones", ((1, 3), 1)),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
],
)
def test_finetuning(tmpdir, strategy, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss)
def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss, lr_scheduler=lr_scheduler, optimizer="sgd", learning_rate=0.1)
callbacks = [] if checker_class is None else checker_class(dirpath=tmpdir, **checker_class_data)
trainer = flash.Trainer(max_epochs=5, limit_train_batches=10, callbacks=callbacks)
ds = DummyDataset()
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/task_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _copy_func(f):
return g


class _StaticDataset(Dataset):
class StaticDataset(Dataset):
def __init__(self, sample, length):
super().__init__()

Expand All @@ -60,7 +60,7 @@ def _test_forward(self):

def _test_fit(self, tmpdir, task_kwargs):
"""Tests that a single batch fit pass completes."""
dataset = _StaticDataset(self.example_train_sample, 4)
dataset = StaticDataset(self.example_train_sample, 4)

args = self.task_args
kwargs = dict(**self.task_kwargs)
Expand Down
4 changes: 3 additions & 1 deletion tests/image/classification/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import SequentialSampler

import flash
from flash.core.utilities.imports import _BAAL_AVAILABLE, _IMAGE_AVAILABLE
from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _IMAGE_AVAILABLE
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
from tests.image.classification.test_data import _rand_image
Expand Down Expand Up @@ -110,6 +110,8 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s
assert len(active_learning_dm._dataset) == 20
assert active_learning_loop.progress.total.completed == 3
labelled = active_learning_loop.state_dict()["state_dict"]["datamodule_state_dict"]["labelled"]
if _BAAL_GREATER_EQUAL_1_5_2:
labelled = labelled > 0
assert isinstance(labelled, np.ndarray)

# Check that we iterate over the actual pool and that shuffle is disabled.
Expand Down
Loading

0 comments on commit def986a

Please sign in to comment.