Skip to content

Commit

Permalink
Implement initial state management routines for lazy_imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bswck committed Nov 7, 2024
1 parent f2803d3 commit 211afd4
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 62 deletions.
106 changes: 65 additions & 41 deletions injection/contrib/pep690.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,40 @@
from __future__ import annotations

import sys
import types
from collections.abc import Callable
from contextlib import suppress
from collections.abc import Generator
from contextlib import contextmanager, suppress
from contextvars import ContextVar
from copy import copy
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, overload
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload

from injection.main import peek_or_inject

if TYPE_CHECKING:
from _typeshed.importlib import MetaPathFinderProtocol, PathEntryFinderProtocol
from typing_extensions import Never, TypeVar
from typing_extensions import TypeAlias

from injection.main import Injection


T = TypeVar("T", default=None)
T = TypeVar("T")
Obj = TypeVar("Obj")
InjectedAttributeStash: TypeAlias = "dict[Injection[Obj], T]"


class SysActions(Enum):
class StateActionType(Enum):
PERSIST = auto()
"""Copy state visible now and expose it to the original thread on future request."""

FUTURE = auto()
"""
Allow the state to evolve naturally at runtime.
Rely on that future version of the state when it's requested.
"""

CONSTANT = auto()
SPECIFIED = auto()
"""Define one state forever (like PERSIST, but with custom value)."""


class StateAction(Generic[T]):
Expand All @@ -35,61 +45,75 @@ class StateAction(Generic[T]):
@overload
def __init__(
self,
action: Literal[SysActions.PERSIST, SysActions.FUTURE],
action_type: Literal[StateActionType.PERSIST, StateActionType.FUTURE],
data: None = None,
) -> None: ...

@overload
def __init__(
self,
action: Literal[SysActions.CONSTANT],
action_type: Literal[StateActionType.CONSTANT],
data: T,
) -> None: ...

def __init__(self, action: SysActions, data: T | None = None) -> None:
self.action = action
def __init__(
self,
action_type: StateActionType,
data: T | None = None,
) -> None:
self.action_type = action_type
self.data = data


PERSIST: StateAction = StateAction(SysActions.PERSIST)
FUTURE: StateAction = StateAction(SysActions.FUTURE)
PERSIST: StateAction[None] = StateAction(StateActionType.PERSIST)
FUTURE: StateAction[None] = StateAction(StateActionType.FUTURE)


injection_var: ContextVar[Injection[Any]] = ContextVar("injection")


class AttributeMappings(TypedDict, Generic[Obj]):
path: dict[Injection[Obj], list[str]]
path_hooks: dict[Injection[Obj], list[Callable[[str], PathEntryFinderProtocol]]]
meta_path: dict[Injection[Obj], list[MetaPathFinderProtocol]]


@dataclass
class _LazyImportsSys(types.ModuleType, Generic[Obj]):
attribute_mappings: AttributeMappings[Obj]
class SysAttributeGetter:
attribute_name: str
mainstream_value: Any
stash: InjectedAttributeStash[Injection[Any], Any]

def __getattr__(self, name: str) -> Any:
def __call__(self) -> Any:
with suppress(LookupError):
injection = injection_var.get()
mapping = self.attribute_mappings[name] # type: ignore[literal-required]
mapping = self.stash[injection]
return mapping[injection]
return getattr(sys, name)


@dataclass
class LazyImportBuiltin:
def __call__(self, *args: Any, **kwds: Any) -> Any:
pass
return self.mainstream_value


@contextmanager
def lazy_imports(
*,
sys_path: StateAction = PERSIST,
sys_meta_path: StateAction = PERSIST,
sys_path_hooks: StateAction = PERSIST,
) -> None:
pass


def type_imports() -> Never:
raise NotImplementedError
sys_path: StateAction[Any] = PERSIST,
sys_meta_path: StateAction[Any] = PERSIST,
sys_path_hooks: StateAction[Any] = PERSIST,
) -> Generator[None]:
stash: dict[Injection[Any], Any] = {}

for attribute_name, action in (
("path", sys_path),
("meta_path", sys_meta_path),
("path_hooks", sys_path_hooks),
):
mainstream_value = getattr(sys, attribute_name)
if action.action_type is StateActionType.PERSIST:
action.data = copy(mainstream_value)
action.action_type = StateActionType.CONSTANT

peek_or_inject(
vars(sys),
attribute_name,
factory=SysAttributeGetter(
attribute_name=attribute_name,
mainstream_value=mainstream_value,
stash=stash,
),
)
vars(sys)[attribute_name]

yield
125 changes: 108 additions & 17 deletions injection/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

from contextlib import suppress
from contextvars import ContextVar, copy_context
from dataclasses import dataclass
from threading import Lock, RLock, get_ident
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
Literal,
NamedTuple,
TypeVar,
cast,
overload,
)
from weakref import WeakSet

from injection.compat import get_frame

Expand All @@ -28,6 +40,10 @@

Object_co = TypeVar("Object_co", covariant=True)

PEEK_MUTEX = RLock()
peeking_var: ContextVar[bool] = ContextVar("peeking", default=False)
peeked_early_var: ContextVar[EarlyObject[Any]] = ContextVar("peeked_early")


