Skip to content

Commit

Permalink
fix: allow forward references in type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
MHajoha committed Jan 13, 2025
1 parent 2f51d6a commit 56cf302
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class SinglechoiceAttempt(Attempt):
question: "SinglechoiceQuestion"

def _compute_score(self) -> float:
if not self.response or "choice" not in self.response:
msg = "'choice' is missing"
Expand Down
10 changes: 4 additions & 6 deletions questionpy/_attempt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing

Check failure on line 1 in questionpy/_attempt.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

questionpy/_attempt.py:1:8: F401 `typing` imported but unused
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import cached_property
Expand All @@ -9,7 +10,7 @@
from questionpy_common.api.attempt import AttemptFile, AttemptUi, CacheControl, ScoredInputModel, ScoringCode

from ._ui import create_jinja2_environment
from ._util import get_mro_type_hint
from ._util import reify_type_hint

if TYPE_CHECKING:
from ._qtype import Question
Expand Down Expand Up @@ -138,8 +139,8 @@ class Attempt(ABC):
attempt_state: BaseAttemptState
scoring_state: BaseScoringState | None

attempt_state_class: ClassVar[type[BaseAttemptState]]
scoring_state_class: ClassVar[type[BaseScoringState]]
attempt_state_class: ClassVar[type[BaseAttemptState]] = reify_type_hint("attempt_state", BaseAttemptState)
scoring_state_class: ClassVar[type[BaseScoringState]] = reify_type_hint("scoring_state", BaseScoringState)

