Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto set timestamp column when update & insert #854

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions sea-orm-macros/src/derives/entity_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
let mut default_expr = None;
let mut select_as = None;
let mut save_as = None;
let mut extra = None;
let mut created_at = false;
let mut updated_at = false;
let mut indexed = false;
let mut ignore = false;
let mut unique = false;
Expand Down Expand Up @@ -191,6 +194,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
format!("Invalid save_as {:?}", nv.lit),
));
}
} else if name == "extra" {
extra = Some(nv.lit.to_owned());
}
}
}
Expand All @@ -208,6 +213,10 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
indexed = true;
} else if name == "unique" {
unique = true;
} else if name == "created_at" {
created_at = true;
} else if name == "updated_at" {
updated_at = true;
billy1624 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -337,12 +346,21 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
if unique {
match_row = quote! { #match_row.unique() };
}
if created_at {
match_row = quote! { #match_row.created_at() };
}
if updated_at {
match_row = quote! { #match_row.updated_at() };
}
if let Some(default_value) = default_value {
match_row = quote! { #match_row.default_value(#default_value) };
}
if let Some(default_expr) = default_expr {
match_row = quote! { #match_row.default_expr(#default_expr) };
}
if let Some(extra) = extra {
match_row = quote! { #match_row.extra(#extra.into()) };
}
columns_trait.push(match_row);
}
}
Expand Down
24 changes: 24 additions & 0 deletions src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ pub struct ColumnDef {
pub(crate) null: bool,
pub(crate) unique: bool,
pub(crate) indexed: bool,
pub(crate) created_at: bool,
pub(crate) updated_at: bool,
pub(crate) default_value: Option<Value>,
pub(crate) extra: Option<String>,
}

macro_rules! bind_oper {
Expand Down Expand Up @@ -295,7 +298,10 @@ impl ColumnTypeTrait for ColumnType {
null: false,
unique: false,
indexed: false,
created_at: false,
updated_at: false,
default_value: None,
extra: None,
}
}

Expand Down Expand Up @@ -335,6 +341,18 @@ impl ColumnDef {
self
}

/// Set the `created_at` field to `true`
pub fn created_at(mut self) -> Self {
self.created_at = true;
self
}

/// Set the `updated_at` field to `true`
pub fn updated_at(mut self) -> Self {
self.updated_at = true;
self
}

/// Set the default value
pub fn default_value<T>(mut self, value: T) -> Self
where
Expand All @@ -344,6 +362,12 @@ impl ColumnDef {
self
}

/// Set the extra
pub fn extra(mut self, value: String) -> Self {
self.extra = Some(value);
self
}

/// Get [ColumnType] as reference
pub fn get_column_type(&self) -> &ColumnType {
&self.col_type
Expand Down
14 changes: 12 additions & 2 deletions src/query/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,23 @@ where
let columns_empty = self.columns.is_empty();
for (idx, col) in <A::Entity as EntityTrait>::Column::iter().enumerate() {
let av = am.take(col);
let av_has_val = av.is_set() || av.is_unchanged();
let col_def = col.def();
let insert_timestamp_col = col_def.created_at || col_def.updated_at;
let av_has_val = av.is_set() || av.is_unchanged() || insert_timestamp_col;
if columns_empty {
self.columns.push(av_has_val);
} else if self.columns[idx] != av_has_val {
panic!("columns mismatch");
}
if av_has_val {

if insert_timestamp_col {
columns.push(col);
let val = match av.into_value() {
Some(v) => Expr::value(v),
None => Expr::current_timestamp().into(),
};
values.push(val);
} else if av_has_val {
columns.push(col);
values.push(col.save_as(Expr::val(av.into_value().unwrap())));
}
Expand Down
3 changes: 3 additions & 0 deletions src/query/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,13 @@ where
if <A::Entity as EntityTrait>::PrimaryKey::from_column(col).is_some() {
continue;
}
let col_def = col.def();
let av = self.model.get(col);
if av.is_set() {
let expr = col.save_as(Expr::val(av.into_value().unwrap()));
self.query.value(col, expr);
} else if col_def.updated_at {
self.query.value(col, Expr::current_timestamp());
}
}
self
Expand Down
3 changes: 3 additions & 0 deletions src/schema/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ where
if let Some(value) = orm_column_def.default_value {
column_def.default(value);
}
if let Some(value) = orm_column_def.extra {
column_def.default(value);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a typo

Suggested change
if let Some(value) = orm_column_def.extra {
column_def.default(value);
}
if let Some(value) = orm_column_def.extra {
column_def.extra(value);
}

Copy link
Contributor Author

@liberwang1013 liberwang1013 Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @billy1624 thanks for your comment,
what about replaced with column_def.default(SimpleExpr::CustomWithExpr(value, vec![]));, I found the generated default expr in schema is going to quoted with ' in this(column_def.default(value);) way.

for primary_key in E::PrimaryKey::iter() {
if column.to_string() == primary_key.into_column().to_string() {
if E::PrimaryKey::auto_increment() {
Expand Down
154 changes: 154 additions & 0 deletions tests/check_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
pub mod common;

pub use common::{features::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection};
use std::{thread, time::Duration};

#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
async fn main() -> Result<(), DbErr> {
let ctx = TestContext::new("check_tests").await;
create_tables(&ctx.db).await?;
insert_check(&ctx.db).await?;
update_check(&ctx.db).await?;
ctx.delete().await;

Ok(())
}

pub async fn insert_check(db: &DatabaseConnection) -> Result<(), DbErr> {
use check::*;

let timestamp = "2022-08-03T00:00:00+08:00"
.parse::<DateTimeWithTimeZone>()
.unwrap();

let model = ActiveModel {
pay: Set("Billy".to_owned()),
amount: Set(100.0),
..Default::default()
}
.insert(db)
.await?;

assert_eq!(
model,
Entity::find()
.filter(Column::Id.eq(1))
.one(db)
.await?
.unwrap()
);

Check::insert_many([
ActiveModel {
pay: Set("Billy".to_owned()),
amount: Set(100.0),
..Default::default()
},
ActiveModel {
pay: Set("Billy".to_owned()),
amount: Set(100.0),
created_at: Set(timestamp.clone()),
..Default::default()
},
ActiveModel {
pay: Set("Billy".to_owned()),
amount: Set(100.0),
updated_at: Set(timestamp.clone()),
..Default::default()
},
ActiveModel {
pay: Set("Billy".to_owned()),
amount: Set(100.0),
updated_at: Set(timestamp.clone()),
created_at: Set(timestamp.clone()),
..Default::default()
},
])
.exec(db)
.await?;

assert_eq!(5, Entity::find().count(db).await?);

assert_eq!(
timestamp,
Entity::find()
.filter(Column::Id.eq(3))
.one(db)
.await?
.unwrap()
.created_at
);

assert_eq!(
timestamp,
Entity::find()
.filter(Column::Id.eq(4))
.one(db)
.await?
.unwrap()
.updated_at
);

let model = Entity::find()
.filter(Column::Id.eq(5))
.one(db)
.await?
.unwrap();

assert_eq!(timestamp, model.updated_at);
assert_eq!(timestamp, model.created_at);

Ok(())
}

pub async fn update_check(db: &DatabaseConnection) -> Result<(), DbErr> {
use check::*;

let timestamp = "2022-08-03T16:24:00+08:00"
.parse::<DateTimeWithTimeZone>()
.unwrap();

let model = Entity::find()
.filter(Column::Id.eq(1))
.one(db)
.await?
.unwrap();

thread::sleep(Duration::from_secs(1));

let updated_model = ActiveModel {
amount: Set(128.0),
..model.clone().into_active_model()
}
.update(db)
.await?;

assert_eq!(128.0, updated_model.amount);
assert!(model.updated_at < updated_model.updated_at);
assert!(model.created_at == updated_model.created_at);

let model = Entity::find()
.filter(Column::Id.eq(1))
.one(db)
.await?
.unwrap();

let updated_model = ActiveModel {
updated_at: Set(timestamp.clone()),
..model.clone().into_active_model()
}
.update(db)
.await?;

assert_eq!(timestamp.clone(), updated_model.updated_at);
assert!(model.created_at == updated_model.created_at);

Ok(())
}
19 changes: 19 additions & 0 deletions tests/common/features/check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use sea_orm::entity::prelude::*;

#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "check")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub pay: String,
pub amount: f64,
#[sea_orm(updated_at, nullable, extra = "DEFAULT CURRENT_TIMESTAMP")]
pub updated_at: DateTimeWithTimeZone,
#[sea_orm(created_at, nullable, extra = "DEFAULT CURRENT_TIMESTAMP")]
pub created_at: DateTimeWithTimeZone,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}
2 changes: 2 additions & 0 deletions tests/common/features/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod active_enum;
pub mod active_enum_child;
pub mod applog;
pub mod byte_primary_key;
pub mod check;
pub mod collection;
pub mod custom_active_model;
pub mod edit_log;
Expand All @@ -23,6 +24,7 @@ pub use active_enum::Entity as ActiveEnum;
pub use active_enum_child::Entity as ActiveEnumChild;
pub use applog::Entity as Applog;
pub use byte_primary_key::Entity as BytePrimaryKey;
pub use check::Entity as Check;
pub use collection::Entity as Collection;
pub use edit_log::Entity as EditLog;
pub use event_trigger::Entity as EventTrigger;
Expand Down
28 changes: 28 additions & 0 deletions tests/common/features/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> {
create_pi_table(db).await?;
create_uuid_fmt_table(db).await?;
create_edit_log_table(db).await?;
create_check_table(db).await?;

if DbBackend::Postgres == db_backend {
create_collection_table(db).await?;
Expand Down Expand Up @@ -498,3 +499,30 @@ pub async fn create_edit_log_table(db: &DbConn) -> Result<ExecResult, DbErr> {

create_table(db, &stmt, EditLog).await
}

pub async fn create_check_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let stmt = sea_query::Table::create()
.table(check::Entity)
.col(
ColumnDef::new(check::Column::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(ColumnDef::new(check::Column::Pay).string().not_null())
.col(ColumnDef::new(check::Column::Amount).double().not_null())
.col(
ColumnDef::new(check::Column::UpdatedAt)
.timestamp_with_time_zone()
.default("CURRENT_TIMESTAMP"),
)
.col(
ColumnDef::new(check::Column::CreatedAt)
.timestamp_with_time_zone()
.default("CURRENT_TIMESTAMP"),
)
.to_owned();

create_table(db, &stmt, Check).await
}