From 946f4180ec598a3b31c86e3278a8ad391a87418a Mon Sep 17 00:00:00 2001 From: Igal Shilman Date: Wed, 23 Oct 2024 15:14:02 +0200 Subject: [PATCH] Add support for pydantic (#26) This commit adds an optional support for pydantic models. To use this simply add the pydantic dependency, and annotate a handler with it. ```py Greeting(BaseModel): name: str @svc.handler() async def greet(ctx, greeting: Greeting): .. ``` With this, any bad input (validation error) will result with a TerminalError thrown by the serializer. --- Justfile | 4 +-- python/restate/discovery.py | 12 +++++-- python/restate/handler.py | 70 ++++++++++++++++++++++++++++++++++--- python/restate/object.py | 4 +-- python/restate/serde.py | 37 ++++++++++++++++++++ python/restate/service.py | 4 +-- python/restate/workflow.py | 4 +-- shell.nix | 5 +-- 8 files changed, 123 insertions(+), 17 deletions(-) diff --git a/Justfile b/Justfile index 66b9035..7af69bd 100644 --- a/Justfile +++ b/Justfile @@ -18,8 +18,8 @@ mypy: # Recipe to run pylint for linting pylint: @echo "Running pylint..." - {{python}} -m pylint python/restate - {{python}} -m pylint examples/ + {{python}} -m pylint python/restate --ignore-paths='^.*.?venv.*$' + {{python}} -m pylint examples/ --ignore-paths='^.*\.?venv.*$' test: @echo "Running Python tests..." diff --git a/python/restate/discovery.py b/python/restate/discovery.py index efd0975..52c9f5c 100644 --- a/python/restate/discovery.py +++ b/python/restate/discovery.py @@ -113,6 +113,14 @@ def compute_discovery_json(endpoint: RestateEndpoint, headers = {"content-type": "application/vnd.restate.endpointmanifest.v1+json"} return (headers, json_str) +def try_extract_json_schema(model: Any) -> typing.Optional[typing.Any]: + """ + Try to extract the JSON schema from a schema object + """ + if model: + return model.model_json_schema(mode='serialization') + return None + def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal["bidi", "request_response"]) -> Endpoint: """ return restate's discovery object for an endpoint @@ -131,11 +139,11 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[ # input inp = InputPayload(required=False, contentType=handler.handler_io.accept, - jsonSchema=None) + jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_input_model)) # output out = OutputPayload(setContentTypeIfEmpty=False, contentType=handler.handler_io.content_type, - jsonSchema=None) + jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_output_model)) # add the handler service_handlers.append(Handler(name=handler.name, ty=ty, input=inp, output=out)) diff --git a/python/restate/handler.py b/python/restate/handler.py index 7012d35..4d833ba 100644 --- a/python/restate/handler.py +++ b/python/restate/handler.py @@ -16,9 +16,11 @@ """ from dataclasses import dataclass +from inspect import Signature from typing import Any, Callable, Awaitable, Generic, Literal, Optional, TypeVar -from restate.serde import Serde +from restate.exceptions import TerminalError +from restate.serde import JsonSerde, Serde, PydanticJsonSerde I = TypeVar('I') O = TypeVar('O') @@ -26,6 +28,22 @@ # we will use this symbol to store the handler in the function RESTATE_UNIQUE_HANDLER_SYMBOL = str(object()) + +def try_import_pydantic_base_model(): + """ + Try to import PydanticBaseModel from Pydantic. + """ + try: + from pydantic import BaseModel # type: ignore # pylint: disable=import-outside-toplevel + return BaseModel + except ImportError: + class Dummy: # pylint: disable=too-few-public-methods + """a dummy class to use when Pydantic is not available""" + + return Dummy + +PYDANTIC_BASE_MODEL = try_import_pydantic_base_model() + @dataclass class ServiceTag: """ @@ -42,13 +60,45 @@ class HandlerIO(Generic[I, O]): Attributes: accept (str): The accept header value for the handler. content_type (str): The content type header value for the handler. - serializer: The serializer function to convert output to bytes. - deserializer: The deserializer function to convert input type to bytes. """ accept: str content_type: str input_serde: Serde[I] output_serde: Serde[O] + pydantic_input_model: Optional[I] = None + pydantic_output_model: Optional[O] = None + +def is_pydantic(annotation) -> bool: + """ + Check if an object is a Pydantic model. + """ + try: + return issubclass(annotation, PYDANTIC_BASE_MODEL) + except TypeError: + # annotation is not a class or a type + return False + + +def infer_pydantic_io(handler_io: HandlerIO[I, O], signature: Signature): + """ + Augment handler_io with Pydantic models when these are provided. + This method will inspect the signature of an handler and will look for + the input and the return types of a function, and will: + * capture any Pydantic models (to be used later at discovery) + * replace the default json serializer (is unchanged by a user) with a Pydantic serde + """ + # check if the handlers I/O is a PydanticBaseModel + annotation = list(signature.parameters.values())[-1].annotation + if is_pydantic(annotation): + handler_io.pydantic_input_model = annotation + if isinstance(handler_io.input_serde, JsonSerde): # type: ignore + handler_io.input_serde = PydanticJsonSerde(annotation) + + annotation = signature.return_annotation + if is_pydantic(annotation): + handler_io.pydantic_output_model = annotation + if isinstance(handler_io.output_serde, JsonSerde): # type: ignore + handler_io.output_serde = PydanticJsonSerde(annotation) @dataclass class Handler(Generic[I, O]): @@ -71,7 +121,7 @@ def make_handler(service_tag: ServiceTag, name: str | None, kind: Optional[Literal["exclusive", "shared", "workflow"]], wrapped: Any, - arity: int) -> Handler[I, O]: + signature: Signature) -> Handler[I, O]: """ Factory function to create a handler. """ @@ -82,12 +132,19 @@ def make_handler(service_tag: ServiceTag, if not handler_name: raise ValueError("Handler name must be provided") + if len(signature.parameters) == 0: + raise ValueError("Handler must have at least one parameter") + + arity = len(signature.parameters) + infer_pydantic_io(handler_io, signature) + handler = Handler[I, O](service_tag, handler_io, kind, handler_name, wrapped, arity) + vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler return handler @@ -105,7 +162,10 @@ async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) -> Invoke the handler with the given context and input. """ if handler.arity == 2: - in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore + try: + in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore + except Exception as e: + raise TerminalError(message=f"Unable to parse an input argument. {e}") from e out_arg = await handler.fn(ctx, in_arg) # type: ignore [call-arg, arg-type] else: out_arg = await handler.fn(ctx) # type: ignore [call-arg] diff --git a/python/restate/object.py b/python/restate/object.py index ffd9b44..2eef77f 100644 --- a/python/restate/object.py +++ b/python/restate/object.py @@ -85,8 +85,8 @@ def wrapper(fn): def wrapped(*args, **kwargs): return fn(*args, **kwargs) - arity = len(inspect.signature(fn).parameters) - handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, arity) + signature = inspect.signature(fn) + handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature) self.handlers[handler.name] = handler return wrapped diff --git a/python/restate/serde.py b/python/restate/serde.py index b1cfc08..f5b4b23 100644 --- a/python/restate/serde.py +++ b/python/restate/serde.py @@ -108,6 +108,43 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: return bytes(json.dumps(obj), "utf-8") +class PydanticJsonSerde(Serde[I]): + """ + Serde for Pydantic models to/from JSON + """ + + def __init__(self, model): + self.model = model + + def deserialize(self, buf: bytes) -> typing.Optional[I]: + """ + Deserializes a bytearray to a Pydantic model. + + Args: + buf (bytearray): The bytearray to deserialize. + + Returns: + typing.Optional[I]: The deserialized Pydantic model. + """ + if not buf: + return None + return self.model.model_validate_json(buf) + + def serialize(self, obj: typing.Optional[I]) -> bytes: + """ + Serializes a Pydantic model to a bytearray. + + Args: + obj (I): The Pydantic model to serialize. + + Returns: + bytearray: The serialized bytearray. + """ + if obj is None: + return bytes() + json_str = obj.model_dump_json() # type: ignore[attr-defined] + return json_str.encode("utf-8") + def deserialize_json(buf: typing.ByteString) -> typing.Optional[O]: """ Deserializes a bytearray to a JSON object. diff --git a/python/restate/service.py b/python/restate/service.py index 7cffb23..ef83491 100644 --- a/python/restate/service.py +++ b/python/restate/service.py @@ -84,8 +84,8 @@ def wrapper(fn): def wrapped(*args, **kwargs): return fn(*args, **kwargs) - arity = len(inspect.signature(fn).parameters) - handler = make_handler(self.service_tag, handler_io, name, None, wrapped, arity) + signature = inspect.signature(fn) + handler = make_handler(self.service_tag, handler_io, name, None, wrapped, signature) self.handlers[handler.name] = handler return wrapped diff --git a/python/restate/workflow.py b/python/restate/workflow.py index 2e1f6b1..2dfe0d0 100644 --- a/python/restate/workflow.py +++ b/python/restate/workflow.py @@ -114,8 +114,8 @@ def wrapper(fn): def wrapped(*args, **kwargs): return fn(*args, **kwargs) - arity = len(inspect.signature(fn).parameters) - handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, arity) + signature = inspect.signature(fn) + handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature) self.handlers[handler.name] = handler return wrapped diff --git a/shell.nix b/shell.nix index ae60cdf..58cb365 100755 --- a/shell.nix +++ b/shell.nix @@ -1,7 +1,7 @@ { pkgs ? import {} }: (pkgs.buildFHSUserEnv { - name = "my-python-env"; + name = "sdk-python"; targetPkgs = pkgs: (with pkgs; [ python3 python3Packages.pip @@ -10,6 +10,7 @@ # rust rustup + cargo clang llvmPackages.bintools protobuf @@ -29,6 +30,6 @@ LIBCLANG_PATH = pkgs.lib.makeLibraryPath [ pkgs.llvmPackages_latest.libclang.lib ]; runScript = '' - bash + bash ''; }).env