Skip to content

Commit

Permalink
add with_schema_name and with_table_name for mysql (maxcountryman#43
Browse files Browse the repository at this point in the history
)
  • Loading branch information
DirectorX authored and vinnymeller committed Aug 27, 2024
1 parent 1a8b2e3 commit 9427ea6
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
53 changes: 51 additions & 2 deletions sqlx-store/src/mysql_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@ impl MySqlStore {
}
}

/// Set the session table schema name with the provided name.
pub fn with_schema_name(mut self, schema_name: impl AsRef<str>) -> Result<Self, String> {
let schema_name = schema_name.as_ref();
if !is_valid_identifier(schema_name) {
return Err(format!(
"Invalid schema name '{}'. Schema names must start with a letter or underscore \
(including letters with diacritical marks and non-Latin letters).Subsequent \
characters can be letters, underscores, digits (0-9), or dollar signs ($).",
schema_name
));
}

schema_name.clone_into(&mut self.schema_name);
Ok(self)
}

/// Set the session table name with the provided name.
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
let table_name = table_name.as_ref();
if !is_valid_identifier(table_name) {
return Err(format!(
"Invalid table name '{}'. Table names must start with a letter or underscore \
(including letters with diacritical marks and non-Latin letters).Subsequent \
characters can be letters, underscores, digits (0-9), or dollar signs ($).",
table_name
));
}

table_name.clone_into(&mut self.table_name);
Ok(self)
}

/// Migrate the session schema.
///
/// # Examples
Expand Down Expand Up @@ -112,7 +144,7 @@ impl MySqlStore {
table_name = self.table_name
);
sqlx::query(&query)
.bind(&record.id.to_string())
.bind(record.id.to_string())
.bind(rmp_serde::to_vec(&record).map_err(SqlxStoreError::Encode)?)
.bind(convert_expiry_date(record.expiry_date))
.execute(conn)
Expand Down Expand Up @@ -193,11 +225,28 @@ impl SessionStore for MySqlStore {
table_name = self.table_name
);
sqlx::query(&query)
.bind(&session_id.to_string())
.bind(session_id.to_string())
.execute(&self.pool)
.await
.map_err(SqlxStoreError::Sqlx)?;

Ok(())
}
}

/// A valid MySQL identifier must start with a letter or underscore
/// (including letters with diacritical marks and non-Latin letters). Subsequent
/// characters in an identifier or keyword can be letters, underscores, digits
/// (0-9), or dollar signs ($).
/// See https://dev.mysql.com/doc/refman/8.4/en/identifiers.html for details.
fn is_valid_identifier(name: &str) -> bool {
!name.is_empty()
&& name
.chars()
.next()
.map(|c| c.is_alphabetic() || c == '_')
.unwrap_or_default()
&& name
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '$')
}
4 changes: 2 additions & 2 deletions sqlx-store/src/postgres_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl PostgresStore {
);

sqlx::query(&query)
.bind(&record.id.to_string())
.bind(record.id.to_string())
.bind(rmp_serde::to_vec(&record).map_err(SqlxStoreError::Encode)?)
.bind(convert_expiry_date(record.expiry_date))
.execute(conn)
Expand Down Expand Up @@ -240,7 +240,7 @@ impl SessionStore for PostgresStore {
table_name = self.table_name
);
sqlx::query(&query)
.bind(&session_id.to_string())
.bind(session_id.to_string())
.execute(&self.pool)
.await
.map_err(SqlxStoreError::Sqlx)?;
Expand Down
4 changes: 2 additions & 2 deletions sqlx-store/src/sqlite_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl SqliteStore {
table_name = self.table_name
);
sqlx::query(&query)
.bind(&record.id.to_string())
.bind(record.id.to_string())
.bind(rmp_serde::to_vec(record).map_err(SqlxStoreError::Encode)?)
.bind(convert_expiry_date(record.expiry_date))
.execute(conn)
Expand Down Expand Up @@ -179,7 +179,7 @@ impl SessionStore for SqliteStore {
self.table_name
);
sqlx::query(&query)
.bind(&session_id.to_string())
.bind(session_id.to_string())
.execute(&self.pool)
.await
.map_err(SqlxStoreError::Sqlx)?;
Expand Down

0 comments on commit 9427ea6

Please sign in to comment.