From 95aa46bf3b6cd82010a83e67d601cd71ecc1dc2b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 17 Aug 2021 21:14:48 +0100 Subject: [PATCH 1/7] Add some providers --- flash/core/utilities/providers.py | 4 ++++ flash/image/segmentation/backbones.py | 2 ++ flash/image/style_transfer/backbones.py | 3 ++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index ff464e690c..d5eee5e071 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -18,3 +18,7 @@ _ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") _MMDET = Provider("OpenMMLab/MMDetection", "https://github.com/open-mmlab/mmdetection") _EFFDET = Provider("rwightman/efficientdet-pytorch", "https://github.com/rwightman/efficientdet-pytorch") +_SEGMENTATION_MODELS = Provider( + "qubvel/segmentation_models.pytorch", "https://github.com/qubvel/segmentation_models.pytorch" +) +_PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche") diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 30690cfaf1..0c73cc14fa 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -15,6 +15,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.core.utilities.providers import _SEGMENTATION_MODELS if _SEGMENTATION_MODELS_AVAILABLE: import segmentation_models_pytorch as smp @@ -39,4 +40,5 @@ def _load_smp_backbone(backbone: str, **_) -> str: name=short_name, namespace="image/segmentation", weights_paths=available_weights, + providers=_SEGMENTATION_MODELS, ) diff --git a/flash/image/style_transfer/backbones.py b/flash/image/style_transfer/backbones.py index 4d951603d2..07c05f1ca1 100644 --- a/flash/image/style_transfer/backbones.py +++ b/flash/image/style_transfer/backbones.py @@ -15,6 +15,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYSTICHE_AVAILABLE +from flash.core.utilities.providers import _PYSTICHE STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") @@ -35,5 +36,5 @@ fn=lambda: (getattr(enc, mle_fn)(), None), name=match.group("name"), namespace="image/style_transfer", - package="pystiche", + providers=_PYSTICHE, ) From a9f543af56e424d6cff72c5a0d9e6b42f4e95e02 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 18 Aug 2021 12:33:38 +0100 Subject: [PATCH 2/7] Updates --- flash/core/utilities/providers.py | 2 ++ flash/image/classification/backbones/timm.py | 2 ++ flash/image/classification/backbones/torchvision.py | 7 ++++--- flash/image/classification/backbones/transformers.py | 7 +++---- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index d5eee5e071..593be6076d 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -13,6 +13,8 @@ # limitations under the License. from flash.core.registry import Provider +_TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models") +_DINO = Provider("facebookresearch/dino", "https://github.com/facebookresearch/dino") _ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision") _TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision") _ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") diff --git a/flash/image/classification/backbones/timm.py b/flash/image/classification/backbones/timm.py index 30efb815dd..ffdc71c39a 100644 --- a/flash/image/classification/backbones/timm.py +++ b/flash/image/classification/backbones/timm.py @@ -18,6 +18,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TIMM_AVAILABLE +from flash.core.utilities.providers import _TIMM from flash.core.utilities.url_error import catch_url_error from flash.image.classification.backbones.torchvision import TORCHVISION_MODELS @@ -47,4 +48,5 @@ def register_timm_backbones(register: FlashRegistry): name=model_name, namespace="vision", package="timm", + providers=_TIMM, ) diff --git a/flash/image/classification/backbones/torchvision.py b/flash/image/classification/backbones/torchvision.py index 38e4afc2f3..11c59792d3 100644 --- a/flash/image/classification/backbones/torchvision.py +++ b/flash/image/classification/backbones/torchvision.py @@ -18,6 +18,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _TORCHVISION from flash.core.utilities.url_error import catch_url_error from flash.image.classification.backbones.resnet import RESNET_MODELS @@ -59,8 +60,8 @@ def register_mobilenet_vgg_backbones(register: FlashRegistry): fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)), name=model_name, namespace="vision", - package="torchvision", type=_type, + providers=_TORCHVISION, ) @@ -71,8 +72,8 @@ def register_resnext_model(register: FlashRegistry): fn=catch_url_error(partial(_fn_resnext, model_name)), name=model_name, namespace="vision", - package="torchvision", type="resnext", + providers=_TORCHVISION, ) @@ -83,6 +84,6 @@ def register_densenet_backbones(register: FlashRegistry): fn=catch_url_error(partial(_fn_densenet, model_name)), name=model_name, namespace="vision", - package="torchvision", type="densenet", + providers=_TORCHVISION, ) diff --git a/flash/image/classification/backbones/transformers.py b/flash/image/classification/backbones/transformers.py index 35ec17bbcc..cf1fd1637c 100644 --- a/flash/image/classification/backbones/transformers.py +++ b/flash/image/classification/backbones/transformers.py @@ -14,6 +14,7 @@ import torch from flash.core.registry import FlashRegistry +from flash.core.utilities.providers import _DINO from flash.core.utilities.url_error import catch_url_error @@ -41,7 +42,5 @@ def dino_vitb8(*_, **__): def register_dino_backbones(register: FlashRegistry): - register(catch_url_error(dino_deits16)) - register(catch_url_error(dino_deits8)) - register(catch_url_error(dino_vitb16)) - register(catch_url_error(dino_vitb8)) + for model in (dino_deits16, dino_deits8, dino_vitb16, dino_vitb8): + register(catch_url_error(model), providers=_DINO) From 34b9914249784b0a8444d4e9ceea59b4d7da33c7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 18 Aug 2021 12:36:52 +0100 Subject: [PATCH 3/7] Updates --- flash/image/segmentation/heads.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index bc7ff8cd01..4886dade8f 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -18,6 +18,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.core.utilities.providers import _SEGMENTATION_MODELS if _SEGMENTATION_MODELS_AVAILABLE: import segmentation_models_pytorch as smp @@ -71,5 +72,5 @@ def _load_smp_head( partial(_load_smp_head, head=model_name), name=model_name, namespace="image/segmentation", - package="segmentation_models.pytorch", + providers=_SEGMENTATION_MODELS, ) From e093b973f4cec098d1c48ec8a52616144044afb7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 18 Aug 2021 13:20:17 +0100 Subject: [PATCH 4/7] Add speech recognition --- flash/audio/speech_recognition/backbone.py | 2 ++ flash/core/utilities/providers.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/flash/audio/speech_recognition/backbone.py b/flash/audio/speech_recognition/backbone.py index 425ef2eb00..e583d7366a 100644 --- a/flash/audio/speech_recognition/backbone.py +++ b/flash/audio/speech_recognition/backbone.py @@ -15,6 +15,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_AVAILABLE +from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") @@ -27,4 +28,5 @@ SPEECH_RECOGNITION_BACKBONES( fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name), name=model_name, + providers=[_HUGGINGFACE, _FAIRSEQ], ) diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index 593be6076d..48a54f9e2c 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -24,3 +24,5 @@ "qubvel/segmentation_models.pytorch", "https://github.com/qubvel/segmentation_models.pytorch" ) _PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche") +_HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers") +_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq") From b76954ede59f86657fb330ff02a96762b884f24d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 18 Aug 2021 16:53:22 +0100 Subject: [PATCH 5/7] Updates --- flash/core/utilities/providers.py | 4 +++- flash/pointcloud/detection/open3d_ml/backbones.py | 5 +++-- flash/pointcloud/segmentation/open3d_ml/backbones.py | 9 +++++---- flash/video/classification/model.py | 3 ++- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index 48a54f9e2c..65a8501351 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -14,7 +14,7 @@ from flash.core.registry import Provider _TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models") -_DINO = Provider("facebookresearch/dino", "https://github.com/facebookresearch/dino") +_DINO = Provider("Facebook Research/dino", "https://github.com/facebookresearch/dino") _ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision") _TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision") _ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") @@ -26,3 +26,5 @@ _PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche") _HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers") _FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq") +_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") +_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py index b8b88b1d89..759b6bdb43 100644 --- a/flash/pointcloud/detection/open3d_ml/backbones.py +++ b/flash/pointcloud/detection/open3d_ml/backbones.py @@ -20,6 +20,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.providers import _OPEN3D_ML ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" @@ -63,7 +64,7 @@ def get_collate_fn(model) -> Callable: return ObjectDetectBatchCollator return batcher.collate_fn - @register(parameters=PointPillars.__init__) + @register(parameters=PointPillars.__init__, providers=_OPEN3D_ML) def pointpillars_kitti(*args, **kwargs) -> PointPillars: cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml")) cfg.model.device = "cpu" @@ -75,7 +76,7 @@ def pointpillars_kitti(*args, **kwargs) -> PointPillars: model.cfg.batcher = "ObjectDetectBatchCollator" return model, 384, get_collate_fn(model) - @register(parameters=PointPillars.__init__) + @register(parameters=PointPillars.__init__, providers=_OPEN3D_ML) def pointpillars(*args, **kwargs) -> PointPillars: model = PointPillars(*args, **kwargs) model.cfg.batcher = "ObjectDetectBatch" diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py index abf1226b68..a326cbcdc5 100644 --- a/flash/pointcloud/segmentation/open3d_ml/backbones.py +++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py @@ -19,6 +19,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.providers import _OPEN3D_ML ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" @@ -42,7 +43,7 @@ def get_collate_fn(model) -> Callable: batcher = None return batcher.collate_fn - @register + @register(providers=_OPEN3D_ML) def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet: cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml")) model = RandLANet(**cfg.model) @@ -53,7 +54,7 @@ def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet: model.load_state_dict(pl_load(weight_url, map_location="cpu")["model_state_dict"]) return model, 32, get_collate_fn(model) - @register + @register(providers=_OPEN3D_ML) def randlanet_toronto3d(*args, **kwargs) -> RandLANet: cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml")) model = RandLANet(**cfg.model) @@ -64,7 +65,7 @@ def randlanet_toronto3d(*args, **kwargs) -> RandLANet: ) return model, 32, get_collate_fn(model) - @register + @register(providers=_OPEN3D_ML) def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet: cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml")) model = RandLANet(**cfg.model) @@ -75,7 +76,7 @@ def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet: ) return model, 32, get_collate_fn(model) - @register + @register(providers=_OPEN3D_ML) def randlanet(*args, **kwargs) -> RandLANet: model = RandLANet(*args, **kwargs) return model, 32, get_collate_fn(model) diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index e6b3b77cf9..9345b7b19b 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -31,6 +31,7 @@ from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE +from flash.core.utilities.providers import _PYTORCHVIDEO _VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones") @@ -41,7 +42,7 @@ if "__" not in fn_name: fn = getattr(hub, fn_name) if isinstance(fn, FunctionType): - _VIDEO_CLASSIFIER_BACKBONES(fn=fn) + _VIDEO_CLASSIFIER_BACKBONES(fn=fn, providers=_PYTORCHVIDEO) class VideoClassifierFinetuning(BaseFinetuning): From afaeee547a4c3fdf20d94cb4103b5f42c5a0b299 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 18 Aug 2021 18:07:12 +0100 Subject: [PATCH 6/7] Add providers list to docs --- .gitignore | 1 + docs/source/conf.py | 14 ++++++++++++++ docs/source/index.rst | 1 + docs/source/integrations/fiftyone.rst | 6 ++++-- docs/source/integrations/providers.rst | 15 +++++++++++++++ flash/core/registry.py | 13 ++----------- flash/core/utilities/providers.py | 18 +++++++++++++++++- 7 files changed, 54 insertions(+), 14 deletions(-) create mode 100644 docs/source/integrations/providers.rst diff --git a/.gitignore b/.gitignore index 9ab9838b44..7b25e29d16 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,7 @@ docs/_build/ docs/api/ docs/notebooks/ docs/source/api/generated/ +docs/source/integrations/generated/ # PyBuilder target/ diff --git a/docs/source/conf.py b/docs/source/conf.py index de578a2121..554f4c8dc7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,6 +22,7 @@ try: from flash import __about__ as about + from flash.core.utilities import providers except ModuleNotFoundError: @@ -32,6 +33,7 @@ def _load_py_module(fname, pkg="flash"): return py about = _load_py_module("__about__.py") + providers = _load_py_module("flash/core/utilities/providers.py") SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True)) @@ -43,6 +45,18 @@ def _load_py_module(fname, pkg="flash"): copyright = "2020-2021, PyTorch Lightning" author = "PyTorch Lightning" +# -- Generate providers ------------------------------------------------------ + +lines = [] +for provider in providers.PROVIDERS: + lines.append(f"- {str(provider)}\n") + +generated_dir = os.path.join("integrations", "generated") +os.makedirs(generated_dir, exist_ok=True) + +with open(os.path.join(generated_dir, "providers.rst"), "w") as f: + f.writelines(lines) + # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be diff --git a/docs/source/index.rst b/docs/source/index.rst index 95c7e2933f..8ce5e881e1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -82,6 +82,7 @@ Lightning Flash :maxdepth: 1 :caption: Integrations + integrations/providers integrations/fiftyone .. toctree:: diff --git a/docs/source/integrations/fiftyone.rst b/docs/source/integrations/fiftyone.rst index 51df47764c..8592fad47b 100644 --- a/docs/source/integrations/fiftyone.rst +++ b/docs/source/integrations/fiftyone.rst @@ -1,10 +1,12 @@ +.. _fiftyone: + ######## FiftyOne ######## We have collaborated with the team at -`Voxel51 `_ to integrate their tool, -`FiftyOne `_, into Lightning Flash. +`Voxel51 `__ to integrate their tool, +`FiftyOne `__, into Lightning Flash. FiftyOne is an open-source tool for building high-quality datasets and computer vision models. The FiftyOne API and App enable you to diff --git a/docs/source/integrations/providers.rst b/docs/source/integrations/providers.rst new file mode 100644 index 0000000000..7254acd6cf --- /dev/null +++ b/docs/source/integrations/providers.rst @@ -0,0 +1,15 @@ +.. _providers: + +######### +Providers +######### + +Flash is a framework integrator. +We rely on many open source frameworks for our tasks, visualizations and backbones. +Here's a list of some of the providers we use for backbones and heads within Flash (check them out and star their repos to spread the open source love!): + +.. include:: generated/providers.rst + +You can also read our guides for some of our larger integrations: + +- :ref:`fiftyone` diff --git a/flash/core/registry.py b/flash/core/registry.py index 1f97f2a664..d5b1b1d764 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -12,23 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException -_REGISTERED_FUNCTION = Dict[str, Any] - - -@dataclass -class Provider: +from flash.core.utilities.providers import Provider - name: str - url: str - - def __str__(self): - return f"{self.name} ({self.url})" +_REGISTERED_FUNCTION = Dict[str, Any] def print_provider_info(name, providers, func): diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index 65a8501351..f25c402683 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -11,7 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.core.registry import Provider +from dataclasses import dataclass + +PROVIDERS = [] #: testing + + +@dataclass +class Provider: + + name: str + url: str + + def __post_init__(self): + PROVIDERS.append(self) + + def __str__(self): + return f"{self.name} ({self.url})" + _TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models") _DINO = Provider("Facebook Research/dino", "https://github.com/facebookresearch/dino") From c29206ecdd8004d225af9d2455f00dfc66f4c501 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 18 Aug 2021 18:15:42 +0100 Subject: [PATCH 7/7] Add sorting --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 554f4c8dc7..15fecb69bb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -55,7 +55,7 @@ def _load_py_module(fname, pkg="flash"): os.makedirs(generated_dir, exist_ok=True) with open(os.path.join(generated_dir, "providers.rst"), "w") as f: - f.writelines(lines) + f.writelines(sorted(lines, key=str.casefold)) # -- General configuration ---------------------------------------------------