Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization Hints #5

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions graphene_django_extras/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
is_valid_django_model,
DJANGO_FILTER_INSTALLED,
)
from graphql.execution.base import get_field_def
from graphene_django_extras.settings import graphql_api_settings

from graphene_django_extras.filters.filter import get_filterset_class
from .base_types import DjangoListObjectBase
from .paginations.pagination import BaseDjangoGraphqlPagination
from .utils import get_extra_filters, queryset_factory, get_related_fields, find_field
from .utils import get_extra_filters, queryset_factory, get_related_fields, find_field, get_type


# *********************************************** #
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs,
**kwargs
):

if DJANGO_FILTER_INSTALLED:
Expand Down Expand Up @@ -129,7 +130,7 @@ def list_resolver(manager, filterset_class, filtering_args, root, info, **kwargs
qs = None

if qs is None:
qs = queryset_factory(manager, info.field_asts, info.fragments, **kwargs)
qs = queryset_factory(manager, info, **kwargs)
qs = filterset_class(
data=filter_kwargs, queryset=qs, request=info.context
).qs
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs,
**kwargs
):

_fields = _type._meta.filter_fields
Expand Down Expand Up @@ -219,14 +220,11 @@ def __init__(
def model(self):
return self.type.of_type._meta.node._meta.model

def get_queryset(self, manager, info, **kwargs):
return queryset_factory(manager, info.field_asts, info.fragments, **kwargs)

def list_resolver(
self, manager, filterset_class, filtering_args, root, info, **kwargs
):
filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args}
qs = self.get_queryset(manager, info, **kwargs)
qs = queryset_factory(manager, info, **kwargs)
qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs

if root and is_valid_django_model(root._meta.model):
Expand Down Expand Up @@ -258,7 +256,7 @@ def __init__(
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs,
**kwargs
):

if DJANGO_FILTER_INSTALLED:
Expand Down Expand Up @@ -299,7 +297,7 @@ def list_resolver(
self, manager, filterset_class, filtering_args, root, info, **kwargs
):

qs = queryset_factory(manager, info.field_asts, info.fragments, **kwargs)
qs = queryset_factory(manager, info, **kwargs)

filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args}

Expand All @@ -318,4 +316,4 @@ def get_resolver(self, parent_resolver):
self.type._meta.model._default_manager,
self.filterset_class,
self.filtering_args,
)
)
19 changes: 19 additions & 0 deletions graphene_django_extras/hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
__all__ = ["OptimizationHints", "resolver_hints"]


class OptimizationHints(object):
def __init__(self, model_field=None, select_related=None, prefetch_related=None, only=None):
self.model_field = model_field
self.prefetch_related = set(prefetch_related) if prefetch_related else set()
self.select_related = set(select_related) if select_related else set()
self.only = set(only) if only else set()


def resolver_hints(*args, **kwargs):
optimization_hints = OptimizationHints(*args, **kwargs)

def apply_resolver_hints(resolver):
resolver.optimization_hints = optimization_hints
return resolver

return apply_resolver_hints
31 changes: 29 additions & 2 deletions graphene_django_extras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ManyToManyRel,
)
from django.db.models.base import ModelBase
from graphql.execution.base import get_field_def
from graphene.utils.str_converters import to_snake_case
from graphene_django.utils import is_valid_django_model
from graphql import GraphQLList, GraphQLNonNull
Expand Down Expand Up @@ -295,13 +296,23 @@ def find_field(field, fields_dict):


def recursive_params(
selection_set, fragments, available_related_fields, select_related, prefetch_related
root_info,
current_parent_type,
selection_set,
fragments,
available_related_fields,
select_related,
prefetch_related,
):

for field in selection_set.selections:
field_def = get_field_def(root_info.schema, current_parent_type, field.name.value)
field_type = get_type(field_def.type)

if isinstance(field, FragmentSpread) and fragments:
a, b = recursive_params(
root_info,
field_type,
fragments[field.name.value].selection_set,
fragments,
available_related_fields,
Expand All @@ -314,6 +325,8 @@ def recursive_params(

if isinstance(field, InlineFragment):
a, b = recursive_params(
root_info,
FrownyFace marked this conversation as resolved.
Show resolved Hide resolved
field_type,
field.selection_set,
fragments,
available_related_fields,
Expand All @@ -324,6 +337,11 @@ def recursive_params(
[prefetch_related.append(x) for x in b if x not in prefetch_related]
continue

optimization_hints = getattr(field_def.resolver, "optimization_hints", None)
if optimization_hints:
select_related = sorted(set(select_related) | optimization_hints.select_related)
prefetch_related = sorted(set(prefetch_related) | optimization_hints.prefetch_related)

temp = available_related_fields.get(
field.name.value,
available_related_fields.get(to_snake_case(field.name.value), None),
Expand All @@ -336,6 +354,8 @@ def recursive_params(
select_related.append(temp.name)
elif getattr(field, "selection_set", None):
a, b = recursive_params(
root_info,
field_type,
field.selection_set,
fragments,
available_related_fields,
Expand All @@ -348,7 +368,9 @@ def recursive_params(
return select_related, prefetch_related


def queryset_factory(manager, fields_asts=None, fragments=None, **kwargs):
def queryset_factory(manager, info, **kwargs):
fields_asts = info.field_asts
fragments = info.fragments

select_related = []
prefetch_related = []
Expand All @@ -364,8 +386,13 @@ def queryset_factory(manager, fields_asts=None, fragments=None, **kwargs):
else:
select_related.append(temp.name)

base_field_def = get_field_def(info.schema, info.parent_type, info.field_name)
base_type = get_type(base_field_def.type)

if fields_asts:
select_related, prefetch_related = recursive_params(
info,
base_type,
fields_asts[0].selection_set,
fragments,
available_related_fields,
Expand Down