Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: Advanced Filtering API #1468

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mealie/repos/repository_generic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from math import ceil
from typing import Any, Generic, TypeVar, Union

from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import func
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes

from mealie.core.root_logger import get_logger
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter

Schema = TypeVar("Schema", bound=BaseModel)
Model = TypeVar("Model")
Expand Down Expand Up @@ -236,14 +238,23 @@ def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase
are filtered by the user and group id when applicable.

NOTE: When you provide an override you'll need to manually type the result of this method
as the override, as the type system, is not able to infer the result of this method.
as the override, as the type system is not able to infer the result of this method.
"""
eff_schema = override or self.schema

q = self.session.query(self.model)

fltr = self._filter_builder()
q = q.filter_by(**fltr)
if pagination.query_filter:
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)

except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

count = q.count()

# interpret -1 as "get_all"
Expand Down
11 changes: 11 additions & 0 deletions mealie/repos/repository_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional
from uuid import UUID

from fastapi import HTTPException
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import and_, func
Expand All @@ -20,6 +21,7 @@
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter

from .repository_generic import RepositoryGeneric

Expand Down Expand Up @@ -147,6 +149,15 @@ def page_all(self, pagination: PaginationQuery, override=None, load_food=False)

fltr = self._filter_builder()
q = q.filter_by(**fltr)
if pagination.query_filter:
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)

except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))

count = q.count()

# interpret -1 as "get_all"
Expand Down
5 changes: 5 additions & 0 deletions mealie/schema/recipe/recipe_ingredient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import enum
from typing import Optional, Union
from uuid import UUID, uuid4
Expand Down Expand Up @@ -27,6 +28,8 @@ class SaveIngredientFood(CreateIngredientFood):
class IngredientFood(CreateIngredientFood):
id: UUID4
label: Optional[MultiPurposeLabelSummary] = None
created_at: Optional[datetime.datetime]
update_at: Optional[datetime.datetime]

class Config:
orm_mode = True
Expand All @@ -48,6 +51,8 @@ class SaveIngredientUnit(CreateIngredientUnit):

class IngredientUnit(CreateIngredientUnit):
id: UUID4
created_at: Optional[datetime.datetime]
update_at: Optional[datetime.datetime]

class Config:
orm_mode = True
Expand Down
1 change: 1 addition & 0 deletions mealie/schema/response/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class PaginationQuery(MealieModel):
per_page: int = 50
order_by: str = "created_at"
order_direction: OrderDirection = OrderDirection.desc
query_filter: str = None


class PaginationBase(GenericModel, Generic[DataT]):
Expand Down
231 changes: 231 additions & 0 deletions mealie/schema/response/query_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from __future__ import annotations

import re
from enum import Enum
from typing import Any, TypeVar, Union, cast

from dateutil import parser as date_parser
from dateutil.parser._parser import ParserError
from humps import decamelize
from sqlalchemy import bindparam, text
from sqlalchemy.orm.query import Query
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.expression import BindParameter

Model = TypeVar("Model")


class RelationalOperator(Enum):
EQ = "="
NOTEQ = "<>"
GT = ">"
LT = "<"
GTE = ">="
LTE = "<="


class LogicalOperator(Enum):
AND = "AND"
OR = "OR"


class QueryFilterComponent:
"""A single relational statement"""

def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None:
self.attribute_name = decamelize(attribute_name)
self.relational_operator = relational_operator
self.value = value

# remove encasing quotes
if len(value) > 2 and value[0] == '"' and value[-1] == '"':
self.value = value[1:-1]

def __repr__(self) -> str:
return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]"


class QueryFilter:
lsep: str = "("
rsep: str = ")"

def __init__(self, filter_string: str) -> None:
# parse filter string
components = QueryFilter._break_filter_string_into_components(filter_string)
base_components = QueryFilter._break_components_into_base_components(components)
if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep):
raise ValueError("invalid filter string: parenthesis are unbalanced")

# parse base components into a filter group
self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components)

def __repr__(self) -> str:
return f'<<{" ".join([str(component.value if isinstance(component, LogicalOperator) else component) for component in self.filter_components])}>>'

def filter_query(self, query: Query, model: type[Model]) -> Query:
segments: list[str] = []
params: list[BindParameter] = []
for i, component in enumerate(self.filter_components):
if component in [QueryFilter.lsep, QueryFilter.rsep]:
segments.append(component) # type: ignore
continue

if isinstance(component, LogicalOperator):
segments.append(component.value)
continue

# for some reason typing doesn't like the lsep and rsep literals, so we explicitly mark this as a filter component instead
# cast doesn't actually do anything at runtime
component = cast(QueryFilterComponent, component)

if not hasattr(model, component.attribute_name):
raise ValueError(f"invalid query string: '{component.attribute_name}' does not exist on this schema")

# convert values to their proper types
attr = getattr(model, component.attribute_name)
value: Any = component.value

if isinstance(attr.type, sqltypes.Date) or isinstance(attr.type, sqltypes.DateTime):
try:
value = date_parser.parse(component.value)

except ParserError:
raise ValueError(f"invalid query string: unknown date or datetime format '{component.value}'")

if isinstance(attr.type, sqltypes.Boolean):
try:
value = component.value.lower()[0] in ["t", "y"] or component.value == "1"

except IndexError:
raise ValueError("invalid query string")

paramkey = f"P{i+1}"
segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"]))
params.append(bindparam(paramkey, value, attr.type))

qs = text(" ".join(segments)).bindparams(*params)
query = query.filter(qs)
return query

@staticmethod
def _break_filter_string_into_components(filter_string: str) -> list[str]:
"""Recursively break filter string into components based on parenthesis groupings"""
components = [filter_string]
in_quotes = False
while True:
subcomponents = list()
for component in components:
# don't parse components comprised of only a separator
if component in [QueryFilter.lsep, QueryFilter.rsep]:
subcomponents.append(component)
continue

# construct a component until it hits the right separator
new_component = ""
for c in component:
# ignore characters in-between quotes
if c == '"':
in_quotes = not in_quotes

if c in [QueryFilter.lsep, QueryFilter.rsep] and not in_quotes:
if new_component:
subcomponents.append(new_component)

subcomponents.append(c)
new_component = ""
continue

new_component += c

if new_component:
subcomponents.append(new_component.strip())

if components == subcomponents:
break

components = subcomponents

return components

@staticmethod
def _break_components_into_base_components(components: list[str]) -> list[str]:
"""Further break down components by splitting at relational and logical operators"""
logical_operators = re.compile(
f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE
)

base_components = []
for component in components:
offset = 0
subcomponents = component.split('"')
for i, subcomponent in enumerate(subcomponents):
# don't parse components comprised of only a separator
if subcomponent in [QueryFilter.lsep, QueryFilter.rsep]:
offset += 1
base_components.append(subcomponent)
continue

# this subscomponent was surrounded in quotes, so we keep it as-is
if (i + offset) % 2:
base_components.append(f'"{subcomponent.strip()}"')
continue

# if the final subcomponent has quotes, it creates an extra empty subcomponent at the end
if not subcomponent:
continue

# parse out logical operators
new_components = [
base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component
]

# parse out relational operators; each base_subcomponent has exactly zero or one relational operator
# we do them one at a time in descending length since some operators overlap (e.g. :> and >)
for component in new_components:
if not component:
continue

added_to_base_components = False
for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True):
if rel_op in component:
new_base_components = [
base_component.strip() for base_component in component.split(rel_op) if base_component
]
new_base_components.insert(1, rel_op)
base_components.extend(new_base_components)

added_to_base_components = True
break

if not added_to_base_components:
base_components.append(component)

return base_components

@staticmethod
def _parse_base_components_into_filter_components(
base_components: list[str],
) -> list[Union[str, QueryFilterComponent, LogicalOperator]]:
"""Walk through base components and construct filter collections"""
relational_operators = [op.value for op in RelationalOperator]
logical_operators = [op.value for op in LogicalOperator]

# parse QueryFilterComponents and logical operators
components: list[Union[str, QueryFilterComponent, LogicalOperator]] = []
for i, base_component in enumerate(base_components):
if base_component in [QueryFilter.lsep, QueryFilter.rsep]:
components.append(base_component)

elif base_component in relational_operators:
components.append(
QueryFilterComponent(
attribute_name=base_components[i - 1],
relational_operator=RelationalOperator(base_components[i]),
value=base_components[i + 1],
)
)

elif base_component.upper() in logical_operators:
components.append(LogicalOperator(base_component.upper()))

return components
Loading