Skip to content

Commit

Permalink
Update cursor API
Browse files Browse the repository at this point in the history
  • Loading branch information
billy1624 committed Jun 20, 2022
1 parent 88c4268 commit 1e2c1c4
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ futures-util = { version = "^0.3" }
tracing = { version = "0.1", features = ["log"] }
rust_decimal = { version = "^1", optional = true }
sea-orm-macros = { version = "^0.8.0", path = "sea-orm-macros", optional = true }
sea-query = { version = "^0.24.0", git = "https://github.com/SeaQL/sea-query", branch = "orders-mut-for-each", features = ["thread-safe"] }
sea-query = { version = "^0.24.0", git = "https://github.com/SeaQL/sea-query", branch = "clear-order-by", features = ["thread-safe"] }
sea-strum = { version = "^0.23", features = ["derive", "sea-orm"] }
serde = { version = "^1.0", features = ["derive"] }
serde_json = { version = "^1", optional = true }
Expand Down
132 changes: 90 additions & 42 deletions src/executor/cursor.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use crate::{
ColumnTrait, ConnectionTrait, DbErr, EntityTrait, FromQueryResult, QueryOrder, Select,
SelectModel, SelectorTrait,
ConnectionTrait, DbErr, EntityTrait, FromQueryResult, Identity, IntoIdentity, QueryOrder,
Select, SelectModel, SelectorTrait,
};
use sea_query::{
Condition, DynIden, Expr, IntoValueTuple, Order, SeaRc, SelectStatement, SimpleExpr, Value,
ValueTuple, OrderedStatement,
};
use sea_query::{OrderedStatement, SelectStatement, Value};
use std::marker::PhantomData;

