From 869e324f438bc84f26db44f4ba66d1c0596d8fd0 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 31 May 2023 14:28:15 +0800 Subject: [PATCH] implements `StatementBuilder` for `sea_query::WithQuery` --- src/database/statement.rs | 12 +++++++ src/executor/query.rs | 70 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/src/database/statement.rs b/src/database/statement.rs index fa54e1de9..b8b067baf 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -109,6 +109,18 @@ build_query_stmt!(sea_query::SelectStatement); build_query_stmt!(sea_query::UpdateStatement); build_query_stmt!(sea_query::DeleteStatement); +impl StatementBuilder for sea_query::WithQuery { + fn build(&self, db_backend: &DbBackend) -> Statement { + use sea_query::QueryStatementWriter; + let stmt = match db_backend { + DbBackend::MySql => QueryStatementWriter::build(self, MysqlQueryBuilder), + DbBackend::Postgres => QueryStatementWriter::build(self, PostgresQueryBuilder), + DbBackend::Sqlite => QueryStatementWriter::build(self, SqliteQueryBuilder), + }; + Statement::from_string_values_tuple(*db_backend, stmt) + } +} + macro_rules! build_schema_stmt { ($stmt: ty) => { impl StatementBuilder for $stmt { diff --git a/src/executor/query.rs b/src/executor/query.rs index 28093822b..0a52166e2 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -1157,4 +1157,74 @@ mod tests { let expected = "A null value was encountered while decoding column".to_owned(); assert_eq!(DbErr::from(try_get_error), DbErr::Type(expected)); } + + #[test] + fn build_with_query() { + use sea_orm::{DbBackend, Statement}; + use sea_query::*; + + let base_query = SelectStatement::new() + .column(Alias::new("id")) + .expr(1i32) + .column(Alias::new("next")) + .column(Alias::new("value")) + .from(Alias::new("table")) + .to_owned(); + + let cte_referencing = SelectStatement::new() + .column(Alias::new("id")) + .expr(Expr::col(Alias::new("depth")).add(1i32)) + .column(Alias::new("next")) + .column(Alias::new("value")) + .from(Alias::new("table")) + .join( + JoinType::InnerJoin, + Alias::new("cte_traversal"), + Expr::col((Alias::new("cte_traversal"), Alias::new("next"))) + .equals((Alias::new("table"), Alias::new("id"))), + ) + .to_owned(); + + let common_table_expression = CommonTableExpression::new() + .query( + base_query + .clone() + .union(UnionType::All, cte_referencing) + .to_owned(), + ) + .columns([ + Alias::new("id"), + Alias::new("depth"), + Alias::new("next"), + Alias::new("value"), + ]) + .table_name(Alias::new("cte_traversal")) + .to_owned(); + + let select = SelectStatement::new() + .column(ColumnRef::Asterisk) + .from(Alias::new("cte_traversal")) + .to_owned(); + + let with_clause = WithClause::new() + .recursive(true) + .cte(common_table_expression) + .cycle(Cycle::new_from_expr_set_using( + SimpleExpr::Column(ColumnRef::Column(Alias::new("id").into_iden())), + Alias::new("looped"), + Alias::new("traversal_path"), + )) + .to_owned(); + + let with_query = select.with(with_clause).to_owned(); + + assert_eq!( + DbBackend::MySql.build(&with_query), + Statement::from_sql_and_values( + DbBackend::MySql, + r#"WITH RECURSIVE `cte_traversal` (`id`, `depth`, `next`, `value`) AS (SELECT `id`, ?, `next`, `value` FROM `table` UNION ALL (SELECT `id`, `depth` + ?, `next`, `value` FROM `table` INNER JOIN `cte_traversal` ON `cte_traversal`.`next` = `table`.`id`)) SELECT * FROM `cte_traversal`"#, + [1.into(), 1.into()] + ) + ); + } }