Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flask async view #1907

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions strawberry/flask/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,55 @@ def dispatch_request(self) -> Response:
response.set_data(json.dumps(response_data))

return response


class AsyncGraphQLView(GraphQLView):
async def dispatch_request(self):
method = request.method
content_type = request.content_type or ""

if "application/json" in content_type:
data: dict = request.json # type:ignore[assignment]
elif content_type.startswith("multipart/form-data"):
operations = json.loads(request.form.get("operations", "{}"))
files_map = json.loads(request.form.get("map", "{}"))

scottweitzner marked this conversation as resolved.
Show resolved Hide resolved
data = replace_placeholders_with_files(operations, files_map, request.files)
elif method == "GET" and request.args:
data = parse_query_params(request.args.to_dict())
elif method == "GET" and should_render_graphiql(self.graphiql, request):
template = render_graphiql_page()

return self.render_template(template=template)
else:
return Response("Unsupported Media Type", 415)

try:
request_data = parse_request_data(data)
except MissingQueryError:
return Response("No valid query was provided for the request", 400)

response = Response(status=200, content_type="application/json")
context = self.get_context(response)

allowed_operation_types = OperationType.from_http(method)

if not self.allow_queries_via_get and method == "GET":
allowed_operation_types = allowed_operation_types - {OperationType.QUERY}
scottweitzner marked this conversation as resolved.
Show resolved Hide resolved

try:
result = await self.schema.execute(
request_data.query,
variable_values=request_data.variables,
context_value=context,
operation_name=request_data.operation_name,
root_value=self.get_root_value(),
allowed_operation_types=allowed_operation_types,
)
except InvalidOperationTypeError as e:
return Response(e.as_http_error_reason(method), 400)

response_data = self.process_result(result)
response.set_data(json.dumps(response_data))

return response
21 changes: 15 additions & 6 deletions tests/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import strawberry
from flask import Flask
from strawberry.file_uploads import Upload
from strawberry.flask.views import GraphQLView as BaseGraphQLView
from strawberry.flask.views import (
AsyncGraphQLView as BaseAsyncGraphQLView,
GraphQLView as BaseGraphQLView,
)


def create_app(**kwargs):
def create_app(use_async_view=False, **kwargs):
@strawberry.input
class FolderInput:
files: typing.List[Upload]
Expand Down Expand Up @@ -47,11 +50,17 @@ class GraphQLView(BaseGraphQLView):
def get_root_value(self):
return Query()

class AsyncGraphQLView(BaseAsyncGraphQLView):
def get_root_value(self):
return Query()

app = Flask(__name__)
app.debug = True

app.add_url_rule(
"/graphql",
view_func=GraphQLView.as_view("graphql_view", schema=schema, **kwargs),
)
if use_async_view:
view_func = AsyncGraphQLView.as_view("graphql_view", schema=schema, **kwargs)
else:
view_func = GraphQLView.as_view("graphql_view", schema=schema, **kwargs)
app.add_url_rule("/graphql", view_func=view_func)

return app
6 changes: 6 additions & 0 deletions tests/flask/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def flask_client():
yield client


@pytest.fixture
def async_flask_client():
with create_app(use_async_view=True).test_client() as client:
yield client


@pytest.fixture
def flask_client_no_graphiql():
with create_app(graphiql=False).test_client() as client:
Expand Down
16 changes: 14 additions & 2 deletions tests/flask/test_query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_no_query(flask_client):
assert response.status_code == 400


def test_get_with_query_params(flask_client):
def test_get_with_query_params(flask_client, async_flask_client):
params = {
"query": """
query {
Expand All @@ -31,8 +31,14 @@ def test_get_with_query_params(flask_client):
assert response.status_code == 200
assert data["data"]["hello"] == "Hello world"

response = async_flask_client.get("/graphql", query_string=params)
data = json.loads(response.data.decode())

assert response.status_code == 200
assert data["data"]["hello"] == "Hello world"


def test_can_pass_variables_with_query_params(flask_client):
def test_can_pass_variables_with_query_params(flask_client, async_flask_client):
params = {
"query": "query Hello($name: String!) { hello(name: $name) }",
"variables": '{"name": "James"}',
Expand All @@ -44,6 +50,12 @@ def test_can_pass_variables_with_query_params(flask_client):
assert response.status_code == 200
assert data["data"]["hello"] == "Hello James"

response = async_flask_client.get("/graphql", query_string=params)
data = json.loads(response.data.decode())

assert response.status_code == 200
assert data["data"]["hello"] == "Hello James"


def test_post_fails_with_query_params(flask_client):
params = {
Expand Down
33 changes: 29 additions & 4 deletions tests/flask/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .app import create_app


def test_graphql_query(flask_client):
def test_graphql_query(flask_client, async_flask_client):
query = {
"query": """
query {
Expand All @@ -23,8 +23,14 @@ def test_graphql_query(flask_client):
assert response.status_code == 200
assert data["data"]["hello"] == "Hello world"

response = async_flask_client.get("/graphql", json=query)
data = json.loads(response.data.decode())

assert response.status_code == 200
assert data["data"]["hello"] == "Hello world"


def test_can_pass_variables(flask_client):
def test_can_pass_variables(flask_client, async_flask_client):
query = {
"query": "query Hello($name: String!) { hello(name: $name) }",
"variables": {"name": "James"},
Expand All @@ -36,23 +42,42 @@ def test_can_pass_variables(flask_client):
assert response.status_code == 200
assert data["data"]["hello"] == "Hello James"

response = async_flask_client.get("/graphql", json=query)
data = json.loads(response.data.decode())

assert response.status_code == 200
assert data["data"]["hello"] == "Hello James"


def test_fails_when_request_body_has_invalid_json(flask_client):
def test_fails_when_request_body_has_invalid_json(flask_client, async_flask_client):
response = flask_client.post(
"/graphql",
data='{"qeury": "{__typena"',
headers={"content-type": "application/json"},
)
assert response.status_code == 400

response = async_flask_client.post(
"/graphql",
data='{"qeury": "{__typena"',
headers={"content-type": "application/json"},
)
assert response.status_code == 400

def test_graphiql_view(flask_client):

def test_graphiql_view(flask_client, async_flask_client):
flask_client.environ_base["HTTP_ACCEPT"] = "text/html"
response = flask_client.get("/graphql")
body = response.data.decode()

assert "GraphiQL" in body

async_flask_client.environ_base["HTTP_ACCEPT"] = "text/html"
response = async_flask_client.get("/graphql")
body = response.data.decode()

assert "GraphiQL" in body


def test_graphiql_disabled_view():
app = create_app(graphiql=False)
Expand Down