Skip to content

Commit

Permalink
feat: port to SQLAlchemy 1.4 / 2.0 (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
hadrien authored Apr 8, 2021
1 parent 78308f4 commit cbf06d1
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 59 deletions.
1 change: 0 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ workflows:
context: org-global-v2
- test:
context: org-global-v2
sqlalchemy_version: "1.3"
matrix:
parameters:
sqlalchemy_version: ["1.3", "1.4"]
Expand Down
34 changes: 16 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,25 @@ def run_bg():

```python
from fastapi import APIRouter, Depends
from fastapi_sqla import Base, Page, Paginate, Session
from fastapi_sqla import Base, Page, Paginate
from pydantic import BaseModel
from sqlalchemy import select

router = APIRouter()


class UserEntity(Base):
class User(Base):
__tablename__ = "user"


class User(BaseModel):
class UserModel(BaseModel):
id: int
name: str


@router.get("/users", response_model=Page[User])
def all_users(session: Session = Depends(), paginate: Paginate = Depends()):
query = session.query(UserEntity)
@router.get("/users", response_model=Page[UserModel])
def all_users(paginate: Paginate = Depends()):
query = select(User)
return paginate(query)
```

Expand All @@ -134,23 +135,23 @@ To customize pagination, create a dependency using `fastapi_sqla.Pagination`
from fastapi import APIRouter, Depends
from fastapi_sqla import Base, Page, Pagination, Session
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Query
from sqlalchemy import func, select
from sqlalchemy.sql import Select

router = APIRouter()


class UserEntity(Base):
class User(Base):
__tablename__ = "user"


class User(BaseModel):
class UserModel(BaseModel):
id: int
name: str


def query_count(session: Session, query: Query):
return query.statement.with_only_columns([func.count()]).scalar()
def query_count(session: Session, query: Select) -> int:
return session.execute(select(func.count()).select_from(User)).scalar()


Paginate = Pagination(
Expand All @@ -160,12 +161,9 @@ Paginate = Pagination(
)


@router.get("/users", response_model=Page[User])
def all_users(
session: Session = Depends(),
paginate: Paginate = Depends(),
):
query = session.query(UserEntity)
@router.get("/users", response_model=Page[UserModel])
def all_users(paginate: Paginate = Depends()):
query = select(User)
return paginate(query)
```

Expand Down
102 changes: 81 additions & 21 deletions fastapi_sqla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import math
import os
from contextlib import contextmanager
from typing import Callable, Generic, List, TypeVar
from functools import singledispatch
from typing import Callable, Generic, List, TypeVar, Union

import structlog
from fastapi import Depends, FastAPI, Query, Request
Expand All @@ -11,9 +12,10 @@
from pydantic.generics import GenericModel
from sqlalchemy import engine_from_config
from sqlalchemy.ext.declarative import DeferredReflection, declarative_base
from sqlalchemy.orm import Query as DbQuery
from sqlalchemy.orm import Query as LegacyQuery
from sqlalchemy.orm.session import Session as SqlaSession
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.sql import Select, func, select

__all__ = ["Base", "Page", "Paginate", "Session", "open_session", "setup"]

Expand Down Expand Up @@ -147,24 +149,92 @@ class Page(Collection, Generic[T]):
meta: Meta


def _query_count(session: Session, query: DbQuery) -> int:
DbQuery = Union[LegacyQuery, Select]
QueryCount = Callable[[SqlaSession, DbQuery], int]
QueryCountDependency = Callable[..., QueryCount]
PaginateSignature = Callable[[DbQuery], Page[T]]


def query_count(session: Session, query: DbQuery) -> int:
"""Default function used to count items returned by a query.
Default Query.count is slower than a manually written query could be: It runs the
query in a subquery, and count the number of elements returned:
It is slower than a manually written query could be: It runs the query in a subquery,
and count the number of elements returned.
See https://gist.github.com/hest/8798884
"""
return query.count()


PaginateSignature = Callable[[DbQuery], Page[T]]
if isinstance(query, LegacyQuery):
result = query.count()

elif isinstance(query, Select):
result = session.execute(select(func.count()).select_from(query)).scalar()

else: # pragma no cover
raise NotImplementedError(f"Query type {type(query)!r} is not supported")

return result


@singledispatch
def paginate_query(
query: DbQuery,
session: Session,
total_items: int,
offset: int,
limit: int,
) -> Page[T]: # pragma no cover
"Dispatch on registered functions based on `query` type"
raise NotImplementedError(f"no paginate_query registered for type {type(query)!r}")


@paginate_query.register
def _paginate_legacy(
query: LegacyQuery,
session: Session,
total_items: int,
offset: int,
limit: int,
) -> Page[T]:
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
return Page[T](
data=query.offset(offset).limit(limit).all(),
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)


@paginate_query.register
def _paginate(
query: Select,
session: Session,
total_items: int,
offset: int,
limit: int,
) -> Page[T]:
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
query = query.offset(offset).limit(limit)
result = session.execute(query)
return Page[T](
data=iter(result.unique().scalars()),
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)


def Pagination(
min_page_size: int = 10,
max_page_size: int = 100,
query_count: Callable[[Session, DbQuery], int] = _query_count,
query_count: QueryCount = query_count,
) -> Callable[[Session, int, int], PaginateSignature]:
def dependency(
session: Session = Depends(),
Expand All @@ -173,17 +243,7 @@ def dependency(
) -> PaginateSignature:
def paginate(query: DbQuery) -> Page[T]:
total_items = query_count(session, query)
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
return Page[T](
data=query.offset(offset).limit(limit).all(),
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)
return paginate_query(query, session, total_items, offset, limit)

return paginate

Expand Down
116 changes: 97 additions & 19 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from faker import Faker
from fastapi import Depends, FastAPI
from pydantic import BaseModel
from pytest import fixture, mark
from sqlalchemy import MetaData, Table, func
from pytest import fixture, mark, param
from sqlalchemy import MetaData, Table, func, select
from sqlalchemy.orm import joinedload, relationship
from sqlalchemy.sql import Select


@fixture(scope="module", autouse=True)
Expand Down Expand Up @@ -91,16 +92,58 @@ def test_pagination(session, user_cls, offset, limit, total_pages, page_number):
"offset,limit,total_pages,page_number",
[(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)],
)
def test_Pagination_with_custom_count(
def test_pagination_with_legacy_query_count(
session, user_cls, offset, limit, total_pages, page_number
):
from fastapi_sqla import Pagination
from fastapi_sqla import Paginate

query = session.query(user_cls).options(joinedload("notes"))
result = Paginate(session, offset, limit)(query)

assert result.meta.total_items == 42
assert result.meta.offset == offset
assert result.meta.total_pages == total_pages
assert result.meta.page_number == page_number


@mark.sqlalchemy("1.3")
@mark.parametrize(
"offset,limit,total_pages,page_number",
[(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)],
)
def test_Pagination_with_custom_sqla13_compliant_count(
session, user_cls, offset, limit, total_pages, page_number
):
from fastapi_sqla import DbQuery, Pagination, Session

def query_count(session: Session, query: DbQuery) -> int:
return (
session.query(user_cls).statement.with_only_columns([func.count()]).scalar()
)

pagination = Pagination(query_count=query_count)
query = session.query(user_cls).options(joinedload("notes"))
result = pagination(session, offset, limit)(query)

assert result.meta.total_items == 42
assert result.meta.offset == offset
assert result.meta.total_pages == total_pages
assert result.meta.page_number == page_number


@mark.sqlalchemy("1.4")
@mark.parametrize(
"offset,limit,total_pages,page_number",
[(0, 5, 9, 1), (10, 10, 5, 2), (40, 10, 5, 5)],
)
def test_Pagination_with_custom_sqla14_compliant_count(
session, user_cls, offset, limit, total_pages, page_number
):
from fastapi_sqla import DbQuery, Pagination, Session

def query_count(session: Session, query: DbQuery) -> int:
return session.execute(select(func.count(user_cls.id))).scalar()

query_count = (
lambda sess, _: session.query(user_cls)
.statement.with_only_columns([func.count()])
.scalar()
)
pagination = Pagination(query_count=query_count)
query = session.query(user_cls).options(joinedload("notes"))
result = pagination(session, offset, limit)(query)
Expand All @@ -113,9 +156,14 @@ def test_Pagination_with_custom_count(

@fixture
def app(user_cls, note_cls):
from sqlalchemy.orm import joinedload

from fastapi_sqla import Page, Paginate, Session, setup
from fastapi_sqla import (
Page,
Paginate,
PaginateSignature,
Pagination,
Session,
setup,
)

app = FastAPI()
setup(app)
Expand All @@ -135,9 +183,26 @@ class User(BaseModel):
class Config:
orm_mode = True

@app.get("/users", response_model=Page[User])
def all_users(session: Session = Depends(), paginate: Paginate = Depends()):
query = session.query(user_cls).options(joinedload("notes"))
@app.get("/v1/users", response_model=Page[User])
def sqla_13_all_users(session: Session = Depends(), paginate: Paginate = Depends()):
query = (
session.query(user_cls).options(joinedload("notes")).order_by(user_cls.id)
)
return paginate(query)

@app.get("/v2/users", response_model=Page[User])
def sqla_14_all_users(paginate: Paginate = Depends()):
query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id)
return paginate(query)

def query_count(session: Session, query: Select) -> int:
return session.execute(select(func.count()).select_from(user_cls)).scalar()

CustomPaginate: PaginateSignature = Pagination(query_count=query_count)

@app.get("/v2/custom/users", response_model=Page[User])
def sqla_14_all_users_custom_pagination(paginate: CustomPaginate = Depends()):
query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id)
return paginate(query)

return app
Expand All @@ -154,14 +219,27 @@ async def client(app):

@mark.asyncio
@mark.parametrize(
"offset,items_number",
[(0, 10), (10, 10), (40, 2)],
"offset,items_number,path",
[
param(0, 10, "/v1/users"),
param(10, 10, "/v1/users"),
param(40, 2, "/v1/users"),
param(0, 10, "/v2/users", marks=mark.sqlalchemy("1.4")),
param(10, 10, "/v2/users", marks=mark.sqlalchemy("1.4")),
param(40, 2, "/v2/users", marks=mark.sqlalchemy("1.4")),
param(0, 10, "/v2/custom/users", marks=mark.sqlalchemy("1.4")),
param(10, 10, "/v2/custom/users", marks=mark.sqlalchemy("1.4")),
param(40, 2, "/v2/custom/users", marks=mark.sqlalchemy("1.4")),
],
)
async def test_functional(client, offset, items_number):
result = await client.get("/users", params={"offset": offset})
async def test_functional(client, offset, items_number, path):
result = await client.get(path, params={"offset": offset})

assert result.status_code == 200, result.json()
users = result.json()["data"]
assert len(users) == items_number
user_ids = [u["id"] for u in users]
assert user_ids == list(range(offset + 1, offset + 1 + items_number))

meta = result.json()["meta"]
assert meta["total_items"] == 42

0 comments on commit cbf06d1

Please sign in to comment.