def __init__(
self,
Expand Down Expand Up @@ -246,9 +247,6 @@ def variant(self) -> int:
def __init_subclass__(cls, *args: object, **kwargs: object):
super().__init_subclass__(*args, **kwargs)

cls.attempt_state_class = get_mro_type_hint(cls, "attempt_state", BaseAttemptState)
cls.scoring_state_class = get_mro_type_hint(cls, "scoring_state", BaseScoringState)


class _ScoringError(Exception):
def __init__(self, scoring_code: ScoringCode, *args: object) -> None:
Expand Down
21 changes: 6 additions & 15 deletions questionpy/_qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from questionpy_common.environment import get_qpy_environment

from ._attempt import Attempt, AttemptProtocol, AttemptScoredProtocol, AttemptStartedProtocol
from ._util import get_mro_type_hint
from ._util import cached_class_property, reify_type_hint, get_mro_type_hint

Check failure on line 14 in questionpy/_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

questionpy/_qtype.py:14:60: F401 `._util.get_mro_type_hint` imported but unused
from .form import FormModel, OptionsFormDefinition
from .form.validation import validate_form

Check failure on line 16 in questionpy/_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

questionpy/_qtype.py:16:30: F401 `.form.validation.validate_form` imported but unused

Expand All @@ -36,9 +36,11 @@ class Question(ABC):
options: FormModel
question_state: BaseQuestionState

options_class: ClassVar[type[FormModel]]
question_state_class: ClassVar[type[BaseQuestionState]]
question_state_with_version_class: ClassVar[type[QuestionStateWithVersion]]
options_class: ClassVar[type[FormModel]] = reify_type_hint("options", FormModel)
question_state_class: ClassVar[type[BaseQuestionState]] = reify_type_hint("question_state", BaseQuestionState)
question_state_with_version_class: ClassVar[type[QuestionStateWithVersion]] = cached_class_property(
lambda cls: QuestionStateWithVersion[cls.options_class, cls.question_state_class] # type: ignore[name-defined]
)

def __init__(self, qswv: QuestionStateWithVersion) -> None:
self.question_state_with_version = qswv
Expand Down Expand Up @@ -164,17 +166,6 @@ def __init_subclass__(cls, *args: object, **kwargs: object) -> None:
msg = f"Missing '{cls.__name__}.attempt_class' attribute. It should point to your attempt implementation"
raise TypeError(msg)

cls.question_state_class = get_mro_type_hint(cls, "question_state", BaseQuestionState)
cls.options_class = get_mro_type_hint(cls, "options", FormModel)
cls.question_state_with_version_class = QuestionStateWithVersion[ # type: ignore[misc]
cls.options_class, cls.question_state_class # type: ignore[name-defined]
]

# A form may have unresolved references when it is intended to be used as a repetition, group, section, etc.
# Only the complete form must pass validation, so the validation has to happen here instead of in FormModel or
# OptionsFormDefinition.
validate_form(cls.options_class.qpy_form)

@property # type: ignore[no-redef]
def options(self) -> FormModel:
return self.question_state_with_version.options
Expand Down
46 changes: 42 additions & 4 deletions questionpy/_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,49 @@
from collections.abc import Callable
from types import UnionType
from typing import TypeVar, get_args, get_type_hints
from typing import Generic, TypeVar, cast, get_args, get_type_hints, TYPE_CHECKING

Check failure on line 3 in questionpy/_util.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

questionpy/_util.py:3:70: F401 `typing.TYPE_CHECKING` imported but unused

_T = TypeVar("_T", bound=type)
_TypeT = TypeVar("_TypeT", bound=type)

Check failure on line 5 in questionpy/_util.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (I001)

questionpy/_util.py:1:1: I001 Import block is un-sorted or un-formatted
_T = TypeVar("_T")
_UNSET = object()


def get_mro_type_hint(klass: type, attr_name: str, bound: _T) -> _T:
# FIXME: This function is called too early, when the classes referenced in forward refs may not be defined yet.
class _CachedClassProperty(Generic[_TypeT, _T]):
def __init__(self, getter: Callable[[_TypeT], _T]) -> None:
self._getter = getter
self._name: str | None = None

def __get__(self, instance: None, owner: _TypeT) -> _T:
if not self._name:
return self._getter(owner)

cached_value = owner.__dict__.get(self._name, _UNSET)
if cached_value is _UNSET:
value = self._getter(owner)
setattr(owner, self._name, value)
return value

return cached_value

def __set_name__(self, owner: _TypeT, name: str) -> None:
self._name = name


def cached_class_property(getter: Callable[[_TypeT], _T] | classmethod) -> _T:
"""Similar to [functools.cached_property][], but for class properties.
Like [functools.cached_property][], the descriptor replaces itself with the computed value after the first lookup.
"""
if isinstance(getter, classmethod):
getter = getter.__wrapped__
return cast(_T, _CachedClassProperty(getter))


def reify_type_hint(attr_name: str, bound: _TypeT) -> _TypeT:
"""Creates a [cached_class_property][] which returns the type hint of the given attribute."""
return cached_class_property(lambda cls: get_mro_type_hint(cls, attr_name, bound))


def get_mro_type_hint(klass: type, attr_name: str, bound: _TypeT) -> _TypeT:
for superclass in klass.mro():
hints = get_type_hints(superclass)
if attr_name in hints:
Expand Down
9 changes: 9 additions & 0 deletions questionpy/_wrappers/_qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from questionpy_common.environment import Package
from questionpy_common.manifest import PackageFile

from questionpy.form.validation import validate_form


class QuestionTypeWrapper(QuestionTypeInterface):

Check failure on line 21 in questionpy/_wrappers/_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (I001)

questionpy/_wrappers/_qtype.py:4:1: I001 Import block is un-sorted or un-formatted
def __init__(
Expand All @@ -34,6 +36,13 @@ def __init__(
functionality. This will probably, but not necessarily, be a subclass of the default
[QuestionWrapper][questionpy.QuestionWrapper].
"""

Check failure on line 38 in questionpy/_wrappers/_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (D202)

questionpy/_wrappers/_qtype.py:29:9: D202 No blank lines allowed after function docstring (found 1)

# A form may have unresolved references when it is intended to be used as a repetition, group, section, etc.
# Only the complete form must pass validation, so the validation has to happen here instead of in FormModel or
# OptionsFormDefinition.
# We also can't do this in Question.__init_subclass__ since the type hint of options may be a forward reference.
validate_form(question_class.options_class.qpy_form)

self._question_class = question_class
self._package = package

Expand Down
24 changes: 24 additions & 0 deletions tests/questionpy/test_attempt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from questionpy import Attempt, BaseAttemptState, BaseScoringState


class MyAttempt(Attempt):
formulation = ""

def _compute_score(self) -> float:
return 1

attempt_state: "MyAttemptState"
scoring_state: "MyScoringState"


class MyAttemptState(BaseAttemptState):
pass


class MyScoringState(BaseScoringState):
pass


def test_should_resolve_forward_references() -> None:
assert MyAttempt.attempt_state_class is MyAttemptState
assert MyAttempt.scoring_state_class is MyScoringState
26 changes: 26 additions & 0 deletions tests/questionpy/test_qtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing

Check failure on line 1 in tests/questionpy/test_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

tests/questionpy/test_qtype.py:1:8: F401 `typing` imported but unused

from questionpy import Attempt, BaseAttemptState, BaseScoringState, Question, BaseQuestionState
from questionpy._qtype import QuestionStateWithVersion
from questionpy.form import FormModel
from tests.questionpy.test_attempt import MyAttempt


class MyQuestion(Question):

Check failure on line 9 in tests/questionpy/test_qtype.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (I001)

tests/questionpy/test_qtype.py:1:1: I001 Import block is un-sorted or un-formatted
attempt_class = MyAttempt

options: "MyFormModel"
question_state: "MyQuestionState"


class MyFormModel(FormModel):
pass

class MyQuestionState(BaseQuestionState):
pass


def test_should_resolve_forward_references() -> None:
assert MyQuestion.options_class is MyFormModel
assert MyQuestion.question_state_class is MyQuestionState
assert MyQuestion.question_state_with_version_class is QuestionStateWithVersion[MyFormModel, MyQuestionState]
1 change: 1 addition & 0 deletions tests/questionpy/wrappers/test_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_should_export_question_model(package: Package) -> None:
assert question_model == QuestionModel(lang="en", scoring_method=ScoringMethod.AUTOMATICALLY_SCORABLE)


# TODO: Test QuestionTypeWrapper here
def test_should_raise_when_form_is_invalid() -> None:
class ModelWithUnresolvedReference(FormModel):
my_text: str | None = text_input("Label", hide_if=is_not_checked("nonexistent_checkbox"))
Expand Down

0 comments on commit 56cf302

Please sign in to comment.