Skip to content

Commit

Permalink
Add support for pydantic (#26)
Browse files Browse the repository at this point in the history
This commit adds an optional support for pydantic models.
To use this simply add the pydantic dependency, and annotate
a handler with it.

```py

Greeting(BaseModel):
  name: str

@svc.handler()
async def greet(ctx, greeting: Greeting):
  ..
```

With this, any bad input (validation error) will result with a
TerminalError thrown by the serializer.
  • Loading branch information
igalshilman authored Oct 23, 2024
1 parent ac0c995 commit 946f418
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ mypy:
# Recipe to run pylint for linting
pylint:
@echo "Running pylint..."
{{python}} -m pylint python/restate
{{python}} -m pylint examples/
{{python}} -m pylint python/restate --ignore-paths='^.*.?venv.*$'
{{python}} -m pylint examples/ --ignore-paths='^.*\.?venv.*$'

test:
@echo "Running Python tests..."
Expand Down
12 changes: 10 additions & 2 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def compute_discovery_json(endpoint: RestateEndpoint,
headers = {"content-type": "application/vnd.restate.endpointmanifest.v1+json"}
return (headers, json_str)

def try_extract_json_schema(model: Any) -> typing.Optional[typing.Any]:
"""
Try to extract the JSON schema from a schema object
"""
if model:
return model.model_json_schema(mode='serialization')
return None

def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal["bidi", "request_response"]) -> Endpoint:
"""
return restate's discovery object for an endpoint
Expand All @@ -131,11 +139,11 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
# input
inp = InputPayload(required=False,
contentType=handler.handler_io.accept,
jsonSchema=None)
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_input_model))
# output
out = OutputPayload(setContentTypeIfEmpty=False,
contentType=handler.handler_io.content_type,
jsonSchema=None)
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_output_model))
# add the handler
service_handlers.append(Handler(name=handler.name, ty=ty, input=inp, output=out))

Expand Down
70 changes: 65 additions & 5 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,34 @@
"""

from dataclasses import dataclass
from inspect import Signature
from typing import Any, Callable, Awaitable, Generic, Literal, Optional, TypeVar

from restate.serde import Serde
from restate.exceptions import TerminalError
from restate.serde import JsonSerde, Serde, PydanticJsonSerde

I = TypeVar('I')
O = TypeVar('O')

# we will use this symbol to store the handler in the function
RESTATE_UNIQUE_HANDLER_SYMBOL = str(object())


def try_import_pydantic_base_model():
"""
Try to import PydanticBaseModel from Pydantic.
"""
try:
from pydantic import BaseModel # type: ignore # pylint: disable=import-outside-toplevel
return BaseModel
except ImportError:
class Dummy: # pylint: disable=too-few-public-methods
"""a dummy class to use when Pydantic is not available"""

return Dummy

PYDANTIC_BASE_MODEL = try_import_pydantic_base_model()

@dataclass
class ServiceTag:
"""
Expand All @@ -42,13 +60,45 @@ class HandlerIO(Generic[I, O]):
Attributes:
accept (str): The accept header value for the handler.
content_type (str): The content type header value for the handler.
serializer: The serializer function to convert output to bytes.
deserializer: The deserializer function to convert input type to bytes.
"""
accept: str
content_type: str
input_serde: Serde[I]
output_serde: Serde[O]
pydantic_input_model: Optional[I] = None
pydantic_output_model: Optional[O] = None

def is_pydantic(annotation) -> bool:
"""
Check if an object is a Pydantic model.
"""
try:
return issubclass(annotation, PYDANTIC_BASE_MODEL)
except TypeError:
# annotation is not a class or a type
return False


def infer_pydantic_io(handler_io: HandlerIO[I, O], signature: Signature):
"""
Augment handler_io with Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the input and the return types of a function, and will:
* capture any Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
"""
# check if the handlers I/O is a PydanticBaseModel
annotation = list(signature.parameters.values())[-1].annotation
if is_pydantic(annotation):
handler_io.pydantic_input_model = annotation
if isinstance(handler_io.input_serde, JsonSerde): # type: ignore
handler_io.input_serde = PydanticJsonSerde(annotation)

annotation = signature.return_annotation
if is_pydantic(annotation):
handler_io.pydantic_output_model = annotation
if isinstance(handler_io.output_serde, JsonSerde): # type: ignore
handler_io.output_serde = PydanticJsonSerde(annotation)

@dataclass
class Handler(Generic[I, O]):
Expand All @@ -71,7 +121,7 @@ def make_handler(service_tag: ServiceTag,
name: str | None,
kind: Optional[Literal["exclusive", "shared", "workflow"]],
wrapped: Any,
arity: int) -> Handler[I, O]:
signature: Signature) -> Handler[I, O]:
"""
Factory function to create a handler.
"""
Expand All @@ -82,12 +132,19 @@ def make_handler(service_tag: ServiceTag,
if not handler_name:
raise ValueError("Handler name must be provided")

if len(signature.parameters) == 0:
raise ValueError("Handler must have at least one parameter")

arity = len(signature.parameters)
infer_pydantic_io(handler_io, signature)

handler = Handler[I, O](service_tag,
handler_io,
kind,
handler_name,
wrapped,
arity)

vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler
return handler

Expand All @@ -105,7 +162,10 @@ async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) ->
Invoke the handler with the given context and input.
"""
if handler.arity == 2:
in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore
try:
in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore
except Exception as e:
raise TerminalError(message=f"Unable to parse an input argument. {e}") from e
out_arg = await handler.fn(ctx, in_arg) # type: ignore [call-arg, arg-type]
else:
out_arg = await handler.fn(ctx) # type: ignore [call-arg]
Expand Down
4 changes: 2 additions & 2 deletions python/restate/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def wrapper(fn):
def wrapped(*args, **kwargs):
return fn(*args, **kwargs)

arity = len(inspect.signature(fn).parameters)
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, arity)
signature = inspect.signature(fn)
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature)
self.handlers[handler.name] = handler
return wrapped

Expand Down
37 changes: 37 additions & 0 deletions python/restate/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,43 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
return bytes(json.dumps(obj), "utf-8")


class PydanticJsonSerde(Serde[I]):
"""
Serde for Pydantic models to/from JSON
"""

def __init__(self, model):
self.model = model

def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
Deserializes a bytearray to a Pydantic model.
Args:
buf (bytearray): The bytearray to deserialize.
Returns:
typing.Optional[I]: The deserialized Pydantic model.
"""
if not buf:
return None
return self.model.model_validate_json(buf)

