From 88227d882f053aa6e3386cb7bd35349d353f9e82 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Fri, 6 May 2022 17:11:23 +0530 Subject: [PATCH 1/3] Support latest version of BaaL (1.5.2) and add necessary utilities (#1315) Co-authored-by: Ethan Harris --- .github/workflows/ci-testing.yml | 1 + CHANGELOG.md | 2 ++ flash/core/utilities/imports.py | 20 ++++++++++++++----- .../integrations/baal/dropout.py | 14 +++++++++++-- requirements/datatype_image_extras.txt | 2 -- requirements/datatype_image_extras_baal.txt | 2 ++ .../classification/test_active_learning.py | 4 +++- 7 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 requirements/datatype_image_extras_baal.txt diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d4880c55c29..a18b81f9f3e 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -33,6 +33,7 @@ jobs: - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'core' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras_baal' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'tabular' ] } diff --git a/CHANGELOG.md b/CHANGELOG.md index 908b964a0c1..757e7f313c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed support for all the versions (including the latest and older) of `baal`. ([#1315](https://github.com/PyTorchLightning/lightning-flash/pull/1315)) + - Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324)) ## [0.7.4] - 2022-04-27 diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 201430d3b64..0fdc4b0d758 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -17,8 +17,9 @@ import os import types from importlib.util import find_spec -from typing import List, Tuple, Union +from typing import Callable, List, Tuple, Union +import pkg_resources from pkg_resources import DistributionNotFound try: @@ -48,21 +49,29 @@ def _module_available(module_path: str) -> bool: return True -def _compare_version(package: str, op, version) -> bool: +def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: """Compare package version with some requirements. >>> _compare_version("torch", operator.ge, "0.1") True + >>> _compare_version("does_not_exist", operator.ge, "0.0") + False """ try: pkg = importlib.import_module(package) - except (ModuleNotFoundError, DistributionNotFound, ValueError): + except (ImportError, DistributionNotFound): return False try: - pkg_version = Version(pkg.__version__) + if hasattr(pkg, "__version__"): + pkg_version = Version(pkg.__version__) + else: + # try pkg_resources to infer version + pkg_version = Version(pkg_resources.get_distribution(package).version) except TypeError: - # this is mock by sphinx, so it shall return True to generate all summaries + # this is mocked by Sphinx, so it should return True to generate all summaries return True + if use_base_version: + pkg_version = Version(pkg_version.base_version) return op(pkg_version, Version(version)) @@ -128,6 +137,7 @@ class Image: _PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0") _ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0") _TM_GREATER_EQUAL_0_7_0 = _compare_version("torchmetrics", operator.ge, "0.7.0") + _BAAL_GREATER_EQUAL_1_5_2 = _compare_version("baal", operator.ge, "1.5.2") _TEXT_AVAILABLE = all( [ diff --git a/flash/image/classification/integrations/baal/dropout.py b/flash/image/classification/integrations/baal/dropout.py index 103c76858dc..db694e26421 100644 --- a/flash/image/classification/integrations/baal/dropout.py +++ b/flash/image/classification/integrations/baal/dropout.py @@ -15,10 +15,20 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException import flash -from flash.core.utilities.imports import _BAAL_AVAILABLE +from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2 if _BAAL_AVAILABLE: - from baal.bayesian.dropout import _patch_dropout_layers + # _patch_dropout_layers function was replaced with replace_layers_in_module helper + # function in v1.5.2 (https://github.com/ElementAI/baal/pull/194 for more details) + if _BAAL_GREATER_EQUAL_1_5_2: + from baal.bayesian.common import replace_layers_in_module + from baal.bayesian.consistent_dropout import _consistent_dropout_mapping_fn + + def _patch_dropout_layers(module: torch.nn.Module): + return replace_layers_in_module(module, _consistent_dropout_mapping_fn) + + else: + from baal.bayesian.consistent_dropout import _patch_dropout_layers class InferenceMCDropoutTask(flash.Task): diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 89c0d7b29d3..ed1c2f386f8 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -7,8 +7,6 @@ icedata effdet albumentations learn2learn -structlog==21.1.0 # remove when baal resolved its dependency. -baal fastface fairscale diff --git a/requirements/datatype_image_extras_baal.txt b/requirements/datatype_image_extras_baal.txt new file mode 100644 index 00000000000..37386a63592 --- /dev/null +++ b/requirements/datatype_image_extras_baal.txt @@ -0,0 +1,2 @@ +# This is a separate file, as baal integration is affected by vissl installation (conflicts) +baal>=1.3.2 diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 7e622f490fc..89652288386 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -22,7 +22,7 @@ from torch.utils.data import SequentialSampler import flash -from flash.core.utilities.imports import _BAAL_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop from tests.image.classification.test_data import _rand_image @@ -110,6 +110,8 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s assert len(active_learning_dm._dataset) == 20 assert active_learning_loop.progress.total.completed == 3 labelled = active_learning_loop.state_dict()["state_dict"]["datamodule_state_dict"]["labelled"] + if _BAAL_GREATER_EQUAL_1_5_2: + labelled = labelled > 0 assert isinstance(labelled, np.ndarray) # Check that we iterate over the actual pool and that shuffle is disabled. From 48a25006cad455672063dc5dc9141c43fcf6607d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 6 May 2022 14:49:37 +0100 Subject: [PATCH 2/3] Add fit tests for tabular tasks (#1332) --- .../integrations/pytorch_tabular/adapter.py | 11 +- flash_examples/question_answering.py | 7 + tests/helpers/task_tester.py | 4 +- tests/tabular/classification/test_model.py | 128 +++++++----------- tests/tabular/forecasting/test_model.py | 96 ++++--------- tests/tabular/regression/test_model.py | 126 +++++++---------- 6 files changed, 144 insertions(+), 228 deletions(-) diff --git a/flash/core/integrations/pytorch_tabular/adapter.py b/flash/core/integrations/pytorch_tabular/adapter.py index 569d609639f..597e4cbd2bf 100644 --- a/flash/core/integrations/pytorch_tabular/adapter.py +++ b/flash/core/integrations/pytorch_tabular/adapter.py @@ -21,9 +21,10 @@ class PytorchTabularAdapter(Adapter): - def __init__(self, backbone): + def __init__(self, task_type, backbone): super().__init__() + self.task_type = task_type self.backbone = backbone @classmethod @@ -52,21 +53,23 @@ def from_task( "output_dim": output_dim, } adapter = cls( + task_type, task.backbones.get(backbone)( task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs - ) + ), ) return adapter - @staticmethod - def convert_batch(batch): + def convert_batch(self, batch): new_batch = { "continuous": batch[DataKeys.INPUT][1], "categorical": batch[DataKeys.INPUT][0], } if DataKeys.TARGET in batch: new_batch["target"] = batch[DataKeys.TARGET].reshape(-1, 1) + if self.task_type == "regression": + new_batch["target"] = new_batch["target"].float() return new_batch def training_step(self, batch, batch_idx) -> Any: diff --git a/flash_examples/question_answering.py b/flash_examples/question_answering.py index beb2a7add87..ab2448fa09b 100644 --- a/flash_examples/question_answering.py +++ b/flash_examples/question_answering.py @@ -13,8 +13,15 @@ # limitations under the License. from flash import Trainer from flash.core.data.utils import download_data +from flash.core.utilities.imports import example_requires from flash.text import QuestionAnsweringData, QuestionAnsweringTask +example_requires("text") + +import nltk # noqa: E402 + +nltk.download("punkt") + # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/") diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py index 83b2d189c31..06a6ed9fd01 100644 --- a/tests/helpers/task_tester.py +++ b/tests/helpers/task_tester.py @@ -36,7 +36,7 @@ def _copy_func(f): return g -class _StaticDataset(Dataset): +class StaticDataset(Dataset): def __init__(self, sample, length): super().__init__() @@ -60,7 +60,7 @@ def _test_forward(self): def _test_fit(self, tmpdir, task_kwargs): """Tests that a single batch fit pass completes.""" - dataset = _StaticDataset(self.example_train_sample, 4) + dataset = StaticDataset(self.example_train_sample, 4) args = self.task_args kwargs = dict(**self.task_kwargs) diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index d9ba0a1d814..c515ab76a0a 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -17,34 +17,13 @@ import pandas as pd import pytest import torch -from pytorch_lightning import Trainer +import flash from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _SERVE_TESTING, _TABULAR_AVAILABLE, _TABULAR_TESTING from flash.tabular.classification.data import TabularClassificationData from flash.tabular.classification.model import TabularClassifier -from tests.helpers.task_tester import TaskTester - -# ======== Mock functions ======== - - -class DummyDataset(torch.utils.data.Dataset): - def __init__(self, num_num=16, num_cat=16): - super().__init__() - self.num_num = num_num - self.num_cat = num_cat - - def __getitem__(self, index): - target = torch.randint(0, 10, size=(1,)).item() - cat_vars = torch.randint(0, 10, size=(self.num_cat,)) - num_vars = torch.rand(self.num_num) - return {DataKeys.INPUT: (cat_vars, num_vars), DataKeys.TARGET: target} - - def __len__(self) -> int: - return 100 - - -# ============================== +from tests.helpers.task_tester import StaticDataset, TaskTester class TestTabularClassifier(TaskTester): @@ -66,6 +45,23 @@ class TestTabularClassifier(TaskTester): scriptable = False traceable = False + marks = { + "test_fit": [ + pytest.mark.parametrize( + "task_kwargs", + [ + {"backbone": "tabnet"}, + {"backbone": "tabtransformer"}, + {"backbone": "fttransformer"}, + {"backbone": "autoint"}, + {"backbone": "node"}, + {"backbone": "category_embedding"}, + ], + ) + ], + "test_cli": [pytest.mark.parametrize("extra_args", ([],))], + } + @property def example_forward_input(self): return { @@ -77,63 +73,37 @@ def check_forward_output(self, output: Any): assert isinstance(output, torch.Tensor) assert output.shape == torch.Size([1, 10]) + @property + def example_train_sample(self): + return {DataKeys.INPUT: (torch.randint(0, 10, size=(4,)), torch.rand(4)), DataKeys.TARGET: 1} -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") -@pytest.mark.parametrize( - "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] -) -def test_init_train(backbone, tmpdir): - train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=16) - data_properties = { - "parameters": {"categorical_fields": list(range(16))}, - "embedding_sizes": [(10, 32) for _ in range(16)], - "cat_dims": [10 for _ in range(16)], - "num_features": 32, - "num_classes": 10, - "backbone": backbone, - } - - model = TabularClassifier(**data_properties) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) - - -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") -@pytest.mark.parametrize( - "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] -) -def test_init_train_no_num(backbone, tmpdir): - train_dl = torch.utils.data.DataLoader(DummyDataset(num_num=0), batch_size=16) - data_properties = { - "parameters": {"categorical_fields": list(range(16))}, - "embedding_sizes": [(10, 32) for _ in range(16)], - "cat_dims": [10 for _ in range(16)], - "num_features": 16, - "num_classes": 10, - "backbone": backbone, - } - - model = TabularClassifier(**data_properties) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) - - -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") -@pytest.mark.parametrize("backbone", ["tabnet", "tabtransformer", "autoint", "node", "category_embedding"]) -def test_init_train_no_cat(backbone, tmpdir): - train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16) - data_properties = { - "parameters": {"categorical_fields": []}, - "embedding_sizes": [], - "cat_dims": [], - "num_features": 16, - "num_classes": 10, - "backbone": backbone, - } - - model = TabularClassifier(**data_properties) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) + @pytest.mark.parametrize( + "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] + ) + def test_init_train_no_num(self, backbone, tmpdir): + no_num_sample = {DataKeys.INPUT: (torch.randint(0, 10, size=(4,)), torch.empty(0)), DataKeys.TARGET: 1} + dataset = StaticDataset(no_num_sample, 4) + + args = self.task_args + kwargs = dict(**self.task_kwargs) + kwargs.update(num_features=4) + model = self.task(*args, **kwargs) + + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) + + @pytest.mark.parametrize("backbone", ["tabnet", "tabtransformer", "autoint", "node", "category_embedding"]) + def test_init_train_no_cat(self, backbone, tmpdir): + no_cat_sample = {DataKeys.INPUT: (torch.empty(0), torch.rand(4)), DataKeys.TARGET: 1} + dataset = StaticDataset(no_cat_sample, 4) + + args = self.task_args + kwargs = dict(**self.task_kwargs) + kwargs.update(parameters={"categorical_fields": []}, embedding_sizes=[], cat_dims=[], num_features=4) + model = self.task(*args, **kwargs) + + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py index ac393b7b606..a6e9a12c9f1 100644 --- a/tests/tabular/forecasting/test_model.py +++ b/tests/tabular/forecasting/test_model.py @@ -17,20 +17,17 @@ import torch import flash -from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE, _TABULAR_TESTING -from flash.tabular.forecasting import TabularForecaster, TabularForecastingData -from tests.helpers.task_tester import TaskTester +from flash import DataKeys +from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TABULAR_TESTING +from flash.tabular.forecasting import TabularForecaster +from tests.helpers.task_tester import StaticDataset, TaskTester if _TABULAR_AVAILABLE: from pytorch_forecasting.data import EncoderNormalizer, NaNLabelEncoder - from pytorch_forecasting.data.examples import generate_ar_data else: EncoderNormalizer = object NaNLabelEncoder = object -if _PANDAS_AVAILABLE: - import pandas as pd - class TestTabularForecaster(TaskTester): @@ -102,66 +99,33 @@ def check_forward_output(self, output: Any): assert isinstance(output["prediction"], torch.Tensor) assert output["prediction"].shape == torch.Size([2, 20]) + @property + def example_train_sample(self): + return { + DataKeys.INPUT: { + "x_cat": torch.empty(60, 0, dtype=torch.int64), + "x_cont": torch.zeros(60, 1), + "encoder_target": torch.zeros(60), + "encoder_length": 60, + "decoder_length": 20, + "encoder_time_idx_start": 1, + "groups": torch.zeros(1), + "target_scale": torch.zeros(2), + }, + DataKeys.TARGET: (torch.rand(20), None), + } -@pytest.fixture -def sample_data(): - data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) - data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") - max_prediction_length = 20 - training_cutoff = data["time_idx"].max() - max_prediction_length - return data, training_cutoff, max_prediction_length - - -@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") -def test_fast_dev_run_smoke(sample_data): - """Test that fast dev run works with the NBeats example data.""" - data, training_cutoff, max_prediction_length = sample_data - datamodule = TabularForecastingData.from_data_frame( - time_idx="time_idx", - target="value", - categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, - group_ids=["series"], - time_varying_unknown_reals=["value"], - max_encoder_length=60, - max_prediction_length=max_prediction_length, - train_data_frame=data[lambda x: x.time_idx <= training_cutoff], - val_data_frame=data, - batch_size=4, - ) - - model = TabularForecaster( - datamodule.parameters, - backbone="n_beats", - backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1}, - ) - - trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01) - trainer.fit(model, datamodule=datamodule) - + def test_testing_raises(self, tmpdir): + """Tests that ``NotImplementedError`` is raised when attempting to perform a test pass.""" + dataset = StaticDataset(self.example_train_sample, 4) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") -def test_testing_raises(sample_data): - """Tests that ``NotImplementedError`` is raised when attempting to perform a test pass.""" - data, training_cutoff, max_prediction_length = sample_data - datamodule = TabularForecastingData.from_data_frame( - time_idx="time_idx", - target="value", - categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, - group_ids=["series"], - time_varying_unknown_reals=["value"], - max_encoder_length=60, - max_prediction_length=max_prediction_length, - train_data_frame=data[lambda x: x.time_idx <= training_cutoff], - test_data_frame=data, - batch_size=4, - ) + args = self.task_args + kwargs = dict(**self.task_kwargs) + model = self.task(*args, **kwargs) - model = TabularForecaster( - datamodule.parameters, - backbone="n_beats", - backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1}, - ) - trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01) + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."): - trainer.test(model, datamodule=datamodule) + with pytest.raises( + NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing." + ): + trainer.test(model, model.process_test_dataset(dataset, batch_size=4)) diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py index 31c239a0f7f..f476467cc7c 100644 --- a/tests/tabular/regression/test_model.py +++ b/tests/tabular/regression/test_model.py @@ -17,33 +17,12 @@ import pandas as pd import pytest import torch -from pytorch_lightning import Trainer +import flash from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _SERVE_TESTING, _TABULAR_AVAILABLE, _TABULAR_TESTING from flash.tabular import TabularRegressionData, TabularRegressor -from tests.helpers.task_tester import TaskTester - -# ======== Mock functions ======== - - -class DummyDataset(torch.utils.data.Dataset): - def __init__(self, num_num=16, num_cat=16): - super().__init__() - self.num_num = num_num - self.num_cat = num_cat - - def __getitem__(self, index): - target = torch.rand(1) - cat_vars = torch.randint(0, 10, size=(self.num_cat,)) - num_vars = torch.rand(self.num_num) - return {DataKeys.INPUT: (cat_vars, num_vars), DataKeys.TARGET: target} - - def __len__(self) -> int: - return 100 - - -# ============================== +from tests.helpers.task_tester import StaticDataset, TaskTester class TestTabularRegressor(TaskTester): @@ -64,6 +43,23 @@ class TestTabularRegressor(TaskTester): scriptable = False traceable = False + marks = { + "test_fit": [ + pytest.mark.parametrize( + "task_kwargs", + [ + {"backbone": "tabnet"}, + {"backbone": "tabtransformer"}, + {"backbone": "fttransformer"}, + {"backbone": "autoint"}, + {"backbone": "node"}, + {"backbone": "category_embedding"}, + ], + ) + ], + "test_cli": [pytest.mark.parametrize("extra_args", ([],))], + } + @property def example_forward_input(self): return { @@ -75,61 +71,37 @@ def check_forward_output(self, output: Any): assert isinstance(output, torch.Tensor) assert output.shape == torch.Size([1, 1]) + @property + def example_train_sample(self): + return {DataKeys.INPUT: (torch.randint(0, 10, size=(4,)), torch.rand(4)), DataKeys.TARGET: 0.1} -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") -@pytest.mark.parametrize( - "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] -) -def test_init_train(backbone, tmpdir): - train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=16) - - data_properties = { - "parameters": {"categorical_fields": list(range(16))}, - "embedding_sizes": [(10, 32) for _ in range(16)], - "cat_dims": [10 for _ in range(16)], - "num_features": 32, - "backbone": backbone, - } - - model = TabularRegressor(**data_properties) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) - - -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") -@pytest.mark.parametrize( - "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] -) -def test_init_train_no_num(backbone, tmpdir): - train_dl = torch.utils.data.DataLoader(DummyDataset(num_num=0), batch_size=16) - data_properties = { - "parameters": {"categorical_fields": list(range(16))}, - "embedding_sizes": [(10, 32) for _ in range(16)], - "cat_dims": [10 for _ in range(16)], - "num_features": 16, - "backbone": backbone, - } - - model = TabularRegressor(**data_properties) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) - - -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") -@pytest.mark.parametrize("backbone", ["tabnet", "tabtransformer", "autoint", "node", "category_embedding"]) -def test_init_train_no_cat(backbone, tmpdir): - train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16) - data_properties = { - "parameters": {"categorical_fields": []}, - "embedding_sizes": [], - "cat_dims": [], - "num_features": 16, - "backbone": backbone, - } - - model = TabularRegressor(**data_properties) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) + @pytest.mark.parametrize( + "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] + ) + def test_init_train_no_num(self, backbone, tmpdir): + no_num_sample = {DataKeys.INPUT: (torch.randint(0, 10, size=(4,)), torch.empty(0)), DataKeys.TARGET: 0.1} + dataset = StaticDataset(no_num_sample, 4) + + args = self.task_args + kwargs = dict(**self.task_kwargs) + kwargs.update(num_features=4) + model = self.task(*args, **kwargs) + + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) + + @pytest.mark.parametrize("backbone", ["tabnet", "tabtransformer", "autoint", "node", "category_embedding"]) + def test_init_train_no_cat(self, backbone, tmpdir): + no_cat_sample = {DataKeys.INPUT: (torch.empty(0), torch.rand(4)), DataKeys.TARGET: 0.1} + dataset = StaticDataset(no_cat_sample, 4) + + args = self.task_args + kwargs = dict(**self.task_kwargs) + kwargs.update(parameters={"categorical_fields": []}, embedding_sizes=[], cat_dims=[], num_features=4) + model = self.task(*args, **kwargs) + + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") From 1ed793824709c868635276f443d3680b8b46a8e3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 6 May 2022 16:08:39 +0100 Subject: [PATCH 3/3] Fix unfreeze strategies with onecyclelr and reduced lr (#1329) --- CHANGELOG.md | 4 ++++ docs/source/general/finetuning.rst | 2 +- flash/core/finetuning.py | 19 ++++++++++++--- tests/core/test_finetuning.py | 37 +++++++++++++++++++++++++----- 4 files changed, 52 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 757e7f313c6..7060a8e02bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324)) +- Fixed a bug where the `freeze_unfreeze` and `unfreeze_milestones` finetuning strategies could not be used in tandem with a `onecyclelr` LR scheduler ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329)) + +- Fixed a bug where the backbone learning rate would be divided by 10 when unfrozen if using the `freeze_unfreeze` or `unfreeze_milestones` strategies ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329)) + ## [0.7.4] - 2022-04-27 ### Fixed diff --git a/docs/source/general/finetuning.rst b/docs/source/general/finetuning.rst index fdd31146d8a..42cb873e29d 100644 --- a/docs/source/general/finetuning.rst +++ b/docs/source/general/finetuning.rst @@ -228,7 +228,7 @@ For even more customization, create your own finetuning callback. Learn more abo # When ``current_epoch`` is 5, backbone will start to be trained. if current_epoch == self._unfreeze_epoch: - self.unfreeze_and_add_param_group( + self.unfreeze_and_extend_param_group( pl_module.backbone, optimizer, ) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 4038f1e9ca4..f6b39c928ea 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -103,6 +103,19 @@ def freeze_before_training(self, pl_module: Union[Module, Iterable[Union[Module, modules = [modules] self.freeze(modules=modules, train_bn=self.train_bn) + def unfreeze_and_extend_param_group( + self, + modules: Union[Module, Iterable[Union[Module, Iterable]]], + optimizer: Optimizer, + train_bn: bool = True, + ) -> None: + self.make_trainable(modules) + + params = self.filter_params(modules, train_bn=train_bn, requires_grad=True) + params = self.filter_on_optimizer(optimizer, params) + if params: + optimizer.param_groups[0]["params"].extend(params) + def _freeze_unfreeze_function( self, pl_module: Union[Module, Iterable[Union[Module, Iterable]]], @@ -117,7 +130,7 @@ def _freeze_unfreeze_function( modules = self._get_modules_to_freeze(pl_module=pl_module) if modules is not None: - self.unfreeze_and_add_param_group( + self.unfreeze_and_extend_param_group( modules=modules, optimizer=optimizer, train_bn=self.train_bn, @@ -140,7 +153,7 @@ def _unfreeze_milestones_function( # unfreeze num_layers last layers backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[-num_layers:] - self.unfreeze_and_add_param_group( + self.unfreeze_and_extend_param_group( modules=backbone_modules, optimizer=optimizer, train_bn=self.train_bn, @@ -148,7 +161,7 @@ def _unfreeze_milestones_function( elif epoch == unfreeze_milestones[1]: # unfreeze remaining layers backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[:-num_layers] - self.unfreeze_and_add_param_group( + self.unfreeze_and_extend_param_group( modules=backbone_modules, optimizer=optimizer, train_bn=self.train_bn, diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 311eb721693..67c63647b06 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -155,20 +155,45 @@ def test_finetuning_with_none_return_type(strategy): @pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") @pytest.mark.parametrize( - ("strategy", "checker_class", "checker_class_data"), + ("strategy", "lr_scheduler", "checker_class", "checker_class_data"), [ - ("no_freeze", None, {}), - ("freeze", FreezeStrategyChecking, {}), - (("freeze_unfreeze", 2), FreezeUnfreezeStrategyChecking, {"check_epoch": 2}), + ("no_freeze", None, None, {}), + ("freeze", None, FreezeStrategyChecking, {}), + (("freeze_unfreeze", 2), None, FreezeUnfreezeStrategyChecking, {"check_epoch": 2}), ( ("unfreeze_milestones", ((1, 3), 1)), + None, + UnfreezeMilestonesStrategyChecking, + {"check_epochs": [1, 3], "num_layers": 1}, + ), + ( + "no_freeze", + ("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}), + None, + {}, + ), + ( + "freeze", + ("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}), + FreezeStrategyChecking, + {}, + ), + ( + ("freeze_unfreeze", 2), + ("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}), + FreezeUnfreezeStrategyChecking, + {"check_epoch": 2}, + ), + ( + ("unfreeze_milestones", ((1, 3), 1)), + ("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}), UnfreezeMilestonesStrategyChecking, {"check_epochs": [1, 3], "num_layers": 1}, ), ], ) -def test_finetuning(tmpdir, strategy, checker_class, checker_class_data): - task = TestTaskWithFinetuning(loss_fn=F.nll_loss) +def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class_data): + task = TestTaskWithFinetuning(loss_fn=F.nll_loss, lr_scheduler=lr_scheduler, optimizer="sgd", learning_rate=0.1) callbacks = [] if checker_class is None else checker_class(dirpath=tmpdir, **checker_class_data) trainer = flash.Trainer(max_epochs=5, limit_train_batches=10, callbacks=callbacks) ds = DummyDataset()