Skip to content

Commit

Permalink
Change how Private is implemented to better support type checkers (#1437
Browse files Browse the repository at this point in the history
)
  • Loading branch information
patrick91 authored Nov 24, 2021
1 parent 1edb1e4 commit 316ccb2
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 41 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: minor

This release changes how `strawberry.Private` is implemented to
improve support for type checkers.
6 changes: 4 additions & 2 deletions strawberry/experimental/pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, List, Type

from strawberry.experimental.pydantic.exceptions import UnregisteredTypeException
from strawberry.private import Private
from strawberry.private import is_private
from strawberry.utils.typing import (
get_list_annotation,
get_optional_annotation,
Expand Down Expand Up @@ -30,7 +30,9 @@ def get_strawberry_type_from_model(type_: Any):

def get_private_fields(cls: Type) -> List[dataclasses.Field]:
private_fields: List[dataclasses.Field] = []

for field in dataclasses.fields(cls):
if isinstance(field.type, Private):
if is_private(field.type):
private_fields.append(field)

return private_fields
15 changes: 0 additions & 15 deletions strawberry/ext/mypy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ def strawberry_field_hook(ctx: FunctionContext) -> Type:
return AnyType(TypeOfAny.special_form)


def private_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type:
type_name = ctx.type.args[0]
type_ = ctx.api.analyze_type(type_name)

return type_


def _get_named_type(name: str, api: SemanticAnalyzerPluginInterface):
if "." in name:
return api.named_type_or_none(name) # type: ignore
Expand Down Expand Up @@ -638,9 +631,6 @@ def get_type_analyze_hook(self, fullname: str):
if self._is_strawberry_lazy_type(fullname):
return lazy_type_analyze_callback

if self._is_strawberry_private(fullname):
return private_type_analyze_callback

return None

def get_class_decorator_hook(
Expand Down Expand Up @@ -682,11 +672,6 @@ def _is_strawberry_enum(self, fullname: str) -> bool:
def _is_strawberry_lazy_type(self, fullname: str) -> bool:
return fullname == "strawberry.lazy_type.LazyType"

def _is_strawberry_private(self, fullname: str) -> bool:
return fullname == "strawberry.private.Private" or fullname.endswith(
"strawberry.Private"
)

def _is_strawberry_decorator(self, fullname: str) -> bool:
if any(
strawberry_decorator in fullname
Expand Down
46 changes: 25 additions & 21 deletions strawberry/private.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
class Private:
"""Represent a private field that won't be converted into a GraphQL field
from typing import TypeVar

Example:
from typing_extensions import Annotated, get_args, get_origin

>>> import strawberry
>>> @strawberry.type
... class User:
... name: str
... age: strawberry.Private[int]
"""

__slots__ = ("type",)
class StrawberryPrivate:
...

def __init__(self, type):
self.type = type

def __repr__(self):
if isinstance(self.type, type):
type_name = self.type.__name__
else:
# typing objects, e.g. List[int]
type_name = repr(self.type)
return f"strawberry.Private[{type_name}]"
T = TypeVar("T")

def __class_getitem__(cls, type):
return Private(type)
Private = Annotated[T, StrawberryPrivate()]
Private.__doc__ = """Represent a private field that won't be converted into a GraphQL field
Example:
>>> import strawberry
>>> @strawberry.type
... class User:
... name: str
... age: strawberry.Private[int]
"""


def is_private(type_: object) -> bool:
if get_origin(type_) is Annotated:
return any(
isinstance(argument, StrawberryPrivate) for argument in get_args(type_)
)

return False
6 changes: 3 additions & 3 deletions strawberry/types/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PrivateStrawberryFieldError,
)
from strawberry.field import StrawberryField
from strawberry.private import Private
from strawberry.private import is_private

from ..arguments import UNSET

Expand Down Expand Up @@ -81,7 +81,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o

if isinstance(field, StrawberryField):
# Check that the field type is not Private
if isinstance(field.type, Private):
if is_private(field.type):
raise PrivateStrawberryFieldError(field.python_name, cls.__name__)

# Check that default is not set if a resolver is defined
Expand Down Expand Up @@ -125,7 +125,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
# Create a StrawberryField for fields that didn't use strawberry.field
else:
# Only ignore Private fields that weren't defined using StrawberryFields
if isinstance(field.type, Private):
if is_private(field.type):
continue

field_type = field.type
Expand Down
50 changes: 50 additions & 0 deletions tests/pyright/test_private.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from .utils import Result, requires_pyright, run_pyright, skip_on_windows


pytestmark = [skip_on_windows, requires_pyright]


CODE = """
import strawberry
@strawberry.type
class User:
name: str
age: strawberry.Private[int]
patrick = User(name="Patrick", age=1)
User(n="Patrick")
reveal_type(patrick.name)
reveal_type(patrick.age)
"""


def test_pyright():
results = run_pyright(CODE)

assert results == [
Result(
type="error",
message='No parameter named "n" (reportGeneralTypeIssues)',
line=12,
column=6,
),
Result(
type="error",
message=(
"Arguments missing for parameters "
'"name", "age" (reportGeneralTypeIssues)'
),
line=12,
column=1,
),
Result(
type="info", message='Type of "patrick.name" is "str"', line=14, column=13
),
Result(
type="info", message='Type of "patrick.age" is "int"', line=15, column=13
),
]

0 comments on commit 316ccb2

Please sign in to comment.