diff --git a/Cargo.toml b/Cargo.toml index 716903748..42f97b62e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.3.0", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.18.0", features = ["thread-safe"] } +sea-query = { version = "^0.18.0", git = "https://github.com/SeaQL/sea-query.git", branch = "sea-orm/active-enum-1", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/src/database/statement.rs b/src/database/statement.rs index 12b074873..86d7c630d 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -73,6 +73,15 @@ macro_rules! build_any_stmt { }; } +macro_rules! build_postgres_stmt { + ($stmt: expr, $db_backend: expr) => { + match $db_backend { + DbBackend::Postgres => $stmt.to_string(PostgresQueryBuilder), + DbBackend::MySql | DbBackend::Sqlite => unimplemented!(), + } + }; +} + macro_rules! build_query_stmt { ($stmt: ty) => { impl StatementBuilder for $stmt { @@ -105,3 +114,18 @@ build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableRenameStatement); build_schema_stmt!(sea_query::TableTruncateStatement); + +macro_rules! build_type_stmt { + ($stmt: ty) => { + impl StatementBuilder for $stmt { + fn build(&self, db_backend: &DbBackend) -> Statement { + let stmt = build_postgres_stmt!(self, db_backend); + Statement::from_string(*db_backend, stmt) + } + } + }; +} + +build_type_stmt!(sea_query::extension::postgres::TypeAlterStatement); +build_type_stmt!(sea_query::extension::postgres::TypeCreateStatement); +build_type_stmt!(sea_query::extension::postgres::TypeDropStatement); diff --git a/src/entity/column.rs b/src/entity/column.rs index 25ed84473..9a1f5c9de 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -34,6 +34,7 @@ pub enum ColumnType { JsonBinary, Custom(String), Uuid, + Enum(String, Vec), } macro_rules! bind_oper { @@ -241,6 +242,13 @@ impl ColumnType { indexed: false, } } + + pub(crate) fn get_enum_name(&self) -> Option<&String> { + match self { + ColumnType::Enum(s, _) => Some(s), + _ => None, + } + } } impl ColumnDef { @@ -291,7 +299,7 @@ impl From for sea_query::ColumnType { ColumnType::Money(s) => sea_query::ColumnType::Money(s), ColumnType::Json => sea_query::ColumnType::Json, ColumnType::JsonBinary => sea_query::ColumnType::JsonBinary, - ColumnType::Custom(s) => { + ColumnType::Custom(s) | ColumnType::Enum(s, _) => { sea_query::ColumnType::Custom(sea_query::SeaRc::new(sea_query::Alias::new(&s))) } ColumnType::Uuid => sea_query::ColumnType::Uuid, diff --git a/src/query/insert.rs b/src/query/insert.rs index 5e504a0c5..f7f77d7d0 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -1,9 +1,9 @@ use crate::{ - ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait, - QueryTrait, + ActiveModelTrait, ColumnTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, + PrimaryKeyTrait, QueryTrait, }; use core::marker::PhantomData; -use sea_query::{InsertStatement, ValueTuple}; +use sea_query::{Alias, Expr, InsertStatement, ValueTuple}; #[derive(Debug)] pub struct Insert @@ -131,11 +131,16 @@ where } if av_has_val { columns.push(col); - values.push(av.into_value().unwrap()); + let val = Expr::val(av.into_value().unwrap()); + let expr = match col.def().get_column_type().get_enum_name() { + Some(enum_name) => val.as_enum(Alias::new(enum_name)), + None => val.into(), + }; + values.push(expr); } } self.query.columns(columns); - self.query.values_panic(values); + self.query.exprs_panic(values); self } diff --git a/src/query/select.rs b/src/query/select.rs index 1b0c93c3d..5e433e8a4 100644 --- a/src/query/select.rs +++ b/src/query/select.rs @@ -2,7 +2,7 @@ use crate::{ColumnTrait, EntityTrait, Iterable, QueryFilter, QueryOrder, QuerySe use core::fmt::Debug; use core::marker::PhantomData; pub use sea_query::JoinType; -use sea_query::{DynIden, IntoColumnRef, SeaRc, SelectStatement, SimpleExpr}; +use sea_query::{Alias, DynIden, Expr, IntoColumnRef, SeaRc, SelectStatement, SimpleExpr}; #[derive(Clone, Debug)] pub struct Select @@ -109,13 +109,22 @@ where } fn prepare_select(mut self) -> Self { - self.query.columns(self.column_list()); + self.query.exprs(self.column_list()); self } - fn column_list(&self) -> Vec<(DynIden, E::Column)> { + fn column_list(&self) -> Vec { let table = SeaRc::new(E::default()) as DynIden; - E::Column::iter().map(|col| (table.clone(), col)).collect() + let text_type = SeaRc::new(Alias::new("text")) as DynIden; + E::Column::iter() + .map(|col| { + let expr = Expr::tbl(table.clone(), col); + match col.def().get_column_type().get_enum_name() { + Some(_) => expr.as_enum(text_type.clone()), + None => expr.into(), + } + }) + .collect() } fn prepare_from(mut self) -> Self { diff --git a/src/query/update.rs b/src/query/update.rs index fec7757cc..89348229a 100644 --- a/src/query/update.rs +++ b/src/query/update.rs @@ -3,7 +3,7 @@ use crate::{ QueryTrait, }; use core::marker::PhantomData; -use sea_query::{IntoIden, SimpleExpr, UpdateStatement}; +use sea_query::{Alias, Expr, IntoIden, SimpleExpr, UpdateStatement}; #[derive(Clone, Debug)] pub struct Update; @@ -106,7 +106,12 @@ where } let av = self.model.get(col); if av.is_set() { - self.query.value(col, av.unwrap()); + let val = Expr::val(av.into_value().unwrap()); + let expr = match col.def().get_column_type().get_enum_name() { + Some(enum_name) => val.as_enum(Alias::new(enum_name)), + None => val.into(), + }; + self.query.value_expr(col, expr); } } self diff --git a/src/schema/entity.rs b/src/schema/entity.rs index a95b70478..238bde36e 100644 --- a/src/schema/entity.rs +++ b/src/schema/entity.rs @@ -1,19 +1,56 @@ use crate::{ - unpack_table_ref, ColumnTrait, EntityTrait, Identity, Iterable, PrimaryKeyToColumn, - PrimaryKeyTrait, RelationTrait, Schema, + unpack_table_ref, ColumnTrait, ColumnType, DbBackend, EntityTrait, Identity, Iterable, + PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema, +}; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, ColumnDef, ForeignKeyCreateStatement, Iden, Index, TableCreateStatement, }; -use sea_query::{ColumnDef, ForeignKeyCreateStatement, Iden, Index, TableCreateStatement}; impl Schema { - pub fn create_table_from_entity(entity: E) -> TableCreateStatement + pub fn create_enum_from_entity(entity: E, db_backend: DbBackend) -> Vec where E: EntityTrait, { - create_table_from_entity(entity) + create_enum_from_entity(entity, db_backend) + } + + pub fn create_table_from_entity(entity: E, db_backend: DbBackend) -> TableCreateStatement + where + E: EntityTrait, + { + create_table_from_entity(entity, db_backend) } } -pub(crate) fn create_table_from_entity(entity: E) -> TableCreateStatement +pub(crate) fn create_enum_from_entity(_: E, db_backend: DbBackend) -> Vec +where + E: EntityTrait, +{ + if matches!(db_backend, DbBackend::MySql | DbBackend::Sqlite) { + return Vec::new(); + } + let mut vec = Vec::new(); + for col in E::Column::iter() { + let col_def = col.def(); + let col_type = col_def.get_column_type(); + if !matches!(col_type, ColumnType::Enum(_, _)) { + continue; + } + let (name, values) = match col_type { + ColumnType::Enum(s, v) => (s.as_str(), v), + _ => unreachable!(), + }; + let stmt = Type::create() + .as_enum(Alias::new(name)) + .values(values.into_iter().map(|val| Alias::new(val.as_str()))) + .to_owned(); + vec.push(stmt); + } + vec +} + +pub(crate) fn create_table_from_entity(entity: E, db_backend: DbBackend) -> TableCreateStatement where E: EntityTrait, { @@ -21,7 +58,17 @@ where for column in E::Column::iter() { let orm_column_def = column.def(); - let types = orm_column_def.col_type.into(); + let types = match orm_column_def.col_type { + ColumnType::Enum(s, variants) => match db_backend { + DbBackend::MySql => { + ColumnType::Custom(format!("ENUM('{}')", variants.join("', '"))) + } + DbBackend::Postgres => ColumnType::Custom(s), + DbBackend::Sqlite => ColumnType::Text, + } + .into(), + _ => orm_column_def.col_type.into(), + }; let mut column_def = ColumnDef::new_with_type(column, types); if !orm_column_def.null { column_def.not_null(); @@ -121,13 +168,14 @@ where #[cfg(test)] mod tests { - use crate::{sea_query::*, tests_cfg::*, Schema}; + use crate::{sea_query::*, tests_cfg::*, DbBackend, Schema}; use pretty_assertions::assert_eq; #[test] fn test_create_table_from_entity() { assert_eq!( - Schema::create_table_from_entity(CakeFillingPrice).to_string(MysqlQueryBuilder), + Schema::create_table_from_entity(CakeFillingPrice, DbBackend::MySql) + .to_string(MysqlQueryBuilder), Table::create() .table(CakeFillingPrice) .col( diff --git a/tests/active_enum_tests.rs b/tests/active_enum_tests.rs index 568524814..ca9bf7d99 100644 --- a/tests/active_enum_tests.rs +++ b/tests/active_enum_tests.rs @@ -24,7 +24,7 @@ pub async fn insert_active_enum(db: &DatabaseConnection) -> Result<(), DbErr> { let am = ActiveModel { category: Set(None), color: Set(None), - // tea: Set(None), + tea: Set(None), ..Default::default() } .insert(db) @@ -36,14 +36,14 @@ pub async fn insert_active_enum(db: &DatabaseConnection) -> Result<(), DbErr> { id: 1, category: None, color: None, - // tea: None, + tea: None, } ); - ActiveModel { + let am = ActiveModel { category: Set(Some(Category::Big)), color: Set(Some(Color::Black)), - // tea: Set(Some(Tea::EverydayTea)), + tea: Set(Some(Tea::EverydayTea)), ..am } .save(db) @@ -55,9 +55,14 @@ pub async fn insert_active_enum(db: &DatabaseConnection) -> Result<(), DbErr> { id: 1, category: Some(Category::Big), color: Some(Color::Black), - // tea: Some(Tea::EverydayTea), + tea: Some(Tea::EverydayTea), } ); + let res = am.delete(db).await?; + + assert_eq!(res.rows_affected, 1); + assert_eq!(Entity::find().one(db).await?, None); + Ok(()) } diff --git a/tests/common/features/active_enum.rs b/tests/common/features/active_enum.rs index d7b15443e..f152f37c3 100644 --- a/tests/common/features/active_enum.rs +++ b/tests/common/features/active_enum.rs @@ -7,7 +7,7 @@ pub struct Model { pub id: i32, pub category: Option, pub color: Option, - // pub tea: Option, + pub tea: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] @@ -34,7 +34,10 @@ pub enum Color { } #[derive(Debug, Clone, PartialEq, DeriveActiveEnum)] -#[sea_orm(rs_type = "String", db_type = r#"Custom("tea".to_owned())"#)] +#[sea_orm( + rs_type = "String", + db_type = r#"Enum("tea".to_owned(), vec!["EverydayTea".to_owned(), "BreakfastTea".to_owned()])"# +)] pub enum Tea { #[sea_orm(string_value = "EverydayTea")] EverydayTea, diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 823ccdfd9..977b4b61c 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -1,9 +1,11 @@ pub use super::super::bakery_chain::*; use super::*; -use crate::common::setup::create_table; -use sea_orm::{error::*, sea_query, DatabaseConnection, DbConn, ExecResult}; -use sea_query::{ColumnDef, ForeignKeyCreateStatement}; +use crate::common::setup::{create_table, create_enum}; +use sea_orm::{ + error::*, sea_query, ConnectionTrait, DatabaseConnection, DbBackend, DbConn, ExecResult, +}; +use sea_query::{extension::postgres::Type, Alias, ColumnDef, ForeignKeyCreateStatement}; pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_log_table(db).await?; @@ -103,7 +105,26 @@ pub async fn create_self_join_table(db: &DbConn) -> Result { } pub async fn create_active_enum_table(db: &DbConn) -> Result { - let stmt = sea_query::Table::create() + let db_backend = db.get_database_backend(); + let tea_enum = Alias::new("tea"); + + let create_enum_stmts = match db_backend { + DbBackend::MySql | DbBackend::Sqlite => Vec::new(), + DbBackend::Postgres => vec![Type::create() + .as_enum(tea_enum.clone()) + .values(vec![Alias::new("EverydayTea"), Alias::new("BreakfastTea")]) + .to_owned()], + }; + + create_enum(db, &create_enum_stmts, ActiveEnum).await?; + + let mut tea_col = ColumnDef::new(active_enum::Column::Tea); + match db_backend { + DbBackend::MySql => tea_col.custom(Alias::new("ENUM('EverydayTea', 'BreakfastTea')")), + DbBackend::Sqlite => tea_col.text(), + DbBackend::Postgres => tea_col.custom(tea_enum), + }; + let create_table_stmt = sea_query::Table::create() .table(active_enum::Entity) .col( ColumnDef::new(active_enum::Column::Id) @@ -114,8 +135,8 @@ pub async fn create_active_enum_table(db: &DbConn) -> Result ) .col(ColumnDef::new(active_enum::Column::Category).string_len(1)) .col(ColumnDef::new(active_enum::Column::Color).integer()) - // .col(ColumnDef::new(active_enum::Column::Tea).custom(Alias::new("tea"))) + .col(&mut tea_col) .to_owned(); - create_table(db, &stmt, ActiveEnum).await + create_table(db, &create_table_stmt, ActiveEnum).await } diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index 615de234c..263e19f2f 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,9 +1,12 @@ use pretty_assertions::assert_eq; use sea_orm::{ - ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, DbBackend, DbConn, DbErr, - EntityTrait, ExecResult, Schema, Statement, + ColumnTrait, ColumnType, ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, + DbBackend, DbConn, DbErr, EntityTrait, ExecResult, Iterable, Schema, Statement, +}; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, Table, TableCreateStatement, }; -use sea_query::{Alias, Table, TableCreateStatement}; pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { let db = if cfg!(feature = "sqlx-mysql") { @@ -74,6 +77,52 @@ pub async fn tear_down(base_url: &str, db_name: &str) { }; } +pub async fn create_enum( + db: &DbConn, + creates: &[TypeCreateStatement], + entity: E, +) -> Result<(), DbErr> +where + E: EntityTrait, +{ + let builder = db.get_database_backend(); + if builder == DbBackend::Postgres { + for col in E::Column::iter() { + let col_def = col.def(); + let col_type = col_def.get_column_type(); + if !matches!(col_type, ColumnType::Enum(_, _)) { + continue; + } + let name = match col_type { + ColumnType::Enum(s, _) => s.as_str(), + _ => unreachable!(), + }; + let drop_type_stmt = Type::drop() + .name(Alias::new(name)) + .if_exists() + .cascade() + .to_owned(); + let stmt = builder.build(&drop_type_stmt); + db.execute(stmt).await?; + } + } + + let expect_stmts: Vec = creates.iter().map(|stmt| builder.build(stmt)).collect(); + let create_from_entity_stmts: Vec = + Schema::create_enum_from_entity(entity, db.get_database_backend()) + .iter() + .map(|stmt| builder.build(stmt)) + .collect(); + + assert_eq!(expect_stmts, create_from_entity_stmts); + + for stmt in expect_stmts { + db.execute(stmt).await.map(|_| ())?; + } + + Ok(()) +} + pub async fn create_table( db: &DbConn, create: &TableCreateStatement, @@ -95,7 +144,10 @@ where let stmt = builder.build(create); assert_eq!( - builder.build(&Schema::create_table_from_entity(entity)), + builder.build(&Schema::create_table_from_entity( + entity, + db.get_database_backend() + )), stmt ); db.execute(stmt).await