diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 30c4cd80b7..da8ff27410 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,20 +10,18 @@ # owners /.github/CODEOWNERS @williamfalcon # main -/README.md @edenlightning @ethanwharris +/README.md @ethanwharris @krshrimali # installation -/setup.py @borda @ethanwharris -/__about__.py @borda @ethanwharris -/__init__.py @borda @ethanwharris +/setup.py @borda @ethanwharris @krshrimali +/__about__.py @borda @ethanwharris @krshrimali +/__init__.py @borda @ethanwharris @krshrimali # CI/CD -/.github/workflows/ @borda @ethanwharris +/.github/workflows/ @borda @ethanwharris @krshrimali # configs in root -/*.yml @borda @ethanwharris +/*.yml @borda @ethanwharris @krshrimali # Docs -/docs/ @edenlightning @ethanwharris -/.github/*.md @edenlightning @ethanwharris -/.github/ISSUE_TEMPLATE/*.md @edenlightning @ethanwharris +/.github/ISSUE_TEMPLATE/*.md @borda @ethanwharris @krshrimali /docs/source/conf.py @borda @ethanwharris /flash/core/integrations/labelstudio @KonstantinKorotaev @niklub diff --git a/.github/labeler.yml b/.github/labeler.yml index 2f4664afde..b190ac2113 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -4,7 +4,6 @@ documentation: examples: - flash_examples/**/* - - flash_notebooks/**/* data: - flash/core/data/**/* diff --git a/.github/workflows/ci-notebook.yml b/.github/workflows/ci-notebook.yml deleted file mode 100644 index 92f3592631..0000000000 --- a/.github/workflows/ci-notebook.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: CI notebook -on: - push: - branches: [master] - pull_request: - branches: [master] - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-20.04, macOS-10.15, windows-2022] - python-version: [3.8] # 3.6, - env: - TEST_ENV: TRUE - steps: - - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - # Note: This uses an internal pip API and may not always work - # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - - name: Get pip cache - id: pip-cache - run: | - python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" - - - name: Cache pip - uses: actions/cache@v3 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} - restore-keys: | - ${{ runner.os }}-${{ matrix.python-version }}-pip- - - - name: Install dependencies - run: | - pip install -U pip wheel - pip install -e .[notebooks,image,tabular,text] --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - - - name: Cache datasets - uses: actions/cache@v3 - with: - path: flash_examples/finetuning # This path is specific to Ubuntu - # Look to see if there is a cache hit for the corresponding requirements file - key: flash-datasets_finetuning - - - name: Cache datasets - uses: actions/cache@v3 - with: - path: flash_examples/predict # This path is specific to Ubuntu - # Look to see if there is a cache hit for the corresponding requirements file - key: flash-datasets_predict - - - name: Run Notebooks - env: - FIFTYONE_DO_NOT_TRACK: true - FLASH_TESTING: 1 - run: | - jupyter nbconvert --to script flash_notebooks/image_classification.ipynb - jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb - jupyter nbconvert --to script flash_notebooks/text_classification.ipynb - - ipython flash_notebooks/image_classification.py - ipython flash_notebooks/tabular_classification.py - ipython flash_notebooks/text_classification.py diff --git a/.gitignore b/.gitignore index 8eb6b037be..41cfb04542 100644 --- a/.gitignore +++ b/.gitignore @@ -144,8 +144,6 @@ titanic.csv data_folder *.pt *.zip -flash_notebooks/*.py -flash_notebooks/data /data MNIST* titanic diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a7ecde5ca..281370aca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added fine tuning strategies for DeepSpeed (with parameter loading and storing omitted) ([#1377](https://github.com/Lightning-AI/lightning-flash/pull/1377)) +- Added `torchvision` as a requirement to `datatype_audio.txt` as it's used for Audio Classification ([#1425](https://github.com/Lightning-AI/lightning-flash/pull/1425)) + - Added `figsize` and `limit_nb_samples` for showing batch images ([#1381](https://github.com/Lightning-AI/lightning-flash/pull/1381)) - Added support for `from_lists` for Tabular Classification and Regression ([#1337](https://github.com/PyTorchLightning/lightning-flash/pull/1337)) @@ -48,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed when suitable error not being raised for image segmentation (kornia) ([#1425](https://github.com/Lightning-AI/lightning-flash/pull/1425)). + - Fixed the script of integrating `lightning-flash` with `learn2learn` ([#1376](https://github.com/Lightning-AI/lightning-flash/pull/1383)) - Fixed JIT tracing tests where the model class was not attached to the `Trainer` class ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) diff --git a/README.md b/README.md index 1734992079..26773424f4 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ model.serve() or make predictions from raw data directly. ```py -trainer = Trainer(accelerator='ddp', gpus=2) +trainer = Trainer(strategy='ddp', accelerator="gpu", gpus=2) dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB") predictions = trainer.predict(model, dm) ``` diff --git a/docs/source/_templates/theme_variables.jinja b/docs/source/_templates/theme_variables.jinja index 1aad8a99bc..121ac9f79c 100644 --- a/docs/source/_templates/theme_variables.jinja +++ b/docs/source/_templates/theme_variables.jinja @@ -6,7 +6,7 @@ 'docs': 'https://lightning-flash.readthedocs.io', 'twitter': 'https://twitter.com/PyTorchLightnin', 'discuss': 'https://pytorch-lightning.slack.com', - 'tutorials': 'https://github.com/PyTorchLightning/lightning-flash/tree/master/flash_notebooks', + 'tutorials': 'https://github.com/Lightning-AI/tutorials', 'previous_pytorch_versions': 'https://lightning-flash.readthedocs.io/en/stable', 'home': 'https://lightning-flash.readthedocs.io', 'get_started': 'https://lightning-flash.readthedocs.io/en/latest/quickstart.html', diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 301f789c65..22594c8207 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -6,13 +6,13 @@ Flash Governance | Persons of interest Leads ----- - Ethan Harris (`ethanwharris `_) +- Kushashwa Ravi Shrimali (`krshrimali `_) - Thomas Chaton (`tchaton `_) -- William Falcon (`williamFalcon `_) Core Maintainers ---------------- +- William Falcon (`williamFalcon `_) - Jirka Borovec (`Borda `_) -- Kushashwa Ravi Shrimali (`krshrimali` `_) - Kaushik Bokka (`kaushikb11 `_) - Justus Schock (`justusschock `_) - Akihiro Nitta (`akihironitta `_) diff --git a/flash/audio/classification/input_transform.py b/flash/audio/classification/input_transform.py index cd0640cf91..22384d21e8 100644 --- a/flash/audio/classification/input_transform.py +++ b/flash/audio/classification/input_transform.py @@ -19,7 +19,7 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, requires if _TORCHVISION_AVAILABLE: from torchvision import transforms as T @@ -51,6 +51,7 @@ def train_per_sample_transform(self) -> Callable: ] ) + @requires("audio") def per_sample_transform(self) -> Callable: return T.Compose( [ diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 83f5a636b4..3ae42575f9 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -15,6 +15,7 @@ import os import sys from copy import deepcopy +from enum import Enum from typing import Any, cast, Dict, Iterable, List, Sequence, Tuple, Union from pytorch_lightning.utilities.enums import LightningEnum @@ -171,7 +172,18 @@ def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> No def _call_load_sample(self, sample: Any) -> Any: # Deepcopy the sample to avoid leaks with complex data structures - return getattr(self, f"{_STAGES_PREFIX[self.running_stage]}_load_sample")(deepcopy(sample)) + sample_output = getattr(self, f"{_STAGES_PREFIX[self.running_stage]}_load_sample")(deepcopy(sample)) + + # Change DataKeys Enum to strings + if isinstance(sample_output, dict): + output_dict = {} + for key, val in sample_output.items(): + if isinstance(key, Enum) and hasattr(key, "value"): + output_dict[key.value] = val + else: + output_dict[key] = val + return output_dict + return sample_output @staticmethod def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 1c1ce6fe8b..bc22549655 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -163,7 +163,7 @@ class Image: ) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE -_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE]) +_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE]) _GRAPH_AVAILABLE = ( _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE and _NETWORKX_AVAILABLE ) diff --git a/flash/image/segmentation/input_transform.py b/flash/image/segmentation/input_transform.py index e36d6a2c49..ba4b7075da 100644 --- a/flash/image/segmentation/input_transform.py +++ b/flash/image/segmentation/input_transform.py @@ -17,7 +17,7 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, requires if _KORNIA_AVAILABLE: import kornia as K @@ -47,6 +47,7 @@ class SemanticSegmentationInputTransform(InputTransform): mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) + @requires("image") def train_per_sample_transform(self) -> Callable: return T.Compose( [ @@ -61,6 +62,7 @@ def train_per_sample_transform(self) -> Callable: ] ) + @requires("image") def per_sample_transform(self) -> Callable: return T.Compose( [ @@ -72,6 +74,7 @@ def per_sample_transform(self) -> Callable: ] ) + @requires("image") def predict_per_sample_transform(self) -> Callable: return ApplyToKeys( DataKeys.INPUT, diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index 8af2ca255d..b506dff891 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -146,6 +146,7 @@ def collate(self): trainer = flash.Trainer( max_epochs=1, gpus=1, + accelerator="gpu", precision=16, ) diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb deleted file mode 100644 index 6ee85d359a..0000000000 --- a/flash_notebooks/image_classification.ipynb +++ /dev/null @@ -1,364 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "governing-statement", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "touched-summary", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", - "\n", - "# Finetuning\n", - "\n", - "Finetuning consists of four steps:\n", - " \n", - " - 1. Train a source neural network model on a source dataset. For computer vision, it is traditionally the [ImageNet dataset](http://www.image-net.org/search?q=cat). As training is costly, library such as [Torchvion](https://pytorch.org/docs/stable/torchvision/index.html) library supports popular pre-trainer model architectures . In this notebook, we will be using their [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).\n", - " \n", - " - 2. Create a new neural network called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone\n", - " \n", - " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", - " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.\n", - " \n", - " \n", - "\n", - " \n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", - " - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://www.pytorchlightning.ai/community)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "tested-torture", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'" - ] - }, - { - "cell_type": "markdown", - "id": "present-region", - "metadata": {}, - "source": [ - "### The notebook runtime has to be re-started once Flash is installed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "national-practice", - "metadata": {}, - "outputs": [], - "source": [ - "# https://github.com/streamlit/demo-self-driving/issues/17\n", - "if 'google.colab' in str(get_ipython()):\n", - " import os\n", - " os.kill(os.getpid(), 9)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "little-writer", - "metadata": {}, - "outputs": [], - "source": [ - "import flash\n", - "from flash.core.data.utils import download_data\n", - "from flash.image import ImageClassificationData, ImageClassifier" - ] - }, - { - "cell_type": "markdown", - "id": "federal-anaheim", - "metadata": {}, - "source": [ - "## 1. Download data\n", - "The data are downloaded from a URL, and save in a 'data' directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "former-arcade", - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip\", 'data/')" - ] - }, - { - "cell_type": "markdown", - "id": "attempted-serve", - "metadata": {}, - "source": [ - "

