diff --git a/.circleci/config.yml b/.circleci/config.yml index fd80627f..2581f330 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -20,7 +20,6 @@ workflows: context: org-global-v2 - test: context: org-global-v2 - sqlalchemy_version: "1.3" matrix: parameters: sqlalchemy_version: ["1.3", "1.4"] diff --git a/README.md b/README.md index 12bccf72..6a573e3b 100644 --- a/README.md +++ b/README.md @@ -96,24 +96,25 @@ def run_bg(): ```python from fastapi import APIRouter, Depends -from fastapi_sqla import Base, Page, Paginate, Session +from fastapi_sqla import Base, Page, Paginate from pydantic import BaseModel +from sqlalchemy import select router = APIRouter() -class UserEntity(Base): +class User(Base): __tablename__ = "user" -class User(BaseModel): +class UserModel(BaseModel): id: int name: str -@router.get("/users", response_model=Page[User]) -def all_users(session: Session = Depends(), paginate: Paginate = Depends()): - query = session.query(UserEntity) +@router.get("/users", response_model=Page[UserModel]) +def all_users(paginate: Paginate = Depends()): + query = select(User) return paginate(query) ``` @@ -134,23 +135,23 @@ To customize pagination, create a dependency using `fastapi_sqla.Pagination` from fastapi import APIRouter, Depends from fastapi_sqla import Base, Page, Pagination, Session from pydantic import BaseModel -from sqlalchemy import func -from sqlalchemy.orm import Query +from sqlalchemy import func, select +from sqlalchemy.sql import Select router = APIRouter() -class UserEntity(Base): +class User(Base): __tablename__ = "user" -class User(BaseModel): +class UserModel(BaseModel): id: int name: str -def query_count(session: Session, query: Query): - return query.statement.with_only_columns([func.count()]).scalar() +def query_count(session: Session, query: Select) -> int: + return session.execute(select(func.count()).select_from(User)).scalar() Paginate = Pagination( @@ -160,12 +161,9 @@ Paginate = Pagination( ) -@router.get("/users", response_model=Page[User]) -def all_users( - session: Session = Depends(), - paginate: Paginate = Depends(), -): - query = session.query(UserEntity) +@router.get("/users", response_model=Page[UserModel]) +def all_users(paginate: Paginate = Depends()): + query = select(User) return paginate(query) ``` diff --git a/fastapi_sqla/__init__.py b/fastapi_sqla/__init__.py index 5f4abbf0..ea4c9374 100644 --- a/fastapi_sqla/__init__.py +++ b/fastapi_sqla/__init__.py @@ -2,7 +2,8 @@ import math import os from contextlib import contextmanager -from typing import Callable, Generic, List, TypeVar +from functools import singledispatch +from typing import Callable, Generic, List, TypeVar, Union import structlog from fastapi import Depends, FastAPI, Query, Request @@ -11,9 +12,10 @@ from pydantic.generics import GenericModel from sqlalchemy import engine_from_config from sqlalchemy.ext.declarative import DeferredReflection, declarative_base -from sqlalchemy.orm import Query as DbQuery +from sqlalchemy.orm import Query as LegacyQuery from sqlalchemy.orm.session import Session as SqlaSession from sqlalchemy.orm.session import sessionmaker +from sqlalchemy.sql import Select, func, select __all__ = ["Base", "Page", "Paginate", "Session", "open_session", "setup"] @@ -147,24 +149,92 @@ class Page(Collection, Generic[T]): meta: Meta -def _query_count(session: Session, query: DbQuery) -> int: +DbQuery = Union[LegacyQuery, Select] +QueryCount = Callable[[SqlaSession, DbQuery], int] +QueryCountDependency = Callable[..., QueryCount] +PaginateSignature = Callable[[DbQuery], Page[T]] + + +def query_count(session: Session, query: DbQuery) -> int: """Default function used to count items returned by a query. - Default Query.count is slower than a manually written query could be: It runs the - query in a subquery, and count the number of elements returned: + It is slower than a manually written query could be: It runs the query in a subquery, + and count the number of elements returned. See https://gist.github.com/hest/8798884 """ - return query.count() - - -PaginateSignature = Callable[[DbQuery], Page[T]] + if isinstance(query, LegacyQuery): + result = query.count() + + elif isinstance(query, Select): + result = session.execute(select(func.count()).select_from(query)).scalar() + + else: # pragma no cover + raise NotImplementedError(f"Query type {type(query)!r} is not supported") + + return result + + +@singledispatch +def paginate_query( + query: DbQuery, + session: Session, + total_items: int, + offset: int, + limit: int, +) -> Page[T]: # pragma no cover + "Dispatch on registered functions based on `query` type" + raise NotImplementedError(f"no paginate_query registered for type {type(query)!r}") + + +@paginate_query.register +def _paginate_legacy( + query: LegacyQuery, + session: Session, + total_items: int, + offset: int, + limit: int, +) -> Page[T]: + total_pages = math.ceil(total_items / limit) + page_number = offset / limit + 1 + return Page[T]( + data=query.offset(offset).limit(limit).all(), + meta={ + "offset": offset, + "total_items": total_items, + "total_pages": total_pages, + "page_number": page_number, + }, + ) + + +@paginate_query.register +def _paginate( + query: Select, + session: Session, + total_items: int, + offset: int, + limit: int, +) -> Page[T]: + total_pages = math.ceil(total_items / limit) + page_number = offset / limit + 1 + query = query.offset(offset).limit(limit) + result = session.execute(query) + return Page[T]( + data=iter(result.unique().scalars()), + meta={ + "offset": offset, + "total_items": total_items, + "total_pages": total_pages, + "page_number": page_number, + }, + ) def Pagination( min_page_size: int = 10, max_page_size: int = 100, - query_count: Callable[[Session, DbQuery], int] = _query_count, + query_count: QueryCount = query_count, ) -> Callable[[Session, int, int], PaginateSignature]: def dependency( session: Session = Depends(), @@ -173,17 +243,7 @@ def dependency( ) -> PaginateSignature: def paginate(query: DbQuery) -> Page[T]: total_items = query_count(session, query) - total_pages = math.ceil(total_items / limit) - page_number = offset / limit + 1 - return Page[T]( - data=query.offset(offset).limit(limit).all(), - meta={ - "offset": offset, - "total_items": total_items, - "total_pages": total_pages, - "page_number": page_number, - }, - ) + return paginate_query(query, session, total_items, offset, limit) return paginate diff --git a/tests/test_pagination.py b/tests/test_pagination.py index fbc965d2..f3f8c11c 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -6,9 +6,10 @@ from faker import Faker from fastapi import Depends, FastAPI from pydantic import BaseModel -from pytest import fixture, mark -from sqlalchemy import MetaData, Table, func +from pytest import fixture, mark, param +from sqlalchemy import MetaData, Table, func, select from sqlalchemy.orm import joinedload, relationship +from sqlalchemy.sql import Select @fixture(scope="module", autouse=True) @@ -91,16 +92,58 @@ def test_pagination(session, user_cls, offset, limit, total_pages, page_number): "offset,limit,total_pages,page_number", [(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)], ) -def test_Pagination_with_custom_count( +def test_pagination_with_legacy_query_count( session, user_cls, offset, limit, total_pages, page_number ): - from fastapi_sqla import Pagination + from fastapi_sqla import Paginate + + query = session.query(user_cls).options(joinedload("notes")) + result = Paginate(session, offset, limit)(query) + + assert result.meta.total_items == 42 + assert result.meta.offset == offset + assert result.meta.total_pages == total_pages + assert result.meta.page_number == page_number + + +@mark.sqlalchemy("1.3") +@mark.parametrize( + "offset,limit,total_pages,page_number", + [(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)], +) +def test_Pagination_with_custom_sqla13_compliant_count( + session, user_cls, offset, limit, total_pages, page_number +): + from fastapi_sqla import DbQuery, Pagination, Session + + def query_count(session: Session, query: DbQuery) -> int: + return ( + session.query(user_cls).statement.with_only_columns([func.count()]).scalar() + ) + + pagination = Pagination(query_count=query_count) + query = session.query(user_cls).options(joinedload("notes")) + result = pagination(session, offset, limit)(query) + + assert result.meta.total_items == 42 + assert result.meta.offset == offset + assert result.meta.total_pages == total_pages + assert result.meta.page_number == page_number + + +@mark.sqlalchemy("1.4") +@mark.parametrize( + "offset,limit,total_pages,page_number", + [(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)], +) +def test_Pagination_with_custom_sqla14_compliant_count( + session, user_cls, offset, limit, total_pages, page_number +): + from fastapi_sqla import DbQuery, Pagination, Session + + def query_count(session: Session, query: DbQuery) -> int: + return session.execute(select(func.count(user_cls.id))).scalar() - query_count = ( - lambda sess, _: session.query(user_cls) - .statement.with_only_columns([func.count()]) - .scalar() - ) pagination = Pagination(query_count=query_count) query = session.query(user_cls).options(joinedload("notes")) result = pagination(session, offset, limit)(query) @@ -113,9 +156,14 @@ def test_Pagination_with_custom_count( @fixture def app(user_cls, note_cls): - from sqlalchemy.orm import joinedload - - from fastapi_sqla import Page, Paginate, Session, setup + from fastapi_sqla import ( + Page, + Paginate, + PaginateSignature, + Pagination, + Session, + setup, + ) app = FastAPI() setup(app) @@ -135,9 +183,26 @@ class User(BaseModel): class Config: orm_mode = True - @app.get("/users", response_model=Page[User]) - def all_users(session: Session = Depends(), paginate: Paginate = Depends()): - query = session.query(user_cls).options(joinedload("notes")) + @app.get("/v1/users", response_model=Page[User]) + def sqla_13_all_users(session: Session = Depends(), paginate: Paginate = Depends()): + query = ( + session.query(user_cls).options(joinedload("notes")).order_by(user_cls.id) + ) + return paginate(query) + + @app.get("/v2/users", response_model=Page[User]) + def sqla_14_all_users(paginate: Paginate = Depends()): + query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id) + return paginate(query) + + def query_count(session: Session, query: Select) -> int: + return session.execute(select(func.count()).select_from(user_cls)).scalar() + + CustomPaginate: PaginateSignature = Pagination(query_count=query_count) + + @app.get("/v2/custom/users", response_model=Page[User]) + def sqla_14_all_users_custom_pagination(paginate: CustomPaginate = Depends()): + query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id) return paginate(query) return app @@ -154,14 +219,27 @@ async def client(app): @mark.asyncio @mark.parametrize( - "offset,items_number", - [(0, 10), (10, 10), (40, 2)], + "offset,items_number,path", + [ + param(0, 10, "/v1/users"), + param(10, 10, "/v1/users"), + param(40, 2, "/v1/users"), + param(0, 10, "/v2/users", marks=mark.sqlalchemy("1.4")), + param(10, 10, "/v2/users", marks=mark.sqlalchemy("1.4")), + param(40, 2, "/v2/users", marks=mark.sqlalchemy("1.4")), + param(0, 10, "/v2/custom/users", marks=mark.sqlalchemy("1.4")), + param(10, 10, "/v2/custom/users", marks=mark.sqlalchemy("1.4")), + param(40, 2, "/v2/custom/users", marks=mark.sqlalchemy("1.4")), + ], ) -async def test_functional(client, offset, items_number): - result = await client.get("/users", params={"offset": offset}) +async def test_functional(client, offset, items_number, path): + result = await client.get(path, params={"offset": offset}) assert result.status_code == 200, result.json() users = result.json()["data"] assert len(users) == items_number user_ids = [u["id"] for u in users] assert user_ids == list(range(offset + 1, offset + 1 + items_number)) + + meta = result.json()["meta"] + assert meta["total_items"] == 42