Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor serving #1062

Merged
merged 5 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the `backbone` argument from `TextClassificationData`, it is now sufficient to only provide a `backbone` argument to the `TextClassifier` ([#1022](https://github.com/PyTorchLightning/lightning-flash/pull/1022))

- Removed support for the `serve_sanity_check` argument in `flash.Trainer` ([#1062](https://github.com/PyTorchLightning/lightning-flash/pull/1062))

## [0.5.2] - 2021-11-05

### Added
Expand Down
35 changes: 31 additions & 4 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
# limitations under the License.
import os
import warnings
from typing import Any, Dict
from typing import Any, Dict, Optional, Type

import torch
import torch.nn as nn

from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
from flash.audio.speech_recognition.output_transform import SpeechRecognitionBackboneState
from flash.audio.speech_recognition.input import SpeechRecognitionDeserializer
from flash.audio.speech_recognition.output_transform import (
SpeechRecognitionBackboneState,
SpeechRecognitionOutputTransform,
)
from flash.core.data.input_transform import InputTransform
from flash.core.data.io.input import ServeInput
from flash.core.data.states import CollateFn
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.core.serve import Composition
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires
from flash.core.utilities.types import (
INPUT_TRANSFORM_TYPE,
LR_SCHEDULER_TYPE,
OPTIMIZER_TYPE,
OUTPUT_TRANSFORM_TYPE,
OUTPUT_TYPE,
)

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
Expand Down Expand Up @@ -55,6 +68,7 @@ def __init__(
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 1e-5,
output: OUTPUT_TYPE = None,
output_transform: OUTPUT_TRANSFORM_TYPE = SpeechRecognitionOutputTransform(),
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -69,6 +83,7 @@ def __init__(
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
output=output,
output_transform=output_transform,
)

self.save_hyperparameters()
Expand All @@ -86,3 +101,15 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
out = self.model(batch["input_values"], labels=batch["labels"])
out["logs"] = {"loss": out.loss}
return out

@requires("serve")
def serve(
self,
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = SpeechRecognitionDeserializer,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
) -> Composition:
return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs)
7 changes: 4 additions & 3 deletions flash/audio/speech_recognition/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class SpeechRecognitionBackboneState(ProcessState):


class SpeechRecognitionOutputTransform(OutputTransform):
@requires("audio")
def __init__(self):
super().__init__()

Expand All @@ -54,6 +53,7 @@ def tokenizer(self):
self._backbone = self.backbone
return self._tokenizer

@requires("audio")
def per_batch_transform(self, batch: Any) -> Any:
# converts logits into greedy transcription
pred_ids = torch.argmax(batch.logits, dim=-1)
Expand All @@ -62,9 +62,10 @@ def per_batch_transform(self, batch: Any) -> Any:

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("_tokenizer")
state.pop("_tokenizer", None)
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone)
if self.backbone is not None:
self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone)
27 changes: 8 additions & 19 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
from typing import Any, Sequence, TYPE_CHECKING

import torch
from torch import Tensor

from flash.core.data.callback import ControlFlow, FlashCallback
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.stages import RunningStage

if TYPE_CHECKING:
from flash.core.data.input_transform import InputTransform
from flash.core.data.process import Deserializer
from flash.core.data.io.input import ServeInput


class _DeserializeProcessorV2(torch.nn.Module):
class _ServeInputProcessor(torch.nn.Module):
def __init__(
self,
deserializer: "Deserializer",
input_transform: "InputTransform",
per_sample_transform: Callable,
callbacks: Optional[List[FlashCallback]] = None,
serve_input: "ServeInput",
):
super().__init__()
self.input_transform = input_transform
self.callback = ControlFlow(callbacks or [])
self.deserializer = convert_to_modules(deserializer)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.serve_input = serve_input
self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([])

def forward(self, sample: str):
sample = self.deserializer(sample)
sample = self.per_sample_transform(sample)
self.callback.on_per_sample_transform(sample, RunningStage.SERVING)
sample = self.serve_input._call_load_sample(sample)
sample = self.dataloader_collate_fn(sample)
return sample


Expand Down
20 changes: 1 addition & 19 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union
from typing import Any, Dict, List, Optional, Set, Type, Union

from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.batch import _DeserializeProcessorV2
from flash.core.data.input_transform import _create_collate_input_transform_processors
from flash.core.data.input_transform import InputTransform
from flash.core.data.input_transform import InputTransform as NewInputTransform
from flash.core.data.io.input import Input, InputBase
from flash.core.data.io.input_transform import _InputTransformProcessorV2
from flash.core.data.io.output import _OutputProcessor, Output
from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform
from flash.core.data.process import Deserializer
Expand Down Expand Up @@ -133,21 +130,6 @@ def _is_overridden_recursive(
return has_different_code
return has_different_code or cls._is_overridden_recursive(method_name, process_obj, super_obj)

@staticmethod
def _identity(samples: Sequence[Any]) -> Sequence[Any]:
return samples

def deserialize_processor(self) -> _DeserializeProcessorV2:
return _DeserializeProcessorV2(
self._deserializer,
self._input_transform_pipeline,
self._input_transform_pipeline._per_sample_transform,
[],
)

def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessorV2:
return _create_collate_input_transform_processors(self._input_transform_pipeline, [])[1]

def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor:
return self._create_output_transform_processor(running_stage, is_serving=is_serving)

Expand Down
78 changes: 51 additions & 27 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import flash
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.input_transform import InputTransform
from flash.core.data.io.input import InputBase
from flash.core.data.io.input import InputBase, ServeInput
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Deserializer, DeserializerMapping
Expand Down Expand Up @@ -354,18 +354,27 @@ def __init__(
self.output = output
self._wrapped_predict_step = False

def _wrap_predict_step(task, predict_step: Callable) -> Callable:
def _wrap_predict_step(self) -> None:
if not self._wrapped_predict_step:
process_fn = self.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING)

process_fn = task.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING)
predict_step = self.predict_step

@functools.wraps(predict_step)
def wrapper(self, *args, **kwargs):
predictions = predict_step(self, *args, **kwargs)
return process_fn(predictions)
@functools.wraps(predict_step)
def wrapper(*args, **kwargs):
predictions = predict_step(*args, **kwargs)
return process_fn(predictions)

task._wrapped_predict_step = True
self._original_predict_step = self.predict_step
self.predict_step = wrapper

return wrapper
self._wrapped_predict_step = True

def _unwrap_predict_step(self) -> None:
if self._wrapped_predict_step:
self.predict_step = self._original_predict_step
del self._original_predict_step
self._wrapped_predict_step = False

def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""Implement the core logic for the training/validation/test step. By default this includes:
Expand Down Expand Up @@ -759,11 +768,6 @@ def build_data_pipeline(
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@torch.jit.unused
@property
def is_servable(self) -> bool:
return type(self.build_data_pipeline()._deserializer) != Deserializer

@torch.jit.unused
@property
def data_pipeline(self) -> DataPipeline:
Expand Down Expand Up @@ -803,8 +807,10 @@ def output_transform(self) -> OutputTransform:
return getattr(self.data_pipeline, "_output_transform", None)

def on_predict_start(self) -> None:
if self.trainer and not self._wrapped_predict_step:
self.predict_step = self._wrap_predict_step(self.predict_step)
self._wrap_predict_step()

def on_predict_end(self) -> None:
self._unwrap_predict_step()

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# This may be an issue since here we create the same problems with pickle as in
Expand Down Expand Up @@ -1066,36 +1072,54 @@ def configure_callbacks(self):
return [BenchmarkConvergenceCI()]

@requires("serve")
def run_serve_sanity_check(self):
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")

def run_serve_sanity_check(self, serve_input: ServeInput):
from fastapi.testclient import TestClient

from flash.core.serve.flash_components import build_flash_serve_model_component

print("Running serve sanity check")
comp = build_flash_serve_model_component(self)
comp = build_flash_serve_model_component(self, serve_input)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
input_str = self.data_pipeline._deserializer.example_input
input_str = serve_input.example_input
body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}}
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
print(f"Sanity check response: {resp.json()}")

@requires("serve")
def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition":
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")
def serve(
self,
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = None,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
) -> "Composition":
"""Serve the ``Task``. Override this method to provide a default ``input_cls``, ``transform``, and
``transform_kwargs``.

Args:
host: The IP address to host the ``Task`` on.
port: The port to host on.
sanity_check: If ``True``, runs a sanity check before serving.
input_cls: The ``ServeInput`` type to use.
transform: The transform to use when serving.
transform_kwargs: Keyword arguments used to instantiate the transform.
"""
from flash.core.serve.flash_components import build_flash_serve_model_component

if input_cls is None:
raise NotImplementedError("The `input_cls` must be provided to enable serving.")

serve_input = input_cls(transform=transform, transform_kwargs=transform_kwargs)

if sanity_check:
self.run_serve_sanity_check()
self.run_serve_sanity_check(serve_input)

comp = build_flash_serve_model_component(self)
comp = build_flash_serve_model_component(self, serve_input)
composition = Composition(predict=comp, TESTING=flash._IS_TESTING)
composition.serve(host=host, port=port)
return composition
26 changes: 19 additions & 7 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from flash.core.data.batch import _ServeInputProcessor
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import DataKeys
from flash.core.serve import expose, ModelComponent
from flash.core.serve.types.base import BaseType
Expand Down Expand Up @@ -46,7 +48,18 @@ def deserialize(self, data: str) -> Any: # pragma: no cover
return None


def build_flash_serve_model_component(model):
def build_flash_serve_model_component(model, serve_input):

data_pipeline_state = DataPipelineState()
for properties in [
serve_input,
getattr(serve_input, "transform", None),
model._output_transform,
model._output,
model,
]:
if properties is not None and hasattr(properties, "attach_data_pipeline_state"):
properties.attach_data_pipeline_state(data_pipeline_state)

data_pipeline = model.build_data_pipeline()

Expand All @@ -55,23 +68,22 @@ def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = model.build_data_pipeline()
self.deserializer = self.data_pipeline._deserializer
self.dataloader_collate_fn = self.data_pipeline._deserializer._create_dataloader_collate_fn([])
self.on_after_batch_transfer_fn = self.data_pipeline._deserializer._create_on_after_batch_transfer_fn([])
self.serve_input = serve_input
self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([])
self.on_after_batch_transfer_fn = self.serve_input._create_on_after_batch_transfer_fn([])
self.output_transform_processor = self.data_pipeline.output_transform_processor(
RunningStage.PREDICTING, is_serving=True
RunningStage.SERVING, is_serving=True
)
# todo (tchaton) Remove this hack
self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
self.device = self.model.device

@expose(
inputs={"inputs": FlashInputs(data_pipeline._deserializer._call_load_sample)},
inputs={"inputs": FlashInputs(_ServeInputProcessor(serve_input))},
outputs={"outputs": FlashOutputs(data_pipeline.output_processor())},
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.dataloader_collate_fn(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
Expand Down
Loading