2. Load the data

\n", - "\n", - "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", - "Creates a ImageClassificationData object from folders of images arranged in this way:\n", - "\n", - "\n", - " train/dog/xxx.png\n", - " train/dog/xxy.png\n", - " train/dog/xxz.png\n", - " train/cat/123.png\n", - " train/cat/nsdf3.png\n", - " train/cat/asd932.png\n", - "\n", - "\n", - "Note: Each sub-folder content will be considered as a new class." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "undefined-expert", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = ImageClassificationData.from_folders(\n", - " train_folder=\"data/hymenoptera_data/train/\",\n", - " val_folder=\"data/hymenoptera_data/val/\",\n", - " test_folder=\"data/hymenoptera_data/test/\",\n", - " batch_size=1,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "raised-groove", - "metadata": {}, - "source": [ - "### 3. Build the model\n", - "Create the ImageClassifier task. By default, the ImageClassifier task uses a [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/) backbone to train or finetune your model.\n", - "For [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images, ``datamodule.num_classes`` will be 2.\n", - "Backbone can easily be changed with `ImageClassifier(backbone=\"resnet50\")` or you could provide your own `ImageClassifier(backbone=my_backbone)`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unauthorized-timer", - "metadata": {}, - "outputs": [], - "source": [ - "model = ImageClassifier(num_classes=datamodule.num_classes)" - ] - }, - { - "cell_type": "markdown", - "id": "referenced-sacramento", - "metadata": {}, - "source": [ - "### 4. Create the trainer. Run once on data\n", - "\n", - "The trainer object can be used for training or fine-tuning tasks on new sets of data. \n", - "\n", - "You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.\n", - "\n", - "For more details, read the [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html).\n", - "\n", - "In this demo, we will limit the fine-tuning to run just one epoch using max_epochs=2." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "average-aggregate", - "metadata": {}, - "outputs": [], - "source": [ - "trainer = flash.Trainer(max_epochs=3)" - ] - }, - { - "cell_type": "markdown", - "id": "vulnerable-contamination", - "metadata": {}, - "source": [ - "### 5. Finetune the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "apart-arrangement", - "metadata": { - "tags": [ - "outputPrepend" - ] - }, - "outputs": [], - "source": [ - "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze\")" - ] - }, - { - "cell_type": "markdown", - "id": "electronic-lobby", - "metadata": {}, - "source": [ - "### 6. Test the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "refined-narrative", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "markdown", - "id": "alpha-vacuum", - "metadata": {}, - "source": [ - "### 7. Save it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "composed-equivalent", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_checkpoint(\"image_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "surprised-express", - "metadata": {}, - "source": [ - "# Predicting" - ] - }, - { - "cell_type": "markdown", - "id": "bridal-christianity", - "metadata": {}, - "source": [ - "### 1. Load the model from a checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "extreme-surrey", - "metadata": {}, - "outputs": [], - "source": [ - "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "individual-recipe", - "metadata": {}, - "source": [ - "### 2. Predict what's on a few images! ants or bees?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "lyric-johnston", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = ImageClassificationData.from_files(\n", - " predict_files=[\n", - " \"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg\",\n", - " \"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg\",\n", - " \"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg\",\n", - " ],\n", - " batch_size=1,\n", - ")\n", - "predictions = trainer.predict(model, datamodule=datamodule)\n", - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "prime-leadership", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb deleted file mode 100644 index 776857c976..0000000000 --- a/flash_notebooks/tabular_classification.ipynb +++ /dev/null @@ -1,327 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "twelve-miracle", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "genuine-elephant", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash by training a TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", - " - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://www.pytorchlightning.ai/community)" - ] - }, - { - "cell_type": "markdown", - "id": "sorted-dancing", - "metadata": {}, - "source": [ - "# Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "caring-appreciation", - "metadata": {}, - "outputs": [], - "source": [ - "# %%capture\n", - "! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[tabular]'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "sexual-diabetes", - "metadata": {}, - "outputs": [], - "source": [ - "from torchmetrics.classification import Accuracy, Precision, Recall\n", - "\n", - "import flash\n", - "from flash.core.data.utils import download_data\n", - "from flash.tabular import TabularClassifier, TabularClassificationData" - ] - }, - { - "cell_type": "markdown", - "id": "boxed-harvest", - "metadata": {}, - "source": [ - "### 1. Download the data\n", - "The data are downloaded from a URL, and save in a 'data' directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "backed-render", - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/titanic.zip\", 'data/')" - ] - }, - { - "cell_type": "markdown", - "id": "young-arthritis", - "metadata": {}, - "source": [ - "### 2. Load the data\n", - "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", - "\n", - "Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ultimate-bunny", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = TabularClassificationData.from_csv(\n", - " [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", - " [\"Fare\"],\n", - " target_fields=\"Survived\",\n", - " train_file=\"./data/titanic/titanic.csv\",\n", - " test_file=\"./data/titanic/test.csv\",\n", - " val_split=0.25,\n", - " batch_size=8,\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "brutal-hypothesis", - "metadata": {}, - "source": [ - "### 3. Build the model\n", - "\n", - "Note: Categorical columns will be mapped to the embedding space. Embedding space is set of tensors to be trained associated to each categorical column. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "practical-perry", - "metadata": {}, - "outputs": [], - "source": [ - "model = TabularClassifier.from_data(datamodule)" - ] - }, - { - "cell_type": "markdown", - "id": "dietary-bowling", - "metadata": {}, - "source": [ - "### 4. Create the trainer. Run 10 times on data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "integral-interface", - "metadata": {}, - "outputs": [], - "source": [ - "trainer = flash.Trainer(max_epochs=10)" - ] - }, - { - "cell_type": "markdown", - "id": "liable-remains", - "metadata": {}, - "source": [ - "### 5. Train the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "controversial-newcastle", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.fit(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "markdown", - "id": "fluid-franchise", - "metadata": {}, - "source": [ - "### 6. Test model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "therapeutic-bidder", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "markdown", - "id": "genuine-pilot", - "metadata": {}, - "source": [ - "### 7. Save it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "alien-stand", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_checkpoint(\"tabular_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "conventional-travel", - "metadata": {}, - "source": [ - "# Predicting" - ] - }, - { - "cell_type": "markdown", - "id": "coated-insulation", - "metadata": {}, - "source": [ - "### 8. Load the model from a checkpoint\n", - "\n", - "`TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "alpine-drilling", - "metadata": {}, - "outputs": [], - "source": [ - "model = TabularClassifier.load_from_checkpoint(\n", - " \"https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "painted-assistant", - "metadata": {}, - "source": [ - "### 9. Generate predictions from a sheet file! Who would survive?\n", - "\n", - "`TabularClassifier.predict` support both DataFrame and path to `.csv` file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "located-cable", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = TabularClassificationData.from_csv(\n", - " predict_file=\"data/titanic/titanic.csv\",\n", - " parameters=datamodule.parameters,\n", - " batch_size=8,\n", - ")\n", - "predictions = trainer.predict(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "realistic-infection", - "metadata": {}, - "outputs": [], - "source": [ - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "classified-casino", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb deleted file mode 100644 index 11515905e3..0000000000 --- a/flash_notebooks/text_classification.ipynb +++ /dev/null @@ -1,354 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "instant-bruce", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "orange-spread", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", - "\n", - "# Finetuning\n", - "\n", - "Finetuning consists of four steps:\n", - " \n", - " - 1. Train a source neural network model on a source dataset. For text classication, it is traditionally a transformer model such as BERT [Bidirectional Encoder Representations from Transformers](https://arxiv.org/abs/1810.04805) trained on wikipedia.\n", - "As those model are costly to train, [Transformers](https://github.com/huggingface/transformers) or [FairSeq](https://github.com/pytorch/fairseq) libraries provides popular pre-trained model architectures for NLP. In this notebook, we will be using [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny).\n", - "\n", - " \n", - " - 2. Create a new neural network the target model. Its architecture replicates all model designs and their parameters on the source model, expect the latest layer which is removed. This model without its latest layers is traditionally called a backbone\n", - " \n", - "\n", - "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", - " \n", - "\n", - "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", - " - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://www.pytorchlightning.ai/community)" - ] - }, - { - "cell_type": "markdown", - "id": "generic-evaluation", - "metadata": {}, - "source": [ - "### Setup \n", - "Lightning Flash is easy to install. Simply ```pip install lightning-flash```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "academic-alpha", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[text]'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "historical-asthma", - "metadata": {}, - "outputs": [], - "source": [ - "import flash\n", - "from flash.core.data.utils import download_data\n", - "from flash.text import TextClassificationData, TextClassifier" - ] - }, - { - "cell_type": "markdown", - "id": "bronze-ghost", - "metadata": {}, - "source": [ - "### 1. Download the data\n", - "The data are downloaded from a URL, and save in a 'data' directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "applied-operation", - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/imdb.zip\", 'data/')" - ] - }, - { - "cell_type": "markdown", - "id": "instrumental-approval", - "metadata": {}, - "source": [ - "

