Skip to content

Commit

Permalink
sqlalchemy 1.4 infer tables also from select object (sqlalchemy 2.0 s…
Browse files Browse the repository at this point in the history
…tyle?)
  • Loading branch information
bodik committed Nov 4, 2023
1 parent edc7c24 commit b6791e3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 34 deletions.
59 changes: 28 additions & 31 deletions sqlalchemy_filters/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import operator

from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import mapperlib
Expand All @@ -8,10 +10,14 @@
from .exceptions import BadQuery, FieldNotFound, BadSpec


def sqlalchemy_version_lt(version):
def sqlalchemy_version_cmp(op, version):
"""compares sqla version < version"""

return tuple(sqlalchemy_version.split('.')) < tuple(version.split('.'))
ops = {'<': operator.lt, '>=': operator.ge}
return ops[op](
tuple(sqlalchemy_version.split('.')),
tuple(version.split('.'))
)


class Field(object):
Expand Down Expand Up @@ -68,7 +74,7 @@ def get_model_from_table(table): # pragma: no_cover_sqlalchemy_lt_1_4
return None


def get_query_models(query):
def get_query_models(query): # pragma: nocover
"""Get models from query.
:param query:
Expand All @@ -80,37 +86,33 @@ def get_query_models(query):
models = [col_desc['entity'] for col_desc in query.column_descriptions]

# account joined entities
if sqlalchemy_version_lt('1.4'): # pragma: no_cover_sqlalchemy_gte_1_4
if sqlalchemy_version_cmp('<', '1.4'):
models.extend(mapper.class_ for mapper in query._join_entities)
else: # pragma: no_cover_sqlalchemy_lt_1_4
else:
try:
models.extend(
mapper.class_
for mapper
in query._compile_state()._join_entities
)
except InvalidRequestError:
except (InvalidRequestError, AttributeError):
# query might not contain columns yet, hence cannot be compiled
# try to infer the models from various internals
for table_tuple in query._setup_joins + query._legacy_setup_joins:
model_class = get_model_from_table(table_tuple[0])
if model_class:
models.append(model_class)
# or query might be a sqla2.0 select statement
pass
# also try to infer the models from various internals
for table_tuple in query._setup_joins + query._legacy_setup_joins:
models.append(get_model_from_table(table_tuple[0]))

# account also query.select_from entities
model_class = None
if sqlalchemy_version_lt('1.4'): # pragma: no_cover_sqlalchemy_gte_1_4
if sqlalchemy_version_cmp('<', '1.1'): # sqla 1.0
if query._select_from_entity:
model_class = (
query._select_from_entity
if sqlalchemy_version_lt('1.1')
else query._select_from_entity.class_
)
else: # pragma: no_cover_sqlalchemy_lt_1_4
models.append(query._select_from_entity)
elif sqlalchemy_version_cmp('<', '1.4'): # sqla 1.1-1.3
if query._select_from_entity:
models.append(query._select_from_entity.class_)
else: # sqla 1.4
if query._from_obj:
model_class = get_model_from_table(query._from_obj[0])
if model_class and (model_class not in models):
models.append(model_class)
models.append(get_model_from_table(query._from_obj[0]))

return {model.__name__: model for model in models if model is not None}

Expand Down Expand Up @@ -191,23 +193,18 @@ def auto_join(query, *model_names):
last_model = list(query_models)[-1]
model_registry = (
last_model._decl_class_registry
if sqlalchemy_version_lt('1.4')
if sqlalchemy_version_cmp('<', '1.4')
else last_model.registry._class_registry
)

for name in model_names:
model = get_model_class_by_name(model_registry, name)
if model and (model not in get_query_models(query).values()):
try:
if sqlalchemy_version_lt('1.4'): # pragma: no_cover_sqlalchemy_gte_1_4
query = query.join(model)
else: # pragma: no_cover_sqlalchemy_lt_1_4
# https://docs.sqlalchemy.org/en/14/changelog/migration_14.html
# Many Core and ORM statement objects now perform much of
# their construction and validation in the compile phase
tmp = query.join(model)
tmp = query.join(model)
if sqlalchemy_version_cmp('>=', '1.4'): # pragma: nocover
tmp._compile_state()
query = tmp
query = tmp
except InvalidRequestError:
pass # can't be autojoined
return query
18 changes: 18 additions & 0 deletions test/interface/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from six import string_types
from sqlalchemy import func
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select

from sqlalchemy_filters import apply_filters
from sqlalchemy_filters.exceptions import (
BadFilterFormat, BadSpec, FieldNotFound
)
from sqlalchemy_filters.models import sqlalchemy_version_cmp

from test.models import Foo, Bar, Qux, Corge

Expand Down Expand Up @@ -1316,3 +1318,19 @@ def test_filter_by_hybrid_methods(self, session):
assert set(map(type, quxs)) == {Qux}
assert {qux.id for qux in quxs} == {4}
assert {qux.three_times_count() for qux in quxs} == {45}


class TestSelectObject:

@pytest.mark.usefixtures('multiple_bars_inserted')
def test_filter_on_select(self, session):
if sqlalchemy_version_cmp('<', '1.4'):
pytest.skip("Sqlalchemy select style 2.0 not supported")

query = select(Bar)
filters = {'field': 'name', 'op': '==', 'value': 'name_2'}

query = apply_filters(query, filters)
result = session.execute(query).fetchall()

assert len(result) == 1
6 changes: 3 additions & 3 deletions test/interface/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from sqlalchemy_filters.exceptions import BadSpec, BadQuery
from sqlalchemy_filters.models import (
auto_join, get_default_model, get_query_models, get_model_class_by_name,
get_model_from_spec, sqlalchemy_version_lt, get_model_from_table
get_model_from_spec, get_model_from_table, sqlalchemy_version_cmp
)
from test.models import Base, Bar, Foo, Qux


class TestGetQueryModels(object):
@pytest.mark.skipif(
sqlalchemy_version_lt('1.4'), reason='tests sqlalchemy 1.4 code'
sqlalchemy_version_cmp('<', '1.4'), reason='tests sqlalchemy 1.4 code'
)
def test_returns_none_for_unknown_table(self):

Expand Down Expand Up @@ -153,7 +153,7 @@ class TestGetModelClassByName:
def registry(self):
return (
Base._decl_class_registry
if sqlalchemy_version_lt('1.4')
if sqlalchemy_version_cmp('<', '1.4')
else Base.registry._class_registry
)

Expand Down

0 comments on commit b6791e3

Please sign in to comment.