Skip to content

Commit

Permalink
Feat/improve filter (#25)
Browse files Browse the repository at this point in the history
* feat(api): add Query Class to sites, grps & cats

* feat(api): add sort and fix _get_model

Via _get_entity

* test(api): test sort query

* feat(api): add api sort/sort_dir params

To be able to sort through REST Api

* fix(api): check if integer to avoid using ilike

In filter_by_params

* test(api): add test to check filter integer
  • Loading branch information
mvergez authored and amandine-sahl committed Oct 10, 2023
1 parent a61458d commit fb3bcc1
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 15 deletions.
8 changes: 6 additions & 2 deletions backend/gn_module_monitoring/monitoring/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pypnusershub.db.models import User
from geonature.core.gn_monitoring.models import corVisitObserver

from gn_module_monitoring.monitoring.queries import Query as MonitoringQuery

cor_module_categorie = DB.Table(
"cor_module_categorie",
DB.Column(
Expand Down Expand Up @@ -162,6 +164,7 @@ class TMonitoringSites(TBaseSites):
__mapper_args__ = {
"polymorphic_identity": "monitoring_site",
}
query_class = MonitoringQuery

id_base_site = DB.Column(
DB.ForeignKey("gn_monitoring.t_base_sites.id_base_site"), nullable=False, primary_key=True
Expand Down Expand Up @@ -215,8 +218,9 @@ class TMonitoringSites(TBaseSites):

@serializable
class TMonitoringSitesGroups(DB.Model):
__tablename__ = "t_sites_groups"
__table_args__ = {"schema": "gn_monitoring"}
__tablename__ = 't_sites_groups'
__table_args__ = {'schema': 'gn_monitoring'}
query_class = MonitoringQuery

id_sites_group = DB.Column(DB.Integer, primary_key=True, nullable=False, unique=True)

Expand Down
37 changes: 37 additions & 0 deletions backend/gn_module_monitoring/monitoring/queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from flask_sqlalchemy import BaseQuery
from sqlalchemy import Integer, and_
from werkzeug.datastructures import MultiDict


class Query(BaseQuery):
def _get_entity(self, entity):
if hasattr(entity, "_entities"):
return self._get_entity(entity._entities[0])
return entity.entities[0]

def _get_model(self):
# When sqlalchemy is updated:
# return self._raw_columns[0].entity_namespace
# But for now:
entity = self._get_entity(self)
return entity.c

def filter_by_params(self, params: MultiDict = None):
model = self._get_model()
and_list = []
for key, value in params.items():
column = getattr(model, key)
if isinstance(column.type, Integer):
and_list.append(column == value)
else:
and_list.append(column.ilike(f"%{value}%"))
and_query = and_(*and_list)
return self.filter(and_query)

def sort(self, label: str, direction: str):
model = self._get_model()
order_by = getattr(model, label)
if direction == "desc":
order_by = order_by.desc()

return self.order_by(order_by)
23 changes: 18 additions & 5 deletions backend/gn_module_monitoring/routes/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
from werkzeug.datastructures import MultiDict

from gn_module_monitoring.blueprint import blueprint
from gn_module_monitoring.monitoring.models import BibCategorieSite
from gn_module_monitoring.monitoring.models import BibCategorieSite, TMonitoringSites
from gn_module_monitoring.utils.routes import (
filter_params,
get_limit_offset,
get_sort,
paginate,
sort,
)
from gn_module_monitoring.monitoring.schemas import MonitoringSitesSchema,BibCategorieSiteSchema
from gn_module_monitoring.utils.routes import filter_params, get_limit_offset, paginate


@blueprint.route("/sites/categories", methods=["GET"])
def get_categories():
params = MultiDict(request.args)
limit, page = get_limit_offset(params=params)
sort_label, sort_dir = get_sort(
params=params, default_sort="id_categorie", default_direction="desc"
)

query = filter_params(query=BibCategorieSite.query, params=params)
query = query.order_by(BibCategorieSite.id_categorie)
query = sort(query=query, sort=sort_label, sort_dir=sort_dir)

return paginate(
query=query,
Expand All @@ -38,10 +47,14 @@ def get_sites():
params = MultiDict(request.args)
# TODO: add filter support
limit, page = get_limit_offset(params=params)
query = TBaseSites.query.join(
BibCategorieSite, TBaseSites.id_categorie == BibCategorieSite.id_categorie
sort_label, sort_dir = get_sort(
params=params, default_sort="id_base_site", default_direction="desc"
)
query = TMonitoringSites.query.join(
BibCategorieSite, TMonitoringSites.id_categorie == BibCategorieSite.id_categorie
)
query = filter_params(query=query, params=params)
query = sort(query=query, sort=sort_label, sort_dir=sort_dir)
return paginate(
query=query,
schema=MonitoringSitesSchema,
Expand Down
14 changes: 11 additions & 3 deletions backend/gn_module_monitoring/routes/sites_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@

from gn_module_monitoring.blueprint import blueprint
from gn_module_monitoring.monitoring.models import TMonitoringSitesGroups
from gn_module_monitoring.utils.routes import (
filter_params,
get_limit_offset,
get_sort,
paginate,
sort,
)
from gn_module_monitoring.monitoring.schemas import MonitoringSitesGroupsSchema
from gn_module_monitoring.utils.routes import filter_params, get_limit_offset, paginate


@blueprint.route("/sites_groups", methods=["GET"])
def get_sites_groups():
params = MultiDict(request.args)
limit, page = get_limit_offset(params=params)

sort_label, sort_dir = get_sort(
params=params, default_sort="id_sites_group", default_direction="desc"
)
query = filter_params(query=TMonitoringSitesGroups.query, params=params)

query = query.order_by(TMonitoringSitesGroups.id_sites_group)
query = sort(query=query, sort=sort_label, sort_dir=sort_dir)
return paginate(
query=query,
schema=MonitoringSitesGroupsSchema,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from gn_module_monitoring.monitoring.models import TMonitoringSitesGroups


@pytest.mark.usefixtures("temporary_transaction")
class TestTMonitoringSitesGroups:
def test_sort_desc(self, sites_groups):
if len(sites_groups) < 2:
pytest.xfail(
"This test cannot work if there is less than 2 sites_groups in database (via fixtures or not)"
)

query = TMonitoringSitesGroups.query.filter(
TMonitoringSitesGroups.id_sites_group.in_(
group.id_sites_group for group in sites_groups.values()
)
).sort(label="id_sites_group", direction="desc")
result = query.all()

assert result[0].id_sites_group > result[1].id_sites_group

def test_sort_asc(self, sites_groups):
if len(sites_groups) < 2:
pytest.xfail(
"This test cannot work if there is less than 2 sites_groups in database (via fixtures or not)"
)

query = TMonitoringSitesGroups.query.filter(
TMonitoringSitesGroups.id_sites_group.in_(
group.id_sites_group for group in sites_groups.values()
)
).sort(label="id_sites_group", direction="asc")
result = query.all()

assert result[0].id_sites_group < result[1].id_sites_group
22 changes: 19 additions & 3 deletions backend/gn_module_monitoring/tests/test_routes/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def test_get_categories(self, categories):
r = self.client.get(url_for("monitorings.get_categories"))

assert r.json["count"] >= len(categories)
assert all(
[schema.dump(cat) in r.json["items"] for cat in categories.values()]
)
assert all([schema.dump(cat) in r.json["items"] for cat in categories.values()])

def test_get_categories_label(self, categories):
label = list(categories.keys())[0]
Expand All @@ -47,6 +45,24 @@ def test_get_sites_limit(self, sites):

assert len(r.json["items"]) == limit

def test_get_sites_base_site_name(self, sites):
site = list(sites.values())[0]
base_site_name = site.base_site_name

r = self.client.get(url_for("monitorings.get_sites", base_site_name=base_site_name))

assert len(r.json["items"]) == 1
assert r.json["items"][0]["base_site_name"] == base_site_name

def test_get_sites_id_base_site(self, sites):
site = list(sites.values())[0]
id_base_site = site.id_base_site

r = self.client.get(url_for("monitorings.get_sites", id_base_site=id_base_site))

assert len(r.json["items"]) == 1
assert r.json["items"][0]["id_base_site"] == id_base_site

def test_get_module_sites(self):
module_code = "TEST"
r = self.client.get(url_for("monitorings.get_module_sites", module_code=module_code))
Expand Down
15 changes: 13 additions & 2 deletions backend/gn_module_monitoring/utils/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from sqlalchemy.orm import Query
from werkzeug.datastructures import MultiDict

from gn_module_monitoring.monitoring.queries import Query as MonitoringQuery
from gn_module_monitoring.monitoring.schemas import paginate_schema


def get_limit_offset(params: MultiDict) -> Tuple[int]:
return int(params.pop("limit", 50)), int(params.pop("offset", 1))


def get_sort(params: MultiDict, default_sort: str, default_direction) -> Tuple[str]:
return params.pop("sort", default_sort), params.pop("sort_dir", default_direction)


def paginate(query: Query, schema: Schema, limit: int, page: int) -> Response:
result = query.paginate(page=page, error_out=False, per_page=limit)
pagination_schema = paginate_schema(schema)
Expand All @@ -22,7 +27,13 @@ def paginate(query: Query, schema: Schema, limit: int, page: int) -> Response:
return jsonify(data)


def filter_params(query: Query, params: MultiDict) -> Query:
def filter_params(query: MonitoringQuery, params: MultiDict) -> MonitoringQuery:
if len(params) != 0:
query = query.filter_by(**params)
query = query.filter_by_params(params)
return query


def sort(query: MonitoringQuery, sort: str, sort_dir: str) -> MonitoringQuery:
if sort_dir in ["desc", "asc"]:
query = query.sort(label=sort, direction=sort_dir)
return query

0 comments on commit fb3bcc1

Please sign in to comment.