Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port to SQLAlchemy 1.4 / 2.0 #28

Merged
merged 10 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ workflows:
context: org-global-v2
- test:
context: org-global-v2
sqlalchemy_version: "1.3"
matrix:
parameters:
sqlalchemy_version: ["1.3", "1.4"]
Expand Down
102 changes: 81 additions & 21 deletions fastapi_sqla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import math
import os
from contextlib import contextmanager
from typing import Callable, Generic, List, TypeVar
from functools import singledispatch
from typing import Callable, Generic, List, TypeVar, Union

import structlog
from fastapi import Depends, FastAPI, Query, Request
Expand All @@ -11,9 +12,10 @@
from pydantic.generics import GenericModel
from sqlalchemy import engine_from_config
from sqlalchemy.ext.declarative import DeferredReflection, declarative_base
from sqlalchemy.orm import Query as DbQuery
from sqlalchemy.orm import Query as LegacyQuery
from sqlalchemy.orm.session import Session as SqlaSession
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.sql import Select, func, select

__all__ = ["Base", "Page", "Paginate", "Session", "open_session", "setup"]

Expand Down Expand Up @@ -147,24 +149,92 @@ class Page(Collection, Generic[T]):
meta: Meta


def _query_count(session: Session, query: DbQuery) -> int:
DbQuery = Union[LegacyQuery, Select]
QueryCount = Callable[[SqlaSession, DbQuery], int]
QueryCountDependency = Callable[..., QueryCount]
PaginateSignature = Callable[[DbQuery], Page[T]]


def query_count(session: Session, query: DbQuery) -> int:
"""Default function used to count items returned by a query.

Default Query.count is slower than a manually written query could be: It runs the
query in a subquery, and count the number of elements returned:
It is slower than a manually written query could be: It runs the query in a subquery,
and count the number of elements returned.

See https://gist.github.com/hest/8798884
"""
return query.count()


PaginateSignature = Callable[[DbQuery], Page[T]]
if isinstance(query, LegacyQuery):
result = query.count()

elif isinstance(query, Select):
result = session.execute(select(func.count()).select_from(query)).scalar()

else: # pragma no cover
raise NotImplementedError(f"Query type {type(query)!r} is not supported")

return result


@singledispatch
def paginate_query(
query: DbQuery,
session: Session,
total_items: int,
offset: int,
limit: int,
) -> Page[T]: # pragma no cover
"Dispatch on registered functions based on `query` type"
raise NotImplementedError(f"no paginate_query registered for type {type(query)!r}")


@paginate_query.register
def _paginate_legacy(
query: LegacyQuery,
session: Session,
total_items: int,
offset: int,
limit: int,
) -> Page[T]:
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
return Page[T](
data=query.offset(offset).limit(limit).all(),
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)


@paginate_query.register
def _paginate(
query: Select,
session: Session,
total_items: int,
offset: int,
limit: int,
) -> Page[T]:
hadrien marked this conversation as resolved.
Show resolved Hide resolved
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
query = query.offset(offset).limit(limit)
result = session.execute(query)
return Page[T](
data=iter(result.unique().scalars()),
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)


