Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Flash serve bug fixes #422

Merged
merged 2 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ We get the following output:
.. testcode::
:hide:

assert all([prediction in [0, 1] for prediction in predictions])
assert all([prediction in ["positive", "negative"] for prediction in predictions])

.. code-block::

[1, 1, 0]
["negative", "negative", "positive"]

-------

Expand Down
6 changes: 1 addition & 5 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,7 @@ def forward(self, batch: Sequence[Any]):
self.save_fn(pred)
else:
self.save_fn(final_preds)
else:
# todo (tchaton): Debug the serializer not iterating over a list.
if self.is_serving and isinstance(final_preds, list) and len(final_preds) == 1:
return final_preds[0]
return final_preds
return final_preds

def __str__(self) -> str:
return (
Expand Down
23 changes: 14 additions & 9 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import inspect
from pathlib import Path
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Mapping, Optional

import torch
from pytorch_lightning.trainer.states import RunningStage

from flash import Task
from flash.core.serve import Composition, expose, GridModel, ModelComponent
from flash.core.serve.core import FilePath, GridModelValidArgs_T, GridserveScriptLoader
from flash.core.data.data_source import DefaultDataKeys
from flash.core.serve.core import FilePath, GridserveScriptLoader
from flash.core.serve.types.base import BaseType


Expand All @@ -34,9 +31,17 @@ def __init__(
):
self._serializer = serializer

def serialize(self, output) -> Any: # pragma: no cover
result = self._serializer(output)
return result
def serialize(self, outputs) -> Any: # pragma: no cover
results = []
if isinstance(outputs, list) or isinstance(outputs, torch.Tensor):
for output in outputs:
result = self._serializer(output)
if isinstance(result, Mapping):
result = result[DefaultDataKeys.PREDS]
results.append(result)
if len(results) == 1:
return results[0]
return results

def deserialize(self, data: str) -> Any: # pragma: no cover
return None
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def deserialize(self, data: str) -> torch.Tensor:
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
img = self.to_tensor(img)
return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: img.shape}
return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: {"size": img.shape}}


class SemanticSegmentationPreprocess(Preprocess):
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch import nn
Expand Down
2 changes: 0 additions & 2 deletions flash_examples/predict/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

model.serializer = Labels()

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.classification import Labels
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
model.serializer = Labels(['Did not survive', 'Survived'])
model.serve()
4 changes: 2 additions & 2 deletions requirements/serve.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ fastapi>=0.65.2,<0.66.0
# to have full feature control of fastapi, manually install optional
# dependencies rather than installing fastapi[all]
# https://fastapi.tiangolo.com/#optional-dependencies
pydantic>=1.6.0,<2.0.0
starlette>=0.14.0
pydantic>1.8.1,<2.0.0
starlette==0.14.2
uvicorn[standard]>=0.12.0,<0.14.0
aiofiles
jinja2
Expand Down