Skip to content

Commit

Permalink
feat: advanced filtering API (#1468)
Browse files Browse the repository at this point in the history
* created query filter classes

* extended pagination to include query filtering

* added filtering tests

* type improvements

* move type help to dev depedency

* minor type and perf fixes

* breakup test cases

Co-authored-by: Hayden <[email protected]>
  • Loading branch information
michael-genson and hay-kot authored Jul 10, 2022
1 parent c64da1f commit 7f50071
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 353 deletions.
13 changes: 12 additions & 1 deletion mealie/repos/repository_generic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from math import ceil
from typing import Any, Generic, TypeVar, Union

from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import func
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes

from mealie.core.root_logger import get_logger
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter

Schema = TypeVar("Schema", bound=BaseModel)
Model = TypeVar("Model")
Expand Down Expand Up @@ -236,14 +238,23 @@ def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase
are filtered by the user and group id when applicable.
NOTE: When you provide an override you'll need to manually type the result of this method
as the override, as the type system, is not able to infer the result of this method.
as the override, as the type system is not able to infer the result of this method.
"""
eff_schema = override or self.schema

q = self.session.query(self.model)

fltr = self._filter_builder()
q = q.filter_by(**fltr)
if pagination.query_filter:
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)

except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

count = q.count()

# interpret -1 as "get_all"
Expand Down
11 changes: 11 additions & 0 deletions mealie/repos/repository_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional
from uuid import UUID

from fastapi import HTTPException
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import and_, func
Expand All @@ -20,6 +21,7 @@
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter

from .repository_generic import RepositoryGeneric

Expand Down Expand Up @@ -147,6 +149,15 @@ def page_all(self, pagination: PaginationQuery, override=None, load_food=False)

fltr = self._filter_builder()
q = q.filter_by(**fltr)
if pagination.query_filter:
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)

except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

count = q.count()

# interpret -1 as "get_all"
Expand Down
5 changes: 5 additions & 0 deletions mealie/schema/recipe/recipe_ingredient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import enum
from typing import Optional, Union
from uuid import UUID, uuid4
Expand Down Expand Up @@ -27,6 +28,8 @@ class SaveIngredientFood(CreateIngredientFood):
class IngredientFood(CreateIngredientFood):
id: UUID4
label: Optional[MultiPurposeLabelSummary] = None
created_at: Optional[datetime.datetime]
update_at: Optional[datetime.datetime]

class Config:
orm_mode = True
Expand All @@ -48,6 +51,8 @@ class SaveIngredientUnit(CreateIngredientUnit):

class IngredientUnit(CreateIngredientUnit):
id: UUID4
created_at: Optional[datetime.datetime]
update_at: Optional[datetime.datetime]

class Config:
orm_mode = True
Expand Down
1 change: 1 addition & 0 deletions mealie/schema/response/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class PaginationQuery(MealieModel):
per_page: int = 50
order_by: str = "created_at"
order_direction: OrderDirection = OrderDirection.desc
query_filter: str = None


class PaginationBase(GenericModel, Generic[DataT]):
Expand Down
235 changes: 235 additions & 0 deletions mealie/schema/response/query_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from __future__ import annotations

import re
from enum import Enum
from typing import Any, TypeVar, Union, cast

from dateutil import parser as date_parser
from dateutil.parser import ParserError
from humps import decamelize
from sqlalchemy import bindparam, text
from sqlalchemy.orm.query import Query
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.expression import BindParameter

Model = TypeVar("Model")


class RelationalOperator(Enum):
EQ = "="
NOTEQ = "<>"
GT = ">"
LT = "<"
GTE = ">="
LTE = "<="


class LogicalOperator(Enum):
AND = "AND"
OR = "OR"


class QueryFilterComponent:
"""A single relational statement"""

def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None:
self.attribute_name = decamelize(attribute_name)
self.relational_operator = relational_operator
self.value = value

# remove encasing quotes
if len(value) > 2 and value[0] == '"' and value[-1] == '"':
self.value = value[1:-1]

def __repr__(self) -> str:
return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]"


class QueryFilter:
lsep: str = "("
rsep: str = ")"

seps: set[str] = {lsep, rsep}

def __init__(self, filter_string: str) -> None:
# parse filter string
components = QueryFilter._break_filter_string_into_components(filter_string)
base_components = QueryFilter._break_components_into_base_components(components)
if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep):
raise ValueError("invalid filter string: parenthesis are unbalanced")

# parse base components into a filter group
self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components)

def __repr__(self) -> str:
return f'<<{" ".join([str(component.value if isinstance(component, LogicalOperator) else component) for component in self.filter_components])}>>'

def filter_query(self, query: Query, model: type[Model]) -> Query:
segments: list[str] = []
params: list[BindParameter] = []
for i, component in enumerate(self.filter_components):
if component in QueryFilter.seps:
segments.append(component) # type: ignore
continue

if isinstance(component, LogicalOperator):
segments.append(component.value)
continue

# for some reason typing doesn't like the lsep and rsep literals, so we explicitly mark this as a filter component instead
# cast doesn't actually do anything at runtime
component = cast(QueryFilterComponent, component)

if not hasattr(model, component.attribute_name):
raise ValueError(f"invalid query string: '{component.attribute_name}' does not exist on this schema")

# convert values to their proper types
attr = getattr(model, component.attribute_name)
value: Any = component.value

if isinstance(attr.type, (sqltypes.Date, sqltypes.DateTime)):
try:
value = date_parser.parse(component.value)

except ParserError as e:
raise ValueError(
f"invalid query string: unknown date or datetime format '{component.value}'"
) from e

if isinstance(attr.type, sqltypes.Boolean):
try:
value = component.value.lower()[0] in ["t", "y"] or component.value == "1"

except IndexError as e:
raise ValueError("invalid query string") from e

paramkey = f"P{i+1}"
segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"]))
params.append(bindparam(paramkey, value, attr.type))

qs = text(" ".join(segments)).bindparams(*params)
query = query.filter(qs)
return query

@staticmethod
def _break_filter_string_into_components(filter_string: str) -> list[str]:
"""Recursively break filter string into components based on parenthesis groupings"""
components = [filter_string]
in_quotes = False
while True:
subcomponents = []
for component in components:
# don't parse components comprised of only a separator
if component in QueryFilter.seps:
subcomponents.append(component)
continue

# construct a component until it hits the right separator
new_component = ""
for c in component:
# ignore characters in-between quotes
if c == '"':
in_quotes = not in_quotes

if c in QueryFilter.seps and not in_quotes:
if new_component:
subcomponents.append(new_component)

subcomponents.append(c)
new_component = ""
continue

new_component += c

if new_component:
subcomponents.append(new_component.strip())

if components == subcomponents:
break

components = subcomponents

return components

@staticmethod
def _break_components_into_base_components(components: list[str]) -> list[str]:
"""Further break down components by splitting at relational and logical operators"""
logical_operators = re.compile(
f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE
)

base_components = []
for component in components:
offset = 0
subcomponents = component.split('"')
for i, subcomponent in enumerate(subcomponents):
# don't parse components comprised of only a separator
if subcomponent in QueryFilter.seps:
offset += 1
base_components.append(subcomponent)
continue

# this subscomponent was surrounded in quotes, so we keep it as-is
if (i + offset) % 2:
base_components.append(f'"{subcomponent.strip()}"')
continue

# if the final subcomponent has quotes, it creates an extra empty subcomponent at the end
if not subcomponent:
continue

# parse out logical operators
new_components = [
base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component
]

# parse out relational operators; each base_subcomponent has exactly zero or one relational operator
# we do them one at a time in descending length since some operators overlap (e.g. :> and >)
for component in new_components:
if not component:
continue

added_to_base_components = False
for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True):
if rel_op in component:
new_base_components = [
base_component.strip() for base_component in component.split(rel_op) if base_component
]
new_base_components.insert(1, rel_op)
base_components.extend(new_base_components)

added_to_base_components = True
break

if not added_to_base_components:
base_components.append(component)

return base_components

@staticmethod
def _parse_base_components_into_filter_components(
base_components: list[str],
) -> list[Union[str, QueryFilterComponent, LogicalOperator]]:
"""Walk through base components and construct filter collections"""
relational_operators = [op.value for op in RelationalOperator]
logical_operators = [op.value for op in LogicalOperator]

# parse QueryFilterComponents and logical operators
components: list[Union[str, QueryFilterComponent, LogicalOperator]] = []
for i, base_component in enumerate(base_components):
if base_component in QueryFilter.seps:
components.append(base_component)

elif base_component in relational_operators:
components.append(
QueryFilterComponent(
attribute_name=base_components[i - 1],
relational_operator=RelationalOperator(base_components[i]),
value=base_components[i + 1],
)
)

elif base_component.upper() in logical_operators:
components.append(LogicalOperator(base_component.upper()))

return components
Loading

0 comments on commit 7f50071

Please sign in to comment.