Skip to content

Commit

Permalink
Allow using a static method as a resolver (#1430)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Arminio <[email protected]>
Co-authored-by: ignormies <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2021
1 parent c5a1b61 commit 2139e63
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 33 deletions.
19 changes: 19 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Release type: patch

This release fixes an issue that prevented using `classmethod`s and `staticmethod`s as resolvers

```python
import strawberry

@strawberry.type
class Query:
@strawberry.field
@staticmethod
def static_text() -> str:
return "Strawberry"

@strawberry.field
@classmethod
def class_name(cls) -> str:
return cls.__name__
```
6 changes: 3 additions & 3 deletions strawberry/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .object_type import TypeDefinition


_RESOLVER_TYPE = Union[StrawberryResolver, Callable]
_RESOLVER_TYPE = Union[StrawberryResolver, Callable, staticmethod, classmethod]


class StrawberryField(dataclasses.Field):
Expand All @@ -48,7 +48,7 @@ def __init__(
python_name: Optional[str] = None,
graphql_name: Optional[str] = None,
type_annotation: Optional[StrawberryAnnotation] = None,
origin: Optional[Union[Type, Callable]] = None,
origin: Optional[Union[Type, Callable, staticmethod, classmethod]] = None,
is_subscription: bool = False,
description: Optional[str] = None,
base_resolver: Optional[StrawberryResolver] = None,
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self.type_annotation = type_annotation

self.description: Optional[str] = description
self.origin: Optional[Union[Type, Callable]] = origin
self.origin = origin

self._base_resolver: Optional[StrawberryResolver] = None
if base_resolver is not None:
Expand Down
17 changes: 15 additions & 2 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import inspect
import types
from typing import Callable, List, Optional, Sequence, Type, TypeVar, cast, overload

from strawberry.schema_directive import StrawberrySchemaDirective
Expand Down Expand Up @@ -123,10 +124,22 @@ def _process_type(
# https://github.com/python/cpython/blob/577d7c4e/Lib/dataclasses.py#L873-L880
# so we need to restore them, this will change in future, but for now this
# solution should suffice

for field_ in fields:
if field_.base_resolver and field_.python_name:
setattr(cls, field_.python_name, field_.base_resolver.wrapped_func)
wrapped_func = field_.base_resolver.wrapped_func

# Bind the functions to the class object. This is necessary because when
# the @strawberry.field decorator is used on @staticmethod/@classmethods,
# we get the raw staticmethod/classmethod objects before class evaluation
# binds them to the class. We need to do this manually.
if isinstance(wrapped_func, staticmethod):
bound_method = wrapped_func.__get__(cls)
field_.base_resolver.wrapped_func = bound_method
elif isinstance(wrapped_func, classmethod):
bound_method = types.MethodType(wrapped_func.__func__, cls)
field_.base_resolver.wrapped_func = bound_method

setattr(cls, field_.python_name, wrapped_func)

return cls

Expand Down
86 changes: 59 additions & 27 deletions strawberry/types/fields/resolver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
from __future__ import annotations as _

import builtins
import inspect
import sys
from inspect import isasyncgenfunction, iscoroutinefunction
from typing import Callable, Generic, List, Mapping, Optional, TypeVar, Union
from typing import Callable, Dict, Generic, List, Mapping, Optional, TypeVar, Union

from cached_property import cached_property # type: ignore

Expand All @@ -19,9 +19,12 @@


class StrawberryResolver(Generic[T]):
# TODO: Move to StrawberryArgument? StrawberryResolver ClassVar?
_SPECIAL_ARGS = {"root", "info", "self", "cls"}

def __init__(
self,
func: Callable[..., T],
func: Union[Callable[..., T], staticmethod, classmethod],
*,
description: Optional[str] = None,
type_override: Optional[Union[StrawberryType, type]] = None,
Expand All @@ -36,36 +39,46 @@ def __init__(

# TODO: Use this when doing the actual resolving? How to deal with async resolvers?
def __call__(self, *args, **kwargs) -> T:
if not callable(self.wrapped_func):
raise UncallableResolverError(self)
return self.wrapped_func(*args, **kwargs)

@cached_property
def arguments(self) -> List[StrawberryArgument]:
# TODO: Move to StrawberryArgument? StrawberryResolver ClassVar?
SPECIAL_ARGS = {"root", "self", "info"}
def annotations(self) -> Dict[str, object]:
"""Annotations for the resolver.
annotations = self.wrapped_func.__annotations__
parameters = inspect.signature(self.wrapped_func).parameters
function_arguments = set(parameters) - SPECIAL_ARGS
Does not include special args defined in _SPECIAL_ARGS (e.g. self, root, info)
"""
annotations = self._unbound_wrapped_func.__annotations__

annotations = {
name: annotation
for name, annotation in annotations.items()
if name not in (SPECIAL_ARGS | {"return"})
if name not in self._SPECIAL_ARGS
}

annotated_arguments = set(annotations)
arguments_missing_annotations = function_arguments - annotated_arguments
return annotations

@cached_property
def arguments(self) -> List[StrawberryArgument]:
parameters = inspect.signature(self._unbound_wrapped_func).parameters
function_arguments = set(parameters) - self._SPECIAL_ARGS

arguments = self.annotations.copy()
arguments.pop("return", None) # Discard return annotation to get just arguments

arguments_missing_annotations = function_arguments - set(arguments)

if any(arguments_missing_annotations):
raise MissingArgumentsAnnotationsError(
field_name=self.wrapped_func.__name__,
field_name=self.name,
arguments=arguments_missing_annotations,
)

module = sys.modules[self.wrapped_func.__module__]
module = sys.modules[self._module]
annotation_namespace = module.__dict__
arguments = []
for arg_name, annotation in annotations.items():
strawberry_arguments = []
for arg_name, annotation in arguments.items():
parameter = parameters[arg_name]

argument = StrawberryArgument(
Expand All @@ -77,40 +90,39 @@ def arguments(self) -> List[StrawberryArgument]:
default=parameter.default,
)

arguments.append(argument)
strawberry_arguments.append(argument)

return arguments
return strawberry_arguments

@cached_property
def has_info_arg(self) -> bool:
args = get_func_args(self.wrapped_func)
args = get_func_args(self._unbound_wrapped_func)
return "info" in args

@cached_property
def has_root_arg(self) -> bool:
args = get_func_args(self.wrapped_func)
args = get_func_args(self._unbound_wrapped_func)
return "root" in args

@cached_property
def has_self_arg(self) -> bool:
args = get_func_args(self.wrapped_func)
args = get_func_args(self._unbound_wrapped_func)
return args and args[0] == "self"

@cached_property
def name(self) -> str:
# TODO: What to do if resolver is a lambda?
return self.wrapped_func.__name__
return self._unbound_wrapped_func.__name__

@cached_property
def type_annotation(self) -> Optional[StrawberryAnnotation]:
try:
return_annotation = self.wrapped_func.__annotations__["return"]
return_annotation = self.annotations["return"]
except KeyError:
# No return annotation at all (as opposed to `-> None`)
return None

# TODO: PyCharm doesn't like this. Says `() -> ...` has no __module__ attribute
module = sys.modules[self.wrapped_func.__module__]
module = sys.modules[self._module]
type_annotation = StrawberryAnnotation(
annotation=return_annotation, namespace=module.__dict__
)
Expand All @@ -127,8 +139,8 @@ def type(self) -> Optional[Union[StrawberryType, type]]:

@cached_property
def is_async(self) -> bool:
return iscoroutinefunction(self.wrapped_func) or isasyncgenfunction(
self.wrapped_func
return iscoroutinefunction(self._unbound_wrapped_func) or isasyncgenfunction(
self._unbound_wrapped_func
)

def copy_with(
Expand All @@ -150,5 +162,25 @@ def copy_with(
type_override=type_override,
)

@cached_property
def _module(self) -> str:
return self._unbound_wrapped_func.__module__

@cached_property
def _unbound_wrapped_func(self) -> Callable[..., T]:
if isinstance(self.wrapped_func, (staticmethod, classmethod)):
return self.wrapped_func.__func__

return self.wrapped_func


class UncallableResolverError(Exception):
def __init__(self, resolver: "StrawberryResolver"):
message = (
f"Attempted to call resolver {resolver} with uncallable function "
f"{resolver.wrapped_func}"
)
super().__init__(message)


__all__ = ["StrawberryResolver"]
67 changes: 67 additions & 0 deletions tests/fields/test_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import re
from typing import ClassVar

import pytest

Expand All @@ -8,6 +10,7 @@
MissingFieldAnnotationError,
MissingReturnAnnotationError,
)
from strawberry.types.fields.resolver import StrawberryResolver, UncallableResolverError


def test_resolver_as_argument():
Expand Down Expand Up @@ -47,6 +50,52 @@ def name(self) -> str:
assert definition.fields[0].base_resolver(None) == Query().name()


def test_staticmethod_resolver_fields():
@strawberry.type
class Query:
@strawberry.field
@staticmethod
def name() -> str:
return "Name"

definition = Query._type_definition

assert definition.name == "Query"
assert len(definition.fields) == 1

assert definition.fields[0].python_name == "name"
assert definition.fields[0].graphql_name is None
assert definition.fields[0].type == str
assert definition.fields[0].base_resolver() == Query.name()

assert Query.name() == "Name"
assert Query().name() == "Name"


def test_classmethod_resolver_fields():
@strawberry.type
class Query:
my_val: ClassVar[str] = "thingy"

@strawberry.field
@classmethod
def val(cls) -> str:
return cls.my_val

definition = Query._type_definition

assert definition.name == "Query"
assert len(definition.fields) == 1

assert definition.fields[0].python_name == "val"
assert definition.fields[0].graphql_name is None
assert definition.fields[0].type == str
assert definition.fields[0].base_resolver() == Query.val()

assert Query.val() == "thingy"
assert Query().val() == "thingy"


def test_raises_error_when_return_annotation_missing():
with pytest.raises(MissingReturnAnnotationError) as e:

Expand Down Expand Up @@ -131,6 +180,24 @@ class Query: # noqa: F841
)


def test_raises_error_calling_uncallable_resolver():
@classmethod
def class_func(cls) -> int:
...

# Note that class_func is a raw classmethod object because it has not been bound
# to a class at this point
resolver = StrawberryResolver(class_func)

expected_error_message = re.escape(
f"Attempted to call resolver {resolver} with uncallable function "
f"{class_func}"
)

with pytest.raises(UncallableResolverError, match=expected_error_message):
resolver()


def test_can_reuse_resolver():
def get_name(self) -> str:
return "Name"
Expand Down
22 changes: 21 additions & 1 deletion tests/schema/test_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class Query:
assert result.data["helloWithParams"] == "I'm abc for helloWithParams"


def test_classmethods_resolvers():
def test_classmethod_resolvers():
global User

@strawberry.type
Expand Down Expand Up @@ -230,6 +230,26 @@ class Query:
del User


def test_staticmethod_resolvers():
class Alphabet:
@staticmethod
def get_letters() -> List[str]:
return ["a", "b", "c"]

@strawberry.type
class Query:
letters: List[str] = strawberry.field(resolver=Alphabet.get_letters)

schema = strawberry.Schema(query=Query)

query = "{ letters }"

result = schema.execute_sync(query)

assert not result.errors
assert result.data == {"letters": ["a", "b", "c"]}


def test_lambda_resolvers():
@strawberry.type
class Query:
Expand Down

0 comments on commit 2139e63

Please sign in to comment.