From 0a159a83f6105964807b16bab83912db570175f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Aug 2021 12:24:32 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/conf.py | 66 +- flash/__about__.py | 2 +- flash/__init__.py | 1 + flash/__main__.py | 17 +- flash/audio/classification/cli.py | 6 +- flash/audio/classification/data.py | 5 +- flash/audio/classification/transforms.py | 7 +- flash/audio/speech_recognition/cli.py | 4 +- flash/audio/speech_recognition/data.py | 41 +- flash/audio/speech_recognition/model.py | 7 +- flash/core/classification.py | 1 - flash/core/data/auto_dataset.py | 8 +- flash/core/data/batch.py | 48 +- flash/core/data/callback.py | 3 +- flash/core/data/data_module.py | 48 +- flash/core/data/data_pipeline.py | 103 +- flash/core/data/data_source.py | 33 +- flash/core/data/process.py | 37 +- flash/core/data/properties.py | 7 +- flash/core/data/transforms.py | 2 +- flash/core/data/utils.py | 25 +- flash/core/finetuning.py | 12 +- flash/core/model.py | 99 +- flash/core/registry.py | 12 +- flash/core/schedulers.py | 3 +- flash/core/serve/_compat/__init__.py | 2 +- flash/core/serve/_compat/cached_property.py | 2 +- flash/core/serve/component.py | 7 +- flash/core/serve/composition.py | 8 +- flash/core/serve/core.py | 23 +- flash/core/serve/dag/optimization.py | 140 +- flash/core/serve/dag/order.py | 6 +- flash/core/serve/dag/rewrite.py | 4 +- flash/core/serve/dag/task.py | 6 +- flash/core/serve/dag/visualize.py | 4 +- flash/core/serve/decorators.py | 5 +- flash/core/serve/execution.py | 32 +- flash/core/serve/flash_components.py | 3 - flash/core/serve/interfaces/http.py | 35 +- flash/core/serve/interfaces/models.py | 13 +- flash/core/serve/server.py | 2 +- flash/core/serve/types/label.py | 3 +- flash/core/serve/types/table.py | 3 +- flash/core/serve/utils.py | 2 +- flash/core/trainer.py | 6 +- flash/core/utilities/flash_cli.py | 11 +- flash/core/utilities/imports.py | 48 +- flash/core/utilities/lightning_cli.py | 120 +- flash/core/utilities/url_error.py | 4 +- flash/graph/classification/cli.py | 6 +- flash/graph/classification/data.py | 1 - flash/graph/classification/model.py | 2 - flash/graph/data.py | 1 - flash/image/backbones.py | 2 +- .../image/classification/backbones/resnet.py | 113 +- .../classification/backbones/torchvision.py | 6 +- .../classification/backbones/transformers.py | 8 +- flash/image/classification/cli.py | 11 +- flash/image/classification/data.py | 36 +- flash/image/classification/model.py | 10 +- flash/image/classification/transforms.py | 2 +- flash/image/data.py | 5 - flash/image/detection/cli.py | 6 +- flash/image/detection/data.py | 7 +- flash/image/detection/model.py | 23 +- flash/image/detection/serialization.py | 12 +- flash/image/detection/transforms.py | 16 +- flash/image/embedding/model.py | 8 +- flash/image/segmentation/cli.py | 8 +- flash/image/segmentation/data.py | 39 +- flash/image/segmentation/heads.py | 13 +- flash/image/segmentation/model.py | 9 +- flash/image/segmentation/serialization.py | 4 +- flash/image/segmentation/transforms.py | 7 +- flash/image/style_transfer/cli.py | 6 +- flash/image/style_transfer/data.py | 9 +- flash/image/style_transfer/model.py | 4 +- flash/pointcloud/detection/cli.py | 4 +- flash/pointcloud/detection/data.py | 6 +- flash/pointcloud/detection/datasets.py | 2 +- flash/pointcloud/detection/model.py | 13 +- flash/pointcloud/detection/open3d_ml/app.py | 8 +- .../detection/open3d_ml/backbones.py | 11 +- .../detection/open3d_ml/data_sources.py | 27 +- flash/pointcloud/segmentation/cli.py | 6 +- flash/pointcloud/segmentation/data.py | 7 +- flash/pointcloud/segmentation/datasets.py | 6 +- flash/pointcloud/segmentation/model.py | 10 +- .../pointcloud/segmentation/open3d_ml/app.py | 20 +- .../segmentation/open3d_ml/backbones.py | 16 +- .../open3d_ml/sequences_dataset.py | 29 +- flash/setup_tools.py | 24 +- flash/tabular/classification/cli.py | 2 +- flash/tabular/classification/model.py | 6 +- flash/tabular/data.py | 41 +- flash/template/classification/backbones.py | 32 +- flash/template/classification/model.py | 2 +- flash/text/classification/cli.py | 4 +- flash/text/classification/data.py | 40 +- flash/text/seq2seq/core/data.py | 48 +- flash/text/seq2seq/core/metrics.py | 11 +- flash/text/seq2seq/core/model.py | 8 +- flash/text/seq2seq/core/utils.py | 3 +- flash/text/seq2seq/question_answering/data.py | 3 +- .../text/seq2seq/question_answering/model.py | 4 +- flash/text/seq2seq/summarization/cli.py | 4 +- flash/text/seq2seq/summarization/data.py | 3 +- flash/text/seq2seq/summarization/model.py | 4 +- flash/text/seq2seq/translation/cli.py | 4 +- flash/text/seq2seq/translation/data.py | 3 +- flash/video/classification/cli.py | 4 +- flash/video/classification/data.py | 64 +- flash/video/classification/model.py | 8 +- flash_examples/audio_classification.py | 12 +- flash_examples/custom_task.py | 19 +- flash_examples/image_classification.py | 12 +- .../image_classification_multi_label.py | 12 +- flash_examples/object_detection.py | 12 +- flash_examples/pointcloud_detection.py | 10 +- flash_examples/pointcloud_segmentation.py | 12 +- flash_examples/semantic_segmentation.py | 14 +- .../boston_prediction/inference_server.py | 1 - .../serve/generic/detection/inference.py | 6 +- .../inference_server.py | 2 +- flash_examples/speech_recognition.py | 2 +- flash_examples/style_transfer.py | 12 +- flash_examples/template.py | 12 +- flash_examples/text_classification.py | 12 +- .../text_classification_multi_label.py | 12 +- flash_examples/translation.py | 12 +- .../visualizations/pointcloud_segmentation.py | 12 +- setup.py | 8 +- tests/__init__.py | 2 +- tests/audio/classification/test_data.py | 46 +- tests/audio/speech_recognition/test_data.py | 6 +- .../test_data_model_integration.py | 6 +- tests/audio/speech_recognition/test_model.py | 5 +- tests/conftest.py | 2 +- tests/core/data/test_auto_dataset.py | 1 - tests/core/data/test_base_viz.py | 10 +- tests/core/data/test_batch.py | 20 +- tests/core/data/test_callback.py | 3 +- tests/core/data/test_callbacks.py | 5 +- tests/core/data/test_data_pipeline.py | 62 +- tests/core/data/test_data_source.py | 2 +- tests/core/data/test_process.py | 26 +- tests/core/data/test_sampler.py | 10 +- tests/core/data/test_serialization.py | 8 +- tests/core/data/test_splits.py | 1 - tests/core/data/test_transforms.py | 102 +- tests/core/serve/models.py | 18 +- .../serve/test_compat/test_cached_property.py | 3 - tests/core/serve/test_components.py | 36 +- tests/core/serve/test_composition.py | 26 +- .../core/serve/test_dag/test_optimization.py | 1158 +++++++++-------- tests/core/serve/test_dag/test_order.py | 112 +- tests/core/serve/test_dag/test_rewrite.py | 37 +- tests/core/serve/test_dag/test_task.py | 5 +- tests/core/serve/test_dag/test_utils.py | 5 +- tests/core/serve/test_gridbase_validations.py | 17 +- tests/core/serve/test_integration.py | 130 +- tests/core/serve/test_types/test_bbox.py | 20 +- tests/core/serve/test_types/test_repeated.py | 13 +- tests/core/serve/test_types/test_table.py | 9 +- tests/core/test_classification.py | 22 +- tests/core/test_data.py | 3 +- tests/core/test_finetuning.py | 7 +- tests/core/test_model.py | 60 +- tests/core/test_registry.py | 6 +- tests/core/test_trainer.py | 10 +- tests/core/test_utils.py | 3 +- tests/core/utilities/test_lightning_cli.py | 378 +++--- tests/examples/test_integrations.py | 7 +- tests/examples/test_scripts.py | 34 +- tests/examples/utils.py | 6 +- tests/graph/classification/test_data.py | 6 +- tests/graph/classification/test_model.py | 8 +- tests/helpers/boring_model.py | 5 +- tests/image/classification/test_data.py | 78 +- tests/image/classification/test_model.py | 8 +- tests/image/detection/test_data.py | 105 +- .../detection/test_data_model_integration.py | 8 +- tests/image/detection/test_model.py | 7 +- tests/image/detection/test_serialization.py | 6 +- tests/image/embedding/test_model.py | 2 +- tests/image/segmentation/test_backbones.py | 11 +- tests/image/segmentation/test_data.py | 12 +- tests/image/segmentation/test_heads.py | 15 +- tests/image/segmentation/test_model.py | 4 +- .../image/segmentation/test_serialization.py | 5 +- tests/image/test_backbones.py | 51 +- tests/pointcloud/detection/test_data.py | 5 +- tests/pointcloud/detection/test_model.py | 2 +- tests/pointcloud/segmentation/test_data.py | 5 +- tests/pointcloud/segmentation/test_model.py | 2 +- tests/tabular/classification/test_data.py | 12 +- tests/tabular/classification/test_model.py | 7 +- tests/template/classification/test_data.py | 14 +- tests/template/classification/test_model.py | 4 +- tests/text/classification/test_data.py | 6 +- tests/text/classification/test_model.py | 12 +- tests/text/seq2seq/core/test_data.py | 21 +- tests/text/seq2seq/core/test_metrics.py | 4 +- .../seq2seq/question_answering/test_model.py | 5 +- .../text/seq2seq/summarization/test_model.py | 5 +- tests/text/seq2seq/translation/test_data.py | 2 +- tests/text/seq2seq/translation/test_model.py | 5 +- tests/video/classification/test_model.py | 106 +- 208 files changed, 2458 insertions(+), 2673 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index d15cb85fd35..de58e174e6e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,7 +17,7 @@ import pt_lightning_sphinx_theme _PATH_HERE = os.path.abspath(os.path.dirname(__file__)) -_PATH_ROOT = os.path.join(_PATH_HERE, '..', '..') +_PATH_ROOT = os.path.join(_PATH_HERE, "..", "..") sys.path.insert(0, os.path.abspath(_PATH_ROOT)) try: @@ -33,9 +33,9 @@ def _load_py_module(fname, pkg="flash"): about = _load_py_module("__about__.py") -SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) +SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True)) -html_favicon = '_static/images/icon.svg' +html_favicon = "_static/images/icon.svg" # -- Project information ----------------------------------------------------- @@ -49,22 +49,22 @@ def _load_py_module(fname, pkg="flash"): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", # 'sphinx.ext.coverage', - 'sphinx.ext.viewcode', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'sphinx.ext.imgmath', - 'recommonmark', + "sphinx.ext.viewcode", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.imgmath", + "recommonmark", # 'sphinx.ext.autosectionlabel', # 'nbsphinx', # it seems some sphinx issue - 'sphinx_autodoc_typehints', - 'sphinx_copybutton', - 'sphinx_paramlinks', - 'sphinx_togglebutton', + "sphinx_autodoc_typehints", + "sphinx_copybutton", + "sphinx_paramlinks", + "sphinx_togglebutton", ] # autodoc: Default to members and undoc-members @@ -114,8 +114,8 @@ def _load_py_module(fname, pkg="flash"): # documentation. html_theme_options = { - 'pytorch_project': 'https://pytorchlightning.ai', - 'canonical_url': about.__docs_url__, + "pytorch_project": "https://pytorchlightning.ai", + "canonical_url": about.__docs_url__, "collapse_navigation": False, "display_version": True, "logo_only": False, @@ -132,20 +132,20 @@ def _load_py_module(fname, pkg="flash"): def setup(app): # this is for hiding doctest decoration, # see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/ - app.add_js_file('copybutton.js') - app.add_css_file('main.css') + app.add_js_file("copybutton.js") + app.add_css_file("main.css") # Ignoring Third-party packages # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule def _package_list_from_file(pfile): assert os.path.isfile(pfile) - with open(pfile, 'r') as fp: + with open(pfile, "r") as fp: lines = fp.readlines() list_pkgs = [] for ln in lines: - found = [ln.index(ch) for ch in list(',=<>#@') if ch in ln] - pkg = ln[:min(found)] if found else ln + found = [ln.index(ch) for ch in list(",=<>#@") if ch in ln] + pkg = ln[: min(found)] if found else ln if pkg.strip(): list_pkgs.append(pkg.strip()) return list_pkgs @@ -153,26 +153,26 @@ def _package_list_from_file(pfile): # define mapping from PyPI names to python imports PACKAGE_MAPPING = { - 'pytorch-lightning': 'pytorch_lightning', - 'scikit-learn': 'sklearn', - 'Pillow': 'PIL', - 'PyYAML': 'yaml', - 'rouge-score': 'rouge_score', - 'lightning-bolts': 'pl_bolts', - 'pytorch-tabnet': 'pytorch_tabnet', - 'pyDeprecate': 'deprecate', + "pytorch-lightning": "pytorch_lightning", + "scikit-learn": "sklearn", + "Pillow": "PIL", + "PyYAML": "yaml", + "rouge-score": "rouge_score", + "lightning-bolts": "pl_bolts", + "pytorch-tabnet": "pytorch_tabnet", + "pyDeprecate": "deprecate", } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: # mock also base packages when we are on RTD since we don't install them there - MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, 'requirements.txt')) + MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt")) # replace PyPI packages by importing ones MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES] autodoc_mock_imports = MOCK_PACKAGES # only run doctests marked with a ".. doctest::" directive -doctest_test_doctest_blocks = '' +doctest_test_doctest_blocks = "" doctest_global_setup = """ import torch import pytorch_lightning as pl diff --git a/flash/__about__.py b/flash/__about__.py index d66522a6694..e57715c058a 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,7 +1,7 @@ __version__ = "0.4.1dev" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" -__license__ = 'Apache-2.0' +__license__ = "Apache-2.0" __copyright__ = f"Copyright (c) 2020-2021, f{__author__}." __homepage__ = "https://github.com/PyTorchLightning/lightning-flash" __docs_url__ = "https://lightning-flash.readthedocs.io/en/stable/" diff --git a/flash/__init__.py b/flash/__init__.py index 7a13f9d20ba..e8321350c99 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -33,6 +33,7 @@ if _IS_TESTING: from pytorch_lightning import seed_everything + seed_everything(42) __all__ = [ diff --git a/flash/__main__.py b/flash/__main__.py index b93d9428d18..f4eb704a769 100644 --- a/flash/__main__.py +++ b/flash/__main__.py @@ -24,15 +24,16 @@ def main(): def register_command(command): - - @main.command(context_settings=dict( - help_option_names=[], - ignore_unknown_options=True, - )) - @click.argument('cli_args', nargs=-1, type=click.UNPROCESSED) + @main.command( + context_settings=dict( + help_option_names=[], + ignore_unknown_options=True, + ) + ) + @click.argument("cli_args", nargs=-1, type=click.UNPROCESSED) @functools.wraps(command) def wrapper(cli_args): - with patch('sys.argv', [command.__name__] + list(cli_args)): + with patch("sys.argv", [command.__name__] + list(cli_args)): command() @@ -63,5 +64,5 @@ def wrapper(cli_args): except ImportError: pass -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/flash/audio/classification/cli.py b/flash/audio/classification/cli.py index 38d24414009..c198a992395 100644 --- a/flash/audio/classification/cli.py +++ b/flash/audio/classification/cli.py @@ -44,12 +44,12 @@ def audio_classification(): AudioClassificationData, default_datamodule_builder=from_urban8k, default_arguments={ - 'trainer.max_epochs': 3, - } + "trainer.max_epochs": 3, + }, ) cli.trainer.save_checkpoint("audio_classification_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": audio_classification() diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index c458b279cb9..bcc421198ce 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -24,7 +24,6 @@ class AudioClassificationPreprocess(Preprocess): - @requires_extras(["audio", "image"]) def __init__( self, @@ -35,7 +34,7 @@ def __init__( spectrogram_size: Tuple[int, int] = (196, 196), time_mask_param: int = 80, freq_mask_param: int = 80, - deserializer: Optional['Deserializer'] = None, + deserializer: Optional["Deserializer"] = None, ): self.spectrogram_size = spectrogram_size self.time_mask_param = time_mask_param @@ -48,7 +47,7 @@ def __init__( predict_transform=predict_transform, data_sources={ DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource() + DefaultDataSources.FOLDERS: ImagePathsDataSource(), }, deserializer=deserializer or ImageDeserializer(), default_data_source=DefaultDataSources.FILES, diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index e1850eb06b9..4fe89d38273 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -41,13 +41,14 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable] } -def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, - freq_mask_param: int) -> Dict[str, Callable]: +def train_default_transforms( + spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int +) -> Dict[str, Callable]: """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" transforms = { "post_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), - ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)), ) } diff --git a/flash/audio/speech_recognition/cli.py b/flash/audio/speech_recognition/cli.py index e3b49929d1e..9bbdb48df8b 100644 --- a/flash/audio/speech_recognition/cli.py +++ b/flash/audio/speech_recognition/cli.py @@ -47,7 +47,7 @@ def speech_recognition(): SpeechRecognitionData, default_datamodule_builder=from_timit, default_arguments={ - 'trainer.max_epochs': 3, + "trainer.max_epochs": 3, }, finetune=False, ) @@ -55,5 +55,5 @@ def speech_recognition(): cli.trainer.save_checkpoint("speech_recognition_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": speech_recognition() diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index dd7f5d187fb..029419b50b9 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -44,7 +44,6 @@ class SpeechRecognitionDeserializer(Deserializer): - def deserialize(self, sample: Any) -> Dict: encoded_with_padding = (sample + "===").encode("ascii") audio = base64.b64decode(encoded_with_padding) @@ -52,9 +51,7 @@ def deserialize(self, sample: Any) -> Dict: data, sampling_rate = sf.read(buffer) return { DefaultDataKeys.INPUT: data, - DefaultDataKeys.METADATA: { - "sampling_rate": sampling_rate - }, + DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate}, } @property @@ -64,11 +61,13 @@ def example_input(self) -> str: class BaseSpeechRecognition: - def _load_sample(self, sample: Dict[str, Any]) -> Any: path = sample[DefaultDataKeys.INPUT] - if not os.path.isabs(path) and DefaultDataKeys.METADATA in sample and "root" in sample[DefaultDataKeys.METADATA - ]: + if ( + not os.path.isabs(path) + and DefaultDataKeys.METADATA in sample + and "root" in sample[DefaultDataKeys.METADATA] + ): path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path) speech_array, sampling_rate = sf.read(path) sample[DefaultDataKeys.INPUT] = speech_array @@ -77,7 +76,6 @@ def _load_sample(self, sample: Dict[str, Any]) -> Any: class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition): - def __init__(self, filetype: Optional[str] = None): super().__init__() self.filetype = filetype @@ -87,42 +85,42 @@ def load_data( data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], dataset: Optional[Any] = None, ) -> Union[Sequence[Mapping[str, Any]]]: - if self.filetype == 'json': + if self.filetype == "json": file, input_key, target_key, field = data else: file, input_key, target_key = data stage = self.running_stage.value - if self.filetype == 'json' and field is not None: + if self.filetype == "json" and field is not None: dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}, field=field) else: dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}) dataset = dataset_dict[stage] meta = {"root": os.path.dirname(file)} - return [{ - DefaultDataKeys.INPUT: input_file, - DefaultDataKeys.TARGET: target, - DefaultDataKeys.METADATA: meta, - } for input_file, target in zip(dataset[input_key], dataset[target_key])] + return [ + { + DefaultDataKeys.INPUT: input_file, + DefaultDataKeys.TARGET: target, + DefaultDataKeys.METADATA: meta, + } + for input_file, target in zip(dataset[input_key], dataset[target_key]) + ] def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: return self._load_sample(sample) class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource): - def __init__(self): - super().__init__(filetype='csv') + super().__init__(filetype="csv") class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource): - def __init__(self): - super().__init__(filetype='json') + super().__init__(filetype="json") class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition): - def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]: if isinstance(data, HFDataset): data = list(zip(data["file"], data["text"])) @@ -130,7 +128,6 @@ def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Seque class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition): - def __init__(self): super().__init__(("wav", "ogg", "flac", "mat")) @@ -139,7 +136,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: class SpeechRecognitionPreprocess(Preprocess): - @requires_extras("audio") def __init__( self, @@ -181,7 +177,6 @@ class SpeechRecognitionBackboneState(ProcessState): class SpeechRecognitionPostprocess(Postprocess): - @requires_extras("audio") def __init__(self): super().__init__() diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index d62767a8d8b..15cdcef4f9f 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -51,8 +51,9 @@ def __init__( # set os environ variable for multiprocesses os.environ["PYTHONWARNINGS"] = "ignore" - model = self.backbones.get(backbone - )() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone) + model = ( + self.backbones.get(backbone)() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone) + ) super().__init__( model=model, loss_fn=loss_fn, @@ -74,5 +75,5 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: out = self.model(batch["input_values"], labels=batch["labels"]) - out["logs"] = {'loss': out.loss} + out["logs"] = {"loss": out.loss} return out diff --git a/flash/core/classification.py b/flash/core/classification.py index d1775cb37c2..ba10162abcc 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -38,7 +38,6 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. class ClassificationTask(Task): - def __init__( self, *args, diff --git a/flash/core/data/auto_dataset.py b/flash/core/data/auto_dataset.py index 9a1251d4485..fcd03fb18c9 100644 --- a/flash/core/data/auto_dataset.py +++ b/flash/core/data/auto_dataset.py @@ -20,7 +20,7 @@ import flash from flash.core.data.utils import CurrentRunningStageFuncContext -DATA_TYPE = TypeVar('DATA_TYPE') +DATA_TYPE = TypeVar("DATA_TYPE") class BaseAutoDataset(Generic[DATA_TYPE]): @@ -41,7 +41,7 @@ class BaseAutoDataset(Generic[DATA_TYPE]): def __init__( self, data: DATA_TYPE, - data_source: 'flash.core.data.data_source.DataSource', + data_source: "flash.core.data.data_source.DataSource", running_stage: RunningStage, ) -> None: super().__init__() @@ -68,11 +68,11 @@ def running_stage(self, running_stage: RunningStage) -> None: self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( - 'load_sample', + "load_sample", self.data_source, self.running_stage, DataSource, - ) + ), ) def _call_load_sample(self, sample: Any) -> Any: diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 80094cc59ac..dd0ed1e9ddc 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -41,7 +41,7 @@ class _Sequential(torch.nn.Module): def __init__( self, - preprocess: 'Preprocess', + preprocess: "Preprocess", pre_tensor_transform: Optional[Callable], to_tensor_transform: Optional[Callable], post_tensor_transform: Callable, @@ -101,11 +101,10 @@ def __str__(self) -> str: class _DeserializeProcessor(torch.nn.Module): - def __init__( self, - deserializer: 'Deserializer', - preprocess: 'Preprocess', + deserializer: "Deserializer", + preprocess: "Preprocess", pre_tensor_transform: Callable, to_tensor_transform: Callable, ): @@ -137,10 +136,9 @@ def forward(self, sample: str): class _SerializeProcessor(torch.nn.Module): - def __init__( self, - serializer: 'Serializer', + serializer: "Serializer", ): super().__init__() self.serializer = convert_to_modules(serializer) @@ -151,28 +149,28 @@ def forward(self, sample): class _Preprocessor(torch.nn.Module): """ - This class is used to encapsultate the following functions of a Preprocess Object: - Inside a worker: - per_sample_transform: Function to transform an individual sample - Inside a worker, it is actually make of 3 functions: - * pre_tensor_transform - * to_tensor_transform - * post_tensor_transform - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform - - Inside main process: - per_sample_transform: Function to transform an individual sample - * per_sample_transform_on_device - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform_on_device + This class is used to encapsultate the following functions of a Preprocess Object: + Inside a worker: + per_sample_transform: Function to transform an individual sample + Inside a worker, it is actually make of 3 functions: + * pre_tensor_transform + * to_tensor_transform + * post_tensor_transform + collate: Function to merge sample into a batch + per_batch_transform: Function to transform an individual batch + * per_batch_transform + + Inside main process: + per_sample_transform: Function to transform an individual sample + * per_sample_transform_on_device + collate: Function to merge sample into a batch + per_batch_transform: Function to transform an individual batch + * per_batch_transform_on_device """ def __init__( self, - preprocess: 'Preprocess', + preprocess: "Preprocess", collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, @@ -349,7 +347,7 @@ def default_uncollate(batch: Any): if isinstance(batch, Mapping): return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())] - if isinstance(batch, tuple) and hasattr(batch, '_fields'): # namedtuple + if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple return [batch_type(*default_uncollate(sample)) for sample in zip(*batch)] if isinstance(batch, Sequence) and not isinstance(batch, str): diff --git a/flash/core/data/callback.py b/flash/core/data/callback.py index 96ef4edb1ba..b4c2aa93eed 100644 --- a/flash/core/data/callback.py +++ b/flash/core/data/callback.py @@ -47,7 +47,6 @@ def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningSta class ControlFlow(FlashCallback): - def __init__(self, callbacks: List[FlashCallback]): self._callbacks = callbacks @@ -208,7 +207,7 @@ def enable(self): yield self.enabled = False - def attach_to_preprocess(self, preprocess: 'flash.core.data.process.Preprocess') -> None: + def attach_to_preprocess(self, preprocess: "flash.core.data.process.Preprocess") -> None: preprocess.add_callbacks([self]) self._preprocess = preprocess diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 7273ac73fc5..f725069e16f 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -219,22 +219,22 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool if reset: self.data_fetcher.batches[stage] = {} - def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_train_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the train dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING] self._show_batch(stage_name, hooks_names, reset=reset) - def show_val_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_val_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the validation dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING] self._show_batch(stage_name, hooks_names, reset=reset) - def show_test_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_test_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the test dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.TESTING] self._show_batch(stage_name, hooks_names, reset=reset) - def show_predict_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the predict dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING] self._show_batch(stage_name, hooks_names, reset=reset) @@ -255,16 +255,16 @@ def set_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, val def set_running_stages(self): if self._train_ds: - self.set_dataset_attribute(self._train_ds, 'running_stage', RunningStage.TRAINING) + self.set_dataset_attribute(self._train_ds, "running_stage", RunningStage.TRAINING) if self._val_ds: - self.set_dataset_attribute(self._val_ds, 'running_stage', RunningStage.VALIDATING) + self.set_dataset_attribute(self._val_ds, "running_stage", RunningStage.VALIDATING) if self._test_ds: - self.set_dataset_attribute(self._test_ds, 'running_stage', RunningStage.TESTING) + self.set_dataset_attribute(self._test_ds, "running_stage", RunningStage.TESTING) if self._predict_ds: - self.set_dataset_attribute(self._predict_ds, 'running_stage', RunningStage.PREDICTING) + self.set_dataset_attribute(self._predict_ds, "running_stage", RunningStage.PREDICTING) def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]: if isinstance(dataset, (BaseAutoDataset, SplitDataset)): @@ -292,7 +292,7 @@ def _train_dataloader(self) -> DataLoader: shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, - sampler=self.sampler + sampler=self.sampler, ) return DataLoader( @@ -303,7 +303,7 @@ def _train_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=pin_memory, drop_last=drop_last, - collate_fn=collate_fn + collate_fn=collate_fn, ) def _val_dataloader(self) -> DataLoader: @@ -317,7 +317,7 @@ def _val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn + collate_fn=collate_fn, ) return DataLoader( @@ -325,7 +325,7 @@ def _val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn + collate_fn=collate_fn, ) def _test_dataloader(self) -> DataLoader: @@ -339,7 +339,7 @@ def _test_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn + collate_fn=collate_fn, ) return DataLoader( @@ -347,7 +347,7 @@ def _test_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn + collate_fn=collate_fn, ) def _predict_dataloader(self) -> DataLoader: @@ -366,7 +366,7 @@ def _predict_dataloader(self) -> DataLoader: batch_size=batch_size, num_workers=self.num_workers, pin_memory=pin_memory, - collate_fn=collate_fn + collate_fn=collate_fn, ) return DataLoader( @@ -455,7 +455,7 @@ def from_data_source( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to :meth:`~flash.core.data.data_source.DataSource.load_data` (``train_data``, ``val_data``, ``test_data``, ``predict_data``). The data source will be resolved from the instantiated @@ -555,7 +555,7 @@ def from_folders( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` @@ -638,7 +638,7 @@ def from_files( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FILES` from the passed or constructed @@ -725,7 +725,7 @@ def from_tensors( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.TENSOR` @@ -812,7 +812,7 @@ def from_numpy( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` @@ -899,7 +899,7 @@ def from_json( sampler: Optional[Sampler] = None, field: Optional[str] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.JSON` @@ -1008,7 +1008,7 @@ def from_csv( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` @@ -1092,7 +1092,7 @@ def from_datasets( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.DATASETS` @@ -1172,7 +1172,7 @@ def from_fiftyone( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given FiftyOne Datasets using the :class:`~flash.core.data.data_source.DataSource` of name diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index a377e736053..4c707ef8c21 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -49,7 +49,8 @@ def set_state(self, state: ProcessState): else: rank_zero_warn( f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will" - " only have an effect when a new data pipeline is created.", UserWarning + " only have an effect when a new data pipeline is created.", + UserWarning, ) def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: @@ -127,7 +128,7 @@ def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optiona """Cropped Version of https://github.com/PyTorchLightning/pytorch- lightning/blob/master/pytorch_lightning/utilities/model_helpers.py.""" - current_method_name = method_name if prefix is None else f'{prefix}_{method_name}' + current_method_name = method_name if prefix is None else f"{prefix}_{method_name}" if not hasattr(process_obj, current_method_name): return False @@ -144,7 +145,7 @@ def _is_overriden_recursive( if prefix is None and not hasattr(super_obj, method_name): raise MisconfigurationException(f"This function doesn't belong to the parent class {super_obj}") - current_method_name = method_name if prefix is None else f'{prefix}_{method_name}' + current_method_name = method_name if prefix is None else f"{prefix}_{method_name}" if not hasattr(process_obj, current_method_name): return DataPipeline._is_overriden_recursive(method_name, process_obj, super_obj) @@ -185,19 +186,19 @@ def _resolve_function_hierarchy( prefixes = [] if stage in (RunningStage.TRAINING, RunningStage.TUNING): - prefixes += ['train', 'fit'] + prefixes += ["train", "fit"] elif stage == RunningStage.VALIDATING: - prefixes += ['val', 'fit'] + prefixes += ["val", "fit"] elif stage == RunningStage.TESTING: - prefixes += ['test'] + prefixes += ["test"] elif stage == RunningStage.PREDICTING: - prefixes += ['predict'] + prefixes += ["predict"] prefixes += [None] for prefix in prefixes: if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix): - return function_name if prefix is None else f'{prefix}_{function_name}' + return function_name if prefix is None else f"{prefix}_{function_name}" return function_name @@ -222,8 +223,7 @@ def _create_collate_preprocessors( preprocess._default_collate = collate_fn func_names: Dict[str, str] = { - k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) - for k in self.PREPROCESS_FUNCS + k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } collate_fn: Callable = getattr(preprocess, func_names["collate"]) @@ -243,8 +243,8 @@ def _create_collate_preprocessors( is_per_overriden = per_batch_transform_overriden and per_sample_transform_on_device_overriden if collate_in_worker_from_transform is None and is_per_overriden: raise MisconfigurationException( - f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' - f'are mutually exclusive for stage {stage}' + f"{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` " + f"are mutually exclusive for stage {stage}" ) if isinstance(collate_in_worker_from_transform, bool): @@ -254,9 +254,9 @@ def _create_collate_preprocessors( per_sample_transform_on_device_overriden, collate_fn ) - worker_collate_fn = worker_collate_fn.collate_fn if isinstance( - worker_collate_fn, _Preprocessor - ) else worker_collate_fn + worker_collate_fn = ( + worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _Preprocessor) else worker_collate_fn + ) assert_contains_tensor = self._is_overriden_recursive( "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] @@ -265,26 +265,29 @@ def _create_collate_preprocessors( deserialize_processor = _DeserializeProcessor( self._deserializer, preprocess, - getattr(preprocess, func_names['pre_tensor_transform']), - getattr(preprocess, func_names['to_tensor_transform']), + getattr(preprocess, func_names["pre_tensor_transform"]), + getattr(preprocess, func_names["to_tensor_transform"]), ) worker_preprocessor = _Preprocessor( - preprocess, worker_collate_fn, + preprocess, + worker_collate_fn, _Sequential( preprocess, - None if is_serving else getattr(preprocess, func_names['pre_tensor_transform']), - None if is_serving else getattr(preprocess, func_names['to_tensor_transform']), - getattr(preprocess, func_names['post_tensor_transform']), + None if is_serving else getattr(preprocess, func_names["pre_tensor_transform"]), + None if is_serving else getattr(preprocess, func_names["to_tensor_transform"]), + getattr(preprocess, func_names["post_tensor_transform"]), stage, assert_contains_tensor=assert_contains_tensor, - ), getattr(preprocess, func_names['per_batch_transform']), stage + ), + getattr(preprocess, func_names["per_batch_transform"]), + stage, ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _Preprocessor( preprocess, device_collate_fn, - getattr(preprocess, func_names['per_sample_transform_on_device']), - getattr(preprocess, func_names['per_batch_transform_on_device']), + getattr(preprocess, func_names["per_sample_transform_on_device"]), + getattr(preprocess, func_names["per_batch_transform_on_device"]), stage, apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, @@ -293,7 +296,7 @@ def _create_collate_preprocessors( @staticmethod def _model_transfer_to_device_wrapper( - func: Callable, preprocessor: _Preprocessor, model: 'Task', stage: RunningStage + func: Callable, preprocessor: _Preprocessor, model: "Task", stage: RunningStage ) -> Callable: if not isinstance(func, _StageOrchestrator): @@ -303,7 +306,7 @@ def _model_transfer_to_device_wrapper( return func @staticmethod - def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, model: 'Task') -> Callable: + def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, model: "Task") -> Callable: if not isinstance(func, _StageOrchestrator): _original = func @@ -314,22 +317,22 @@ def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, m return func @staticmethod - def _get_dataloader(model: 'Task', loader_name: str) -> Tuple[DataLoader, str]: + def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]: dataloader, attr_name = None, None if hasattr(model, loader_name): dataloader = getattr(model, loader_name) attr_name = loader_name - elif model.trainer and hasattr(model.trainer, 'datamodule') and model.trainer.datamodule: - dataloader = getattr(model, f'trainer.datamodule.{loader_name}', None) - attr_name = f'trainer.datamodule.{loader_name}' + elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule: + dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None) + attr_name = f"trainer.datamodule.{loader_name}" return dataloader, attr_name @staticmethod - def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None: + def _set_loader(model: "Task", loader_name: str, new_loader: DataLoader) -> None: """This function is used to set the loader to model and/or datamodule.""" - *intermediates, final_name = loader_name.split('.') + *intermediates, final_name = loader_name.split(".") curr_attr = model # This relies on python calling all non-integral types by reference. @@ -342,7 +345,7 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None def _attach_preprocess_to_model( self, - model: 'Task', + model: "Task", stage: Optional[RunningStage] = None, device_transform_only: bool = False, is_serving: bool = False, @@ -357,7 +360,7 @@ def _attach_preprocess_to_model( for stage in stages: - loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' + loader_name = f"{_STAGES_PREFIX[stage]}_dataloader" dataloader, whole_attr_name = self._get_dataloader(model, loader_name) @@ -381,8 +384,8 @@ def _attach_preprocess_to_model( if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - _, dl_args['collate_fn'], device_collate_fn = self._create_collate_preprocessors( - stage=stage, collate_fn=dl_args['collate_fn'], is_serving=is_serving + _, dl_args["collate_fn"], device_collate_fn = self._create_collate_preprocessors( + stage=stage, collate_fn=dl_args["collate_fn"], is_serving=is_serving ) if isinstance(dl_args["dataset"], IterableDataset): @@ -405,8 +408,8 @@ def _attach_preprocess_to_model( self._set_loader(model, whole_attr_name, dataloader) - model.transfer_batch_to_device = ( - self._model_transfer_to_device_wrapper(model.transfer_batch_to_device, device_collate_fn, model, stage) + model.transfer_batch_to_device = self._model_transfer_to_device_wrapper( + model.transfer_batch_to_device, device_collate_fn, model, stage ) def _create_uncollate_postprocessors( @@ -447,10 +450,10 @@ def _create_uncollate_postprocessors( def _attach_postprocess_to_model( self, - model: 'Task', + model: "Task", stage: RunningStage, is_serving: bool = False, - ) -> 'Task': + ) -> "Task": model.predict_step = self._model_predict_step_wrapper( model.predict_step, self._create_uncollate_postprocessors(stage, is_serving=is_serving), model ) @@ -458,7 +461,7 @@ def _attach_postprocess_to_model( def _attach_to_model( self, - model: 'Task', + model: "Task", stage: RunningStage = None, is_serving: bool = False, ): @@ -468,13 +471,13 @@ def _attach_to_model( if not stage or stage == RunningStage.PREDICTING: self._attach_postprocess_to_model(model, RunningStage.PREDICTING, is_serving=is_serving) - def _detach_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): + def _detach_from_model(self, model: "Task", stage: Optional[RunningStage] = None): self._detach_preprocessing_from_model(model, stage) if not stage or stage == RunningStage.PREDICTING: self._detach_postprocess_from_model(model) - def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): + def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[RunningStage] = None): if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] elif isinstance(stage, RunningStage): @@ -493,7 +496,7 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin if not device_collate: device_collate = self._identity - loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' + loader_name = f"{_STAGES_PREFIX[stage]}_dataloader" dataloader, whole_attr_name = self._get_dataloader(model, loader_name) @@ -515,11 +518,11 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - if isinstance(dl_args['collate_fn'], _Preprocessor): + if isinstance(dl_args["collate_fn"], _Preprocessor): dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn if isinstance(dl_args["dataset"], IterableAutoDataset): - del dl_args['sampler'] + del dl_args["sampler"] del dl_args["batch_sampler"] @@ -536,9 +539,9 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin self._set_loader(model, whole_attr_name, dataloader) @staticmethod - def _detach_postprocess_from_model(model: 'Task'): + def _detach_postprocess_from_model(model: "Task"): - if hasattr(model.predict_step, '_original'): + if hasattr(model.predict_step, "_original"): # don't delete the predict_step here since we don't know # if any other pipeline is attached which may rely on this! model.predict_step = model.predict_step._original @@ -568,10 +571,10 @@ class _StageOrchestrator: RunningStage.VALIDATING: RunningStage.VALIDATING, RunningStage.TESTING: RunningStage.TESTING, RunningStage.PREDICTING: RunningStage.PREDICTING, - RunningStage.TUNING: RunningStage.TUNING + RunningStage.TUNING: RunningStage.TUNING, } - def __init__(self, func_to_wrap: Callable, model: 'Task') -> None: + def __init__(self, func_to_wrap: Callable, model: "Task") -> None: self.func = func_to_wrap self._stage_mapping = {k: None for k in RunningStage} diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index e4722df44d1..94a36dd535d 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -193,7 +193,7 @@ def __init__(self): self.metadata = {} def __setattr__(self, key, value): - if key != 'metadata': + if key != "metadata": self.metadata[key] = value object.__setattr__(self, key, value) @@ -390,10 +390,9 @@ def load_data( inputs, targets = data if targets is None: return self.predict_load_data(data) - return [{ - DefaultDataKeys.INPUT: input, - DefaultDataKeys.TARGET: target - } for input, target in zip(inputs, targets)] + return [ + {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in zip(inputs, targets) + ] @staticmethod def predict_load_data(data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: @@ -439,9 +438,9 @@ def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: # data is not path-like (e.g. it may be a list of paths) return False - def load_data(self, - data: Union[str, Tuple[List[str], List[Any]]], - dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + def load_data( + self, data: Union[str, Tuple[List[str], List[Any]]], dataset: Optional[Any] = None + ) -> Sequence[Mapping[str, Any]]: if self.isdir(data): classes, class_to_idx = self.find_classes(data) if not classes: @@ -460,9 +459,9 @@ def load_data(self, ) ) - def predict_load_data(self, - data: Union[str, List[str]], - dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + def predict_load_data( + self, data: Union[str, List[str]], dataset: Optional[Any] = None + ) -> Sequence[Mapping[str, Any]]: if self.isdir(data): data = [os.path.join(data, file) for file in os.listdir(data)] @@ -522,15 +521,19 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se def to_idx(t): return [class_to_idx[x] for x in t] + else: def to_idx(t): return class_to_idx[t] - return [{ - DefaultDataKeys.INPUT: f, - DefaultDataKeys.TARGET: to_idx(t), - } for f, t in zip(filepaths, targets)] + return [ + { + DefaultDataKeys.INPUT: f, + DefaultDataKeys.TARGET: to_idx(t), + } + for f, t in zip(filepaths, targets) + ] @staticmethod @requires("fiftyone") diff --git a/flash/core/data/process.py b/flash/core/data/process.py index f0e6bf79ca7..932ef8dc23e 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -32,7 +32,6 @@ class BasePreprocess(ABC): - @abstractmethod def get_state_dict(self) -> Dict[str, Any]: """Override this method to return state_dict.""" @@ -182,8 +181,8 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_sources: Optional[Dict[str, 'DataSource']] = None, - deserializer: Optional['Deserializer'] = None, + data_sources: Optional[Dict[str, "DataSource"]] = None, + deserializer: Optional["Deserializer"] = None, default_data_source: Optional[str] = None, ): super().__init__() @@ -221,7 +220,7 @@ def __init__( self._default_collate: Callable = default_collate @property - def deserializer(self) -> Optional['Deserializer']: + def deserializer(self) -> Optional["Deserializer"]: return self._deserializer def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: @@ -243,19 +242,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): preprocess_state_dict["_meta"]["module"] = self.__module__ preprocess_state_dict["_meta"]["class_name"] = self.__class__.__name__ preprocess_state_dict["_meta"]["_state"] = self._state - destination['preprocess.state_dict'] = preprocess_state_dict - self._ddp_params_and_buffers_to_ignore = ['preprocess.state_dict'] + destination["preprocess.state_dict"] = preprocess_state_dict + self._ddp_params_and_buffers_to_ignore = ["preprocess.state_dict"] return super()._save_to_state_dict(destination, prefix, keep_vars) - def _check_transforms(self, transform: Optional[Dict[str, Callable]], - stage: RunningStage) -> Optional[Dict[str, Callable]]: + def _check_transforms( + self, transform: Optional[Dict[str, Callable]], stage: RunningStage + ) -> Optional[Dict[str, Callable]]: if transform is None: return transform if not isinstance(transform, Dict): raise MisconfigurationException( - "Transform should be a dict. " - f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." + "Transform should be a dict. " f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." ) keys_diff = set(transform.keys()).difference(_PREPROCESS_FUNCS) @@ -270,8 +269,7 @@ def _check_transforms(self, transform: Optional[Dict[str, Callable]], if is_per_batch_transform_in and is_per_sample_transform_on_device_in: raise MisconfigurationException( - f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` ' - f'are mutually exclusive.' + f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` " f"are mutually exclusive." ) collate_in_worker: Optional[bool] = None @@ -317,16 +315,16 @@ def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: } @property - def callbacks(self) -> List['FlashCallback']: + def callbacks(self) -> List["FlashCallback"]: if not hasattr(self, "_callbacks"): self._callbacks: List[FlashCallback] = [] return self._callbacks @callbacks.setter - def callbacks(self, callbacks: List['FlashCallback']): + def callbacks(self, callbacks: List["FlashCallback"]): self._callbacks = callbacks - def add_callbacks(self, callbacks: List['FlashCallback']): + def add_callbacks(self, callbacks: List["FlashCallback"]): _callbacks = [c for c in callbacks if c not in self._callbacks] self._callbacks.extend(_callbacks) @@ -439,14 +437,13 @@ def data_source_of_name(self, data_source_name: str) -> DataSource: class DefaultPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_sources: Optional[Dict[str, 'DataSource']] = None, + data_sources: Optional[Dict[str, "DataSource"]] = None, default_data_source: Optional[str] = None, ): super().__init__( @@ -511,7 +508,7 @@ def save_sample(sample: Any, path: str) -> None: # TODO: Are those needed ? def format_sample_save_path(self, path: str) -> str: - path = os.path.join(path, f'sample_{self._saved_samples}.ptl') + path = os.path.join(path, f"sample_{self._saved_samples}.ptl") self._saved_samples += 1 return path @@ -570,7 +567,7 @@ def serialize(self, sample: Any) -> Any: return {key: serializer.serialize(sample[key]) for key, serializer in self._serializers.items()} raise ValueError("The model output must be a mapping when using a SerializerMapping.") - def attach_data_pipeline_state(self, data_pipeline_state: 'flash.core.data.data_pipeline.DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): for serializer in self._serializers.values(): serializer.attach_data_pipeline_state(data_pipeline_state) @@ -604,6 +601,6 @@ def deserialize(self, sample: Any) -> Any: return {key: deserializer.deserialize(sample[key]) for key, deserializer in self._deserializers.items()} raise ValueError("The model output must be a mapping when using a DeserializerMapping.") - def attach_data_pipeline_state(self, data_pipeline_state: 'flash.core.data.data_pipeline.DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): for deserializer in self._deserializers.values(): deserializer.attach_data_pipeline_state(data_pipeline_state) diff --git a/flash/core/data/properties.py b/flash/core/data/properties.py index 4ab24b74d9d..2a228467832 100644 --- a/flash/core/data/properties.py +++ b/flash/core/data/properties.py @@ -24,17 +24,16 @@ class ProcessState: """Base class for all process states.""" -STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState) +STATE_TYPE = TypeVar("STATE_TYPE", bound=ProcessState) class Properties: - def __init__(self): super().__init__() self._running_stage: Optional[RunningStage] = None self._current_fn: Optional[str] = None - self._data_pipeline_state: Optional['flash.core.data.data_pipeline.DataPipelineState'] = None + self._data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None self._state: Dict[Type[ProcessState], ProcessState] = {} def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: @@ -49,7 +48,7 @@ def set_state(self, state: ProcessState): if self._data_pipeline_state is not None: self._data_pipeline_state.set_state(state) - def attach_data_pipeline_state(self, data_pipeline_state: 'flash.core.data.data_pipeline.DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): self._data_pipeline_state = data_pipeline_state for state in self._state.values(): self._data_pipeline_state.set_state(state) diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index 5f6ddb0791a..d637ab4acc1 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -44,7 +44,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: inputs = inputs[0] outputs = super().forward(inputs) if not isinstance(outputs, Sequence): - outputs = (outputs, ) + outputs = (outputs,) result = {} result.update(x) diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 376092ac6a6..3779b7426e8 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -24,10 +24,10 @@ from tqdm.auto import tqdm as tq _STAGES_PREFIX = { - RunningStage.TRAINING: 'train', - RunningStage.TESTING: 'test', - RunningStage.VALIDATING: 'val', - RunningStage.PREDICTING: 'predict' + RunningStage.TRAINING: "train", + RunningStage.TESTING: "test", + RunningStage.VALIDATING: "val", + RunningStage.PREDICTING: "predict", } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} @@ -61,7 +61,6 @@ class CurrentRunningStageContext: - def __init__(self, running_stage: RunningStage, obj: Any, reset: bool = True): self._running_stage = running_stage self._obj = obj @@ -79,7 +78,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class CurrentFuncContext: - def __init__(self, current_fn: str, obj: Any): self._current_fn = current_fn self._obj = obj @@ -96,7 +94,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class CurrentRunningStageFuncContext: - def __init__(self, running_stage: RunningStage, current_fn: str, obj: Any): self._running_stage = running_stage self._current_fn = current_fn @@ -131,9 +128,9 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: if not os.path.exists(path): os.makedirs(path) - local_filename = os.path.join(path, url.split('/')[-1]) + local_filename = os.path.join(path, url.split("/")[-1]) r = requests.get(url, stream=True, verify=False) - file_size = int(r.headers['Content-Length']) if 'Content-Length' in r.headers else 0 + file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0 chunk_size = 1024 num_bars = int(file_size / chunk_size) if verbose: @@ -141,19 +138,19 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: print(dict(num_bars=num_bars)) if not os.path.exists(local_filename): - with open(local_filename, 'wb') as fp: + with open(local_filename, "wb") as fp: for chunk in tq( r.iter_content(chunk_size=chunk_size), total=num_bars, - unit='KB', + unit="KB", desc=local_filename, - leave=True # progressbar stays + leave=True, # progressbar stays ): fp.write(chunk) # type: ignore - if '.zip' in local_filename: + if ".zip" in local_filename: if os.path.exists(local_filename): - with zipfile.ZipFile(local_filename, 'r') as zip_ref: + with zipfile.ZipFile(local_filename, "r") as zip_ref: zip_ref.extractall(path) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 5e58bca0909..854164fb159 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -21,7 +21,6 @@ class NoFreeze(BaseFinetuning): - def freeze_before_training(self, pl_module: LightningModule) -> None: pass @@ -67,7 +66,6 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O class Freeze(FlashBaseFinetuning): - def finetune_function( self, pl_module: LightningModule, @@ -79,7 +77,6 @@ def finetune_function( class FreezeUnfreeze(FlashBaseFinetuning): - def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): super().__init__(attr_names, train_bn) self.unfreeze_epoch = unfreeze_epoch @@ -102,13 +99,12 @@ def finetune_function( class UnfreezeMilestones(FlashBaseFinetuning): - def __init__( self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_milestones: tuple = (5, 10), - num_layers: int = 5 + num_layers: int = 5, ): self.unfreeze_milestones = unfreeze_milestones self.num_layers = num_layers @@ -126,7 +122,7 @@ def finetune_function( if epoch == self.unfreeze_milestones[0]: # unfreeze num_layers last layers self.unfreeze_and_add_param_group( - modules=backbone_modules[-self.num_layers:], + modules=backbone_modules[-self.num_layers :], optimizer=optimizer, train_bn=self.train_bn, ) @@ -134,7 +130,7 @@ def finetune_function( elif epoch == self.unfreeze_milestones[1]: # unfreeze remaining layers self.unfreeze_and_add_param_group( - modules=backbone_modules[:-self.num_layers], + modules=backbone_modules[: -self.num_layers], optimizer=optimizer, train_bn=self.train_bn, ) @@ -144,7 +140,7 @@ def finetune_function( "no_freeze": NoFreeze, "freeze": Freeze, "freeze_unfreeze": FreezeUnfreeze, - "unfreeze_milestones": UnfreezeMilestones + "unfreeze_milestones": UnfreezeMilestones, } diff --git a/flash/core/model.py b/flash/core/model.py index f3862a6e7f0..51c77e879d5 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -51,11 +51,10 @@ class BenchmarkConvergenceCI(Callback): - def __init__(self): self.history = [] - def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.history.append(deepcopy(trainer.callback_metrics)) if trainer.current_epoch == trainer.max_epochs - 1: fn = getattr(pl_module, "_ci_benchmark_fn", None) @@ -87,7 +86,6 @@ def wrapper(self, *args, **kwargs) -> Any: class CheckDependenciesMeta(ABCMeta): - def __new__(mcs, *args, **kwargs): result = ABCMeta.__new__(mcs, *args, **kwargs) if result.required_extras is not None: @@ -396,21 +394,23 @@ def build_data_pipeline( deserializer, old_data_source, preprocess, postprocess, serializer = None, None, None, None, None # Datamodule - if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: - old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) - preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) - postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) - serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) - deserializer = getattr(self.datamodule.data_pipeline, '_deserializer', None) - - elif self.trainer is not None and hasattr(self.trainer, 'datamodule') and getattr( - self.trainer.datamodule, 'data_pipeline', None - ) is not None: - old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) - preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) - postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) - serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) - deserializer = getattr(self.trainer.datamodule.data_pipeline, '_deserializer', None) + if self.datamodule is not None and getattr(self.datamodule, "data_pipeline", None) is not None: + old_data_source = getattr(self.datamodule.data_pipeline, "data_source", None) + preprocess = getattr(self.datamodule.data_pipeline, "_preprocess_pipeline", None) + postprocess = getattr(self.datamodule.data_pipeline, "_postprocess_pipeline", None) + serializer = getattr(self.datamodule.data_pipeline, "_serializer", None) + deserializer = getattr(self.datamodule.data_pipeline, "_deserializer", None) + + elif ( + self.trainer is not None + and hasattr(self.trainer, "datamodule") + and getattr(self.trainer.datamodule, "data_pipeline", None) is not None + ): + old_data_source = getattr(self.trainer.datamodule.data_pipeline, "data_source", None) + preprocess = getattr(self.trainer.datamodule.data_pipeline, "_preprocess_pipeline", None) + postprocess = getattr(self.trainer.datamodule.data_pipeline, "_postprocess_pipeline", None) + serializer = getattr(self.trainer.datamodule.data_pipeline, "_serializer", None) + deserializer = getattr(self.trainer.datamodule.data_pipeline, "_deserializer", None) else: # TODO: we should log with low severity level that we use defaults to create # `preprocess`, `postprocess` and `serializer`. @@ -435,10 +435,10 @@ def build_data_pipeline( preprocess, postprocess, serializer, - getattr(data_pipeline, '_deserializer', None), - getattr(data_pipeline, '_preprocess_pipeline', None), - getattr(data_pipeline, '_postprocess_pipeline', None), - getattr(data_pipeline, '_serializer', None), + getattr(data_pipeline, "_deserializer", None), + getattr(data_pipeline, "_preprocess_pipeline", None), + getattr(data_pipeline, "_postprocess_pipeline", None), + getattr(data_pipeline, "_serializer", None), ) data_source = data_source or old_data_source @@ -481,10 +481,10 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: self._preprocess, self._postprocess, self._serializer, - getattr(data_pipeline, '_deserializer', None), - getattr(data_pipeline, '_preprocess_pipeline', None), - getattr(data_pipeline, '_postprocess_pipeline', None), - getattr(data_pipeline, '_serializer', None), + getattr(data_pipeline, "_deserializer", None), + getattr(data_pipeline, "_preprocess_pipeline", None), + getattr(data_pipeline, "_postprocess_pipeline", None), + getattr(data_pipeline, "_serializer", None), ) # self._preprocess.state_dict() @@ -494,12 +494,12 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: @torch.jit.unused @property def preprocess(self) -> Preprocess: - return getattr(self.data_pipeline, '_preprocess_pipeline', None) + return getattr(self.data_pipeline, "_preprocess_pipeline", None) @torch.jit.unused @property def postprocess(self) -> Postprocess: - return getattr(self.data_pipeline, '_postprocess_pipeline', None) + return getattr(self.data_pipeline, "_postprocess_pipeline", None) def on_train_dataloader(self) -> None: if self.data_pipeline is not None: @@ -538,18 +538,18 @@ def on_fit_end(self) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html - if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: - checkpoint['data_pipeline'] = self.data_pipeline - if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: - checkpoint['_data_pipeline_state'] = self._data_pipeline_state + if self.data_pipeline is not None and "data_pipeline" not in checkpoint: + checkpoint["data_pipeline"] = self.data_pipeline + if self._data_pipeline_state is not None and "_data_pipeline_state" not in checkpoint: + checkpoint["_data_pipeline_state"] = self._data_pipeline_state super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_load_checkpoint(checkpoint) - if 'data_pipeline' in checkpoint: - self.data_pipeline = checkpoint['data_pipeline'] - if '_data_pipeline_state' in checkpoint: - self._data_pipeline_state = checkpoint['_data_pipeline_state'] + if "data_pipeline" in checkpoint: + self.data_pipeline = checkpoint["data_pipeline"] + if "_data_pipeline_state" in checkpoint: + self._data_pipeline_state = checkpoint["_data_pipeline_state"] @classmethod def available_backbones(cls) -> List[str]: @@ -636,14 +636,13 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if 'preprocess.state_dict' in state_dict: + if "preprocess.state_dict" in state_dict: try: preprocess_state_dict = state_dict["preprocess.state_dict"] meta = preprocess_state_dict["_meta"] cls = getattr(import_module(meta["module"]), meta["class_name"]) self._preprocess = cls.load_state_dict( - {k: v - for k, v in preprocess_state_dict.items() if k != '_meta'}, + {k: v for k, v in preprocess_state_dict.items() if k != "_meta"}, strict=strict, ) self._preprocess._state = meta["_state"] @@ -685,7 +684,7 @@ def run_serve_sanity_check(self): print(f"Sanity check response: {resp.json()}") @requires_extras("serve") - def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition': + def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition": if not self.is_servable: raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") @@ -711,7 +710,7 @@ def set_state(self, state: ProcessState): if self._data_pipeline_state is not None: self._data_pipeline_state.set_state(state) - def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): for state in self._state.values(): data_pipeline_state.set_state(state) @@ -735,7 +734,7 @@ def _process_dataset( pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, - collate_fn=collate_fn + collate_fn=collate_fn, ) return dataset @@ -748,7 +747,7 @@ def process_train_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self._process_dataset( dataset, @@ -758,7 +757,7 @@ def process_train_dataset( collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, - sampler=sampler + sampler=sampler, ) def process_val_dataset( @@ -770,7 +769,7 @@ def process_val_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = False, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self._process_dataset( dataset, @@ -780,7 +779,7 @@ def process_val_dataset( collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, - sampler=sampler + sampler=sampler, ) def process_test_dataset( @@ -792,7 +791,7 @@ def process_test_dataset( collate_fn: Callable, shuffle: bool = False, drop_last: bool = True, - sampler: Optional[Sampler] = None + sampler: Optional[Sampler] = None, ) -> DataLoader: return self._process_dataset( dataset, @@ -802,7 +801,7 @@ def process_test_dataset( collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, - sampler=sampler + sampler=sampler, ) def process_predict_dataset( @@ -815,7 +814,7 @@ def process_predict_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, - convert_to_dataloader: bool = True + convert_to_dataloader: bool = True, ) -> Union[DataLoader, BaseAutoDataset]: return self._process_dataset( dataset, @@ -826,5 +825,5 @@ def process_predict_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, - convert_to_dataloader=convert_to_dataloader + convert_to_dataloader=convert_to_dataloader, ) diff --git a/flash/core/registry.py b/flash/core/registry.py index aafcdf6733f..e35e3e33798 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -36,7 +36,7 @@ def __contains__(self, key) -> bool: return any(key == e["name"] for e in self.functions) def __repr__(self) -> str: - return f'{self.__class__.__name__}(name={self.name}, functions={self.functions})' + return f"{self.__class__.__name__}(name={self.name}, functions={self.functions})" def get( self, @@ -73,7 +73,7 @@ def _register_function( fn: Callable, name: Optional[str] = None, override: bool = False, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ): if not isinstance(fn, FunctionType) and not isinstance(fn, partial): raise MisconfigurationException(f"You can only register a function, found: {fn}") @@ -102,11 +102,7 @@ def _find_matching_index(self, item: _REGISTERED_FUNCTION) -> Optional[int]: return idx def __call__( - self, - fn: Optional[Callable[..., Any]] = None, - name: Optional[str] = None, - override: bool = False, - **metadata + self, fn: Optional[Callable[..., Any]] = None, name: Optional[str] = None, override: bool = False, **metadata ) -> Callable: """This function is used to register new functions to the registry along their metadata. @@ -118,7 +114,7 @@ def __call__( # raise the error ahead of time if not (name is None or isinstance(name, str)): - raise TypeError(f'`name` must be a str, found {name}') + raise TypeError(f"`name` must be a str, found {name}") def _register(cls): self._register_function(fn=cls, name=name, override=override, metadata=metadata) diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py index 4e01306b2a5..bfc1bc82b81 100644 --- a/flash/core/schedulers.py +++ b/flash/core/schedulers.py @@ -7,8 +7,9 @@ if _TRANSFORMERS_AVAILABLE: from transformers import optimization + functions: List[Callable] = [ - getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') + getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != "get_scheduler") ] for fn in functions: _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/flash/core/serve/_compat/__init__.py b/flash/core/serve/_compat/__init__.py index 439ab3add04..50af1bf725e 100644 --- a/flash/core/serve/_compat/__init__.py +++ b/flash/core/serve/_compat/__init__.py @@ -1,3 +1,3 @@ from flash.core.serve._compat.cached_property import cached_property -__all__ = ("cached_property", ) +__all__ = ("cached_property",) diff --git a/flash/core/serve/_compat/cached_property.py b/flash/core/serve/_compat/cached_property.py index d490d1015ca..2adde68103a 100644 --- a/flash/core/serve/_compat/cached_property.py +++ b/flash/core/serve/_compat/cached_property.py @@ -5,7 +5,7 @@ credits: https://github.com/penguinolog/backports.cached_property """ -__all__ = ("cached_property", ) +__all__ = ("cached_property",) # Standard Library from sys import version_info diff --git a/flash/core/serve/component.py b/flash/core/serve/component.py index 47fbbdc316b..cf5c81f266f 100644 --- a/flash/core/serve/component.py +++ b/flash/core/serve/component.py @@ -41,7 +41,7 @@ def _validate_exposed_input_parameters_valid(instance): ) -def _validate_subclass_init_signature(cls: Type['ModelComponent']): +def _validate_subclass_init_signature(cls: Type["ModelComponent"]): """Raises SyntaxError if the __init__ method is not formatted correctly. Expects arguments: ['self', 'models', Optional['config']] @@ -163,7 +163,9 @@ def __new__(cls, name, bases, namespace): # alter namespace to insert flash serve info as bound components of class. exposed = first(ex_meths.values()) namespace["_flashserve_meta_"] = exposed.flashserve_meta - namespace["__call__"] = wraps(exposed)(exposed, ) + namespace["__call__"] = wraps(exposed)( + exposed, + ) new_cls = super().__new__(cls, name, bases, namespace) if new_cls.__name__ != "ModelComponent": @@ -243,5 +245,6 @@ def outputs(self) -> ParameterContainer: def uid(self) -> str: return self._flashserve_meta_.uid + else: ModelComponent = object diff --git a/flash/core/serve/composition.py b/flash/core/serve/composition.py index 5a6642cb4a2..f3f9e8441e5 100644 --- a/flash/core/serve/composition.py +++ b/flash/core/serve/composition.py @@ -14,8 +14,9 @@ concat, first = None, None -def _parse_composition_kwargs(**kwargs: Union[ModelComponent, - Endpoint]) -> Tuple[Dict[str, ModelComponent], Dict[str, Endpoint]]: +def _parse_composition_kwargs( + **kwargs: Union[ModelComponent, Endpoint] +) -> Tuple[Dict[str, ModelComponent], Dict[str, Endpoint]]: components, endpoints = {}, {} for k, v in kwargs.items(): @@ -28,8 +29,7 @@ def _parse_composition_kwargs(**kwargs: Union[ModelComponent, if len(components) > 1 and len(endpoints) == 0: raise ValueError( - "Must explicitly define atelast one Endpoint when " - "two or more components are included in a composition." + "Must explicitly define atelast one Endpoint when " "two or more components are included in a composition." ) return (components, endpoints) diff --git a/flash/core/serve/core.py b/flash/core/serve/core.py index 38a9a81d8c3..e05717212ae 100644 --- a/flash/core/serve/core.py +++ b/flash/core/serve/core.py @@ -41,8 +41,7 @@ class Endpoint: def __post_init__(self): if not isinstance(self.route, str): raise TypeError( - f"route parameter must be type={str}, recieved " - f"route={self.route} of type={type(self.route)}" + f"route parameter must be type={str}, recieved " f"route={self.route} of type={type(self.route)}" ) if not self.route.startswith("/"): raise ValueError("route must begin with a `slash` character (ie `/`).") @@ -76,8 +75,11 @@ def __call__(self, *args, **kwargs): return self.instance(*args, **kwargs) -ServableValidArgs_T = Union[Tuple[Type[pl.LightningModule], Union[HttpUrl, FilePath]], Tuple[HttpUrl], - Tuple[FilePath], ] +ServableValidArgs_T = Union[ + Tuple[Type[pl.LightningModule], Union[HttpUrl, FilePath]], + Tuple[HttpUrl], + Tuple[FilePath], +] class Servable: @@ -105,7 +107,7 @@ def __init__( self, *args: ServableValidArgs_T, download_path: Optional[Path] = None, - script_loader_cls: Type[FlashServeScriptLoader] = FlashServeScriptLoader + script_loader_cls: Type[FlashServeScriptLoader] = FlashServeScriptLoader, ): try: loc = args[-1] # last element in args is always loc @@ -175,8 +177,7 @@ def _repr_pretty_(self, p, cycle): # pragma: no cover def __str__(self): return ( - f"{self.source_component}.outputs.{self.source_key} >> " - f"{self.target_component}.inputs.{self.target_key}" + f"{self.source_component}.outputs.{self.source_key} >> " f"{self.target_component}.inputs.{self.target_key}" ) @@ -276,7 +277,6 @@ def __rshift__(self, other: "Parameter"): class DictAttrAccessBase: - def __grid_fields__(self) -> Iterator[str]: for field in dataclasses.fields(self): # noqa F402 yield field.name @@ -322,15 +322,16 @@ def make_parameter_container(data: Dict[str, Parameter]) -> ParameterContainer: ParameterContainer = make_dataclass( "ParameterContainer", dataclass_fields, - bases=(DictAttrAccessBase, ), + bases=(DictAttrAccessBase,), frozen=True, unsafe_hash=True, ) return ParameterContainer(**data) -def make_param_dict(inputs: Dict[str, BaseType], outputs: Dict[str, BaseType], - component_uid: str) -> Tuple[Dict[str, Parameter], Dict[str, Parameter]]: +def make_param_dict( + inputs: Dict[str, BaseType], outputs: Dict[str, BaseType], component_uid: str +) -> Tuple[Dict[str, Parameter], Dict[str, Parameter]]: """Convert exposed input/outputs parameters / dtypes to parameter objects. Returns diff --git a/flash/core/serve/dag/optimization.py b/flash/core/serve/dag/optimization.py index 4c937491b08..ea4293798e4 100644 --- a/flash/core/serve/dag/optimization.py +++ b/flash/core/serve/dag/optimization.py @@ -62,7 +62,7 @@ def default_fused_linear_keys_renamer(keys): if typ is tuple and len(keys[0]) > 0 and isinstance(keys[0][0], str): names = [key_split(x) for x in keys[:0:-1]] names.append(keys[0][0]) - return ("-".join(names), ) + keys[0][1:] + return ("-".join(names),) + keys[0][1:] return None @@ -381,7 +381,7 @@ def _enforce_max_key_limit(key_name): names = sorted(names) names.append(first_key[0]) concatenated_name = "-".join(names) - return (_enforce_max_key_limit(concatenated_name), ) + first_key[1:] + return (_enforce_max_key_limit(concatenated_name),) + first_key[1:] # PEP-484 compliant singleton constant @@ -552,16 +552,18 @@ def fuse( children_stack_pop() # This is a leaf node in the reduction region # key, task, fused_keys, height, width, number of nodes, fudge, set of edges - info_stack_append(( - child, - rv[child], - [child] if rename_keys else None, - 1, - 1, - 1, - 0, - deps[child] - reducible, - )) + info_stack_append( + ( + child, + rv[child], + [child] if rename_keys else None, + 1, + 1, + 1, + 0, + deps[child] - reducible, + ) + ) else: children_stack_pop() # Calculate metrics and fuse as appropriate @@ -591,7 +593,7 @@ def fuse( fudge += 1 # Sanity check; don't go too deep if new levels introduce new edge dependencies - if ((num_nodes + fudge) / height <= ave_width and (no_new_edges or height < max_depth_new_edges)): + if (num_nodes + fudge) / height <= ave_width and (no_new_edges or height < max_depth_new_edges): # Perform substitutions as we go val = subs(dsk[parent], child_key, child_task) deps_parent.remove(child_key) @@ -606,27 +608,31 @@ def fuse( if children_stack: if no_new_edges: # Linear fuse - info_stack_append(( - parent, - val, - child_keys, - height, - width, - num_nodes, - fudge, - edges, - )) + info_stack_append( + ( + parent, + val, + child_keys, + height, + width, + num_nodes, + fudge, + edges, + ) + ) else: - info_stack_append(( - parent, - val, - child_keys, - height + 1, - width, - num_nodes + 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + val, + child_keys, + height + 1, + width, + num_nodes + 1, + fudge, + edges, + ) + ) else: rv[parent] = val break @@ -639,16 +645,18 @@ def fuse( if fudge > int(ave_width - 1): fudge = int(ave_width - 1) # This task *implicitly* depends on `edges` - info_stack_append(( - parent, - rv[parent], - [parent] if rename_keys else None, - 1, - width, - 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + rv[parent], + [parent] if rename_keys else None, + 1, + width, + 1, + fudge, + edges, + ) + ) else: break else: @@ -716,16 +724,18 @@ def fuse( fused_trees[parent] = child_keys if children_stack: - info_stack_append(( - parent, - val, - child_keys, - height + 1, - width, - num_nodes + 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + val, + child_keys, + height + 1, + width, + num_nodes + 1, + fudge, + edges, + ) + ) else: rv[parent] = val break @@ -742,16 +752,18 @@ def fuse( fudge = int(ave_width - 1) # key, task, height, width, number of nodes, fudge, set of edges # This task *implicitly* depends on `edges` - info_stack_append(( - parent, - rv[parent], - [parent] if rename_keys else None, - 1, - width, - 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + rv[parent], + [parent] if rename_keys else None, + 1, + width, + 1, + fudge, + edges, + ) + ) else: break # Traverse upwards @@ -827,7 +839,7 @@ def _inplace_fuse_subgraphs(dsk, keys, dependencies, fused_trees, rename_keys): # Create new task inkeys = tuple(inkeys_set) - dsk[outkey] = (SubgraphCallable(subgraph, outkey, inkeys), ) + inkeys + dsk[outkey] = (SubgraphCallable(subgraph, outkey, inkeys),) + inkeys # Mutate `fused_trees` if key renaming is needed (renaming done in fuse) if rename_keys: diff --git a/flash/core/serve/dag/order.py b/flash/core/serve/dag/order.py index 881a66ad505..da096decb9e 100644 --- a/flash/core/serve/dag/order.py +++ b/flash/core/serve/dag/order.py @@ -321,7 +321,7 @@ def finish_now_key(x): if len(deps) == 1: # Fast path! We trim down `deps` above hoping to reach here. - (dep, ) = deps + (dep,) = deps if not inner_stack: if add_to_inner_stack: inner_stack = [dep] @@ -565,7 +565,7 @@ def graph_metrics(dependencies, dependents, total_dependencies): key = current_pop() parents = dependents[key] if len(parents) == 1: - (parent, ) = parents + (parent,) = parents ( total_dependents, min_dependencies, @@ -665,7 +665,7 @@ class StrComparable: False """ - __slots__ = ("obj", ) + __slots__ = ("obj",) def __init__(self, obj): self.obj = obj diff --git a/flash/core/serve/dag/rewrite.py b/flash/core/serve/dag/rewrite.py index bb876661ded..a7682b05ac7 100644 --- a/flash/core/serve/dag/rewrite.py +++ b/flash/core/serve/dag/rewrite.py @@ -354,7 +354,7 @@ def _top_level(net, term): def _bottom_up(net, term): if istask(term): - term = (head(term), ) + tuple(_bottom_up(net, t) for t in args(term)) + term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term)) elif isinstance(term, list): term = [_bottom_up(net, t) for t in args(term)] return net._rewrite(term) @@ -389,7 +389,7 @@ def _match(S, N): n = N.edges.get(VAR, None) if n: restore_state_flag = False - matches = matches + (S.term, ) + matches = matches + (S.term,) S.skip() N = n continue diff --git a/flash/core/serve/dag/task.py b/flash/core/serve/dag/task.py index a404cd3962a..da8becdfd49 100644 --- a/flash/core/serve/dag/task.py +++ b/flash/core/serve/dag/task.py @@ -399,7 +399,7 @@ def isdag(d, keys): class literal: """A small serializable object to wrap literal values without copying.""" - __slots__ = ("data", ) + __slots__ = ("data",) def __init__(self, data): self.data = data @@ -408,7 +408,7 @@ def __repr__(self): return "literal" % type(self.data).__name__ def __reduce__(self): - return (literal, (self.data, )) + return (literal, (self.data,)) def __call__(self): return self.data @@ -424,5 +424,5 @@ def quote(x): (literal,) """ if istask(x) or type(x) is list or type(x) is dict: - return (literal(x), ) + return (literal(x),) return x diff --git a/flash/core/serve/dag/visualize.py b/flash/core/serve/dag/visualize.py index fc2d60069ab..bc847d984a3 100644 --- a/flash/core/serve/dag/visualize.py +++ b/flash/core/serve/dag/visualize.py @@ -37,7 +37,7 @@ def _dag_to_graphviz(dag, dependencies, request_data, response_data, *, no_optim g.node(request_name, request_name, shape="oval") with g.subgraph(name=f"cluster_{cluster}") as c: c.node(task_key, task_key, shape="rectangle") - c.edge(task_key, task_key[:-len(".serial")]) + c.edge(task_key, task_key[: -len(".serial")]) g.edge(request_name, task_key) @@ -48,7 +48,7 @@ def _dag_to_graphviz(dag, dependencies, request_data, response_data, *, no_optim def visualize( - tc: 'TaskComposition', + tc: "TaskComposition", fhandle: BytesIO = None, format: str = "png", *, diff --git a/flash/core/serve/decorators.py b/flash/core/serve/decorators.py index ae647ef14d2..5569707000c 100644 --- a/flash/core/serve/decorators.py +++ b/flash/core/serve/decorators.py @@ -29,7 +29,7 @@ class UnboundMeta: @dataclass(unsafe_hash=True) class BoundMeta(UnboundMeta): - models: Union[List['Servable'], Tuple['Servable', ...], Dict[str, 'Servable']] + models: Union[List["Servable"], Tuple["Servable", ...], Dict[str, "Servable"]] uid: str = field(default_factory=lambda: uuid4().hex, init=False) out_attr_dict: ParameterContainer = field(default=None, init=False) inp_attr_dict: ParameterContainer = field(default=None, init=False) @@ -66,7 +66,7 @@ def __post_init__(self): ) @property - def connections(self) -> Sequence['Connection']: + def connections(self) -> Sequence["Connection"]: connections = [] for fld in fields(self.inp_attr_dict): connections.extend(getattr(self.inp_attr_dict, fld.name).connections) @@ -154,7 +154,6 @@ def expose(inputs: Dict[str, BaseType], outputs: Dict[str, BaseType]): _validate_expose_inputs_outputs_args(outputs) def wrapper(fn): - @wraps(fn) def wrapped(func): func.flashserve_meta = UnboundMeta(exposed=func, inputs=inputs, outputs=outputs) diff --git a/flash/core/serve/execution.py b/flash/core/serve/execution.py index e3ba5485f26..1546ff76d98 100644 --- a/flash/core/serve/execution.py +++ b/flash/core/serve/execution.py @@ -134,7 +134,7 @@ class UnprocessedTaskDask: def _process_initial( - endpoint_protocol: 'EndpointProtocol', components: Dict[str, 'ModelComponent'] + endpoint_protocol: "EndpointProtocol", components: Dict[str, "ModelComponent"] ) -> UnprocessedTaskDask: """Extract task dsk and payload / results keys and return computable form. @@ -154,22 +154,18 @@ def _process_initial( # mapping payload input keys -> serialized keys / tasks payload_dsk_key_map = { - payload_key: f"{input_key}.serial" - for payload_key, input_key in endpoint_protocol.dsk_input_key_map.items() + payload_key: f"{input_key}.serial" for payload_key, input_key in endpoint_protocol.dsk_input_key_map.items() } payload_input_tasks_dsk = { - input_dsk_key: (identity, payload_key) - for payload_key, input_dsk_key in payload_dsk_key_map.items() + input_dsk_key: (identity, payload_key) for payload_key, input_dsk_key in payload_dsk_key_map.items() } # mapping result keys -> serialize keys / tasks res_dsk_key_map = { - result_key: f"{output_key}.serial" - for result_key, output_key in endpoint_protocol.dsk_output_key_map.items() + result_key: f"{output_key}.serial" for result_key, output_key in endpoint_protocol.dsk_output_key_map.items() } result_output_tasks_dsk = { - result_key: (identity, output_dsk_key) - for result_key, output_dsk_key in res_dsk_key_map.items() + result_key: (identity, output_dsk_key) for result_key, output_dsk_key in res_dsk_key_map.items() } output_keys = list(res_dsk_key_map.keys()) @@ -198,10 +194,10 @@ def _process_initial( def build_composition( - endpoint_protocol: 'EndpointProtocol', - components: Dict[str, 'ModelComponent'], - connections: List['Connection'], -) -> 'TaskComposition': + endpoint_protocol: "EndpointProtocol", + components: Dict[str, "ModelComponent"], + connections: List["Connection"], +) -> "TaskComposition": r"""Build a composed graph. Notes on easy sources to introduce bugs. @@ -342,7 +338,7 @@ def _verify_no_cycles(dsk: Dict[str, tuple], out_keys: List[str], endpoint_name: ) -def connections_from_components_map(components: Dict[str, 'ModelComponent']) -> List[Dict[str, str]]: +def connections_from_components_map(components: Dict[str, "ModelComponent"]) -> List[Dict[str, str]]: dsk_connections = [] for con in flatten([comp._flashserve_meta_.connections for comp in components.values()]): # value of target key is mapped one-to-one from value of source @@ -350,7 +346,7 @@ def connections_from_components_map(components: Dict[str, 'ModelComponent']) -> return dsk_connections -def endpoint_protocol_content(ep_proto: 'EndpointProtocol') -> 'EndpointProtoJSON': +def endpoint_protocol_content(ep_proto: "EndpointProtocol") -> "EndpointProtoJSON": ep_proto_payload_dsk_key_map = valmap(lambda x: f"{x}.serial", ep_proto.dsk_input_key_map) ep_proto_result_key_dsk_map = valmap(lambda x: f"{x}.serial", ep_proto.dsk_output_key_map) @@ -362,7 +358,7 @@ def endpoint_protocol_content(ep_proto: 'EndpointProtocol') -> 'EndpointProtoJSO ) -def merged_dag_content(ep_proto: 'EndpointProtocol', components: Dict[str, 'ModelComponent']) -> 'MergedJSON': +def merged_dag_content(ep_proto: "EndpointProtocol", components: Dict[str, "ModelComponent"]) -> "MergedJSON": init = _process_initial(ep_proto, components) dsk_connections = connections_from_components_map(components) epjson = endpoint_protocol_content(ep_proto) @@ -376,7 +372,7 @@ def merged_dag_content(ep_proto: 'EndpointProtocol', components: Dict[str, 'Mode for request_name, task_key in init.payload_dsk_map.items(): cluster, *_ = task_key.split(".") - merged_proto[task_key[:-len(".serial")]].append(task_key) + merged_proto[task_key[: -len(".serial")]].append(task_key) merged_proto[task_key].append(request_name) merged_proto = dict(merged_proto) @@ -394,7 +390,7 @@ def merged_dag_content(ep_proto: 'EndpointProtocol', components: Dict[str, 'Mode ) -def component_dag_content(components: Dict[str, 'ModelComponent']) -> 'ComponentJSON': +def component_dag_content(components: Dict[str, "ModelComponent"]) -> "ComponentJSON": dsk_connections = connections_from_components_map(components) comp_dependencies, comp_dependents, comp_funcnames = {}, {}, {} diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index ea5ae85392b..f52afe63826 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -10,7 +10,6 @@ class FlashInputs(BaseType): - def __init__( self, deserializer: Callable, @@ -25,7 +24,6 @@ def deserialize(self, data: str) -> Any: # pragma: no cover class FlashOutputs(BaseType): - def __init__( self, serializer: Callable, @@ -53,7 +51,6 @@ def build_flash_serve_model_component(model): data_pipeline = model.build_data_pipeline() class FlashServeModelComponent(ModelComponent): - def __init__(self, model): self.model = model self.model.eval() diff --git a/flash/core/serve/interfaces/http.py b/flash/core/serve/interfaces/http.py index 594dea1b7f6..861ad329372 100644 --- a/flash/core/serve/interfaces/http.py +++ b/flash/core/serve/interfaces/http.py @@ -35,6 +35,7 @@ try: from typing import ForwardRef + RequestModel = ForwardRef("RequestModel") ResponseModel = ForwardRef("ResponseModel") except ImportError: @@ -47,7 +48,6 @@ def _build_endpoint( dsk_composition: TaskComposition, response_model: ResponseModel, ) -> Callable[[RequestModel], ResponseModel]: - def endpoint_fn(body: request_model): session = body.session if body.session else str(uuid.uuid4()) _res = get( @@ -67,7 +67,6 @@ def endpoint_fn(body: request_model): def _build_meta(Body: RequestModel) -> Callable[[], Dict[str, Any]]: - def meta() -> Dict[str, Any]: nonlocal Body return Body.schema() @@ -76,7 +75,6 @@ def meta() -> Dict[str, Any]: def _build_alive_check() -> Callable[[], Alive]: - def alive() -> Alive: return Alive.construct(alive=True) @@ -89,7 +87,6 @@ def _build_visualization( *, no_optimization: bool = False, ): - def endpoint_visualization(request: Request): nonlocal dsk_composition, templates, no_optimization with BytesIO() as f: @@ -104,8 +101,8 @@ def endpoint_visualization(request: Request): def _build_dag_json( - components: Dict[str, 'ModelComponent'], - ep_proto: Optional['EndpointProtocol'], + components: Dict[str, "ModelComponent"], + ep_proto: Optional["EndpointProtocol"], *, show_connected_components: bool = True, ): @@ -122,7 +119,7 @@ def dag_json(): return dag_json -def setup_http_app(composition: 'Composition', debug: bool) -> 'FastAPI': +def setup_http_app(composition: "Composition", debug: bool) -> "FastAPI": from flash import __version__ app = FastAPI( @@ -163,11 +160,13 @@ def setup_http_app(composition: 'Composition', debug: bool) -> 'FastAPI': name="components JSON DAG", summary="JSON representation of component DAG", response_model=ComponentJSON, - )(_build_dag_json( - components=composition.components, - ep_proto=None, - show_connected_components=False, - )) + )( + _build_dag_json( + components=composition.components, + ep_proto=None, + show_connected_components=False, + ) + ) for ep_name, ep_proto in composition.endpoint_protocols.items(): dsk = build_composition( @@ -221,9 +220,11 @@ def setup_http_app(composition: 'Composition', debug: bool) -> 'FastAPI': tags=[ep_name], summary="JSON representatino of DAG", response_model=MergedJSON, - )(_build_dag_json( - components=composition.components, - ep_proto=ep_proto, - show_connected_components=True, - )) + )( + _build_dag_json( + components=composition.components, + ep_proto=ep_proto, + show_connected_components=True, + ) + ) return app diff --git a/flash/core/serve/interfaces/models.py b/flash/core/serve/interfaces/models.py index 2ffec172f6d..3b2503b8665 100644 --- a/flash/core/serve/interfaces/models.py +++ b/flash/core/serve/interfaces/models.py @@ -12,6 +12,7 @@ try: from typing import ForwardRef + RequestModel = ForwardRef("RequestModel") ResponseModel = ForwardRef("ResponseModel") except ImportError: @@ -34,7 +35,7 @@ class initializer. Component inputs & outputs (as defined in `@expose` object de returned as subclasses of pydantic ``BaseModel``. """ - def __init__(self, name: str, endpoint: 'Endpoint', components: Dict[str, 'ModelComponent']): + def __init__(self, name: str, endpoint: "Endpoint", components: Dict[str, "ModelComponent"]): self._name = name self._endpoint = endpoint self._component = components @@ -119,10 +120,7 @@ def request_model(self) -> RequestModel: RequestModel = create_model( f"{self.name.title()}_RequestModel", __module__=self.__class__.__module__, - **{ - "session": (Optional[str], None), - "payload": (payload_model, ...) - }, + **{"session": (Optional[str], None), "payload": (payload_model, ...)}, ) RequestModel.update_forward_refs() return RequestModel @@ -180,10 +178,7 @@ def response_model(self) -> ResponseModel: ResponseModel = create_model( f"{self.name.title()}_Response", __module__=self.__class__.__module__, - **{ - "session": (Optional[str], None), - "result": (results_model, ...) - }, + **{"session": (Optional[str], None), "result": (results_model, ...)}, ) ResponseModel.update_forward_refs() return ResponseModel diff --git a/flash/core/serve/server.py b/flash/core/serve/server.py index a48df4925a3..ced1cc5fc96 100644 --- a/flash/core/serve/server.py +++ b/flash/core/serve/server.py @@ -25,7 +25,7 @@ class ServerMixin: DEBUG: bool TESTING: bool - def http_app(self) -> 'FastAPI': + def http_app(self) -> "FastAPI": return setup_http_app(composition=self, debug=self.DEBUG) def serve(self, host: str = "127.0.0.1", port: int = 8000): diff --git a/flash/core/serve/types/label.py b/flash/core/serve/types/label.py index 28cb0b18d1b..67e7340ce09 100644 --- a/flash/core/serve/types/label.py +++ b/flash/core/serve/types/label.py @@ -29,8 +29,7 @@ def __post_init__(self): if self.classes is None: if self.path is None: raise ValueError( - "Must provide either classes as a list or " - "path to a text file that contains classes" + "Must provide either classes as a list or " "path to a text file that contains classes" ) with Path(self.path).open(mode="r") as f: self.classes = tuple([item.strip() for item in f.readlines()]) diff --git a/flash/core/serve/types/table.py b/flash/core/serve/types/table.py index 22e3e57e9a9..5b993e7c57d 100644 --- a/flash/core/serve/types/table.py +++ b/flash/core/serve/types/table.py @@ -65,8 +65,7 @@ def deserialize(self, features: Dict[Union[int, str], Dict[int, Any]]): df = pd.DataFrame.from_dict(features) if len(self.column_names) != len(df.columns) or not np.all(df.columns == self.column_names): raise RuntimeError( - f"Failed to validate column names. \nExpected: " - f"{self.column_names}\nReceived: {list(df.columns)}" + f"Failed to validate column names. \nExpected: " f"{self.column_names}\nReceived: {list(df.columns)}" ) # TODO: This strict type checking needs to be changed when numpy arrays are returned if df.values.dtype.name not in allowed_types: diff --git a/flash/core/serve/utils.py b/flash/core/serve/utils.py index e3ca91c569e..813d3846ca1 100644 --- a/flash/core/serve/utils.py +++ b/flash/core/serve/utils.py @@ -7,7 +7,7 @@ def fn_outputs_to_keyed_map(serialize_fn_out_keys, fn_output) -> Dict[str, Any]: - """"convert outputs of a function to a dict of `{result_name: values}` + """ "convert outputs of a function to a dict of `{result_name: values}` accepts function outputs which are sequence, dict, or object. """ diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 5cc2cdd4f71..e376e3316b3 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -72,7 +72,6 @@ def insert_env_defaults(self, *args, **kwargs): class Trainer(PlTrainer): - @_defaults_from_env_vars def __init__(self, *args, serve_sanity_check: bool = False, **kwargs): if flash._IS_TESTING: @@ -186,7 +185,8 @@ def _resolve_callbacks(self, model, strategy): if strategy is not None: rank_zero_warn( "The model contains a default finetune callback. The provided {strategy} will be overriden.\n" - " HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning + " HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", + UserWarning, ) callback = model_callback else: @@ -214,7 +214,7 @@ def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser: return add_argparse_args(PlTrainer, *args, **kwargs) @classmethod - def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer': + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> "Trainer": """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" # the lightning trainer implementation does not support subclasses. diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index add089816f7..7cfb341342f 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -28,7 +28,6 @@ def drop_kwargs(func): - @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -46,7 +45,6 @@ def wrapper(*args, **kwargs): def make_args_optional(cls, args: Set[str]): - @wraps(cls) def wrapper(*args, **kwargs): return cls(*args, **kwargs) @@ -79,11 +77,10 @@ def get_overlapping_args(func_a, func_b) -> Set[str]: class FlashCLI(LightningCLI): - def __init__( self, model_class: Type[pl.LightningModule], - datamodule_class: Type['flash.DataModule'], + datamodule_class: Type["flash.DataModule"], trainer_class: Type[pl.Trainer] = flash.Trainer, default_datamodule_builder: Optional[Callable] = None, additional_datamodule_builders: Optional[List[Callable]] = None, @@ -171,9 +168,7 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None preprocess_function = class_from_function(drop_kwargs(self.local_datamodule_class.preprocess_cls)) subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand.add_class_arguments( - preprocess_function, - fail_untyped=False, - skip=get_overlapping_args(datamodule_function, preprocess_function) + preprocess_function, fail_untyped=False, skip=get_overlapping_args(datamodule_function, preprocess_function) ) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) @@ -189,7 +184,7 @@ def instantiate_classes(self) -> None: if getattr(self.datamodule, datamodule_attribute, None) is not None: self.config["model"][datamodule_attribute] = getattr(self.datamodule, datamodule_attribute) self.config_init = self.parser.instantiate_classes(self.config) - self.model = self.config_init['model'] + self.model = self.config_init["model"] self.instantiate_trainer() def prepare_fit_kwargs(self): diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index eaf16a41e6b..a1375fca9b8 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -100,36 +100,40 @@ def _compare_version(package: str, op, version) -> bool: if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") -_TEXT_AVAILABLE = all([ - _TRANSFORMERS_AVAILABLE, - _ROUGE_SCORE_AVAILABLE, - _SENTENCEPIECE_AVAILABLE, - _DATASETS_AVAILABLE, -]) +_TEXT_AVAILABLE = all( + [ + _TRANSFORMERS_AVAILABLE, + _ROUGE_SCORE_AVAILABLE, + _SENTENCEPIECE_AVAILABLE, + _DATASETS_AVAILABLE, + ] +) _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE _VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE -_IMAGE_AVAILABLE = all([ - _TORCHVISION_AVAILABLE, - _TIMM_AVAILABLE, - _PIL_AVAILABLE, - _KORNIA_AVAILABLE, - _PYSTICHE_AVAILABLE, - _SEGMENTATION_MODELS_AVAILABLE, -]) +_IMAGE_AVAILABLE = all( + [ + _TORCHVISION_AVAILABLE, + _TIMM_AVAILABLE, + _PIL_AVAILABLE, + _KORNIA_AVAILABLE, + _PYSTICHE_AVAILABLE, + _SEGMENTATION_MODELS_AVAILABLE, + ] +) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE _AUDIO_AVAILABLE = all([_ASTEROID_AVAILABLE, _TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE _EXTRAS_AVAILABLE = { - 'image': _IMAGE_AVAILABLE, - 'tabular': _TABULAR_AVAILABLE, - 'text': _TEXT_AVAILABLE, - 'video': _VIDEO_AVAILABLE, - 'pointcloud': _POINTCLOUD_AVAILABLE, - 'serve': _SERVE_AVAILABLE, - 'audio': _AUDIO_AVAILABLE, - 'graph': _GRAPH_AVAILABLE, + "image": _IMAGE_AVAILABLE, + "tabular": _TABULAR_AVAILABLE, + "text": _TEXT_AVAILABLE, + "video": _VIDEO_AVAILABLE, + "pointcloud": _POINTCLOUD_AVAILABLE, + "serve": _SERVE_AVAILABLE, + "audio": _AUDIO_AVAILABLE, + "graph": _GRAPH_AVAILABLE, } diff --git a/flash/core/utilities/lightning_cli.py b/flash/core/utilities/lightning_cli.py index 2a82eb9dd0c..1b5170b88fe 100644 --- a/flash/core/utilities/lightning_cli.py +++ b/flash/core/utilities/lightning_cli.py @@ -40,7 +40,7 @@ def __new__(cls, *args, **kwargs): return_type = inspect.signature(func).return_annotation if isinstance(return_type, str): - if return_type == 'DataModule': + if return_type == "DataModule": return_type = DataModule class ClassFromFunction(return_type, ClassFromFunctionBase): # type: ignore @@ -64,17 +64,22 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non """ super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) self.add_argument( - '--config', action=ActionConfigFile, help='Path to a configuration file in json or yaml format.' + "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) self.callback_keys: List[str] = [] self.optimizers_and_lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} def add_lightning_class_args( self, - lightning_class: Union[Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - Type[Trainer], Type[LightningModule], Type[LightningDataModule], Type[Callback]], + lightning_class: Union[ + Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], + Type[Trainer], + Type[LightningModule], + Type[LightningDataModule], + Type[Callback], + ], nested_key: str, - subclass_mode: bool = False + subclass_mode: bool = False, ) -> List[str]: """Adds arguments from a lightning class to a nested key of the parser. @@ -107,8 +112,8 @@ def add_lightning_class_args( def add_optimizer_args( self, optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]], - nested_key: str = 'optimizer', - link_to: str = 'AUTOMATIC', + nested_key: str = "optimizer", + link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from an optimizer class to a nested key of the parser. @@ -122,9 +127,9 @@ def add_optimizer_args( else: assert issubclass(optimizer_class, Optimizer) kwargs = { - 'instantiate': False, - 'fail_untyped': False, - 'skip': {'params'}, + "instantiate": False, + "fail_untyped": False, + "skip": {"params"}, } if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, required=True, **kwargs) @@ -135,8 +140,8 @@ def add_optimizer_args( def add_lr_scheduler_args( self, lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]], - nested_key: str = 'lr_scheduler', - link_to: str = 'AUTOMATIC', + nested_key: str = "lr_scheduler", + link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from a learning rate scheduler class to a nested key of the parser. @@ -150,9 +155,9 @@ def add_lr_scheduler_args( else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = { - 'instantiate': False, - 'fail_untyped': False, - 'skip': {'optimizer'}, + "instantiate": False, + "fail_untyped": False, + "skip": {"optimizer"}, } if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, required=True, **kwargs) @@ -188,10 +193,10 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st config_path = os.path.join(log_dir, self.config_filename) if not self.overwrite and os.path.isfile(config_path): raise RuntimeError( - f'{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting' - ' results of a previous run. You can delete the previous config file,' - ' set `LightningCLI(save_config_callback=None)` to disable config saving,' - ' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.' + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." ) if trainer.is_global_zero: # save only on rank zero to avoid race conditions on DDP. @@ -200,7 +205,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st get_filesystem(log_dir).makedirs(log_dir, exist_ok=True) self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite) - def __reduce__(self) -> Tuple[Type['SaveConfigCallback'], Tuple, Dict]: + def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: # `ArgumentParser` is un-pickleable. Drop it return ( self.__class__, @@ -217,17 +222,17 @@ def __init__( model_class: Union[Type[LightningModule], Callable[..., LightningModule]], datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, - save_config_filename: str = 'config.yaml', + save_config_filename: str = "config.yaml", save_config_overwrite: bool = False, trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, trainer_defaults: Dict[str, Any] = None, seed_everything_default: int = None, - description: str = 'pytorch-lightning trainer command line tool', - env_prefix: str = 'PL', + description: str = "pytorch-lightning trainer command line tool", + env_prefix: str = "PL", env_parse: bool = False, parser_kwargs: Dict[str, Any] = None, subclass_mode_model: bool = False, - subclass_mode_data: bool = False + subclass_mode_data: bool = False, ) -> None: """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args and then runs @@ -285,15 +290,15 @@ def __init__( self.subclass_mode_model = subclass_mode_model self.subclass_mode_data = subclass_mode_data self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs - self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse}) + self.parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) self.init_parser() self.add_core_arguments_to_parser() self.add_arguments_to_parser(self.parser) self.link_optimizers_and_lr_schedulers() self.parse_arguments() - if self.config['seed_everything'] is not None: - seed_everything(self.config['seed_everything'], workers=True) + if self.config["seed_everything"] is not None: + seed_everything(self.config["seed_everything"], workers=True) self.before_instantiate_classes() self.instantiate_classes() self.add_configure_optimizers_method_to_model() @@ -309,17 +314,17 @@ def init_parser(self) -> None: def add_core_arguments_to_parser(self) -> None: """Adds arguments from the core classes to the parser.""" self.parser.add_argument( - '--seed_everything', + "--seed_everything", type=Optional[int], default=self.seed_everything_default, - help='Set to an int to run seed_everything with this value before classes instantiation', + help="Set to an int to run seed_everything with this value before classes instantiation", ) - self.parser.add_lightning_class_args(self.trainer_class, 'trainer') - trainer_defaults = {'trainer.' + k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} + self.parser.add_lightning_class_args(self.trainer_class, "trainer") + trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} self.parser.set_defaults(trainer_defaults) - self.parser.add_lightning_class_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) + self.parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: - self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) + self.parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Implement to add extra arguments to parser or link arguments. @@ -331,7 +336,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def link_optimizers_and_lr_schedulers(self) -> None: """Creates argument links for optimizers and lr_schedulers that specified a link_to.""" for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items(): - if link_to == 'AUTOMATIC': + if link_to == "AUTOMATIC": continue if isinstance(class_type, tuple): self.parser.link_arguments(key, link_to) @@ -349,27 +354,27 @@ def before_instantiate_classes(self) -> None: def instantiate_classes(self) -> None: """Instantiates the classes using settings from self.config.""" self.config_init = self.parser.instantiate_classes(self.config) - self.datamodule = self.config_init.get('data') - self.model = self.config_init['model'] + self.datamodule = self.config_init.get("data") + self.model = self.config_init["model"] self.instantiate_trainer() def instantiate_trainer(self) -> None: """Instantiates the trainer using self.config_init['trainer']""" - if self.config_init['trainer'].get('callbacks') is None: - self.config_init['trainer']['callbacks'] = [] + if self.config_init["trainer"].get("callbacks") is None: + self.config_init["trainer"]["callbacks"] = [] callbacks = [self.config_init[c] for c in self.parser.callback_keys] - self.config_init['trainer']['callbacks'].extend(callbacks) - if 'callbacks' in self.trainer_defaults: - if isinstance(self.trainer_defaults['callbacks'], list): - self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) + self.config_init["trainer"]["callbacks"].extend(callbacks) + if "callbacks" in self.trainer_defaults: + if isinstance(self.trainer_defaults["callbacks"], list): + self.config_init["trainer"]["callbacks"].extend(self.trainer_defaults["callbacks"]) else: - self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) - if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']: + self.config_init["trainer"]["callbacks"].append(self.trainer_defaults["callbacks"]) + if self.save_config_callback and not self.config_init["trainer"]["fast_dev_run"]: config_callback = self.save_config_callback( self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite ) - self.config_init['trainer']['callbacks'].append(config_callback) - self.trainer = self.trainer_class(**self.config_init['trainer']) + self.config_init["trainer"]["callbacks"].append(config_callback) + self.trainer = self.trainer_class(**self.config_init["trainer"]) def add_configure_optimizers_method_to_model(self) -> None: """Adds to the model an automatically generated configure_optimizers method. @@ -382,8 +387,8 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: automatic = [] for key, (base_class, link_to) in self.parser.optimizers_and_lr_schedulers.items(): if not isinstance(base_class, tuple): - base_class = (base_class, ) - if link_to == 'AUTOMATIC' and any(issubclass(c, class_type) for c in base_class): + base_class = (base_class,) + if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): automatic.append(key) return automatic @@ -402,7 +407,7 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: "#optimizers-and-learning-rate-schedulers" ) - if is_overridden('configure_optimizers', self.model): + if is_overridden("configure_optimizers", self.model): warnings.warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." @@ -420,7 +425,7 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) def configure_optimizers( - self: LightningModule + self: LightningModule, ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: optimizer = instantiate_class(self.parameters(), optimizer_init) if not lr_scheduler_init: @@ -432,9 +437,9 @@ def configure_optimizers( def prepare_fit_kwargs(self) -> None: """Prepares fit_kwargs including datamodule using self.config_init['data'] if given.""" - self.fit_kwargs = {'model': self.model} + self.fit_kwargs = {"model": self.model} if self.datamodule is not None: - self.fit_kwargs['datamodule'] = self.datamodule + self.fit_kwargs["datamodule"] = self.datamodule def before_fit(self) -> None: """Implement to run some code before fit is started.""" @@ -449,13 +454,12 @@ def after_fit(self) -> None: def _global_add_class_path(class_type: Type, init_args: Dict[str, Any]) -> Dict[str, Any]: return { - 'class_path': class_type.__module__ + '.' + class_type.__name__, - 'init_args': init_args, + "class_path": class_type.__module__ + "." + class_type.__name__, + "init_args": init_args, } def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]: - def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]: return _global_add_class_path(class_type, init_args) @@ -472,10 +476,10 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) - Returns: The instantiated class object. """ - kwargs = init.get('init_args', {}) + kwargs = init.get("init_args", {}) if not isinstance(args, tuple): - args = (args, ) - class_module, class_name = init['class_path'].rsplit('.', 1) + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) args_class = getattr(module, class_name) return args_class(*args, **kwargs) diff --git a/flash/core/utilities/url_error.py b/flash/core/utilities/url_error.py index cd1f772e284..83559131c96 100644 --- a/flash/core/utilities/url_error.py +++ b/flash/core/utilities/url_error.py @@ -18,7 +18,6 @@ def catch_url_error(fn): - @functools.wraps(fn) def wrapper(*args, pretrained=False, **kwargs): try: @@ -28,7 +27,8 @@ def wrapper(*args, pretrained=False, **kwargs): rank_zero_warn( "Failed to download pretrained weights for the selected backbone. The backbone has been created with" " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" - " ignored.", UserWarning + " ignored.", + UserWarning, ) return result diff --git a/flash/graph/classification/cli.py b/flash/graph/classification/cli.py index 8d9e100695a..f79af259d8c 100644 --- a/flash/graph/classification/cli.py +++ b/flash/graph/classification/cli.py @@ -52,14 +52,14 @@ def graph_classification(): GraphClassificationData, default_datamodule_builder=from_tu_dataset, default_arguments={ - 'trainer.max_epochs': 3, + "trainer.max_epochs": 3, }, finetune=False, - datamodule_attributes={"num_classes", "num_features"} + datamodule_attributes={"num_classes", "num_features"}, ) cli.trainer.save_checkpoint("graph_classification.pt") -if __name__ == '__main__': +if __name__ == "__main__": graph_classification() diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index f49f8082c8b..cd5e3568f81 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -25,7 +25,6 @@ class GraphClassificationPreprocess(Preprocess): - @requires_extras("graph") def __init__( self, diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index 6fe1b618441..e4d96c2d92e 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -29,7 +29,6 @@ class GraphBlock(nn.Module): - def __init__(self, nc_input, nc_output, conv_cls, act=nn.ReLU(), **conv_kwargs): super().__init__() self.conv = conv_cls(nc_input, nc_output, **conv_kwargs) @@ -43,7 +42,6 @@ def forward(self, x, edge_index, edge_weight): class BaseGraphModel(nn.Module): - def __init__( self, num_features: int, diff --git a/flash/graph/data.py b/flash/graph/data.py index 1987852675a..a3d020bc367 100644 --- a/flash/graph/data.py +++ b/flash/graph/data.py @@ -24,7 +24,6 @@ class GraphDatasetDataSource(DatasetDataSource): - @requires_extras("graph") def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: data = super().load_data(data, dataset=dataset) diff --git a/flash/image/backbones.py b/flash/image/backbones.py index d3bca51b976..82bb8dc8a6c 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -43,5 +43,5 @@ def _fn_resnet_fpn( fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), name=model_name, package="torchvision", - type="resnet-fpn" + type="resnet-fpn", ) diff --git a/flash/image/classification/backbones/resnet.py b/flash/image/classification/backbones/resnet.py index 27f150ee30d..ccbbe14d1b3 100644 --- a/flash/image/classification/backbones/resnet.py +++ b/flash/image/classification/backbones/resnet.py @@ -38,7 +38,7 @@ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, d padding=dilation, groups=groups, bias=False, - dilation=dilation + dilation=dilation, ) @@ -60,13 +60,13 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 @@ -116,12 +116,12 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) @@ -157,7 +157,6 @@ def forward(self, x: Tensor) -> Tensor: class ResNet(nn.Module): - def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], @@ -245,7 +244,7 @@ def _make_layer( planes: int, blocks: int, stride: int = 1, - dilate: bool = False + dilate: bool = False, ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None @@ -320,11 +319,11 @@ def _resnet( model_weights = None if pretrained_flag: - if 'supervised' not in weights_paths: - raise KeyError('Supervised pretrained weights not available for {0}'.format(model_name)) + if "supervised" not in weights_paths: + raise KeyError("Supervised pretrained weights not available for {0}".format(model_name)) model_weights = load_state_dict_from_url( - weights_paths['supervised'], map_location=torch.device('cpu') if device == -1 else torch.device(device) + weights_paths["supervised"], map_location=torch.device("cpu") if device == -1 else torch.device(device) ) # for supervised pretrained weights @@ -334,7 +333,7 @@ def _resnet( if not pretrained_flag and isinstance(pretrained, str): if pretrained in weights_paths: model_weights = load_state_dict_from_url( - weights_paths[pretrained], map_location=torch.device('cpu') if device == -1 else torch.device(device) + weights_paths[pretrained], map_location=torch.device("cpu") if device == -1 else torch.device(device) ) if "classy_state_dict" in model_weights.keys(): @@ -344,11 +343,10 @@ def _resnet( for (key, val) in model_weights.items() } else: - raise KeyError('Unrecognized state dict. Logic for loading the current state dict missing.') + raise KeyError("Unrecognized state dict. Logic for loading the current state dict missing.") else: raise KeyError( - f"Requested weights for {model_name} not available," - f" choose from one of {weights_paths.keys()}" + f"Requested weights for {model_name} not available," f" choose from one of {weights_paths.keys()}" ) if model_weights is not None: @@ -359,78 +357,65 @@ def _resnet( HTTPS_VISSL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/" RESNET50_WEIGHTS_PATHS = { - "supervised": 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + "supervised": "https://download.pytorch.org/models/resnet50-0676ba61.pth", "simclr": HTTPS_VISSL + "simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/" "model_final_checkpoint_phase799.torch", "swav": HTTPS_VISSL + "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/" "model_final_checkpoint_phase799.torch", } RESNET50W2_WEIGHTS_PATHS = { - 'simclr': HTTPS_VISSL + 'simclr_rn50w2_1000ep_simclr_8node_resnet_16_07_20.e1e3bbf0/' - 'model_final_checkpoint_phase999.torch', - 'swav': HTTPS_VISSL + 'swav_rn50w2_in1k_bs32_16node_400ep_swav_8node_resnet_30_07_20.93563e51/' - 'model_final_checkpoint_phase399.torch', + "simclr": HTTPS_VISSL + "simclr_rn50w2_1000ep_simclr_8node_resnet_16_07_20.e1e3bbf0/" + "model_final_checkpoint_phase999.torch", + "swav": HTTPS_VISSL + "swav_rn50w2_in1k_bs32_16node_400ep_swav_8node_resnet_30_07_20.93563e51/" + "model_final_checkpoint_phase399.torch", } RESNET50W4_WEIGHTS_PATHS = { - 'simclr': HTTPS_VISSL + 'simclr_rn50w4_1000ep_bs32_16node_simclr_8node_resnet_28_07_20.9e20b0ae/' - 'model_final_checkpoint_phase999.torch', - 'swav': HTTPS_VISSL + 'swav_rn50w4_in1k_bs40_8node_400ep_swav_8node_resnet_30_07_20.1736135b/' - 'model_final_checkpoint_phase399.torch', + "simclr": HTTPS_VISSL + "simclr_rn50w4_1000ep_bs32_16node_simclr_8node_resnet_28_07_20.9e20b0ae/" + "model_final_checkpoint_phase999.torch", + "swav": HTTPS_VISSL + "swav_rn50w4_in1k_bs40_8node_400ep_swav_8node_resnet_30_07_20.1736135b/" + "model_final_checkpoint_phase399.torch", } RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet50w2", "resnet50w4"] RESNET_PARAMS = [ { - 'block': BasicBlock, - 'layers': [2, 2, 2, 2], - 'num_features': 512, - 'weights_paths': { - "supervised": 'https://download.pytorch.org/models/resnet18-f37072fd.pth' - } - }, - { - 'block': BasicBlock, - 'layers': [3, 4, 6, 3], - 'num_features': 512, - 'weights_paths': { - "supervised": 'https://download.pytorch.org/models/resnet34-b627a593.pth' - } + "block": BasicBlock, + "layers": [2, 2, 2, 2], + "num_features": 512, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet18-f37072fd.pth"}, }, { - 'block': Bottleneck, - 'layers': [3, 4, 6, 3], - 'num_features': 2048, - 'weights_paths': RESNET50_WEIGHTS_PATHS + "block": BasicBlock, + "layers": [3, 4, 6, 3], + "num_features": 512, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet34-b627a593.pth"}, }, + {"block": Bottleneck, "layers": [3, 4, 6, 3], "num_features": 2048, "weights_paths": RESNET50_WEIGHTS_PATHS}, { - 'block': Bottleneck, - 'layers': [3, 4, 23, 3], - 'num_features': 2048, - 'weights_paths': { - "supervised": 'https://download.pytorch.org/models/resnet101-63fe2227.pth' - } + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "num_features": 2048, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet101-63fe2227.pth"}, }, { - 'block': Bottleneck, - 'layers': [3, 8, 36, 3], - 'num_features': 2048, - 'weights_paths': { - "supervised": 'https://download.pytorch.org/models/resnet152-394f9c45.pth' - } + "block": Bottleneck, + "layers": [3, 8, 36, 3], + "num_features": 2048, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet152-394f9c45.pth"}, }, { - 'block': Bottleneck, - 'layers': [3, 4, 6, 3], - 'widen': 2, - 'num_features': 4096, - 'weights_paths': RESNET50W2_WEIGHTS_PATHS + "block": Bottleneck, + "layers": [3, 4, 6, 3], + "widen": 2, + "num_features": 4096, + "weights_paths": RESNET50W2_WEIGHTS_PATHS, }, { - 'block': Bottleneck, - 'layers': [3, 4, 6, 3], - 'widen': 4, - 'num_features': 8192, - 'weights_paths': RESNET50W4_WEIGHTS_PATHS + "block": Bottleneck, + "layers": [3, 4, 6, 3], + "widen": 4, + "num_features": 8192, + "weights_paths": RESNET50W4_WEIGHTS_PATHS, }, ] @@ -443,5 +428,5 @@ def register_resnet_backbones(register: FlashRegistry): namespace="vision", package="multiple", type="resnet", - weights_paths=params['weights_paths'] # update + weights_paths=params["weights_paths"], # update ) diff --git a/flash/image/classification/backbones/torchvision.py b/flash/image/classification/backbones/torchvision.py index b4b24d2eba1..38e4afc2f34 100644 --- a/flash/image/classification/backbones/torchvision.py +++ b/flash/image/classification/backbones/torchvision.py @@ -60,7 +60,7 @@ def register_mobilenet_vgg_backbones(register: FlashRegistry): name=model_name, namespace="vision", package="torchvision", - type=_type + type=_type, ) @@ -72,7 +72,7 @@ def register_resnext_model(register: FlashRegistry): name=model_name, namespace="vision", package="torchvision", - type="resnext" + type="resnext", ) @@ -84,5 +84,5 @@ def register_densenet_backbones(register: FlashRegistry): name=model_name, namespace="vision", package="torchvision", - type="densenet" + type="densenet", ) diff --git a/flash/image/classification/backbones/transformers.py b/flash/image/classification/backbones/transformers.py index 2a72eae58eb..35ec17bbcc6 100644 --- a/flash/image/classification/backbones/transformers.py +++ b/flash/image/classification/backbones/transformers.py @@ -21,22 +21,22 @@ # https://arxiv.org/abs/2104.14294 from Mathilde Caron and al. (29 Apr 2021) # weights from https://github.com/facebookresearch/dino def dino_deits16(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits16') + backbone = torch.hub.load("facebookresearch/dino:main", "dino_deits16") return backbone, 384 def dino_deits8(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits8') + backbone = torch.hub.load("facebookresearch/dino:main", "dino_deits8") return backbone, 384 def dino_vitb16(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') + backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16") return backbone, 768 def dino_vitb8(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8') + backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb8") return backbone, 768 diff --git a/flash/image/classification/cli.py b/flash/image/classification/cli.py index c3df8be1187..6804c909f8a 100644 --- a/flash/image/classification/cli.py +++ b/flash/image/classification/cli.py @@ -44,12 +44,13 @@ def from_movie_posters( """Downloads and loads the movie posters genre classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "./data") return ImageClassificationData.from_csv( - "Id", ["Action", "Romance", "Crime", "Thriller", "Adventure"], + "Id", + ["Action", "Romance", "Crime", "Thriller", "Adventure"], train_file="data/movie_posters/train/metadata.csv", val_file="data/movie_posters/val/metadata.csv", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) @@ -61,13 +62,13 @@ def image_classification(): default_datamodule_builder=from_hymenoptera, additional_datamodule_builders=[from_movie_posters], default_arguments={ - 'trainer.max_epochs': 3, + "trainer.max_epochs": 3, }, - datamodule_attributes={"num_classes", "multi_label"} + datamodule_attributes={"num_classes", "multi_label"}, ) cli.trainer.save_checkpoint("image_classification_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": image_classification() diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index afb2dff76bd..4bf01f47a32 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -62,7 +62,6 @@ class Image: class ImageClassificationDataFrameDataSource( DataSource[Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str]]] ): - @staticmethod def _resolve_file(root: str, file_id: str) -> str: if os.path.isabs(file_id): @@ -120,25 +119,30 @@ def load_data( label_to_class = {v: k for k, v in enumerate(labels)} data_frame = data_frame.apply(partial(self._resolve_target, label_to_class, target_keys), axis=1) - return [{ - DefaultDataKeys.INPUT: row[input_key], - DefaultDataKeys.TARGET: row[target_keys], - DefaultDataKeys.METADATA: dict(root=root), - } for _, row in data_frame.iterrows()] + return [ + { + DefaultDataKeys.INPUT: row[input_key], + DefaultDataKeys.TARGET: row[target_keys], + DefaultDataKeys.METADATA: dict(root=root), + } + for _, row in data_frame.iterrows() + ] else: - return [{ - DefaultDataKeys.INPUT: row[input_key], - DefaultDataKeys.METADATA: dict(root=root), - } for _, row in data_frame.iterrows()] + return [ + { + DefaultDataKeys.INPUT: row[input_key], + DefaultDataKeys.METADATA: dict(root=root), + } + for _, row in data_frame.iterrows() + ] def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - file = self._resolve_file(sample[DefaultDataKeys.METADATA]['root'], sample[DefaultDataKeys.INPUT]) + file = self._resolve_file(sample[DefaultDataKeys.METADATA]["root"], sample[DefaultDataKeys.INPUT]) sample[DefaultDataKeys.INPUT] = default_loader(file) return sample class ImageClassificationCSVDataSource(ImageClassificationDataFrameDataSource): - def load_data( self, data: Tuple[str, str, Union[str, List[str]], Optional[str]], @@ -152,7 +156,6 @@ def load_data( class ImageClassificationPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -226,7 +229,7 @@ def from_data_frame( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given pandas ``DataFrame`` objects. @@ -320,7 +323,7 @@ def from_csv( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` from the passed or constructed @@ -403,6 +406,7 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: class MatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib.""" + max_cols: int = 4 # maximum number of columns we accept block_viz_window: bool = True # parameter to allow user to block visualisation windows @@ -446,7 +450,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) # show image and set label as subplot title ax.imshow(_img) ax.set_title(str(_label)) - ax.axis('off') + ax.axis("off") plt.show(block=self.block_viz_window) def show_load_sample(self, samples: List[Any], running_stage: RunningStage): diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index d4b240818de..a12780a86e5 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -109,7 +109,9 @@ def __init__( self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) + self.head = head or nn.Sequential( + nn.Linear(num_features, num_classes), + ) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -124,9 +126,9 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch[DefaultDataKeys.PREDS] = super().predict_step((batch[DefaultDataKeys.INPUT]), - batch_idx, - dataloader_idx=dataloader_idx) + batch[DefaultDataKeys.PREDS] = super().predict_step( + (batch[DefaultDataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + ) return batch def forward(self, x) -> torch.Tensor: diff --git a/flash/image/classification/transforms.py b/flash/image/classification/transforms.py index 945f1cabc55..3b5ba98a4cc 100644 --- a/flash/image/classification/transforms.py +++ b/flash/image/classification/transforms.py @@ -47,7 +47,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: "per_batch_transform_on_device": ApplyToKeys( DefaultDataKeys.INPUT, K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ) + ), } return { "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)), diff --git a/flash/image/data.py b/flash/image/data.py index 4f5605efc5f..30a64fcb79f 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -45,7 +45,6 @@ class Image: class ImageDeserializer(Deserializer): - @requires_extras("image") def __init__(self): super().__init__() @@ -67,7 +66,6 @@ def example_input(self) -> str: class ImagePathsDataSource(PathsDataSource): - @requires_extras("image") def __init__(self): super().__init__(extensions=IMG_EXTENSIONS) @@ -85,7 +83,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class ImageTensorDataSource(TensorDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = to_pil_image(sample[DefaultDataKeys.INPUT]) sample[DefaultDataKeys.INPUT] = img @@ -95,7 +92,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class ImageNumpyDataSource(NumpyDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) sample[DefaultDataKeys.INPUT] = img @@ -105,7 +101,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class ImageFiftyOneDataSource(FiftyOneDataSource): - @staticmethod def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img_path = sample[DefaultDataKeys.INPUT] diff --git a/flash/image/detection/cli.py b/flash/image/detection/cli.py index f7245c8cfb1..8c2eb0c3d18 100644 --- a/flash/image/detection/cli.py +++ b/flash/image/detection/cli.py @@ -34,7 +34,7 @@ def from_coco_128( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) @@ -46,11 +46,11 @@ def object_detection(): default_datamodule_builder=from_coco_128, default_arguments={ "trainer.max_epochs": 3, - } + }, ) cli.trainer.save_checkpoint("object_detection_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": object_detection() diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index f7dbe03659c..d19ec4f2e3f 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -44,7 +44,6 @@ class COCODataSource(DataSource[Tuple[str, str]]): - @requires("pycocotools") def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: root, ann_file = data @@ -95,7 +94,7 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq image_id=img_id, area=areas, iscrowd=iscrowd, - ) + ), ) ) return data @@ -113,7 +112,6 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): - def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): super().__init__(label_field=label_field) self.iscrowd = iscrowd @@ -165,7 +163,7 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se image_id=img_id, area=output_areas, iscrowd=output_iscrowd, - ) + ), ) ) img_id += 1 @@ -197,7 +195,6 @@ def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): class ObjectDetectionPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 0323d5e2bbf..320f64bbee5 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -87,7 +87,7 @@ def __init__( pretrained: bool = True, pretrained_backbone: bool = True, trainable_backbone_layers: int = 3, - anchor_generator: Optional[Type['AnchorGenerator']] = None, + anchor_generator: Optional[Type["AnchorGenerator"]] = None, loss=None, metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, optimizer: Type[Optimizer] = torch.optim.AdamW, @@ -99,8 +99,15 @@ def __init__( if model in _models: model = ObjectDetector.get_model( - model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, - anchor_generator, **kwargs + model, + num_classes, + backbone, + fpn, + pretrained, + pretrained_backbone, + trainable_backbone_layers, + anchor_generator, + **kwargs, ) else: ValueError(f"{model} is not supported yet.") @@ -143,7 +150,7 @@ def get_model( in_channels=model.backbone.out_channels, num_anchors=model.head.classification_head.num_anchors, num_classes=num_classes, - **kwargs + **kwargs, ) else: backbone_model, num_features = ObjectDetector.backbones.get(backbone)( @@ -153,9 +160,11 @@ def get_model( ) backbone_model.out_channels = num_features if anchor_generator is None: - anchor_generator = AnchorGenerator( - sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), ) - ) if not hasattr(backbone_model, "fpn") else None + anchor_generator = ( + AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) + if not hasattr(backbone_model, "fpn") + else None + ) if model_name == "fasterrcnn": model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) diff --git a/flash/image/detection/serialization.py b/flash/image/detection/serialization.py index 46a31abe4b0..b2f0bd09013 100644 --- a/flash/image/detection/serialization.py +++ b/flash/image/detection/serialization.py @@ -101,11 +101,13 @@ def serialize(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] else: label = str(int(label)) - detections.append(fo.Detection( - label=label, - bounding_box=box, - confidence=confidence, - )) + detections.append( + fo.Detection( + label=label, + bounding_box=box, + confidence=confidence, + ) + ) fo_predictions = fo.Detections(detections=detections) if self.return_filepath: filepath = sample[DefaultDataKeys.METADATA]["filepath"] diff --git a/flash/image/detection/transforms.py b/flash/image/detection/transforms.py index 3c1684feb55..5179f1f8a73 100644 --- a/flash/image/detection/transforms.py +++ b/flash/image/detection/transforms.py @@ -32,16 +32,16 @@ def default_transforms() -> Dict[str, Callable]: batch.""" return { "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys("input", torchvision.transforms.ToTensor()), ApplyToKeys( - 'target', + "target", nn.Sequential( - ApplyToKeys('boxes', torch.as_tensor), - ApplyToKeys('labels', torch.as_tensor), - ApplyToKeys('image_id', torch.as_tensor), - ApplyToKeys('area', torch.as_tensor), - ApplyToKeys('iscrowd', torch.as_tensor), - ) + ApplyToKeys("boxes", torch.as_tensor), + ApplyToKeys("labels", torch.as_tensor), + ApplyToKeys("image_id", torch.as_tensor), + ApplyToKeys("area", torch.as_tensor), + ApplyToKeys("iscrowd", torch.as_tensor), + ), ), ), "collate": collate, diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 76bf5337109..f5e2c0cca95 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -63,7 +63,7 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Metric, Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, - pooling_fn: Callable = torch.max + pooling_fn: Callable = torch.max, ): super().__init__( model=None, @@ -71,7 +71,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - preprocess=ImageClassificationPreprocess() + preprocess=ImageClassificationPreprocess(), ) self.save_hyperparameters() @@ -89,7 +89,7 @@ def __init__( nn.Flatten(), nn.Linear(num_features, embedding_dim), ) - rank_zero_warn('Adding linear layer on top of backbone. Remember to finetune first before using!') + rank_zero_warn("Adding linear layer on top of backbone. Remember to finetune first before using!") def apply_pool(self, x): x = self.pooling_fn(x, dim=-1) @@ -126,5 +126,5 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) + batch = batch[DefaultDataKeys.INPUT] return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) diff --git a/flash/image/segmentation/cli.py b/flash/image/segmentation/cli.py index 6d01d043272..64cb0c3d938 100644 --- a/flash/image/segmentation/cli.py +++ b/flash/image/segmentation/cli.py @@ -30,7 +30,7 @@ def from_carla( """Downloads and loads the CARLA capture data set.""" download_data( "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", - "./data" + "./data", ) return SemanticSegmentationData.from_folders( train_folder="data/CameraRGB", @@ -39,7 +39,7 @@ def from_carla( batch_size=batch_size, num_workers=num_workers, num_classes=num_classes, - **preprocess_kwargs + **preprocess_kwargs, ) @@ -51,11 +51,11 @@ def semantic_segmentation(): default_datamodule_builder=from_carla, default_arguments={ "trainer.max_epochs": 3, - } + }, ) cli.trainer.save_checkpoint("semantic_segmentation_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": semantic_segmentation() diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index dea0c256939..30cc7207c7a 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -76,7 +76,6 @@ class Image: class SemanticSegmentationNumpyDataSource(NumpyDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() sample[DefaultDataKeys.INPUT] = img @@ -85,7 +84,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class SemanticSegmentationTensorDataSource(TensorDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = sample[DefaultDataKeys.INPUT].float() sample[DefaultDataKeys.INPUT] = img @@ -94,13 +92,13 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class SemanticSegmentationPathsDataSource(PathsDataSource): - @requires_extras("image") def __init__(self): super().__init__(IMG_EXTENSIONS) - def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], - dataset: BaseAutoDataset) -> Sequence[Mapping[str, Any]]: + def load_data( + self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], dataset: BaseAutoDataset + ) -> Sequence[Mapping[str, Any]]: input_data, target_data = data if self.isdir(input_data) and self.isdir(target_data): @@ -131,8 +129,8 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], data = filter( lambda sample: ( - has_file_allowed_extension(sample[0], self.extensions) and - has_file_allowed_extension(sample[1], self.extensions) + has_file_allowed_extension(sample[0], self.extensions) + and has_file_allowed_extension(sample[1], self.extensions) ), zip(input_data, target_data), ) @@ -176,7 +174,6 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationFiftyOneDataSource(FiftyOneDataSource): - @requires_extras("image") def __init__(self, label_field: str = "ground_truth"): super().__init__(label_field=label_field) @@ -223,7 +220,6 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationDeserializer(ImageDeserializer): - def deserialize(self, data: str) -> torch.Tensor: result = super().deserialize(data) result[DefaultDataKeys.INPUT] = self.to_tensor(result[DefaultDataKeys.INPUT]) @@ -232,7 +228,6 @@ def deserialize(self, data: str) -> torch.Tensor: class SemanticSegmentationPreprocess(Preprocess): - @requires_extras("image") def __init__( self, @@ -241,7 +236,7 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (128, 128), - deserializer: Optional['Deserializer'] = None, + deserializer: Optional["Deserializer"] = None, num_classes: int = None, labels_map: Dict[int, Tuple[int, int, int]] = None, **data_source_kwargs: Any, @@ -284,9 +279,10 @@ def __init__( def get_state_dict(self) -> Dict[str, Any]: return { - **self.transforms, "image_size": self.image_size, + **self.transforms, + "image_size": self.image_size, "num_classes": self.num_classes, - "labels_map": self.labels_map + "labels_map": self.labels_map, } @classmethod @@ -308,7 +304,7 @@ class SemanticSegmentationData(DataModule): @staticmethod def configure_data_fetcher( labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None - ) -> 'SegmentationMatplotlibVisualization': + ) -> "SegmentationMatplotlibVisualization": return SegmentationMatplotlibVisualization(labels_map=labels_map) def set_block_viz_window(self, value: bool) -> None: @@ -333,15 +329,16 @@ def from_data_source( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": - if 'num_classes' not in preprocess_kwargs: + if "num_classes" not in preprocess_kwargs: raise MisconfigurationException("`num_classes` should be provided during instantiation.") num_classes = preprocess_kwargs["num_classes"] - labels_map = getattr(preprocess_kwargs, "labels_map", - None) or SegmentationLabels.create_random_labels_map(num_classes) + labels_map = getattr(preprocess_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map( + num_classes + ) data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) @@ -363,7 +360,7 @@ def from_data_source( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) if dm.train_dataset is not None: @@ -392,7 +389,7 @@ def from_folders( num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, **preprocess_kwargs, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.image.segmentation.data.SemanticSegmentationData` object from the given data folders and corresponding target folders. @@ -509,7 +506,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) img_vis = np.hstack((image_vis, label_vis)) # send to visualiser ax.imshow(img_vis) - ax.axis('off') + ax.axis("off") plt.show(block=self.block_viz_window) def show_load_sample(self, samples: List[Any], running_stage: RunningStage): diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 294c7f36d9d..bc7ff8cd010 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -23,8 +23,15 @@ import segmentation_models_pytorch as smp SMP_MODEL_CLASS = [ - smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.Linknet, smp.FPN, smp.PSPNet, smp.DeepLabV3, smp.DeepLabV3Plus, - smp.PAN + smp.Unet, + smp.UnetPlusPlus, + smp.MAnet, + smp.Linknet, + smp.FPN, + smp.PSPNet, + smp.DeepLabV3, + smp.DeepLabV3Plus, + smp.PAN, ] SMP_MODELS = {a.__name__.lower(): a for a in SMP_MODEL_CLASS} @@ -64,5 +71,5 @@ def _load_smp_head( partial(_load_smp_head, head=model_name), name=model_name, namespace="image/segmentation", - package="segmentation_models.pytorch" + package="segmentation_models.pytorch", ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index e073e4ef094..771014bbb50 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -33,9 +33,8 @@ class SemanticSegmentationPostprocess(Postprocess): - def per_sample_transform(self, sample: Any) -> Any: - resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation='bilinear') + resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear") sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS])) sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT])) return super().per_sample_transform(sample) @@ -104,7 +103,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, serializer=serializer or SegmentationLabels(), - postprocess=postprocess or self.postprocess_cls() + postprocess=postprocess or self.postprocess_cls(), ) self.save_hyperparameters() @@ -138,7 +137,7 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch_input = (batch[DefaultDataKeys.INPUT]) + batch_input = batch[DefaultDataKeys.INPUT] batch[DefaultDataKeys.PREDS] = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx) return batch @@ -149,7 +148,7 @@ def forward(self, x) -> torch.Tensor: # In particular, torchvision segmentation models return the output logits # in the key `out`. if _isinstance(res, Dict[str, torch.Tensor]): - res = res['out'] + res = res["out"] return res diff --git a/flash/image/segmentation/serialization.py b/flash/image/segmentation/serialization.py index 8b219531044..8bc893fce32 100644 --- a/flash/image/segmentation/serialization.py +++ b/flash/image/segmentation/serialization.py @@ -70,7 +70,7 @@ def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, i H, W = img_labels.shape out = torch.empty(3, H, W, dtype=torch.uint8) for label_id, label_val in labels_map.items(): - mask = (img_labels == label_id) + mask = img_labels == label_id for i in range(3): out[i].masked_fill_(mask, label_val[i]) return out @@ -79,7 +79,7 @@ def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, i def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: labels_map: Dict[int, Tuple[int, int, int]] = {} for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) + labels_map[i] = torch.randint(0, 255, (3,)) return labels_map @requires("matplotlib") diff --git a/flash/image/segmentation/transforms.py b/flash/image/segmentation/transforms.py index 498d09032ff..53bd0a6314c 100644 --- a/flash/image/segmentation/transforms.py +++ b/flash/image/segmentation/transforms.py @@ -40,7 +40,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: "post_tensor_transform": nn.Sequential( ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], - KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')), + KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation="nearest")), ), ), "collate": Compose([kornia_collate, ApplyToKeys(DefaultDataKeys.TARGET, prepare_target)]), @@ -51,12 +51,13 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and ``ColorJitter``.""" return merge_transforms( - default_transforms(image_size), { + default_transforms(image_size), + { "post_tensor_transform": nn.Sequential( ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)), ), ), - } + }, ) diff --git a/flash/image/style_transfer/cli.py b/flash/image/style_transfer/cli.py index d8c553bd00c..0fec3470210 100644 --- a/flash/image/style_transfer/cli.py +++ b/flash/image/style_transfer/cli.py @@ -33,7 +33,7 @@ def from_coco_128( train_folder="data/coco128/images/train2017/", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) @@ -45,7 +45,7 @@ def style_transfer(): default_datamodule_builder=from_coco_128, default_arguments={ "trainer.max_epochs": 3, - "model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg") + "model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"), }, finetune=False, ) @@ -53,5 +53,5 @@ def style_transfer(): cli.trainer.save_checkpoint("style_transfer_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": style_transfer() diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 65a017ce4cb..f9f63c5905c 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -32,9 +32,9 @@ __all__ = ["StyleTransferPreprocess", "StyleTransferData"] -def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], - DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]: - +def _apply_to_input( + default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], DefaultDataKeys] +) -> Callable[..., Dict[str, ApplyToKeys]]: @functools.wraps(default_transforms_fn) def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: default_transforms = default_transforms_fn(*args, **kwargs) @@ -47,7 +47,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: class StyleTransferPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -119,7 +118,7 @@ def from_folders( predict_transform: Optional[Union[str, Dict]] = None, preprocess: Optional[Preprocess] = None, **kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": if any(param in kwargs and kwargs[param] is not None for param in ("val_folder", "val_transform")): raise_not_supported("validation") diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 95cf6fe3372..2908df52e6a 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -40,7 +40,6 @@ class ops: MultiLayerEncodingOperator = None class loss: - class PerceptualLoss: pass @@ -100,7 +99,7 @@ def __init__( model = pystiche.demo.transformer() if not isinstance(style_layers, (List, Tuple)): - style_layers = (style_layers, ) + style_layers = (style_layers,) perceptual_loss = self._get_perceptual_loss( backbone=backbone, @@ -134,7 +133,6 @@ def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.Enc # oversight: they normalize the representation twice by the number of channels. To be compatible with them, we # do the same here. class GramOperator(ops.GramOperator): - def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] diff --git a/flash/pointcloud/detection/cli.py b/flash/pointcloud/detection/cli.py index 0043a7232f8..01a4c329cef 100644 --- a/flash/pointcloud/detection/cli.py +++ b/flash/pointcloud/detection/cli.py @@ -32,7 +32,7 @@ def from_kitti( val_folder="data/KITTI_Tiny/Kitti/val", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) @@ -51,5 +51,5 @@ def pointcloud_detection(): cli.trainer.save_checkpoint("pointcloud_detection_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": pointcloud_detection() diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 4527eba22ba..b6a778db75f 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -21,7 +21,6 @@ class PointCloudObjectDetectionDataFormat: class PointCloudObjectDetectorDatasetDataSource(DataSource): - def __init__(self, **kwargs): super().__init__() @@ -39,13 +38,12 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: sample = dataset.dataset[index] return { - DefaultDataKeys.INPUT: sample['data'], + DefaultDataKeys.INPUT: sample["data"], DefaultDataKeys.METADATA: sample["attr"], } class PointCloudObjectDetectorPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -106,7 +104,7 @@ def from_folders( calibrations_folder_name: Optional[str] = "calibs", data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` diff --git a/flash/pointcloud/detection/datasets.py b/flash/pointcloud/detection/datasets.py index 4860da13639..335f6997572 100644 --- a/flash/pointcloud/detection/datasets.py +++ b/flash/pointcloud/detection/datasets.py @@ -32,7 +32,7 @@ def kitti(dataset_path, download, **kwargs): "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_kitti.sh", # noqa E501 None, dataset_path, - name + name, ) return KITTI(download_path, **kwargs) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index d1abee600a3..b17adb67bac 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -79,9 +79,9 @@ def __init__( metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(), - lambda_loss_cls: float = 1., - lambda_loss_bbox: float = 1., - lambda_loss_dir: float = 1., + lambda_loss_cls: float = 1.0, + lambda_loss_bbox: float = 1.0, + lambda_loss_dir: float = 1.0, ): super().__init__( @@ -120,8 +120,9 @@ def __init__( def compute_loss(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: losses = losses["loss"] return ( - self.hparams.lambda_loss_cls * losses["loss_cls"] + self.hparams.lambda_loss_bbox * losses["loss_bbox"] + - self.hparams.lambda_loss_dir * losses["loss_dir"] + self.hparams.lambda_loss_cls * losses["loss_cls"] + + self.hparams.lambda_loss_bbox * losses["loss_bbox"] + + self.hparams.lambda_loss_dir * losses["loss_dir"] ) def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]): @@ -143,7 +144,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return { DefaultDataKeys.INPUT: getattr(batch, "point", None), DefaultDataKeys.PREDS: boxes, - DefaultDataKeys.METADATA: [a["name"] for a in batch.attr] + DefaultDataKeys.METADATA: [a["name"] for a in batch.attr], } def forward(self, x) -> torch.Tensor: diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py index 5578955d8ac..065a0c51b9b 100644 --- a/flash/pointcloud/detection/open3d_ml/app.py +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -26,7 +26,6 @@ from open3d.visualization import gui class Visualizer(Visualizer): - def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): """Visualize a dataset. @@ -125,14 +124,13 @@ def get_data(self, index): def get_attr(self, index): return self.dataset[index]["attr"] - def get_split(self, *_) -> 'VizDataset': + def get_split(self, *_) -> "VizDataset": return self def __len__(self) -> int: return len(self.dataset) class App: - def __init__(self, datamodule: DataModule): self.datamodule = datamodule self._enabled = not flash._IS_TESTING @@ -145,7 +143,7 @@ def show_train_dataset(self, indices=None): if self._enabled: dataset = self.get_dataset("train") viz = Visualizer() - viz.visualize_dataset(dataset, 'all', indices=indices) + viz.visualize_dataset(dataset, "all", indices=indices) def show_predictions(self, predictions): if self._enabled: @@ -167,5 +165,5 @@ def show_predictions(self, predictions): viz.visualize([data], bounding_boxes=bounding_box) -def launch_app(datamodule: DataModule) -> 'App': +def launch_app(datamodule: DataModule) -> "App": return App(datamodule) diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py index 622971299e1..b8b88b1d892 100644 --- a/flash/pointcloud/detection/open3d_ml/backbones.py +++ b/flash/pointcloud/detection/open3d_ml/backbones.py @@ -35,7 +35,6 @@ class ObjectDetectBatchCollator(ObjectDetectBatch): - def __init__(self, batches): self.num_batches = len(batches) super().__init__(batches) @@ -56,11 +55,11 @@ def register_open_3d_ml(register: FlashRegistry): def get_collate_fn(model) -> Callable: batcher_name = model.cfg.batcher - if batcher_name == 'DefaultBatcher': + if batcher_name == "DefaultBatcher": batcher = DefaultBatcher() - elif batcher_name == 'ConcatBatcher': + elif batcher_name == "ConcatBatcher": batcher = ConcatBatcher(torch, model.__class__.__name__) - elif batcher_name == 'ObjectDetectBatchCollator': + elif batcher_name == "ObjectDetectBatchCollator": return ObjectDetectBatchCollator return batcher.collate_fn @@ -70,7 +69,9 @@ def pointpillars_kitti(*args, **kwargs) -> PointPillars: cfg.model.device = "cpu" model = PointPillars(**cfg.model) weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth") - model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'], ) + model.load_state_dict( + pl_load(weight_url, map_location="cpu")["model_state_dict"], + ) model.cfg.batcher = "ObjectDetectBatchCollator" return model, 384, get_collate_fn(model) diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py index 234344e6f25..f4c8a640bd8 100644 --- a/flash/pointcloud/detection/open3d_ml/data_sources.py +++ b/flash/pointcloud/detection/open3d_ml/data_sources.py @@ -36,7 +36,6 @@ class BasePointCloudObjectDetectorLoader: class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader): - def __init__( self, image_size: tuple = (375, 1242), @@ -56,7 +55,7 @@ def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]): if not exists(meta_file): raise MisconfigurationException(f"The {root_dir} should contain a `meta.yaml` file about the classes.") - with open(meta_file, 'r') as f: + with open(meta_file, "r") as f: self.meta = yaml.safe_load(f) if "label_to_names" not in self.meta: @@ -94,11 +93,10 @@ def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]): dataset.path_list = scan_paths - return [{ - "scan_path": scan_path, - "label_path": label_path, - "calibration_path": calibration_path - } for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths)] + return [ + {"scan_path": scan_path, "label_path": label_path, "calibration_path": calibration_path} + for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths) + ] def load_sample( self, sample: Dict[str, str], dataset: Optional[BaseAutoDataset] = None, has_label: bool = True @@ -109,7 +107,7 @@ def load_sample( if has_label: label = KITTI.read_label(sample["label_path"], calib) - reduced_pc = DataProcessing.remove_outside_points(pc, calib['world_cam'], calib['cam_img'], self.image_size) + reduced_pc = DataProcessing.remove_outside_points(pc, calib["world_cam"], calib["cam_img"], self.image_size) attr = { "name": basename(sample["scan_path"]), @@ -120,12 +118,12 @@ def load_sample( } data = { - 'point': reduced_pc, - 'full_point': pc, - 'feat': None, - 'calib': calib, - 'bounding_boxes': label if has_label else None, - 'attr': attr + "point": reduced_pc, + "full_point": pc, + "feat": None, + "calib": calib, + "bounding_boxes": label if has_label else None, + "attr": attr, } return data, attr @@ -154,7 +152,6 @@ def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None): class PointCloudObjectDetectorFoldersDataSource(DataSource): - def __init__( self, data_format: Optional[BaseDataFormat] = None, diff --git a/flash/pointcloud/segmentation/cli.py b/flash/pointcloud/segmentation/cli.py index 7bb11d604e5..57d1125f9b2 100644 --- a/flash/pointcloud/segmentation/cli.py +++ b/flash/pointcloud/segmentation/cli.py @@ -29,10 +29,10 @@ def from_kitti( download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") return PointCloudSegmentationData.from_folders( train_folder="data/SemanticKittiTiny/train", - val_folder='data/SemanticKittiTiny/val', + val_folder="data/SemanticKittiTiny/val", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) @@ -52,5 +52,5 @@ def pointcloud_segmentation(): cli.trainer.save_checkpoint("pointcloud_segmentation_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": pointcloud_segmentation() diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 193b5838e2d..92cd2cdbc22 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -8,7 +8,6 @@ class PointCloudSegmentationDatasetDataSource(DataSource): - def load_data( self, data: Any, @@ -25,13 +24,12 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: sample = dataset.dataset[index] return { - DefaultDataKeys.INPUT: sample['data'], + DefaultDataKeys.INPUT: sample["data"], DefaultDataKeys.METADATA: sample["attr"], } class PointCloudSegmentationFoldersDataSource(DataSource): - @requires_extras("pointcloud") def load_data( self, @@ -49,13 +47,12 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: sample = dataset.dataset[index] return { - DefaultDataKeys.INPUT: sample['data'], + DefaultDataKeys.INPUT: sample["data"], DefaultDataKeys.METADATA: sample["attr"], } class PointCloudSegmentationPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, diff --git a/flash/pointcloud/segmentation/datasets.py b/flash/pointcloud/segmentation/datasets.py index 19182d816fd..ff792282a41 100644 --- a/flash/pointcloud/segmentation/datasets.py +++ b/flash/pointcloud/segmentation/datasets.py @@ -34,7 +34,9 @@ def lyft(dataset_path): name = "Lyft" executor( "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_lyft.sh", - "https://github.com/intel-isl/Open3D-ML/blob/master/scripts/preprocess_lyft.py", dataset_path, name + "https://github.com/intel-isl/Open3D-ML/blob/master/scripts/preprocess_lyft.py", + dataset_path, + name, ) return Lyft(os.path.join(dataset_path, name)) @@ -51,7 +53,7 @@ def semantickitti(dataset_path, download, **kwargs): "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_semantickitti.sh", # noqa E501 None, dataset_path, - name + name, ) return SemanticKITTI(os.path.join(dataset_path, name), **kwargs) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index b6de290b25a..7098aea98e7 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -39,7 +39,6 @@ class PointCloudSegmentationFinetuning(BaseFinetuning): - def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1): super().__init__() self.num_layers = num_layers @@ -47,7 +46,7 @@ def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: i self.unfreeze_epoch = unfreeze_epoch def freeze_before_training(self, pl_module: LightningModule) -> None: - self.freeze(modules=list(pl_module.backbone.children())[:-self.num_layers], train_bn=self.train_bn) + self.freeze(modules=list(pl_module.backbone.children())[: -self.num_layers], train_bn=self.train_bn) def finetune_function( self, @@ -59,7 +58,7 @@ def finetune_function( if epoch != self.unfreeze_epoch: return self.unfreeze_and_add_param_group( - modules=list(pl_module.backbone.children())[-self.num_layers:], + modules=list(pl_module.backbone.children())[-self.num_layers :], optimizer=optimizer, train_bn=self.train_bn, ) @@ -112,6 +111,7 @@ def __init__( serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(), ): import flash + if metrics is None: metrics = IoU(num_classes=num_classes) @@ -168,9 +168,9 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT]) - batch[DefaultDataKeys.TARGET] = batch[DefaultDataKeys.INPUT]['labels'] + batch[DefaultDataKeys.TARGET] = batch[DefaultDataKeys.INPUT]["labels"] # drop sub-sampled pointclouds - batch[DefaultDataKeys.INPUT] = batch[DefaultDataKeys.INPUT]['xyz'][0] + batch[DefaultDataKeys.INPUT] = batch[DefaultDataKeys.INPUT]["xyz"][0] return batch def forward(self, x) -> torch.Tensor: diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index f525ef64c92..b1145c53b57 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -29,7 +29,6 @@ class Visualizer(Open3dVisualizer): - def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): """Visualize a dataset. @@ -61,7 +60,6 @@ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768 class App: - def __init__(self, datamodule: DataModule): self.datamodule = datamodule self._enabled = True # not flash._IS_TESTING @@ -77,7 +75,7 @@ def show_train_dataset(self, indices=None): if self._enabled: dataset = self.get_dataset("train") viz = Visualizer() - viz.visualize_dataset(dataset, 'all', indices=indices) + viz.visualize_dataset(dataset, "all", indices=indices) def show_predictions(self, predictions): if self._enabled: @@ -86,12 +84,14 @@ def show_predictions(self, predictions): predictions_visualizations = [] for pred in predictions: - predictions_visualizations.append({ - "points": torch.stack(pred[DefaultDataKeys.INPUT]), - "labels": torch.stack(pred[DefaultDataKeys.TARGET]), - "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, - "name": pred[DefaultDataKeys.METADATA]["name"], - }) + predictions_visualizations.append( + { + "points": torch.stack(pred[DefaultDataKeys.INPUT]), + "labels": torch.stack(pred[DefaultDataKeys.TARGET]), + "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, + "name": pred[DefaultDataKeys.METADATA]["name"], + } + ) viz = Visualizer() lut = LabelLUT() @@ -103,5 +103,5 @@ def show_predictions(self, predictions): viz.visualize(predictions_visualizations) -def launch_app(datamodule: DataModule) -> 'App': +def launch_app(datamodule: DataModule) -> "App": return App(datamodule) diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py index aec3aa01235..abf1226b686 100644 --- a/flash/pointcloud/segmentation/open3d_ml/backbones.py +++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py @@ -34,9 +34,9 @@ def register_open_3d_ml(register: FlashRegistry): def get_collate_fn(model) -> Callable: batcher_name = model.cfg.batcher - if batcher_name == 'DefaultBatcher': + if batcher_name == "DefaultBatcher": batcher = DefaultBatcher() - elif batcher_name == 'ConcatBatcher': + elif batcher_name == "ConcatBatcher": batcher = ConcatBatcher(torch, model.__class__.__name__) else: batcher = None @@ -50,7 +50,7 @@ def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet: weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth") else: weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth") - model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict']) + model.load_state_dict(pl_load(weight_url, map_location="cpu")["model_state_dict"]) return model, 32, get_collate_fn(model) @register @@ -58,8 +58,9 @@ def randlanet_toronto3d(*args, **kwargs) -> RandLANet: cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml")) model = RandLANet(**cfg.model) model.load_state_dict( - pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"), - map_location='cpu')['model_state_dict'], + pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"), map_location="cpu")[ + "model_state_dict" + ], ) return model, 32, get_collate_fn(model) @@ -68,8 +69,9 @@ def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet: cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml")) model = RandLANet(**cfg.model) model.load_state_dict( - pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"), - map_location='cpu')['model_state_dict'], + pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"), map_location="cpu")[ + "model_state_dict" + ], ) return model, 32, get_collate_fn(model) diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py index 73a3344dcd8..983e6e8c9d8 100644 --- a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py +++ b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -28,16 +28,15 @@ class SequencesDataset(Dataset): - def __init__( self, data, - cache_dir='./logs/cache', + cache_dir="./logs/cache", use_cache=False, num_points=65536, ignored_label_inds=[0], predicting=False, - **kwargs + **kwargs, ): super().__init__() @@ -78,13 +77,13 @@ def load_meta(self, root_dir): f"The {root_dir} should contain a `meta.yaml` file about the pointcloud sequences." ) - with open(meta_file, 'r') as f: + with open(meta_file, "r") as f: self.meta = yaml.safe_load(f) self.label_to_names = self.get_label_to_names() self.num_classes = len(self.label_to_names) - with open(meta_file, 'r') as f: + with open(meta_file, "r") as f: self.meta = yaml.safe_load(f) remap_dict_val = self.meta["learning_map"] @@ -138,7 +137,7 @@ def get_label_to_names(self): def __getitem__(self, index): data = self.get_data(index) - data['attr'] = self.get_attr(index) + data["attr"] = self.get_attr(index) return data def get_data(self, idx): @@ -147,21 +146,21 @@ def get_data(self, idx): dir, file = split(pc_path) if self.predicting: - label_path = join(dir, file[:-4] + '.label') + label_path = join(dir, file[:-4] + ".label") else: - label_path = join(dir, '../labels', file[:-4] + '.label') + label_path = join(dir, "../labels", file[:-4] + ".label") if not exists(label_path): labels = np.zeros(np.shape(points)[0], dtype=np.int32) - if self.split not in ['test', 'all']: - raise FileNotFoundError(f' Label file {label_path} not found') + if self.split not in ["test", "all"]: + raise FileNotFoundError(f" Label file {label_path} not found") else: labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32) data = { - 'point': points[:, 0:3], - 'feat': None, - 'label': labels, + "point": points[:, 0:3], + "feat": None, + "label": labels, } return data @@ -170,10 +169,10 @@ def get_attr(self, idx): pc_path = self.path_list[idx] dir, file = split(pc_path) _, seq = split(split(dir)[0]) - name = '{}_{}'.format(seq, file[:-4]) + name = "{}_{}".format(seq, file[:-4]) pc_path = str(pc_path) - attr = {'idx': idx, 'name': name, 'path': pc_path, 'split': self.split} + attr = {"idx": idx, "name": name, "path": pc_path, "split": self.split} return attr def __len__(self): diff --git a/flash/setup_tools.py b/flash/setup_tools.py index b609bd70322..6bba0c335e9 100644 --- a/flash/setup_tools.py +++ b/flash/setup_tools.py @@ -19,17 +19,17 @@ _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) -def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_chars: str = '#@') -> List[str]: - with open(os.path.join(path_dir, file_name), 'r') as file: +def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_chars: str = "#@") -> List[str]: + with open(os.path.join(path_dir, file_name), "r") as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: # filer all comments found = [ln.index(ch) for ch in comment_chars if ch in ln] if found: - ln = ln[:min(found)].strip() + ln = ln[: min(found)].strip() # skip directly installed dependencies - if ln.startswith('http') or ln.startswith('git'): + if ln.startswith("http") or ln.startswith("git"): continue if ln: # if requirement is not empty reqs.append(ln) @@ -46,7 +46,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: text = open(path_readme, encoding="utf-8").read() # drop images from readme - text = text.replace('![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)', '') + text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "") # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png github_source_url = os.path.join(homepage, "raw", ver) @@ -55,17 +55,17 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") # readthedocs badge - text = text.replace('badge/?version=stable', f'badge/?version={ver}') - text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{ver}') + text = text.replace("badge/?version=stable", f"badge/?version={ver}") + text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{ver}") # codecov badge - text = text.replace('/branch/master/graph/badge.svg', f'/release/{ver}/graph/badge.svg') + text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg") # replace github badges for release ones - text = text.replace('badge.svg?branch=master&event=push', f'badge.svg?tag={ver}') + text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") - skip_begin = r'' - skip_end = r'' + skip_begin = r"" + skip_end = r"" # todo: wrap content as commented description - text = re.sub(rf"{skip_begin}.+?{skip_end}", '', text, flags=re.IGNORECASE + re.DOTALL) + text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png # github_release_url = os.path.join(homepage, "releases", "download", ver) diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index cfaba9f1365..63eff2458fd 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -55,5 +55,5 @@ def tabular_classification(): cli.trainer.save_checkpoint("tabular_classification_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": tabular_classification() diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index b600f4e8951..b01e99e4f6f 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -71,7 +71,7 @@ def __init__( cat_idxs=list(range(len(embedding_sizes))), cat_dims=list(cat_dims), cat_emb_dim=list(cat_emb_dim), - **tabnet_kwargs + **tabnet_kwargs, ) super().__init__( @@ -108,11 +108,11 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) + batch = batch[DefaultDataKeys.INPUT] return self(batch) @classmethod - def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier': + def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": model = cls(datamodule.num_features, datamodule.num_classes, datamodule.embedding_sizes, **kwargs) return model diff --git a/flash/tabular/data.py b/flash/tabular/data.py index 006c32362b5..da36d726cec 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -39,7 +39,6 @@ class TabularDataFrameDataSource(DataSource[DataFrame]): - def __init__( self, cat_cols: Optional[List[str]] = None, @@ -73,8 +72,9 @@ def common_load_data( ): # impute_data # compute train dataset stats - dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, - self.target_codes) + dfs = _pre_transform( + [df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes + ) df = dfs[0] @@ -91,10 +91,9 @@ def common_load_data( def load_data(self, data: DataFrame, dataset: Optional[Any] = None): df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) - return [{ - DefaultDataKeys.INPUT: (c, n), - DefaultDataKeys.TARGET: t - } for c, n, t in zip(cat_vars, num_vars, target)] + return [ + {DefaultDataKeys.INPUT: (c, n), DefaultDataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, target) + ] def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) @@ -102,7 +101,6 @@ def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): class TabularCSVDataSource(TabularDataFrameDataSource): - def load_data(self, data: str, dataset: Optional[Any] = None): return super().load_data(pd.read_csv(data), dataset=dataset) @@ -111,7 +109,6 @@ def predict_load_data(self, data: str, dataset: Optional[Any] = None): class TabularDeserializer(Deserializer): - def __init__( self, cat_cols: Optional[List[str]] = None, @@ -122,7 +119,7 @@ def __init__( codes: Optional[Dict[str, Any]] = None, target_codes: Optional[Dict[str, Any]] = None, classes: Optional[List[str]] = None, - is_regression: bool = True + is_regression: bool = True, ): super().__init__() self.cat_cols = cat_cols @@ -137,8 +134,9 @@ def __init__( def deserialize(self, data: str) -> Any: df = pd.read_csv(StringIO(data)) - df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, - self.target_codes)[0] + df = _pre_transform( + [df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes + )[0] cat_vars = _to_cat_vars_numpy(df, self.cat_cols) num_vars = _to_num_vars_numpy(df, self.num_cols) @@ -159,7 +157,6 @@ def example_input(self) -> str: class TabularPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -175,7 +172,7 @@ def __init__( target_codes: Optional[Dict[str, Any]] = None, classes: Optional[List[str]] = None, is_regression: bool = True, - deserializer: Optional[Deserializer] = None + deserializer: Optional[Deserializer] = None, ): classes = classes or [] @@ -203,7 +200,8 @@ def __init__( ), }, default_data_source=DefaultDataSources.CSV, - deserializer=deserializer or TabularDeserializer( + deserializer=deserializer + or TabularDeserializer( cat_cols=cat_cols, num_cols=num_cols, target_col=target_col, @@ -212,8 +210,8 @@ def __init__( codes=codes, target_codes=target_codes, classes=classes, - is_regression=is_regression - ) + is_regression=is_regression, + ), ) def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: @@ -231,12 +229,11 @@ def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: } @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess": return cls(**state_dict) class TabularPostprocess(Postprocess): - def uncollate(self, batch: Any) -> Any: return batch @@ -277,13 +274,13 @@ def embedding_sizes(self) -> list: # The following "formula" provides a general rule of thumb about the number of embedding dimensions: # embedding_dimensions = number_of_categories**0.25 num_classes = [len(self.codes[cat]) for cat in self.cat_cols] - emb_dims = [max(int(n**0.25), 16) for n in num_classes] + emb_dims = [max(int(n ** 0.25), 16) for n in num_classes] return list(zip(num_classes, emb_dims)) @staticmethod def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]): if cat_cols is None and num_cols is None: - raise RuntimeError('Both `cat_cols` and `num_cols` are None!') + raise RuntimeError("Both `cat_cols` and `num_cols` are None!") return cat_cols or [], num_cols or [] @@ -455,7 +452,7 @@ def from_csv( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. Args: diff --git a/flash/template/classification/backbones.py b/flash/template/classification/backbones.py index b36f6a398e2..7ea84130033 100644 --- a/flash/template/classification/backbones.py +++ b/flash/template/classification/backbones.py @@ -21,21 +21,27 @@ @TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification") def load_mlp_128(num_features, **_): """A simple MLP backbone with 128 hidden units.""" - return nn.Sequential( - nn.Linear(num_features, 128), - nn.ReLU(True), - nn.BatchNorm1d(128), - ), 128 + return ( + nn.Sequential( + nn.Linear(num_features, 128), + nn.ReLU(True), + nn.BatchNorm1d(128), + ), + 128, + ) @TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification") def load_mlp_128_256(num_features, **_): """An two layer MLP backbone with 128 and 256 hidden units respectively.""" - return nn.Sequential( - nn.Linear(num_features, 128), - nn.ReLU(True), - nn.BatchNorm1d(128), - nn.Linear(128, 256), - nn.ReLU(True), - nn.BatchNorm1d(256), - ), 256 + return ( + nn.Sequential( + nn.Linear(num_features, 128), + nn.ReLU(True), + nn.BatchNorm1d(128), + nn.Linear(128, 256), + nn.ReLU(True), + nn.BatchNorm1d(256), + ), + 256, + ) diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index b38e5814285..e330fafdc86 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -114,7 +114,7 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key from the input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" - batch = (batch[DefaultDataKeys.INPUT]) + batch = batch[DefaultDataKeys.INPUT] return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) def forward(self, x) -> torch.Tensor: diff --git a/flash/text/classification/cli.py b/flash/text/classification/cli.py index 2418d80eccd..42499bb53fe 100644 --- a/flash/text/classification/cli.py +++ b/flash/text/classification/cli.py @@ -71,11 +71,11 @@ def text_classification(): default_arguments={ "trainer.max_epochs": 3, }, - datamodule_attributes={"num_classes", "multi_label", "backbone"} + datamodule_attributes={"num_classes", "multi_label", "backbone"}, ) cli.trainer.save_checkpoint("text_classification_model.pt") -if __name__ == '__main__': +if __name__ == "__main__": text_classification() diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 8d362e616c4..ebb202624e5 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -31,7 +31,6 @@ class TextDeserializer(Deserializer): - @requires_extras("text") def __init__(self, backbone: str, max_length: int, use_fast: bool = True): super().__init__() @@ -57,7 +56,6 @@ def __setstate__(self, state): class TextDataSource(DataSource): - @requires_extras("text") def __init__(self, backbone: str, max_length: int = 128): super().__init__() @@ -92,7 +90,6 @@ def __setstate__(self, state): class TextFileDataSource(TextDataSource): - def __init__(self, filetype: str, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -110,7 +107,7 @@ def load_data( dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), ) -> Union[Sequence[Mapping[str, Any]]]: - if self.filetype == 'json': + if self.filetype == "json": file, input, target, field = data else: file, input, target = data @@ -123,22 +120,25 @@ def load_data( # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING and not torch.cuda.is_available(): try: - if self.filetype == 'json' and field is not None: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'], - field=field)[0] - }) + if self.filetype == "json" and field is not None: + dataset_dict = DatasetDict( + { + stage: load_dataset( + self.filetype, data_files=data_files, split=[f"{stage}[:20]"], field=field + )[0] + } + ) else: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + dataset_dict = DatasetDict( + {stage: load_dataset(self.filetype, data_files=data_files, split=[f"{stage}[:20]"])[0]} + ) except Exception: - if self.filetype == 'json' and field is not None: + if self.filetype == "json" and field is not None: dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) else: dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - if self.filetype == 'json' and field is not None: + if self.filetype == "json" and field is not None: dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) else: dataset_dict = load_dataset(self.filetype, data_files=data_files) @@ -188,7 +188,6 @@ def __setstate__(self, state): class TextCSVDataSource(TextFileDataSource): - def __init__(self, backbone: str, max_length: int = 128): super().__init__("csv", backbone, max_length=max_length) @@ -203,7 +202,6 @@ def __setstate__(self, state): class TextJSONDataSource(TextFileDataSource): - def __init__(self, backbone: str, max_length: int = 128): super().__init__("json", backbone, max_length=max_length) @@ -218,7 +216,6 @@ def __setstate__(self, state): class TextSentencesDataSource(TextDataSource): - def __init__(self, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -230,7 +227,12 @@ def load_data( if isinstance(data, str): data = [data] - return [self._tokenize_fn(s, ) for s in data] + return [ + self._tokenize_fn( + s, + ) + for s in data + ] def __getstate__(self): # TODO: Find out why this is being pickled state = self.__dict__.copy() @@ -243,7 +245,6 @@ def __setstate__(self, state): class TextClassificationPreprocess(Preprocess): - @requires_extras("text") def __init__( self, @@ -297,7 +298,6 @@ def collate(self, samples: Any) -> Tensor: class TextClassificationPostprocess(Postprocess): - def per_batch_transform(self, batch: Any) -> Any: if isinstance(batch, SequenceClassifierOutput): batch = batch.logits diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 6cf7ac785eb..60404a5b66d 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -33,14 +33,13 @@ class Seq2SeqDataSource(DataSource): - @requires_extras("text") def __init__( self, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__() @@ -82,23 +81,22 @@ def __setstate__(self, state): class Seq2SeqFileDataSource(Seq2SeqDataSource): - def __init__( self, filetype: str, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', + padding: Union[str, bool] = "max_length", ): super().__init__(backbone, max_source_length, max_target_length, padding) self.filetype = filetype - def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': + def load_data(self, data: Any, columns: List[str] = None) -> "datasets.Dataset": if columns is None: columns = ["input_ids", "attention_mask", "labels"] - if self.filetype == 'json': + if self.filetype == "json": file, input, target, field = data else: file, input, target = data @@ -109,22 +107,25 @@ def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING: try: - if self.filetype == 'json' and field is not None: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'], - field=field)[0] - }) + if self.filetype == "json" and field is not None: + dataset_dict = DatasetDict( + { + stage: load_dataset( + self.filetype, data_files=data_files, split=[f"{stage}[:20]"], field=field + )[0] + } + ) else: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + dataset_dict = DatasetDict( + {stage: load_dataset(self.filetype, data_files=data_files, split=[f"{stage}[:20]"])[0]} + ) except Exception: - if self.filetype == 'json' and field is not None: + if self.filetype == "json" and field is not None: dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) else: dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - if self.filetype == 'json' and field is not None: + if self.filetype == "json" and field is not None: dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) else: dataset_dict = load_dataset(self.filetype, data_files=data_files) @@ -133,7 +134,7 @@ def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': dataset_dict.set_format(columns=columns) return dataset_dict[stage] - def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: + def predict_load_data(self, data: Any) -> Union["datasets.Dataset", List[Dict[str, torch.Tensor]]]: return self.load_data(data, columns=["input_ids", "attention_mask"]) def __getstate__(self): # TODO: Find out why this is being pickled @@ -147,13 +148,12 @@ def __setstate__(self, state): class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): - def __init__( self, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', + padding: Union[str, bool] = "max_length", ): super().__init__( "csv", @@ -174,13 +174,12 @@ def __setstate__(self, state): class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): - def __init__( self, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', + padding: Union[str, bool] = "max_length", ): super().__init__( "json", @@ -201,7 +200,6 @@ def __setstate__(self, state): class Seq2SeqSentencesDataSource(Seq2SeqDataSource): - def load_data( self, data: Union[str, List[str]], @@ -232,7 +230,6 @@ class Seq2SeqBackboneState(ProcessState): class Seq2SeqPreprocess(Preprocess): - @requires_extras("text") def __init__( self, @@ -243,7 +240,7 @@ def __init__( backbone: str = "sshleifer/tiny-mbart", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): self.backbone = backbone self.max_target_length = max_target_length @@ -276,7 +273,7 @@ def __init__( ), }, default_data_source="sentences", - deserializer=TextDeserializer(backbone, max_source_length) + deserializer=TextDeserializer(backbone, max_source_length), ) self.set_state(Seq2SeqBackboneState(self.backbone)) @@ -300,7 +297,6 @@ def collate(self, samples: Any) -> Tensor: class Seq2SeqPostprocess(Postprocess): - @requires_extras("text") def __init__(self): super().__init__() diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index 47992f59748..621bb23d748 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -49,7 +49,7 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: for i in range(1, n_gram + 1): for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j:(i + j)]) + ngram_key = tuple(ngram_input_list[j : (i + j)]) ngram_counter[ngram_key] += 1 return ngram_counter @@ -94,12 +94,11 @@ def compute(self): else: precision_scores = self.numerator / self.denominator - log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, - device=self.r.device) * torch.log(precision_scores) - geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = ( - tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) + log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, device=self.r.device) * torch.log( + precision_scores ) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) bleu = brevity_penalty * geometric_mean return bleu diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 3d93ef9a958..283abaf1204 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -40,7 +40,7 @@ def _pad_tensors_to_max_len(model_cfg, tensor, max_length): ) padded_tensor = pad_token_id * torch.ones((tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device) - padded_tensor[:, :tensor.shape[-1]] = tensor + padded_tensor[:, : tensor.shape[-1]] = tensor return padded_tensor @@ -60,7 +60,7 @@ class Seq2SeqTask(Task): def __init__( self, - backbone: str = 't5-small', + backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -83,7 +83,7 @@ def forward(self, x: Any) -> Any: max_length = self.val_target_max_length if self.val_target_max_length else self.model.config.max_length num_beams = self.num_beams if self.num_beams else self.model.config.num_beams generated_tokens = self.model.generate( - input_ids=x['input_ids'], attention_mask=x['attention_mask'], max_length=max_length, num_beams=num_beams + input_ids=x["input_ids"], attention_mask=x["attention_mask"], max_length=max_length, num_beams=num_beams ) # in case the batch is shorter than max length, the output should be padded if generated_tokens.shape[-1] < max_length: @@ -125,7 +125,7 @@ def _initialize_model_specific_parameters(self): self.model.config.update(pars) @property - def tokenizer(self) -> 'PreTrainedTokenizerBase': + def tokenizer(self) -> "PreTrainedTokenizerBase": return self.data_pipeline.data_source.tokenizer def tokenize_labels(self, labels: Tensor) -> List[str]: diff --git a/flash/text/seq2seq/core/utils.py b/flash/text/seq2seq/core/utils.py index 02647f7264c..e48248754c8 100644 --- a/flash/text/seq2seq/core/utils.py +++ b/flash/text/seq2seq/core/utils.py @@ -16,8 +16,9 @@ from pytorch_lightning.utilities import _module_available nltk = None -if _module_available('nltk'): +if _module_available("nltk"): import nltk + nltk.download("punkt", quiet=True) diff --git a/flash/text/seq2seq/question_answering/data.py b/flash/text/seq2seq/question_answering/data.py index b3d42662a54..ad3f028f20e 100644 --- a/flash/text/seq2seq/question_answering/data.py +++ b/flash/text/seq2seq/question_answering/data.py @@ -17,7 +17,6 @@ class QuestionAnsweringPreprocess(Seq2SeqPreprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -27,7 +26,7 @@ def __init__( backbone: str = "t5-small", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__( train_transform=train_transform, diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py index 51d030a7cea..2db3a6d6aa3 100644 --- a/flash/text/seq2seq/question_answering/model.py +++ b/flash/text/seq2seq/question_answering/model.py @@ -54,7 +54,7 @@ def __init__( val_target_max_length: Optional[int] = None, num_beams: Optional[int] = 4, use_stemmer: bool = True, - rouge_newline_sep: bool = True + rouge_newline_sep: bool = True, ): self.save_hyperparameters() super().__init__( @@ -64,7 +64,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, val_target_max_length=val_target_max_length, - num_beams=num_beams + num_beams=num_beams, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/summarization/cli.py b/flash/text/seq2seq/summarization/cli.py index b63b41958a6..666dd87f40a 100644 --- a/flash/text/seq2seq/summarization/cli.py +++ b/flash/text/seq2seq/summarization/cli.py @@ -49,11 +49,11 @@ def summarization(): default_arguments={ "trainer.max_epochs": 3, "model.backbone": "sshleifer/distilbart-xsum-1-1", - } + }, ) cli.trainer.save_checkpoint("summarization_model_xsum.pt") -if __name__ == '__main__': +if __name__ == "__main__": summarization() diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index c2a29df52ce..3797d97f924 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -17,7 +17,6 @@ class SummarizationPreprocess(Seq2SeqPreprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -27,7 +26,7 @@ def __init__( backbone: str = "sshleifer/distilbart-xsum-1-1", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__( train_transform=train_transform, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index d810bd1d22d..af7820b10e5 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -54,7 +54,7 @@ def __init__( val_target_max_length: Optional[int] = None, num_beams: Optional[int] = 4, use_stemmer: bool = True, - rouge_newline_sep: bool = True + rouge_newline_sep: bool = True, ): self.save_hyperparameters() super().__init__( @@ -64,7 +64,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, val_target_max_length=val_target_max_length, - num_beams=num_beams + num_beams=num_beams, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/translation/cli.py b/flash/text/seq2seq/translation/cli.py index 8e9865431f2..1609cb4de04 100644 --- a/flash/text/seq2seq/translation/cli.py +++ b/flash/text/seq2seq/translation/cli.py @@ -49,11 +49,11 @@ def translation(): default_arguments={ "trainer.max_epochs": 3, "model.backbone": "Helsinki-NLP/opus-mt-en-ro", - } + }, ) cli.trainer.save_checkpoint("translation_model_en_ro.pt") -if __name__ == '__main__': +if __name__ == "__main__": translation() diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 0b9e7a3ce70..5485be1003b 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -17,7 +17,6 @@ class TranslationPreprocess(Seq2SeqPreprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -27,7 +26,7 @@ def __init__( backbone: str = "t5-small", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__( train_transform=train_transform, diff --git a/flash/video/classification/cli.py b/flash/video/classification/cli.py index 44af93fc60c..840386506be 100644 --- a/flash/video/classification/cli.py +++ b/flash/video/classification/cli.py @@ -51,11 +51,11 @@ def video_classification(): default_datamodule_builder=from_kinetics, default_arguments={ "trainer.max_epochs": 3, - } + }, ) cli.trainer.save_checkpoint("video_classification.pt") -if __name__ == '__main__': +if __name__ == "__main__": video_classification() diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index b062d31bac6..90c6351dd97 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -55,10 +55,9 @@ class BaseVideoClassification(object): - def __init__( self, - clip_sampler: 'ClipSampler', + clip_sampler: "ClipSampler", video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", @@ -68,12 +67,12 @@ def __init__( self.decode_audio = decode_audio self.decoder = decoder - def load_data(self, data: str, dataset: Optional[Any] = None) -> 'LabeledVideoDataset': + def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDataset": ds = self._make_encoded_video_dataset(data) if self.training: label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels} self.set_state(LabelsState(label_to_class_mapping)) - dataset.num_classes = len(np.unique([s[1]['label'] for s in ds._labeled_videos])) + dataset.num_classes = len(np.unique([s[1]["label"] for s in ds._labeled_videos])) return ds def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -110,20 +109,17 @@ def _encoded_video_to_dict(self, video, annotation: Optional[Dict[str, Any]] = N "video_index": 0, "clip_index": clip_index, "aug_index": aug_index, - **({ - "audio": audio_samples - } if audio_samples is not None else {}), + **({"audio": audio_samples} if audio_samples is not None else {}), } - def _make_encoded_video_dataset(self, data) -> 'LabeledVideoDataset': + def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset": raise NotImplementedError("Subclass must implement _make_encoded_video_dataset()") class VideoClassificationPathsDataSource(BaseVideoClassification, PathsDataSource): - def __init__( self, - clip_sampler: 'ClipSampler', + clip_sampler: "ClipSampler", video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", @@ -139,7 +135,7 @@ def __init__( extensions=("mp4", "avi"), ) - def _make_encoded_video_dataset(self, data) -> 'LabeledVideoDataset': + def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset": ds: LabeledVideoDataset = labeled_video_dataset( pathlib.Path(data), self.clip_sampler, @@ -154,10 +150,9 @@ class VideoClassificationFiftyOneDataSource( BaseVideoClassification, FiftyOneDataSource, ): - def __init__( self, - clip_sampler: 'ClipSampler', + clip_sampler: "ClipSampler", video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", @@ -178,7 +173,7 @@ def __init__( def label_cls(self): return fol.Classification - def _make_encoded_video_dataset(self, data: SampleCollection) -> 'LabeledVideoDataset': + def _make_encoded_video_dataset(self, data: SampleCollection) -> "LabeledVideoDataset": classes = self._get_classes(data) label_to_class_mapping = dict(enumerate(classes)) class_to_label_mapping = {c: lab for lab, c in label_to_class_mapping.items()} @@ -199,14 +194,13 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> 'LabeledVideoDa class VideoClassificationPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - clip_sampler: Union[str, 'ClipSampler'] = "random", + clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, @@ -275,7 +269,7 @@ def get_state_dict(self) -> Dict[str, Any]: } @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess': + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> "VideoClassificationPreprocess": return cls(**state_dict) def default_transforms(self) -> Dict[str, Callable]: @@ -290,22 +284,26 @@ def default_transforms(self) -> Dict[str, Callable]: ] return { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), + "post_tensor_transform": Compose( + [ + ApplyTransformToKey( + key="video", + transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform), + ), + ] + ), + "per_batch_transform_on_device": Compose( + [ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + data_format="BCTHW", + same_on_frame=False, + ), + ), + ] + ), } diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 483e4f8e93b..e6b3b77cf93 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -36,6 +36,7 @@ if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.models import hub + for fn_name in dir(hub): if "__" not in fn_name: fn = getattr(hub, fn_name) @@ -44,7 +45,6 @@ class VideoClassifierFinetuning(BaseFinetuning): - def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1): super().__init__() self.num_layers = num_layers @@ -52,7 +52,7 @@ def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: i self.unfreeze_epoch = unfreeze_epoch def freeze_before_training(self, pl_module: LightningModule) -> None: - self.freeze(modules=list(pl_module.backbone.children())[:-self.num_layers], train_bn=self.train_bn) + self.freeze(modules=list(pl_module.backbone.children())[: -self.num_layers], train_bn=self.train_bn) def finetune_function( self, @@ -64,7 +64,7 @@ def finetune_function( if epoch != self.unfreeze_epoch: return self.unfreeze_and_add_param_group( - modules=list(pl_module.backbone.children())[-self.num_layers:], + modules=list(pl_module.backbone.children())[-self.num_layers :], optimizer=optimizer, train_bn=self.train_bn, ) @@ -110,7 +110,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - serializer=serializer or Labels() + serializer=serializer or Labels(), ) self.save_hyperparameters() diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py index b8f0f8a312d..9cd53e45845 100644 --- a/flash_examples/audio_classification.py +++ b/flash_examples/audio_classification.py @@ -34,11 +34,13 @@ trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) # 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c -predictions = model.predict([ - "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", - "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", - "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", -]) +predictions = model.predict( + [ + "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", + "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", + "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index 2ab29f65269..837cf8afa8f 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -35,7 +35,6 @@ class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.2, metrics=None): # what kind of model do we want? model = nn.Linear(num_inputs, 1) @@ -85,7 +84,6 @@ def forward(self, x): class NumpyDataSource(DataSource[Tuple[ND, ND]]): - def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: if self.training: dataset.num_inputs = data[0].shape[1] @@ -97,7 +95,6 @@ def predict_load_data(data: ND) -> List[Dict[str, Any]]: class NumpyPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -163,13 +160,15 @@ class NumpyDataModule(flash.DataModule): trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False) trainer.fit(model, datamodule=datamodule) -predict_data = np.array([ - [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], - [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], - [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], - [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], -]) +predict_data = np.array( + [ + [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], + [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], + [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], + [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], + [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], + ] +) predictions = model.predict(predict_data) print(predictions) diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index a675938c57b..97780a4b8c9 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -31,11 +31,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict what's on a few images! ants or bees? -predictions = model.predict([ - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", -]) +predictions = model.predict( + [ + "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", + "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", + "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py index 307b8fe7ced..82d5e488a63 100644 --- a/flash_examples/image_classification_multi_label.py +++ b/flash_examples/image_classification_multi_label.py @@ -36,11 +36,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict the genre of a few movies! -predictions = model.predict([ - "data/movie_posters/predict/tt0085318.jpg", - "data/movie_posters/predict/tt0089461.jpg", - "data/movie_posters/predict/tt0097179.jpg", -]) +predictions = model.predict( + [ + "data/movie_posters/predict/tt0085318.jpg", + "data/movie_posters/predict/tt0089461.jpg", + "data/movie_posters/predict/tt0097179.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 118bdc5c67a..9e65aab098e 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -33,11 +33,13 @@ trainer.finetune(model, datamodule=datamodule) # 4. Detect objects in a few images! -predictions = model.predict([ - "data/coco128/images/train2017/000000000625.jpg", - "data/coco128/images/train2017/000000000626.jpg", - "data/coco128/images/train2017/000000000629.jpg", -]) +predictions = model.predict( + [ + "data/coco128/images/train2017/000000000625.jpg", + "data/coco128/images/train2017/000000000626.jpg", + "data/coco128/images/train2017/000000000629.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/pointcloud_detection.py b/flash_examples/pointcloud_detection.py index 4b4cc55d1f0..7c65735bd42 100644 --- a/flash_examples/pointcloud_detection.py +++ b/flash_examples/pointcloud_detection.py @@ -32,10 +32,12 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict([ - "data/KITTI_Tiny/Kitti/predict/scans/000000.bin", - "data/KITTI_Tiny/Kitti/predict/scans/000001.bin", -]) +predictions = model.predict( + [ + "data/KITTI_Tiny/Kitti/predict/scans/000000.bin", + "data/KITTI_Tiny/Kitti/predict/scans/000001.bin", + ] +) # 5. Save the model! trainer.save_checkpoint("pointcloud_detection_model.pt") diff --git a/flash_examples/pointcloud_segmentation.py b/flash_examples/pointcloud_segmentation.py index f316cc91086..95ba45fcc66 100644 --- a/flash_examples/pointcloud_segmentation.py +++ b/flash_examples/pointcloud_segmentation.py @@ -21,7 +21,7 @@ datamodule = PointCloudSegmentationData.from_folders( train_folder="data/SemanticKittiTiny/train", - val_folder='data/SemanticKittiTiny/val', + val_folder="data/SemanticKittiTiny/val", ) # 2. Build the task @@ -32,10 +32,12 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict([ - "data/SemanticKittiTiny/predict/000000.bin", - "data/SemanticKittiTiny/predict/000001.bin", -]) +predictions = model.predict( + [ + "data/SemanticKittiTiny/predict/000000.bin", + "data/SemanticKittiTiny/predict/000001.bin", + ] +) # 5. Save the model! trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_examples/semantic_segmentation.py b/flash_examples/semantic_segmentation.py index 65bb56b89d5..7b3b21421b3 100644 --- a/flash_examples/semantic_segmentation.py +++ b/flash_examples/semantic_segmentation.py @@ -20,7 +20,7 @@ # More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge download_data( "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", - "./data" + "./data", ) datamodule = SemanticSegmentationData.from_folders( @@ -43,11 +43,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Segment a few images! -predictions = model.predict([ - "data/CameraRGB/F61-1.png", - "data/CameraRGB/F62-1.png", - "data/CameraRGB/F63-1.png", -]) +predictions = model.predict( + [ + "data/CameraRGB/F61-1.png", + "data/CameraRGB/F62-1.png", + "data/CameraRGB/F63-1.png", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/serve/generic/boston_prediction/inference_server.py b/flash_examples/serve/generic/boston_prediction/inference_server.py index 1e1d958e9f3..acd1735ae93 100644 --- a/flash_examples/serve/generic/boston_prediction/inference_server.py +++ b/flash_examples/serve/generic/boston_prediction/inference_server.py @@ -35,7 +35,6 @@ class PricePrediction(ModelComponent): - def __init__(self, model): # skipcq: PYL-W0621 self.model = model diff --git a/flash_examples/serve/generic/detection/inference.py b/flash_examples/serve/generic/detection/inference.py index 0971fb380c7..813359a6dc0 100644 --- a/flash_examples/serve/generic/detection/inference.py +++ b/flash_examples/serve/generic/detection/inference.py @@ -18,16 +18,12 @@ class ObjectDetection(ModelComponent): - def __init__(self, model): self.model = model @expose( inputs={"img": Image()}, - outputs={ - "boxes": Repeated(BBox()), - "labels": Repeated(Label("classes.txt")) - }, + outputs={"boxes": Repeated(BBox()), "labels": Repeated(Label("classes.txt"))}, ) def detect(self, img): img = img.permute(0, 3, 2, 1).float() / 255 diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py index f6aac866e2e..4b58b8f6913 100644 --- a/flash_examples/serve/tabular_classification/inference_server.py +++ b/flash_examples/serve/tabular_classification/inference_server.py @@ -15,5 +15,5 @@ from flash.tabular import TabularClassifier model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") -model.serializer = Labels(['Did not survive', 'Survived']) +model.serializer = Labels(["Did not survive", "Survived"]) model.serve() diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py index a22282920a1..f084ebac3a2 100644 --- a/flash_examples/speech_recognition.py +++ b/flash_examples/speech_recognition.py @@ -30,7 +30,7 @@ # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1) -trainer.finetune(model, datamodule=datamodule, strategy='no_freeze') +trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") # 4. Predict on audio files! predictions = model.predict(["data/timit/example.wav"]) diff --git a/flash_examples/style_transfer.py b/flash_examples/style_transfer.py index 37500e93589..1e60a9f844c 100644 --- a/flash_examples/style_transfer.py +++ b/flash_examples/style_transfer.py @@ -30,11 +30,13 @@ trainer.fit(model, datamodule=datamodule) # 4. Apply style transfer to a few images! -predictions = model.predict([ - "data/coco128/images/train2017/000000000625.jpg", - "data/coco128/images/train2017/000000000626.jpg", - "data/coco128/images/train2017/000000000629.jpg", -]) +predictions = model.predict( + [ + "data/coco128/images/train2017/000000000625.jpg", + "data/coco128/images/train2017/000000000626.jpg", + "data/coco128/images/train2017/000000000629.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/template.py b/flash_examples/template.py index 66ce579a832..978a341843b 100644 --- a/flash_examples/template.py +++ b/flash_examples/template.py @@ -31,11 +31,13 @@ trainer.fit(model, datamodule=datamodule) # 4. Classify a few examples -predictions = model.predict([ - np.array([4.9, 3.0, 1.4, 0.2]), - np.array([6.9, 3.2, 5.7, 2.3]), - np.array([7.2, 3.0, 5.8, 1.6]), -]) +predictions = model.predict( + [ + np.array([4.9, 3.0, 1.4, 0.2]), + np.array([6.9, 3.2, 5.7, 2.3]), + np.array([7.2, 3.0, 5.8, 1.6]), + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 1924d408deb..1ba19367589 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -34,11 +34,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Classify a few sentences! How was the movie? -predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado.", -]) +predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification_multi_label.py b/flash_examples/text_classification_multi_label.py index b9dab3944ea..80859efccde 100644 --- a/flash_examples/text_classification_multi_label.py +++ b/flash_examples/text_classification_multi_label.py @@ -40,11 +40,13 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Generate predictions for a few comments! -predictions = model.predict([ - "No, he is an arrogant, self serving, immature idiot. Get it right.", - "U SUCK HANNAH MONTANA", - "Would you care to vote? Thx.", -]) +predictions = model.predict( + [ + "No, he is an arrogant, self serving, immature idiot. Get it right.", + "U SUCK HANNAH MONTANA", + "Would you care to vote? Thx.", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/translation.py b/flash_examples/translation.py index 2a0d7889f24..a246fff1029 100644 --- a/flash_examples/translation.py +++ b/flash_examples/translation.py @@ -34,11 +34,13 @@ trainer.finetune(model, datamodule=datamodule) # 4. Translate something! -predictions = model.predict([ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", - "Of course, it's still early in the election cycle.", -]) +predictions = model.predict( + [ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + "Of course, it's still early in the election cycle.", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py index 85565a7027b..d7d0fcd04ec 100644 --- a/flash_examples/visualizations/pointcloud_segmentation.py +++ b/flash_examples/visualizations/pointcloud_segmentation.py @@ -21,7 +21,7 @@ datamodule = PointCloudSegmentationData.from_folders( train_folder="data/SemanticKittiTiny/train", - val_folder='data/SemanticKittiTiny/val', + val_folder="data/SemanticKittiTiny/val", ) # 2. Build the task @@ -32,10 +32,12 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict([ - "data/SemanticKittiTiny/predict/000000.bin", - "data/SemanticKittiTiny/predict/000001.bin", -]) +predictions = model.predict( + [ + "data/SemanticKittiTiny/predict/000000.bin", + "data/SemanticKittiTiny/predict/000001.bin", + ] +) # 5. Save the model! trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/setup.py b/setup.py index b5106c05b6d..96fb1a61645 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,8 @@ def _load_py_module(fname, pkg="flash"): return py -about = _load_py_module('__about__.py') -setup_tools = _load_py_module('setup_tools.py') +about = _load_py_module("__about__.py") +setup_tools = _load_py_module("setup_tools.py") long_description = setup_tools._load_readme_description( _PATH_ROOT, @@ -84,12 +84,12 @@ def _load_py_module(fname, pkg="flash"): include_package_data=True, extras_require=extras, entry_points={ - 'console_scripts': ['flash=flash.__main__:main'], + "console_scripts": ["flash=flash.__main__:main"], }, zip_safe=False, keywords=["deep learning", "pytorch", "AI"], python_requires=">=3.6", - install_requires=setup_tools._load_requirements(_PATH_ROOT, file_name='requirements.txt'), + install_requires=setup_tools._load_requirements(_PATH_ROOT, file_name="requirements.txt"), project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/lightning-flash/issues", "Documentation": "https://lightning-flash.rtfd.io/en/latest/", diff --git a/tests/__init__.py b/tests/__init__.py index c64310c910f..2be74bcdc7e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,5 +2,5 @@ # TorchVision hotfix https://github.com/pytorch/vision/issues/1938 opener = urllib.request.build_opener() -opener.addheaders = [('User-agent', 'Mozilla/5.0')] +opener.addheaders = [("User-agent", "Mozilla/5.0")] urllib.request.install_opener(opener) diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index a1c0ba06775..d18a588e5d6 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -64,9 +64,9 @@ def test_from_filepaths_smoke(tmpdir): assert spectrograms_data.test_dataloader() is None data = next(iter(spectrograms_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [1, 2] @@ -96,24 +96,24 @@ def test_from_filepaths_list_image_paths(tmpdir): # check training data data = next(iter(spectrograms_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here # check validation data data = next(iter(spectrograms_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(spectrograms_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [2, 5] @@ -201,7 +201,7 @@ def test_from_filepaths_splits(tmpdir): _rand_image(img_size).save(tmpdir / "s.png") num_samples: int = 10 - val_split: float = .3 + val_split: float = 0.3 train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] @@ -212,7 +212,7 @@ def test_from_filepaths_splits(tmpdir): _to_tensor = { "to_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), } @@ -228,9 +228,9 @@ def run(transform: Any = None): spectrogram_size=img_size, ) data = next(iter(dm.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (B, 3, H, W) - assert labels.shape == (B, ) + assert labels.shape == (B,) run(_to_tensor) @@ -251,9 +251,9 @@ def test_from_folders_only_train(tmpdir): spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(spectrograms_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert labels.shape == (1,) assert spectrograms_data.val_dataloader() is None assert spectrograms_data.test_dataloader() is None @@ -281,20 +281,20 @@ def test_from_folders_train_val(tmpdir): ) data = next(iter(spectrograms_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) data = next(iter(spectrograms_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] data = next(iter(spectrograms_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] @@ -323,18 +323,18 @@ def test_from_filepaths_multilabel(tmpdir): ) data = next(iter(dm.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) data = next(iter(dm.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(test_labels)) diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py index 2b871292107..6205da309dd 100644 --- a/tests/audio/speech_recognition/test_data.py +++ b/tests/audio/speech_recognition/test_data.py @@ -23,7 +23,7 @@ from tests.helpers.utils import _AUDIO_TESTING path = str(Path(flash.ASSETS_ROOT) / "example.wav") -sample = {'file': path, 'text': 'example input.'} +sample = {"file": path, "text": "example input."} TEST_CSV_DATA = f"""file,text {path},example input. @@ -42,8 +42,8 @@ def csv_data(tmpdir): def json_data(tmpdir, n_samples=5): path = Path(tmpdir) / "data.json" - with path.open('w') as f: - f.write('\n'.join([json.dumps(sample) for x in range(n_samples)])) + with path.open("w") as f: + f.write("\n".join([json.dumps(sample) for x in range(n_samples)])) return path diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py index 0c9773022dc..eda3ac86b3b 100644 --- a/tests/audio/speech_recognition/test_data_model_integration.py +++ b/tests/audio/speech_recognition/test_data_model_integration.py @@ -25,7 +25,7 @@ TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing path = str(Path(flash.ASSETS_ROOT) / "example.wav") -sample = {'file': path, 'text': 'example input.'} +sample = {"file": path, "text": "example input."} TEST_CSV_DATA = f"""file,text {path},example input. @@ -44,8 +44,8 @@ def csv_data(tmpdir): def json_data(tmpdir, n_samples=5): path = Path(tmpdir) / "data.json" - with path.open('w') as f: - f.write('\n'.join([json.dumps(sample) for x in range(n_samples)])) + with path.open("w") as f: + f.write("\n".join([json.dumps(sample) for x in range(n_samples)])) return path diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index c5e204adb45..f1b1f55ee52 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -30,14 +30,11 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { DefaultDataKeys.INPUT: np.random.randn(86631), DefaultDataKeys.TARGET: "some target text", - DefaultDataKeys.METADATA: { - "sampling_rate": 16000 - }, + DefaultDataKeys.METADATA: {"sampling_rate": 16000}, } def __len__(self) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index b32e74d5246..43fd8dc824f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,7 +80,7 @@ def lightning_squeezenet1_1_obj(): def squeezenet_servable(squeezenet1_1_model, session_global_datadir): from flash.core.serve import Servable - trace = torch.jit.trace(squeezenet1_1_model.eval(), (torch.rand(1, 3, 224, 224), )) + trace = torch.jit.trace(squeezenet1_1_model.eval(), (torch.rand(1, 3, 224, 224),)) fpth = str(session_global_datadir / "squeezenet_jit_trace.pt") torch.jit.save(trace, fpth) diff --git a/tests/core/data/test_auto_dataset.py b/tests/core/data/test_auto_dataset.py index 7acbffe671a..8571363a0a1 100644 --- a/tests/core/data/test_auto_dataset.py +++ b/tests/core/data/test_auto_dataset.py @@ -22,7 +22,6 @@ class _AutoDatasetTestDataSource(DataSource): - def __init__(self, with_dset: bool): self._callbacks: List[FlashCallback] = [] self.load_data_count = 0 diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index 20d2084b9bd..9af754eb1ca 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -37,7 +37,6 @@ def _rand_image(): class CustomBaseVisualization(BaseVisualization): - def __init__(self): super().__init__() @@ -77,7 +76,6 @@ def check_reset(self): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") class TestBaseViz: - def test_base_viz(self, tmpdir): seed_everything(42) @@ -89,7 +87,6 @@ def test_base_viz(self, tmpdir): _rand_image().save(train_images[1]) class CustomImageClassificationData(ImageClassificationData): - @staticmethod def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: return CustomBaseVisualization(*args, **kwargs) @@ -154,7 +151,7 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("per_batch_transform") - assert res[0][DefaultDataKeys.TARGET].shape == (B, ) + assert res[0][DefaultDataKeys.TARGET].shape == (B,) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called @@ -165,12 +162,13 @@ def _get_result(function_name: str): dm.data_fetcher.check_reset() @pytest.mark.parametrize( - "func_names, valid", [ + "func_names, valid", + [ (["load_sample"], True), (["not_a_hook"], False), (["load_sample", "pre_tensor_transform"], True), (["load_sample", "not_a_hook"], True), - ] + ], ) def test_show(self, func_names, valid): base_viz = CustomBaseVisualization() diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index caba5cf4a0a..a03457ed773 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -102,9 +102,9 @@ def test_tensor_batch(): def test_sequence(self): batch = { - 'a': torch.rand(self.BATCH_SIZE, 4), - 'b': torch.rand(self.BATCH_SIZE, 2), - 'c': torch.rand(self.BATCH_SIZE) + "a": torch.rand(self.BATCH_SIZE, 4), + "b": torch.rand(self.BATCH_SIZE, 2), + "c": torch.rand(self.BATCH_SIZE), } output = default_uncollate(batch) @@ -112,13 +112,13 @@ def test_sequence(self): assert len(batch) == self.BATCH_SIZE for sample in output: - assert list(sample.keys()) == ['a', 'b', 'c'] - assert isinstance(sample['a'], list) - assert len(sample['a']) == 4 - assert isinstance(sample['b'], list) - assert len(sample['b']) == 2 - assert isinstance(sample['c'], torch.Tensor) - assert len(sample['c'].shape) == 0 + assert list(sample.keys()) == ["a", "b", "c"] + assert isinstance(sample["a"], list) + assert len(sample["a"]) == 4 + assert isinstance(sample["b"], list) + assert len(sample["b"]) == 2 + assert isinstance(sample["c"], torch.Tensor) + assert len(sample["c"].shape) == 0 def test_named_tuple(self): Batch = namedtuple("Batch", ["x", "y"]) diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index e11591f33ab..e9b6b853a27 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -47,7 +47,6 @@ def test_flash_callback(_, tmpdir): ] class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -91,5 +90,5 @@ def __init__(self): call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), - call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING) + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), ] diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index b01c46a1640..07e89fec161 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -23,9 +23,7 @@ def test_base_data_fetcher(tmpdir): - class CheckData(BaseDataFetcher): - def check(self): assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] @@ -38,7 +36,6 @@ def check(self): assert self.batches["predict"] == {} class CustomDataModule(DataModule): - @staticmethod def configure_data_fetcher(): return CheckData() @@ -70,7 +67,7 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat data_fetcher.check() data_fetcher.reset() - assert data_fetcher.batches == {'train': {}, 'test': {}, 'val': {}, 'predict': {}} + assert data_fetcher.batches == {"train": {}, "test": {}, "val": {}, "predict": {}} def test_data_loaders_num_workers_to_0(tmpdir): diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index e6ca144a229..7124675f307 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -44,7 +44,6 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: return torch.rand(1), torch.rand(1) @@ -53,7 +52,6 @@ def __len__(self) -> int: class TestDataPipelineState: - @staticmethod def test_str(): state = DataPipelineState() @@ -95,9 +93,7 @@ def test_data_pipeline_str(): @pytest.mark.parametrize("use_preprocess", [False, True]) @pytest.mark.parametrize("use_postprocess", [False, True]) def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): - class CustomModel(Task): - def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess @@ -135,9 +131,7 @@ class SubPostprocess(Postprocess): def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): - class CustomPreprocess(DefaultPreprocess): - def val_pre_tensor_transform(self, *_, **__): pass @@ -258,7 +252,6 @@ def test_per_batch_transform_on_device(self, *_, **__): class CustomPreprocess(DefaultPreprocess): - def train_per_sample_transform(self, *_, **__): pass @@ -307,9 +300,7 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): def test_detach_preprocessing_from_model(tmpdir): - class CustomModel(Task): - def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess @@ -333,7 +324,6 @@ def train_dataloader(self) -> Any: class TestPreprocess(DefaultPreprocess): - def train_per_sample_transform(self, *_, **__): pass @@ -363,7 +353,6 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_attaching_datapipeline_to_model(tmpdir): - class SubPreprocess(DefaultPreprocess): pass @@ -371,7 +360,6 @@ class SubPreprocess(DefaultPreprocess): data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = Postprocess() @@ -513,8 +501,7 @@ def test_stage_orchestrator_state_attach_detach(tmpdir): _original_predict_step = model.predict_step class CustomDataPipeline(DataPipeline): - - def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postprocessor) -> 'Task': + def _attach_postprocess_to_model(self, model: "Task", _postprocesssor: _Postprocessor) -> "Task": model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model) return model @@ -528,7 +515,6 @@ def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postproc class LamdaDummyDataset(torch.utils.data.Dataset): - def __init__(self, fx: Callable): self.fx = fx @@ -540,7 +526,6 @@ def __len__(self) -> int: class TestPreprocessTransformationsDataSource(DataSource): - def __init__(self): super().__init__() @@ -589,7 +574,7 @@ def test_load_data(self, sample) -> LamdaDummyDataset: @staticmethod def fn_predict_load_data() -> List[str]: - return (["a", "b"]) + return ["a", "b"] def predict_load_data(self, sample) -> LamdaDummyDataset: assert self.predicting @@ -599,7 +584,6 @@ def predict_load_data(self, sample) -> LamdaDummyDataset: class TestPreprocessTransformations(DefaultPreprocess): - def __init__(self): super().__init__(data_sources={"default": TestPreprocessTransformationsDataSource()}) @@ -616,7 +600,7 @@ def train_pre_tensor_transform(self, sample: Any) -> Any: assert self.training assert self.current_fn == "pre_tensor_transform" self.train_pre_tensor_transform_called = True - return sample + (5, ) + return sample + (5,) def train_collate(self, samples) -> Tensor: assert self.training @@ -640,9 +624,9 @@ def val_collate(self, samples) -> Dict[str, Tensor]: assert self.validating assert self.current_fn == "collate" self.val_collate_called = True - _count = samples[0]['a'] - assert samples == [{'a': _count, 'b': _count + 1}, {'a': _count + 1, 'b': _count + 2}] - return {'a': tensor([0, 1]), 'b': tensor([1, 2])} + _count = samples[0]["a"] + assert samples == [{"a": _count, "b": _count + 1}, {"a": _count + 1, "b": _count + 2}] + return {"a": tensor([0, 1]), "b": tensor([1, 2])} def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert self.validating @@ -668,14 +652,12 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: class TestPreprocessTransformations2(TestPreprocessTransformations): - def val_to_tensor_transform(self, sample: Any) -> Tensor: self.val_to_tensor_transform_called = True return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -692,10 +674,10 @@ def test_step(self, batch, batch_idx): assert batch[0].shape == torch.Size([2, 1]) def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert batch[0][0] == 'a' - assert batch[0][1] == 'a' - assert batch[1][0] == 'b' - assert batch[1][1] == 'b' + assert batch[0][0] == "a" + assert batch[0][1] == "a" + assert batch[1][0] == "b" + assert batch[1][1] == "b" return tensor([0, 0, 0]) @@ -709,8 +691,8 @@ def test_datapipeline_transformations(tmpdir): batch = next(iter(datamodule.train_dataloader())) assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) - assert datamodule.val_dataloader().dataset[0] == {'a': 0, 'b': 1} - assert datamodule.val_dataloader().dataset[1] == {'a': 1, 'b': 2} + assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1} + assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2} with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): batch = next(iter(datamodule.val_dataloader())) @@ -728,7 +710,7 @@ def test_datapipeline_transformations(tmpdir): limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, - num_sanity_val_steps=1 + num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) trainer.test(model) @@ -752,9 +734,7 @@ def test_datapipeline_transformations(tmpdir): def test_is_overriden_recursive(tmpdir): - class TestPreprocess(DefaultPreprocess): - def collate(self, *_): pass @@ -775,9 +755,7 @@ def val_collate(self, *_): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): - class ImageDataSource(DataSource): - def load_data(self, folder: str): # from folder -> return files paths return ["a.jpg", "b.jpg"] @@ -788,7 +766,6 @@ def load_sample(self, path: str) -> Image.Image: return Image.fromarray(img8Bit) class ImageClassificationPreprocess(DefaultPreprocess): - def __init__( self, train_transform=None, @@ -817,7 +794,6 @@ def train_per_sample_transform_on_device(self, sample: Any) -> Any: return self._train_per_sample_transform_on_device(sample) class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -856,7 +832,7 @@ class CustomDataModule(DataModule): limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, - num_sanity_val_steps=1 + num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) trainer.test(model) @@ -883,13 +859,13 @@ def test_preprocess_transforms(tmpdir): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), - "per_sample_transform_on_device": torch.nn.Linear(1, 1) + "per_sample_transform_on_device": torch.nn.Linear(1, 1), } ) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, - predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)} + predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, ) # keep is None assert preprocess._train_collate_in_worker_from_transform is True @@ -908,7 +884,6 @@ def test_preprocess_transforms(tmpdir): assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): - def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) @@ -917,7 +892,7 @@ def per_batch_transform(self, batch: Any) -> Any: preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, - predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)} + predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, ) # keep is None assert preprocess._train_collate_in_worker_from_transform is True @@ -939,9 +914,7 @@ def per_batch_transform(self, batch: Any) -> Any: def test_iterable_auto_dataset(tmpdir): - class CustomDataSource(DataSource): - def load_sample(self, index: int) -> Dict[str, int]: return {"index": index} @@ -952,7 +925,6 @@ def load_sample(self, index: int) -> Dict[str, int]: class CustomPreprocessHyperparameters(DefaultPreprocess): - def __init__(self, token: str, *args, **kwargs): self.token = token super().__init__(*args, **kwargs) diff --git a/tests/core/data/test_data_source.py b/tests/core/data/test_data_source.py index 77dbb173be6..24a0b875fcd 100644 --- a/tests/core/data/test_data_source.py +++ b/tests/core/data/test_data_source.py @@ -17,7 +17,7 @@ def test_dataset_data_source(): data_source = DatasetDataSource() - input, target = 'test', 3 + input, target = "test", 3 assert data_source.load_sample((input, target)) == {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} assert data_source.load_sample(input) == {DefaultDataKeys.INPUT: input} diff --git a/tests/core/data/test_process.py b/tests/core/data/test_process.py index 7d240dcb573..509bbce3f8d 100644 --- a/tests/core/data/test_process.py +++ b/tests/core/data/test_process.py @@ -33,15 +33,15 @@ def test_serializer(): my_serializer = Serializer() - assert my_serializer.serialize('test') == 'test' + assert my_serializer.serialize("test") == "test" my_serializer.serialize = Mock() my_serializer.disable() - assert my_serializer('test') == 'test' + assert my_serializer("test") == "test" my_serializer.serialize.assert_not_called() my_serializer.enable() - my_serializer('test') + my_serializer("test") my_serializer.serialize.assert_called_once() @@ -52,24 +52,24 @@ def test_serializer_mapping(): """ serializer1 = Serializer() - serializer1.serialize = Mock(return_value='test1') + serializer1.serialize = Mock(return_value="test1") class Serializer1State(ProcessState): pass serializer2 = Serializer() - serializer2.serialize = Mock(return_value='test2') + serializer2.serialize = Mock(return_value="test2") class Serializer2State(ProcessState): pass - serializer_mapping = SerializerMapping({'key1': serializer1, 'key2': serializer2}) - assert serializer_mapping({'key1': 'serializer1', 'key2': 'serializer2'}) == {'key1': 'test1', 'key2': 'test2'} - serializer1.serialize.assert_called_once_with('serializer1') - serializer2.serialize.assert_called_once_with('serializer2') + serializer_mapping = SerializerMapping({"key1": serializer1, "key2": serializer2}) + assert serializer_mapping({"key1": "serializer1", "key2": "serializer2"}) == {"key1": "test1", "key2": "test2"} + serializer1.serialize.assert_called_once_with("serializer1") + serializer2.serialize.assert_called_once_with("serializer2") - with pytest.raises(ValueError, match='output must be a mapping'): - serializer_mapping('not a mapping') + with pytest.raises(ValueError, match="output must be a mapping"): + serializer_mapping("not a mapping") serializer1_state = Serializer1State() serializer2_state = Serializer2State() @@ -89,10 +89,9 @@ class Serializer2State(ProcessState): def test_saving_with_serializers(tmpdir): - checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') + checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -112,7 +111,6 @@ def __init__(self): class CustomPreprocess(DefaultPreprocess): - def __init__(self): super().__init__( data_sources={ diff --git a/tests/core/data/test_sampler.py b/tests/core/data/test_sampler.py index 9ee9ace3a17..3480bc2abf2 100644 --- a/tests/core/data/test_sampler.py +++ b/tests/core/data/test_sampler.py @@ -19,14 +19,14 @@ @mock.patch("flash.core.data.data_module.DataLoader") def test_dataloaders_with_sampler(mock_dataloader): - train_ds = val_ds = test_ds = 'dataset' - mock_sampler = 'sampler' + train_ds = val_ds = test_ds = "dataset" + mock_sampler = "sampler" dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler) assert dm.sampler is mock_sampler dl = dm.train_dataloader() kwargs = mock_dataloader.call_args[1] - assert 'sampler' in kwargs - assert kwargs['sampler'] is mock_sampler + assert "sampler" in kwargs + assert kwargs["sampler"] is mock_sampler for dl in [dm.val_dataloader(), dm.test_dataloader()]: kwargs = mock_dataloader.call_args[1] - assert 'sampler' not in kwargs + assert "sampler" not in kwargs diff --git a/tests/core/data/test_serialization.py b/tests/core/data/test_serialization.py index 5c368bb0b99..948f6bee13d 100644 --- a/tests/core/data/test_serialization.py +++ b/tests/core/data/test_serialization.py @@ -25,13 +25,11 @@ class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) class CustomPreprocess(DefaultPreprocess): - @classmethod def load_data(cls, data): return data @@ -40,8 +38,8 @@ def load_data(cls, data): def test_serialization_data_pipeline(tmpdir): model = CustomModel() - checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') - checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt') + checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") + checkpoint = ModelCheckpoint(tmpdir, "test.ckpt") trainer = Trainer(callbacks=[checkpoint], max_epochs=1) dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, dummy_data) @@ -69,5 +67,5 @@ def fn(*args, **kwargs): assert loaded_model.data_pipeline assert isinstance(loaded_model.preprocess, CustomPreprocess) for file in os.listdir(tmpdir): - if file.endswith('.ckpt'): + if file.endswith(".ckpt"): os.remove(os.path.join(tmpdir, file)) diff --git a/tests/core/data/test_splits.py b/tests/core/data/test_splits.py index 14e7f129935..0d58ed22289 100644 --- a/tests/core/data/test_splits.py +++ b/tests/core/data/test_splits.py @@ -28,7 +28,6 @@ def test_split_dataset(): assert len(np.unique(train_ds.indices)) == len(train_ds.indices) class Dataset: - def __init__(self): self.data = [0, 1, 2] self.name = "something" diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index f9239aa6543..b66bd41cc89 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -23,40 +23,21 @@ class TestApplyToKeys: - @pytest.mark.parametrize( - "sample, keys, expected", [ - ({ - DefaultDataKeys.INPUT: "test" - }, DefaultDataKeys.INPUT, "test"), + "sample, keys, expected", + [ + ({DefaultDataKeys.INPUT: "test"}, DefaultDataKeys.INPUT, "test"), ( - { - DefaultDataKeys.INPUT: "test_a", - DefaultDataKeys.TARGET: "test_b" - }, + {DefaultDataKeys.INPUT: "test_a", DefaultDataKeys.TARGET: "test_b"}, [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], ["test_a", "test_b"], ), - ({ - "input": "test" - }, "input", "test"), - ({ - "input": "test_a", - "target": "test_b" - }, ["input", "target"], ["test_a", "test_b"]), - ({ - "input": "test_a", - "target": "test_b", - "extra": "..." - }, ["input", "target"], ["test_a", "test_b"]), - ({ - "input": "test_a", - "target": "test_b" - }, ["input", "target", "extra"], ["test_a", "test_b"]), - ({ - "target": "..." - }, "input", None), - ] + ({"input": "test"}, "input", "test"), + ({"input": "test_a", "target": "test_b"}, ["input", "target"], ["test_a", "test_b"]), + ({"input": "test_a", "target": "test_b", "extra": "..."}, ["input", "target"], ["test_a", "test_b"]), + ({"input": "test_a", "target": "test_b"}, ["input", "target", "extra"], ["test_a", "test_b"]), + ({"target": "..."}, "input", None), + ], ) def test_forward(self, sample, keys, expected): transform = Mock(return_value=["out"] * len(keys)) @@ -67,7 +48,8 @@ def test_forward(self, sample, keys, expected): transform.assert_not_called() @pytest.mark.parametrize( - "transform, expected", [ + "transform, expected", + [ ( ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.ReLU()), "ApplyToKeys(keys=, transform=ReLU())", @@ -82,7 +64,7 @@ def test_forward(self, sample, keys, expected): ApplyToKeys(["input", "target"], torch.nn.ReLU()), "ApplyToKeys(keys=['input', 'target'], transform=ReLU())", ), - ] + ], ) def test_repr(self, transform, expected): assert repr(transform) == expected @@ -118,18 +100,9 @@ def test_kornia_parallel_transforms(with_params): def test_kornia_collate(): samples = [ - { - DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), - DefaultDataKeys.TARGET: 1 - }, - { - DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), - DefaultDataKeys.TARGET: 2 - }, - { - DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), - DefaultDataKeys.TARGET: 3 - }, + {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 1}, + {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 2}, + {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 3}, ] result = kornia_collate(samples) @@ -145,24 +118,13 @@ def test_kornia_collate(): "base_transforms, additional_transforms, expected_result", [ ( - { - "to_tensor_transform": _MOCK_TRANSFORM - }, - { - "post_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM, - "post_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM}, + {"post_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, ), ( - { - "to_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM}, { "to_tensor_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) @@ -170,33 +132,23 @@ def test_kornia_collate(): }, ), ( - { - "to_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM, - "post_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, { "to_tensor_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ), - "post_tensor_transform": _MOCK_TRANSFORM + "post_tensor_transform": _MOCK_TRANSFORM, }, ), ( - { - "to_tensor_transform": _MOCK_TRANSFORM, - "post_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM}, { "to_tensor_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ), - "post_tensor_transform": _MOCK_TRANSFORM + "post_tensor_transform": _MOCK_TRANSFORM, }, ), ], diff --git a/tests/core/serve/models.py b/tests/core/serve/models.py index 63f99327f7e..9e0e914c411 100644 --- a/tests/core/serve/models.py +++ b/tests/core/serve/models.py @@ -14,7 +14,6 @@ class LightningSqueezenet(pl.LightningModule): - def __init__(self): super().__init__() self.model = squeezenet1_1(pretrained=True).eval() @@ -24,7 +23,6 @@ def forward(self, x): class LightningSqueezenetServable(pl.LightningModule): - def __init__(self, model): super().__init__() self.model = model @@ -38,7 +36,6 @@ def _func_from_exposed(arg): class ClassificationInference(ModelComponent): - def __init__(self, model): # skipcq: PYL-W0621 self.model = model @@ -73,7 +70,6 @@ def method_from_exposed(arg): try: class ClassificationInferenceRepeated(ModelComponent): - def __init__(self, model): self.model = model @@ -92,13 +88,14 @@ def classify(self, img): img = img.permute(0, 3, 2, 1) out = self.model(img) return ([out.argmax(), out.argmax()], torch.Tensor([21])) + + except TypeError: ClassificationInferenceRepeated = None try: class ClassificationInferenceModelSequence(ModelComponent): - def __init__(self, model): self.model1 = model[0] self.model2 = model[1] @@ -117,13 +114,14 @@ def classify(self, img): out2 = self.model2(img) assert out.argmax() == out2.argmax() return out.argmax() + + except TypeError: ClassificationInferenceRepeated = None try: class ClassificationInferenceModelMapping(ModelComponent): - def __init__(self, model): self.model1 = model["model_one"] self.model2 = model["model_two"] @@ -142,13 +140,14 @@ def classify(self, img): out2 = self.model2(img) assert out.argmax() == out2.argmax() return out.argmax() + + except TypeError: ClassificationInferenceModelMapping = None try: class ClassificationInferenceComposable(ModelComponent): - def __init__(self, model): self.model = model @@ -171,13 +170,14 @@ def classify(self, img, tag): out = self.model(img_new) return out.argmax(), img + + except TypeError: ClassificationInferenceComposable = None try: class SeatClassifier(ModelComponent): - def __init__(self, model, config): self.sport = config["sport"] @@ -197,5 +197,7 @@ def predict(self, section, isle, row, stadium): seat_num = section.item() * isle.item() * row.item() * stadium * len(self.sport) stadium_idx = torch.tensor(1000) return torch.Tensor([seat_num]), stadium_idx + + except TypeError: SeatClassifier = None diff --git a/tests/core/serve/test_compat/test_cached_property.py b/tests/core/serve/test_compat/test_cached_property.py index c6c909bdf86..b708fa81891 100644 --- a/tests/core/serve/test_compat/test_cached_property.py +++ b/tests/core/serve/test_compat/test_cached_property.py @@ -79,7 +79,6 @@ def cost(self): # noinspection PyStatementEffect @pytest.mark.skipif(sys.version_info >= (3, 8), reason="Python 3.8+ uses standard library implementation.") class TestCachedProperty: - @staticmethod def test_cached(): item = CachedCostItem() @@ -125,7 +124,6 @@ def test_object_with_slots(): @staticmethod def test_immutable_dict(): - class MyMeta(type): """Test metaclass.""" @@ -214,7 +212,6 @@ def test_doc(): @pytest.mark.skipif(sys.version_info < (3, 8), reason="Validate, that python 3.8 uses standard implementation") class TestPy38Plus: - @staticmethod def test_is(): import functools diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py index a32773726f8..f31f89c84a6 100644 --- a/tests/core/serve/test_components.py +++ b/tests/core/serve/test_components.py @@ -21,12 +21,14 @@ def test_model_compute_dependencies(lightning_squeezenet1_1_obj): comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp1.inputs.tag << comp2.outputs.predicted_tag - res = [{ - "source_component": "callnum_2", - "source_key": "predicted_tag", - "target_component": "callnum_1", - "target_key": "tag", - }] + res = [ + { + "source_component": "callnum_2", + "source_key": "predicted_tag", + "target_component": "callnum_1", + "target_key": "tag", + } + ] assert list(map(lambda x: x._asdict(), comp1._flashserve_meta_.connections)) == res assert list(comp2._flashserve_meta_.connections) == [] @@ -38,12 +40,14 @@ def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_ob comp2.outputs.predicted_tag >> comp1.inputs.tag - res = [{ - "source_component": "callnum_2", - "source_key": "predicted_tag", - "target_component": "callnum_1", - "target_key": "tag", - }] + res = [ + { + "source_component": "callnum_2", + "source_key": "predicted_tag", + "target_component": "callnum_1", + "target_key": "tag", + } + ] assert list(map(lambda x: x._asdict(), comp2._flashserve_meta_.connections)) == res assert list(comp1._flashserve_meta_.connections) == [] @@ -74,7 +78,6 @@ def test_two_component_invalid_dependencies_fail(lightning_squeezenet1_1_obj): comp2.outputs.predicted_tag >> comp1.outputs.predicted_tag class Foo: - def __init__(self): pass @@ -128,7 +131,6 @@ def test_invalid_expose_inputs(): with pytest.raises(SyntaxError, match="must be valid python attribute"): class ComposeClassInvalidExposeNameKeyword(ModelComponent): - def __init__(self, model): pass @@ -142,7 +144,6 @@ def predict(param): with pytest.raises(AttributeError, match="object has no attribute"): class ComposeClassInvalidExposeNameType(ModelComponent): - def __init__(self, model): pass @@ -156,7 +157,6 @@ def predict(param): with pytest.raises(TypeError, match="`expose` values must be"): class ComposeClassInvalidExposeInputsType(ModelComponent): - def __init__(self, model): pass @@ -170,7 +170,6 @@ def predict(param): with pytest.raises(ValueError, match="cannot set dict of length < 1"): class ComposeClassEmptyExposeInputsType(ModelComponent): - def __init__(self, model): pass @@ -206,7 +205,6 @@ def test_invalid_name(lightning_squeezenet1_1_obj): with pytest.raises(SyntaxError): class FailedExposedOutputsKeyworkName(ModelComponent): - def __init__(self, model): self.model = model @@ -222,7 +220,6 @@ def test_invalid_config_args(lightning_squeezenet1_1_obj): from flash.core.serve.types import Number class SomeComponent(ModelComponent): - def __init__(self, model, config=None): self.model = model self.config = config @@ -250,7 +247,6 @@ def test_invalid_model_args(lightning_squeezenet1_1_obj): from flash.core.serve.types import Number class SomeComponent(ModelComponent): - def __init__(self, model): self.model = model diff --git a/tests/core/serve/test_composition.py b/tests/core/serve/test_composition.py index 5679859ee22..c354e64f2f3 100644 --- a/tests/core/serve/test_composition.py +++ b/tests/core/serve/test_composition.py @@ -23,10 +23,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj): actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()} assert actual_endpoints == { "classify_ENDPOINT": { - "inputs": { - "img": "callnum_1.inputs.img", - "tag": "callnum_1.inputs.tag" - }, + "inputs": {"img": "callnum_1.inputs.img", "tag": "callnum_1.inputs.tag"}, "outputs": { "cropped_img": "callnum_1.outputs.cropped_img", "predicted_tag": "callnum_1.outputs.predicted_tag", @@ -50,10 +47,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj): actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()} assert actual_endpoints == { "predict_ep": { - "inputs": { - "label_1": "callnum_1.inputs.img", - "tag_1": "callnum_1.inputs.tag" - }, + "inputs": {"label_1": "callnum_1.inputs.img", "tag_1": "callnum_1.inputs.tag"}, "outputs": { "cropped": "callnum_1.outputs.cropped_img", "prediction": "callnum_1.outputs.predicted_tag", @@ -381,21 +375,13 @@ def test_start_server_from_composition(tmp_path, squeezenet_servable, session_gl data = { "session": "session_uuid", "payload": { - "img_1": { - "data": cat_imgstr - }, - "img_2": { - "data": fish_imgstr - }, - "tag_1": { - "label": "stingray" - }, + "img_1": {"data": cat_imgstr}, + "img_2": {"data": fish_imgstr}, + "tag_1": {"label": "stingray"}, }, } expected_response = { - "result": { - "prediction": "goldfish, Carassius auratus" - }, + "result": {"prediction": "goldfish, Carassius auratus"}, "session": "session_uuid", } diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py index fa61545bdb2..673dce81064 100644 --- a/tests/core/serve/test_dag/test_optimization.py +++ b/tests/core/serve/test_dag/test_optimization.py @@ -60,12 +60,14 @@ def test_fuse(): "b": 2, } assert fuse(d, rename_keys=False) == with_deps({"w": (inc, (inc, (inc, (add, "a", "b")))), "a": 1, "b": 2}) - assert fuse(d, rename_keys=True) == with_deps({ - "z-y-x-w": (inc, (inc, (inc, (add, "a", "b")))), - "a": 1, - "b": 2, - "w": "z-y-x-w", - }) + assert fuse(d, rename_keys=True) == with_deps( + { + "z-y-x-w": (inc, (inc, (inc, (add, "a", "b")))), + "a": 1, + "b": 2, + "w": "z-y-x-w", + } + ) d = { "NEW": (inc, "y"), @@ -76,22 +78,26 @@ def test_fuse(): "a": 1, "b": 2, } - assert fuse(d, rename_keys=False) == with_deps({ - "NEW": (inc, "y"), - "w": (inc, (inc, "y")), - "y": (inc, (add, "a", "b")), - "a": 1, - "b": 2, - }) - assert fuse(d, rename_keys=True) == with_deps({ - "NEW": (inc, "z-y"), - "x-w": (inc, (inc, "z-y")), - "z-y": (inc, (add, "a", "b")), - "a": 1, - "b": 2, - "w": "x-w", - "y": "z-y", - }) + assert fuse(d, rename_keys=False) == with_deps( + { + "NEW": (inc, "y"), + "w": (inc, (inc, "y")), + "y": (inc, (add, "a", "b")), + "a": 1, + "b": 2, + } + ) + assert fuse(d, rename_keys=True) == with_deps( + { + "NEW": (inc, "z-y"), + "x-w": (inc, (inc, "z-y")), + "z-y": (inc, (add, "a", "b")), + "a": 1, + "b": 2, + "w": "x-w", + "y": "z-y", + } + ) d = { "v": (inc, "y"), @@ -105,24 +111,28 @@ def test_fuse(): "c": 1, "d": 2, } - assert fuse(d, rename_keys=False) == with_deps({ - "u": (inc, (inc, (inc, "y"))), - "v": (inc, "y"), - "y": (inc, (add, "a", "b")), - "a": (inc, 1), - "b": (inc, 2), - }) - assert fuse(d, rename_keys=True) == with_deps({ - "x-w-u": (inc, (inc, (inc, "z-y"))), - "v": (inc, "z-y"), - "z-y": (inc, (add, "c-a", "d-b")), - "c-a": (inc, 1), - "d-b": (inc, 2), - "a": "c-a", - "b": "d-b", - "u": "x-w-u", - "y": "z-y", - }) + assert fuse(d, rename_keys=False) == with_deps( + { + "u": (inc, (inc, (inc, "y"))), + "v": (inc, "y"), + "y": (inc, (add, "a", "b")), + "a": (inc, 1), + "b": (inc, 2), + } + ) + assert fuse(d, rename_keys=True) == with_deps( + { + "x-w-u": (inc, (inc, (inc, "z-y"))), + "v": (inc, "z-y"), + "z-y": (inc, (add, "c-a", "d-b")), + "c-a": (inc, 1), + "d-b": (inc, 2), + "a": "c-a", + "b": "d-b", + "u": "x-w-u", + "y": "z-y", + } + ) d = { "a": (inc, "x"), @@ -132,20 +142,19 @@ def test_fuse(): "x": (inc, "y"), "y": 0, } - assert fuse(d, rename_keys=False) == with_deps({ - "a": (inc, "x"), - "b": (inc, "x"), - "d": (inc, (inc, "x")), - "x": (inc, 0) - }) - assert fuse(d, rename_keys=True) == with_deps({ - "a": (inc, "y-x"), - "b": (inc, "y-x"), - "c-d": (inc, (inc, "y-x")), - "y-x": (inc, 0), - "d": "c-d", - "x": "y-x", - }) + assert fuse(d, rename_keys=False) == with_deps( + {"a": (inc, "x"), "b": (inc, "x"), "d": (inc, (inc, "x")), "x": (inc, 0)} + ) + assert fuse(d, rename_keys=True) == with_deps( + { + "a": (inc, "y-x"), + "b": (inc, "y-x"), + "c-d": (inc, (inc, "y-x")), + "y-x": (inc, 0), + "d": "c-d", + "x": "y-x", + } + ) d = {"a": 1, "b": (inc, "a"), "c": (add, "b", "b")} assert fuse(d, rename_keys=False) == with_deps({"b": (inc, 1), "c": (add, "b", "b")}) @@ -168,21 +177,19 @@ def test_fuse_keys(): "b": 2, } keys = ["x", "z"] - assert fuse(d, keys, rename_keys=False) == with_deps({ - "w": (inc, "x"), - "x": (inc, (inc, "z")), - "z": (add, "a", "b"), - "a": 1, - "b": 2 - }) - assert fuse(d, keys, rename_keys=True) == with_deps({ - "w": (inc, "y-x"), - "y-x": (inc, (inc, "z")), - "z": (add, "a", "b"), - "a": 1, - "b": 2, - "x": "y-x", - }) + assert fuse(d, keys, rename_keys=False) == with_deps( + {"w": (inc, "x"), "x": (inc, (inc, "z")), "z": (add, "a", "b"), "a": 1, "b": 2} + ) + assert fuse(d, keys, rename_keys=True) == with_deps( + { + "w": (inc, "y-x"), + "y-x": (inc, (inc, "z")), + "z": (add, "a", "b"), + "a": 1, + "b": 2, + "x": "y-x", + } + ) def test_inline(): @@ -238,9 +245,7 @@ def test_inline_ignores_curries_and_partials(): def test_inline_functions_non_hashable(): - class NonHashableCallable: - def __call__(self, a): return a + 1 @@ -277,7 +282,6 @@ def test_inline_functions_protects_output_keys(): def test_functions_of(): - def a(x): return x @@ -290,7 +294,7 @@ def b(x): assert functions_of((a, [[[(b, 1)]]])) == {a, b} assert functions_of(1) == set() assert functions_of(a) == set() - assert functions_of((a, )) == {a} + assert functions_of((a,)) == {a} def test_inline_cull_dependencies(): @@ -301,7 +305,6 @@ def test_inline_cull_dependencies(): def test_fuse_reductions_single_input(): - def f(*args): return args @@ -309,11 +312,9 @@ def f(*args): assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "c": (f, (f, "a"), (f, "a", "a"))}) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c": (f, (f, "a"), (f, "a", "a")), - "c": "b1-b2-c" - }) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + {"a": 1, "b1-b2-c": (f, (f, "a"), (f, "a", "a")), "c": "b1-b2-c"} + ) d = { "a": 1, @@ -324,25 +325,24 @@ def f(*args): } assert fuse(d, ave_width=2.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=2.9, rename_keys=True) == with_deps(d) - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a": 1, - "c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")) - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-b3-c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")), - "c": "b1-b2-b3-c", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + {"a": 1, "c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a"))} + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-b3-c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")), + "c": "b1-b2-b3-c", + } + ) d = {"a": 1, "b1": (f, "a"), "b2": (f, "a"), "c": (f, "a", "b1", "b2")} assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "c": (f, "a", (f, "a"), (f, "a"))}) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c": (f, "a", (f, "a"), (f, "a")), - "c": "b1-b2-c" - }) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + {"a": 1, "b1-b2-c": (f, "a", (f, "a"), (f, "a")), "c": "b1-b2-c"} + ) d = { "a": 1, @@ -355,18 +355,18 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - assert fuse(d, ave_width=2, rename_keys=False) == with_deps({ - "a": 1, - "c": (f, (f, "a"), (f, "a")), - "e": (f, (f, "c"), (f, "c")) - }) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c": (f, (f, "a"), (f, "a")), - "d1-d2-e": (f, (f, "c"), (f, "c")), - "c": "b1-b2-c", - "e": "d1-d2-e", - }) + assert fuse(d, ave_width=2, rename_keys=False) == with_deps( + {"a": 1, "c": (f, (f, "a"), (f, "a")), "e": (f, (f, "c"), (f, "c"))} + ) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-c": (f, (f, "a"), (f, "a")), + "d1-d2-e": (f, (f, "c"), (f, "c")), + "c": "b1-b2-c", + "e": "d1-d2-e", + } + ) d = { "a": 1, @@ -380,37 +380,42 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - expected = with_deps({ - "a": 1, - "c1": (f, (f, "a"), (f, "a")), - "c2": (f, (f, "a"), (f, "a")), - "d": (f, "c1", "c2"), - }) + expected = with_deps( + { + "a": 1, + "c1": (f, (f, "a"), (f, "a")), + "c2": (f, (f, "a"), (f, "a")), + "d": (f, "c1", "c2"), + } + ) assert fuse(d, ave_width=2, rename_keys=False) == expected assert fuse(d, ave_width=2.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-c1": (f, (f, "a"), (f, "a")), - "b3-b4-c2": (f, (f, "a"), (f, "a")), - "d": (f, "c1", "c2"), - "c1": "b1-b2-c1", - "c2": "b3-b4-c2", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-c1": (f, (f, "a"), (f, "a")), + "b3-b4-c2": (f, (f, "a"), (f, "a")), + "d": (f, "c1", "c2"), + "c1": "b1-b2-c1", + "c2": "b3-b4-c2", + } + ) assert fuse(d, ave_width=2, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a": 1, - "d": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))) - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-b3-b4-c1-c2-d": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "d": "b1-b2-b3-b4-c1-c2-d", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + {"a": 1, "d": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a")))} + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-b3-b4-c1-c2-d": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "d": "b1-b2-b3-b4-c1-c2-d", + } + ) d = { "a": 1, @@ -432,77 +437,89 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - expected = with_deps({ - "a": 1, - "c1": (f, (f, "a"), (f, "a")), - "c2": (f, (f, "a"), (f, "a")), - "c3": (f, (f, "a"), (f, "a")), - "c4": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "e": (f, "d1", "d2"), - }) + expected = with_deps( + { + "a": 1, + "c1": (f, (f, "a"), (f, "a")), + "c2": (f, (f, "a"), (f, "a")), + "c3": (f, (f, "a"), (f, "a")), + "c4": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "e": (f, "d1", "d2"), + } + ) assert fuse(d, ave_width=2, rename_keys=False) == expected assert fuse(d, ave_width=2.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-c1": (f, (f, "a"), (f, "a")), - "b3-b4-c2": (f, (f, "a"), (f, "a")), - "b5-b6-c3": (f, (f, "a"), (f, "a")), - "b7-b8-c4": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "e": (f, "d1", "d2"), - "c1": "b1-b2-c1", - "c2": "b3-b4-c2", - "c3": "b5-b6-c3", - "c4": "b7-b8-c4", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-c1": (f, (f, "a"), (f, "a")), + "b3-b4-c2": (f, (f, "a"), (f, "a")), + "b5-b6-c3": (f, (f, "a"), (f, "a")), + "b7-b8-c4": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "e": (f, "d1", "d2"), + "c1": "b1-b2-c1", + "c2": "b3-b4-c2", + "c3": "b5-b6-c3", + "c4": "b7-b8-c4", + } + ) assert fuse(d, ave_width=2, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - expected = with_deps({ - "a": 1, - "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "e": (f, "d1", "d2"), - }) + expected = with_deps( + { + "a": 1, + "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "e": (f, "d1", "d2"), + } + ) assert fuse(d, ave_width=3, rename_keys=False) == expected assert fuse(d, ave_width=4.6, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-b3-b4-c1-c2-d1": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b5-b6-b7-b8-c3-c4-d2": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "e": (f, "d1", "d2"), - "d1": "b1-b2-b3-b4-c1-c2-d1", - "d2": "b5-b6-b7-b8-c3-c4-d2", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-b3-b4-c1-c2-d1": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b5-b6-b7-b8-c3-c4-d2": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "e": (f, "d1", "d2"), + "d1": "b1-b2-b3-b4-c1-c2-d1", + "d2": "b5-b6-b7-b8-c3-c4-d2", + } + ) assert fuse(d, ave_width=3, rename_keys=True) == expected assert fuse(d, ave_width=4.6, rename_keys=True) == expected - assert fuse(d, ave_width=4.7, rename_keys=False) == with_deps({ - "a": 1, - "e": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - }) - assert fuse(d, ave_width=4.7, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "e": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e", - }) + assert fuse(d, ave_width=4.7, rename_keys=False) == with_deps( + { + "a": 1, + "e": ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + } + ) + assert fuse(d, ave_width=4.7, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e": ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + "e": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e", + } + ) d = { "a": 1, @@ -540,165 +557,181 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - expected = with_deps({ - "a": 1, - "c1": (f, (f, "a"), (f, "a")), - "c2": (f, (f, "a"), (f, "a")), - "c3": (f, (f, "a"), (f, "a")), - "c4": (f, (f, "a"), (f, "a")), - "c5": (f, (f, "a"), (f, "a")), - "c6": (f, (f, "a"), (f, "a")), - "c7": (f, (f, "a"), (f, "a")), - "c8": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "d3": (f, "c5", "c6"), - "d4": (f, "c7", "c8"), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - }) + expected = with_deps( + { + "a": 1, + "c1": (f, (f, "a"), (f, "a")), + "c2": (f, (f, "a"), (f, "a")), + "c3": (f, (f, "a"), (f, "a")), + "c4": (f, (f, "a"), (f, "a")), + "c5": (f, (f, "a"), (f, "a")), + "c6": (f, (f, "a"), (f, "a")), + "c7": (f, (f, "a"), (f, "a")), + "c8": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "d3": (f, "c5", "c6"), + "d4": (f, "c7", "c8"), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + } + ) assert fuse(d, ave_width=2, rename_keys=False) == expected assert fuse(d, ave_width=2.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-c1": (f, (f, "a"), (f, "a")), - "b3-b4-c2": (f, (f, "a"), (f, "a")), - "b5-b6-c3": (f, (f, "a"), (f, "a")), - "b7-b8-c4": (f, (f, "a"), (f, "a")), - "b10-b9-c5": (f, (f, "a"), (f, "a")), - "b11-b12-c6": (f, (f, "a"), (f, "a")), - "b13-b14-c7": (f, (f, "a"), (f, "a")), - "b15-b16-c8": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "d3": (f, "c5", "c6"), - "d4": (f, "c7", "c8"), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - "c1": "b1-b2-c1", - "c2": "b3-b4-c2", - "c3": "b5-b6-c3", - "c4": "b7-b8-c4", - "c5": "b10-b9-c5", - "c6": "b11-b12-c6", - "c7": "b13-b14-c7", - "c8": "b15-b16-c8", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-c1": (f, (f, "a"), (f, "a")), + "b3-b4-c2": (f, (f, "a"), (f, "a")), + "b5-b6-c3": (f, (f, "a"), (f, "a")), + "b7-b8-c4": (f, (f, "a"), (f, "a")), + "b10-b9-c5": (f, (f, "a"), (f, "a")), + "b11-b12-c6": (f, (f, "a"), (f, "a")), + "b13-b14-c7": (f, (f, "a"), (f, "a")), + "b15-b16-c8": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "d3": (f, "c5", "c6"), + "d4": (f, "c7", "c8"), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + "c1": "b1-b2-c1", + "c2": "b3-b4-c2", + "c3": "b5-b6-c3", + "c4": "b7-b8-c4", + "c5": "b10-b9-c5", + "c6": "b11-b12-c6", + "c7": "b13-b14-c7", + "c8": "b15-b16-c8", + } + ) assert fuse(d, ave_width=2, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - expected = with_deps({ - "a": 1, - "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d3": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d4": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - }) + expected = with_deps( + { + "a": 1, + "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d3": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d4": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + } + ) assert fuse(d, ave_width=3, rename_keys=False) == expected assert fuse(d, ave_width=4.6, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-b3-b4-c1-c2-d1": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b5-b6-b7-b8-c3-c4-d2": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b10-b11-b12-b9-c5-c6-d3": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b13-b14-b15-b16-c7-c8-d4": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - "d1": "b1-b2-b3-b4-c1-c2-d1", - "d2": "b5-b6-b7-b8-c3-c4-d2", - "d3": "b10-b11-b12-b9-c5-c6-d3", - "d4": "b13-b14-b15-b16-c7-c8-d4", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-b3-b4-c1-c2-d1": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b5-b6-b7-b8-c3-c4-d2": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b10-b11-b12-b9-c5-c6-d3": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b13-b14-b15-b16-c7-c8-d4": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + "d1": "b1-b2-b3-b4-c1-c2-d1", + "d2": "b5-b6-b7-b8-c3-c4-d2", + "d3": "b10-b11-b12-b9-c5-c6-d3", + "d4": "b13-b14-b15-b16-c7-c8-d4", + } + ) assert fuse(d, ave_width=3, rename_keys=True) == expected assert fuse(d, ave_width=4.6, rename_keys=True) == expected - expected = with_deps({ - "a": 1, - "e1": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "e2": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "f": (f, "e1", "e2"), - }) - assert fuse(d, ave_width=4.7, rename_keys=False) == expected - assert fuse(d, ave_width=7.4, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "f": (f, "e1", "e2"), - "e1": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1", - "e2": "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2", - }) - assert fuse(d, ave_width=4.7, rename_keys=True) == expected - assert fuse(d, ave_width=7.4, rename_keys=True) == expected - assert fuse(d, ave_width=7.5, rename_keys=False) == with_deps({ - "a": 1, - "f": ( - f, - ( + expected = with_deps( + { + "a": 1, + "e1": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ( + "e2": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ), - }) - assert fuse(d, ave_width=7.5, rename_keys=True) == with_deps({ - "a": 1, - "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f": ( - f, - ( + "f": (f, "e1", "e2"), + } + ) + assert fuse(d, ave_width=4.7, rename_keys=False) == expected + assert fuse(d, ave_width=7.4, rename_keys=False) == expected + expected = with_deps( + { + "a": 1, + "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ( + "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ), - "f": "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f", - }) + "f": (f, "e1", "e2"), + "e1": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1", + "e2": "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2", + } + ) + assert fuse(d, ave_width=4.7, rename_keys=True) == expected + assert fuse(d, ave_width=7.4, rename_keys=True) == expected + assert fuse(d, ave_width=7.5, rename_keys=False) == with_deps( + { + "a": 1, + "f": ( + f, + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ), + } + ) + assert fuse(d, ave_width=7.5, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f": ( + f, + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ), + "f": "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f", + } + ) d = {"a": 1, "b": (f, "a")} assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"b": (f, 1)}) @@ -710,11 +743,9 @@ def f(*args): d = {"a": 1, "b": (f, "a"), "c": (f, "a", "b"), "d": (f, "a", "c")} assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"a": 1, "d": (f, "a", (f, "a", (f, "a")))}) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a": 1, - "b-c-d": (f, "a", (f, "a", (f, "a"))), - "d": "b-c-d" - }) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + {"a": 1, "b-c-d": (f, "a", (f, "a", (f, "a"))), "d": "b-c-d"} + ) d = { "a": 1, @@ -728,21 +759,25 @@ def f(*args): expected = with_deps({"a": 1, "b2": (f, "a"), "e1": (f, (f, (f, (f, "a")))), "f": (f, "e1", "b2")}) assert fuse(d, ave_width=1, rename_keys=False) == expected assert fuse(d, ave_width=1.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b2": (f, "a"), - "b1-c1-d1-e1": (f, (f, (f, (f, "a")))), - "f": (f, "e1", "b2"), - "e1": "b1-c1-d1-e1", - }) + expected = with_deps( + { + "a": 1, + "b2": (f, "a"), + "b1-c1-d1-e1": (f, (f, (f, (f, "a")))), + "f": (f, "e1", "b2"), + "e1": "b1-c1-d1-e1", + } + ) assert fuse(d, ave_width=1, rename_keys=True) == expected assert fuse(d, ave_width=1.9, rename_keys=True) == expected assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "f": (f, (f, (f, (f, (f, "a")))), (f, "a"))}) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c1-d1-e1-f": (f, (f, (f, (f, (f, "a")))), (f, "a")), - "f": "b1-b2-c1-d1-e1-f", - }) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-c1-d1-e1-f": (f, (f, (f, (f, (f, "a")))), (f, "a")), + "f": "b1-b2-c1-d1-e1-f", + } + ) d = { "a": 1, @@ -753,37 +788,42 @@ def f(*args): "e1": (f, "a", "d1"), "f": (f, "a", "e1", "b2"), } - expected = with_deps({ - "a": 1, - "b2": (f, "a"), - "e1": (f, "a", (f, "a", (f, "a", (f, "a")))), - "f": (f, "a", "e1", "b2"), - }) + expected = with_deps( + { + "a": 1, + "b2": (f, "a"), + "e1": (f, "a", (f, "a", (f, "a", (f, "a")))), + "f": (f, "a", "e1", "b2"), + } + ) assert fuse(d, ave_width=1, rename_keys=False) == expected assert fuse(d, ave_width=1.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b2": (f, "a"), - "b1-c1-d1-e1": (f, "a", (f, "a", (f, "a", (f, "a")))), - "f": (f, "a", "e1", "b2"), - "e1": "b1-c1-d1-e1", - }) + expected = with_deps( + { + "a": 1, + "b2": (f, "a"), + "b1-c1-d1-e1": (f, "a", (f, "a", (f, "a", (f, "a")))), + "f": (f, "a", "e1", "b2"), + "e1": "b1-c1-d1-e1", + } + ) assert fuse(d, ave_width=1, rename_keys=True) == expected assert fuse(d, ave_width=1.9, rename_keys=True) == expected - assert fuse(d, ave_width=2, rename_keys=False) == with_deps({ - "a": 1, - "f": (f, "a", (f, "a", (f, "a", (f, "a", (f, "a")))), (f, "a")) - }) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c1-d1-e1-f": ( - f, - "a", - (f, "a", (f, "a", (f, "a", (f, "a")))), - (f, "a"), - ), - "f": "b1-b2-c1-d1-e1-f", - }) + assert fuse(d, ave_width=2, rename_keys=False) == with_deps( + {"a": 1, "f": (f, "a", (f, "a", (f, "a", (f, "a", (f, "a")))), (f, "a"))} + ) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-c1-d1-e1-f": ( + f, + "a", + (f, "a", (f, "a", (f, "a", (f, "a")))), + (f, "a"), + ), + "f": "b1-b2-c1-d1-e1-f", + } + ) d = { "a": 1, @@ -800,24 +840,28 @@ def f(*args): "f": (f, "e"), "g": (f, "f"), } - assert fuse(d, ave_width=1, rename_keys=False) == with_deps({ - "a": 1, - "d1": (f, (f, (f, "a"))), - "d2": (f, (f, (f, "a"))), - "d3": (f, (f, (f, "a"))), - "g": (f, (f, (f, "d1", "d2", "d3"))), - }) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a": 1, - "b1-c1-d1": (f, (f, (f, "a"))), - "b2-c2-d2": (f, (f, (f, "a"))), - "b3-c3-d3": (f, (f, (f, "a"))), - "e-f-g": (f, (f, (f, "d1", "d2", "d3"))), - "d1": "b1-c1-d1", - "d2": "b2-c2-d2", - "d3": "b3-c3-d3", - "g": "e-f-g", - }) + assert fuse(d, ave_width=1, rename_keys=False) == with_deps( + { + "a": 1, + "d1": (f, (f, (f, "a"))), + "d2": (f, (f, (f, "a"))), + "d3": (f, (f, (f, "a"))), + "g": (f, (f, (f, "d1", "d2", "d3"))), + } + ) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + { + "a": 1, + "b1-c1-d1": (f, (f, (f, "a"))), + "b2-c2-d2": (f, (f, (f, "a"))), + "b3-c3-d3": (f, (f, (f, "a"))), + "e-f-g": (f, (f, (f, "d1", "d2", "d3"))), + "d1": "b1-c1-d1", + "d2": "b2-c2-d2", + "d3": "b3-c3-d3", + "g": "e-f-g", + } + ) d = { "a": 1, @@ -828,23 +872,22 @@ def f(*args): "f": (f, "e"), "g": (f, "d", "f"), } - assert fuse(d, ave_width=1, rename_keys=False) == with_deps({ - "b": (f, 1), - "d": (f, "b", (f, "b")), - "g": (f, "d", (f, (f, "d"))) - }) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a-b": (f, 1), - "c-d": (f, "b", (f, "b")), - "e-f-g": (f, "d", (f, (f, "d"))), - "b": "a-b", - "d": "c-d", - "g": "e-f-g", - }) + assert fuse(d, ave_width=1, rename_keys=False) == with_deps( + {"b": (f, 1), "d": (f, "b", (f, "b")), "g": (f, "d", (f, (f, "d")))} + ) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + { + "a-b": (f, 1), + "c-d": (f, "b", (f, "b")), + "e-f-g": (f, "d", (f, (f, "d"))), + "b": "a-b", + "d": "c-d", + "g": "e-f-g", + } + ) def test_fuse_stressed(): - def f(*args): return args @@ -917,7 +960,6 @@ def f(*args): def test_fuse_reductions_multiple_input(): - def f(*args): return args @@ -925,12 +967,9 @@ def f(*args): assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"c": (f, (f, 1, 2))}) assert fuse(d, ave_width=2, rename_keys=True) == with_deps({"a1-a2-b-c": (f, (f, 1, 2)), "c": "a1-a2-b-c"}) assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"a1": 1, "a2": 2, "c": (f, (f, "a1", "a2"))}) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b-c": (f, (f, "a1", "a2")), - "c": "b-c" - }) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + {"a1": 1, "a2": 2, "b-c": (f, (f, "a1", "a2")), "c": "b-c"} + ) d = { "a1": 1, @@ -945,17 +984,17 @@ def f(*args): assert fuse(d, ave_width=2.9, rename_keys=False) == expected assert fuse(d, ave_width=1, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a1": 1, - "a2": 2, - "c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")) - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b1-b2-b3-c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")), - "c": "b1-b2-b3-c", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + {"a1": 1, "a2": 2, "c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2"))} + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a1": 1, + "a2": 2, + "b1-b2-b3-c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")), + "c": "b1-b2-b3-c", + } + ) d = { "a1": 1, @@ -968,22 +1007,26 @@ def f(*args): } assert fuse(d, ave_width=1, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1, rename_keys=True) == with_deps(d) - assert fuse(d, ave_width=2, rename_keys=False) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "c1": (f, (f, "a1"), "b2"), - "c2": (f, "b2", (f, "a2")), - }) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "b1-c1": (f, (f, "a1"), "b2"), - "b3-c2": (f, "b2", (f, "a2")), - "c1": "b1-c1", - "c2": "b3-c2", - }) + assert fuse(d, ave_width=2, rename_keys=False) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "c1": (f, (f, "a1"), "b2"), + "c2": (f, "b2", (f, "a2")), + } + ) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "b1-c1": (f, (f, "a1"), "b2"), + "b3-c2": (f, "b2", (f, "a2")), + "c1": "b1-c1", + "c2": "b3-c2", + } + ) d = { "a1": 1, @@ -1000,19 +1043,23 @@ def f(*args): # A more aggressive heuristic could do this at `ave_width=2`. Perhaps # we can improve this. Nevertheless, this is behaving as intended. - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "b1-b3-c1-c2-d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), - "d": "b1-b3-c1-c2-d", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), + } + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "b1-b3-c1-c2-d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), + "d": "b1-b3-c1-c2-d", + } + ) def func_with_kwargs(a, b, c=2): @@ -1028,20 +1075,13 @@ def test_SubgraphCallable(): apply, partial_by_order, ["in2"], - { - "function": func_with_kwargs, - "other": [(1, 20)], - "c": 4 - }, + {"function": func_with_kwargs, "other": [(1, 20)], "c": 4}, ), "c": ( apply, partial_by_order, ["in2", "in1"], - { - "function": func_with_kwargs, - "other": [(1, 20)] - }, + {"function": func_with_kwargs, "other": [(1, 20)]}, ), "d": (inc, "a"), "e": (add, "c", "d"), @@ -1105,54 +1145,60 @@ def test_fuse_subgraphs(): } res = fuse(dsk, "inc-6", fuse_subgraphs=True) - sol = with_deps({ - "inc-6": "add-inc-x-1", - "add-inc-x-1": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", (inc, (inc, "x-1"))), - "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), - }, - "inc-6", - (), + sol = with_deps( + { + "inc-6": "add-inc-x-1", + "add-inc-x-1": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", (inc, (inc, "x-1"))), + "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), + }, + "inc-6", + (), + ), ), - ), - }) + } + ) assert res == sol res = fuse(dsk, "inc-6", fuse_subgraphs=True, rename_keys=False) - sol = with_deps({ - "inc-6": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", (inc, (inc, "x-1"))), - "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), - }, - "inc-6", - (), - ), - ) - }) + sol = with_deps( + { + "inc-6": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", (inc, (inc, "x-1"))), + "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), + }, + "inc-6", + (), + ), + ) + } + ) assert res == sol res = fuse(dsk, "add-2", fuse_subgraphs=True) - sol = with_deps({ - "add-inc-x-1": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", (inc, (inc, "x-1"))), - "add-2": (add, "add-1", (inc, (inc, "add-1"))), - }, - "add-2", - (), + sol = with_deps( + { + "add-inc-x-1": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", (inc, (inc, "x-1"))), + "add-2": (add, "add-1", (inc, (inc, "add-1"))), + }, + "add-2", + (), + ), ), - ), - "add-2": "add-inc-x-1", - "inc-6": (inc, (inc, "add-2")), - }) + "add-2": "add-inc-x-1", + "inc-6": (inc, (inc, "add-2")), + } + ) assert res == sol res = fuse(dsk, "inc-2", fuse_subgraphs=True) @@ -1160,24 +1206,27 @@ def test_fuse_subgraphs(): sols = [] for inkeys in itertools.permutations(("x-1", "inc-2")): sols.append( - with_deps({ - "x-1": 1, - "inc-2": (inc, (inc, "x-1")), - "inc-6": "inc-add-1", - "inc-add-1": ( - SubgraphCallable( - { - "add-1": (add, "x-1", "inc-2"), - "inc-6": ( - inc, - (inc, (add, "add-1", (inc, (inc, "add-1")))), - ), - }, - "inc-6", - inkeys, - ), - ) + inkeys, - }) + with_deps( + { + "x-1": 1, + "inc-2": (inc, (inc, "x-1")), + "inc-6": "inc-add-1", + "inc-add-1": ( + SubgraphCallable( + { + "add-1": (add, "x-1", "inc-2"), + "inc-6": ( + inc, + (inc, (add, "add-1", (inc, (inc, "add-1")))), + ), + }, + "inc-6", + inkeys, + ), + ) + + inkeys, + } + ) ) assert res in sols @@ -1186,22 +1235,25 @@ def test_fuse_subgraphs(): sols = [] for inkeys in itertools.permutations(("x-1", "inc-2")): sols.append( - with_deps({ - "x-1": 1, - "inc-2": (inc, (inc, "x-1")), - "inc-add-1": ( - SubgraphCallable( - { - "add-1": (add, "x-1", "inc-2"), - "add-2": (add, "add-1", (inc, (inc, "add-1"))), - }, - "add-2", - inkeys, - ), - ) + inkeys, - "add-2": "inc-add-1", - "inc-6": (inc, (inc, "add-2")), - }) + with_deps( + { + "x-1": 1, + "inc-2": (inc, (inc, "x-1")), + "inc-add-1": ( + SubgraphCallable( + { + "add-1": (add, "x-1", "inc-2"), + "add-2": (add, "add-1", (inc, (inc, "add-1"))), + }, + "add-2", + inkeys, + ), + ) + + inkeys, + "add-2": "inc-add-1", + "inc-6": (inc, (inc, "add-2")), + } + ) ) assert res in sols @@ -1217,23 +1269,25 @@ def test_fuse_subgraphs_linear_chains_of_duplicate_deps(): } res = fuse(dsk, "add-5", fuse_subgraphs=True) - sol = with_deps({ - "add-x-1": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", "x-1"), - "add-2": (add, "add-1", "add-1"), - "add-3": (add, "add-2", "add-2"), - "add-4": (add, "add-3", "add-3"), - "add-5": (add, "add-4", "add-4"), - }, - "add-5", - (), + sol = with_deps( + { + "add-x-1": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", "x-1"), + "add-2": (add, "add-1", "add-1"), + "add-3": (add, "add-2", "add-2"), + "add-4": (add, "add-3", "add-3"), + "add-5": (add, "add-4", "add-4"), + }, + "add-5", + (), + ), ), - ), - "add-5": "add-x-1", - }) + "add-5": "add-x-1", + } + ) assert res == sol diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py index 4b4f1589c87..d11c11504f4 100644 --- a/tests/core/serve/test_dag/test_order.py +++ b/tests/core/serve/test_dag/test_order.py @@ -20,14 +20,14 @@ def f(*args): def test_ordering_keeps_groups_together(abcde): a, b, c, d, e = abcde - d = dict(((a, i), (f, )) for i in range(4)) + d = dict(((a, i), (f,)) for i in range(4)) d.update({(b, 0): (f, (a, 0), (a, 1)), (b, 1): (f, (a, 2), (a, 3))}) o = order(d) assert abs(o[(a, 0)] - o[(a, 1)]) == 1 assert abs(o[(a, 2)] - o[(a, 3)]) == 1 - d = dict(((a, i), (f, )) for i in range(4)) + d = dict(((a, i), (f,)) for i in range(4)) d.update({(b, 0): (f, (a, 0), (a, 2)), (b, 1): (f, (a, 1), (a, 3))}) o = order(d) @@ -46,8 +46,8 @@ def test_avoid_broker_nodes(abcde): """ a, b, c, d, e = abcde dsk = { - (a, 0): (f, ), - (a, 1): (f, ), + (a, 0): (f,), + (a, 1): (f,), (b, 0): (f, (a, 0)), (b, 1): (f, (a, 1)), (b, 2): (f, (a, 1)), @@ -57,8 +57,8 @@ def test_avoid_broker_nodes(abcde): # Switch name of 0, 1 to ensure that this isn't due to string comparison dsk = { - (a, 1): (f, ), - (a, 0): (f, ), + (a, 1): (f,), + (a, 0): (f,), (b, 0): (f, (a, 1)), (b, 1): (f, (a, 0)), (b, 2): (f, (a, 0)), @@ -68,8 +68,8 @@ def test_avoid_broker_nodes(abcde): # Switch name of 0, 1 for "b"s too dsk = { - (a, 0): (f, ), - (a, 1): (f, ), + (a, 0): (f,), + (a, 1): (f,), (b, 1): (f, (a, 0)), (b, 0): (f, (a, 1)), (b, 2): (f, (a, 1)), @@ -161,10 +161,10 @@ def test_avoid_upwards_branching_complex(abcde): (a, 2): (f, (a, 3)), (a, 3): (f, (b, 1), (c, 1)), (b, 1): (f, (b, 2)), - (b, 2): (f, ), + (b, 2): (f,), (c, 1): (f, (c, 2)), (c, 2): (f, (c, 3)), - (c, 3): (f, ), + (c, 3): (f,), (d, 1): (f, (c, 1)), (d, 2): (f, (d, 1)), (d, 3): (f, (d, 1)), @@ -261,7 +261,7 @@ def test_prefer_short_dependents(abcde): during the long computations. """ a, b, c, d, e = abcde - dsk = {c: (f, ), d: (f, c), e: (f, c), b: (f, c), a: (f, b)} + dsk = {c: (f,), d: (f, c), e: (f, c), b: (f, c), a: (f, b)} o = order(dsk) assert o[d] < o[b] @@ -287,17 +287,16 @@ def test_run_smaller_sections(abcde): log = [] def f(x): - def _(*args): log.append(x) return _ dsk = { - a: (f(a), ), - c: (f(c), ), - e: (f(e), ), - cc: (f(cc), ), + a: (f(a),), + c: (f(c),), + e: (f(e),), + cc: (f(cc),), b: (f(b), a, c), d: (f(d), c, e), bb: (f(bb), cc), @@ -335,20 +334,19 @@ def test_local_parents_of_reduction(abcde): log = [] def f(x): - def _(*args): log.append(x) return _ dsk = { - a3: (f(a3), ), + a3: (f(a3),), a2: (f(a2), a3), a1: (f(a1), a2), - b3: (f(b3), ), + b3: (f(b3),), b2: (f(b2), b3, a2), b1: (f(b1), b2), - c3: (f(c3), ), + c3: (f(c3),), c2: (f(c2), c3, b2), c1: (f(c1), c2), } @@ -374,10 +372,10 @@ def test_nearest_neighbor(abcde): b1, b2, b3, b4 = [b + i for i in "1234"] dsk = { - b1: (f, ), - b2: (f, ), - b3: (f, ), - b4: (f, ), + b1: (f,), + b2: (f,), + b3: (f,), + b4: (f,), a1: (f, b1), a2: (f, b1), a3: (f, b1, b2), @@ -398,14 +396,14 @@ def test_nearest_neighbor(abcde): def test_string_ordering(): """Prefer ordering tasks by name first.""" - dsk = {("a", 1): (f, ), ("a", 2): (f, ), ("a", 3): (f, )} + dsk = {("a", 1): (f,), ("a", 2): (f,), ("a", 3): (f,)} o = order(dsk) assert o == {("a", 1): 0, ("a", 2): 1, ("a", 3): 2} def test_string_ordering_dependents(): """Prefer ordering tasks by name first even when in dependencies.""" - dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f, )} + dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f,)} o = order(dsk) assert o == {"b": 0, ("a", 1): 1, ("a", 2): 2, ("a", 3): 3} @@ -502,19 +500,19 @@ def test_map_overlap(abcde): """ a, b, c, d, e = abcde dsk = { - (e, 1): (f, ), + (e, 1): (f,), (d, 1): (f, (e, 1)), (c, 1): (f, (d, 1)), (b, 1): (f, (c, 1), (c, 2)), - (d, 2): (f, ), + (d, 2): (f,), (c, 2): (f, (d, 1), (d, 2), (d, 3)), - (e, 3): (f, ), + (e, 3): (f,), (d, 3): (f, (e, 3)), (c, 3): (f, (d, 3)), (b, 3): (f, (c, 2), (c, 3), (c, 4)), - (d, 4): (f, ), + (d, 4): (f,), (c, 4): (f, (d, 3), (d, 4), (d, 5)), - (e, 5): (f, ), + (e, 5): (f,), (d, 5): (f, (e, 5)), (c, 5): (f, (d, 5)), (b, 5): (f, (c, 4), (c, 5)), @@ -532,16 +530,16 @@ def test_use_structure_not_keys(abcde): """ a, b, _, _, _ = abcde dsk = { - (a, 0): (f, ), - (a, 1): (f, ), - (a, 2): (f, ), - (a, 3): (f, ), - (a, 4): (f, ), - (a, 5): (f, ), - (a, 6): (f, ), - (a, 7): (f, ), - (a, 8): (f, ), - (a, 9): (f, ), + (a, 0): (f,), + (a, 1): (f,), + (a, 2): (f,), + (a, 3): (f,), + (a, 4): (f,), + (a, 5): (f,), + (a, 6): (f,), + (a, 7): (f,), + (a, 8): (f,), + (a, 9): (f,), (b, 5): (f, (a, 2)), (b, 7): (f, (a, 0), (a, 2)), (b, 9): (f, (a, 7), (a, 0), (a, 2)), @@ -701,21 +699,25 @@ def test_order_with_equal_dependents(abcde): dsk = {} abc = [a, b, c, d] for x in abc: - dsk.update({ - (x, 0): 0, - (x, 1): (f, (x, 0)), - (x, 2, 0): (f, (x, 0)), - (x, 2, 1): (f, (x, 1)), - }) + dsk.update( + { + (x, 0): 0, + (x, 1): (f, (x, 0)), + (x, 2, 0): (f, (x, 0)), + (x, 2, 1): (f, (x, 1)), + } + ) for i, y in enumerate(abc): - dsk.update({ - (x, 3, i): (f, (x, 2, 0), (y, 2, 1)), # cross x and y - (x, 4, i): (f, (x, 3, i)), - (x, 5, i, 0): (f, (x, 4, i)), - (x, 5, i, 1): (f, (x, 4, i)), - (x, 6, i, 0): (f, (x, 5, i, 0)), - (x, 6, i, 1): (f, (x, 5, i, 1)), - }) + dsk.update( + { + (x, 3, i): (f, (x, 2, 0), (y, 2, 1)), # cross x and y + (x, 4, i): (f, (x, 3, i)), + (x, 5, i, 0): (f, (x, 4, i)), + (x, 5, i, 1): (f, (x, 4, i)), + (x, 6, i, 0): (f, (x, 5, i, 0)), + (x, 6, i, 1): (f, (x, 5, i, 1)), + } + ) o = order(dsk) total = 0 for x in abc: diff --git a/tests/core/serve/test_dag/test_rewrite.py b/tests/core/serve/test_dag/test_rewrite.py index 64055f72114..97fbaf25f37 100644 --- a/tests/core/serve/test_dag/test_rewrite.py +++ b/tests/core/serve/test_dag/test_rewrite.py @@ -21,7 +21,7 @@ def test_head(): def test_args(): - assert args((inc, 1)) == (1, ) + assert args((inc, 1)) == (1,) assert args((add, 1, 2)) == (1, 2) assert args(1) == () assert args([1, 2, 3]) == [1, 2, 3] @@ -65,16 +65,16 @@ def repl_list(sd): return (list, x) -rule6 = RewriteRule((list, "x"), repl_list, ("x", )) +rule6 = RewriteRule((list, "x"), repl_list, ("x",)) def test_RewriteRule(): # Test extraneous vars are removed, varlist is correct - assert rule1.vars == ("a", ) + assert rule1.vars == ("a",) assert rule1._varlist == ["a"] - assert rule2.vars == ("a", ) + assert rule2.vars == ("a",) assert rule2._varlist == ["a", "a"] - assert rule3.vars == ("a", ) + assert rule3.vars == ("a",) assert rule3._varlist == ["a", "a"] assert rule4.vars == ("a", "b") assert rule4._varlist == ["b", "a"] @@ -97,32 +97,13 @@ def test_RuleSet(): { add: ( { - VAR: ({ - VAR: ({}, [1]), - 1: ({}, [0]) - }, []), - inc: ({ - VAR: ({ - inc: ({ - VAR: ({}, [2, 3]) - }, []) - }, []) - }, []), + VAR: ({VAR: ({}, [1]), 1: ({}, [0])}, []), + inc: ({VAR: ({inc: ({VAR: ({}, [2, 3])}, [])}, [])}, []), }, [], ), - list: ({ - VAR: ({}, [5]) - }, []), - sum: ({ - list: ({ - VAR: ({ - VAR: ({ - VAR: ({}, [4]) - }, []) - }, []) - }, []) - }, []), + list: ({VAR: ({}, [5])}, []), + sum: ({list: ({VAR: ({VAR: ({VAR: ({}, [4])}, [])}, [])}, [])}, []), }, [], ) diff --git a/tests/core/serve/test_dag/test_task.py b/tests/core/serve/test_dag/test_task.py index cd7479f5d5c..260bc72d0bc 100644 --- a/tests/core/serve/test_dag/test_task.py +++ b/tests/core/serve/test_dag/test_task.py @@ -52,7 +52,7 @@ def test_get_dependencies_nested(): def test_get_dependencies_empty(): - dsk = {"x": (inc, )} + dsk = {"x": (inc,)} assert get_dependencies(dsk, "x") == set() assert get_dependencies(dsk, "x", as_list=True) == [] @@ -181,7 +181,6 @@ class MyException(Exception): pass class F: - def __eq__(self, other): raise MyException() @@ -200,9 +199,7 @@ def test_subs_with_surprisingly_friendly_eq(): def test_subs_unexpected_hashable_key(): - class UnexpectedButHashable: - def __init__(self): self.name = "a" diff --git a/tests/core/serve/test_dag/test_utils.py b/tests/core/serve/test_dag/test_utils.py index 29a914ec785..7ce379d006f 100644 --- a/tests/core/serve/test_dag/test_utils.py +++ b/tests/core/serve/test_dag/test_utils.py @@ -12,7 +12,6 @@ def test_funcname_long(): - def a_long_function_name_11111111111111111111111111111111111111111111111(): pass @@ -23,7 +22,6 @@ def a_long_function_name_11111111111111111111111111111111111111111111111(): @pytest.mark.skipif(not _CYTOOLZ_AVAILABLE, reason="the library `cytoolz` is not installed.") def test_funcname_cytoolz(): - @curry def foo(a, b, c): pass @@ -45,12 +43,11 @@ def test_partial_by_order(): def test_funcname(): assert funcname(np.floor_divide) == "floor_divide" assert funcname(partial(bool)) == "bool" - assert (funcname(operator.methodcaller("__getitem__")) == "operator.methodcaller('__getitem__')") + assert funcname(operator.methodcaller("__getitem__")) == "operator.methodcaller('__getitem__')" assert funcname(lambda x: x) == "lambda" def test_numpy_vectorize_funcname(): - def myfunc(a, b): """Return a-b if a>b, otherwise return a+b.""" if a > b: diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py index 007cd800ed4..17e094dd830 100644 --- a/tests/core/serve/test_gridbase_validations.py +++ b/tests/core/serve/test_gridbase_validations.py @@ -12,7 +12,6 @@ def test_metaclass_raises_if_expose_decorator_not_applied_to_method(): with pytest.raises(SyntaxError, match=r"expose.* decorator"): class FailedNoExposed(ModelComponent): - def __init__(self, model): pass @@ -23,7 +22,6 @@ def test_metaclass_raises_if_more_than_one_expose_decorator_applied(): with pytest.raises(SyntaxError, match=r"decorator must be applied to one"): class FailedTwoExposed(ModelComponent): - def __init__(self, model): pass @@ -44,7 +42,6 @@ def test_metaclass_raises_if_first_arg_in_init_is_not_model(): with pytest.raises(SyntaxError, match="__init__ must set 'model' as first"): class FailedModelArg(ModelComponent): - def __init__(self, foo): pass @@ -60,7 +57,6 @@ def test_metaclass_raises_if_second_arg_is_not_config(): with pytest.raises(SyntaxError, match="__init__ can only set 'config'"): class FailedConfig(ModelComponent): - def __init__(self, model, OTHER): pass @@ -76,7 +72,6 @@ def test_metaclass_raises_if_random_parameters_in_init(): with pytest.raises(SyntaxError, match="__init__ can only have 1 or 2 parameters"): class FailedInit(ModelComponent): - def __init__(self, model, config, FOO): pass @@ -93,7 +88,6 @@ def test_metaclass_raises_uses_restricted_method_name(): with pytest.raises(TypeError, match="bound methods/attrs named"): class FailedMethod_Inputs(ModelComponent): - def __init__(self, model): pass @@ -109,7 +103,6 @@ def inputs(self): with pytest.raises(TypeError, match="bound methods/attrs named"): class FailedMethod_Outputs(ModelComponent): - def __init__(self, model): pass @@ -125,7 +118,6 @@ def outputs(self): with pytest.raises(TypeError, match="bound methods/attrs named"): class FailedMethod_Name(ModelComponent): - def __init__(self, model): pass @@ -136,11 +128,12 @@ def predict(param): @property def uid(self): - return f'{self.uid}_SHOULD_NOT_RETURN' + return f"{self.uid}_SHOULD_NOT_RETURN" # Ensure that if we add more restricted names in the future, # there is a test for them as well. from flash.core.serve.component import _FLASH_SERVE_RESERVED_NAMES + assert set(_FLASH_SERVE_RESERVED_NAMES).difference({"inputs", "outputs", "uid"}) == set() @@ -149,7 +142,6 @@ def test_metaclass_raises_if_argument_values_of_expose_arent_subclasses_of_baset with pytest.raises(TypeError, match="must be subclass of"): class FailedExposedDecoratorInputs(ModelComponent): - def __init__(self, model): self.model = model @@ -162,7 +154,6 @@ def predict(param): with pytest.raises(TypeError, match="must be subclass of"): class FailedExposedDecoratorOutputs(ModelComponent): - def __init__(self, model): self.model = model @@ -175,7 +166,6 @@ def predict(param): with pytest.raises(TypeError, match="must be subclass of"): class FailedExposedDecoratorClass(ModelComponent): - def __init__(self, model): self.model = model @@ -197,7 +187,6 @@ class defiition time. from tests.core.serve.models import ClassificationInference class FailedExposedDecorator(ModelComponent): - def __init__(self, model): self.model = model @@ -220,7 +209,6 @@ class defiition time. """ class ConfigComponent(ModelComponent): - def __init__(self, model, config): pass @@ -241,7 +229,6 @@ class defiition time. """ class ConfigComponent(ModelComponent): - def __init__(self, model): pass diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index 2d3cebef273..4efafb548cb 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -89,35 +89,21 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat assert meta.json() == { "definitions": { "Ep_Ep_In_Image": { - "properties": { - "data": { - "title": "Data", - "type": "string" - } - }, + "properties": {"data": {"title": "Data", "type": "string"}}, "required": ["data"], "title": "Ep_Ep_In_Image", "type": "object", }, "Ep_Payload": { - "properties": { - "ep_in_image": { - "$ref": "#/definitions/Ep_Ep_In_Image" - } - }, + "properties": {"ep_in_image": {"$ref": "#/definitions/Ep_Ep_In_Image"}}, "required": ["ep_in_image"], "title": "Ep_Payload", "type": "object", }, }, "properties": { - "payload": { - "$ref": "#/definitions/Ep_Payload" - }, - "session": { - "title": "Session", - "type": "string" - }, + "payload": {"$ref": "#/definitions/Ep_Payload"}, + "session": {"title": "Session", "type": "string"}, }, "required": ["payload"], "title": "Ep_RequestModel", @@ -134,9 +120,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat assert "result" in success.json() expected = { "session": "UUID", - "result": { - "ep_out_prediction": "goldfish, Carassius auratus" - }, + "result": {"ep_out_prediction": "goldfish, Carassius auratus"}, } assert expected == success.json() @@ -209,26 +193,15 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): body = { "session": "UUID", "payload": { - "image": { - "data": imgstr - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "image": {"data": imgstr}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat", json=body) assert success.json() == { - "result": { - "seat_number": 4799680, - "team": "buffalo bills, the ralph" - }, + "result": {"seat_number": 4799680, "team": "buffalo bills, the ralph"}, "session": "UUID", } resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") @@ -295,26 +268,15 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad body = { "session": "UUID", "payload": { - "image": { - "data": imgstr - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "image": {"data": imgstr}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat", json=body) assert success.json() == { - "result": { - "seat_number_out": 4799680, - "team_out": "buffalo bills, the ralph" - }, + "result": {"seat_number_out": 4799680, "team_out": "buffalo bills, the ralph"}, "session": "UUID", } resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") @@ -339,10 +301,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ "section": seat_comp.inputs.section, "row": seat_comp.inputs.row, }, - outputs={ - "seat_number": seat_comp.outputs.seat_number, - "team": seat_comp.outputs.team - }, + outputs={"seat_number": seat_comp.outputs.seat_number, "team": seat_comp.outputs.team}, ) ep2 = Endpoint( route="/predict_seat_img", @@ -366,10 +325,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ "section": seat_comp.inputs.section, "row": seat_comp.inputs.row, }, - outputs={ - "seat_number": seat_comp.outputs.seat_number, - "team": seat_comp.outputs.team - }, + outputs={"seat_number": seat_comp.outputs.seat_number, "team": seat_comp.outputs.team}, ) composit = Composition( @@ -402,26 +358,15 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ body = { "session": "UUID", "payload": { - "image": { - "data": imgstr - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "image": {"data": imgstr}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat", json=body) assert success.json() == { - "result": { - "seat_number": 4799680, - "team": "buffalo bills, the ralph" - }, + "result": {"seat_number": 4799680, "team": "buffalo bills, the ralph"}, "session": "UUID", } @@ -438,26 +383,15 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ body = { "session": "UUID", "payload": { - "stadium": { - "label": "buffalo bills, the ralph" - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "stadium": {"label": "buffalo bills, the ralph"}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat_img_two", json=body) assert success.json() == { - "result": { - "seat_number": 16960000, - "team": "buffalo bills, the ralph" - }, + "result": {"seat_number": 16960000, "team": "buffalo bills, the ralph"}, "session": "UUID", } @@ -476,6 +410,7 @@ def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1 def test_composition_from_url_torchscript_servable(tmp_path): from flash.core.serve import expose, ModelComponent, Servable from flash.core.serve.types import Number + """ # Tensor x Tensor class MyModule(torch.nn.Module): @@ -494,7 +429,6 @@ def forward(self, a, b): TORCHSCRIPT_DOWNLOAD_URL = "https://github.com/pytorch/pytorch/raw/95489b590f00801bdee7f41783f30874883cf6bb/test/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt" # noqa E501 class ComponentTwoModels(ModelComponent): - def __init__(self, model): self.encoder = model["encoder"] self.decoder = model["decoder"] @@ -523,15 +457,11 @@ def do_my_predict(self, inp): body = { "session": "UUID", "payload": { - "ep_in": { - "num": 10 - }, + "ep_in": {"num": 10}, }, } success = tc.post("http://127.0.0.1:8000/predictr", json=body) assert success.json() == { - "result": { - "ep_out": 1.0 - }, + "result": {"ep_out": 1.0}, "session": "UUID", } diff --git a/tests/core/serve/test_types/test_bbox.py b/tests/core/serve/test_types/test_bbox.py index fb4fbe26c09..ca58a8f2a9b 100644 --- a/tests/core/serve/test_types/test_bbox.py +++ b/tests/core/serve/test_types/test_bbox.py @@ -6,7 +6,7 @@ def test_deserialize(): bbox = BBox() - assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4, ))) + assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4,))) assert bbox.deserialize((0, 0, 0, 0)).shape == torch.Size([4]) with pytest.raises(ValueError): # only three elements, need four @@ -19,15 +19,17 @@ def test_deserialize(): bbox.deserialize({1: 1, 2: 2, 3: 3, 4: 4}) with pytest.raises(ValueError): # tuple instead of float - bbox.deserialize(( + bbox.deserialize( ( - 0, - 0, - ), - (0, 0), - (0, 0), - (0, 0), - )) + ( + 0, + 0, + ), + (0, 0), + (0, 0), + (0, 0), + ) + ) def test_serialize(): diff --git a/tests/core/serve/test_types/test_repeated.py b/tests/core/serve/test_types/test_repeated.py index b8fa64ef7e2..2038dd29ec6 100644 --- a/tests/core/serve/test_types/test_repeated.py +++ b/tests/core/serve/test_types/test_repeated.py @@ -12,11 +12,7 @@ def test_repeated_deserialize(): def test_repeated_serialize(session_global_datadir): repeated = Repeated(dtype=Label(path=str(session_global_datadir / "imagenet_labels.txt"))) - assert repeated.deserialize(*({ - "label": "chickadee" - }, { - "label": "stingray" - })) == ( + assert repeated.deserialize(*({"label": "chickadee"}, {"label": "stingray"})) == ( torch.tensor(19), torch.tensor(6), ) @@ -29,11 +25,7 @@ def test_repeated_max_len(): with pytest.raises(ValueError): repeated.deserialize(*({"label": "classA"}, {"label": "classA"}, {"label": "classB"})) - assert repeated.deserialize(*({ - "label": "classA" - }, { - "label": "classB" - })) == ( + assert repeated.deserialize(*({"label": "classA"}, {"label": "classB"})) == ( torch.tensor(0), torch.tensor(1), ) @@ -52,7 +44,6 @@ def test_repeated_max_len(): def test_repeated_non_serve_dtype(): - class NonServeDtype: pass diff --git a/tests/core/serve/test_types/test_table.py b/tests/core/serve/test_types/test_table.py index c1da29b7034..5bccc648923 100644 --- a/tests/core/serve/test_types/test_table.py +++ b/tests/core/serve/test_types/test_table.py @@ -65,14 +65,7 @@ def test_deserialize(): with pytest.raises(RuntimeError): table.deserialize({"title1": {0: 100}, "title2": {0: 200}}) assert torch.allclose( - table.deserialize({ - "t1": { - 0: 100.0 - }, - "t2": { - 1: 200.0 - } - }), + table.deserialize({"t1": {0: 100.0}, "t2": {1: 200.0}}), torch.tensor([[100.0, float("nan")], [float("nan"), 200.0]], dtype=torch.float64), equal_nan=True, ) diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index 88097cc7136..6cfa7a2c50b 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -21,17 +21,17 @@ def test_classification_serializers(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] assert torch.allclose(torch.tensor(Logits().serialize(example_output)), example_output) assert torch.allclose(torch.tensor(Probabilities().serialize(example_output)), torch.softmax(example_output, -1)) assert Classes().serialize(example_output) == 2 - assert Labels(labels).serialize(example_output) == 'class_3' + assert Labels(labels).serialize(example_output) == "class_3" def test_classification_serializers_multi_label(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] assert torch.allclose(torch.tensor(Logits(multi_label=True).serialize(example_output)), example_output) assert torch.allclose( @@ -39,7 +39,7 @@ def test_classification_serializers_multi_label(): torch.sigmoid(example_output), ) assert Classes(multi_label=True).serialize(example_output) == [1, 2] - assert Labels(labels, multi_label=True).serialize(example_output) == ['class_2', 'class_3'] + assert Labels(labels, multi_label=True).serialize(example_output) == ["class_2", "class_3"] @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -48,24 +48,24 @@ def test_classification_serializers_fiftyone(): logits = torch.tensor([-0.1, 0.2, 0.3]) example_output = {DefaultDataKeys.PREDS: logits, DefaultDataKeys.METADATA: {"filepath": "something"}} # 3 classes - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] predictions = FiftyOneLabels(return_filepath=True).serialize(example_output) - assert predictions["predictions"].label == '2' + assert predictions["predictions"].label == "2" assert predictions["filepath"] == "something" predictions = FiftyOneLabels(labels, return_filepath=True).serialize(example_output) - assert predictions["predictions"].label == 'class_3' + assert predictions["predictions"].label == "class_3" assert predictions["filepath"] == "something" predictions = FiftyOneLabels(store_logits=True).serialize(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) assert torch.allclose(torch.tensor(predictions.confidence), torch.softmax(logits, -1)[-1]) - assert predictions.label == '2' + assert predictions.label == "2" predictions = FiftyOneLabels(labels, store_logits=True).serialize(example_output) - assert predictions.label == 'class_3' + assert predictions.label == "class_3" predictions = FiftyOneLabels(store_logits=True, multi_label=True).serialize(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) - assert [c.label for c in predictions.classifications] == ['1', '2'] + assert [c.label for c in predictions.classifications] == ["1", "2"] predictions = FiftyOneLabels(labels, multi_label=True).serialize(example_output) - assert [c.label for c in predictions.classifications] == ['class_2', 'class_3'] + assert [c.label for c in predictions.classifications] == ["class_2", "class_3"] diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 65e3759323e..156669a657e 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -21,9 +21,8 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): - return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item() + return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item() def __len__(self) -> int: return 10 diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index ad44cc7dbf7..809bfb41abc 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -24,9 +24,8 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index: int) -> Any: - return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1, )).item()} + return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1,)).item()} def __len__(self) -> int: return 100 @@ -34,7 +33,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize( - "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] + "strategy", ["no_freeze", "freeze", "freeze_unfreeze", "unfreeze_milestones", None, "cls", "chocolat"] ) def test_finetuning(tmpdir: str, strategy): train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -43,7 +42,7 @@ def test_finetuning(tmpdir: str, strategy): trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if strategy == "cls": strategy = NoFreeze() - if strategy == 'chocolat' or strategy is None: + if strategy == "chocolat" or strategy is None: with pytest.raises(MisconfigurationException, match="strategy should be provided"): trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index eb04ecdb68d..91d846a1269 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -51,16 +51,14 @@ class Image: class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index: int) -> Tuple[Tensor, Number]: - return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item() + return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item() def __len__(self) -> int: return 9 class PredictDummyDataset(DummyDataset): - def __getitem__(self, index: int) -> Tensor: return torch.rand(1, 28, 28) @@ -71,7 +69,6 @@ class DummyPostprocess(Postprocess): class FixedDataset(torch.utils.data.Dataset): - def __init__(self, targets): super().__init__() @@ -85,13 +82,12 @@ def __len__(self) -> int: class OnesModel(nn.Module): - def __init__(self): super().__init__() self.layer = nn.Linear(1, 2) - self.register_buffer('zeros', torch.zeros(2)) - self.register_buffer('zero_one', torch.tensor([0.0, 1.0])) + self.register_buffer("zeros", torch.zeros(2)) + self.register_buffer("zero_one", torch.tensor([0.0, 1.0])) def forward(self, x): x = self.layer(x) @@ -99,7 +95,6 @@ def forward(self, x): class Parent(ClassificationTask): - def __init__(self, child): super().__init__() @@ -119,7 +114,6 @@ def forward(self, x): class GrandParent(Parent): - def __init__(self, child): super().__init__(Parent(child)) @@ -229,24 +223,27 @@ def test_task_datapipeline_save(tmpdir): assert task.postprocess.test -@pytest.mark.parametrize(["cls", "filename"], [ - pytest.param( - ImageClassifier, - "image_classification_model.pt", - marks=pytest.mark.skipif( - not _IMAGE_TESTING, - reason="image packages aren't installed", - ) - ), - pytest.param( - TabularClassifier, - "tabular_classification_model.pt", - marks=pytest.mark.skipif( - not _TABULAR_TESTING, - reason="tabular packages aren't installed", - ) - ), -]) +@pytest.mark.parametrize( + ["cls", "filename"], + [ + pytest.param( + ImageClassifier, + "image_classification_model.pt", + marks=pytest.mark.skipif( + not _IMAGE_TESTING, + reason="image packages aren't installed", + ), + ), + pytest.param( + TabularClassifier, + "tabular_classification_model.pt", + marks=pytest.mark.skipif( + not _TABULAR_TESTING, + reason="tabular packages aren't installed", + ), + ), + ], +) def test_model_download(tmpdir, cls, filename): url = "https://flash-weights.s3.amazonaws.com/" with tmpdir.as_cwd(): @@ -283,7 +280,7 @@ def test_optimization(tmpdir): model, optimizer=torch.optim.Adadelta, scheduler=torch.optim.lr_scheduler.StepLR, - scheduler_kwargs={"step_size": 1} + scheduler_kwargs={"step_size": 1}, ) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) @@ -319,7 +316,7 @@ def test_optimization(tmpdir): assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) expected = get_linear_schedule_with_warmup.__name__ - assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected + assert scheduler[0].lr_lambdas[0].__qualname__.split(".")[0] == expected def test_classification_task_metrics(): @@ -329,9 +326,8 @@ def test_classification_task_metrics(): model = OnesModel() class CheckAccuracy(Callback): - - def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: - assert math.isclose(trainer.callback_metrics['train_accuracy_epoch'], 0.5) + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5) task = ClassificationTask(model) trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy()) diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py index 3af891aa3a8..674a3a4616a 100644 --- a/tests/core/test_registry.py +++ b/tests/core/test_registry.py @@ -82,7 +82,7 @@ def my_model(nc_input=5, nc_output=6): assert all(callable(f) for f in functions) # test available keys - assert backbones.available_keys() == ['foo', 'foo', 'foo', 'foo', 'foo', 'my_model'] + assert backbones.available_keys() == ["foo", "foo", "foo", "foo", "foo", "my_model"] # todo (tchaton) Debug this test. @@ -100,8 +100,8 @@ def my_model(): assert caplog.messages == [ "Registering: my_model function with name: bar and metadata: {'foobar': True}", - 'Registering: my_model function with name: foo and metadata: {}', - 'Registering: my_model function with name: my_model and metadata: {}' + "Registering: my_model function with name: foo and metadata: {}", + "Registering: my_model function with name: my_model and metadata: {}", ] assert len(backbones) == 3 diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index 7bd330d83a8..436bb48a2e1 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -27,7 +27,6 @@ class DummyDataset(torch.utils.data.Dataset): - def __init__(self, predict: bool = False): self._predict = predict @@ -35,14 +34,13 @@ def __getitem__(self, index: int) -> Any: sample = torch.rand(1, 28, 28) if self._predict: return sample - return sample, torch.randint(10, size=(1, )).item() + return sample, torch.randint(10, size=(1,)).item() def __len__(self) -> int: return 100 class DummyClassifier(nn.Module): - def __init__(self): super().__init__() self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) @@ -85,7 +83,6 @@ def test_resolve_callbacks_invalid_strategy(tmpdir): class MultiFinetuneClassificationTask(ClassificationTask): - def configure_finetune_callback(self): return [NoFreeze(), NoFreeze()] @@ -99,7 +96,6 @@ def test_resolve_callbacks_multi_error(tmpdir): class FinetuneClassificationTask(ClassificationTask): - def configure_finetune_callback(self): return [NoFreeze()] @@ -115,14 +111,14 @@ def test_resolve_callbacks_override_warning(tmpdir): def test_add_argparse_args(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) - args = parser.parse_args(['--gpus=1']) + args = parser.parse_args(["--gpus=1"]) assert args.gpus == 1 def test_from_argparse_args(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) - args = parser.parse_args(['--max_epochs=200']) + args = parser.parse_args(["--max_epochs=200"]) trainer = Trainer.from_argparse_args(args) assert trainer.max_epochs == 200 assert isinstance(trainer, Trainer) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 250aba1122f..49d24bf7ab7 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -20,7 +20,6 @@ class A: - def __call__(self, x): return True @@ -54,4 +53,4 @@ def test_get_callable_dict(): def test_download_data(tmpdir): path = os.path.join(tmpdir, "data") download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", path) - assert set(os.listdir(path)) == {'titanic', 'titanic.zip'} + assert set(os.listdir(path)) == {"titanic", "titanic.zip"} diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index 542277a3360..1b664a02e51 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -28,12 +28,12 @@ ) from tests.helpers.boring_model import BoringDataModule, BoringModel -torchvision_version = version.parse('0') +torchvision_version = version.parse("0") if _TORCHVISION_AVAILABLE: - torchvision_version = version.parse(__import__('torchvision').__version__) + torchvision_version = version.parse(__import__("torchvision").__version__) -@mock.patch('argparse.ArgumentParser.parse_args') +@mock.patch("argparse.ArgumentParser.parse_args") def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) @@ -48,7 +48,7 @@ def test_default_args(mock_argparse, tmpdir): assert trainer.max_epochs == 5 -@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []]) +@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) @@ -60,7 +60,7 @@ def test_add_argparse_args_redefined(cli_args): pickle.dumps(args) # Check few deprecated args are not in namespace: - for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'): + for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"): assert depr_name not in args trainer = Trainer.from_argparse_args(args=args) @@ -70,19 +70,19 @@ def test_add_argparse_args_redefined(cli_args): @pytest.mark.parametrize( - ['cli_args', 'expected'], + ["cli_args", "expected"], [ - ('--auto_lr_find=True --auto_scale_batch_size=power', dict(auto_lr_find=True, auto_scale_batch_size='power')), + ("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")), ( - '--auto_lr_find any_string --auto_scale_batch_size ON', - dict(auto_lr_find='any_string', auto_scale_batch_size=True), + "--auto_lr_find any_string --auto_scale_batch_size ON", + dict(auto_lr_find="any_string", auto_scale_batch_size=True), ), - ('--auto_lr_find=Yes --auto_scale_batch_size=On', dict(auto_lr_find=True, auto_scale_batch_size=True)), - ('--auto_lr_find Off --auto_scale_batch_size No', dict(auto_lr_find=False, auto_scale_batch_size=False)), - ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', dict(auto_lr_find=True, auto_scale_batch_size=False)), - ('--limit_train_batches=100', dict(limit_train_batches=100)), - ('--limit_train_batches 0.8', dict(limit_train_batches=0.8)), - ('--weights_summary=null', dict(weights_summary=None)), + ("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)), + ("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)), + ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)), + ("--limit_train_batches=100", dict(limit_train_batches=100)), + ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)), + ("--weights_summary=null", dict(weights_summary=None)), ( "", dict( @@ -96,14 +96,14 @@ def test_add_argparse_args_redefined(cli_args): weights_save_path=None, truncated_bptt_steps=None, resume_from_checkpoint=None, - profiler=None + profiler=None, ), ), ], ) def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.split(" ") if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): @@ -115,14 +115,11 @@ def test_parse_args_parsing(cli_args, expected): @pytest.mark.parametrize( - ['cli_args', 'expected', 'instantiate'], + ["cli_args", "expected", "instantiate"], [ - (['--gpus', '[0, 2]'], dict(gpus=[0, 2]), False), - (['--tpu_cores=[1,3]'], dict(tpu_cores=[1, 3]), False), - (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={ - 5: 3, - 10: 20 - }), True), + (["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False), + (["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False), + (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True), ], ) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): @@ -139,17 +136,17 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): @pytest.mark.parametrize( - ['cli_args', 'expected_gpu'], + ["cli_args", "expected_gpu"], [ - ('--gpus 1', [0]), - ('--gpus 0,', [0]), - ('--gpus 0,1', [0, 1]), + ("--gpus 1", [0]), + ("--gpus 0,", [0]), + ("--gpus 0,1", [0, 1]), ], ) def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.split(" ") if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): @@ -164,7 +161,7 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", ) @pytest.mark.parametrize( - ['cli_args', 'extra_args'], + ["cli_args", "extra_args"], [ ({}, {}), (dict(logger=False), {}), @@ -176,7 +173,7 @@ def test_init_from_argparse_args(cli_args, extra_args): unknown_args = dict(unknown_arg=0) # unkown args in the argparser/namespace should be ignored - with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init: + with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init: trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) expected = dict(cli_args) expected.update(extra_args) # extra args should override any cli arg @@ -188,7 +185,6 @@ def test_init_from_argparse_args(cli_args, extra_args): class Model(LightningModule): - def __init__(self, model_param: int): super().__init__() self.model_param = model_param @@ -199,14 +195,12 @@ def model_builder(model_param: int) -> Model: def trainer_builder( - limit_train_batches: int, - fast_dev_run: bool = False, - callbacks: Optional[Union[List[Callback], Callback]] = None + limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None ) -> Trainer: return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) -@pytest.mark.parametrize(['trainer_class', 'model_class'], [(Trainer, Model), (trainer_builder, model_builder)]) +@pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (trainer_builder, model_builder)]) def test_lightning_cli(trainer_class, model_class, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" @@ -225,79 +219,75 @@ def fit(trainer, model): def on_train_start(callback, trainer, _): config_dump = callback.parser.dump(callback.config, skip_none=False) for k, v in expected_model.items(): - assert f' {k}: {v}' in config_dump + assert f" {k}: {v}" in config_dump for k, v in expected_trainer.items(): - assert f' {k}: {v}' in config_dump + assert f" {k}: {v}" in config_dump trainer.ran_asserts = True - monkeypatch.setattr(Trainer, 'fit', fit) - monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) + monkeypatch.setattr(Trainer, "fit", fit) + monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start) - with mock.patch('sys.argv', ['any.py', '--model.model_param=7', '--trainer.limit_train_batches=100']): + with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]): cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback) - assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts + assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ dict( - class_path='pytorch_lightning.callbacks.LearningRateMonitor', - init_args=dict(logging_interval='epoch', log_momentum=True) + class_path="pytorch_lightning.callbacks.LearningRateMonitor", + init_args=dict(logging_interval="epoch", log_momentum=True), ), - dict(class_path='pytorch_lightning.callbacks.ModelCheckpoint', init_args=dict(monitor='NAME')), + dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")), ] class TestModel(BoringModel): - def on_fit_start(self): callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 - assert callback[0].logging_interval == 'epoch' + assert callback[0].logging_interval == "epoch" assert callback[0].log_momentum is True callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] assert len(callback) == 1 - assert callback[0].monitor == 'NAME' + assert callback[0].monitor == "NAME" self.trainer.ran_asserts = True - with mock.patch('sys.argv', ['any.py', f'--trainer.callbacks={json.dumps(callbacks)}']): + with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]): cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts def test_lightning_cli_configurable_callbacks(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_lightning_class_args(LearningRateMonitor, 'learning_rate_monitor') + parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - '--learning_rate_monitor.logging_interval=epoch', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--learning_rate_monitor.logging_interval=epoch", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 - assert callback[0].logging_interval == 'epoch' + assert callback[0].logging_interval == "epoch" def test_lightning_cli_args_cluster_environments(tmpdir): - plugins = [dict(class_path='pytorch_lightning.plugins.environments.SLURMEnvironment')] + plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] class TestModel(BoringModel): - def on_fit_start(self): # Ensure SLURMEnvironment is set, instead of default LightningEnvironment assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment) self.trainer.ran_asserts = True - with mock.patch('sys.argv', ['any.py', f'--trainer.plugins={json.dumps(plugins)}']): + with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts @@ -306,78 +296,78 @@ def on_fit_start(self): def test_lightning_cli_args(tmpdir): cli_args = [ - f'--data.data_dir={tmpdir}', - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - '--trainer.weights_summary=null', - '--seed_everything=1234', + f"--data.data_dir={tmpdir}", + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--trainer.weights_summary=null", + "--seed_everything=1234", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): - cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={'callbacks': [LearningRateMonitor()]}) + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]}) - assert cli.config['seed_everything'] == 1234 - config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' + assert cli.config["seed_everything"] == 1234 + config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) - assert 'model' not in config and 'model' not in cli.config # no arguments to include - assert config['data'] == cli.config['data'] - assert config['trainer'] == cli.config['trainer'] + assert "model" not in config and "model" not in cli.config # no arguments to include + assert config["data"] == cli.config["data"] + assert config["trainer"] == cli.config["trainer"] def test_lightning_cli_save_config_cases(tmpdir): - config_path = tmpdir / 'config.yaml' + config_path = tmpdir / "config.yaml" cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.logger=False', - '--trainer.fast_dev_run=1', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.logger=False", + "--trainer.fast_dev_run=1", ] # With fast_dev_run!=False config should not be saved - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert not os.path.isfile(config_path) # With fast_dev_run==False config should be saved - cli_args[-1] = '--trainer.max_epochs=1' - with mock.patch('sys.argv', ['any.py'] + cli_args): + cli_args[-1] = "--trainer.max_epochs=1" + with mock.patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert os.path.isfile(config_path) # If run again on same directory exception should be raised since config file already exists - with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.raises(RuntimeError): + with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError): LightningCLI(BoringModel) def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict( - model=dict(class_path='tests.helpers.boring_model.BoringModel'), - data=dict(class_path='tests.helpers.boring_model.BoringDataModule', init_args=dict(data_dir=str(tmpdir))), - trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None) + model=dict(class_path="tests.helpers.boring_model.BoringModel"), + data=dict(class_path="tests.helpers.boring_model.BoringDataModule", init_args=dict(data_dir=str(tmpdir))), + trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None), ) - config_path = tmpdir / 'config.yaml' - with open(config_path, 'w') as f: + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: f.write(yaml.dump(config)) - with mock.patch('sys.argv', ['any.py', '--config', str(config_path)]): + with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]): cli = LightningCLI( BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True, - trainer_defaults={'callbacks': LearningRateMonitor()} + trainer_defaults={"callbacks": LearningRateMonitor()}, ) - config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' + config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) - assert config['model'] == cli.config['model'] - assert config['data'] == cli.config['data'] - assert config['trainer'] == cli.config['trainer'] + assert config["model"] == cli.config["model"] + assert config["data"] == cli.config["data"] + assert config["trainer"] == cli.config["trainer"] def any_model_any_data_cli(): @@ -391,54 +381,52 @@ def any_model_any_data_cli(): def test_lightning_cli_help(): - cli_args = ['any.py', '--help'] + cli_args = ["any.py", "--help"] out = StringIO() - with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() - assert '--print_config' in out.getvalue() - assert '--config' in out.getvalue() - assert '--seed_everything' in out.getvalue() - assert '--model.help' in out.getvalue() - assert '--data.help' in out.getvalue() + assert "--print_config" in out.getvalue() + assert "--config" in out.getvalue() + assert "--seed_everything" in out.getvalue() + assert "--model.help" in out.getvalue() + assert "--data.help" in out.getvalue() - skip_params = {'self'} + skip_params = {"self"} for param in inspect.signature(Trainer.__init__).parameters.keys(): if param not in skip_params: - assert f'--trainer.{param}' in out.getvalue() + assert f"--trainer.{param}" in out.getvalue() - cli_args = ['any.py', '--data.help=tests.helpers.boring_model.BoringDataModule'] + cli_args = ["any.py", "--data.help=tests.helpers.boring_model.BoringDataModule"] out = StringIO() - with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() - assert '--data.init_args.data_dir' in out.getvalue() + assert "--data.init_args.data_dir" in out.getvalue() def test_lightning_cli_print_config(): cli_args = [ - 'any.py', - '--seed_everything=1234', - '--model=tests.helpers.boring_model.BoringModel', - '--data=tests.helpers.boring_model.BoringDataModule', - '--print_config', + "any.py", + "--seed_everything=1234", + "--model=tests.helpers.boring_model.BoringModel", + "--data=tests.helpers.boring_model.BoringDataModule", + "--print_config", ] out = StringIO() - with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() outval = yaml.safe_load(out.getvalue()) - assert outval['seed_everything'] == 1234 - assert outval['model']['class_path'] == 'tests.helpers.boring_model.BoringModel' - assert outval['data']['class_path'] == 'tests.helpers.boring_model.BoringDataModule' + assert outval["seed_everything"] == 1234 + assert outval["model"]["class_path"] == "tests.helpers.boring_model.BoringModel" + assert outval["data"]["class_path"] == "tests.helpers.boring_model.BoringDataModule" def test_lightning_cli_submodules(tmpdir): - class MainModule(BoringModel): - def __init__( self, submodule1: LightningModule, @@ -456,29 +444,27 @@ def __init__( submodule2: class_path: tests.helpers.boring_model.BoringModel """ - config_path = tmpdir / 'config.yaml' - with open(config_path, 'w') as f: + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: f.write(config) cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - f'--config={str(config_path)}', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + f"--config={str(config_path)}", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(MainModule) - assert cli.config['model']['main_param'] == 2 + assert cli.config["model"]["main_param"] == 2 assert isinstance(cli.model.submodule1, BoringModel) assert isinstance(cli.model.submodule2, BoringModel) -@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required') +@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required") def test_lightning_cli_torch_modules(tmpdir): - class TestModule(BoringModel): - def __init__( self, activation: torch.nn.Module = None, @@ -501,17 +487,17 @@ def __init__( init_args: size: 64 """ - config_path = tmpdir / 'config.yaml' - with open(config_path, 'w') as f: + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: f.write(config) cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - f'--config={str(config_path)}', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + f"--config={str(config_path)}", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(TestModule) assert isinstance(cli.model.activation, torch.nn.LeakyReLU) @@ -521,7 +507,6 @@ def __init__( class BoringModelRequiredClasses(BoringModel): - def __init__( self, num_classes: int, @@ -533,7 +518,6 @@ def __init__( class BoringDataModuleBatchSizeAndClasses(BoringDataModule): - def __init__( self, batch_size: int = 8, @@ -544,34 +528,31 @@ def __init__( def test_lightning_cli_link_arguments(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.link_arguments('data.batch_size', 'model.batch_size') - parser.link_arguments('data.num_classes', 'model.num_classes', apply_on='instantiate') + parser.link_arguments("data.batch_size", "model.batch_size") + parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate") cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - '--data.batch_size=12', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--data.batch_size=12", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses) assert cli.model.batch_size == 12 assert cli.model.num_classes == 5 class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.link_arguments('data.batch_size', 'model.init_args.batch_size') - parser.link_arguments('data.num_classes', 'model.init_args.num_classes', apply_on='instantiate') + parser.link_arguments("data.batch_size", "model.init_args.batch_size") + parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate") - cli_args[-1] = '--model=tests.core.utilities.test_lightning_cli.BoringModelRequiredClasses' + cli_args[-1] = "--model=tests.core.utilities.test_lightning_cli.BoringModelRequiredClasses" - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI( BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, @@ -583,68 +564,66 @@ def add_arguments_to_parser(self, parser): class EarlyExitTestModel(BoringModel): - def on_fit_start(self): raise KeyboardInterrupt() -@pytest.mark.parametrize('logger', (False, True)) +@pytest.mark.parametrize("logger", (False, True)) @pytest.mark.parametrize( - 'trainer_kwargs', ( - dict(accelerator='ddp_cpu'), - dict(accelerator='ddp_cpu', plugins="ddp_find_unused_parameters_false"), - ) + "trainer_kwargs", + ( + dict(accelerator="ddp_cpu"), + dict(accelerator="ddp_cpu", plugins="ddp_find_unused_parameters_false"), + ), ) def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): - with mock.patch('sys.argv', ['any.py']), pytest.raises(KeyboardInterrupt): + with mock.patch("sys.argv", ["any.py"]), pytest.raises(KeyboardInterrupt): LightningCLI( EarlyExitTestModel, trainer_defaults={ - 'default_root_dir': str(tmpdir), - 'logger': logger, - 'max_steps': 1, - 'max_epochs': 1, + "default_root_dir": str(tmpdir), + "logger": logger, + "max_steps": 1, + "max_epochs": 1, **trainer_kwargs, - } + }, ) if logger: - config_dir = tmpdir / 'lightning_logs' + config_dir = tmpdir / "lightning_logs" # no more version dirs should get created - assert os.listdir(config_dir) == ['version_0'] - config_path = config_dir / 'version_0' / 'config.yaml' + assert os.listdir(config_dir) == ["version_0"] + config_path = config_dir / "version_0" / "config.yaml" else: - config_path = tmpdir / 'config.yaml' + config_path = tmpdir / "config.yaml" assert os.path.isfile(config_path) def test_cli_config_overwrite(tmpdir): - trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1} + trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} - with mock.patch('sys.argv', ['any.py']): + with mock.patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) - with mock.patch('sys.argv', ['any.py']), pytest.raises(RuntimeError, match='Aborting to avoid overwriting'): + with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) - with mock.patch('sys.argv', ['any.py']): + with mock.patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults) def test_lightning_cli_optimizer(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", ] match = ( - 'BoringModel.configure_optimizers` will be overridden by ' - '`MyLightningCLI.add_configure_optimizers_method_to_model`' + "BoringModel.configure_optimizers` will be overridden by " + "`MyLightningCLI.add_configure_optimizers_method_to_model`" ) - with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.warns(UserWarning, match=match): + with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel) assert cli.model.configure_optimizers is not BoringModel.configure_optimizers @@ -654,74 +633,67 @@ def add_arguments_to_parser(self, parser): def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - '--lr_scheduler.gamma=0.8', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--lr_scheduler.gamma=0.8", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) assert cli.model.configure_optimizers is not BoringModel.configure_optimizers assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]['scheduler'], torch.optim.lr_scheduler.ExponentialLR) - assert cli.trainer.lr_schedulers[0]['scheduler'].gamma == 0.8 + assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR) + assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR)) optimizer_arg = dict( - class_path='torch.optim.Adam', + class_path="torch.optim.Adam", init_args=dict(lr=0.01), ) lr_scheduler_arg = dict( - class_path='torch.optim.lr_scheduler.StepLR', + class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50), ) cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - f'--optimizer={json.dumps(optimizer_arg)}', - f'--lr_scheduler={json.dumps(lr_scheduler_arg)}', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + f"--optimizer={json.dumps(optimizer_arg)}", + f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]['scheduler'], torch.optim.lr_scheduler.StepLR) - assert cli.trainer.lr_schedulers[0]['scheduler'].step_size == 50 + assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) + assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(torch.optim.Adam, nested_key='optim1', link_to='model.optim1') - parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key='optim2', link_to='model.optim2') - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to='model.scheduler') + parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") + parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") class TestModel(BoringModel): - def __init__( self, optim1: dict, @@ -734,14 +706,14 @@ def __init__( self.scheduler = instantiate_class(self.optim1, scheduler) cli_args = [ - f'--trainer.default_root_dir={tmpdir}', - '--trainer.max_epochs=1', - '--optim2.class_path=torch.optim.SGD', - '--optim2.init_args.lr=0.01', - '--lr_scheduler.gamma=0.2', + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--optim2.class_path=torch.optim.SGD", + "--optim2.init_args.lr=0.01", + "--lr_scheduler.gamma=0.2", ] - with mock.patch('sys.argv', ['any.py'] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(TestModel) assert isinstance(cli.model.optim1, torch.optim.Adam) diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py index b3af1de2f54..5fe061c678e 100644 --- a/tests/examples/test_integrations.py +++ b/tests/examples/test_integrations.py @@ -25,15 +25,16 @@ @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "folder, file", [ + "folder, file", + [ pytest.param( "fiftyone", "image_classification.py", marks=pytest.mark.skipif( not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" - ) + ), ), - ] + ], ) def test_integrations(tmpdir, folder, file): run_test(str(root / "flash_examples" / "integrations" / folder / file)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index bc3260b1a84..75a5d7cd5ff 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -40,40 +40,39 @@ ), pytest.param( "audio_classification.py", - marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), ), pytest.param( "speech_recognition.py", - marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), ), pytest.param( "image_classification.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), pytest.param( "image_classification_multi_label.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), # pytest.param("finetuning", "object_detection.py"), # TODO: takes too long. pytest.param( "semantic_segmentation.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), pytest.param( - "style_transfer.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + "style_transfer.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), pytest.param( "tabular_classification.py", - marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed") + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), ), pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")), pytest.param( "text_classification.py", - marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), ), # pytest.param( # "text_classification_multi_label.py", @@ -84,21 +83,21 @@ ), pytest.param( "video_classification.py", - marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed") + marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed"), ), pytest.param( "pointcloud_segmentation.py", - marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), ), pytest.param( "pointcloud_detection.py", - marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), ), pytest.param( "graph_classification.py", - marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed") + marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"), ), - ] + ], ) def test_example(tmpdir, file): run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file)) @@ -106,12 +105,13 @@ def test_example(tmpdir, file): @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "file", [ + "file", + [ pytest.param( "pointcloud_detection.py", - marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), ), - ] + ], ) def test_example_2(tmpdir, file): run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file)) diff --git a/tests/examples/utils.py b/tests/examples/utils.py index 109b49466a1..f35c00cc0ce 100644 --- a/tests/examples/utils.py +++ b/tests/examples/utils.py @@ -21,10 +21,10 @@ def call_script( args: Optional[List[str]] = None, timeout: Optional[int] = 60 * 10, ) -> Tuple[int, str, str]: - with open(filepath, 'r') as original: + with open(filepath, "r") as original: data = original.read() - with open(filepath, 'w') as modified: + with open(filepath, "w") as modified: modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data) if args is None: @@ -41,7 +41,7 @@ def call_script( stdout = stdout.decode("utf-8") stderr = stderr.decode("utf-8") - with open(filepath, 'w') as modified: + with open(filepath, "w") as modified: modified.write(data) return p.returncode, stdout, stderr diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py index 8a8835e83c1..de4d08ff725 100644 --- a/tests/graph/classification/test_data.py +++ b/tests/graph/classification/test_data.py @@ -42,7 +42,7 @@ def test_smoke(self): assert dm is not None def test_from_datasets(self, tmpdir): - tudataset = TUDataset(root=tmpdir, name='KKI') + tudataset = TUDataset(root=tmpdir, name="KKI") train_dataset = tudataset val_dataset = tudataset test_dataset = tudataset @@ -58,7 +58,7 @@ def test_from_datasets(self, tmpdir): val_transform=None, test_transform=None, predict_transform=None, - batch_size=2 + batch_size=2, ) assert dm is not None assert dm.train_dataloader() is not None @@ -81,7 +81,7 @@ def test_from_datasets(self, tmpdir): assert list(data.y.size()) == [2] def test_transforms(self, tmpdir): - tudataset = TUDataset(root=tmpdir, name='KKI') + tudataset = TUDataset(root=tmpdir, name="KKI") train_dataset = tudataset val_dataset = tudataset test_dataset = tudataset diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index d25d3b55676..656d69f7290 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -38,7 +38,7 @@ def test_smoke(): @pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") def test_train(tmpdir): """Tests that the model can be trained on a pytorch geometric dataset.""" - tudataset = datasets.TUDataset(root=tmpdir, name='KKI') + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) train_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) @@ -49,7 +49,7 @@ def test_train(tmpdir): @pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") def test_val(tmpdir): """Tests that the model can be validated on a pytorch geometric dataset.""" - tudataset = datasets.TUDataset(root=tmpdir, name='KKI') + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) val_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) @@ -60,7 +60,7 @@ def test_val(tmpdir): @pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") def test_test(tmpdir): """Tests that the model can be tested on a pytorch geometric dataset.""" - tudataset = datasets.TUDataset(root=tmpdir, name='KKI') + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) @@ -71,7 +71,7 @@ def test_test(tmpdir): @pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") def test_predict_dataset(tmpdir): """Tests that we can generate predictions from a pytorch geometric dataset.""" - tudataset = datasets.TUDataset(root=tmpdir, name='KKI') + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) data_pipe = DataPipeline(preprocess=GraphClassificationPreprocess()) out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index a2c06420975..e7ece2c0b83 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -8,7 +8,6 @@ class RandomDataset(Dataset): - def __init__(self, size, length): self.len = length self.data = torch.randn(length, size) @@ -21,7 +20,6 @@ def __len__(self): class BoringModel(LightningModule): - def __init__(self): """Testing PL Module. @@ -70,7 +68,7 @@ def validation_step(self, batch, batch_idx): return {"x": loss} def validation_epoch_end(self, outputs) -> None: - torch.stack([x['x'] for x in outputs]).mean() + torch.stack([x["x"] for x in outputs]).mean() def test_step(self, batch, batch_idx): output = self(batch) @@ -99,7 +97,6 @@ def predict_dataloader(self): class BoringDataModule(LightningDataModule): - def __init__(self, data_dir: str = "./"): super().__init__() self.data_dir = data_dir diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 87cb183504f..e0fcb3c1e82 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -79,9 +79,9 @@ def test_from_filepaths_smoke(tmpdir): assert img_data.test_dataloader() is None data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [1, 2] @@ -111,24 +111,24 @@ def test_from_filepaths_list_image_paths(tmpdir): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [2, 5] @@ -216,7 +216,7 @@ def test_from_filepaths_splits(tmpdir): _rand_image(img_size).save(tmpdir / "s.png") num_samples: int = 10 - val_split: float = .3 + val_split: float = 0.3 train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] @@ -227,7 +227,7 @@ def test_from_filepaths_splits(tmpdir): _to_tensor = { "to_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), } @@ -243,9 +243,9 @@ def run(transform: Any = None): image_size=img_size, ) data = next(iter(dm.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (B, 3, H, W) - assert labels.shape == (B, ) + assert labels.shape == (B,) run(_to_tensor) @@ -266,9 +266,9 @@ def test_from_folders_only_train(tmpdir): img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert labels.shape == (1,) assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None @@ -296,20 +296,20 @@ def test_from_folders_train_val(tmpdir): ) data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] @@ -338,18 +338,18 @@ def test_from_filepaths_multilabel(tmpdir): ) data = next(iter(dm.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) data = next(iter(dm.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(test_labels)) @@ -377,24 +377,24 @@ def test_from_data(data, from_function): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [2, 5] @@ -435,23 +435,23 @@ def test_from_fiftyone(tmpdir): # check train data data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [0, 1] # check val data data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [0, 1] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [0, 1] @@ -469,19 +469,19 @@ def test_from_datasets(): data = next(iter(img_data.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) # check validation data data = next(iter(img_data.val_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) @pytest.fixture @@ -517,7 +517,7 @@ def test_from_csv_single_target(single_target_csv): data = next(iter(img_data.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) @pytest.fixture diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 5171c3f4370..3fb01b87f24 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -31,11 +31,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { DefaultDataKeys.INPUT: torch.rand(3, 224, 224), - DefaultDataKeys.TARGET: torch.randint(10, size=(1, )).item(), + DefaultDataKeys.TARGET: torch.randint(10, size=(1,)).item(), } def __len__(self) -> int: @@ -43,14 +42,13 @@ def __len__(self) -> int: class DummyMultiLabelDataset(torch.utils.data.Dataset): - def __init__(self, num_classes: int): self.num_classes = num_classes def __getitem__(self, index): return { DefaultDataKeys.INPUT: torch.rand(3, 224, 224), - DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes, )), + DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes,)), } def __len__(self) -> int: @@ -118,7 +116,7 @@ def test_multilabel(tmpdir): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index d0ef137a249..2c5b6706717 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -18,44 +18,53 @@ def _create_dummy_coco_json(dummy_json_path): dummy_json = { - "images": [{ - "id": 0, - 'width': 1920, - 'height': 1080, - 'file_name': 'sample_one.png', - }, { - "id": 1, - "width": 1920, - "height": 1080, - "file_name": "sample_two.png", - }], - "annotations": [{ - "id": 1, - "image_id": 0, - "category_id": 0, - "area": 150, - "bbox": [30, 40, 20, 20], - "iscrowd": 0, - }, { - "id": 2, - "image_id": 1, - "category_id": 0, - "area": 240, - "bbox": [50, 100, 280, 15], - "iscrowd": 0, - }, { - "id": 3, - "image_id": 1, - "category_id": 0, - "area": 170, - "bbox": [230, 130, 90, 180], - "iscrowd": 0, - }], - "categories": [{ - "id": 0, - "name": "person", - "supercategory": "person", - }] + "images": [ + { + "id": 0, + "width": 1920, + "height": 1080, + "file_name": "sample_one.png", + }, + { + "id": 1, + "width": 1920, + "height": 1080, + "file_name": "sample_two.png", + }, + ], + "annotations": [ + { + "id": 1, + "image_id": 0, + "category_id": 0, + "area": 150, + "bbox": [30, 40, 20, 20], + "iscrowd": 0, + }, + { + "id": 2, + "image_id": 1, + "category_id": 0, + "area": 240, + "bbox": [50, 100, 280, 15], + "iscrowd": 0, + }, + { + "id": 3, + "image_id": 1, + "category_id": 0, + "area": 170, + "bbox": [230, 130, 90, 180], + "iscrowd": 0, + }, + ], + "categories": [ + { + "id": 0, + "name": "person", + "supercategory": "person", + } + ], } with open(dummy_json_path, "w") as fp: @@ -67,8 +76,8 @@ def _create_synth_coco_dataset(tmpdir): train_dir.mkdir() (train_dir / "images").mkdir() - Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_one.png") - Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_two.png") + Image.new("RGB", (1920, 1080)).save(train_dir / "images" / "sample_one.png") + Image.new("RGB", (1920, 1080)).save(train_dir / "images" / "sample_two.png") (train_dir / "annotations").mkdir() dummy_json = train_dir / "annotations" / "sample.json" @@ -84,8 +93,8 @@ def _create_synth_fiftyone_dataset(tmpdir): img_dir = Path(tmpdir / "fo_imgs") img_dir.mkdir() - Image.new('RGB', (1920, 1080)).save(img_dir / "sample_one.png") - Image.new('RGB', (1920, 1080)).save(img_dir / "sample_two.png") + Image.new("RGB", (1920, 1080)).save(img_dir / "sample_one.png") + Image.new("RGB", (1920, 1080)).save(img_dir / "sample_two.png") dataset = fo.Dataset.from_dir( img_dir, @@ -134,7 +143,7 @@ def test_image_detector_data_from_coco(tmpdir): assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -156,7 +165,7 @@ def test_image_detector_data_from_coco(tmpdir): assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] data = next(iter(datamodule.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] @@ -164,7 +173,7 @@ def test_image_detector_data_from_coco(tmpdir): assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -181,7 +190,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -200,7 +209,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] data = next(iter(datamodule.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] @@ -208,4 +217,4 @@ def test_image_detector_data_from_fiftyone(tmpdir): assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + assert list(labels[0].keys()) == ["boxes", "labels", "image_id", "area", "iscrowd"] diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index cba7034319e..becfe6c594d 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -49,8 +49,8 @@ def test_detection(tmpdir, model, backbone): test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") - Image.new('RGB', (512, 512)).save(test_image_one) - Image.new('RGB', (512, 512)).save(test_image_two) + Image.new("RGB", (512, 512)).save(test_image_one) + Image.new("RGB", (512, 512)).save(test_image_two) test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) @@ -73,8 +73,8 @@ def test_detection_fiftyone(tmpdir, model, backbone): test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") - Image.new('RGB', (512, 512)).save(test_image_one) - Image.new('RGB', (512, 512)).save(test_image_two) + Image.new("RGB", (512, 512)).save(test_image_one) + Image.new("RGB", (512, 512)).save(test_image_two) test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index c9388a280c9..cfc5e57d232 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -32,7 +32,6 @@ def collate_fn(samples): class DummyDetectionDataset(Dataset): - def __init__(self, img_shape, num_boxes, num_classes, length): super().__init__() self.img_shape = img_shape @@ -45,14 +44,14 @@ def __len__(self) -> int: def _random_bbox(self): c, h, w = self.img_shape - xs = torch.randint(w - 1, (2, )) - ys = torch.randint(h - 1, (2, )) + xs = torch.randint(w - 1, (2,)) + ys = torch.randint(h - 1, (2,)) return [min(xs), min(ys), max(xs) + 1, max(ys) + 1] def __getitem__(self, idx): img = torch.rand(self.img_shape) boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) - labels = torch.randint(self.num_classes, (self.num_boxes, )) + labels = torch.randint(self.num_classes, (self.num_boxes,)) return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} diff --git a/tests/image/detection/test_serialization.py b/tests/image/detection/test_serialization.py index f0c3d0e757b..8f707a229aa 100644 --- a/tests/image/detection/test_serialization.py +++ b/tests/image/detection/test_serialization.py @@ -9,7 +9,6 @@ @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") class TestFiftyOneDetectionLabels: - @staticmethod def test_smoke(): serial = FiftyOneDetectionLabels() @@ -17,7 +16,7 @@ def test_smoke(): @staticmethod def test_serialize_fiftyone(): - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] serial = FiftyOneDetectionLabels() filepath_serial = FiftyOneDetectionLabels(return_filepath=True) threshold_serial = FiftyOneDetectionLabels(threshold=0.9) @@ -26,8 +25,7 @@ def test_serialize_fiftyone(): sample = { DefaultDataKeys.PREDS: [ { - "boxes": [torch.tensor(20), torch.tensor(30), - torch.tensor(40), torch.tensor(50)], + "boxes": [torch.tensor(20), torch.tensor(30), torch.tensor(40), torch.tensor(50)], "labels": torch.tensor(0), "scores": torch.tensor(0.5), }, diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 2700c3a37e3..e823212ef75 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -23,7 +23,7 @@ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py index 6d1c118812f..4b8fb7a7a78 100644 --- a/tests/image/segmentation/test_backbones.py +++ b/tests/image/segmentation/test_backbones.py @@ -17,10 +17,13 @@ from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES -@pytest.mark.parametrize(["backbone"], [ - pytest.param("resnet50", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - pytest.param("dpn131", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), -]) +@pytest.mark.parametrize( + ["backbone"], + [ + pytest.param("resnet50", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + pytest.param("dpn131", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + ], +) def test_semantic_segmentation_backbones_registry(backbone): backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)() assert backbone diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index 5a081a5f733..b44a68da0de 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -22,8 +22,8 @@ def build_checkboard(n, m, k=8): x = np.zeros((n, m)) - x[k::k * 2, ::k] = 1 - x[::k * 2, k::k * 2] = 1 + x[k :: k * 2, ::k] = 1 + x[:: k * 2, k :: k * 2] = 1 return x @@ -48,7 +48,6 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup class TestSemanticSegmentationPreprocess: - @staticmethod @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") def test_smoke(): @@ -57,7 +56,6 @@ def test_smoke(): class TestSemanticSegmentationData: - @staticmethod @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_smoke(): @@ -203,7 +201,7 @@ def test_from_files(tmpdir): test_targets=targets, batch_size=2, num_workers=0, - num_classes=num_classes + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -259,7 +257,7 @@ def test_from_files_warning(tmpdir): train_targets=targets + [str(tmp_dir / "labels_img4.png")], batch_size=2, num_workers=0, - num_classes=num_classes + num_classes=num_classes, ) @staticmethod @@ -370,7 +368,7 @@ def test_map_labels(tmpdir): val_targets=targets, batch_size=2, num_workers=0, - num_classes=num_classes + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index f6bfb6fb247..dbc4b3b38e9 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -24,11 +24,12 @@ @pytest.mark.parametrize( - "head", [ + "head", + [ pytest.param("fpn", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), pytest.param("deeplabv3", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), pytest.param("unet", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - ] + ], ) def test_semantic_segmentation_heads_registry(head): img = torch.rand(1, 3, 32, 32) @@ -52,11 +53,11 @@ def test_pretrained_weights(mock_smp): SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=True) kwargs = { - 'arch': 'unet', - 'classes': 10, - 'encoder_name': 'resnet18', - 'in_channels': 3, - "encoder_weights": "imagenet" + "arch": "unet", + "classes": 10, + "encoder_name": "resnet18", + "in_channels": 3, + "encoder_weights": "imagenet", } mock_smp.create_model.assert_called_with(**kwargs) diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 0c3c3bd7f66..79058bec3f8 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -125,7 +125,7 @@ def test_predict_numpy(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") @@ -160,7 +160,7 @@ def test_load_from_checkpoint_dependency_error(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_available_pretrained_weights(): - assert SemanticSegmentation.available_pretrained_weights("resnet18") == ['imagenet', 'ssl', 'swsl'] + assert SemanticSegmentation.available_pretrained_weights("resnet18") == ["imagenet", "ssl", "swsl"] @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") diff --git a/tests/image/segmentation/test_serialization.py b/tests/image/segmentation/test_serialization.py index 9d82f557a60..0e7477348a2 100644 --- a/tests/image/segmentation/test_serialization.py +++ b/tests/image/segmentation/test_serialization.py @@ -21,7 +21,6 @@ class TestSemanticSegmentationLabels: - @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") @staticmethod def test_smoke(): @@ -69,9 +68,7 @@ def test_serialize_fiftyone(): sample = { DefaultDataKeys.PREDS: preds, - DefaultDataKeys.METADATA: { - "filepath": "something" - }, + DefaultDataKeys.METADATA: {"filepath": "something"}, } segmentation = serial.serialize(sample) diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 978dc002a89..cc9f80c629b 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -21,11 +21,16 @@ from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES -@pytest.mark.parametrize(["backbone", "expected_num_features"], [ - pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")), - pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), -]) +@pytest.mark.parametrize( + ["backbone", "expected_num_features"], + [ + pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")), + pytest.param( + "mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") + ), + ], +) def test_image_classifier_backbones_registry(backbone, expected_num_features): backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone) backbone_model, num_features = backbone_fn(pretrained=False) @@ -33,14 +38,20 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): assert num_features == expected_num_features -@pytest.mark.parametrize(["backbone", "pretrained", "expected_num_features"], [ - pytest.param( - "resnet50", "supervised", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") - ), - pytest.param( - "resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") - ), -]) +@pytest.mark.parametrize( + ["backbone", "pretrained", "expected_num_features"], + [ + pytest.param( + "resnet50", + "supervised", + 2048, + marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision"), + ), + pytest.param( + "resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") + ), + ], +) def test_pretrained_weights_registry(backbone, pretrained, expected_num_features): backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone) backbone_model, num_features = backbone_fn(pretrained=pretrained) @@ -48,20 +59,22 @@ def test_pretrained_weights_registry(backbone, pretrained, expected_num_features assert num_features == expected_num_features -@pytest.mark.parametrize(["backbone", "pretrained"], [ - pytest.param("resnet50w2", True), - pytest.param("resnet50w4", "supervised"), -]) +@pytest.mark.parametrize( + ["backbone", "pretrained"], + [ + pytest.param("resnet50w2", True), + pytest.param("resnet50w4", "supervised"), + ], +) def test_wide_resnets(backbone, pretrained): with pytest.raises(KeyError, match="Supervised pretrained weights not available for {0}".format(backbone)): IMAGE_CLASSIFIER_BACKBONES.get(backbone)(pretrained=pretrained) def test_pretrained_backbones_catch_url_error(): - def raise_error_if_pretrained(pretrained=False): if pretrained: - raise urllib.error.URLError('Test error') + raise urllib.error.URLError("Test error") with pytest.warns(UserWarning, match="Failed to download pretrained weights"): catch_url_error(raise_error_if_pretrained)(pretrained=True) diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py index 2423022bf00..b337fa28daa 100644 --- a/tests/pointcloud/detection/test_data.py +++ b/tests/pointcloud/detection/test_data.py @@ -37,14 +37,13 @@ def test_pointcloud_object_detection_data(tmpdir): dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train")) class MockModel(PointCloudObjectDetector): - def training_step(self, batch, batch_idx: int): assert isinstance(batch, ObjectDetectBatchCollator) assert len(batch.point) == 2 assert batch.point[0][1].shape == torch.Size([4]) assert len(batch.bboxes) > 1 - assert batch.attr[0]["name"] in ('000000.bin', '000001.bin') - assert batch.attr[1]["name"] in ('000000.bin', '000001.bin') + assert batch.attr[0]["name"] in ("000000.bin", "000001.bin") + assert batch.attr[1]["name"] in ("000000.bin", "000001.bin") num_classes = 19 model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes) diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py index b7d807c8377..deafc06fafd 100644 --- a/tests/pointcloud/detection/test_model.py +++ b/tests/pointcloud/detection/test_model.py @@ -21,4 +21,4 @@ def test_backbones(): backbones = PointCloudObjectDetector.available_backbones() - assert backbones == ['pointpillars', 'pointpillars_kitti'] + assert backbones == ["pointpillars", "pointpillars_kitti"] diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py index 9411c3639eb..a4c808fff22 100644 --- a/tests/pointcloud/segmentation/test_data.py +++ b/tests/pointcloud/segmentation/test_data.py @@ -34,7 +34,6 @@ def test_pointcloud_segmentation_data(tmpdir): dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train")) class MockModel(PointCloudSegmentation): - def training_step(self, batch, batch_idx: int): assert batch[DefaultDataKeys.INPUT]["xyz"][0].shape == torch.Size([2, 45056, 3]) assert batch[DefaultDataKeys.INPUT]["xyz"][1].shape == torch.Size([2, 11264, 3]) @@ -43,8 +42,8 @@ def training_step(self, batch, batch_idx: int): assert batch[DefaultDataKeys.INPUT]["labels"].shape == torch.Size([2, 45056]) assert batch[DefaultDataKeys.INPUT]["labels"].max() == 19 assert batch[DefaultDataKeys.INPUT]["labels"].min() == 0 - assert batch[DefaultDataKeys.METADATA][0]["name"] in ('00_000000', '00_000001') - assert batch[DefaultDataKeys.METADATA][1]["name"] in ('00_000000', '00_000001') + assert batch[DefaultDataKeys.METADATA][0]["name"] in ("00_000000", "00_000001") + assert batch[DefaultDataKeys.METADATA][1]["name"] in ("00_000000", "00_000001") num_classes = 19 model = MockModel(backbone="randlanet", num_classes=num_classes) diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py index 13c4120a1bf..234f867e642 100644 --- a/tests/pointcloud/segmentation/test_model.py +++ b/tests/pointcloud/segmentation/test_model.py @@ -22,7 +22,7 @@ def test_backbones(): backbones = PointCloudSegmentation.available_backbones() - assert backbones == ['randlanet', 'randlanet_s3dis', 'randlanet_semantic_kitti', 'randlanet_toronto3d'] + assert backbones == ["randlanet", "randlanet_s3dis", "randlanet_semantic_kitti", "randlanet_toronto3d"] @pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index a2c11ddebdc..b1e9ef3f252 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -110,7 +110,7 @@ def test_tabular_data(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") @@ -138,7 +138,7 @@ def test_categorical_target(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") @@ -154,7 +154,7 @@ def test_from_data_frame(tmpdir): val_data_frame=val_data_frame, test_data_frame=test_data_frame, num_workers=0, - batch_size=1 + batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) @@ -162,7 +162,7 @@ def test_from_data_frame(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") @@ -181,7 +181,7 @@ def test_from_csv(tmpdir): val_file=str(val_csv), test_file=str(test_csv), num_workers=0, - batch_size=1 + batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) @@ -189,7 +189,7 @@ def test_from_csv(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index a64c2d090d4..e7ee5e9f5d8 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -28,15 +28,14 @@ class DummyDataset(torch.utils.data.Dataset): - def __init__(self, num_num=16, num_cat=16): super().__init__() self.num_num = num_num self.num_cat = num_cat def __getitem__(self, index): - target = torch.randint(0, 10, size=(1, )).item() - cat_vars = torch.randint(0, 10, size=(self.num_cat, )) + target = torch.randint(0, 10, size=(1,)).item() + cat_vars = torch.randint(0, 10, size=(self.num_cat,)) num_vars = torch.rand(self.num_num) return {DefaultDataKeys.INPUT: (cat_vars, num_vars), DefaultDataKeys.TARGET: target} @@ -83,7 +82,7 @@ def test_jit(tmpdir): model.eval() # torch.jit.script doesn't work with tabnet - model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)), )) + model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)),)) # TODO: torch.jit.save doesn't work with tabnet # path = os.path.join(tmpdir, "test.pt") diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py index 6bdec2f2ef8..b793849e088 100644 --- a/tests/template/classification/test_data.py +++ b/tests/template/classification/test_data.py @@ -49,7 +49,7 @@ def test_smoke(): def test_from_numpy(self): """Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method.""" data = np.random.rand(10, self.num_features) - targets = np.random.randint(0, self.num_classes, (10, )) + targets = np.random.randint(0, self.num_classes, (10,)) # instantiate the data module dm = TemplateData.from_numpy( @@ -71,19 +71,19 @@ def test_from_numpy(self): data = next(iter(dm.train_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, self.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check val data data = next(iter(dm.val_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, self.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check test data data = next(iter(dm.test_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, self.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) @staticmethod def test_from_sklearn(): @@ -107,16 +107,16 @@ def test_from_sklearn(): data = next(iter(dm.train_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, dm.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check val data data = next(iter(dm.val_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, dm.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check test data data = next(iter(dm.test_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, dm.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index 9fa57b80b91..cfd0f77f39a 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -39,7 +39,7 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { DefaultDataKeys.INPUT: torch.randn(self.num_features), - DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1, ))[0], + DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1,))[0], } def __len__(self) -> int: @@ -121,7 +121,7 @@ def test_predict_sklearn(): @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index b92c3757cc9..4c42909b35f 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -90,7 +90,7 @@ def test_test_valid(tmpdir): train_file=csv_path, val_file=csv_path, test_file=csv_path, - batch_size=1 + batch_size=1, ) batch = next(iter(dm.val_dataloader())) assert batch["labels"].item() in [0, 1] @@ -135,9 +135,7 @@ def test_text_module_not_found_error(): "cls, kwargs", [ (TextDataSource, {}), - (TextFileDataSource, { - "filetype": "csv" - }), + (TextFileDataSource, {"filetype": "csv"}), (TextCSVDataSource, {}), (TextJSONDataSource, {}), (TextSentencesDataSource, {}), diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 4bf7db1c824..73da369e25e 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -29,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(100, )), - "labels": torch.randint(2, size=(1, )).item(), + "input_ids": torch.randint(1000, size=(100,)), + "labels": torch.randint(2, size=(1,)).item(), } def __len__(self) -> int: @@ -92,8 +91,11 @@ def test_load_from_checkpoint_dependency_error(): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") @pytest.mark.parametrize( - "cli_args", (["flash", "text-classification", "--trainer.fast_dev_run", "True" - ], ["flash", "text-classification", "--trainer.fast_dev_run", "True", "from_toxic"]) + "cli_args", + ( + ["flash", "text-classification", "--trainer.fast_dev_run", "True"], + ["flash", "text-classification", "--trainer.fast_dev_run", "True", "from_toxic"], + ), ) def test_cli(cli_args): with mock.patch("sys.argv", cli_args): diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py index 4f2144aa905..d52bd9132ac 100644 --- a/tests/text/seq2seq/core/test_data.py +++ b/tests/text/seq2seq/core/test_data.py @@ -36,22 +36,11 @@ @pytest.mark.parametrize( "cls, kwargs", [ - (Seq2SeqDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), - (Seq2SeqFileDataSource, { - "backbone": "sshleifer/tiny-mbart", - "filetype": "csv" - }), - (Seq2SeqCSVDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), - (Seq2SeqJSONDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), - (Seq2SeqSentencesDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), + (Seq2SeqDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqFileDataSource, {"backbone": "sshleifer/tiny-mbart", "filetype": "csv"}), + (Seq2SeqCSVDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqJSONDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqSentencesDataSource, {"backbone": "sshleifer/tiny-mbart"}), (Seq2SeqPostprocess, {}), ], ) diff --git a/tests/text/seq2seq/core/test_metrics.py b/tests/text/seq2seq/core/test_metrics.py index 692c4a80784..c16f828c37b 100644 --- a/tests/text/seq2seq/core/test_metrics.py +++ b/tests/text/seq2seq/core/test_metrics.py @@ -28,7 +28,7 @@ def test_rouge(): @pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)]) def test_bleu_score(smooth, expected): - translate_corpus = ['the cat is on the mat'.split()] - reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + translate_corpus = ["the cat is on the mat".split()] + reference_corpus = [["there is a cat on the mat".split(), "a cat is on the mat".split()]] metric = BLEUScore(smooth=smooth) assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4) diff --git a/tests/text/seq2seq/question_answering/test_model.py b/tests/text/seq2seq/question_answering/test_model.py index 3f2ee8f960e..ad4389b7680 100644 --- a/tests/text/seq2seq/question_answering/test_model.py +++ b/tests/text/seq2seq/question_answering/test_model.py @@ -29,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(128, )), - "labels": torch.randint(1000, size=(128, )), + "input_ids": torch.randint(1000, size=(128,)), + "labels": torch.randint(1000, size=(128,)), } def __len__(self) -> int: diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index ccff5e6d85f..c6adf69fdc6 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -29,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(128, )), - "labels": torch.randint(1000, size=(128, )), + "input_ids": torch.randint(1000, size=(128,)), + "labels": torch.randint(1000, size=(128,)), } def __len__(self) -> int: diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index 27162491a06..f87a51fdcd7 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -79,7 +79,7 @@ def test_from_files(tmpdir): train_file=csv_path, val_file=csv_path, test_file=csv_path, - batch_size=1 + batch_size=1, ) batch = next(iter(dm.val_dataloader())) assert "labels" in batch diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index c49ccd4c24c..237fa3bb5ad 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -29,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(128, )), - "labels": torch.randint(1000, size=(128, )), + "input_ids": torch.randint(1000, size=(128,)), + "labels": torch.randint(1000, size=(128,)), } def __len__(self) -> int: diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index 3ba81eaa360..dca5dc81ab5 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -45,7 +45,7 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int): for i in range(num_frames): xc = float(i) / num_frames yc = 1 - float(i) / (2 * num_frames) - d = torch.exp(-((x - xc)**2 + (y - yc)**2) / 2) * 255 + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) return torch.stack(data, 0) @@ -152,28 +152,34 @@ def test_video_classifier_finetune(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([ - UniformTemporalSubsample(8), - RandomShortSideScale(min_size=256, max_size=320), - RandomCrop(244), - RandomHorizontalFlip(p=0.5), - ]), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), + "post_tensor_transform": Compose( + [ + ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ] + ), + ), + ] + ), + "per_batch_transform_on_device": Compose( + [ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), + data_format="BCTHW", + same_on_frame=False, + ), + ), + ] + ), } datamodule = VideoClassificationData.from_folders( @@ -182,7 +188,7 @@ def test_video_classifier_finetune(tmpdir): clip_duration=half_duration, video_sampler=SequentialSampler, decode_audio=False, - train_transform=train_transform + train_transform=train_transform, ) model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") @@ -222,28 +228,34 @@ def test_video_classifier_finetune_fiftyone(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([ - UniformTemporalSubsample(8), - RandomShortSideScale(min_size=256, max_size=320), - RandomCrop(244), - RandomHorizontalFlip(p=0.5), - ]), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), + "post_tensor_transform": Compose( + [ + ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ] + ), + ), + ] + ), + "per_batch_transform_on_device": Compose( + [ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), + data_format="BCTHW", + same_on_frame=False, + ), + ), + ] + ), } datamodule = VideoClassificationData.from_fiftyone( @@ -252,7 +264,7 @@ def test_video_classifier_finetune_fiftyone(tmpdir): clip_duration=half_duration, video_sampler=SequentialSampler, decode_audio=False, - train_transform=train_transform + train_transform=train_transform, ) model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50")