class InjectionKey(str):
__slots__ = ("origin", "hash", "reset", "early")
Expand All @@ -49,11 +65,19 @@ def __eq__(self, other: object) -> bool:
self.reset = False
return True

caller_locals = get_frame(1).f_locals
try:
caller_locals = get_frame(1).f_locals
except ValueError:
# can happen if we patch sys
return True

if caller_locals.get("__injection_recursive_guard__"):
return True

if peeking_var.get():
peeked_early_var.set(self.early)
return True

with self.early.__mutex__:
__injection_recursive_guard__ = True # noqa: F841
self.early.__inject__()
Expand All @@ -73,9 +97,19 @@ def strict_recursion_guard(early: EarlyObject[object]) -> Never:
raise RecursionError(msg)


class InjectionFactoryWrapper(NamedTuple, Generic[Object_co]):
actual_factory: Any
pass_scope: bool

def __call__(self, scope: Locals) -> Object_co:
if self.pass_scope:
return cast("Object_co", self.actual_factory(scope))
return cast("Object_co", self.actual_factory())


@dataclass
class Injection(Generic[Object_co]):
factory: Callable[..., Object_co]
actual_factory: Callable[..., Object_co]
pass_scope: bool = False
cache: bool = False
cache_per_alias: bool = False
Expand All @@ -84,52 +118,63 @@ class Injection(Generic[Object_co]):

_reassignment_lock: ClassVar[Lock] = Lock()

def _call_factory(self, scope: Locals) -> Object_co:
if self.pass_scope:
return self.factory(scope)
return self.factory()
@property
def factory(self) -> InjectionFactoryWrapper[Object_co]:
return InjectionFactoryWrapper(
actual_factory=self.actual_factory,
pass_scope=self.pass_scope,
)

def __post_init__(self) -> None:
if self.debug_info is None:
factory, cache, cache_per_alias = (
self.factory,
actual_factory, cache, cache_per_alias = (
self.actual_factory,
self.cache,
self.cache_per_alias,
)
init_opts = f"{factory=!r}, {cache=!r}, {cache_per_alias=!r}"
init_opts = f"{actual_factory=!r}, {cache=!r}, {cache_per_alias=!r}"
include = ""
if debug_info := self.debug_info:
include = f", {debug_info}"
self.debug_info = f"<injection {init_opts}{include}>"

def assign_to(self, *aliases: str, scope: Locals) -> None:
def assign_to(
self,
*aliases: str,
scope: Locals,
) -> WeakSet[EarlyObject[Object_co]]:
if not aliases:
msg = f"expected at least one alias in Injection.assign_to() ({self!r})"
raise ValueError(msg)

state = ObjectState(
state: ObjectState[Object_co] = ObjectState(
cache=self.cache,
factory=self._call_factory,
factory=self.factory, # type: ignore[arg-type] # not sure why
recursion_guard=self.recursion_guard,
debug_info=self.debug_info,
scope=scope,
)

cache_per_alias = self.cache_per_alias

early_objects: WeakSet[EarlyObject[Object_co]] = WeakSet()

for alias in aliases:
debug_info = f"{alias!r} from {self.debug_info}"
early = EarlyObject(
early_object = EarlyObject(
alias=alias,
state=state,
cache_per_alias=cache_per_alias,
debug_info=debug_info,
)
key = early.__key__
early_objects.add(early_object)
key = early_object.__key__

with self._reassignment_lock:
scope.pop(key, None)
scope[key] = early
scope[key] = early_object

return early_objects


SENTINEL = object()
Expand Down Expand Up @@ -286,7 +331,7 @@ def inject( # noqa: PLR0913
"""
inj = Injection(
factory=factory,
actual_factory=factory,
pass_scope=pass_scope,
cache_per_alias=cache_per_alias,
cache=cache,
Expand All @@ -295,3 +340,49 @@ def inject( # noqa: PLR0913
)
if into is not None and aliases:
inj.assign_to(*aliases, scope=into)


def peek(scope: Locals, alias: str) -> EarlyObject[Any] | None:
"""Safely get early object from a scope without triggering injection behavior."""
peeking_context = copy_context()
peeking_context.run(peeking_var.set, True) # noqa: FBT003
with suppress(KeyError):
peeking_context.run(scope.__getitem__, alias)
return peeking_context.get(peeked_early_var)


def peek_or_inject( # noqa: PLR0913
scope: Locals,
alias: str,
*,
factory: Callable[[], Object_co] | Callable[[Locals], Object_co],
pass_scope: bool = False,
cache: bool = False,
cache_per_alias: bool = False,
recursion_guard: Callable[[EarlyObject[Any]], object] = strict_recursion_guard,
debug_info: str | None = None,
) -> EarlyObject[Object_co]:
"""
Peek or inject as necessary in a thread-safe manner.
If an injection is present, return the existing early object.
If it is not present, create a new injection, inject it and return an early object.
This function works only for one alias at a time.
"""
with PEEK_MUTEX:
metadata = peek(scope, alias)
if metadata is None:
return next(
iter(
Injection(
actual_factory=factory,
pass_scope=pass_scope,
cache=cache,
cache_per_alias=cache_per_alias,
recursion_guard=recursion_guard,
debug_info=debug_info,
).assign_to(alias, scope=scope)
)
)
return metadata
Loading

0 comments on commit 211afd4

Please sign in to comment.