From 316ccb2d3cf1082969da0721bc494eb002188883 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 24 Nov 2021 21:18:40 +0000 Subject: [PATCH] Change how Private is implemented to better support type checkers (#1437) --- RELEASE.md | 4 ++ strawberry/experimental/pydantic/utils.py | 6 ++- strawberry/ext/mypy_plugin.py | 15 ------- strawberry/private.py | 46 +++++++++++---------- strawberry/types/type_resolver.py | 6 +-- tests/pyright/test_private.py | 50 +++++++++++++++++++++++ 6 files changed, 86 insertions(+), 41 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/pyright/test_private.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..209eae28bb --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +This release changes how `strawberry.Private` is implemented to +improve support for type checkers. diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index bb76ff650a..ce5aa960aa 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -2,7 +2,7 @@ from typing import Any, List, Type from strawberry.experimental.pydantic.exceptions import UnregisteredTypeException -from strawberry.private import Private +from strawberry.private import is_private from strawberry.utils.typing import ( get_list_annotation, get_optional_annotation, @@ -30,7 +30,9 @@ def get_strawberry_type_from_model(type_: Any): def get_private_fields(cls: Type) -> List[dataclasses.Field]: private_fields: List[dataclasses.Field] = [] + for field in dataclasses.fields(cls): - if isinstance(field.type, Private): + if is_private(field.type): private_fields.append(field) + return private_fields diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index e4e699a4db..7db037b89c 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -78,13 +78,6 @@ def strawberry_field_hook(ctx: FunctionContext) -> Type: return AnyType(TypeOfAny.special_form) -def private_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type: - type_name = ctx.type.args[0] - type_ = ctx.api.analyze_type(type_name) - - return type_ - - def _get_named_type(name: str, api: SemanticAnalyzerPluginInterface): if "." in name: return api.named_type_or_none(name) # type: ignore @@ -638,9 +631,6 @@ def get_type_analyze_hook(self, fullname: str): if self._is_strawberry_lazy_type(fullname): return lazy_type_analyze_callback - if self._is_strawberry_private(fullname): - return private_type_analyze_callback - return None def get_class_decorator_hook( @@ -682,11 +672,6 @@ def _is_strawberry_enum(self, fullname: str) -> bool: def _is_strawberry_lazy_type(self, fullname: str) -> bool: return fullname == "strawberry.lazy_type.LazyType" - def _is_strawberry_private(self, fullname: str) -> bool: - return fullname == "strawberry.private.Private" or fullname.endswith( - "strawberry.Private" - ) - def _is_strawberry_decorator(self, fullname: str) -> bool: if any( strawberry_decorator in fullname diff --git a/strawberry/private.py b/strawberry/private.py index 30bed922f2..e0d0bc972b 100644 --- a/strawberry/private.py +++ b/strawberry/private.py @@ -1,27 +1,31 @@ -class Private: - """Represent a private field that won't be converted into a GraphQL field +from typing import TypeVar - Example: +from typing_extensions import Annotated, get_args, get_origin - >>> import strawberry - >>> @strawberry.type - ... class User: - ... name: str - ... age: strawberry.Private[int] - """ - __slots__ = ("type",) +class StrawberryPrivate: + ... - def __init__(self, type): - self.type = type - def __repr__(self): - if isinstance(self.type, type): - type_name = self.type.__name__ - else: - # typing objects, e.g. List[int] - type_name = repr(self.type) - return f"strawberry.Private[{type_name}]" +T = TypeVar("T") - def __class_getitem__(cls, type): - return Private(type) +Private = Annotated[T, StrawberryPrivate()] +Private.__doc__ = """Represent a private field that won't be converted into a GraphQL field + +Example: + +>>> import strawberry +>>> @strawberry.type +... class User: +... name: str +... age: strawberry.Private[int] +""" + + +def is_private(type_: object) -> bool: + if get_origin(type_) is Annotated: + return any( + isinstance(argument, StrawberryPrivate) for argument in get_args(type_) + ) + + return False diff --git a/strawberry/types/type_resolver.py b/strawberry/types/type_resolver.py index 7e5f51d716..332a5154b9 100644 --- a/strawberry/types/type_resolver.py +++ b/strawberry/types/type_resolver.py @@ -9,7 +9,7 @@ PrivateStrawberryFieldError, ) from strawberry.field import StrawberryField -from strawberry.private import Private +from strawberry.private import is_private from ..arguments import UNSET @@ -81,7 +81,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o if isinstance(field, StrawberryField): # Check that the field type is not Private - if isinstance(field.type, Private): + if is_private(field.type): raise PrivateStrawberryFieldError(field.python_name, cls.__name__) # Check that default is not set if a resolver is defined @@ -125,7 +125,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o # Create a StrawberryField for fields that didn't use strawberry.field else: # Only ignore Private fields that weren't defined using StrawberryFields - if isinstance(field.type, Private): + if is_private(field.type): continue field_type = field.type diff --git a/tests/pyright/test_private.py b/tests/pyright/test_private.py new file mode 100644 index 0000000000..fe52f3c855 --- /dev/null +++ b/tests/pyright/test_private.py @@ -0,0 +1,50 @@ +from .utils import Result, requires_pyright, run_pyright, skip_on_windows + + +pytestmark = [skip_on_windows, requires_pyright] + + +CODE = """ +import strawberry + + +@strawberry.type +class User: + name: str + age: strawberry.Private[int] + + +patrick = User(name="Patrick", age=1) +User(n="Patrick") + +reveal_type(patrick.name) +reveal_type(patrick.age) +""" + + +def test_pyright(): + results = run_pyright(CODE) + + assert results == [ + Result( + type="error", + message='No parameter named "n" (reportGeneralTypeIssues)', + line=12, + column=6, + ), + Result( + type="error", + message=( + "Arguments missing for parameters " + '"name", "age" (reportGeneralTypeIssues)' + ), + line=12, + column=1, + ), + Result( + type="info", message='Type of "patrick.name" is "str"', line=14, column=13 + ), + Result( + type="info", message='Type of "patrick.age" is "int"', line=15, column=13 + ), + ]