/// Cursor pagination
///
/// To ensure proper ordering of the paginated result, the select statement must have order by expression.
#[derive(Debug, Clone)]
pub struct Cursor<S>
where
S: SelectorTrait,
{
pub(crate) query: SelectStatement,
pub(crate) table: DynIden,
pub(crate) order_columns: Identity,
pub(crate) last: bool,
pub(crate) phantom: PhantomData<S>,
}
Expand All @@ -23,67 +26,110 @@ where
S: SelectorTrait,
{
/// Initialize a cursor
pub fn new(query: SelectStatement) -> Self {
pub fn new<C>(query: SelectStatement, table: DynIden, order_columns: C) -> Self
where
C: IntoIdentity,
{
Self {
query,
table,
order_columns: order_columns.into_identity(),
last: false,
phantom: PhantomData,
}
}

/// Filter paginated rows with column value less than the input value
pub fn before<C, V>(&mut self, col: C, val: V) -> &mut Self
pub fn before<V>(&mut self, values: V) -> &mut Self
where
C: ColumnTrait,
V: Into<Value>,
V: IntoValueTuple,
{
self.query.and_where(col.lt(val));
let condition = self.apply_filter(values, |c, v| {
Expr::tbl(SeaRc::clone(&self.table), SeaRc::clone(c)).lt(v)
});
self.query.cond_where(condition);
self
}

/// Filter paginated rows with column value greater than the input value
pub fn after<C, V>(&mut self, col: C, val: V) -> &mut Self
pub fn after<V>(&mut self, values: V) -> &mut Self
where
C: ColumnTrait,
V: Into<Value>,
V: IntoValueTuple,
{
self.query.and_where(col.gt(val));
let condition = self.apply_filter(values, |c, v| {
Expr::tbl(SeaRc::clone(&self.table), SeaRc::clone(c)).gt(v)
});
self.query.cond_where(condition);
self
}

fn reverse_ordering(&mut self) {
self.query.orders_mut_for_each(|order_expr| {
order_expr.reverse_ordering();
});
fn apply_filter<V, F>(&self, values: V, f: F) -> Condition
where
V: IntoValueTuple,
F: Fn(&DynIden, Value) -> SimpleExpr,
{
match (&self.order_columns, values.into_value_tuple()) {
(Identity::Unary(c1), ValueTuple::One(v1)) => Condition::all().add(f(c1, v1)),
(Identity::Binary(c1, c2), ValueTuple::Two(v1, v2)) => {
Condition::all().add(f(c1, v1)).add(f(c2, v2))
}
(Identity::Ternary(c1, c2, c3), ValueTuple::Three(v1, v2, v3)) => Condition::all()
.add(f(c1, v1))
.add(f(c2, v2))
.add(f(c3, v3)),
_ => panic!("column arity mismatch"),
}
}

/// Limit result set to only first N rows in ascending order of the paginated query
pub fn first(&mut self, num_rows: u64) -> &mut Self {
self.query.limit(num_rows);
if self.last {
self.reverse_ordering();
}
self.query.limit(num_rows).clear_order_by();
let table = SeaRc::clone(&self.table);
self.apply_order_by(|query, col| {
query.order_by((SeaRc::clone(&table), SeaRc::clone(col)), Order::Asc);
});
self.last = false;
self
}

/// Limit result set to only last N rows in ascending order of the paginated query
pub fn last(&mut self, num_rows: u64) -> &mut Self {
self.query.limit(num_rows);
if !self.last {
self.reverse_ordering();
}
self.query.limit(num_rows).clear_order_by();
let table = SeaRc::clone(&self.table);
self.apply_order_by(|query, col| {
query.order_by((SeaRc::clone(&table), SeaRc::clone(col)), Order::Desc);
});
self.last = true;
self
}

fn apply_order_by<F>(&mut self, f: F)
where
F: Fn(&mut SelectStatement, &DynIden),
{
let query = &mut self.query;
match &self.order_columns {
Identity::Unary(c1) => {
f(query, c1);
}
Identity::Binary(c1, c2) => {
f(query, c1);
f(query, c2);
}
Identity::Ternary(c1, c2, c3) => {
f(query, c1);
f(query, c2);
f(query, c3);
}
}
}

/// Fetch the rows
pub async fn all<C>(&self, db: &C) -> Result<Vec<S::Item>, DbErr>
pub async fn all<C>(&mut self, db: &C) -> Result<Vec<S::Item>, DbErr>
where
C: ConnectionTrait,
{
let builder = db.get_database_backend();
let stmt = builder.build(&self.query);
let stmt = db.get_database_backend().build(&self.query);
let rows = db.query_all(stmt).await?;
let mut buffer = Vec::with_capacity(rows.len());
for row in rows.into_iter() {
Expand Down Expand Up @@ -113,7 +159,9 @@ pub trait CursorTrait {
type Selector: SelectorTrait + Send + Sync;

/// Convert current type into a cursor
fn cursor(self) -> Cursor<Self::Selector>;
fn cursor<C>(self, order_columns: C) -> Cursor<Self::Selector>
where
C: IntoIdentity;
}

impl<E, M> CursorTrait for Select<E>
Expand All @@ -123,8 +171,11 @@ where
{
type Selector = SelectModel<M>;

fn cursor(self) -> Cursor<Self::Selector> {
Cursor::new(self.query)
fn cursor<C>(self, order_columns: C) -> Cursor<Self::Selector>
where
C: IntoIdentity,
{
Cursor::new(self.query, SeaRc::new(E::default()), order_columns)
}
}

Expand Down Expand Up @@ -160,9 +211,8 @@ mod tests {

assert_eq!(
Entity::find()
.cursor()
.order_by_asc(Column::Id)
.before(Column::Id, 10)
.cursor(Column::Id)
.before(10)
.first(2)
.all(&db)
.await?,
Expand Down Expand Up @@ -210,9 +260,8 @@ mod tests {

assert_eq!(
Entity::find()
.order_by_asc(Column::Id)
.cursor()
.after(Column::Id, 10)
.cursor(Column::Id)
.after(10)
.last(2)
.all(&db)
.await?,
Expand Down Expand Up @@ -271,10 +320,9 @@ mod tests {

assert_eq!(
Entity::find()
.order_by_asc(Column::Id)
.cursor()
.after(Column::Id, 25)
.before(Column::Id, 30)
.cursor(Column::Id)
.after(25)
.before(30)
.last(2)
.all(&db)
.await?,
Expand Down
14 changes: 7 additions & 7 deletions tests/cursor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod common;

pub use common::{features::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
use sea_orm::{entity::prelude::*, QueryOrder};
use sea_orm::entity::prelude::*;

#[sea_orm_macros::test]
#[cfg(any(
Expand Down Expand Up @@ -55,9 +55,9 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> {

// Before 5, i.e. id < 5

let mut cursor = Entity::find().order_by_asc(Column::Id).cursor();
let mut cursor = Entity::find().cursor(Column::Id);

cursor.before(Column::Id, 5);
cursor.before(5);

assert_eq!(
cursor.first(4).all(db).await?,
Expand Down Expand Up @@ -101,9 +101,9 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> {

// After 5, i.e. id > 5

let mut cursor = Entity::find().order_by_asc(Column::Id).cursor();
let mut cursor = Entity::find().cursor(Column::Id);

cursor.after(Column::Id, 5);
cursor.after(5);

assert_eq!(
cursor.first(4).all(db).await?,
Expand Down Expand Up @@ -171,9 +171,9 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> {

// Between 5 and 8, i.e. id > 5 AND id < 8

let mut cursor = Entity::find().order_by_asc(Column::Id).cursor();
let mut cursor = Entity::find().cursor(Column::Id);

cursor.after(Column::Id, 5).before(Column::Id, 8);
cursor.after(5).before(8);

assert_eq!(cursor.first(1).all(db).await?, vec![Model { id: 6 }]);

Expand Down

0 comments on commit 1e2c1c4

Please sign in to comment.