Skip to content

Commit

Permalink
WIP on parameter passing
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Feb 23, 2023
1 parent 452caaa commit 39ad35d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/blueapi/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
PlanGenerator,
WatchableStatus,
is_bluesky_compatible_device,
is_bluesky_compatible_device_type,
is_bluesky_plan_generator,
)
from .context import BlueskyContext
Expand All @@ -27,4 +28,5 @@
"WatchableStatus",
"is_bluesky_compatible_device",
"is_bluesky_plan_generator",
"is_bluesky_compatible_device_type",
]
15 changes: 11 additions & 4 deletions src/blueapi/core/bluesky_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,19 @@

def is_bluesky_compatible_device(obj: Any) -> bool:
is_object = not inspect.isclass(obj)
follows_protocols = any(
map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS)
)
# We must separately check if Obj refers to an instance rather than a
# class, as both follow the protocols but only one is a "device".
return is_object and follows_protocols
return is_object and _follows_bluesky_protocols(obj)


def is_bluesky_compatible_device_type(cls: Type[Any]) -> bool:
# We must separately check if Obj refers to an class rather than an
# instance, as both follow the protocols but only one is a type.
return inspect.isclass(cls) and _follows_bluesky_protocols(cls)


def _follows_bluesky_protocols(obj: Any) -> bool:
return any(map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS))


def is_bluesky_plan_generator(func: PlanGenerator) -> bool:
Expand Down
95 changes: 85 additions & 10 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
import logging
from dataclasses import dataclass, field
from importlib import import_module
from inspect import Parameter, signature
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Union
from typing import (
Any,
Callable,
Deque,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)

from bluesky import RunEngine
from bluesky.protocols import Flyable, Readable
from pydantic import BaseConfig, Extra, validate_arguments
from pydantic.decorator import ValidatedFunction
from pydantic import BaseModel, create_model, validator

from blueapi.utils import load_module_all, schema_for_func

Expand All @@ -17,17 +31,14 @@
Plan,
PlanGenerator,
is_bluesky_compatible_device,
is_bluesky_compatible_device_type,
is_bluesky_plan_generator,
)
from .device_lookup import find_component

LOGGER = logging.getLogger(__name__)


class PlanConfig(BaseConfig):
extra = Extra.forbid


@dataclass
class BlueskyContext:
"""
Expand All @@ -37,8 +48,9 @@ class BlueskyContext:
run_engine: RunEngine = field(
default_factory=lambda: RunEngine(context_managers=[])
)
plans: Dict[str, ValidatedFunction] = field(default_factory=dict)
plans: Dict[str, Plan] = field(default_factory=dict)
devices: Dict[str, Device] = field(default_factory=dict)
plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict)

def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]:
"""
Expand Down Expand Up @@ -112,8 +124,12 @@ def my_plan(a: int, b: str):
if not is_bluesky_plan_generator(plan):
raise TypeError(f"{plan} is not a valid plan generator function")

schema = ValidatedFunction(plan, PlanConfig)
self.plans[plan.__name__] = schema
def get_device(name: str) -> Device:
return self.find_device(name)

model = generate_plan_model(plan, get_device)
self.plans[plan.__name__] = Plan(name=plan.__name__, model=model)
self.plan_functions[plan.__name__] = plan
return plan

def device(self, device: Device, name: Optional[str] = None) -> None:
Expand Down Expand Up @@ -142,3 +158,62 @@ def device(self, device: Device, name: Optional[str] = None) -> None:
raise KeyError(f"Must supply a name for this device: {device}")

self.devices[name] = device


def generate_plan_model(
plan: PlanGenerator, get_device: Callable[[str], Device]
) -> Type[BaseModel]:
model_annotations: Dict[str, Tuple[Type, Any]] = {}
validators: Dict[str, Any] = {}
for name, param in signature(plan).parameters.items():
type_annotation = param.annotation
if is_bluesky_compatible_device_type(type_annotation):
type_annotation = str
validators[name] = validator(name)(get_device)
elif is_iterable_of_devices(type_annotation):
validators[name] = validator(name, each_item=True)(get_device)

default_value = param.default
if default_value is Parameter.empty:
default_value = ...

anno = (type_annotation, default_value)
model_annotations[name] = anno

name = f"{plan.__name__}_model"
from pprint import pprint

pprint(model_annotations)
return create_model(name, **model_annotations, __validators__=validators)


def is_mapping_with_devices(dct: Type) -> bool:
if get_params(dct):
...


def is_iterable_of_devices(lst: Type) -> bool:
if origin_is_iterable(lst):
params = list(get_params(lst))
if params:
(inner,) = params
return is_bluesky_compatible_device_type(inner)
return False


def get_params(maybe_parametrised: Type) -> Iterable[Type]:
for attr in "__args__", "__parameters__":
yield from getattr(maybe_parametrised, attr, [])


def origin_is_iterable(to_check: Type) -> bool:
return any(
map(
lambda origin: origin_is(to_check, origin),
[List, Set, Tuple, FrozenSet, Deque],
)
)


def origin_is(to_check: Type, origin: Type) -> bool:
return hasattr(to_check, "__origin__") and to_check.__origin__ is origin

0 comments on commit 39ad35d

Please sign in to comment.