Skip to content

Commit

Permalink
Add all missing type definitions (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Jul 1, 2020
1 parent 57c6e2a commit a7152f9
Show file tree
Hide file tree
Showing 72 changed files with 537 additions and 344 deletions.
16 changes: 16 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,19 @@ no_implicit_optional = True
strict_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
disallow_untyped_defs = True

[mypy-graphql.language.printer]
disallow_untyped_defs = False

[mypy-graphql.pyutils.frozen_dict]
disallow_untyped_defs = False

[mypy-graphql.pyutils.frozen_list]
disallow_untyped_defs = False

[mypy-graphql.type.introspection]
disallow_untyped_defs = False

[mypy-tests.*]
disallow_untyped_defs = False
10 changes: 5 additions & 5 deletions src/graphql/error/graphql_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def __init__(
if not self.__traceback__:
self.__traceback__ = exc_info()[2]

def __str__(self):
def __str__(self) -> str:
return print_error(self)

def __repr__(self):
def __repr__(self) -> str:
args = [repr(self.message)]
if self.locations:
args.append(f"locations={self.locations!r}")
Expand All @@ -149,7 +149,7 @@ def __repr__(self):
args.append(f"extensions={self.extensions!r}")
return f"{self.__class__.__name__}({', '.join(args)})"

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, GraphQLError)
and self.__class__ == other.__class__
Expand All @@ -165,11 +165,11 @@ def __eq__(self, other):
)
)

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other

@property
def formatted(self):
def formatted(self) -> Dict[str, Any]:
"""Get error formatted according to the specification."""
return format_error(self)

Expand Down
7 changes: 6 additions & 1 deletion src/graphql/error/syntax_error.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import TYPE_CHECKING

from .graphql_error import GraphQLError

if TYPE_CHECKING:
from ..language.source import Source # noqa: F401

__all__ = ["GraphQLSyntaxError"]


class GraphQLSyntaxError(GraphQLError):
"""A GraphQLError representing a syntax error."""

