Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add missing providers #674

Merged
merged 9 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ docs/_build/
docs/api/
docs/notebooks/
docs/source/api/generated/
docs/source/integrations/generated/

# PyBuilder
target/
Expand Down
14 changes: 14 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

try:
from flash import __about__ as about
from flash.core.utilities import providers

except ModuleNotFoundError:

Expand All @@ -32,6 +33,7 @@ def _load_py_module(fname, pkg="flash"):
return py

about = _load_py_module("__about__.py")
providers = _load_py_module("flash/core/utilities/providers.py")

SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))

Expand All @@ -43,6 +45,18 @@ def _load_py_module(fname, pkg="flash"):
copyright = "2020-2021, PyTorch Lightning"
author = "PyTorch Lightning"

# -- Generate providers ------------------------------------------------------

lines = []
for provider in providers.PROVIDERS:
lines.append(f"- {str(provider)}\n")

generated_dir = os.path.join("integrations", "generated")
os.makedirs(generated_dir, exist_ok=True)

with open(os.path.join(generated_dir, "providers.rst"), "w") as f:
f.writelines(sorted(lines, key=str.casefold))

# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Lightning Flash
:maxdepth: 1
:caption: Integrations

integrations/providers
integrations/fiftyone

.. toctree::
Expand Down
6 changes: 4 additions & 2 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
.. _fiftyone:

########
FiftyOne
########

We have collaborated with the team at
`Voxel51 <https://voxel51.com>`_ to integrate their tool,
`FiftyOne <https://fiftyone.ai>`_, into Lightning Flash.
`Voxel51 <https://voxel51.com>`__ to integrate their tool,
`FiftyOne <https://fiftyone.ai>`__, into Lightning Flash.

FiftyOne is an open-source tool for building high-quality
datasets and computer vision models. The FiftyOne API and App enable you to
Expand Down
15 changes: 15 additions & 0 deletions docs/source/integrations/providers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.. _providers:

#########
Providers
#########

Flash is a framework integrator.
We rely on many open source frameworks for our tasks, visualizations and backbones.
Here's a list of some of the providers we use for backbones and heads within Flash (check them out and star their repos to spread the open source love!):

.. include:: generated/providers.rst

You can also read our guides for some of our larger integrations:

- :ref:`fiftyone`
2 changes: 2 additions & 0 deletions flash/audio/speech_recognition/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE

SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones")

Expand All @@ -27,4 +28,5 @@
SPEECH_RECOGNITION_BACKBONES(
fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name),
name=model_name,
providers=[_HUGGINGFACE, _FAIRSEQ],
)
13 changes: 2 additions & 11 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_REGISTERED_FUNCTION = Dict[str, Any]


@dataclass
class Provider:
from flash.core.utilities.providers import Provider

name: str
url: str

def __str__(self):
return f"{self.name} ({self.url})"
_REGISTERED_FUNCTION = Dict[str, Any]


def print_provider_info(name, providers, func):
Expand Down
28 changes: 27 additions & 1 deletion flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.registry import Provider
from dataclasses import dataclass

PROVIDERS = [] #: testing


@dataclass
class Provider:

name: str
url: str

def __post_init__(self):
PROVIDERS.append(self)

def __str__(self):
return f"{self.name} ({self.url})"


_TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models")
_DINO = Provider("Facebook Research/dino", "https://github.com/facebookresearch/dino")
_ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision")
_TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision")
_ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5")
_MMDET = Provider("OpenMMLab/MMDetection", "https://github.com/open-mmlab/mmdetection")
_EFFDET = Provider("rwightman/efficientdet-pytorch", "https://github.com/rwightman/efficientdet-pytorch")
_SEGMENTATION_MODELS = Provider(
"qubvel/segmentation_models.pytorch", "https://github.com/qubvel/segmentation_models.pytorch"
)
_PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche")
_HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers")
_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq")
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
2 changes: 2 additions & 0 deletions flash/image/classification/backbones/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TIMM_AVAILABLE
from flash.core.utilities.providers import _TIMM
from flash.core.utilities.url_error import catch_url_error
from flash.image.classification.backbones.torchvision import TORCHVISION_MODELS

Expand Down Expand Up @@ -47,4 +48,5 @@ def register_timm_backbones(register: FlashRegistry):
name=model_name,
namespace="vision",
package="timm",
providers=_TIMM,
)
7 changes: 4 additions & 3 deletions flash/image/classification/backbones/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.utilities.providers import _TORCHVISION
from flash.core.utilities.url_error import catch_url_error
from flash.image.classification.backbones.resnet import RESNET_MODELS

Expand Down Expand Up @@ -59,8 +60,8 @@ def register_mobilenet_vgg_backbones(register: FlashRegistry):
fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type=_type,
providers=_TORCHVISION,
)


