diff --git a/Cargo.toml b/Cargo.toml index f175a3ecda..9a52d742df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ futures-util = { version = "^0.3" } tracing = { version = "0.1", features = ["log"] } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.8.0", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.24.0", git = "https://github.com/SeaQL/sea-query", branch = "orders-mut-for-each", features = ["thread-safe"] } +sea-query = { version = "^0.24.0", git = "https://github.com/SeaQL/sea-query", branch = "clear-order-by", features = ["thread-safe"] } sea-strum = { version = "^0.23", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/src/executor/cursor.rs b/src/executor/cursor.rs index fb9fff67a7..e369e5cba4 100644 --- a/src/executor/cursor.rs +++ b/src/executor/cursor.rs @@ -1,19 +1,22 @@ use crate::{ - ColumnTrait, ConnectionTrait, DbErr, EntityTrait, FromQueryResult, QueryOrder, Select, - SelectModel, SelectorTrait, + ConnectionTrait, DbErr, EntityTrait, FromQueryResult, Identity, IntoIdentity, QueryOrder, + Select, SelectModel, SelectorTrait, +}; +use sea_query::{ + Condition, DynIden, Expr, IntoValueTuple, Order, SeaRc, SelectStatement, SimpleExpr, Value, + ValueTuple, OrderedStatement, }; -use sea_query::{OrderedStatement, SelectStatement, Value}; use std::marker::PhantomData; /// Cursor pagination -/// -/// To ensure proper ordering of the paginated result, the select statement must have order by expression. #[derive(Debug, Clone)] pub struct Cursor where S: SelectorTrait, { pub(crate) query: SelectStatement, + pub(crate) table: DynIden, + pub(crate) order_columns: Identity, pub(crate) last: bool, pub(crate) phantom: PhantomData, } @@ -23,67 +26,110 @@ where S: SelectorTrait, { /// Initialize a cursor - pub fn new(query: SelectStatement) -> Self { + pub fn new(query: SelectStatement, table: DynIden, order_columns: C) -> Self + where + C: IntoIdentity, + { Self { query, + table, + order_columns: order_columns.into_identity(), last: false, phantom: PhantomData, } } /// Filter paginated rows with column value less than the input value - pub fn before(&mut self, col: C, val: V) -> &mut Self + pub fn before(&mut self, values: V) -> &mut Self where - C: ColumnTrait, - V: Into, + V: IntoValueTuple, { - self.query.and_where(col.lt(val)); + let condition = self.apply_filter(values, |c, v| { + Expr::tbl(SeaRc::clone(&self.table), SeaRc::clone(c)).lt(v) + }); + self.query.cond_where(condition); self } /// Filter paginated rows with column value greater than the input value - pub fn after(&mut self, col: C, val: V) -> &mut Self + pub fn after(&mut self, values: V) -> &mut Self where - C: ColumnTrait, - V: Into, + V: IntoValueTuple, { - self.query.and_where(col.gt(val)); + let condition = self.apply_filter(values, |c, v| { + Expr::tbl(SeaRc::clone(&self.table), SeaRc::clone(c)).gt(v) + }); + self.query.cond_where(condition); self } - fn reverse_ordering(&mut self) { - self.query.orders_mut_for_each(|order_expr| { - order_expr.reverse_ordering(); - }); + fn apply_filter(&self, values: V, f: F) -> Condition + where + V: IntoValueTuple, + F: Fn(&DynIden, Value) -> SimpleExpr, + { + match (&self.order_columns, values.into_value_tuple()) { + (Identity::Unary(c1), ValueTuple::One(v1)) => Condition::all().add(f(c1, v1)), + (Identity::Binary(c1, c2), ValueTuple::Two(v1, v2)) => { + Condition::all().add(f(c1, v1)).add(f(c2, v2)) + } + (Identity::Ternary(c1, c2, c3), ValueTuple::Three(v1, v2, v3)) => Condition::all() + .add(f(c1, v1)) + .add(f(c2, v2)) + .add(f(c3, v3)), + _ => panic!("column arity mismatch"), + } } /// Limit result set to only first N rows in ascending order of the paginated query pub fn first(&mut self, num_rows: u64) -> &mut Self { - self.query.limit(num_rows); - if self.last { - self.reverse_ordering(); - } + self.query.limit(num_rows).clear_order_by(); + let table = SeaRc::clone(&self.table); + self.apply_order_by(|query, col| { + query.order_by((SeaRc::clone(&table), SeaRc::clone(col)), Order::Asc); + }); self.last = false; self } /// Limit result set to only last N rows in ascending order of the paginated query pub fn last(&mut self, num_rows: u64) -> &mut Self { - self.query.limit(num_rows); - if !self.last { - self.reverse_ordering(); - } + self.query.limit(num_rows).clear_order_by(); + let table = SeaRc::clone(&self.table); + self.apply_order_by(|query, col| { + query.order_by((SeaRc::clone(&table), SeaRc::clone(col)), Order::Desc); + }); self.last = true; self } + fn apply_order_by(&mut self, f: F) + where + F: Fn(&mut SelectStatement, &DynIden), + { + let query = &mut self.query; + match &self.order_columns { + Identity::Unary(c1) => { + f(query, c1); + } + Identity::Binary(c1, c2) => { + f(query, c1); + f(query, c2); + } + Identity::Ternary(c1, c2, c3) => { + f(query, c1); + f(query, c2); + f(query, c3); + } + } + } + /// Fetch the rows - pub async fn all(&self, db: &C) -> Result, DbErr> + pub async fn all(&mut self, db: &C) -> Result, DbErr> where C: ConnectionTrait, { - let builder = db.get_database_backend(); - let stmt = builder.build(&self.query); + let stmt = db.get_database_backend().build(&self.query); let rows = db.query_all(stmt).await?; let mut buffer = Vec::with_capacity(rows.len()); for row in rows.into_iter() { @@ -113,7 +159,9 @@ pub trait CursorTrait { type Selector: SelectorTrait + Send + Sync; /// Convert current type into a cursor - fn cursor(self) -> Cursor; + fn cursor(self, order_columns: C) -> Cursor + where + C: IntoIdentity; } impl CursorTrait for Select @@ -123,8 +171,11 @@ where { type Selector = SelectModel; - fn cursor(self) -> Cursor { - Cursor::new(self.query) + fn cursor(self, order_columns: C) -> Cursor + where + C: IntoIdentity, + { + Cursor::new(self.query, SeaRc::new(E::default()), order_columns) } } @@ -160,9 +211,8 @@ mod tests { assert_eq!( Entity::find() - .cursor() - .order_by_asc(Column::Id) - .before(Column::Id, 10) + .cursor(Column::Id) + .before(10) .first(2) .all(&db) .await?, @@ -210,9 +260,8 @@ mod tests { assert_eq!( Entity::find() - .order_by_asc(Column::Id) - .cursor() - .after(Column::Id, 10) + .cursor(Column::Id) + .after(10) .last(2) .all(&db) .await?, @@ -271,10 +320,9 @@ mod tests { assert_eq!( Entity::find() - .order_by_asc(Column::Id) - .cursor() - .after(Column::Id, 25) - .before(Column::Id, 30) + .cursor(Column::Id) + .after(25) + .before(30) .last(2) .all(&db) .await?, diff --git a/tests/cursor_tests.rs b/tests/cursor_tests.rs index 0b49a28bd4..76bd757f6c 100644 --- a/tests/cursor_tests.rs +++ b/tests/cursor_tests.rs @@ -2,7 +2,7 @@ pub mod common; pub use common::{features::*, setup::*, TestContext}; use pretty_assertions::assert_eq; -use sea_orm::{entity::prelude::*, QueryOrder}; +use sea_orm::entity::prelude::*; #[sea_orm_macros::test] #[cfg(any( @@ -55,9 +55,9 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> { // Before 5, i.e. id < 5 - let mut cursor = Entity::find().order_by_asc(Column::Id).cursor(); + let mut cursor = Entity::find().cursor(Column::Id); - cursor.before(Column::Id, 5); + cursor.before(5); assert_eq!( cursor.first(4).all(db).await?, @@ -101,9 +101,9 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> { // After 5, i.e. id > 5 - let mut cursor = Entity::find().order_by_asc(Column::Id).cursor(); + let mut cursor = Entity::find().cursor(Column::Id); - cursor.after(Column::Id, 5); + cursor.after(5); assert_eq!( cursor.first(4).all(db).await?, @@ -171,9 +171,9 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> { // Between 5 and 8, i.e. id > 5 AND id < 8 - let mut cursor = Entity::find().order_by_asc(Column::Id).cursor(); + let mut cursor = Entity::find().cursor(Column::Id); - cursor.after(Column::Id, 5).before(Column::Id, 8); + cursor.after(5).before(8); assert_eq!(cursor.first(1).all(db).await?, vec![Model { id: 6 }]);