diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d055484b78..46e1870aae 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -128,7 +128,7 @@ jobs: if: contains( matrix.topic , 'serve' ) run: | sudo apt-get install libsndfile1 - pip install '.[all,audio]' --upgrade + pip install '.[all,audio]' icevision effdet --upgrade - name: Install audio test dependencies if: contains( matrix.topic , 'audio' ) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a41c1213b..88d1850fa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for audio file formats to `AudioClassificationData` ([#1085](https://github.com/PyTorchLightning/lightning-flash/pull/1085)) +- Added support for Flash serve to the `ObjectDetector` ([#1370](https://github.com/PyTorchLightning/lightning-flash/pull/1370)) + ### Changed - Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 9786d79e73..b9b3b5cfe3 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -115,3 +115,23 @@ creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransfor transform=BrightnessContrastTransform, batch_size=4, ) + +------ + +******* +Serving +******* + +The :class:`~flash.image.detection.model.ObjectDetector` is servable. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. +Here's an example: + +.. literalinclude:: ../../../flash_examples/serve/object_detection/inference_server.py + :language: python + :lines: 14- + +You can now perform inference from your client like this: + +.. literalinclude:: ../../../flash_examples/serve/object_detection/client.py + :language: python + :lines: 14- diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 428382d087..4192197b54 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -20,7 +20,15 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, requires +from flash.core.utilities.imports import ( + _ICEVISION_AVAILABLE, + _ICEVISION_GREATER_EQUAL_0_11_0, + _IMAGE_AVAILABLE, + requires, +) + +if _IMAGE_AVAILABLE: + from PIL import Image if _ICEVISION_AVAILABLE: from icevision.core import tasks @@ -90,7 +98,10 @@ def to_icevision_record(sample: Dict[str, Any]): input_component = ImageRecordComponent() input_component.composite = record image = sample[DataKeys.INPUT] - image = image.permute(1, 2, 0).numpy() if isinstance(image, torch.Tensor) else image + if isinstance(image, torch.Tensor): + image = image.permute(1, 2, 0).numpy() + elif isinstance(image, Image.Image): + image = np.array(image) input_component.set_img(image) record.add_component(OriginalSizeRecordComponent(metadata.get("size", image.shape[:2]))) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 0fdc4b0d75..ce48639b71 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -287,9 +287,9 @@ def _import_module(self): # Global variables used for testing purposes (e.g. to only run doctests in the correct CI job) _CORE_TESTING = True _IMAGE_TESTING = _IMAGE_AVAILABLE -_IMAGE_EXTRAS_TESTING = False # Not for normal use +_IMAGE_EXTRAS_TESTING = True # Not for normal use _VIDEO_TESTING = _VIDEO_AVAILABLE -_VIDEO_EXTRAS_TESTING = False # Not for normal use +_VIDEO_EXTRAS_TESTING = True # Not for normal use _TABULAR_TESTING = _TABULAR_AVAILABLE _TEXT_TESTING = _TEXT_AVAILABLE _SERVE_TESTING = _SERVE_AVAILABLE diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index af21184fcc..a6af3ec112 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type, Union from flash.core.adapter import AdapterTask +from flash.core.data.io.input import ServeInput +from flash.core.data.io.output import Output +from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE +from flash.core.serve import Composition +from flash.core.utilities.imports import requires +from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, OPTIMIZER_TYPE +from flash.image.data import ImageDeserializer from flash.image.detection.backbones import OBJECT_DETECTION_HEADS from flash.image.detection.output import OBJECT_DETECTION_OUTPUTS @@ -95,3 +101,16 @@ def predict_kwargs(self) -> Dict[str, Any]: @predict_kwargs.setter def predict_kwargs(self, predict_kwargs: Dict[str, Any]): self.adapter.predict_kwargs = predict_kwargs + + @requires("serve") + def serve( + self, + host: str = "127.0.0.1", + port: int = 8000, + sanity_check: bool = True, + input_cls: Optional[Type[ServeInput]] = ImageDeserializer, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, + ) -> Composition: + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output) diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 91a45c8004..23674436fe 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -22,7 +22,7 @@ from flash.core.data.utilities.sort import sorted_alphanumeric from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _KORNIA_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, _KORNIA_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -39,7 +39,7 @@ # Skip doctests if requirements aren't available -if not _ICEVISION_AVAILABLE: +if not _IMAGE_EXTRAS_TESTING: __doctest_skip__ = ["InstanceSegmentationData", "InstanceSegmentationData.*"] diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index af0661419d..32d0395a0c 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -17,7 +17,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.integrations.icevision.data import IceVisionInput -from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform @@ -32,7 +32,7 @@ # Skip doctests if requirements aren't available -if not _ICEVISION_AVAILABLE: +if not _IMAGE_EXTRAS_TESTING: __doctest_skip__ = ["KeypointDetectionData", "KeypointDetectionData.*"] diff --git a/flash_examples/serve/object_detection/client.py b/flash_examples/serve/object_detection/client.py new file mode 100644 index 0000000000..77d9e89e7b --- /dev/null +++ b/flash_examples/serve/object_detection/client.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +from pathlib import Path + +import requests + +import flash + +with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f: + imgstr = base64.b64encode(f.read()).decode("UTF-8") + +body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}} +resp = requests.post("http://127.0.0.1:8000/predict", json=body) +print(resp.json()) diff --git a/flash_examples/serve/object_detection/inference_server.py b/flash_examples/serve/object_detection/inference_server.py new file mode 100644 index 0000000000..427c2045b3 --- /dev/null +++ b/flash_examples/serve/object_detection/inference_server.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.image import ObjectDetector + +model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.8.0/object_detection_model.pt") +model.serve() diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index f42a283a4b..77254d3b49 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -25,6 +25,7 @@ _GRAPH_TESTING, _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, + _IMAGE_EXTRAS_TESTING, _IMAGE_TESTING, _POINTCLOUD_TESTING, _TABULAR_TESTING, @@ -70,19 +71,19 @@ pytest.param( "object_detection.py", marks=pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" ), ), pytest.param( "instance_segmentation.py", marks=pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" ), ), pytest.param( "keypoint_detection.py", marks=pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" ), ), pytest.param( diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 62f056720f..6cd050345e 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -18,7 +18,7 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _PIL_AVAILABLE from flash.image.detection.data import ObjectDetectionData if _PIL_AVAILABLE: @@ -163,8 +163,7 @@ def _create_synth_fiftyone_dataset(tmpdir): return dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) @@ -198,7 +197,7 @@ def test_image_detector_data_from_coco(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_image_detector_data_from_fiftyone(tmpdir): @@ -230,8 +229,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) @@ -243,8 +241,7 @@ def test_image_detector_data_from_files(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index db7891f09d..9dd24318de 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -17,13 +17,7 @@ import torch import flash -from flash.core.utilities.imports import ( - _COCO_AVAILABLE, - _FIFTYONE_AVAILABLE, - _ICEVISION_AVAILABLE, - _IMAGE_AVAILABLE, - _PIL_AVAILABLE, -) +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _PIL_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData @@ -39,8 +33,7 @@ from tests.image.detection.test_data import _create_synth_fiftyone_dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn"), ("faster_rcnn", "resnet18_fpn")]) def test_detection(tmpdir, head, backbone): @@ -63,7 +56,7 @@ def test_detection(tmpdir, head, backbone): trainer.predict(model, datamodule=datamodule) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) def test_detection_fiftyone(tmpdir, head, backbone): diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index fa5169584e..0fc4f02c48 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -13,6 +13,7 @@ # limitations under the License. import random from typing import Any +from unittest import mock import numpy as np import pytest @@ -22,7 +23,7 @@ from flash.core.data.io.input import DataKeys from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.trainer import Trainer -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _SERVE_TESTING from flash.image import ObjectDetector from tests.helpers.task_tester import TaskTester @@ -72,7 +73,7 @@ class TestObjectDetector(TaskTester): task = ObjectDetector task_kwargs = {"num_classes": 2} cli_command = "object_detection" - is_testing = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE + is_testing = _IMAGE_EXTRAS_TESTING is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE # TODO: Resolve JIT support @@ -109,8 +110,7 @@ def example_test_sample(self): @pytest.mark.parametrize("head", ["retinanet"]) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_predict(tmpdir, head): model = ObjectDetector(num_classes=2, head=head, pretrained=False) ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) @@ -137,3 +137,11 @@ def test_predict(tmpdir, head): model.predict_kwargs = {"detection_threshold": 2} predictions = trainer.predict(model, dl, output="preds") assert len(predictions[0][0]["bboxes"]) == 0 + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = ObjectDetector(2) + model.eval() + model.serve() diff --git a/tests/image/instance_segmentation/test_data.py b/tests/image/instance_segmentation/test_data.py index e57d39411a..81dae6ff7e 100644 --- a/tests/image/instance_segmentation/test_data.py +++ b/tests/image/instance_segmentation/test_data.py @@ -16,14 +16,13 @@ import torch from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _IMAGE_EXTRAS_TESTING from flash.image.instance_segmentation import InstanceSegmentationData from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) @@ -35,8 +34,7 @@ def test_image_detector_data_from_files(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) @@ -48,8 +46,7 @@ def test_image_detector_data_from_folders(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_instance_segmentation_output_transform(): sample = { diff --git a/tests/image/instance_segmentation/test_model.py b/tests/image/instance_segmentation/test_model.py index f4f76ce31a..2518f1a07a 100644 --- a/tests/image/instance_segmentation/test_model.py +++ b/tests/image/instance_segmentation/test_model.py @@ -22,7 +22,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING from flash.image import InstanceSegmentation, InstanceSegmentationData from tests.helpers.task_tester import TaskTester @@ -95,7 +95,7 @@ class TestInstanceSegmentation(TaskTester): task = InstanceSegmentation task_kwargs = {"num_classes": 2} cli_command = "instance_segmentation" - is_testing = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE + is_testing = _IMAGE_EXTRAS_TESTING is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE # TODO: Resolve JIT support @@ -132,8 +132,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("backbone, head", [("resnet18_fpn", "mask_rcnn")]) def test_model(coco_instances, backbone, head): datamodule = InstanceSegmentationData.from_coco( diff --git a/tests/image/keypoint_detection/test_data.py b/tests/image/keypoint_detection/test_data.py index 92e2429bd4..684fda6478 100644 --- a/tests/image/keypoint_detection/test_data.py +++ b/tests/image/keypoint_detection/test_data.py @@ -14,13 +14,12 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _IMAGE_EXTRAS_TESTING from flash.image.keypoint_detection import KeypointDetectionData from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) @@ -32,8 +31,7 @@ def test_image_detector_data_from_files(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py index b05cc90ee4..b0980961cd 100644 --- a/tests/image/keypoint_detection/test_model.py +++ b/tests/image/keypoint_detection/test_model.py @@ -22,7 +22,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING from flash.image import KeypointDetectionData, KeypointDetector from tests.helpers.task_tester import TaskTester @@ -99,7 +99,7 @@ class TestKeypointDetector(TaskTester): task_args = (2,) task_kwargs = {"num_classes": 2} cli_command = "keypoint_detection" - is_testing = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE + is_testing = _IMAGE_EXTRAS_TESTING is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE # TODO: Resolve JIT support @@ -139,8 +139,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing") +@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("backbone, head", [("resnet18_fpn", "keypoint_rcnn")]) def test_model(coco_keypoints, backbone, head): datamodule = KeypointDetectionData.from_coco(