From 0c601469911e09f9c7ef263ade7dc0a23822eec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sat, 7 Oct 2023 00:54:08 +0200 Subject: [PATCH] feat(common): support `Self` annotations for `Annotable` --- ibis/common/grounds.py | 2 +- ibis/common/patterns.py | 4 +++- ibis/common/tests/test_graph_benchmarks.py | 19 +++++++++++++++++++ ibis/common/tests/test_grounds.py | 20 +++++++++++++++++++- ibis/common/tests/test_typing.py | 13 ++++++++++--- ibis/common/typing.py | 13 ++++++++++--- 6 files changed, 62 insertions(+), 9 deletions(-) create mode 100644 ibis/common/tests/test_graph_benchmarks.py diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index 951e319b3ae4..af5674312368 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -50,7 +50,7 @@ def __new__(metacls, clsname, bases, dct, **kwargs): annotations = dct.get("__annotations__", {}) # TODO(kszucs): pass dct as localns to evaluate_annotations - typehints = evaluate_annotations(annotations, module) + typehints = evaluate_annotations(annotations, module, clsname) for name, typehint in typehints.items(): if get_origin(typehint) is ClassVar: continue diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index ca3f73d29fda..758d00ef684b 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -119,9 +119,11 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: elif isinstance(annot, Enum): # for enums we check the value against the enum values return EqualTo(annot) - elif isinstance(annot, (str, ForwardRef)): + elif isinstance(annot, str): # for strings and forward references we check in a lazy way return LazyInstanceOf(annot) + elif isinstance(annot, ForwardRef): + return LazyInstanceOf(annot.__forward_arg__) else: raise TypeError(f"Cannot create validator from annotation {annot!r}") elif origin is CoercedTo: diff --git a/ibis/common/tests/test_graph_benchmarks.py b/ibis/common/tests/test_graph_benchmarks.py new file mode 100644 index 000000000000..12529f894f55 --- /dev/null +++ b/ibis/common/tests/test_graph_benchmarks.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ibis.common.collections import frozendict # noqa: TCH001 +from ibis.common.graph import Node +from ibis.common.grounds import Concrete + +if TYPE_CHECKING: + from typing_extensions import Self + + +class MyNode(Node, Concrete): + a: int + b: str + c: tuple[int, ...] + d: frozendict[str, int] + e: Self + f: tuple[Self, ...] diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index c75957d778a5..5d5263aa583a 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -5,7 +5,7 @@ import sys import weakref from abc import ABCMeta -from typing import Callable, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union import pytest @@ -42,6 +42,9 @@ ) from ibis.tests.util import assert_pickle_roundtrip +if TYPE_CHECKING: + from typing_extensions import Self + is_any = InstanceOf(object) is_bool = InstanceOf(bool) is_float = InstanceOf(float) @@ -314,6 +317,21 @@ class Op2(Annotable): Op2() +class RecursiveNode(Annotable): + child: Optional[Self] = None + + +def test_annotable_with_self_typehint() -> None: + node = RecursiveNode(RecursiveNode(RecursiveNode(None))) + assert isinstance(node, RecursiveNode) + assert isinstance(node.child, RecursiveNode) + assert isinstance(node.child.child, RecursiveNode) + assert node.child.child.child is None + + with pytest.raises(ValidationError): + RecursiveNode(1) + + def test_annotable_with_recursive_generic_type_annotations(): # testing cons list pattern = Pattern.from_typehint(List[Integer]) diff --git a/ibis/common/tests/test_typing.py b/ibis/common/tests/test_typing.py index 2991314ec25e..d17e2599db98 100644 --- a/ibis/common/tests/test_typing.py +++ b/ibis/common/tests/test_typing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, Optional, Union +from typing import ForwardRef, Generic, Optional, Union from typing_extensions import TypeVar @@ -41,11 +41,18 @@ def example(a: int, b: str) -> str: # type: ignore def test_evaluate_annotations() -> None: - annotations = {"a": "Union[int, str]", "b": "Optional[str]"} - hints = evaluate_annotations(annotations, module_name=__name__) + annots = {"a": "Union[int, str]", "b": "Optional[str]"} + hints = evaluate_annotations(annots, module_name=__name__) assert hints == {"a": Union[int, str], "b": Optional[str]} +def test_evaluate_annotations_with_self() -> None: + annots = {"a": "Union[int, Self]", "b": "Optional[Self]"} + myhint = ForwardRef(f"{__name__}.My") + hints = evaluate_annotations(annots, module_name=__name__, class_name="My") + assert hints == {"a": Union[int, myhint], "b": Optional[myhint]} + + def test_get_type_hints() -> None: hints = get_type_hints(My) assert hints == {"a": T, "b": S, "c": str} diff --git a/ibis/common/typing.py b/ibis/common/typing.py index feac3b2d853e..d263b9faab46 100644 --- a/ibis/common/typing.py +++ b/ibis/common/typing.py @@ -167,7 +167,9 @@ def get_bound_typevars(obj: Any) -> dict[TypeVar, tuple[str, type]]: def evaluate_annotations( - annots: dict[str, str], module_name: str, localns: Optional[Namespace] = None + annots: dict[str, str], + module_name: str, + class_name: Optional[str] = None, ) -> dict[str, Any]: """Evaluate type annotations that are strings. @@ -178,8 +180,9 @@ def evaluate_annotations( module_name The name of the module that the annotations are defined in, hence providing global scope. - localns - The local namespace to use for evaluation. + class_name + The name of the class that the annotations are defined in, hence + providing Self type. Returns ------- @@ -193,6 +196,10 @@ def evaluate_annotations( """ module = sys.modules.get(module_name, None) globalns = getattr(module, "__dict__", None) + if class_name is None: + localns = None + else: + localns = dict(Self=f"{module_name}.{class_name}") return { k: eval(v, globalns, localns) if isinstance(v, str) else v # noqa: PGH001 for k, v in annots.items()