diff --git a/conda-store-server/conda_store_server/_internal/schema.py b/conda-store-server/conda_store_server/_internal/schema.py index da3d5bcca..b27cc17a2 100644 --- a/conda-store-server/conda_store_server/_internal/schema.py +++ b/conda-store-server/conda_store_server/_internal/schema.py @@ -586,6 +586,14 @@ class APIPaginatedResponse(APIResponse): count: int +class APICursorPaginatedResponse(BaseModel): + data: Optional[Any] + status: APIStatus + message: Optional[str] + cursor: Optional[str] + count: int + + class APIAckResponse(BaseModel): status: APIStatus message: Optional[str] = None @@ -650,7 +658,7 @@ class APIDeleteNamespaceRole(BaseModel): # GET /api/v1/environment -class APIListEnvironment(APIPaginatedResponse): +class APIListEnvironment(APICursorPaginatedResponse): data: List[Environment] diff --git a/conda-store-server/conda_store_server/_internal/server/views/api.py b/conda-store-server/conda_store_server/_internal/server/views/api.py index 43d0a213d..0734fdc82 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/api.py +++ b/conda-store-server/conda_store_server/_internal/server/views/api.py @@ -19,7 +19,12 @@ Permissions, ) from conda_store_server._internal.server import dependencies -from conda_store_server._internal.server.views.pagination import Cursor, paginate +from conda_store_server._internal.server.views.pagination import ( + Cursor, + CursorPaginatedArgs, + OrderingMetadata, + paginate, +) from conda_store_server.server.auth import Authentication @@ -27,6 +32,19 @@ def get_cursor(encoded_cursor: Optional[str] = None) -> Cursor: return Cursor.load(encoded_cursor) +def get_cursor_paginated_args( + order: Optional[str] = None, + limit: Optional[int] = None, + sort_by: List[str] = Query([]), + server=Depends(dependencies.get_server), +) -> CursorPaginatedArgs: + return CursorPaginatedArgs( + limit=server.max_page_size if limit is None else limit, + order=order, + sort_by=sort_by, + ) + + class PaginatedArgs(TypedDict): """Dictionary type holding information about paginated requests.""" @@ -639,7 +657,7 @@ async def api_list_environments( auth: Authentication = Depends(dependencies.get_auth), conda_store: app.CondaStore = Depends(dependencies.get_conda_store), entity: AuthenticationToken = Depends(dependencies.get_entity), - paginated_args: PaginatedArgs = Depends(get_paginated_args), + paginated_args: CursorPaginatedArgs = Depends(get_cursor_paginated_args), cursor: Cursor = Depends(get_cursor), artifact: Optional[schema.BuildArtifactType] = None, jwt: Optional[str] = None, @@ -648,7 +666,7 @@ async def api_list_environments( packages: Optional[List[str]] = Query([]), search: Optional[str] = None, status: Optional[schema.BuildStatus] = None, -): +) -> schema.APIListEnvironment: """Retrieve a list of environments. Parameters @@ -659,7 +677,7 @@ async def api_list_environments( the request entity : AuthenticationToken Token of the user making the request - paginated_args : PaginatedArgs + paginated_args : CursorPaginatedArgs Arguments for controlling pagination of the response conda_store : app.CondaStore The running conda store application @@ -683,7 +701,7 @@ async def api_list_environments( Returns ------- - Dict + schema.APIListEnvironment Paginated JSON response containing the requested environments. Results are sorted by each envrionment's build's scheduled_on time to ensure all results are returned when iterating over pages in systems where the number of environments is changing while results are being @@ -716,51 +734,22 @@ async def api_list_environments( role_bindings=auth.entity_bindings(entity), ) - valid_sort_by = { - "namespace": orm.Namespace.name, - "name": orm.Environment.name, - } - - # sorts = get_sorts( - # order=paginated_args["order"], - # sort_by=paginated_args["sort_by"], - # allowed_sort_bys=valid_sort_by, - # default_sort_by=["namespace", "name"], - # default_order="asc", - # ) - - # query = ( - # query - # .filter( - # or_( - # orm.Namespace.name > cursor.order_by['namespace'], - # orm.Namespace.name == cursor.order_by['namespace'] - # ) - # ) - # .order_by( - # *sorts, - # orm.Environment.id.asc() - # ) - # ) - - return paginate( + paginated, next_cursor = paginate( query=query, + ordering_metadata=OrderingMetadata( + order_names=["namespace", "name"], + column_names=["namespace.name", "name"], + ), cursor=cursor, - sort_by=paginated_args["sort_by"], - valid_sort_by=valid_sort_by, + order_by=paginated_args["sort_by"], + limit=paginated_args["limit"], ) - return paginated_api_response( - query, - paginated_args, - schema.Environment, - exclude={"current_build"}, - allowed_sort_bys={ - "namespace": orm.Namespace.name, - "name": orm.Environment.name, - }, - default_sort_by=["namespace", "name"], - default_order="asc", + return schema.APIListEnvironment( + data=paginated, + status="ok", + cursor=None if next_cursor is None else next_cursor.dump(), + count=1000, ) diff --git a/conda-store-server/conda_store_server/_internal/server/views/pagination.py b/conda-store-server/conda_store_server/_internal/server/views/pagination.py index ae3e914df..5d76e401f 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/pagination.py +++ b/conda-store-server/conda_store_server/_internal/server/views/pagination.py @@ -1,40 +1,59 @@ +from __future__ import annotations + import base64 +import operator +from typing import Any, TypedDict import pydantic -from sqlalchemy import tuple_ +from fastapi import HTTPException +from sqlalchemy import asc, desc, tuple_ from sqlalchemy.orm import Query as SqlQuery +from sqlalchemy.sql.expression import ColumnClause class Cursor(pydantic.BaseModel): - last_id: int | None = 1 + last_id: int | None = 0 + count: int | None = None - # List of names of attributes to order by, and the last value of the ordered attribute + # List query parameters to order by, and the last value of the ordered attribute # { # 'namespace': 'foo', # 'environment': 'bar', # } last_value: dict[str, str] | None = {} - def dump(self): + def dump(self) -> str: return base64.b64encode(self.model_dump_json()) @classmethod - def load(cls, data: str | None = None): + def load(cls, data: str | None = None) -> Cursor | None: if data is None: - return cls() + return None return cls.from_json(base64.b64decode(data)) + def get_last_values(self, order_names: list[str]) -> list[Any]: + if order_names: + return [self.last_value[name] for name in order_names] + else: + return [] + def paginate( query: SqlQuery, - cursor: Cursor, - sort_by: list[str] | None = None, - valid_sort_by: dict[str, object] | None = None, -) -> SqlQuery: + ordering_metadata: OrderingMetadata, + cursor: Cursor | None = None, + order_by: list[str] | None = None, + # valid_order_by: dict[str, str] | None = None, + order: str = "asc", + limit: int = 10, +) -> tuple[SqlQuery, Cursor]: """Paginate the query using the cursor and the requested sort_bys. - With cursor pagination, all keys used to order must be included in - the call to query.filter(). + This function assumes that the first column of the query contains + the type whose ID should be used to sort the results. + + Additionally, with cursor pagination all keys used to order the results + must be included in the call to query.filter(). https://medium.com/@george_16060/cursor-based-pagination-with-arbitrary-ordering-b4af6d5e22db @@ -42,28 +61,156 @@ def paginate( ---------- query : SqlQuery Query containing database results to paginate - cursor : Cursor - Cursor object containing information about the last item - on the previous page - sort_by : list[str] | None + valid_order_by : dict[str, str] | None + Mapping between valid names to order by and the column names on the orm object they apply to + cursor : Cursor | None + Cursor object containing information about the last item on the previous page. + If None, the first page is returned. + order_by : list[str] | None List of sort_by query parameters - valid_sort_by : dict[str, object] | None - Mapping between query parameter names and the orm object they apply to + + Returns + ------- + tuple[SqlQuery, Cursor] + Query containing the paginated results, and Cursor for retrieving + the next page """ - breakpoint() + if order_by is None: + order_by = [] + + if order == "asc": + comparison = operator.gt + order_func = asc + elif order == "desc": + comparison = operator.lt + order_func = desc + else: + raise HTTPException( + status_code=400, + detail=f"Invalid query parameter: order = {order}; must be one of ['asc', 'desc']", + ) + + # Get the python type of the objects being queried + queried_type = query.column_descriptions[0]["type"] + columns = ordering_metadata.get_requested_columns(order_by) + + # If there's a cursor already, use the last attributes to filter + # the results by (*attributes, id) >/< (*last_values, last_id) + # Order by desc or asc + if cursor is not None: + last_values = cursor.get_last_values(order_by) + query = query.filter( + comparison( + tuple_(*columns, queried_type.id), + (*last_values, cursor.last_id), + ) + ) + + query = query.order_by( + *[order_func(col) for col in columns], order_func(queried_type.id) + ) + data = query.limit(limit).all() + count = query.count() + + if count > 0: + last_result = data[-1] + last_value = ordering_metadata.get_attr_values(last_result, order_by) + + next_cursor = Cursor( + last_id=data[-1].id, last_value=last_value, count=query.count() + ) + else: + next_cursor = None + + return (data, next_cursor) + + +class CursorPaginatedArgs(TypedDict): + limit: int + order: str + sort_by: list[str] + + +class OrderingMetadata: + def __init__( + self, + order_names: list[str] | None = None, + column_names: list[str] | None = None, + ): + self.order_names = order_names + self.column_names = column_names + + def validate(self, model: Any): + if len(self.order_names) != len(self.column_names): + raise ValueError( + "Each name of a valid ordering available to the order_by query parameter" + "must have an associated column name to select in the table." + ) + + for col in self.column_names: + if not hasattr(model, col): + raise ValueError(f"No column named {col} found on model {model}.") + + def get_requested_columns( + self, + order_by: list[str] | None = None, + ) -> list[ColumnClause]: + """Get a list of sqlalchemy columns requested by the value of the order_by query param. + + Parameters + ---------- + order_by : list[str] | None + If specified, this should be a subset of self.order_names. If none, an + empty list is returned. + + Returns + ------- + list[ColumnClause] + A list of sqlalchemy columns corresponding to the order_by values passed + as a query parameter + """ + columns = [] + if order_by: + for order_name in order_by: + idx = self.order_names.index(order_name) + columns.append(self.column_names[idx]) + + return columns + + def get_attr_values( + self, + obj: Any, + order_by: list[str] | None = None, + ) -> dict[str, Any]: + """Using the order_by values, get the corresponding attribute values on obj. + + Parameters + ---------- + obj : Any + sqlalchemy model containing attribute names that are contained in + `self.column_names` + order_by : list[str] | None + Values that the user wants to order by; these are used to look up the corresponding + column names that are used to access the attributes of `obj`. + + Returns + ------- + dict[str, Any] + A mapping between the `order_by` values and the attribute values on `obj` + + """ + values = {} + for order_name in order_by: + idx = self.order_names.index(order_name) + values[order_name] = get_nested_attribute(obj, self.column_names[idx]) - if sort_by is None: - sort_by = [] + return values - if valid_sort_by is None: - valid_sort_by = {} - objects = [] - last_values = [] - for obj in sort_by: - objects.append(valid_sort_by[obj]) - last_values.append(cursor.last_value[obj]) +def get_nested_attribute(obj: Any, attr: str) -> Any: + attribute, *rest = attr.split(".") + while len(rest) > 0: + obj = getattr(obj, attribute) + attribute, *rest = rest - return query.filter( - tuple_(*objects, object.id) > (*last_values, cursor.last_id) - ) # .order_by(sorts) + return getattr(obj, attribute) diff --git a/conda-store-server/tests/_internal/server/views/test_api.py b/conda-store-server/tests/_internal/server/views/test_api.py index 5348d414c..cca4eb75e 100644 --- a/conda-store-server/tests/_internal/server/views/test_api.py +++ b/conda-store-server/tests/_internal/server/views/test_api.py @@ -1071,3 +1071,18 @@ def test_default_conda_store_dir(): assert dir == rf"C:\Users\{user}\AppData\Local\conda-store\conda-store" else: assert dir == f"/home/{user}/.local/share/conda-store" + + +def test_api_list_environments( + conda_store_server, + testclient, + seed_conda_store, + authenticate, +): + """Test that the REST API lists the expected paginated environments.""" + response = testclient.get("api/v1/environment/?sort_by=name") + response.raise_for_status() + + r = schema.APIListEnvironment.parse_obj(response.json()) + + assert r.status == schema.APIStatus.OK