Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add serve support for object detection #1370

Merged
merged 6 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
if: contains( matrix.topic , 'serve' )
run: |
sudo apt-get install libsndfile1
pip install '.[all,audio]' --upgrade
pip install '.[all,audio]' icevision effdet --upgrade

- name: Install audio test dependencies
if: contains( matrix.topic , 'audio' )
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for audio file formats to `AudioClassificationData` ([#1085](https://github.com/PyTorchLightning/lightning-flash/pull/1085))

- Added support for Flash serve to the `ObjectDetector` ([#1370](https://github.com/PyTorchLightning/lightning-flash/pull/1370))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
20 changes: 20 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,23 @@ creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransfor
transform=BrightnessContrastTransform,
batch_size=4,
)

------

*******
Serving
*******

The :class:`~flash.image.detection.model.ObjectDetector` is servable.
This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`.
Here's an example:

.. literalinclude:: ../../../flash_examples/serve/object_detection/inference_server.py
:language: python
:lines: 14-

You can now perform inference from your client like this:

.. literalinclude:: ../../../flash_examples/serve/object_detection/client.py
:language: python
:lines: 14-
15 changes: 13 additions & 2 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, requires
from flash.core.utilities.imports import (
_ICEVISION_AVAILABLE,
_ICEVISION_GREATER_EQUAL_0_11_0,
_IMAGE_AVAILABLE,
requires,
)

if _IMAGE_AVAILABLE:
from PIL import Image

if _ICEVISION_AVAILABLE:
from icevision.core import tasks
Expand Down Expand Up @@ -90,7 +98,10 @@ def to_icevision_record(sample: Dict[str, Any]):
input_component = ImageRecordComponent()
input_component.composite = record
image = sample[DataKeys.INPUT]
image = image.permute(1, 2, 0).numpy() if isinstance(image, torch.Tensor) else image
if isinstance(image, torch.Tensor):
image = image.permute(1, 2, 0).numpy()
elif isinstance(image, Image.Image):
image = np.array(image)
input_component.set_img(image)

record.add_component(OriginalSizeRecordComponent(metadata.get("size", image.shape[:2])))
Expand Down
4 changes: 2 additions & 2 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def _import_module(self):
# Global variables used for testing purposes (e.g. to only run doctests in the correct CI job)
_CORE_TESTING = True
_IMAGE_TESTING = _IMAGE_AVAILABLE
_IMAGE_EXTRAS_TESTING = False # Not for normal use
_IMAGE_EXTRAS_TESTING = True # Not for normal use
_VIDEO_TESTING = _VIDEO_AVAILABLE
_VIDEO_EXTRAS_TESTING = False # Not for normal use
_VIDEO_EXTRAS_TESTING = True # Not for normal use
_TABULAR_TESTING = _TABULAR_AVAILABLE
_TEXT_TESTING = _TEXT_AVAILABLE
_SERVE_TESTING = _SERVE_AVAILABLE
Expand Down
23 changes: 21 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type, Union

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import ServeInput
from flash.core.data.io.output import Output
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.core.serve import Composition
from flash.core.utilities.imports import requires
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.data import ImageDeserializer
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS
from flash.image.detection.output import OBJECT_DETECTION_OUTPUTS

Expand Down Expand Up @@ -95,3 +101,16 @@ def predict_kwargs(self) -> Dict[str, Any]:
@predict_kwargs.setter
def predict_kwargs(self, predict_kwargs: Dict[str, Any]):
self.adapter.predict_kwargs = predict_kwargs

@requires("serve")
def serve(
self,
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = ImageDeserializer,
transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
transform_kwargs: Optional[Dict] = None,
output: Optional[Union[str, Output]] = None,
) -> Composition:
return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output)
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flash.core.data.utilities.sort import sorted_alphanumeric
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _KORNIA_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, _KORNIA_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE

Expand All @@ -39,7 +39,7 @@


# Skip doctests if requirements aren't available
if not _ICEVISION_AVAILABLE:
if not _IMAGE_EXTRAS_TESTING:
__doctest_skip__ = ["InstanceSegmentationData", "InstanceSegmentationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/image/keypoint_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform
Expand All @@ -32,7 +32,7 @@


# Skip doctests if requirements aren't available
if not _ICEVISION_AVAILABLE:
if not _IMAGE_EXTRAS_TESTING:
__doctest_skip__ = ["KeypointDetectionData", "KeypointDetectionData.*"]


Expand Down
26 changes: 26 additions & 0 deletions flash_examples/serve/object_detection/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
17 changes: 17 additions & 0 deletions flash_examples/serve/object_detection/inference_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.image import ObjectDetector

model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.8.0/object_detection_model.pt")
model.serve()
7 changes: 4 additions & 3 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_GRAPH_TESTING,
_ICEVISION_AVAILABLE,
_IMAGE_AVAILABLE,
_IMAGE_EXTRAS_TESTING,
_IMAGE_TESTING,
_POINTCLOUD_TESTING,
_TABULAR_TESTING,
Expand Down Expand Up @@ -70,19 +71,19 @@
pytest.param(
"object_detection.py",
marks=pytest.mark.skipif(
not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
),
),
pytest.param(
"instance_segmentation.py",
marks=pytest.mark.skipif(
not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
),
),
pytest.param(
"keypoint_detection.py",
marks=pytest.mark.skipif(
not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
),
),
pytest.param(
Expand Down
13 changes: 5 additions & 8 deletions tests/image/detection/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _PIL_AVAILABLE
from flash.image.detection.data import ObjectDetectionData

if _PIL_AVAILABLE:
Expand Down Expand Up @@ -163,8 +163,7 @@ def _create_synth_fiftyone_dataset(tmpdir):
return dataset


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_image_detector_data_from_coco(tmpdir):

train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir)
Expand Down Expand Up @@ -198,7 +197,7 @@ def test_image_detector_data_from_coco(tmpdir):
assert sample[DataKeys.INPUT].shape == (128, 128, 3)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
def test_image_detector_data_from_fiftyone(tmpdir):

Expand Down Expand Up @@ -230,8 +229,7 @@ def test_image_detector_data_from_fiftyone(tmpdir):
assert sample[DataKeys.INPUT].shape == (128, 128, 3)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_image_detector_data_from_files(tmpdir):

predict_files = _create_synth_files_dataset(tmpdir)
Expand All @@ -243,8 +241,7 @@ def test_image_detector_data_from_files(tmpdir):
assert sample[DataKeys.INPUT].shape == (128, 128, 3)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_image_detector_data_from_folders(tmpdir):

predict_folder = _create_synth_folders_dataset(tmpdir)
Expand Down
13 changes: 3 additions & 10 deletions tests/image/detection/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
import torch

import flash
from flash.core.utilities.imports import (
_COCO_AVAILABLE,
_FIFTYONE_AVAILABLE,
_ICEVISION_AVAILABLE,
_IMAGE_AVAILABLE,
_PIL_AVAILABLE,
)
from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _PIL_AVAILABLE
from flash.image import ObjectDetector
from flash.image.detection import ObjectDetectionData

Expand All @@ -39,8 +33,7 @@
from tests.image.detection.test_data import _create_synth_fiftyone_dataset


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn"), ("faster_rcnn", "resnet18_fpn")])
def test_detection(tmpdir, head, backbone):

Expand All @@ -63,7 +56,7 @@ def test_detection(tmpdir, head, backbone):
trainer.predict(model, datamodule=datamodule)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")])
def test_detection_fiftyone(tmpdir, head, backbone):
Expand Down
16 changes: 12 additions & 4 deletions tests/image/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import random
from typing import Any
from unittest import mock

import numpy as np
import pytest
Expand All @@ -22,7 +23,7 @@
from flash.core.data.io.input import DataKeys
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.trainer import Trainer
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _SERVE_TESTING
from flash.image import ObjectDetector
from tests.helpers.task_tester import TaskTester

Expand Down Expand Up @@ -72,7 +73,7 @@ class TestObjectDetector(TaskTester):
task = ObjectDetector
task_kwargs = {"num_classes": 2}
cli_command = "object_detection"
is_testing = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE
is_testing = _IMAGE_EXTRAS_TESTING
is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE

# TODO: Resolve JIT support
Expand Down Expand Up @@ -109,8 +110,7 @@ def example_test_sample(self):


@pytest.mark.parametrize("head", ["retinanet"])
@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_predict(tmpdir, head):
model = ObjectDetector(num_classes=2, head=head, pretrained=False)
ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
Expand All @@ -137,3 +137,11 @@ def test_predict(tmpdir, head):
model.predict_kwargs = {"detection_threshold": 2}
predictions = trainer.predict(model, dl, output="preds")
assert len(predictions[0][0]["bboxes"]) == 0


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
@mock.patch("flash._IS_TESTING", True)
def test_serve():
model = ObjectDetector(2)
model.eval()
model.serve()
11 changes: 4 additions & 7 deletions tests/image/instance_segmentation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
import torch

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE
from flash.core.utilities.imports import _IMAGE_EXTRAS_TESTING
from flash.image.instance_segmentation import InstanceSegmentationData
from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform
from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_image_detector_data_from_files(tmpdir):

predict_files = _create_synth_files_dataset(tmpdir)
Expand All @@ -35,8 +34,7 @@ def test_image_detector_data_from_files(tmpdir):
assert sample[DataKeys.INPUT].shape == (128, 128, 3)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_image_detector_data_from_folders(tmpdir):

predict_folder = _create_synth_folders_dataset(tmpdir)
Expand All @@ -48,8 +46,7 @@ def test_image_detector_data_from_folders(tmpdir):
assert sample[DataKeys.INPUT].shape == (128, 128, 3)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.")
def test_instance_segmentation_output_transform():

sample = {
Expand Down
Loading