Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmark for connection fields #1

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var/
*.egg-info/
.installed.cfg
*.egg
.python-version

# PyInstaller
# Usually these files are written by a python script from a template
Expand All @@ -47,6 +48,7 @@ nosetests.xml
coverage.xml
*,cover
.pytest_cache/
.benchmarks/

# Translations
*.mo
Expand Down
3 changes: 0 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ matrix:
- env: TOXENV=py27
python: 2.7
# Python 3.5
- env: TOXENV=py34
python: 3.4
# Python 3.5
- env: TOXENV=py35
python: 3.5
# Python 3.6
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .fields import SQLAlchemyConnectionField
from .utils import get_query, get_session

__version__ = "2.2.2"
__version__ = "2.3.0.dev0"

__all__ = [
"__version__",
Expand Down
72 changes: 72 additions & 0 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import sqlalchemy
from promise import dataloader, promise
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext


def get_batch_resolver(relationship_prop):

# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))

class RelationshipLoader(dataloader.DataLoader):
cache = False

def batch_load_fn(self, parents): # pylint: disable=method-hidden
"""
Batch loads the relationships of all the parents as one SQL statement.

There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.

The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.

TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = relationship_prop.mapper
parent_mapper = relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty

# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = QueryContext(session.query(parent_mapper.entity))

selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
)

return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents])

loader = RelationshipLoader()

def resolve(root, info, **args):
return loader.load(root)

return resolve
38 changes: 30 additions & 8 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice

from .batching import get_batch_resolver
from .utils import get_query


Expand All @@ -33,14 +34,8 @@ def model(self):
return self.type._meta.node._meta.model

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, six.string_types):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
return query
def get_query(cls, model, info, **args):
return get_query(model, info.context)

@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
Expand Down Expand Up @@ -78,6 +73,7 @@ def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)


# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if "sort" not in kwargs and issubclass(type, Connection):
Expand All @@ -95,6 +91,32 @@ def __init__(self, type, *args, **kwargs):
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, six.string_types):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
return query


class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, self.resolver, self.type, self.model)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs)


def default_connection_field_factory(relationship, registry, **field_kwargs):
model = relationship.mapper.entity
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Reporter(Base):
last_name = Column(String(30), doc="Last name")
email = Column(String(), doc="Email")
favorite_pet_kind = Column(PetKind)
pets = relationship("Pet", secondary=association_table, backref="reporters")
pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id")
articles = relationship("Article", backref="reporter")
favorite_article = relationship("Article", uselist=False)

Expand Down
Loading