Skip to content

Commit

Permalink
[api] fix, SQL selects and many to many joins (#1361)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Apr 29, 2020
1 parent b7a9fc2 commit 3685073
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 58 deletions.
2 changes: 1 addition & 1 deletion examples/crud_rest_api/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GreetingApi(BaseApi):
openapi_spec_methods = {
"greeting": {
"get": {
"description": "Override description",
"description": "Override description",
}
}
}
Expand Down
12 changes: 1 addition & 11 deletions flask_appbuilder/api/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, data):
self.childs = list()

def __repr__(self):
return "{}.{}".format(self.data, str(self.childs))
return f"{self.data}.{str(self.childs)}"


class Tree:
Expand Down Expand Up @@ -176,16 +176,6 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False)
field.validate.append(self.validators_columns[column.data])
return field

@staticmethod
def get_column_child_model(column):
if "." in column:
return column.split(".")[0]
return column

@staticmethod
def is_column_dotted(column):
return "." in column

def convert(self, columns, model=None, nested=True, enum_dump_by_name=False):
"""
Creates a Marshmallow ModelSchema class
Expand Down
130 changes: 88 additions & 42 deletions flask_appbuilder/models/sqla/interface.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# -*- coding: utf-8 -*-
import logging
import sys
from typing import List, Tuple

from flask_sqlalchemy import BaseQuery
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Load
from sqlalchemy.orm import aliased, Load
from sqlalchemy.orm.descriptor_props import SynonymProperty
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy_utils.types.uuid import UUIDType

from . import filters
from . import filters, Model
from ..base import BaseInterface
from ..group import GroupByCol, GroupByDateMonth, GroupByDateYear
from ..mixins import FileColumn, ImageColumn
Expand All @@ -23,6 +26,7 @@
LOGMSG_WAR_DBI_EDIT_INTEGRITY,
)
from ...filemanager import FileManager, ImageManager
from ...utils.base import get_column_leaf, get_column_root_relation, is_column_dotted

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,33 +99,69 @@ def _get_base_query(
query = query.order_by(self._get_attr(order_column).desc())
return query

def _query_join_dotted_column(self, query, column) -> (object, tuple):
relation_tuple = tuple()
if len(column.split(".")) >= 2:
for join_relation in column.split(".")[:-1]:
relation_tuple = self.get_related_model_and_join(join_relation)
model_relation, relation_join = relation_tuple
if not self.is_model_already_joined(query, model_relation):
query = query.join(model_relation, relation_join, isouter=True)
return query, relation_tuple

def _query_select_options(self, query, select_columns=None):
def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuery:
"""
Add select load options to query. The goal
is to only SQL select what is requested
Helper function that applies necessary joins for dotted columns on a
SQLAlchemy query object
:param query: SQLAlchemy Query obj
:param query: SQLAlchemy query object
:param root_relation: The root part of a dotted column, so the root relation
:return: Transformed SQLAlchemy Query
"""
relations = self.get_related_model_and_join(root_relation)

for relation in relations:
model_relation, relation_join = relation
# Support multiple joins for the same table
if self.is_model_already_joined(query, model_relation):
# Since the join already exists apply a new aliased one
model_relation = aliased(model_relation)
# The binary expression needs to be inverted
relation_join = BinaryExpression(
relation_join.left, model_relation.id, relation_join.operator
)
query = query.join(model_relation, relation_join, isouter=True)
return query

def _query_join_dotted_column(self, query: BaseQuery, column: str) -> BaseQuery:
"""
:param query: SQLAlchemy query object
:param column: If the column is dotted will join the root relation
:return: Transformed SQLAlchemy Query
"""
if is_column_dotted(column):
return self._query_join_relation(query, get_column_root_relation(column))
return query

def _query_select_options(
self, query: BaseQuery, select_columns: List[str] = None
) -> BaseQuery:
"""
Add select load options to query. The goal
is to only SQL select what is requested and join all the necessary
models when dotted notation is used
:param query: SQLAlchemy Query obj to apply joins and selects
:param select_columns: (list) of columns
:return: SQLAlchemy Query obj
:return: Transformed SQLAlchemy Query
"""
if select_columns:
_load_options = list()
load_options = list()
joined_models = list()
for column in select_columns:
query, relation_tuple = self._query_join_dotted_column(query, column)
model_relation, relation_join = relation_tuple or (None, None)
if model_relation:
_load_options.append(
Load(model_relation).load_only(column.split(".")[1])
if is_column_dotted(column):
root_relation = get_column_root_relation(column)
leaf_column = get_column_leaf(column)
if root_relation not in joined_models:
query = self._query_join_relation(query, root_relation)
joined_models.append(root_relation)
load_options.append(
(
Load(self.obj)
.joinedload(root_relation)
.load_only(leaf_column)
)
)
else:
# is a custom property method field?
Expand All @@ -131,10 +171,11 @@ def _query_select_options(self, query, select_columns=None):
elif not self.is_relation(column) and not hasattr(
getattr(self.obj, column), "__call__"
):
_load_options.append(Load(self.obj).load_only(column))
load_options.append(Load(self.obj).load_only(column))
# it's a normal column
else:
_load_options.append(Load(self.obj))
query = query.options(*tuple(_load_options))
load_options.append(Load(self.obj))
query = query.options(*tuple(load_options))
return query

def query(
Expand All @@ -147,21 +188,21 @@ def query(
select_columns=None,
):
"""
QUERY
:param filters:
dict with filters {<col_name>:<value,...}
:param order_column:
name of the column to order
:param order_direction:
the direction to order <'asc'|'desc'>
:param page:
the current page
:param page_size:
the current page size
Returns the results for a model query, applies filters, sorting and pagination
:param filters:
dict with filters {<col_name>:<value,...}
:param order_column:
name of the column to order
:param order_direction:
the direction to order <'asc'|'desc'>
:param page:
the current page
:param page_size:
the current page size
"""
query = self.session.query(self.obj)
query, relation_tuple = self._query_join_dotted_column(query, order_column)
query = self._query_join_dotted_column(query, order_column)
query = self._query_select_options(query, select_columns)
query_count = self.session.query(func.count("*")).select_from(self.obj)

Expand Down Expand Up @@ -521,12 +562,17 @@ def get_col_default(self, col_name):
return None
return value

def get_related_model(self, col_name):
def get_related_model(self, col_name: str) -> Model:
return self.list_properties[col_name].mapper.class_

def get_related_model_and_join(self, col_name):
def get_related_model_and_join(self, col_name: str) -> List[Tuple[Model, object]]:
relation = self.list_properties[col_name]
return relation.mapper.class_, relation.primaryjoin
if relation.direction.name == "MANYTOMANY":
return [
(relation.secondary, relation.primaryjoin),
(relation.mapper.class_, relation.secondaryjoin),
]
return [(relation.mapper.class_, relation.primaryjoin)]

def query_model_relation(self, col_name):
model = self.get_related_model(col_name)
Expand Down
2 changes: 2 additions & 0 deletions flask_appbuilder/tests/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class ModelMMChild(Model):
__tablename__ = "child"
id = Column(Integer, primary_key=True)
field_string = Column(String(50), unique=True, nullable=False)
field_integer = Column(Integer())


assoc_parent_child_required = Table(
Expand Down Expand Up @@ -292,6 +293,7 @@ def insert_data(session, count):
for i in range(1, 4):
model = ModelMMChild()
model.field_string = str(i)
model.field_integer = i
children.append(model)
session.add(model)
session.commit()
Expand Down
57 changes: 53 additions & 4 deletions flask_appbuilder/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ class ModelMMApi(ModelRestApi):

self.appbuilder.add_api(ModelMMApi)

class ModelDottedMMApi(ModelRestApi):
datamodel = SQLAInterface(ModelMMParent)
list_columns = ["field_string", "children.field_integer"]
show_columns = ["field_string", "children.field_integer"]

self.appbuilder.add_api(ModelDottedMMApi)

class ModelOMParentApi(ModelRestApi):
datamodel = SQLAInterface(ModelOMParent)

Expand Down Expand Up @@ -793,16 +800,38 @@ def test_get_item_mm_field(self):

# We can't get a base filtered item
pk = 1
rv = self.auth_client_get(client, token, "api/v1/modelmmapi/{}".format(pk))
rv = self.auth_client_get(client, token, f"api/v1/modelmmapi/{pk}")
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
expected_rel_field = [
{"field_string": "1", "id": 1},
{"field_string": "2", "id": 2},
{"field_string": "3", "id": 3},
{"field_string": "1", "field_integer": 1, "id": 1},
{"field_string": "2", "field_integer": 2, "id": 2},
{"field_string": "3", "field_integer": 3, "id": 3},
]
self.assertEqual(data[API_RESULT_RES_KEY]["children"], expected_rel_field)

def test_get_item_dotted_mm_field(self):
"""
REST Api: Test get item with dotted N-N related field
"""
client = self.app.test_client()
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)

# We can't get a base filtered item
pk = 1
rv = self.auth_client_get(client, token, f"api/v1/modeldottedmmapi/{pk}")
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
expected_result = {
"field_string": "0",
"children": [
{"field_integer": 1},
{"field_integer": 2},
{"field_integer": 3},
],
}
self.assertEqual(data[API_RESULT_RES_KEY], expected_result)

def test_get_item_om_field(self):
"""
REST Api: Test get item with O-M related field
Expand Down Expand Up @@ -860,6 +889,26 @@ def test_get_list_dotted_notation(self):
{"field_string": "test0", "group": {"field_string": "test0"}},
)