def Pagination(
min_page_size: int = 10,
max_page_size: int = 100,
query_count: Callable[[Session, DbQuery], int] = _query_count,
query_count: QueryCount = query_count,
) -> Callable[[Session, int, int], PaginateSignature]:
def dependency(
session: Session = Depends(),
Expand All @@ -173,17 +243,7 @@ def dependency(
) -> PaginateSignature:
def paginate(query: DbQuery) -> Page[T]:
total_items = query_count(session, query)
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
return Page[T](
data=query.offset(offset).limit(limit).all(),
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)
return paginate_query(query, session, total_items, offset, limit)

return paginate

Expand Down
116 changes: 97 additions & 19 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from faker import Faker
from fastapi import Depends, FastAPI
from pydantic import BaseModel
from pytest import fixture, mark
from sqlalchemy import MetaData, Table, func
from pytest import fixture, mark, param
from sqlalchemy import MetaData, Table, func, select
from sqlalchemy.orm import joinedload, relationship
from sqlalchemy.sql import Select


@fixture(scope="module", autouse=True)
Expand Down Expand Up @@ -91,16 +92,58 @@ def test_pagination(session, user_cls, offset, limit, total_pages, page_number):
"offset,limit,total_pages,page_number",
[(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)],
)
def test_Pagination_with_custom_count(
def test_pagination_with_legacy_query_count(
session, user_cls, offset, limit, total_pages, page_number
):
from fastapi_sqla import Pagination
from fastapi_sqla import Paginate

query = session.query(user_cls).options(joinedload("notes"))
result = Paginate(session, offset, limit)(query)

assert result.meta.total_items == 42
assert result.meta.offset == offset
assert result.meta.total_pages == total_pages
assert result.meta.page_number == page_number


@mark.sqlalchemy("1.3")
@mark.parametrize(
"offset,limit,total_pages,page_number",
[(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)],
)
def test_Pagination_with_custom_sqla13_compliant_count(
session, user_cls, offset, limit, total_pages, page_number
):
from fastapi_sqla import DbQuery, Pagination, Session

def query_count(session: Session, query: DbQuery) -> int:
return (
session.query(user_cls).statement.with_only_columns([func.count()]).scalar()
)

pagination = Pagination(query_count=query_count)
query = session.query(user_cls).options(joinedload("notes"))
result = pagination(session, offset, limit)(query)

assert result.meta.total_items == 42
assert result.meta.offset == offset
assert result.meta.total_pages == total_pages
assert result.meta.page_number == page_number


@mark.sqlalchemy("1.4")
@mark.parametrize(
"offset,limit,total_pages,page_number",
[(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)],
)
def test_Pagination_with_custom_sqla14_compliant_count(
session, user_cls, offset, limit, total_pages, page_number
):
from fastapi_sqla import DbQuery, Pagination, Session

def query_count(session: Session, query: DbQuery) -> int:
return session.execute(select(func.count(user_cls.id))).scalar()

query_count = (
lambda sess, _: session.query(user_cls)
.statement.with_only_columns([func.count()])
.scalar()
)
pagination = Pagination(query_count=query_count)
query = session.query(user_cls).options(joinedload("notes"))
result = pagination(session, offset, limit)(query)
Expand All @@ -113,9 +156,14 @@ def test_Pagination_with_custom_count(

@fixture
def app(user_cls, note_cls):
from sqlalchemy.orm import joinedload

from fastapi_sqla import Page, Paginate, Session, setup
from fastapi_sqla import (
Page,
Paginate,
PaginateSignature,
Pagination,
Session,
setup,
)

app = FastAPI()
setup(app)
Expand All @@ -135,9 +183,26 @@ class User(BaseModel):
class Config:
orm_mode = True

@app.get("/users", response_model=Page[User])
def all_users(session: Session = Depends(), paginate: Paginate = Depends()):
query = session.query(user_cls).options(joinedload("notes"))
@app.get("/v1/users", response_model=Page[User])
def sqla_13_all_users(session: Session = Depends(), paginate: Paginate = Depends()):
query = (
session.query(user_cls).options(joinedload("notes")).order_by(user_cls.id)
)
return paginate(query)

@app.get("/v2/users", response_model=Page[User])
def sqla_14_all_users(paginate: Paginate = Depends()):
query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id)
return paginate(query)

def query_count(session: Session, query: Select) -> int:
return session.execute(select(func.count()).select_from(user_cls)).scalar()

CustomPaginate: PaginateSignature = Pagination(query_count=query_count)

@app.get("/v2/custom/users", response_model=Page[User])
def sqla_14_all_users_custom_pagination(paginate: CustomPaginate = Depends()):
query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id)
return paginate(query)

return app
Expand All @@ -154,14 +219,27 @@ async def client(app):

@mark.asyncio
@mark.parametrize(
"offset,items_number",
[(0, 10), (10, 10), (40, 2)],
"offset,items_number,path",
[
param(0, 10, "/v1/users"),
param(10, 10, "/v1/users"),
param(40, 2, "/v1/users"),
param(0, 10, "/v2/users", marks=mark.sqlalchemy("1.4")),
param(10, 10, "/v2/users", marks=mark.sqlalchemy("1.4")),
param(40, 2, "/v2/users", marks=mark.sqlalchemy("1.4")),
param(0, 10, "/v2/custom/users", marks=mark.sqlalchemy("1.4")),
param(10, 10, "/v2/custom/users", marks=mark.sqlalchemy("1.4")),
param(40, 2, "/v2/custom/users", marks=mark.sqlalchemy("1.4")),
],
)
async def test_functional(client, offset, items_number):
result = await client.get("/users", params={"offset": offset})
async def test_functional(client, offset, items_number, path):
result = await client.get(path, params={"offset": offset})

assert result.status_code == 200, result.json()
users = result.json()["data"]
assert len(users) == items_number
user_ids = [u["id"] for u in users]
assert user_ids == list(range(offset + 1, offset + 1 + items_number))

meta = result.json()["meta"]
assert meta["total_items"] == 42