Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor serving (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 13, 2021
1 parent 3e1612b commit 052ed52
Show file tree
Hide file tree
Showing 23 changed files with 290 additions and 169 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the `backbone` argument from `TextClassificationData`, it is now sufficient to only provide a `backbone` argument to the `TextClassifier` ([#1022](https://github.com/PyTorchLightning/lightning-flash/pull/1022))

- Removed support for the `serve_sanity_check` argument in `flash.Trainer` ([#1062](https://github.com/PyTorchLightning/lightning-flash/pull/1062))

## [0.5.2] - 2021-11-05

### Added
Expand Down
35 changes: 31 additions & 4 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
# limitations under the License.
import os
import warnings
from typing import Any, Dict
from typing import Any, Dict, Optional, Type

import torch
import torch.nn as nn

from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
from flash.audio.speech_recognition.output_transform import SpeechRecognitionBackboneState
from flash.audio.speech_recognition.input import SpeechRecognitionDeserializer
from flash.audio.speech_recognition.output_transform import (
SpeechRecognitionBackboneState,
SpeechRecognitionOutputTransform,
)
from flash.core.data.input_transform import InputTransform
from flash.core.data.io.input import ServeInput
from flash.core.data.states import CollateFn
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.core.serve import Composition
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires
from flash.core.utilities.types import (
INPUT_TRANSFORM_TYPE,
LR_SCHEDULER_TYPE,
OPTIMIZER_TYPE,
OUTPUT_TRANSFORM_TYPE,
OUTPUT_TYPE,
)

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
Expand Down Expand Up @@ -55,6 +68,7 @@ def __init__(
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 1e-5,
output: OUTPUT_TYPE = None,
output_transform: OUTPUT_TRANSFORM_TYPE = SpeechRecognitionOutputTransform(),
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -69,6 +83,7 @@ def __init__(
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
output=output,
output_transform=output_transform,
)

self.save_hyperparameters()
Expand All @@ -86,3 +101,15 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
out = self.model(batch["input_values"], labels=batch["labels"])
out["logs"] = {"loss": out.loss}
return out

@requires("serve")
def serve(
self,
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = SpeechRecognitionDeserializer,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
) -> Composition:
return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs)
7 changes: 4 additions & 3 deletions flash/audio/speech_recognition/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class SpeechRecognitionBackboneState(ProcessState):


class SpeechRecognitionOutputTransform(OutputTransform):
@requires("audio")
def __init__(self):
super().__init__()

Expand All @@ -54,6 +53,7 @@ def tokenizer(self):
self._backbone = self.backbone
return self._tokenizer

@requires("audio")
def per_batch_transform(self, batch: Any) -> Any:
# converts logits into greedy transcription
pred_ids = torch.argmax(batch.logits, dim=-1)
Expand All @@ -62,9 +62,10 @@ def per_batch_transform(self, batch: Any) -> Any:

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("_tokenizer")
state.pop("_tokenizer", None)
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone)
if self.backbone is not None:
self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone)
27 changes: 8 additions & 19 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
from typing import Any, Sequence, TYPE_CHECKING

import torch
from torch import Tensor

from flash.core.data.callback import ControlFlow, FlashCallback
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.stages import RunningStage

if TYPE_CHECKING:
from flash.core.data.input_transform import InputTransform
from flash.core.data.process import Deserializer
from flash.core.data.io.input import ServeInput


class _DeserializeProcessorV2(torch.nn.Module):
class _ServeInputProcessor(torch.nn.Module):
def __init__(
self,
deserializer: "Deserializer",
input_transform: "InputTransform",
per_sample_transform: Callable,
callbacks: Optional[List[FlashCallback]] = None,
serve_input: "ServeInput",
):
super().__init__()
self.input_transform = input_transform
self.callback = ControlFlow(callbacks or [])
self.deserializer = convert_to_modules(deserializer)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.serve_input = serve_input
self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([])

def forward(self, sample: str):
sample = self.deserializer(sample)
sample = self.per_sample_transform(sample)
self.callback.on_per_sample_transform(sample, RunningStage.SERVING)
sample = self.serve_input._call_load_sample(sample)
sample = self.dataloader_collate_fn(sample)
return sample


Expand Down
20 changes: 1 addition & 19 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union
from typing import Any, Dict, List, Optional, Set, Type, Union

from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.batch import _DeserializeProcessorV2
from flash.core.data.input_transform import _create_collate_input_transform_processors
from flash.core.data.input_transform import InputTransform
from flash.core.data.input_transform import InputTransform as NewInputTransform
from flash.core.data.io.input import Input, InputBase
from flash.core.data.io.input_transform import _InputTransformProcessorV2
from flash.core.data.io.output import _OutputProcessor, Output
from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform
from flash.core.data.process import Deserializer
Expand Down Expand Up @@ -133,21 +130,6 @@ def _is_overridden_recursive(
return has_different_code
return has_different_code or cls._is_overridden_recursive(method_name, process_obj, super_obj)

@staticmethod
def _identity(samples: Sequence[Any]) -> Sequence[Any]:
return samples

def deserialize_processor(self) -> _DeserializeProcessorV2:
return _DeserializeProcessorV2(
self._deserializer,
self._input_transform_pipeline,
self._input_transform_pipeline._per_sample_transform,
[],
)

def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessorV2:
return _create_collate_input_transform_processors(self._input_transform_pipeline, [])[1]

def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor:
return self._create_output_transform_processor(running_stage, is_serving=is_serving)

Expand Down
78 changes: 51 additions & 27 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import flash
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.input_transform import InputTransform
from flash.core.data.io.input import InputBase
from flash.core.data.io.input import InputBase, ServeInput
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Deserializer, DeserializerMapping
Expand Down Expand Up @@ -354,18 +354,27 @@ def __init__(
self.output = output
self._wrapped_predict_step = False

def _wrap_predict_step(task, predict_step: Callable) -> Callable:
def _wrap_predict_step(self) -> None:
if not self._wrapped_predict_step:
process_fn = self.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING)

process_fn = task.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING)
predict_step = self.predict_step

@functools.wraps(predict_step)
def wrapper(self, *args, **kwargs):
predictions = predict_step(self, *args, **kwargs)
return process_fn(predictions)
@functools.wraps(predict_step)
def wrapper(*args, **kwargs):
predictions = predict_step(*args, **kwargs)
return process_fn(predictions)

task._wrapped_predict_step = True
self._original_predict_step = self.predict_step
self.predict_step = wrapper

return wrapper
self._wrapped_predict_step = True

def _unwrap_predict_step(self) -> None:
if self._wrapped_predict_step:
self.predict_step = self._original_predict_step
del self._original_predict_step
self._wrapped_predict_step = False

def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""Implement the core logic for the training/validation/test step. By default this includes:
Expand Down Expand Up @@ -759,11 +768,6 @@ def build_data_pipeline(
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@torch.jit.unused
@property
def is_servable(self) -> bool:
return type(self.build_data_pipeline()._deserializer) != Deserializer

@torch.jit.unused
@property
def data_pipeline(self) -> DataPipeline:
Expand Down Expand Up @@ -803,8 +807,10 @@ def output_transform(self) -> OutputTransform:
return getattr(self.data_pipeline, "_output_transform", None)

def on_predict_start(self) -> None:
if self.trainer and not self._wrapped_predict_step:
self.predict_step = self._wrap_predict_step(self.predict_step)
self._wrap_predict_step()

def on_predict_end(self) -> None:
self._unwrap_predict_step()

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# This may be an issue since here we create the same problems with pickle as in
Expand Down Expand Up @@ -1066,36 +1072,54 @@ def configure_callbacks(self):
return [BenchmarkConvergenceCI()]

@requires("serve")
def run_serve_sanity_check(self):
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")

def run_serve_sanity_check(self, serve_input: ServeInput):
from fastapi.testclient import TestClient

from flash.core.serve.flash_components import build_flash_serve_model_component

print("Running serve sanity check")
comp = build_flash_serve_model_component(self)
comp = build_flash_serve_model_component(self, serve_input)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)

with TestClient(app) as tc:
input_str = self.data_pipeline._deserializer.example_input
input_str = serve_input.example_input
body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}}
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
print(f"Sanity check response: {resp.json()}")

@requires("serve")
def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition":
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")
def serve(
self,
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = None,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
) -> "Composition":
"""Serve the ``Task``. Override this method to provide a default ``input_cls``, ``transform``, and
``transform_kwargs``.
Args:
host: The IP address to host the ``Task`` on.
port: The port to host on.
sanity_check: If ``True``, runs a sanity check before serving.
input_cls: The ``ServeInput`` type to use.
transform: The transform to use when serving.
transform_kwargs: Keyword arguments used to instantiate the transform.
"""
from flash.core.serve.flash_components import build_flash_serve_model_component

if input_cls is None:
raise NotImplementedError("The `input_cls` must be provided to enable serving.")

serve_input = input_cls(transform=transform, transform_kwargs=transform_kwargs)

if sanity_check:
self.run_serve_sanity_check()
self.run_serve_sanity_check(serve_input)

comp = build_flash_serve_model_component(self)
comp = build_flash_serve_model_component(self, serve_input)
composition = Composition(predict=comp, TESTING=flash._IS_TESTING)
composition.serve(host=host, port=port)
return composition
26 changes: 19 additions & 7 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from flash.core.data.batch import _ServeInputProcessor
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import DataKeys
from flash.core.serve import expose, ModelComponent
from flash.core.serve.types.base import BaseType
Expand Down Expand Up @@ -46,7 +48,18 @@ def deserialize(self, data: str) -> Any: # pragma: no cover
return None


def build_flash_serve_model_component(model):
def build_flash_serve_model_component(model, serve_input):

data_pipeline_state = DataPipelineState()
for properties in [
serve_input,
getattr(serve_input, "transform", None),
model._output_transform,
model._output,
model,
]:
if properties is not None and hasattr(properties, "attach_data_pipeline_state"):
properties.attach_data_pipeline_state(data_pipeline_state)

data_pipeline = model.build_data_pipeline()

Expand All @@ -55,23 +68,22 @@ def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = model.build_data_pipeline()
self.deserializer = self.data_pipeline._deserializer
self.dataloader_collate_fn = self.data_pipeline._deserializer._create_dataloader_collate_fn([])
self.on_after_batch_transfer_fn = self.data_pipeline._deserializer._create_on_after_batch_transfer_fn([])
self.serve_input = serve_input
self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([])
self.on_after_batch_transfer_fn = self.serve_input._create_on_after_batch_transfer_fn([])
self.output_transform_processor = self.data_pipeline.output_transform_processor(
RunningStage.PREDICTING, is_serving=True
RunningStage.SERVING, is_serving=True
)
# todo (tchaton) Remove this hack
self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
self.device = self.model.device

@expose(
inputs={"inputs": FlashInputs(data_pipeline._deserializer._call_load_sample)},
inputs={"inputs": FlashInputs(_ServeInputProcessor(serve_input))},
outputs={"outputs": FlashOutputs(data_pipeline.output_processor())},
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.dataloader_collate_fn(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
Expand Down
Loading

0 comments on commit 052ed52

Please sign in to comment.