Skip to content

Commit

Permalink
fix: allow forward references in type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
MHajoha committed Jan 8, 2025
1 parent 2f51d6a commit e16c276
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class SinglechoiceAttempt(Attempt):
question: "SinglechoiceQuestion"

def _compute_score(self) -> float:
if not self.response or "choice" not in self.response:
msg = "'choice' is missing"
Expand Down
9 changes: 3 additions & 6 deletions questionpy/_attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from questionpy_common.api.attempt import AttemptFile, AttemptUi, CacheControl, ScoredInputModel, ScoringCode

from ._ui import create_jinja2_environment
from ._util import get_mro_type_hint
from ._util import reify_type_hint

if TYPE_CHECKING:
from ._qtype import Question
Expand Down Expand Up @@ -138,8 +138,8 @@ class Attempt(ABC):
attempt_state: BaseAttemptState
scoring_state: BaseScoringState | None

attempt_state_class: ClassVar[type[BaseAttemptState]]
scoring_state_class: ClassVar[type[BaseScoringState]]
attempt_state_class: ClassVar[type[BaseAttemptState]] = reify_type_hint("attempt_state", BaseAttemptState)
scoring_state_class: ClassVar[type[BaseScoringState]] = reify_type_hint("scoring_state", BaseScoringState)

def __init__(
self,
Expand Down Expand Up @@ -246,9 +246,6 @@ def variant(self) -> int:
def __init_subclass__(cls, *args: object, **kwargs: object):
super().__init_subclass__(*args, **kwargs)

cls.attempt_state_class = get_mro_type_hint(cls, "attempt_state", BaseAttemptState)
cls.scoring_state_class = get_mro_type_hint(cls, "scoring_state", BaseScoringState)


class _ScoringError(Exception):
def __init__(self, scoring_code: ScoringCode, *args: object) -> None:
Expand Down
14 changes: 5 additions & 9 deletions questionpy/_qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from questionpy_common.environment import get_qpy_environment

from ._attempt import Attempt, AttemptProtocol, AttemptScoredProtocol, AttemptStartedProtocol
from ._util import get_mro_type_hint
from ._util import cached_class_property, get_mro_type_hint, reify_type_hint

Check failure on line 14 in questionpy/_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

questionpy/_qtype.py:14:43: F401 `._util.get_mro_type_hint` imported but unused
from .form import FormModel, OptionsFormDefinition
from .form.validation import validate_form

Expand All @@ -37,8 +37,10 @@ class Question(ABC):
question_state: BaseQuestionState

options_class: ClassVar[type[FormModel]]
question_state_class: ClassVar[type[BaseQuestionState]]
question_state_with_version_class: ClassVar[type[QuestionStateWithVersion]]
question_state_class: ClassVar[type[BaseQuestionState]] = reify_type_hint("question_state", BaseQuestionState)
question_state_with_version_class: ClassVar[type[QuestionStateWithVersion]] = cached_class_property(
lambda cls: QuestionStateWithVersion[cls.options_class, cls.question_state_class] # type: ignore[name-defined]
)

def __init__(self, qswv: QuestionStateWithVersion) -> None:
self.question_state_with_version = qswv
Expand Down Expand Up @@ -164,12 +166,6 @@ def __init_subclass__(cls, *args: object, **kwargs: object) -> None:
msg = f"Missing '{cls.__name__}.attempt_class' attribute. It should point to your attempt implementation"
raise TypeError(msg)

cls.question_state_class = get_mro_type_hint(cls, "question_state", BaseQuestionState)
cls.options_class = get_mro_type_hint(cls, "options", FormModel)
cls.question_state_with_version_class = QuestionStateWithVersion[ # type: ignore[misc]
cls.options_class, cls.question_state_class # type: ignore[name-defined]
]

# A form may have unresolved references when it is intended to be used as a repetition, group, section, etc.
# Only the complete form must pass validation, so the validation has to happen here instead of in FormModel or
# OptionsFormDefinition.
Expand Down
44 changes: 40 additions & 4 deletions questionpy/_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
from collections.abc import Callable
from types import UnionType
from typing import TypeVar, get_args, get_type_hints
from typing import Generic, TypeVar, cast, get_args, get_type_hints

_T = TypeVar("_T", bound=type)
_TypeT = TypeVar("_TypeT", bound=type)
_T = TypeVar("_T")
_UNSET = object()


def get_mro_type_hint(klass: type, attr_name: str, bound: _T) -> _T:
# FIXME: This function is called too early, when the classes referenced in forward refs may not be defined yet.
class _CachedClassProperty(Generic[_TypeT, _T]):
def __init__(self, getter: Callable[[_TypeT], _T]) -> None:
self._getter = getter
self._name: str | None = None

def __get__(self, instance: None, owner: _TypeT) -> _T:
if not self._name:
return self._getter(owner)

cached_value = owner.__dict__.get(self._name, _UNSET)
if cached_value is _UNSET:
value = self._getter(owner)
setattr(owner, self._name, value)
return value

return cached_value

def __set_name__(self, owner: _TypeT, name: str) -> None:
self._name = name


def cached_class_property(getter: Callable[[_TypeT], _T]) -> _T:
"""Similar to [functools.cached_property][], but for class properties.
Like [functools.cached_property][], the descriptor replaces itself with the computed value after the first lookup.
"""
return cast(_T, _CachedClassProperty(getter))


def reify_type_hint(attr_name: str, bound: _TypeT) -> _TypeT:
"""Creates a [cached_class_property][] which returns the type hint of the given attribute."""
return cached_class_property(lambda cls: get_mro_type_hint(cls, attr_name, bound))


def get_mro_type_hint(klass: type, attr_name: str, bound: _TypeT) -> _TypeT:
for superclass in klass.mro():
hints = get_type_hints(superclass)
if attr_name in hints:
Expand Down

0 comments on commit e16c276

Please sign in to comment.