def __init__(self, source, position, description):
def __init__(self, source: "Source", position: int, description: str) -> None:
super().__init__(
f"Syntax Error: {description}", source=source, positions=[position]
)
Expand Down
47 changes: 28 additions & 19 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def build_response(
"""
if self.is_awaitable(data):

async def build_response_async():
async def build_response_async() -> ExecutionResult:
return self.build_response(await data) # type: ignore

return build_response_async()
Expand Down Expand Up @@ -285,6 +285,7 @@ def execute_operation(
#
# Similar to complete_value_catching_error.
try:
# noinspection PyArgumentList
result = (
self.execute_fields_serially
if operation.operation == OperationType.MUTATION
Expand All @@ -296,7 +297,7 @@ def execute_operation(
else:
if self.is_awaitable(result):
# noinspection PyShadowingNames
async def await_result():
async def await_result() -> Any:
try:
return await result # type: ignore
except GraphQLError as error:
Expand All @@ -316,7 +317,7 @@ def execute_fields_serially(
Implements the "Evaluating selection sets" section of the spec for "write" mode.
"""
results: Dict[str, Any] = {}
results: AwaitableOrValue[Dict[str, Any]] = {}
is_awaitable = self.is_awaitable
for response_name, field_nodes in fields.items():
field_path = Path(path, response_name)
Expand All @@ -327,30 +328,36 @@ def execute_fields_serially(
continue
if is_awaitable(results):
# noinspection PyShadowingNames
async def await_and_set_result(results, response_name, result):
async def await_and_set_result(
results: Awaitable[Dict[str, Any]],
response_name: str,
result: AwaitableOrValue[Any],
) -> Dict[str, Any]:
awaited_results = await results
awaited_results[response_name] = (
await result if is_awaitable(result) else result
)
return awaited_results

# noinspection PyTypeChecker
results = await_and_set_result(
cast(Awaitable, results), response_name, result
)
elif is_awaitable(result):
# noinspection PyShadowingNames
async def set_result(results, response_name, result):
async def set_result(
results: Dict[str, Any], response_name: str, result: Awaitable,
) -> Dict[str, Any]:
results[response_name] = await result
return results

# noinspection PyTypeChecker
results = set_result(results, response_name, result)
results = set_result(
cast(Dict[str, Any], results), response_name, result
)
else:
results[response_name] = result
cast(Dict[str, Any], results)[response_name] = result
if is_awaitable(results):
# noinspection PyShadowingNames
async def get_results():
async def get_results() -> Any:
return await cast(Awaitable, results)

return get_results()
Expand Down Expand Up @@ -389,7 +396,7 @@ def execute_fields(
# field, which is possibly a coroutine object. Return a coroutine object that
# will yield this same map, but with any coroutines awaited in parallel and
# replaced with the values they yielded.
async def get_results():
async def get_results() -> Dict[str, Any]:
results.update(
zip(
awaitable_fields,
Expand Down Expand Up @@ -579,7 +586,7 @@ def resolve_field_value_or_error(
result = resolve_fn(source, info, **args)
if self.is_awaitable(result):
# noinspection PyShadowingNames
async def await_result():
async def await_result() -> Any:
try:
return await result
except GraphQLError as error:
Expand Down Expand Up @@ -607,10 +614,11 @@ def complete_value_catching_error(
This is a small wrapper around completeValue which detects and logs errors in
the execution context.
"""
completed: AwaitableOrValue[Any]
try:
if self.is_awaitable(result):

async def await_result():
async def await_result() -> Any:
value = self.complete_value(
return_type, field_nodes, info, path, await result
)
Expand All @@ -625,7 +633,7 @@ async def await_result():
)
if self.is_awaitable(completed):
# noinspection PyShadowingNames
async def await_completed():
async def await_completed() -> Any:
try:
return await completed
except Exception as error:
Expand Down Expand Up @@ -783,7 +791,7 @@ def complete_list_value(
return completed_results

# noinspection PyShadowingNames
async def get_completed_results():
async def get_completed_results() -> Any:
for index, result in zip(
awaitable_indices,
await gather(
Expand Down Expand Up @@ -828,7 +836,7 @@ def complete_abstract_value(

if self.is_awaitable(runtime_type):

async def await_complete_object_value():
async def await_complete_object_value() -> Any:
value = self.complete_object_value(
self.ensure_valid_runtime_type(
await runtime_type, # type: ignore
Expand Down Expand Up @@ -912,14 +920,14 @@ def complete_object_value(

if self.is_awaitable(is_type_of):

async def collect_and_execute_subfields_async():
async def collect_and_execute_subfields_async() -> Dict[str, Any]:
if not await is_type_of: # type: ignore
raise invalid_return_type_error(
return_type, result, field_nodes
)
return self.collect_and_execute_subfields(
return_type, field_nodes, path, result
)
) # type: ignore

return collect_and_execute_subfields_async()

Expand Down Expand Up @@ -1158,11 +1166,12 @@ def default_type_resolver(

if awaitable_is_type_of_results:
# noinspection PyShadowingNames
async def get_type():
async def get_type() -> Optional[Union[GraphQLObjectType, str]]:
is_type_of_results = await gather(*awaitable_is_type_of_results)
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
if is_type_of_result:
return type_
return None

return get_type()

Expand Down
6 changes: 3 additions & 3 deletions src/graphql/execution/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_variable_values(
"""
errors: List[GraphQLError] = []

def on_error(error: GraphQLError):
def on_error(error: GraphQLError) -> None:
if max_errors is not None and len(errors) >= max_errors:
raise GraphQLError(
"Too many errors processing variables,"
Expand All @@ -74,7 +74,7 @@ def coerce_variable_values(
var_def_nodes: FrozenList[VariableDefinitionNode],
inputs: Dict[str, Any],
on_error: Callable[[GraphQLError], None],
):
) -> Dict[str, Any]:
coerced_values: Dict[str, Any] = {}
for var_def_node in var_def_nodes:
var_name = var_def_node.variable.name.value
Expand Down Expand Up @@ -123,7 +123,7 @@ def coerce_variable_values(

def on_input_value_error(
path: List[Union[str, int]], invalid_value: Any, error: GraphQLError
):
) -> None:
invalid_str = inspect(invalid_value)
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
if path:
Expand Down
Loading

0 comments on commit a7152f9

Please sign in to comment.