Skip to content

Commit

Permalink
Deserialise plan method arguments by replacing types instead of build…
Browse files Browse the repository at this point in the history
…ing custom validators (#154)

* Improve support for bluesky type deserialisation

Plans that reference bluesky types need an intermediate step in the
deserialisation process to allow them to be referenced by name only.
Types in plan signatures are now converted to a runtime generated type
that can access the context where the plan will be run allowing pydantic
to ensure that devices are present and are of the correct type.

* Change type conversion functions into methods

Functions took an instance of context as an argument and were only
called from within existing context methods so it makes more sense for
them to be instance methods of the BlueskyContext instead of standalone
functions.

* Handle lazily evaluated types in type conversion

* Add test for concrete type conversion

---------

Co-authored-by: Peter Holloway <[email protected]>
  • Loading branch information
tpoliaw and tpoliaw authored May 4, 2023
1 parent 4c4df02 commit 6557da5
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 1,041 deletions.
64 changes: 27 additions & 37 deletions docs/developer/explanations/type_validators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ Blueapi takes the parameters of a plan and internally creates a pydantic_ model
b: str = "b"
That way, when the plan parameters are send in JSON form, they can be parsed and validated by pydantic.
However, it must also cover the case where a plan doesn't take a simple dictionary, list or primitive but
instead a device, such as a detector.
That way, when the plan parameters are sent in JSON form, they can be parsed and validated by pydantic.
However, it must also cover the case where a plan doesn't take a simple dictionary, list or primitive but
instead a device, such as a detector.

.. code:: python
def my_plan(a: int, b: Readable) -> Plan:
...
An Ophyd object cannot be passed over the network as JSON because it has state.
An Ophyd object cannot be passed over the network as JSON because it has state.
Instead, a string is passed, representing an ID of the object known to the ``BlueskyContext``.
At the time a plan's parameters are validated, blueapi must take all the strings that are supposed
to be devices and look them up against the context. For example with the request:
Expand All @@ -52,55 +52,45 @@ to be devices and look them up against the context. For example with the request

Solution
--------

Before pydantic, blueapi used apischema_, which had an ideal feature for this called conversions_.
Currently in the utils module of a blueapi is a similar feature called type validators.

They enable the ``BlueskyContext`` to dynamically generate pydantic models, like above, that look
roughly like this:
When the context loads available plans, it iterates through the type signature
and replaces any reference to a bluesky protocol (or instance of a protocol)
with a new class that extends the original type. Defining a class validator on
this new type allows it to check that the string being deserialised is the ID of
a device of the correct type.

These new intermediate types are used only in the deserialisation process. The
object returned from validator method is not checked by pydantic so it can be
the actual instance and the plan never sees the runtime generated reference
type, only the type it was expecting.

.. note:: This uses the fact that the new types generated at runtime have access to
the context that required them via their closure. This circumvents the usual
problem of pydantic validation not being able to access external state when
validating or deserialising.

.. code:: python
def my_plan(a: int, b: Readable) -> Plan:
...
# Becomes
class MyPlanModel(BaseModel):
a: int
b: Readable
@validator("b")
def valdiate_b(self, val: str) -> Readable:
return ctx.find_device(val)
b: Reference[Readable]
It also handles the case of the ``Readable`` type being placed at various type levels? For example:
This also allows ``Readable`` to be placed at various type levels. For example:

.. code:: python
def my_weird_plan(
a: Readable,
b: List[Readable],
c: Dict[str, Readable],
d: List[List[Readable]],
a: Readable,
b: List[Readable],
c: Dict[str, Readable],
d: List[List[Readable]],
e: List[Dict[str, Set[Readable]]]) -> Plan:
...
Implementation Details
----------------------

Pydantic models have validators: functions that are applied to specific fields by name. This is
insufficient for the requirements here, it would be helpful if validators could be applied by type,
rather than name.
The type validation module is essentially a shim layer that works out the names of all fields of a
particular type, then creates validators for all of those names. It also supports the type being in
nested lists and/or dictionaries, as mentioned above.
The field names are deteted by comparing the type annotation in the model to the type requested.
The actual validator is a function supplied by the caller, but if a list or dictionary is passed,
it will apply it to each item/value.

.. _pydantic: https://docs.pydantic.dev/
.. _apischema: https://wyfo.github.io/apischema/0.18/
.. _conversions: https://wyfo.github.io/apischema/0.18/conversions/
6 changes: 5 additions & 1 deletion src/blueapi/core/bluesky_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def _follows_bluesky_protocols(obj: Any) -> bool:


def is_bluesky_plan_generator(func: PlanGenerator) -> bool:
return get_type_hints(func).get("return") is MsgGenerator
try:
return get_type_hints(func).get("return") is MsgGenerator
except TypeError:
# get_type_hints fails on some objects (such as Union or Optinoal)
return False


class Plan(BlueapiBaseModel):
Expand Down
125 changes: 105 additions & 20 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import logging
from dataclasses import dataclass, field
from importlib import import_module
from inspect import Parameter, signature
from pathlib import Path
from types import ModuleType
from typing import Dict, Iterable, List, Optional, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
get_args,
get_origin,
get_type_hints,
)

