Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on_executing_* hook to extensions #1400

Merged
merged 9 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Add `on_executing_*` hooks to extensions to allow you to override the execution phase of a GraphQL operation.
83 changes: 83 additions & 0 deletions docs/guides/custom-extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,89 @@ class MyExtension(Extension):
print('GraphQL parsing end')
```

### Execution

`on_executing_start` and `on_executing_end` can be used to run code on the execution step of
the GraphQL execution. Both methods can be implemented asynchronously.

```python
from strawberry.extensions import Extension

class MyExtension(Extension):
def on_executing_start(self):
print('GraphQL execution start')

def on_executing_end(self):
print('GraphQL execution end')
```

#### Examples:

<details>
<summary>In memory cached execution</summary>

```python
import json
import strawberry
from strawberry.extensions import Extension

# Use an actual cache in production so that this doesn't grow unbounded
response_cache = {}

class ExecutionCache(Extension):
def on_executing_start(self):
# Check if we've come across this query before
execution_context = self.execution_context
self.cache_key = (
f"{execution_context.query}:{json.dumps(execution_context.variables)}"
)
if self.cache_key in response_cache:
self.execution_context.result = response_cache[self.cache_key]

def on_executing_end(self):
execution_context = self.execution_context
if self.cache_key not in response_cache:
response_cache[self.cache_key] = execution_context.result


schema = strawberry.Schema(
Query,
extensions=[
ExecutionCache,
]
)
```

</details>

<details>
<summary>Rejecting a request before executing it</summary>

```python
import strawberry
from strawberry.extensions import Extension

class RejectSomeQueries(Extension):
def on_executing_start(self):
# Reject all operations called "RejectMe"
execution_context = self.execution_context
if execution_context.operation_name == "RejectMe":
self.execution_context.result = GraphQLExecutionResult(
data=None,
errors=[GraphQLError("Well you asked for it")],
)


schema = strawberry.Schema(
Query,
extensions=[
RejectSomeQueries,
]
)
```

</details>

### Execution Context

The `Extension` object has an `execution_context` property on `self` of type
Expand Down
6 changes: 6 additions & 0 deletions strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def on_parsing_start(self) -> AwaitableOrValue[None]:
def on_parsing_end(self) -> AwaitableOrValue[None]:
"""This method is called after the parsing step"""

def on_executing_start(self) -> AwaitableOrValue[None]:
"""This method is called before the execution step"""

def on_executing_end(self) -> AwaitableOrValue[None]:
"""This method is called after the executing step"""

def resolve(
self, _next, root, info: Info, *args, **kwargs
) -> AwaitableOrValue[object]:
Expand Down
18 changes: 18 additions & 0 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,21 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
for extension in self.extensions:
await await_maybe(extension.on_parsing_end())


class ExecutingContextManager(ExtensionContextManager):
def __enter__(self):
for extension in self.extensions:
extension.on_executing_start()

def __exit__(self, exc_type, exc_val, exc_tb):
for extension in self.extensions:
extension.on_executing_end()

async def __aenter__(self):
for extension in self.extensions:
await await_maybe(extension.on_executing_start())

async def __aexit__(self, exc_type, exc_val, exc_tb):
for extension in self.extensions:
await await_maybe(extension.on_executing_end())
4 changes: 4 additions & 0 deletions strawberry/extensions/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from graphql import MiddlewareManager

from strawberry.extensions.context import (
ExecutingContextManager,
ParsingContextManager,
RequestContextManager,
ValidationContextManager,
Expand Down Expand Up @@ -49,6 +50,9 @@ def validation(self) -> ValidationContextManager:
def parsing(self) -> ParsingContextManager:
return ParsingContextManager(self.extensions)

def executing(self) -> ExecutingContextManager:
return ExecutingContextManager(self.extensions)

def get_extensions_results_sync(self) -> Dict[str, Any]:
data: Dict[str, Any] = {}

Expand Down
84 changes: 48 additions & 36 deletions strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,32 @@ async def execute(
if execution_context.errors:
return ExecutionResult(data=None, errors=execution_context.errors)

result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = await cast(Awaitable[GraphQLExecutionResult], result)
async with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

execution_context.result = cast(GraphQLExecutionResult, result)
if isawaitable(result):
result = await cast(Awaitable[GraphQLExecutionResult], result)

result = cast(GraphQLExecutionResult, result)
result = cast(GraphQLExecutionResult, result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

return ExecutionResult(
data=result.data,
errors=result.errors,
data=execution_context.result.data,
errors=execution_context.result.errors,
extensions=await extensions_runner.get_extensions_results(),
)

Expand Down Expand Up @@ -157,28 +162,35 @@ def execute_sync(
if execution_context.errors:
return ExecutionResult(data=None, errors=execution_context.errors)

result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)
with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
ensure_future(cast(Awaitable[GraphQLExecutionResult], result)).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")
if isawaitable(result):
result = cast(Awaitable[GraphQLExecutionResult], result)
ensure_future(result).cancel()
raise RuntimeError(
"GraphQL execution failed to complete synchronously."
)

result = cast(GraphQLExecutionResult, result)
execution_context.result = result
if result.errors:
execution_context.errors = result.errors
result = cast(GraphQLExecutionResult, result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

return ExecutionResult(
data=result.data,
errors=result.errors,
data=execution_context.result.data,
errors=execution_context.result.errors,
extensions=extensions_runner.get_extensions_results_sync(),
)
Loading