diff --git a/Cargo.toml b/Cargo.toml index 8fc5401d7..9a52d742d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ futures-util = { version = "^0.3" } tracing = { version = "0.1", features = ["log"] } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.8.0", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.24.5", features = ["thread-safe"] } +sea-query = { version = "^0.24.0", git = "https://github.com/SeaQL/sea-query", branch = "clear-order-by", features = ["thread-safe"] } sea-strum = { version = "^0.23", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/sea-orm-migration/Cargo.toml b/sea-orm-migration/Cargo.toml index 9ca00c742..5d3e94e12 100644 --- a/sea-orm-migration/Cargo.toml +++ b/sea-orm-migration/Cargo.toml @@ -23,7 +23,7 @@ clap = { version = "^2.33" } dotenv = { version = "^0.15" } sea-orm = { version = "^0.8.0", path = "../", default-features = false, features = ["macros"] } sea-orm-cli = { version = "^0.8.1", path = "../sea-orm-cli", default-features = false } -sea-schema = { version = "^0.8.1" } +sea-schema = { version = "^0.8.1", git = "https://github.com/SeaQL/sea-schema", branch = "bump-for-sea-orm-cursor-pagination" } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -39,4 +39,4 @@ runtime-async-std-native-tls = [ "sea-orm/runtime-async-std-native-tls" ] runtime-tokio-native-tls = [ "sea-orm/runtime-tokio-native-tls" ] runtime-actix-rustls = [ "sea-orm/runtime-actix-rustls" ] runtime-async-std-rustls = [ "sea-orm/runtime-async-std-rustls" ] -runtime-tokio-rustls = [ "sea-orm/runtime-tokio-rustls" ] \ No newline at end of file +runtime-tokio-rustls = [ "sea-orm/runtime-tokio-rustls" ] diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 10b1bb503..14e029100 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -1,8 +1,8 @@ pub use crate::{ error::*, ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, - ColumnType, DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, ForeignKeyAction, - Iden, IdenStatic, Linked, ModelTrait, PaginatorTrait, PrimaryKeyToColumn, PrimaryKeyTrait, - QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, + ColumnType, CursorTrait, DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, + ForeignKeyAction, Iden, IdenStatic, Linked, ModelTrait, PaginatorTrait, PrimaryKeyToColumn, + PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, }; #[cfg(feature = "macros")] diff --git a/src/executor/cursor.rs b/src/executor/cursor.rs new file mode 100644 index 000000000..60fa432b9 --- /dev/null +++ b/src/executor/cursor.rs @@ -0,0 +1,437 @@ +use crate::{ + ConnectionTrait, DbErr, EntityTrait, FromQueryResult, Identity, IntoIdentity, QueryOrder, + Select, SelectModel, SelectorTrait, +}; +use sea_query::{ + Condition, DynIden, Expr, IntoValueTuple, Order, OrderedStatement, SeaRc, SelectStatement, + SimpleExpr, Value, ValueTuple, +}; +use std::marker::PhantomData; + +/// Cursor pagination +#[derive(Debug, Clone)] +pub struct Cursor +where + S: SelectorTrait, +{ + pub(crate) query: SelectStatement, + pub(crate) table: DynIden, + pub(crate) order_columns: Identity, + pub(crate) last: bool, + pub(crate) phantom: PhantomData, +} + +impl Cursor +where + S: SelectorTrait, +{ + /// Initialize a cursor + pub fn new(query: SelectStatement, table: DynIden, order_columns: C) -> Self + where + C: IntoIdentity, + { + Self { + query, + table, + order_columns: order_columns.into_identity(), + last: false, + phantom: PhantomData, + } + } + + /// Filter paginated result with corresponding column less than the input value + pub fn before(&mut self, values: V) -> &mut Self + where + V: IntoValueTuple, + { + let condition = self.apply_filter(values, |c, v| { + Expr::tbl(SeaRc::clone(&self.table), SeaRc::clone(c)).lt(v) + }); + self.query.cond_where(condition); + self + } + + /// Filter paginated result with corresponding column greater than the input value + pub fn after(&mut self, values: V) -> &mut Self + where + V: IntoValueTuple, + { + let condition = self.apply_filter(values, |c, v| { + Expr::tbl(SeaRc::clone(&self.table), SeaRc::clone(c)).gt(v) + }); + self.query.cond_where(condition); + self + } + + fn apply_filter(&self, values: V, f: F) -> Condition + where + V: IntoValueTuple, + F: Fn(&DynIden, Value) -> SimpleExpr, + { + match (&self.order_columns, values.into_value_tuple()) { + (Identity::Unary(c1), ValueTuple::One(v1)) => Condition::all().add(f(c1, v1)), + (Identity::Binary(c1, c2), ValueTuple::Two(v1, v2)) => { + Condition::all().add(f(c1, v1)).add(f(c2, v2)) + } + (Identity::Ternary(c1, c2, c3), ValueTuple::Three(v1, v2, v3)) => Condition::all() + .add(f(c1, v1)) + .add(f(c2, v2)) + .add(f(c3, v3)), + _ => panic!("column arity mismatch"), + } + } + + /// Limit result set to only first N rows in ascending order of the order by column + pub fn first(&mut self, num_rows: u64) -> &mut Self { + self.query.limit(num_rows).clear_order_by(); + let table = SeaRc::clone(&self.table); + self.apply_order_by(|query, col| { + query.order_by((SeaRc::clone(&table), SeaRc::clone(col)), Order::Asc); + }); + self.last = false; + self + } + + /// Limit result set to only last N rows in ascending order of the order by column + pub fn last(&mut self, num_rows: u64) -> &mut Self { + self.query.limit(num_rows).clear_order_by(); + let table = SeaRc::clone(&self.table); + self.apply_order_by(|query, col| { + query.order_by((SeaRc::clone(&table), SeaRc::clone(col)), Order::Desc); + }); + self.last = true; + self + } + + fn apply_order_by(&mut self, f: F) + where + F: Fn(&mut SelectStatement, &DynIden), + { + let query = &mut self.query; + match &self.order_columns { + Identity::Unary(c1) => { + f(query, c1); + } + Identity::Binary(c1, c2) => { + f(query, c1); + f(query, c2); + } + Identity::Ternary(c1, c2, c3) => { + f(query, c1); + f(query, c2); + f(query, c3); + } + } + } + + /// Fetch the paginated result + pub async fn all(&mut self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait, + { + let stmt = db.get_database_backend().build(&self.query); + let rows = db.query_all(stmt).await?; + let mut buffer = Vec::with_capacity(rows.len()); + for row in rows.into_iter() { + buffer.push(S::from_raw_query_result(row)?); + } + if self.last { + buffer.reverse() + } + Ok(buffer) + } +} + +impl QueryOrder for Cursor +where + S: SelectorTrait, +{ + type QueryStatement = SelectStatement; + + fn query(&mut self) -> &mut SelectStatement { + &mut self.query + } +} + +/// A trait for any type that can be turn into a cursor +pub trait CursorTrait { + /// Select operation + type Selector: SelectorTrait + Send + Sync; + + /// Convert current type into a cursor + fn cursor(self, order_columns: C) -> Cursor + where + C: IntoIdentity; +} + +impl CursorTrait for Select +where + E: EntityTrait, + M: FromQueryResult + Sized + Send + Sync, +{ + type Selector = SelectModel; + + fn cursor(self, order_columns: C) -> Cursor + where + C: IntoIdentity, + { + Cursor::new(self.query, SeaRc::new(E::default()), order_columns) + } +} + +#[cfg(test)] +#[cfg(feature = "mock")] +mod tests { + use super::*; + use crate::entity::prelude::*; + use crate::tests_cfg::*; + use crate::{DbBackend, MockDatabase, Statement, Transaction}; + use pretty_assertions::assert_eq; + + #[smol_potat::test] + async fn first_2_before_10() -> Result<(), DbErr> { + use fruit::*; + + let models = vec![ + Model { + id: 1, + name: "Blueberry".into(), + cake_id: Some(1), + }, + Model { + id: 2, + name: "Rasberry".into(), + cake_id: Some(1), + }, + ]; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![models.clone()]) + .into_connection(); + + assert_eq!( + Entity::find() + .cursor(Column::Id) + .before(10) + .first(2) + .all(&db) + .await?, + models + ); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![Statement::from_sql_and_values( + DbBackend::Postgres, + [ + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id""#, + r#"FROM "fruit""#, + r#"WHERE "fruit"."id" < $1"#, + r#"ORDER BY "fruit"."id" ASC"#, + r#"LIMIT $2"#, + ] + .join(" ") + .as_str(), + vec![10_i32.into(), 2_u64.into()] + ),])] + ); + + Ok(()) + } + + #[smol_potat::test] + async fn last_2_after_10() -> Result<(), DbErr> { + use fruit::*; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![ + Model { + id: 22, + name: "Rasberry".into(), + cake_id: Some(1), + }, + Model { + id: 21, + name: "Blueberry".into(), + cake_id: Some(1), + }, + ]]) + .into_connection(); + + assert_eq!( + Entity::find() + .cursor(Column::Id) + .after(10) + .last(2) + .all(&db) + .await?, + vec![ + Model { + id: 21, + name: "Blueberry".into(), + cake_id: Some(1), + }, + Model { + id: 22, + name: "Rasberry".into(), + cake_id: Some(1), + }, + ] + ); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![Statement::from_sql_and_values( + DbBackend::Postgres, + [ + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id""#, + r#"FROM "fruit""#, + r#"WHERE "fruit"."id" > $1"#, + r#"ORDER BY "fruit"."id" DESC"#, + r#"LIMIT $2"#, + ] + .join(" ") + .as_str(), + vec![10_i32.into(), 2_u64.into()] + ),])] + ); + + Ok(()) + } + + #[smol_potat::test] + async fn last_2_after_25_before_30() -> Result<(), DbErr> { + use fruit::*; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![ + Model { + id: 27, + name: "Rasberry".into(), + cake_id: Some(1), + }, + Model { + id: 26, + name: "Blueberry".into(), + cake_id: Some(1), + }, + ]]) + .into_connection(); + + assert_eq!( + Entity::find() + .cursor(Column::Id) + .after(25) + .before(30) + .last(2) + .all(&db) + .await?, + vec![ + Model { + id: 26, + name: "Blueberry".into(), + cake_id: Some(1), + }, + Model { + id: 27, + name: "Rasberry".into(), + cake_id: Some(1), + }, + ] + ); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![Statement::from_sql_and_values( + DbBackend::Postgres, + [ + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id""#, + r#"FROM "fruit""#, + r#"WHERE "fruit"."id" > $1"#, + r#"AND "fruit"."id" < $2"#, + r#"ORDER BY "fruit"."id" DESC"#, + r#"LIMIT $3"#, + ] + .join(" ") + .as_str(), + vec![25_i32.into(), 30_i32.into(), 2_u64.into()] + ),])] + ); + + Ok(()) + } + + #[smol_potat::test] + async fn composite_keys() -> Result<(), DbErr> { + use cake_filling::*; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![ + Model { + cake_id: 1, + filling_id: 2, + }, + Model { + cake_id: 1, + filling_id: 3, + }, + Model { + cake_id: 2, + filling_id: 3, + }, + ]]) + .into_connection(); + + assert_eq!( + Entity::find() + .cursor((Column::CakeId, Column::FillingId)) + .after((0, 1)) + .before((10, 11)) + .first(3) + .all(&db) + .await?, + vec![ + Model { + cake_id: 1, + filling_id: 2, + }, + Model { + cake_id: 1, + filling_id: 3, + }, + Model { + cake_id: 2, + filling_id: 3, + }, + ] + ); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![Statement::from_sql_and_values( + DbBackend::Postgres, + [ + r#"SELECT "cake_filling"."cake_id", "cake_filling"."filling_id""#, + r#"FROM "cake_filling""#, + r#"WHERE "cake_filling"."cake_id" > $1"#, + r#"AND "cake_filling"."filling_id" > $2"#, + r#"AND ("cake_filling"."cake_id" < $3"#, + r#"AND "cake_filling"."filling_id" < $4)"#, + r#"ORDER BY "cake_filling"."cake_id" ASC, "cake_filling"."filling_id" ASC"#, + r#"LIMIT $5"#, + ] + .join(" ") + .as_str(), + vec![ + 0_i32.into(), + 1_i32.into(), + 10_i32.into(), + 11_i32.into(), + 3_u64.into() + ] + ),])] + ); + + Ok(()) + } +} diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 4be8227e7..34f8695fe 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -1,3 +1,4 @@ +mod cursor; mod delete; mod execute; mod insert; @@ -6,6 +7,7 @@ mod query; mod select; mod update; +pub use cursor::*; pub use delete::*; pub use execute::*; pub use insert::*; diff --git a/src/query/mod.rs b/src/query/mod.rs index b7e304ac1..2de0e7908 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -23,6 +23,6 @@ pub use update::*; pub use util::*; pub use crate::{ - ConnectionTrait, InsertResult, PaginatorTrait, Statement, StreamTrait, TransactionTrait, - UpdateResult, Value, Values, + ConnectionTrait, CursorTrait, InsertResult, PaginatorTrait, Statement, StreamTrait, + TransactionTrait, UpdateResult, Value, Values, }; diff --git a/tests/cursor_tests.rs b/tests/cursor_tests.rs new file mode 100644 index 000000000..76bd757f6 --- /dev/null +++ b/tests/cursor_tests.rs @@ -0,0 +1,203 @@ +pub mod common; + +pub use common::{features::*, setup::*, TestContext}; +use pretty_assertions::assert_eq; +use sea_orm::entity::prelude::*; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("cursor_tests").await; + create_tables(&ctx.db).await?; + create_insert_default(&ctx.db).await?; + cursor_pagination(&ctx.db).await?; + ctx.delete().await; + + Ok(()) +} + +pub async fn create_insert_default(db: &DatabaseConnection) -> Result<(), DbErr> { + use insert_default::*; + + for _ in 0..10 { + ActiveModel { + ..Default::default() + } + .insert(db) + .await?; + } + + assert_eq!( + Entity::find().all(db).await?, + vec![ + Model { id: 1 }, + Model { id: 2 }, + Model { id: 3 }, + Model { id: 4 }, + Model { id: 5 }, + Model { id: 6 }, + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + Model { id: 10 }, + ] + ); + + Ok(()) +} + +pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> { + use insert_default::*; + + // Before 5, i.e. id < 5 + + let mut cursor = Entity::find().cursor(Column::Id); + + cursor.before(5); + + assert_eq!( + cursor.first(4).all(db).await?, + vec![ + Model { id: 1 }, + Model { id: 2 }, + Model { id: 3 }, + Model { id: 4 }, + ] + ); + + assert_eq!( + cursor.first(5).all(db).await?, + vec![ + Model { id: 1 }, + Model { id: 2 }, + Model { id: 3 }, + Model { id: 4 }, + ] + ); + + assert_eq!( + cursor.last(4).all(db).await?, + vec![ + Model { id: 1 }, + Model { id: 2 }, + Model { id: 3 }, + Model { id: 4 }, + ] + ); + + assert_eq!( + cursor.last(5).all(db).await?, + vec![ + Model { id: 1 }, + Model { id: 2 }, + Model { id: 3 }, + Model { id: 4 }, + ] + ); + + // After 5, i.e. id > 5 + + let mut cursor = Entity::find().cursor(Column::Id); + + cursor.after(5); + + assert_eq!( + cursor.first(4).all(db).await?, + vec![ + Model { id: 6 }, + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + ] + ); + + assert_eq!( + cursor.first(5).all(db).await?, + vec![ + Model { id: 6 }, + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + Model { id: 10 }, + ] + ); + + assert_eq!( + cursor.first(6).all(db).await?, + vec![ + Model { id: 6 }, + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + Model { id: 10 }, + ] + ); + + assert_eq!( + cursor.last(4).all(db).await?, + vec![ + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + Model { id: 10 }, + ] + ); + + assert_eq!( + cursor.last(5).all(db).await?, + vec![ + Model { id: 6 }, + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + Model { id: 10 }, + ] + ); + + assert_eq!( + cursor.last(6).all(db).await?, + vec![ + Model { id: 6 }, + Model { id: 7 }, + Model { id: 8 }, + Model { id: 9 }, + Model { id: 10 }, + ] + ); + + // Between 5 and 8, i.e. id > 5 AND id < 8 + + let mut cursor = Entity::find().cursor(Column::Id); + + cursor.after(5).before(8); + + assert_eq!(cursor.first(1).all(db).await?, vec![Model { id: 6 }]); + + assert_eq!( + cursor.first(2).all(db).await?, + vec![Model { id: 6 }, Model { id: 7 }] + ); + + assert_eq!( + cursor.first(3).all(db).await?, + vec![Model { id: 6 }, Model { id: 7 }] + ); + + assert_eq!(cursor.last(1).all(db).await?, vec![Model { id: 7 }]); + + assert_eq!( + cursor.last(2).all(db).await?, + vec![Model { id: 6 }, Model { id: 7 }] + ); + + assert_eq!( + cursor.last(3).all(db).await?, + vec![Model { id: 6 }, Model { id: 7 }] + ); + + Ok(()) +}