Expand All @@ -71,8 +72,8 @@ def register_resnext_model(register: FlashRegistry):
fn=catch_url_error(partial(_fn_resnext, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type="resnext",
providers=_TORCHVISION,
)


Expand All @@ -83,6 +84,6 @@ def register_densenet_backbones(register: FlashRegistry):
fn=catch_url_error(partial(_fn_densenet, model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type="densenet",
providers=_TORCHVISION,
)
7 changes: 3 additions & 4 deletions flash/image/classification/backbones/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from flash.core.registry import FlashRegistry
from flash.core.utilities.providers import _DINO
from flash.core.utilities.url_error import catch_url_error


Expand Down Expand Up @@ -41,7 +42,5 @@ def dino_vitb8(*_, **__):


def register_dino_backbones(register: FlashRegistry):
register(catch_url_error(dino_deits16))
register(catch_url_error(dino_deits8))
register(catch_url_error(dino_vitb16))
register(catch_url_error(dino_vitb8))
for model in (dino_deits16, dino_deits8, dino_vitb16, dino_vitb8):
register(catch_url_error(model), providers=_DINO)
2 changes: 2 additions & 0 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.core.utilities.providers import _SEGMENTATION_MODELS

if _SEGMENTATION_MODELS_AVAILABLE:
import segmentation_models_pytorch as smp
Expand All @@ -39,4 +40,5 @@ def _load_smp_backbone(backbone: str, **_) -> str:
name=short_name,
namespace="image/segmentation",
weights_paths=available_weights,
providers=_SEGMENTATION_MODELS,
)
3 changes: 2 additions & 1 deletion flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.core.utilities.providers import _SEGMENTATION_MODELS

if _SEGMENTATION_MODELS_AVAILABLE:
import segmentation_models_pytorch as smp
Expand Down Expand Up @@ -71,5 +72,5 @@ def _load_smp_head(
partial(_load_smp_head, head=model_name),
name=model_name,
namespace="image/segmentation",
package="segmentation_models.pytorch",
providers=_SEGMENTATION_MODELS,
)
3 changes: 2 additions & 1 deletion flash/image/style_transfer/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYSTICHE_AVAILABLE
from flash.core.utilities.providers import _PYSTICHE

STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones")

Expand All @@ -35,5 +36,5 @@
fn=lambda: (getattr(enc, mle_fn)(), None),
name=match.group("name"),
namespace="image/style_transfer",
package="pystiche",
providers=_PYSTICHE,
)
5 changes: 3 additions & 2 deletions flash/pointcloud/detection/open3d_ml/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
from flash.core.utilities.providers import _OPEN3D_ML

ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"

Expand Down Expand Up @@ -63,7 +64,7 @@ def get_collate_fn(model) -> Callable:
return ObjectDetectBatchCollator
return batcher.collate_fn

@register(parameters=PointPillars.__init__)
@register(parameters=PointPillars.__init__, providers=_OPEN3D_ML)
def pointpillars_kitti(*args, **kwargs) -> PointPillars:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml"))
cfg.model.device = "cpu"
Expand All @@ -75,7 +76,7 @@ def pointpillars_kitti(*args, **kwargs) -> PointPillars:
model.cfg.batcher = "ObjectDetectBatchCollator"
return model, 384, get_collate_fn(model)

@register(parameters=PointPillars.__init__)
@register(parameters=PointPillars.__init__, providers=_OPEN3D_ML)
def pointpillars(*args, **kwargs) -> PointPillars:
model = PointPillars(*args, **kwargs)
model.cfg.batcher = "ObjectDetectBatch"
Expand Down
9 changes: 5 additions & 4 deletions flash/pointcloud/segmentation/open3d_ml/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
from flash.core.utilities.providers import _OPEN3D_ML

ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"

Expand All @@ -42,7 +43,7 @@ def get_collate_fn(model) -> Callable:
batcher = None
return batcher.collate_fn

@register
@register(providers=_OPEN3D_ML)
def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml"))
model = RandLANet(**cfg.model)
Expand All @@ -53,7 +54,7 @@ def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet:
model.load_state_dict(pl_load(weight_url, map_location="cpu")["model_state_dict"])
return model, 32, get_collate_fn(model)

@register
@register(providers=_OPEN3D_ML)
def randlanet_toronto3d(*args, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml"))
model = RandLANet(**cfg.model)
Expand All @@ -64,7 +65,7 @@ def randlanet_toronto3d(*args, **kwargs) -> RandLANet:
)
return model, 32, get_collate_fn(model)

@register
@register(providers=_OPEN3D_ML)
def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet:
cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml"))
model = RandLANet(**cfg.model)
Expand All @@ -75,7 +76,7 @@ def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet:
)
return model, 32, get_collate_fn(model)

@register
@register(providers=_OPEN3D_ML)
def randlanet(*args, **kwargs) -> RandLANet:
model = RandLANet(*args, **kwargs)
return model, 32, get_collate_fn(model)
3 changes: 2 additions & 1 deletion flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from flash.core.data.process import Serializer
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE
from flash.core.utilities.providers import _PYTORCHVIDEO

_VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones")

Expand All @@ -41,7 +42,7 @@
if "__" not in fn_name:
fn = getattr(hub, fn_name)
if isinstance(fn, FunctionType):
_VIDEO_CLASSIFIER_BACKBONES(fn=fn)
_VIDEO_CLASSIFIER_BACKBONES(fn=fn, providers=_PYTORCHVIDEO)


class VideoClassifierFinetuning(BaseFinetuning):
Expand Down