diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index a584c634de..1c634e2ccf 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -63,6 +63,11 @@ jobs: brew update brew install libomp # https://github.com/pytorch/pytorch/issues/20030 + - name: Install graphviz + if: matrix.topic == 'serve' + run: | + sudo apt-get install graphviz + - name: Set min. dependencies if: matrix.requires == 'minimal' run: | diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 908ac20149..93797ba20a 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -64,6 +64,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_AVAILABLE = _module_available("torch") +_BOLTS_AVAILABLE = _module_available("pl_bolts") and _compare_version("torch", operator.lt, "1.9.0") _PANDAS_AVAILABLE = _module_available("pandas") _SKLEARN_AVAILABLE = _module_available("sklearn") _TABNET_AVAILABLE = _module_available("pytorch_tabnet") diff --git a/flash/image/backbones.py b/flash/image/backbones.py index ec7cacf8f4..f61a963fdf 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -20,11 +20,11 @@ import torch from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from torch import nn as nn from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE if _TIMM_AVAILABLE: import timm diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 06ee9152b2..e77c95bf34 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -51,7 +51,7 @@ class ImageEmbedder(Task): def __init__( self, embedding_dim: Optional[int] = None, - backbone: str = "swav-imagenet", + backbone: str = "resnet101", pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index fa26f8e76e..2e65c800fc 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -17,10 +17,10 @@ import torch.nn as nn from deprecate import deprecated -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.backbones import catch_url_error if _TORCHVISION_AVAILABLE: diff --git a/flash_examples/finetuning/image_classification_multi_label.py b/flash_examples/finetuning/image_classification_multi_label.py index 79c24b8cc1..5b55e4fec2 100644 --- a/flash_examples/finetuning/image_classification_multi_label.py +++ b/flash_examples/finetuning/image_classification_multi_label.py @@ -58,8 +58,8 @@ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], L metrics=F1(num_classes=len(genres)), ) -# 4. Create the trainer. Train on 2 gpus for 10 epochs. -trainer = flash.Trainer(max_epochs=10) +# 4. Create the trainer +trainer = flash.Trainer(fast_dev_run=True) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index d763766c04..ae7716bd71 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -19,9 +19,8 @@ # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") -# 2. Create an ImageEmbedder with swav trained on imagenet. -# Check out SWAV: https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav -embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128) +# 2. Create an ImageEmbedder with resnet101 trained on imagenet. +embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128) # 3. Generate an embedding from an image path. embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"]) diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py index 230ca0bc14..cf09425c87 100644 --- a/tests/image/segmentation/test_backbones.py +++ b/tests/image/segmentation/test_backbones.py @@ -13,8 +13,9 @@ # limitations under the License. import pytest import torch -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 91f0062b8a..6036927555 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -14,9 +14,9 @@ import urllib.error import pytest -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE -from flash.core.utilities.imports import _TIMM_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES diff --git a/tests/serve/test_integration.py b/tests/serve/test_integration.py index 604f1b7451..7d76600579 100644 --- a/tests/serve/test_integration.py +++ b/tests/serve/test_integration.py @@ -230,10 +230,9 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): }, "session": "UUID", } - # TODO: Add graphviz to CI - # resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") - # assert resp.headers["content-type"] == "text/html; charset=utf-8" - # assert resp.template.name == "dag.html" + resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.template.name == "dag.html" @pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") @@ -317,10 +316,9 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad }, "session": "UUID", } - # TODO: Add graphviz to CI - # resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") - # assert resp.headers["content-type"] == "text/html; charset=utf-8" - # assert resp.template.name == "dag.html" + resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.template.name == "dag.html" @pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") @@ -385,19 +383,18 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ app = composit.serve(host="0.0.0.0", port=8000) with TestClient(app) as tc: - # TODO: Add graphviz to CI - # resp = tc.get("http://127.0.0.1:8000/gridserve/component_dags") - # assert resp.headers["content-type"] == "text/html; charset=utf-8" - # assert resp.template.name == "dag.html" - # resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") - # assert resp.headers["content-type"] == "text/html; charset=utf-8" - # assert resp.template.name == "dag.html" - # resp = tc.get("http://127.0.0.1:8000/predict_seat_img/dag") - # assert resp.headers["content-type"] == "text/html; charset=utf-8" - # assert resp.template.name == "dag.html" - # resp = tc.get("http://127.0.0.1:8000/predict_seat_img_two/dag") - # assert resp.headers["content-type"] == "text/html; charset=utf-8" - # assert resp.template.name == "dag.html" + resp = tc.get("http://127.0.0.1:8000/gridserve/component_dags") + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.template.name == "dag.html" + resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.template.name == "dag.html" + resp = tc.get("http://127.0.0.1:8000/predict_seat_img/dag") + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.template.name == "dag.html" + resp = tc.get("http://127.0.0.1:8000/predict_seat_img_two/dag") + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert resp.template.name == "dag.html" with (session_global_datadir / "cat.jpg").open("rb") as f: imgstr = base64.b64encode(f.read()).decode("UTF-8")