2. Load the data

\n", - "\n", - "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", - "Creates a TextClassificationData object from csv file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "flush-prince", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = TextClassificationData.from_csv(\n", - " \"review\",\n", - " \"sentiment\",\n", - " train_file=\"data/imdb/train.csv\",\n", - " val_file=\"data/imdb/valid.csv\",\n", - " test_file=\"data/imdb/test.csv\",\n", - " batch_size=4,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "vital-ecuador", - "metadata": { - "jupyter": { - "outputs_hidden": true - } - }, - "source": [ - "### 3. Build the model\n", - "\n", - "Create the TextClassifier task. By default, the TextClassifier task uses a [tiny-bert](https://huggingface.co/prajjwal1/bert-tiny) backbone to train or finetune your model demo. You could use any models from [transformers - Text Classification](https://huggingface.co/models?filter=text-classification,pytorch)\n", - "\n", - "Backbone can easily be changed with such as `TextClassifier(backbone='bert-tiny-mnli')`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "weighted-cosmetic", - "metadata": {}, - "outputs": [], - "source": [ - "model = TextClassifier(num_classes=datamodule.num_classes, backbone=\"prajjwal1/bert-tiny\")" - ] - }, - { - "cell_type": "markdown", - "id": "neural-blade", - "metadata": { - "jupyter": { - "outputs_hidden": true - } - }, - "source": [ - "### 4. Create the trainer. Run once on data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "august-family", - "metadata": {}, - "outputs": [], - "source": [ - "trainer = flash.Trainer(max_epochs=1)" - ] - }, - { - "cell_type": "markdown", - "id": "figured-exhaust", - "metadata": { - "jupyter": { - "outputs_hidden": true - } - }, - "source": [ - "### 5. Fine-tune the model\n", - "\n", - "The backbone won't be freezed and the entire model will be finetuned on the imdb dataset " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "creative-reform", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze\")" - ] - }, - { - "cell_type": "markdown", - "id": "periodic-holocaust", - "metadata": { - "jupyter": { - "outputs_hidden": true - } - }, - "source": [ - "### 6. Test model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "stopped-clark", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "markdown", - "id": "turned-harris", - "metadata": { - "jupyter": { - "outputs_hidden": true - } - }, - "source": [ - "### 7. Save it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "rotary-account", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_checkpoint(\"text_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "protective-panic", - "metadata": {}, - "source": [ - "# Predicting" - ] - }, - { - "cell_type": "markdown", - "id": "precious-casino", - "metadata": {}, - "source": [ - "### 1. Load the model from a checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eligible-coordination", - "metadata": {}, - "outputs": [], - "source": [ - "model = TextClassifier.load_from_checkpoint(\"text_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "worst-consumer", - "metadata": {}, - "source": [ - "### 2. Classify a few sentences! How was the movie?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "distinct-tragedy", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = TextClassificationData.from_lists(\n", - " predict_data=[\n", - " \"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.\",\n", - " \"The worst movie in the history of cinema.\",\n", - " \"I come from Bulgaria where it 's almost impossible to have a tornado.\",\n", - " ],\n", - " batch_size=4,\n", - ")\n", - "predictions = trainer.predict(model, datamodule=datamodule)\n", - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "other-grain", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index a865ca7ec4..4db6c7765a 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,4 +1,5 @@ torchaudio +torchvision librosa>=0.8.1 transformers>=4.13.0 datasets>=1.16.1 diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 0addd5c433..4588c1b327 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -4,4 +4,4 @@ lightning-bolts>=0.3.3 Pillow>=7.2 kornia>=0.5.1 pystiche==1.* -segmentation-models-pytorch +segmentation-models-pytorch>=0.2.0 diff --git a/setup.cfg b/setup.cfg index 5e85913a25..c29a83b5f9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,6 @@ exclude = *.egg build temp - flash_notebooks .git select = E,W,F doctests = True diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index 350f1509e9..d1d96aba8c 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -590,8 +590,8 @@ def on_exception(self, execption): @pytest.mark.parametrize( "trainer_kwargs", ( - dict(accelerator="ddp_cpu"), - dict(accelerator="ddp_cpu", plugins="ddp_find_unused_parameters_false"), + dict(accelerator="cpu", strategy="ddp"), + dict(accelerator="cpu", strategy="ddp", plugins="ddp_find_unused_parameters_false"), ), ) @pytest.mark.skipif(not _PL_GREATER_EQUAL_1_4_0, reason="Bugs in PL < 1.4.0") diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py index 3fedd7be68..82832b7e40 100644 --- a/tests/image/classification/test_training_strategies.py +++ b/tests/image/classification/test_training_strategies.py @@ -20,7 +20,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_TESTING, _LEARN2LEARN_AVAILABLE +from flash.core.utilities.imports import _IMAGE_TESTING, _LEARN2LEARN_AVAILABLE, _PL_GREATER_EQUAL_1_6_0 from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.adapters import TRAINING_STRATEGIES from tests.image.classification.test_data import _rand_image @@ -54,7 +54,7 @@ def test_learn2learn_training_strategies_registry(): assert TRAINING_STRATEGIES.available_keys() == ["anil", "default", "maml", "metaoptnet", "prototypicalnetworks"] -def _test_learn2learning_training_strategies(gpus, accelerator, training_strategy, tmpdir): +def _test_learn2learning_training_strategies(gpus, training_strategy, tmpdir, accelerator=None, strategy=None): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -87,7 +87,11 @@ def _test_learn2learning_training_strategies(gpus, accelerator, training_strateg training_strategy_kwargs={"ways": dm.num_classes, "shots": 4, "meta_batch_size": 4}, ) - trainer = Trainer(fast_dev_run=2, gpus=gpus, accelerator=accelerator) + if _PL_GREATER_EQUAL_1_6_0: + trainer = Trainer(fast_dev_run=2, gpus=gpus, strategy=strategy) + else: + trainer = Trainer(fast_dev_run=2, gpus=gpus, accelerator=accelerator) + trainer.fit(model, datamodule=dm) @@ -95,7 +99,7 @@ def _test_learn2learning_training_strategies(gpus, accelerator, training_strateg @pytest.mark.parametrize("training_strategy", ["anil", "maml", "prototypicalnetworks"]) @pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") def test_learn2learn_training_strategies(training_strategy, tmpdir): - _test_learn2learning_training_strategies(0, None, training_strategy, tmpdir) + _test_learn2learning_training_strategies(0, training_strategy, tmpdir, accelerator=None) @pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") @@ -111,4 +115,7 @@ def test_wrongly_specified_training_strategies(): @pytest.mark.skipif(not os.getenv("FLASH_RUNNING_SPECIAL_TESTS", "0") == "1", reason="Should run with special test") @pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") def test_learn2learn_training_strategies_ddp(tmpdir): - _test_learn2learning_training_strategies(2, "ddp", "prototypicalnetworks", tmpdir) + if _PL_GREATER_EQUAL_1_6_0: + _test_learn2learning_training_strategies(2, "prototypicalnetworks", tmpdir, strategy="ddp") + else: + _test_learn2learning_training_strategies(2, "prototypicalnetworks", tmpdir, accelerator="ddp")