From fb3bcc18803138f420f25f89001b21bc8d7631bd Mon Sep 17 00:00:00 2001 From: Maxime Vergez <85738261+mvergez@users.noreply.github.com> Date: Wed, 4 Jan 2023 12:39:04 +0100 Subject: [PATCH] Feat/improve filter (#25) * feat(api): add Query Class to sites, grps & cats * feat(api): add sort and fix _get_model Via _get_entity * test(api): test sort query * feat(api): add api sort/sort_dir params To be able to sort through REST Api * fix(api): check if integer to avoid using ilike In filter_by_params * test(api): add test to check filter integer --- .../gn_module_monitoring/monitoring/models.py | 8 +++- .../monitoring/queries.py | 37 +++++++++++++++++++ backend/gn_module_monitoring/routes/site.py | 23 +++++++++--- .../routes/sites_groups.py | 14 +++++-- .../test_models/test_sites_groups.py | 36 ++++++++++++++++++ .../tests/test_routes/test_site.py | 22 +++++++++-- backend/gn_module_monitoring/utils/routes.py | 15 +++++++- 7 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 backend/gn_module_monitoring/monitoring/queries.py create mode 100644 backend/gn_module_monitoring/tests/test_monitoring/test_models/test_sites_groups.py diff --git a/backend/gn_module_monitoring/monitoring/models.py b/backend/gn_module_monitoring/monitoring/models.py index 2602a7be8..99d082d7e 100644 --- a/backend/gn_module_monitoring/monitoring/models.py +++ b/backend/gn_module_monitoring/monitoring/models.py @@ -20,6 +20,8 @@ from pypnusershub.db.models import User from geonature.core.gn_monitoring.models import corVisitObserver +from gn_module_monitoring.monitoring.queries import Query as MonitoringQuery + cor_module_categorie = DB.Table( "cor_module_categorie", DB.Column( @@ -162,6 +164,7 @@ class TMonitoringSites(TBaseSites): __mapper_args__ = { "polymorphic_identity": "monitoring_site", } + query_class = MonitoringQuery id_base_site = DB.Column( DB.ForeignKey("gn_monitoring.t_base_sites.id_base_site"), nullable=False, primary_key=True @@ -215,8 +218,9 @@ class TMonitoringSites(TBaseSites): @serializable class TMonitoringSitesGroups(DB.Model): - __tablename__ = "t_sites_groups" - __table_args__ = {"schema": "gn_monitoring"} + __tablename__ = 't_sites_groups' + __table_args__ = {'schema': 'gn_monitoring'} + query_class = MonitoringQuery id_sites_group = DB.Column(DB.Integer, primary_key=True, nullable=False, unique=True) diff --git a/backend/gn_module_monitoring/monitoring/queries.py b/backend/gn_module_monitoring/monitoring/queries.py new file mode 100644 index 000000000..cc2d6f972 --- /dev/null +++ b/backend/gn_module_monitoring/monitoring/queries.py @@ -0,0 +1,37 @@ +from flask_sqlalchemy import BaseQuery +from sqlalchemy import Integer, and_ +from werkzeug.datastructures import MultiDict + + +class Query(BaseQuery): + def _get_entity(self, entity): + if hasattr(entity, "_entities"): + return self._get_entity(entity._entities[0]) + return entity.entities[0] + + def _get_model(self): + # When sqlalchemy is updated: + # return self._raw_columns[0].entity_namespace + # But for now: + entity = self._get_entity(self) + return entity.c + + def filter_by_params(self, params: MultiDict = None): + model = self._get_model() + and_list = [] + for key, value in params.items(): + column = getattr(model, key) + if isinstance(column.type, Integer): + and_list.append(column == value) + else: + and_list.append(column.ilike(f"%{value}%")) + and_query = and_(*and_list) + return self.filter(and_query) + + def sort(self, label: str, direction: str): + model = self._get_model() + order_by = getattr(model, label) + if direction == "desc": + order_by = order_by.desc() + + return self.order_by(order_by) diff --git a/backend/gn_module_monitoring/routes/site.py b/backend/gn_module_monitoring/routes/site.py index cb10c21ca..1a53ca316 100644 --- a/backend/gn_module_monitoring/routes/site.py +++ b/backend/gn_module_monitoring/routes/site.py @@ -4,18 +4,27 @@ from werkzeug.datastructures import MultiDict from gn_module_monitoring.blueprint import blueprint -from gn_module_monitoring.monitoring.models import BibCategorieSite +from gn_module_monitoring.monitoring.models import BibCategorieSite, TMonitoringSites +from gn_module_monitoring.utils.routes import ( + filter_params, + get_limit_offset, + get_sort, + paginate, + sort, +) from gn_module_monitoring.monitoring.schemas import MonitoringSitesSchema,BibCategorieSiteSchema -from gn_module_monitoring.utils.routes import filter_params, get_limit_offset, paginate @blueprint.route("/sites/categories", methods=["GET"]) def get_categories(): params = MultiDict(request.args) limit, page = get_limit_offset(params=params) + sort_label, sort_dir = get_sort( + params=params, default_sort="id_categorie", default_direction="desc" + ) query = filter_params(query=BibCategorieSite.query, params=params) - query = query.order_by(BibCategorieSite.id_categorie) + query = sort(query=query, sort=sort_label, sort_dir=sort_dir) return paginate( query=query, @@ -38,10 +47,14 @@ def get_sites(): params = MultiDict(request.args) # TODO: add filter support limit, page = get_limit_offset(params=params) - query = TBaseSites.query.join( - BibCategorieSite, TBaseSites.id_categorie == BibCategorieSite.id_categorie + sort_label, sort_dir = get_sort( + params=params, default_sort="id_base_site", default_direction="desc" + ) + query = TMonitoringSites.query.join( + BibCategorieSite, TMonitoringSites.id_categorie == BibCategorieSite.id_categorie ) query = filter_params(query=query, params=params) + query = sort(query=query, sort=sort_label, sort_dir=sort_dir) return paginate( query=query, schema=MonitoringSitesSchema, diff --git a/backend/gn_module_monitoring/routes/sites_groups.py b/backend/gn_module_monitoring/routes/sites_groups.py index 7ea5609df..57317445d 100644 --- a/backend/gn_module_monitoring/routes/sites_groups.py +++ b/backend/gn_module_monitoring/routes/sites_groups.py @@ -3,18 +3,26 @@ from gn_module_monitoring.blueprint import blueprint from gn_module_monitoring.monitoring.models import TMonitoringSitesGroups +from gn_module_monitoring.utils.routes import ( + filter_params, + get_limit_offset, + get_sort, + paginate, + sort, +) from gn_module_monitoring.monitoring.schemas import MonitoringSitesGroupsSchema -from gn_module_monitoring.utils.routes import filter_params, get_limit_offset, paginate @blueprint.route("/sites_groups", methods=["GET"]) def get_sites_groups(): params = MultiDict(request.args) limit, page = get_limit_offset(params=params) - + sort_label, sort_dir = get_sort( + params=params, default_sort="id_sites_group", default_direction="desc" + ) query = filter_params(query=TMonitoringSitesGroups.query, params=params) - query = query.order_by(TMonitoringSitesGroups.id_sites_group) + query = sort(query=query, sort=sort_label, sort_dir=sort_dir) return paginate( query=query, schema=MonitoringSitesGroupsSchema, diff --git a/backend/gn_module_monitoring/tests/test_monitoring/test_models/test_sites_groups.py b/backend/gn_module_monitoring/tests/test_monitoring/test_models/test_sites_groups.py new file mode 100644 index 000000000..99a54f811 --- /dev/null +++ b/backend/gn_module_monitoring/tests/test_monitoring/test_models/test_sites_groups.py @@ -0,0 +1,36 @@ +import pytest + +from gn_module_monitoring.monitoring.models import TMonitoringSitesGroups + + +@pytest.mark.usefixtures("temporary_transaction") +class TestTMonitoringSitesGroups: + def test_sort_desc(self, sites_groups): + if len(sites_groups) < 2: + pytest.xfail( + "This test cannot work if there is less than 2 sites_groups in database (via fixtures or not)" + ) + + query = TMonitoringSitesGroups.query.filter( + TMonitoringSitesGroups.id_sites_group.in_( + group.id_sites_group for group in sites_groups.values() + ) + ).sort(label="id_sites_group", direction="desc") + result = query.all() + + assert result[0].id_sites_group > result[1].id_sites_group + + def test_sort_asc(self, sites_groups): + if len(sites_groups) < 2: + pytest.xfail( + "This test cannot work if there is less than 2 sites_groups in database (via fixtures or not)" + ) + + query = TMonitoringSitesGroups.query.filter( + TMonitoringSitesGroups.id_sites_group.in_( + group.id_sites_group for group in sites_groups.values() + ) + ).sort(label="id_sites_group", direction="asc") + result = query.all() + + assert result[0].id_sites_group < result[1].id_sites_group diff --git a/backend/gn_module_monitoring/tests/test_routes/test_site.py b/backend/gn_module_monitoring/tests/test_routes/test_site.py index d63591ca5..ecbbb3b68 100644 --- a/backend/gn_module_monitoring/tests/test_routes/test_site.py +++ b/backend/gn_module_monitoring/tests/test_routes/test_site.py @@ -22,9 +22,7 @@ def test_get_categories(self, categories): r = self.client.get(url_for("monitorings.get_categories")) assert r.json["count"] >= len(categories) - assert all( - [schema.dump(cat) in r.json["items"] for cat in categories.values()] - ) + assert all([schema.dump(cat) in r.json["items"] for cat in categories.values()]) def test_get_categories_label(self, categories): label = list(categories.keys())[0] @@ -47,6 +45,24 @@ def test_get_sites_limit(self, sites): assert len(r.json["items"]) == limit + def test_get_sites_base_site_name(self, sites): + site = list(sites.values())[0] + base_site_name = site.base_site_name + + r = self.client.get(url_for("monitorings.get_sites", base_site_name=base_site_name)) + + assert len(r.json["items"]) == 1 + assert r.json["items"][0]["base_site_name"] == base_site_name + + def test_get_sites_id_base_site(self, sites): + site = list(sites.values())[0] + id_base_site = site.id_base_site + + r = self.client.get(url_for("monitorings.get_sites", id_base_site=id_base_site)) + + assert len(r.json["items"]) == 1 + assert r.json["items"][0]["id_base_site"] == id_base_site + def test_get_module_sites(self): module_code = "TEST" r = self.client.get(url_for("monitorings.get_module_sites", module_code=module_code)) diff --git a/backend/gn_module_monitoring/utils/routes.py b/backend/gn_module_monitoring/utils/routes.py index a7889dc68..0102f323c 100644 --- a/backend/gn_module_monitoring/utils/routes.py +++ b/backend/gn_module_monitoring/utils/routes.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Query from werkzeug.datastructures import MultiDict +from gn_module_monitoring.monitoring.queries import Query as MonitoringQuery from gn_module_monitoring.monitoring.schemas import paginate_schema @@ -13,6 +14,10 @@ def get_limit_offset(params: MultiDict) -> Tuple[int]: return int(params.pop("limit", 50)), int(params.pop("offset", 1)) +def get_sort(params: MultiDict, default_sort: str, default_direction) -> Tuple[str]: + return params.pop("sort", default_sort), params.pop("sort_dir", default_direction) + + def paginate(query: Query, schema: Schema, limit: int, page: int) -> Response: result = query.paginate(page=page, error_out=False, per_page=limit) pagination_schema = paginate_schema(schema) @@ -22,7 +27,13 @@ def paginate(query: Query, schema: Schema, limit: int, page: int) -> Response: return jsonify(data) -def filter_params(query: Query, params: MultiDict) -> Query: +def filter_params(query: MonitoringQuery, params: MultiDict) -> MonitoringQuery: if len(params) != 0: - query = query.filter_by(**params) + query = query.filter_by_params(params) + return query + + +def sort(query: MonitoringQuery, sort: str, sort_dir: str) -> MonitoringQuery: + if sort_dir in ["desc", "asc"]: + query = query.sort(label=sort, direction=sort_dir) return query