from bluesky import RunEngine
from pydantic import create_model

from blueapi.utils import (
BlueapiPlanModelConfig,
TypeValidatorDefinition,
create_model_with_type_validators,
load_module_all,
)
from blueapi.utils import BlueapiPlanModelConfig, load_module_all

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Expand Down Expand Up @@ -41,6 +50,8 @@ class BlueskyContext:
devices: Dict[str, Device] = field(default_factory=dict)
plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict)

_reference_cache: Dict[Type, Type] = field(default_factory=dict)

def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]:
"""
Find a device in this context, allows for recursive search.
Expand Down Expand Up @@ -113,12 +124,10 @@ def my_plan(a: int, b: str):
if not is_bluesky_plan_generator(plan):
raise TypeError(f"{plan} is not a valid plan generator function")

validators = list(device_validators(self))
model = create_model_with_type_validators(
model = create_model( # type: ignore
plan.__name__,
validators,
func=plan,
config=BlueapiPlanModelConfig,
__config__=BlueapiPlanModelConfig,
**self._type_spec_for_function(plan),
)
self.plans[plan.__name__] = Plan(name=plan.__name__, model=model)
self.plan_functions[plan.__name__] = plan
Expand Down Expand Up @@ -151,13 +160,89 @@ def device(self, device: Device, name: Optional[str] = None) -> None:

self.devices[name] = device

def _reference(self, target: Type) -> Type:
"""
Create an intermediate reference type for the required ``target`` type that
will return an existing device during pydantic deserialisation/validation
Args:
target: Expected type of the device that is expected for IDs being
deserialised by the return type
Returns:
New type that can be deserialised by pydantic returning an existing device
for a string device ID
"""
if target not in self._reference_cache:

class Reference(target):
@classmethod
def __get_validators__(cls):
yield cls.valid

@classmethod
def valid(cls, value):
val = self.find_device(value)
if not isinstance(val, target):
raise ValueError(f"value is not {target}")
return val

self._reference_cache[target] = Reference

return self._reference_cache[target]

def device_validators(ctx: BlueskyContext) -> Iterable[TypeValidatorDefinition]:
def get_device(name: str) -> Device:
device = ctx.find_device(name)
if device is None:
raise KeyError(f"Could not find a device named {name}")
return device
def _type_spec_for_function(
self, func: Callable[..., Any]
) -> dict[str, Tuple[Type, Any]]:
"""
Parse a function signature and build map of field types and default
values that can be used to deserialise arguments from external sources.
Any references to any of the bluesky protocols are replaced with an
intermediate reference type that allows existing devices to be returned
for device ID strings.
Args:
func: The function whose signature is being parsed
Returns:
Mapping of {name: (type, default)} to be used by pydantic for deserialising
function arguments
"""
args = signature(func).parameters
types = get_type_hints(func)
new_args = {}
for name, para in args.items():
default = None if para.default is Parameter.empty else para.default
arg_type = types.get(name, Parameter.empty)
if arg_type is Parameter.empty:
raise ValueError(
f"Type annotation is required for '{name}' in '{func.__name__}'"
)
new_args[name] = (self._convert_type(arg_type), default)
return new_args

def _convert_type(self, typ: Type) -> Type:
"""
Recursively convert a type to something that can be deserialsed by
pydantic. Bluesky protocols (and types that extend them) are replaced
with an intermediate reference types that allows the current context to
be used to look up an existing device when deserialising device ID
strings.
Other types are returned as passed in.
Args:
typ: The type that is required - potentially referencing Bluesky protocols
for proto in BLUESKY_PROTOCOLS:
yield TypeValidatorDefinition(proto, get_device)
Returns:
A Type that can be deserialised by Pydantic
"""
if typ in BLUESKY_PROTOCOLS or any(
isinstance(typ, dev) for dev in BLUESKY_PROTOCOLS
):
return self._reference(typ)
args = get_args(typ)
if args:
new_types = tuple(self._convert_type(i) for i in args)
root = get_origin(typ)
return root[new_types] if root else typ
return typ
3 changes: 0 additions & 3 deletions src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from .modules import load_module_all
from .serialization import serialize
from .thread_exception import handle_all_exceptions
from .type_validator import TypeValidatorDefinition, create_model_with_type_validators

__all__ = [
"handle_all_exceptions",
"load_module_all",
"ConfigLoader",
"create_model_with_type_validators",
"TypeValidatorDefinition",
"serialize",
"BlueapiBaseModel",
"BlueapiModelConfig",
Expand Down
Loading

0 comments on commit 6557da5

Please sign in to comment.