From 089c4e84bbbea792d01f9801676e847c98945fa3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 17 Jun 2021 12:44:10 +0100 Subject: [PATCH] Flash serve bug fixes (#422) * Flash serve bug fixes * Fixes --- docs/source/quickstart.rst | 4 ++-- flash/core/data/batch.py | 6 +---- flash/core/serve/flash_components.py | 23 +++++++++++-------- flash/image/segmentation/data.py | 2 +- flash/image/segmentation/model.py | 2 +- flash_examples/predict/text_classification.py | 2 -- .../inference_server.py | 2 ++ .../inference_server.py | 2 ++ requirements/serve.txt | 4 ++-- 9 files changed, 25 insertions(+), 22 deletions(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index a244448e04..01c4795cc3 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -115,11 +115,11 @@ We get the following output: .. testcode:: :hide: - assert all([prediction in [0, 1] for prediction in predictions]) + assert all([prediction in ["positive", "negative"] for prediction in predictions]) .. code-block:: - [1, 1, 0] + ["negative", "negative", "positive"] ------- diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 587207a0a0..2ab0ce3c5b 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -316,11 +316,7 @@ def forward(self, batch: Sequence[Any]): self.save_fn(pred) else: self.save_fn(final_preds) - else: - # todo (tchaton): Debug the serializer not iterating over a list. - if self.is_serving and isinstance(final_preds, list) and len(final_preds) == 1: - return final_preds[0] - return final_preds + return final_preds def __str__(self) -> str: return ( diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 017d933e3d..1f549be029 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -1,13 +1,10 @@ -import inspect -from pathlib import Path -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Mapping, Optional import torch -from pytorch_lightning.trainer.states import RunningStage from flash import Task -from flash.core.serve import Composition, expose, GridModel, ModelComponent -from flash.core.serve.core import FilePath, GridModelValidArgs_T, GridserveScriptLoader +from flash.core.data.data_source import DefaultDataKeys +from flash.core.serve.core import FilePath, GridserveScriptLoader from flash.core.serve.types.base import BaseType @@ -34,9 +31,17 @@ def __init__( ): self._serializer = serializer - def serialize(self, output) -> Any: # pragma: no cover - result = self._serializer(output) - return result + def serialize(self, outputs) -> Any: # pragma: no cover + results = [] + if isinstance(outputs, list) or isinstance(outputs, torch.Tensor): + for output in outputs: + result = self._serializer(output) + if isinstance(result, Mapping): + result = result[DefaultDataKeys.PREDS] + results.append(result) + if len(results) == 1: + return results[0] + return results def deserialize(self, data: str) -> Any: # pragma: no cover return None diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 1b10137ce7..a21517704d 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -227,7 +227,7 @@ def deserialize(self, data: str) -> torch.Tensor: buffer = BytesIO(img) img = PILImage.open(buffer, mode="r") img = self.to_tensor(img) - return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: img.shape} + return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: {"size": img.shape}} class SemanticSegmentationPreprocess(Preprocess): diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index fb2e4d4b55..df880636d8 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch from torch import nn diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py index 7f03eb27ee..ee7134c413 100644 --- a/flash_examples/predict/text_classification.py +++ b/flash_examples/predict/text_classification.py @@ -23,8 +23,6 @@ # 2. Load the model from a checkpoint model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") -model.serializer = Labels() - # 2a. Classify a few sentences! How was the movie? predictions = model.predict([ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", diff --git a/flash_examples/serve/segmentic_segmentation/inference_server.py b/flash_examples/serve/segmentic_segmentation/inference_server.py index fe3d91d3c7..78d51a0c0e 100644 --- a/flash_examples/serve/segmentic_segmentation/inference_server.py +++ b/flash_examples/serve/segmentic_segmentation/inference_server.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from flash.image import SemanticSegmentation +from flash.image.segmentation.serialization import SegmentationLabels model = SemanticSegmentation.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" ) +model.serializer = SegmentationLabels(visualize=False) model.serve() diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py index cf5b57c9b3..f6aac866e2 100644 --- a/flash_examples/serve/tabular_classification/inference_server.py +++ b/flash_examples/serve/tabular_classification/inference_server.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from flash.core.classification import Labels from flash.tabular import TabularClassifier model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") +model.serializer = Labels(['Did not survive', 'Survived']) model.serve() diff --git a/requirements/serve.txt b/requirements/serve.txt index 2ceef6e06b..da47094178 100644 --- a/requirements/serve.txt +++ b/requirements/serve.txt @@ -10,8 +10,8 @@ fastapi>=0.65.2,<0.66.0 # to have full feature control of fastapi, manually install optional # dependencies rather than installing fastapi[all] # https://fastapi.tiangolo.com/#optional-dependencies -pydantic>=1.6.0,<2.0.0 -starlette>=0.14.0 +pydantic>1.8.1,<2.0.0 +starlette==0.14.2 uvicorn[standard]>=0.12.0,<0.14.0 aiofiles jinja2