diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index 1040d86a3..afc759a5c 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -7,6 +7,7 @@ PlanGenerator, WatchableStatus, is_bluesky_compatible_device, + is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) from .context import BlueskyContext @@ -27,4 +28,5 @@ "WatchableStatus", "is_bluesky_compatible_device", "is_bluesky_plan_generator", + "is_bluesky_compatible_device_type", ] diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index b2e618207..8a8f35f0b 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -58,12 +58,19 @@ def is_bluesky_compatible_device(obj: Any) -> bool: is_object = not inspect.isclass(obj) - follows_protocols = any( - map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS) - ) # We must separately check if Obj refers to an instance rather than a # class, as both follow the protocols but only one is a "device". - return is_object and follows_protocols + return is_object and _follows_bluesky_protocols(obj) + + +def is_bluesky_compatible_device_type(cls: Type[Any]) -> bool: + # We must separately check if Obj refers to an class rather than an + # instance, as both follow the protocols but only one is a type. + return inspect.isclass(cls) and _follows_bluesky_protocols(cls) + + +def _follows_bluesky_protocols(obj: Any) -> bool: + return any(map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS)) def is_bluesky_plan_generator(func: PlanGenerator) -> bool: diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index e607929e8..367c6542b 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,14 +1,28 @@ import logging from dataclasses import dataclass, field from importlib import import_module +from inspect import Parameter, signature from pathlib import Path from types import ModuleType -from typing import Dict, List, Optional, Union +from typing import ( + Any, + Callable, + Deque, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) from bluesky import RunEngine from bluesky.protocols import Flyable, Readable -from pydantic import BaseConfig, Extra, validate_arguments -from pydantic.decorator import ValidatedFunction +from pydantic import BaseModel, create_model, validator from blueapi.utils import load_module_all, schema_for_func @@ -17,6 +31,7 @@ Plan, PlanGenerator, is_bluesky_compatible_device, + is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) from .device_lookup import find_component @@ -24,10 +39,6 @@ LOGGER = logging.getLogger(__name__) -class PlanConfig(BaseConfig): - extra = Extra.forbid - - @dataclass class BlueskyContext: """ @@ -37,8 +48,9 @@ class BlueskyContext: run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) - plans: Dict[str, ValidatedFunction] = field(default_factory=dict) + plans: Dict[str, Plan] = field(default_factory=dict) devices: Dict[str, Device] = field(default_factory=dict) + plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict) def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]: """ @@ -112,8 +124,12 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - schema = ValidatedFunction(plan, PlanConfig) - self.plans[plan.__name__] = schema + def get_device(name: str) -> Device: + return self.find_device(name) + + model = generate_plan_model(plan, get_device) + self.plans[plan.__name__] = Plan(name=plan.__name__, model=model) + self.plan_functions[plan.__name__] = plan return plan def device(self, device: Device, name: Optional[str] = None) -> None: @@ -142,3 +158,62 @@ def device(self, device: Device, name: Optional[str] = None) -> None: raise KeyError(f"Must supply a name for this device: {device}") self.devices[name] = device + + +def generate_plan_model( + plan: PlanGenerator, get_device: Callable[[str], Device] +) -> Type[BaseModel]: + model_annotations: Dict[str, Tuple[Type, Any]] = {} + validators: Dict[str, Any] = {} + for name, param in signature(plan).parameters.items(): + type_annotation = param.annotation + if is_bluesky_compatible_device_type(type_annotation): + type_annotation = str + validators[name] = validator(name)(get_device) + elif is_iterable_of_devices(type_annotation): + validators[name] = validator(name, each_item=True)(get_device) + + default_value = param.default + if default_value is Parameter.empty: + default_value = ... + + anno = (type_annotation, default_value) + model_annotations[name] = anno + + name = f"{plan.__name__}_model" + from pprint import pprint + + pprint(model_annotations) + return create_model(name, **model_annotations, __validators__=validators) + + +def is_mapping_with_devices(dct: Type) -> bool: + if get_params(dct): + ... + + +def is_iterable_of_devices(lst: Type) -> bool: + if origin_is_iterable(lst): + params = list(get_params(lst)) + if params: + (inner,) = params + return is_bluesky_compatible_device_type(inner) + return False + + +def get_params(maybe_parametrised: Type) -> Iterable[Type]: + for attr in "__args__", "__parameters__": + yield from getattr(maybe_parametrised, attr, []) + + +def origin_is_iterable(to_check: Type) -> bool: + return any( + map( + lambda origin: origin_is(to_check, origin), + [List, Set, Tuple, FrozenSet, Deque], + ) + ) + + +def origin_is(to_check: Type, origin: Type) -> bool: + return hasattr(to_check, "__origin__") and to_check.__origin__ is origin