From 88227d882f053aa6e3386cb7bd35349d353f9e82 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Fri, 6 May 2022 17:11:23 +0530 Subject: [PATCH] Support latest version of BaaL (1.5.2) and add necessary utilities (#1315) Co-authored-by: Ethan Harris --- .github/workflows/ci-testing.yml | 1 + CHANGELOG.md | 2 ++ flash/core/utilities/imports.py | 20 ++++++++++++++----- .../integrations/baal/dropout.py | 14 +++++++++++-- requirements/datatype_image_extras.txt | 2 -- requirements/datatype_image_extras_baal.txt | 2 ++ .../classification/test_active_learning.py | 4 +++- 7 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 requirements/datatype_image_extras_baal.txt diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d4880c55c2..a18b81f9f3 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -33,6 +33,7 @@ jobs: - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'core' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras_baal' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] } - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'tabular' ] } diff --git a/CHANGELOG.md b/CHANGELOG.md index 908b964a0c..757e7f313c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed support for all the versions (including the latest and older) of `baal`. ([#1315](https://github.com/PyTorchLightning/lightning-flash/pull/1315)) + - Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324)) ## [0.7.4] - 2022-04-27 diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 201430d3b6..0fdc4b0d75 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -17,8 +17,9 @@ import os import types from importlib.util import find_spec -from typing import List, Tuple, Union +from typing import Callable, List, Tuple, Union +import pkg_resources from pkg_resources import DistributionNotFound try: @@ -48,21 +49,29 @@ def _module_available(module_path: str) -> bool: return True -def _compare_version(package: str, op, version) -> bool: +def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: """Compare package version with some requirements. >>> _compare_version("torch", operator.ge, "0.1") True + >>> _compare_version("does_not_exist", operator.ge, "0.0") + False """ try: pkg = importlib.import_module(package) - except (ModuleNotFoundError, DistributionNotFound, ValueError): + except (ImportError, DistributionNotFound): return False try: - pkg_version = Version(pkg.__version__) + if hasattr(pkg, "__version__"): + pkg_version = Version(pkg.__version__) + else: + # try pkg_resources to infer version + pkg_version = Version(pkg_resources.get_distribution(package).version) except TypeError: - # this is mock by sphinx, so it shall return True to generate all summaries + # this is mocked by Sphinx, so it should return True to generate all summaries return True + if use_base_version: + pkg_version = Version(pkg_version.base_version) return op(pkg_version, Version(version)) @@ -128,6 +137,7 @@ class Image: _PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0") _ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0") _TM_GREATER_EQUAL_0_7_0 = _compare_version("torchmetrics", operator.ge, "0.7.0") + _BAAL_GREATER_EQUAL_1_5_2 = _compare_version("baal", operator.ge, "1.5.2") _TEXT_AVAILABLE = all( [ diff --git a/flash/image/classification/integrations/baal/dropout.py b/flash/image/classification/integrations/baal/dropout.py index 103c76858d..db694e2642 100644 --- a/flash/image/classification/integrations/baal/dropout.py +++ b/flash/image/classification/integrations/baal/dropout.py @@ -15,10 +15,20 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException import flash -from flash.core.utilities.imports import _BAAL_AVAILABLE +from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2 if _BAAL_AVAILABLE: - from baal.bayesian.dropout import _patch_dropout_layers + # _patch_dropout_layers function was replaced with replace_layers_in_module helper + # function in v1.5.2 (https://github.com/ElementAI/baal/pull/194 for more details) + if _BAAL_GREATER_EQUAL_1_5_2: + from baal.bayesian.common import replace_layers_in_module + from baal.bayesian.consistent_dropout import _consistent_dropout_mapping_fn + + def _patch_dropout_layers(module: torch.nn.Module): + return replace_layers_in_module(module, _consistent_dropout_mapping_fn) + + else: + from baal.bayesian.consistent_dropout import _patch_dropout_layers class InferenceMCDropoutTask(flash.Task): diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 89c0d7b29d..ed1c2f386f 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -7,8 +7,6 @@ icedata effdet albumentations learn2learn -structlog==21.1.0 # remove when baal resolved its dependency. -baal fastface fairscale diff --git a/requirements/datatype_image_extras_baal.txt b/requirements/datatype_image_extras_baal.txt new file mode 100644 index 0000000000..37386a6359 --- /dev/null +++ b/requirements/datatype_image_extras_baal.txt @@ -0,0 +1,2 @@ +# This is a separate file, as baal integration is affected by vissl installation (conflicts) +baal>=1.3.2 diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 7e622f490f..8965228838 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -22,7 +22,7 @@ from torch.utils.data import SequentialSampler import flash -from flash.core.utilities.imports import _BAAL_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop from tests.image.classification.test_data import _rand_image @@ -110,6 +110,8 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s assert len(active_learning_dm._dataset) == 20 assert active_learning_loop.progress.total.completed == 3 labelled = active_learning_loop.state_dict()["state_dict"]["datamodule_state_dict"]["labelled"] + if _BAAL_GREATER_EQUAL_1_5_2: + labelled = labelled > 0 assert isinstance(labelled, np.ndarray) # Check that we iterate over the actual pool and that shuffle is disabled.