def test_get_list_dotted_mm_field(self):
"""
REST Api: Test get list with dotted N-N related field
"""
client = self.app.test_client()
token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN)

arguments = {"order_column": "field_string", "order_direction": "asc"}
uri = (
f"api/v1/modeldottedmmapi/?" f"{API_URI_RIS_KEY}={prison.dumps(arguments)}"
)
rv = self.auth_client_get(client, token, uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
i = 0
self.assertEqual(data[API_RESULT_RES_KEY][i]["field_string"], "0")
self.assertIn({"field_integer": 1}, data[API_RESULT_RES_KEY][i]["children"])
self.assertIn({"field_integer": 2}, data[API_RESULT_RES_KEY][i]["children"])
self.assertIn({"field_integer": 3}, data[API_RESULT_RES_KEY][i]["children"])

def test_get_list_dotted_order(self):
"""
REST Api: Test get list and order dotted notation
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions flask_appbuilder/utils/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def get_column_root_relation(column: str) -> str:
if "." in column:
return column.split(".")[0]
return column


def get_column_leaf(column: str) -> str:
if "." in column:
return column.split(".")[1]
return column


def is_column_dotted(column: str) -> bool:
return "." in column
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ Pillow>=7.0.0, <8.0.0
mockldap>=0.3.0
psycopg2-binary==2.7.5
mysqlclient>=1.4.2, < 2.0.0
cython==0.29.17
pymssql==2.1.4
black==19.3b0

0 comments on commit 3685073

Please sign in to comment.