Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add required dependencies error to available_* methods (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Feb 4, 2022
1 parent e65020c commit 20b3a7d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
11 changes: 6 additions & 5 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import inspect
import pickle
import re
from abc import ABCMeta
from copy import deepcopy
from importlib import import_module
Expand Down Expand Up @@ -267,11 +268,11 @@ def __new__(mcs, *args, **kwargs):
result = ABCMeta.__new__(mcs, *args, **kwargs)
if result.required_extras is not None:
result.__init__ = requires(result.required_extras)(result.__init__)
load_from_checkpoint = getattr(result, "load_from_checkpoint", None)
if load_from_checkpoint is not None:
result.load_from_checkpoint = classmethod(
requires(result.required_extras)(result.load_from_checkpoint.__func__)
)

patterns = ["load_from_checkpoint", "available_*"] # must match classmethods only
regex = "(" + ")|(".join(patterns) + ")"
for attribute_name, attribute_value in filter(lambda x: re.match(regex, x[0]), inspect.getmembers(result)):
setattr(result, attribute_name, classmethod(requires(result.required_extras)(attribute_value.__func__)))
return result


Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from flash.core.utilities.imports import (
_AUDIO_TESTING,
_GRAPH_TESTING,
_IMAGE_AVAILABLE,
_IMAGE_TESTING,
_TABULAR_TESTING,
_TEXT_TESTING,
Expand Down Expand Up @@ -317,6 +318,12 @@ class Foo(ImageClassifier):
assert Foo.available_backbones() == {}


@pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.")
def test_available_backbones_raises():
with pytest.raises(ModuleNotFoundError, match="Required dependencies not available."):
_ = ImageClassifier.available_backbones()


@ClassificationTask.lr_schedulers
def custom_steplr_configuration_return_as_instance(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
Expand Down

0 comments on commit 20b3a7d

Please sign in to comment.