Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Serve sanity checks #423

Merged
merged 22 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ jobs:
pip list
shell: bash

- name: Install serve test dependencies
if: matrix.topic == 'serve'
run: |
pip install '.[all]' --pre --upgrade
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

- name: Cache datasets
uses: actions/cache@v2
with:
Expand All @@ -115,7 +120,8 @@ jobs:

- name: Tests
env:
FIFTYONE_DO_NOT_TRACK: true
FLASH_TEST_TOPIC: ${{ matrix.topic }}
FIFTYONE_DO_NOT_TRACK: true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall t be rather just 0/1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it can be. This was done by the fiftyone people for something on their end. I wouldn't want to change it in case something breaks.

run: |
# tox --sitepackages
coverage run --source flash -m pytest flash tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))
- Added a sanity checking feature to flash.serve ([#423](https://github.com/PyTorchLightning/lightning-flash/pull/423))

### Changed

Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flash.core.trainer import Trainer # noqa: E402

_PACKAGE_ROOT = os.path.dirname(__file__)
ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets")
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

Expand Down
Binary file added flash/assets/fish.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if not torch.is_tensor(sample):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
sample = torch.tensor(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand Down Expand Up @@ -153,7 +154,8 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
if isinstance(sample, Dict) and DefaultDataKeys.PREDS in sample:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
sample = sample[DefaultDataKeys.PREDS]
sample = torch.tensor(sample)
labels = None

Expand Down
4 changes: 4 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) ->
data_pipeline_state._initialized = True # TODO: Not sure we need this
return data_pipeline_state

@property
def example_input(self) -> str:
return self._deserializer.example_input

@staticmethod
def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool:
"""
Expand Down
7 changes: 6 additions & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -569,6 +569,11 @@ class Deserializer(Properties):
def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor???
raise NotImplementedError

@property
@abstractmethod
def example_input(self) -> str:
raise NotImplementedError
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, sample: Any) -> Any:
return self.deserialize(sample)

Expand Down
19 changes: 17 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import functools
import inspect
import os
from copy import deepcopy
from importlib import import_module
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
Expand Down Expand Up @@ -592,7 +593,9 @@ def configure_callbacks(self):
if flash._IS_TESTING and torch.cuda.is_available():
return [BenchmarkConvergenceCI()]

def serve(self, host: str = "127.0.0.1", port: int = 8000) -> 'Composition':
def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition':
from fastapi.testclient import TestClient

from flash.core.serve.flash_components import FlashInputs, FlashOutputs

class FlashServeModelComponent(ModelComponent):
Expand Down Expand Up @@ -626,7 +629,19 @@ def predict(self, inputs):
preds = self.postprocessor(preds)
return preds

if sanity_check:
print("Running sanity check")
comp = FlashServeModelComponent(self)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
input_str = self.data_pipeline._deserializer.example_input
body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}}
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
print(f"Sanity check response: {resp.json()}")

comp = FlashServeModelComponent(self)
composition = Composition(predict=comp)
composition = Composition(predict=comp, TESTING=os.environ["FLASH_TESTING"] == "1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, some variables are 0/1 other you write s true... lets be consistent

ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
composition.serve(host=host, port=port)
return composition
11 changes: 10 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""

import importlib
import operator
import os
from importlib.util import find_spec

from pkg_resources import DistributionNotFound
Expand Down Expand Up @@ -94,3 +94,12 @@ def _compare_version(package: str, op, version) -> bool:
_VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE
_IMAGE_AVAILABLE = _TORCHVISION_AVAILABLE and _TIMM_AVAILABLE and _KORNIA_AVAILABLE
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE

if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
_IMAGE_AVAILABLE = topic == "image"
_VIDEO_AVAILABLE = topic == "video"
_TABULAR_AVAILABLE = topic == "tabular"
_TEXT_AVAILABLE = topic == "text"
_IMAGE_STLYE_TRANSFER = topic == "image_style_transfer"
_SERVE_AVAILABLE = topic == "serve"
34 changes: 9 additions & 25 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -25,43 +23,29 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE
from flash.image.classification.transforms import default_transforms, train_default_transforms
from flash.image.data import ImageFiftyOneDataSource, ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource
from flash.image.data import (
ImageDeserializer,
ImageFiftyOneDataSource,
ImageNumpyDataSource,
ImagePathsDataSource,
ImageTensorDataSource,
)

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None

if _TORCHVISION_AVAILABLE:
import torchvision

if _IMAGE_AVAILABLE:
from PIL import Image
from PIL import Image as PILImage
else:

class Image:
Image = None


class ImageClassificationDeserializer(Deserializer):

def __init__(self):

self.to_tensor = torchvision.transforms.ToTensor()

def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
return {
DefaultDataKeys.INPUT: img,
}


class ImageClassificationPreprocess(Preprocess):

def __init__(
Expand All @@ -88,7 +72,7 @@ def __init__(
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
},
deserializer=deserializer or ImageClassificationDeserializer(),
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
)

Expand Down
36 changes: 35 additions & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Optional

import torch

import flash
from flash.core.data.data_source import (
DefaultDataKeys,
FiftyOneDataSource,
NumpyDataSource,
PathsDataSource,
TensorDataSource,
)
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.data.process import Deserializer
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
from torchvision.transforms.functional import to_pil_image
else:
IMG_EXTENSIONS = []

if _IMAGE_AVAILABLE:
from PIL import Image as PILImage
else:

class Image:
Image = None


class ImageDeserializer(Deserializer):

def __init__(self):
super().__init__()
self.to_tensor = torchvision.transforms.ToTensor()

def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
return {
DefaultDataKeys.INPUT: img,
}

@property
def example_input(self) -> str:
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
return base64.b64encode(f.read()).decode("UTF-8")


class ImagePathsDataSource(PathsDataSource):

Expand Down
17 changes: 6 additions & 11 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE
from flash.image.data import ImageDeserializer
from flash.image.segmentation.serialization import SegmentationLabels
from flash.image.segmentation.transforms import default_transforms, train_default_transforms

Expand Down Expand Up @@ -215,19 +216,13 @@ def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return sample


class SemanticSegmentationDeserializer(Deserializer):

def __init__(self):

self.to_tensor = torchvision.transforms.ToTensor()
class SemanticSegmentationDeserializer(ImageDeserializer):

def deserialize(self, data: str) -> torch.Tensor:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
img = self.to_tensor(img)
return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: {"size": img.shape}}
result = super().deserialize(data)
result[DefaultDataKeys.INPUT] = self.to_tensor(result[DefaultDataKeys.INPUT])
result[DefaultDataKeys.METADATA] = {"size": result[DefaultDataKeys.INPUT].shape}
return result


class SemanticSegmentationPreprocess(Preprocess):
Expand Down
26 changes: 12 additions & 14 deletions flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
classes: Optional[List[str]] = None,
is_regression: bool = True
):

super().__init__()
self.cat_cols = cat_cols
self.num_cols = num_cols
self.target_col = target_col
Expand All @@ -134,20 +135,8 @@ def __init__(
self.classes = classes
self.is_regression = is_regression

@staticmethod
def _convert_row(row):
_row = []
for c in row:
try:
_row.append(float(c))
except Exception:
_row.append(c)
return _row

def deserialize(self, data: str) -> Any:
columns = data.split("\n")[0].split(',')
df = pd.DataFrame([TabularDeserializer._convert_row(x.split(',')[1:]) for x in data.split('\n')[1:-1]],
columns=columns)
df = pd.read_csv(StringIO(data))
df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col,
self.target_codes)[0]

Expand All @@ -159,6 +148,15 @@ def deserialize(self, data: str) -> Any:

return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)]

@property
def example_input(self) -> str:
row = {}
for cat_col in self.cat_cols:
row[cat_col] = ["test"]
for num_col in self.num_cols:
row[num_col] = [0]
return str(DataFrame.from_dict(row).to_csv())


class TabularPreprocess(Preprocess):

Expand Down
5 changes: 5 additions & 0 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@
class TextDeserializer(Deserializer):

def __init__(self, backbone: str, max_length: int, use_fast: bool = True):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=use_fast)
self.max_length = max_length

def deserialize(self, text: str) -> Tensor:
return self.tokenizer(text, max_length=self.max_length, truncation=True, padding="max_length")

@property
def example_input(self) -> str:
return "An example input"


class TextDataSource(DataSource):

Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fn_resnet(pretrained: bool = True):
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, serializer=Labels())

# 5. Create the trainer
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
trainer = flash.Trainer(max_epochs=3)

# 6. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
Expand Down
Loading