Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support latest version of BaaL (1.5.2) and add necessary utilities (#…
Browse files Browse the repository at this point in the history
…1315)

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
krshrimali and ethanwharris authored May 6, 2022
1 parent 07d63e3 commit 88227d8
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'core' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras_baal' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] }
- { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'tabular' ] }
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed support for all the versions (including the latest and older) of `baal`. ([#1315](https://github.com/PyTorchLightning/lightning-flash/pull/1315))

- Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324))

## [0.7.4] - 2022-04-27
Expand Down
20 changes: 15 additions & 5 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import os
import types
from importlib.util import find_spec
from typing import List, Tuple, Union
from typing import Callable, List, Tuple, Union

import pkg_resources
from pkg_resources import DistributionNotFound

try:
Expand Down Expand Up @@ -48,21 +49,29 @@ def _module_available(module_path: str) -> bool:
return True


def _compare_version(package: str, op, version) -> bool:
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
"""Compare package version with some requirements.
>>> _compare_version("torch", operator.ge, "0.1")
True
>>> _compare_version("does_not_exist", operator.ge, "0.0")
False
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound, ValueError):
except (ImportError, DistributionNotFound):
return False
try:
pkg_version = Version(pkg.__version__)
if hasattr(pkg, "__version__"):
pkg_version = Version(pkg.__version__)
else:
# try pkg_resources to infer version
pkg_version = Version(pkg_resources.get_distribution(package).version)
except TypeError:
# this is mock by sphinx, so it shall return True to generate all summaries
# this is mocked by Sphinx, so it should return True to generate all summaries
return True
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))


Expand Down Expand Up @@ -128,6 +137,7 @@ class Image:
_PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0")
_TM_GREATER_EQUAL_0_7_0 = _compare_version("torchmetrics", operator.ge, "0.7.0")
_BAAL_GREATER_EQUAL_1_5_2 = _compare_version("baal", operator.ge, "1.5.2")

_TEXT_AVAILABLE = all(
[
Expand Down
14 changes: 12 additions & 2 deletions flash/image/classification/integrations/baal/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

import flash
from flash.core.utilities.imports import _BAAL_AVAILABLE
from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2

if _BAAL_AVAILABLE:
from baal.bayesian.dropout import _patch_dropout_layers
# _patch_dropout_layers function was replaced with replace_layers_in_module helper
# function in v1.5.2 (https://github.com/ElementAI/baal/pull/194 for more details)
if _BAAL_GREATER_EQUAL_1_5_2:
from baal.bayesian.common import replace_layers_in_module
from baal.bayesian.consistent_dropout import _consistent_dropout_mapping_fn

def _patch_dropout_layers(module: torch.nn.Module):
return replace_layers_in_module(module, _consistent_dropout_mapping_fn)

else:
from baal.bayesian.consistent_dropout import _patch_dropout_layers


class InferenceMCDropoutTask(flash.Task):
Expand Down
2 changes: 0 additions & 2 deletions requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ icedata
effdet
albumentations
learn2learn
structlog==21.1.0 # remove when baal resolved its dependency.
baal
fastface
fairscale

Expand Down
2 changes: 2 additions & 0 deletions requirements/datatype_image_extras_baal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# This is a separate file, as baal integration is affected by vissl installation (conflicts)
baal>=1.3.2
4 changes: 3 additions & 1 deletion tests/image/classification/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import SequentialSampler

import flash
from flash.core.utilities.imports import _BAAL_AVAILABLE, _IMAGE_AVAILABLE
from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _IMAGE_AVAILABLE
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
from tests.image.classification.test_data import _rand_image
Expand Down Expand Up @@ -110,6 +110,8 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s
assert len(active_learning_dm._dataset) == 20
assert active_learning_loop.progress.total.completed == 3
labelled = active_learning_loop.state_dict()["state_dict"]["datamodule_state_dict"]["labelled"]
if _BAAL_GREATER_EQUAL_1_5_2:
labelled = labelled > 0
assert isinstance(labelled, np.ndarray)

# Check that we iterate over the actual pool and that shuffle is disabled.
Expand Down

0 comments on commit 88227d8

Please sign in to comment.