Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add serve support for object detection #1370

Merged
merged 6 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
if: contains( matrix.topic , 'serve' )
run: |
sudo apt-get install libsndfile1
pip install '.[all,audio]' --upgrade
pip install '.[all,audio]' icevision effdet --upgrade

- name: Install audio test dependencies
if: contains( matrix.topic , 'audio' )
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for audio file formats to `AudioClassificationData` ([#1085](https://github.com/PyTorchLightning/lightning-flash/pull/1085))

- Added support for Flash serve to the `ObjectDetector` ([#1370](https://github.com/PyTorchLightning/lightning-flash/pull/1370))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
20 changes: 20 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,23 @@ creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransfor
transform=BrightnessContrastTransform,
batch_size=4,
)

------

*******
Serving
*******

The :class:`~flash.image.detection.model.ObjectDetector` is servable.
This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`.
Here's an example:

.. literalinclude:: ../../../flash_examples/serve/object_detection/inference_server.py
:language: python
:lines: 14-

You can now perform inference from your client like this:

.. literalinclude:: ../../../flash_examples/serve/object_detection/client.py
:language: python
:lines: 14-
23 changes: 21 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type, Union

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import ServeInput
from flash.core.data.io.output import Output
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.core.serve import Composition
from flash.core.utilities.imports import requires
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.data import ImageDeserializer
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS
from flash.image.detection.output import OBJECT_DETECTION_OUTPUTS

Expand Down Expand Up @@ -95,3 +101,16 @@ def predict_kwargs(self) -> Dict[str, Any]:
@predict_kwargs.setter
def predict_kwargs(self, predict_kwargs: Dict[str, Any]):
self.adapter.predict_kwargs = predict_kwargs

@requires("serve")
def serve(
self,
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = ImageDeserializer,
transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
transform_kwargs: Optional[Dict] = None,
output: Optional[Union[str, Output]] = None,
) -> Composition:
return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output)
26 changes: 26 additions & 0 deletions flash_examples/serve/object_detection/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
17 changes: 17 additions & 0 deletions flash_examples/serve/object_detection/inference_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.image import ObjectDetector

model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.8.0/object_detection_model.pt")
model.serve()
11 changes: 10 additions & 1 deletion tests/image/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import random
from typing import Any
from unittest import mock

import numpy as np
import pytest
Expand All @@ -22,7 +23,7 @@
from flash.core.data.io.input import DataKeys
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.trainer import Trainer
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _SERVE_TESTING
from flash.image import ObjectDetector
from tests.helpers.task_tester import TaskTester

Expand Down Expand Up @@ -137,3 +138,11 @@ def test_predict(tmpdir, head):
model.predict_kwargs = {"detection_threshold": 2}
predictions = trainer.predict(model, dl, output="preds")
assert len(predictions[0][0]["bboxes"]) == 0


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
@mock.patch("flash._IS_TESTING", True)
def test_serve():
model = ObjectDetector(2)
model.eval()
model.serve()