def serialize(self, obj: typing.Optional[I]) -> bytes:
"""
Serializes a Pydantic model to a bytearray.
Args:
obj (I): The Pydantic model to serialize.
Returns:
bytearray: The serialized bytearray.
"""
if obj is None:
return bytes()
json_str = obj.model_dump_json() # type: ignore[attr-defined]
return json_str.encode("utf-8")

def deserialize_json(buf: typing.ByteString) -> typing.Optional[O]:
"""
Deserializes a bytearray to a JSON object.
Expand Down
4 changes: 2 additions & 2 deletions python/restate/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def wrapper(fn):
def wrapped(*args, **kwargs):
return fn(*args, **kwargs)

arity = len(inspect.signature(fn).parameters)
handler = make_handler(self.service_tag, handler_io, name, None, wrapped, arity)
signature = inspect.signature(fn)
handler = make_handler(self.service_tag, handler_io, name, None, wrapped, signature)
self.handlers[handler.name] = handler
return wrapped

Expand Down
4 changes: 2 additions & 2 deletions python/restate/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def wrapper(fn):
def wrapped(*args, **kwargs):
return fn(*args, **kwargs)

arity = len(inspect.signature(fn).parameters)
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, arity)
signature = inspect.signature(fn)
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature)
self.handlers[handler.name] = handler
return wrapped

Expand Down
5 changes: 3 additions & 2 deletions shell.nix
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{ pkgs ? import <nixpkgs> {} }:

(pkgs.buildFHSUserEnv {
name = "my-python-env";
name = "sdk-python";
targetPkgs = pkgs: (with pkgs; [
python3
python3Packages.pip
Expand All @@ -10,6 +10,7 @@

# rust
rustup
cargo
clang
llvmPackages.bintools
protobuf
Expand All @@ -29,6 +30,6 @@
LIBCLANG_PATH = pkgs.lib.makeLibraryPath [ pkgs.llvmPackages_latest.libclang.lib ];

runScript = ''
bash
bash
'';
}).env

0 comments on commit 946f418

Please sign in to comment.