diff --git a/src/database/statement.rs b/src/database/statement.rs index 12b074873..86d7c630d 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -73,6 +73,15 @@ macro_rules! build_any_stmt { }; } +macro_rules! build_postgres_stmt { + ($stmt: expr, $db_backend: expr) => { + match $db_backend { + DbBackend::Postgres => $stmt.to_string(PostgresQueryBuilder), + DbBackend::MySql | DbBackend::Sqlite => unimplemented!(), + } + }; +} + macro_rules! build_query_stmt { ($stmt: ty) => { impl StatementBuilder for $stmt { @@ -105,3 +114,18 @@ build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableRenameStatement); build_schema_stmt!(sea_query::TableTruncateStatement); + +macro_rules! build_type_stmt { + ($stmt: ty) => { + impl StatementBuilder for $stmt { + fn build(&self, db_backend: &DbBackend) -> Statement { + let stmt = build_postgres_stmt!(self, db_backend); + Statement::from_string(*db_backend, stmt) + } + } + }; +} + +build_type_stmt!(sea_query::extension::postgres::TypeAlterStatement); +build_type_stmt!(sea_query::extension::postgres::TypeCreateStatement); +build_type_stmt!(sea_query::extension::postgres::TypeDropStatement); diff --git a/src/schema/entity.rs b/src/schema/entity.rs index 3f99f8928..238bde36e 100644 --- a/src/schema/entity.rs +++ b/src/schema/entity.rs @@ -2,9 +2,19 @@ use crate::{ unpack_table_ref, ColumnTrait, ColumnType, DbBackend, EntityTrait, Identity, Iterable, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema, }; -use sea_query::{ColumnDef, ForeignKeyCreateStatement, Iden, Index, TableCreateStatement}; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, ColumnDef, ForeignKeyCreateStatement, Iden, Index, TableCreateStatement, +}; impl Schema { + pub fn create_enum_from_entity(entity: E, db_backend: DbBackend) -> Vec + where + E: EntityTrait, + { + create_enum_from_entity(entity, db_backend) + } + pub fn create_table_from_entity(entity: E, db_backend: DbBackend) -> TableCreateStatement where E: EntityTrait, @@ -13,6 +23,33 @@ impl Schema { } } +pub(crate) fn create_enum_from_entity(_: E, db_backend: DbBackend) -> Vec +where + E: EntityTrait, +{ + if matches!(db_backend, DbBackend::MySql | DbBackend::Sqlite) { + return Vec::new(); + } + let mut vec = Vec::new(); + for col in E::Column::iter() { + let col_def = col.def(); + let col_type = col_def.get_column_type(); + if !matches!(col_type, ColumnType::Enum(_, _)) { + continue; + } + let (name, values) = match col_type { + ColumnType::Enum(s, v) => (s.as_str(), v), + _ => unreachable!(), + }; + let stmt = Type::create() + .as_enum(Alias::new(name)) + .values(values.into_iter().map(|val| Alias::new(val.as_str()))) + .to_owned(); + vec.push(stmt); + } + vec +} + pub(crate) fn create_table_from_entity(entity: E, db_backend: DbBackend) -> TableCreateStatement where E: EntityTrait, diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index bbc9faf8b..7e9f3de7b 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -1,14 +1,11 @@ pub use super::super::bakery_chain::*; use super::*; -use crate::common::setup::create_table; +use crate::{common::setup::create_table, create_enum}; use sea_orm::{ error::*, sea_query, ConnectionTrait, DatabaseConnection, DbBackend, DbConn, ExecResult, - Statement, -}; -use sea_query::{ - extension::postgres::Type, Alias, ColumnDef, ForeignKeyCreateStatement, PostgresQueryBuilder, }; +use sea_query::{extension::postgres::Type, Alias, ColumnDef, ForeignKeyCreateStatement}; pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_log_table(db).await?; @@ -111,14 +108,23 @@ pub async fn create_active_enum_table(db: &DbConn) -> Result let db_backend = db.get_database_backend(); let tea_enum = Alias::new("tea"); + let create_enum_stmts = match db_backend { + DbBackend::MySql | DbBackend::Sqlite => Vec::new(), + DbBackend::Postgres => vec![Type::create() + .as_enum(tea_enum.clone()) + .values(vec![Alias::new("EverydayTea"), Alias::new("BreakfastTea")]) + .to_owned()], + }; + + create_enum(db, &create_enum_stmts, ActiveEnum).await?; + let mut tea_col = ColumnDef::new(active_enum::Column::Tea); match db_backend { DbBackend::MySql => tea_col.custom(Alias::new("ENUM('EverydayTea', 'BreakfastTea')")), - DbBackend::Postgres => tea_col.custom(tea_enum.clone()), DbBackend::Sqlite => tea_col.text(), + DbBackend::Postgres => tea_col.custom(tea_enum), }; - - let stmt = sea_query::Table::create() + let create_table_stmt = sea_query::Table::create() .table(active_enum::Entity) .col( ColumnDef::new(active_enum::Column::Id) @@ -132,32 +138,5 @@ pub async fn create_active_enum_table(db: &DbConn) -> Result .col(&mut tea_col) .to_owned(); - if db_backend == DbBackend::Postgres { - let drop_type_stmt = Type::drop() - .name(tea_enum.clone()) - .cascade() - .if_exists() - .to_owned(); - let (sql, values) = drop_type_stmt.build(PostgresQueryBuilder); - let stmt = Statement::from_sql_and_values(db.get_database_backend(), &sql, values); - db.execute(stmt).await?; - - let create_type_stmt = Type::create() - .as_enum(tea_enum) - .values(vec![Alias::new("EverydayTea"), Alias::new("BreakfastTea")]) - .to_owned(); - // FIXME: This is not working - { - let (sql, values) = create_type_stmt.build(PostgresQueryBuilder); - let _stmt = Statement::from_sql_and_values(db.get_database_backend(), &sql, values); - } - // But this is working... - let stmt = Statement::from_string( - db.get_database_backend(), - create_type_stmt.to_string(PostgresQueryBuilder), - ); - db.execute(stmt).await?; - } - - create_table(db, &stmt, ActiveEnum).await + create_table(db, &create_table_stmt, ActiveEnum).await } diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index 7266d1759..263e19f2f 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,9 +1,12 @@ use pretty_assertions::assert_eq; use sea_orm::{ - ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, DbBackend, DbConn, DbErr, - EntityTrait, ExecResult, Schema, Statement, + ColumnTrait, ColumnType, ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, + DbBackend, DbConn, DbErr, EntityTrait, ExecResult, Iterable, Schema, Statement, +}; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, Table, TableCreateStatement, }; -use sea_query::{Alias, Table, TableCreateStatement}; pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { let db = if cfg!(feature = "sqlx-mysql") { @@ -74,6 +77,52 @@ pub async fn tear_down(base_url: &str, db_name: &str) { }; } +pub async fn create_enum( + db: &DbConn, + creates: &[TypeCreateStatement], + entity: E, +) -> Result<(), DbErr> +where + E: EntityTrait, +{ + let builder = db.get_database_backend(); + if builder == DbBackend::Postgres { + for col in E::Column::iter() { + let col_def = col.def(); + let col_type = col_def.get_column_type(); + if !matches!(col_type, ColumnType::Enum(_, _)) { + continue; + } + let name = match col_type { + ColumnType::Enum(s, _) => s.as_str(), + _ => unreachable!(), + }; + let drop_type_stmt = Type::drop() + .name(Alias::new(name)) + .if_exists() + .cascade() + .to_owned(); + let stmt = builder.build(&drop_type_stmt); + db.execute(stmt).await?; + } + } + + let expect_stmts: Vec = creates.iter().map(|stmt| builder.build(stmt)).collect(); + let create_from_entity_stmts: Vec = + Schema::create_enum_from_entity(entity, db.get_database_backend()) + .iter() + .map(|stmt| builder.build(stmt)) + .collect(); + + assert_eq!(expect_stmts, create_from_entity_stmts); + + for stmt in expect_stmts { + db.execute(stmt).await.map(|_| ())?; + } + + Ok(()) +} + pub async fn create_table( db: &DbConn, create: &TableCreateStatement,