From 37de81343bcbfb786817142468321ac67bffbf48 Mon Sep 17 00:00:00 2001 From: Liber Wang Date: Sat, 20 Aug 2022 15:48:58 +0800 Subject: [PATCH] support `extra` in entity macro field --- sea-orm-macros/src/derives/entity_model.rs | 6 ++++++ src/entity/column.rs | 8 ++++++++ src/schema/entity.rs | 3 +++ tests/common/features/check.rs | 4 ++-- tests/common/features/schema.rs | 12 ++++++++++-- 5 files changed, 29 insertions(+), 4 deletions(-) diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index 959509f2c..55b48ca6a 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -90,6 +90,7 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res let mut nullable = false; let mut default_value = None; let mut default_expr = None; + let mut extra = None; let mut created_at = false; let mut updated_at = false; let mut indexed = false; @@ -171,6 +172,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res format!("Invalid enum_name {:?}", nv.lit), )); } + } else if name == "extra" { + extra = Some(nv.lit.to_owned()); } } } @@ -312,6 +315,9 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res if let Some(default_expr) = default_expr { match_row = quote! { #match_row.default_expr(#default_expr) }; } + if let Some(extra) = extra { + match_row = quote! { #match_row.extra(#extra.into()) }; + } columns_trait.push(match_row); } } diff --git a/src/entity/column.rs b/src/entity/column.rs index 69f5ef319..7a2997af5 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -12,6 +12,7 @@ pub struct ColumnDef { pub(crate) created_at: bool, pub(crate) updated_at: bool, pub(crate) default_value: Option, + pub(crate) extra: Option, } /// The type of column as defined in the SQL format @@ -314,6 +315,7 @@ impl ColumnType { created_at: false, updated_at: false, default_value: None, + extra: None, } } @@ -370,6 +372,12 @@ impl ColumnDef { self } + /// Set the extra + pub fn extra(mut self, value: String) -> Self { + self.extra = Some(value); + self + } + /// Get [ColumnType] as reference pub fn get_column_type(&self) -> &ColumnType { &self.col_type diff --git a/src/schema/entity.rs b/src/schema/entity.rs index c09862be2..77e827a8b 100644 --- a/src/schema/entity.rs +++ b/src/schema/entity.rs @@ -143,6 +143,9 @@ where if let Some(value) = orm_column_def.default_value { column_def.default(value); } + if let Some(value) = orm_column_def.extra { + column_def.extra(value); + } for primary_key in E::PrimaryKey::iter() { if column.to_string() == primary_key.into_column().to_string() { if E::PrimaryKey::auto_increment() { diff --git a/tests/common/features/check.rs b/tests/common/features/check.rs index 860d007bc..a24840cd0 100644 --- a/tests/common/features/check.rs +++ b/tests/common/features/check.rs @@ -7,9 +7,9 @@ pub struct Model { pub id: i32, pub pay: String, pub amount: f64, - #[sea_orm(updated_at, nullable)] + #[sea_orm(updated_at, nullable, extra = "DEFAULT CURRENT_TIMESTAMP")] pub updated_at: DateTimeWithTimeZone, - #[sea_orm(created_at, nullable)] + #[sea_orm(created_at, nullable, extra = "DEFAULT CURRENT_TIMESTAMP")] pub created_at: DateTimeWithTimeZone, } diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 9b375b352..a3bd2609a 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -340,8 +340,16 @@ pub async fn create_check_table(db: &DbConn) -> Result { ) .col(ColumnDef::new(check::Column::Pay).string().not_null()) .col(ColumnDef::new(check::Column::Amount).double().not_null()) - .col(ColumnDef::new(check::Column::UpdatedAt).timestamp_with_time_zone()) - .col(ColumnDef::new(check::Column::CreatedAt).timestamp_with_time_zone()) + .col( + ColumnDef::new(check::Column::UpdatedAt) + .timestamp_with_time_zone() + .extra("DEFAULT CURRENT_TIMESTAMP".into()), + ) + .col( + ColumnDef::new(check::Column::CreatedAt) + .timestamp_with_time_zone() + .extra("DEFAULT CURRENT_TIMESTAMP".into()), + ) .to_owned(); create_table(db, &stmt, Check).await