diff --git a/examples/crud_rest_api/app/api.py b/examples/crud_rest_api/app/api.py index afbe49cf3c..8182f74ca0 100644 --- a/examples/crud_rest_api/app/api.py +++ b/examples/crud_rest_api/app/api.py @@ -27,7 +27,7 @@ class GreetingApi(BaseApi): openapi_spec_methods = { "greeting": { "get": { - "description": "Override description", + "description": "Override description", } } } diff --git a/flask_appbuilder/api/convert.py b/flask_appbuilder/api/convert.py index fce433131a..bf0decefaf 100644 --- a/flask_appbuilder/api/convert.py +++ b/flask_appbuilder/api/convert.py @@ -10,7 +10,7 @@ def __init__(self, data): self.childs = list() def __repr__(self): - return "{}.{}".format(self.data, str(self.childs)) + return f"{self.data}.{str(self.childs)}" class Tree: @@ -176,16 +176,6 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) field.validate.append(self.validators_columns[column.data]) return field - @staticmethod - def get_column_child_model(column): - if "." in column: - return column.split(".")[0] - return column - - @staticmethod - def is_column_dotted(column): - return "." in column - def convert(self, columns, model=None, nested=True, enum_dump_by_name=False): """ Creates a Marshmallow ModelSchema class diff --git a/flask_appbuilder/models/sqla/interface.py b/flask_appbuilder/models/sqla/interface.py index aa2f62d0c4..735a91ef7a 100644 --- a/flask_appbuilder/models/sqla/interface.py +++ b/flask_appbuilder/models/sqla/interface.py @@ -1,15 +1,18 @@ # -*- coding: utf-8 -*- import logging import sys +from typing import List, Tuple +from flask_sqlalchemy import BaseQuery import sqlalchemy as sa from sqlalchemy import func from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Load +from sqlalchemy.orm import aliased, Load from sqlalchemy.orm.descriptor_props import SynonymProperty +from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy_utils.types.uuid import UUIDType -from . import filters +from . import filters, Model from ..base import BaseInterface from ..group import GroupByCol, GroupByDateMonth, GroupByDateYear from ..mixins import FileColumn, ImageColumn @@ -23,6 +26,7 @@ LOGMSG_WAR_DBI_EDIT_INTEGRITY, ) from ...filemanager import FileManager, ImageManager +from ...utils.base import get_column_leaf, get_column_root_relation, is_column_dotted log = logging.getLogger(__name__) @@ -95,33 +99,69 @@ def _get_base_query( query = query.order_by(self._get_attr(order_column).desc()) return query - def _query_join_dotted_column(self, query, column) -> (object, tuple): - relation_tuple = tuple() - if len(column.split(".")) >= 2: - for join_relation in column.split(".")[:-1]: - relation_tuple = self.get_related_model_and_join(join_relation) - model_relation, relation_join = relation_tuple - if not self.is_model_already_joined(query, model_relation): - query = query.join(model_relation, relation_join, isouter=True) - return query, relation_tuple - - def _query_select_options(self, query, select_columns=None): + def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuery: """ - Add select load options to query. The goal - is to only SQL select what is requested + Helper function that applies necessary joins for dotted columns on a + SQLAlchemy query object - :param query: SQLAlchemy Query obj + :param query: SQLAlchemy query object + :param root_relation: The root part of a dotted column, so the root relation + :return: Transformed SQLAlchemy Query + """ + relations = self.get_related_model_and_join(root_relation) + + for relation in relations: + model_relation, relation_join = relation + # Support multiple joins for the same table + if self.is_model_already_joined(query, model_relation): + # Since the join already exists apply a new aliased one + model_relation = aliased(model_relation) + # The binary expression needs to be inverted + relation_join = BinaryExpression( + relation_join.left, model_relation.id, relation_join.operator + ) + query = query.join(model_relation, relation_join, isouter=True) + return query + + def _query_join_dotted_column(self, query: BaseQuery, column: str) -> BaseQuery: + """ + + :param query: SQLAlchemy query object + :param column: If the column is dotted will join the root relation + :return: Transformed SQLAlchemy Query + """ + if is_column_dotted(column): + return self._query_join_relation(query, get_column_root_relation(column)) + return query + + def _query_select_options( + self, query: BaseQuery, select_columns: List[str] = None + ) -> BaseQuery: + """ + Add select load options to query. The goal + is to only SQL select what is requested and join all the necessary + models when dotted notation is used + + :param query: SQLAlchemy Query obj to apply joins and selects :param select_columns: (list) of columns - :return: SQLAlchemy Query obj + :return: Transformed SQLAlchemy Query """ if select_columns: - _load_options = list() + load_options = list() + joined_models = list() for column in select_columns: - query, relation_tuple = self._query_join_dotted_column(query, column) - model_relation, relation_join = relation_tuple or (None, None) - if model_relation: - _load_options.append( - Load(model_relation).load_only(column.split(".")[1]) + if is_column_dotted(column): + root_relation = get_column_root_relation(column) + leaf_column = get_column_leaf(column) + if root_relation not in joined_models: + query = self._query_join_relation(query, root_relation) + joined_models.append(root_relation) + load_options.append( + ( + Load(self.obj) + .joinedload(root_relation) + .load_only(leaf_column) + ) ) else: # is a custom property method field? @@ -131,10 +171,11 @@ def _query_select_options(self, query, select_columns=None): elif not self.is_relation(column) and not hasattr( getattr(self.obj, column), "__call__" ): - _load_options.append(Load(self.obj).load_only(column)) + load_options.append(Load(self.obj).load_only(column)) + # it's a normal column else: - _load_options.append(Load(self.obj)) - query = query.options(*tuple(_load_options)) + load_options.append(Load(self.obj)) + query = query.options(*tuple(load_options)) return query def query( @@ -147,21 +188,21 @@ def query( select_columns=None, ): """ - QUERY - :param filters: - dict with filters {: - :param page: - the current page - :param page_size: - the current page size - + Returns the results for a model query, applies filters, sorting and pagination + + :param filters: + dict with filters {: + :param page: + the current page + :param page_size: + the current page size """ query = self.session.query(self.obj) - query, relation_tuple = self._query_join_dotted_column(query, order_column) + query = self._query_join_dotted_column(query, order_column) query = self._query_select_options(query, select_columns) query_count = self.session.query(func.count("*")).select_from(self.obj) @@ -521,12 +562,17 @@ def get_col_default(self, col_name): return None return value - def get_related_model(self, col_name): + def get_related_model(self, col_name: str) -> Model: return self.list_properties[col_name].mapper.class_ - def get_related_model_and_join(self, col_name): + def get_related_model_and_join(self, col_name: str) -> List[Tuple[Model, object]]: relation = self.list_properties[col_name] - return relation.mapper.class_, relation.primaryjoin + if relation.direction.name == "MANYTOMANY": + return [ + (relation.secondary, relation.primaryjoin), + (relation.mapper.class_, relation.secondaryjoin), + ] + return [(relation.mapper.class_, relation.primaryjoin)] def query_model_relation(self, col_name): model = self.get_related_model(col_name) diff --git a/flask_appbuilder/tests/sqla/models.py b/flask_appbuilder/tests/sqla/models.py index e7a930dfde..48d5c62cae 100644 --- a/flask_appbuilder/tests/sqla/models.py +++ b/flask_appbuilder/tests/sqla/models.py @@ -133,6 +133,7 @@ class ModelMMChild(Model): __tablename__ = "child" id = Column(Integer, primary_key=True) field_string = Column(String(50), unique=True, nullable=False) + field_integer = Column(Integer()) assoc_parent_child_required = Table( @@ -292,6 +293,7 @@ def insert_data(session, count): for i in range(1, 4): model = ModelMMChild() model.field_string = str(i) + model.field_integer = i children.append(model) session.add(model) session.commit() diff --git a/flask_appbuilder/tests/test_api.py b/flask_appbuilder/tests/test_api.py index b972fef5c8..1de3fed7f4 100644 --- a/flask_appbuilder/tests/test_api.py +++ b/flask_appbuilder/tests/test_api.py @@ -276,6 +276,13 @@ class ModelMMApi(ModelRestApi): self.appbuilder.add_api(ModelMMApi) + class ModelDottedMMApi(ModelRestApi): + datamodel = SQLAInterface(ModelMMParent) + list_columns = ["field_string", "children.field_integer"] + show_columns = ["field_string", "children.field_integer"] + + self.appbuilder.add_api(ModelDottedMMApi) + class ModelOMParentApi(ModelRestApi): datamodel = SQLAInterface(ModelOMParent) @@ -793,16 +800,38 @@ def test_get_item_mm_field(self): # We can't get a base filtered item pk = 1 - rv = self.auth_client_get(client, token, "api/v1/modelmmapi/{}".format(pk)) + rv = self.auth_client_get(client, token, f"api/v1/modelmmapi/{pk}") data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) expected_rel_field = [ - {"field_string": "1", "id": 1}, - {"field_string": "2", "id": 2}, - {"field_string": "3", "id": 3}, + {"field_string": "1", "field_integer": 1, "id": 1}, + {"field_string": "2", "field_integer": 2, "id": 2}, + {"field_string": "3", "field_integer": 3, "id": 3}, ] self.assertEqual(data[API_RESULT_RES_KEY]["children"], expected_rel_field) + def test_get_item_dotted_mm_field(self): + """ + REST Api: Test get item with dotted N-N related field + """ + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + # We can't get a base filtered item + pk = 1 + rv = self.auth_client_get(client, token, f"api/v1/modeldottedmmapi/{pk}") + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + expected_result = { + "field_string": "0", + "children": [ + {"field_integer": 1}, + {"field_integer": 2}, + {"field_integer": 3}, + ], + } + self.assertEqual(data[API_RESULT_RES_KEY], expected_result) + def test_get_item_om_field(self): """ REST Api: Test get item with O-M related field @@ -860,6 +889,26 @@ def test_get_list_dotted_notation(self): {"field_string": "test0", "group": {"field_string": "test0"}}, ) + def test_get_list_dotted_mm_field(self): + """ + REST Api: Test get list with dotted N-N related field + """ + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + arguments = {"order_column": "field_string", "order_direction": "asc"} + uri = ( + f"api/v1/modeldottedmmapi/?" f"{API_URI_RIS_KEY}={prison.dumps(arguments)}" + ) + rv = self.auth_client_get(client, token, uri) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + i = 0 + self.assertEqual(data[API_RESULT_RES_KEY][i]["field_string"], "0") + self.assertIn({"field_integer": 1}, data[API_RESULT_RES_KEY][i]["children"]) + self.assertIn({"field_integer": 2}, data[API_RESULT_RES_KEY][i]["children"]) + self.assertIn({"field_integer": 3}, data[API_RESULT_RES_KEY][i]["children"]) + def test_get_list_dotted_order(self): """ REST Api: Test get list and order dotted notation diff --git a/flask_appbuilder/utils/__init__.py b/flask_appbuilder/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flask_appbuilder/utils/base.py b/flask_appbuilder/utils/base.py new file mode 100644 index 0000000000..1224df88d4 --- /dev/null +++ b/flask_appbuilder/utils/base.py @@ -0,0 +1,14 @@ +def get_column_root_relation(column: str) -> str: + if "." in column: + return column.split(".")[0] + return column + + +def get_column_leaf(column: str) -> str: + if "." in column: + return column.split(".")[1] + return column + + +def is_column_dotted(column: str) -> bool: + return "." in column diff --git a/requirements-dev.txt b/requirements-dev.txt index 0091b33af1..f58074adb2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,5 +11,6 @@ Pillow>=7.0.0, <8.0.0 mockldap>=0.3.0 psycopg2-binary==2.7.5 mysqlclient>=1.4.2, < 2.0.0 +cython==0.29.17 pymssql==2.1.4 black==19.3b0