diff --git a/src/db/spanner/batch.rs b/src/db/spanner/batch.rs index e145ea9a31..b4a124376c 100644 --- a/src/db/spanner/batch.rs +++ b/src/db/spanner/batch.rs @@ -11,7 +11,7 @@ use protobuf::{ use uuid::Uuid; use super::models::{Result, SpannerDb, DEFAULT_BSO_TTL, PRETOUCH_TS}; -use super::support::{null_value, struct_type_field, ToSpannerValue}; +use super::support::{as_type, null_value, struct_type_field, ToSpannerValue}; use crate::{ db::{params, results, util::to_rfc3339, DbError, DbErrorKind, BATCH_LIFETIME}, web::{extractors::HawkIdentifier, tags::Tags}, @@ -35,20 +35,20 @@ pub async fn create_async( id: batch_id, }; - db.sql( - "INSERT INTO batches (fxa_uid, fxa_kid, collection_id, batch_id, expiry) - VALUES (@fxa_uid, @fxa_kid, @collection_id, @batch_id, @expiry)", - )? - .params(params! { + let (sqlparams, mut sqlparam_types) = params! { "fxa_uid" => params.user_id.fxa_uid.clone(), "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, "batch_id" => new_batch.id.clone(), "expiry" => to_rfc3339(timestamp + BATCH_LIFETIME)?, - }) - .param_types(param_types! { - "expiry" => TypeCode::TIMESTAMP, - }) + }; + sqlparam_types.insert("expiry".to_owned(), as_type(TypeCode::TIMESTAMP)); + db.sql( + "INSERT INTO batches (fxa_uid, fxa_kid, collection_id, batch_id, expiry) + VALUES (@fxa_uid, @fxa_kid, @collection_id, @batch_id, @expiry)", + )? + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&db.conn) .await?; @@ -116,6 +116,12 @@ pub async fn get_async( params: params::GetBatch, ) -> Result> { let collection_id = db.get_collection_id_async(¶ms.collection).await?; + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => collection_id, + "batch_id" => params.id.clone(), + }; let batch = db .sql( "SELECT 1 @@ -126,12 +132,8 @@ pub async fn get_async( AND batch_id = @batch_id AND expiry > CURRENT_TIMESTAMP()", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - "batch_id" => params.id.clone(), - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&db.conn)? .one_or_none() .await? @@ -141,6 +143,12 @@ pub async fn get_async( pub async fn delete_async(db: &SpannerDb, params: params::DeleteBatch) -> Result<()> { let collection_id = db.get_collection_id_async(¶ms.collection).await?; + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => collection_id, + "batch_id" => params.id, + }; // Also deletes child batch_bsos rows (INTERLEAVE IN PARENT batches ON // DELETE CASCADE) db.sql( @@ -150,12 +158,8 @@ pub async fn delete_async(db: &SpannerDb, params: params::DeleteBatch) -> Result AND collection_id = @collection_id AND batch_id = @batch_id", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - "batch_id" => params.id, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&db.conn) .await?; Ok(()) @@ -181,17 +185,17 @@ pub async fn commit_async( // supplied in this batch let mut timer2 = db.metrics.clone(); timer2.start_timer("storage.spanner.apply_batch_update", None); + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => collection_id, + "batch_id" => params.batch.id.clone(), + "timestamp" => as_rfc3339.clone(), + }; + sqlparam_types.insert("timestamp".to_owned(), as_type(TypeCode::TIMESTAMP)); db.sql(include_str!("batch_commit_update.sql"))? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - "batch_id" => params.batch.id.clone(), - "timestamp" => as_rfc3339.clone(), - }) - .param_types(param_types! { - "timestamp" => TypeCode::TIMESTAMP, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&db.conn) .await?; } @@ -199,21 +203,20 @@ pub async fn commit_async( { // Then INSERT INTO SELECT remaining rows from this batch into the bsos // table (that didn't already exist there) + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => collection_id, + "batch_id" => params.batch.id.clone(), + "timestamp" => as_rfc3339, + "default_bso_ttl" => DEFAULT_BSO_TTL, + }; + sqlparam_types.insert("timestamp".to_owned(), as_type(TypeCode::TIMESTAMP)); let mut timer3 = db.metrics.clone(); timer3.start_timer("storage.spanner.apply_batch_insert", None); db.sql(include_str!("batch_commit_insert.sql"))? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - "batch_id" => params.batch.id.clone(), - "timestamp" => as_rfc3339, - "default_bso_ttl" => DEFAULT_BSO_TTL.to_string(), - }) - .param_types(param_types! { - "timestamp" => TypeCode::TIMESTAMP, - "default_bso_ttl" => TypeCode::INT64, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&db.conn) .await?; } @@ -288,17 +291,17 @@ pub async fn do_append_async( .unwrap_or_else(|| "UNKNOWN".to_string()), ); - let bso_ids = bsos.iter().map(|pbso| pbso.id.clone()); - let mut params = params! { + let bso_ids = bsos + .iter() + .map(|pbso| pbso.id.clone()) + .collect::>(); + let (sqlparams, sqlparam_types) = params! { "fxa_uid" => user_id.fxa_uid.clone(), "fxa_kid" => user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, "batch_id" => batch.id.clone(), + "ids" => bso_ids, }; - params.insert( - "ids".to_owned(), - bso_ids.collect::>().to_spanner_value(), - ); let mut existing_stream = db .sql( "SELECT batch_bso_id @@ -309,7 +312,8 @@ pub async fn do_append_async( AND batch_id=@batch_id AND batch_bso_id in UNNEST(@ids);", )? - .params(params) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&db.conn)?; while let Some(row) = existing_stream.next_async().await { let row = row?; @@ -410,16 +414,6 @@ pub async fn do_append_async( } } - let param_types = param_types! { // ### TODO: this should be normalized to one instance. - "fxa_uid" => TypeCode::STRING, - "fxa_kid"=> TypeCode::STRING, - "collection_id"=> TypeCode::INT64, - "batch_id"=> TypeCode::STRING, - "batch_bso_id"=> TypeCode::STRING, - "sortindex"=> TypeCode::INT64, - "payload"=> TypeCode::STRING, - "ttl"=> TypeCode::INT64, - }; let fields = vec![ ("fxa_uid", TypeCode::STRING), ("fxa_kid", TypeCode::STRING), @@ -477,24 +471,27 @@ pub async fn do_append_async( if !update.is_empty() { for val in update { let mut fields = Vec::new(); - let mut params = params! { + let (mut params, mut param_types) = params! { "fxa_uid" => user_id.fxa_uid.clone(), "fxa_kid" => user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, "batch_id" => batch.id.clone(), "batch_bso_id" => val.bso_id, }; if let Some(sortindex) = val.sortindex { fields.push("sortindex"); params.insert("sortindex".to_owned(), sortindex.to_spanner_value()); + param_types.insert("sortindex".to_owned(), sortindex.spanner_type()); } if let Some(payload) = val.payload { fields.push("payload"); params.insert("payload".to_owned(), payload.to_spanner_value()); + param_types.insert("payload".to_owned(), payload.spanner_type()); }; if let Some(ttl) = val.ttl { fields.push("ttl"); params.insert("ttl".to_owned(), ttl.to_spanner_value()); + param_types.insert("ttl".to_owned(), ttl.spanner_type()); } if fields.is_empty() { continue; @@ -533,10 +530,10 @@ async fn pretouch_collection_async( user_id: &HawkIdentifier, collection_id: i32, ) -> Result<()> { - let mut sqlparams = params! { + let (mut sqlparams, mut sqlparam_types) = params! { "fxa_uid" => user_id.fxa_uid.clone(), "fxa_kid" => user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, }; let result = db .sql( @@ -547,6 +544,7 @@ async fn pretouch_collection_async( AND collection_id = @collection_id", )? .params(sqlparams.clone()) + .param_types(sqlparam_types.clone()) .execute_async(&db.conn)? .one_or_none() .await?; @@ -555,6 +553,7 @@ async fn pretouch_collection_async( "modified".to_owned(), PRETOUCH_TS.to_owned().to_spanner_value(), ); + sqlparam_types.insert("modified".to_owned(), as_type(TypeCode::TIMESTAMP)); let sql = if db.quota.enabled { "INSERT INTO user_collections (fxa_uid, fxa_kid, collection_id, modified, count, total_bytes) VALUES (@fxa_uid, @fxa_kid, @collection_id, @modified, 0, 0)" @@ -564,9 +563,7 @@ async fn pretouch_collection_async( }; db.sql(sql)? .params(sqlparams) - .param_types(param_types! { - "modified" => TypeCode::TIMESTAMP, - }) + .param_types(sqlparam_types) .execute_dml_async(&db.conn) .await?; } diff --git a/src/db/spanner/macros.rs b/src/db/spanner/macros.rs index 479ce421b4..0b494382f3 100644 --- a/src/db/spanner/macros.rs +++ b/src/db/spanner/macros.rs @@ -6,28 +6,144 @@ macro_rules! params { ($($key:expr => $value:expr),*) => { { let _cap = params!(@count $($key),*); - let mut _map = ::std::collections::HashMap::with_capacity(_cap); + let mut _value_map = ::std::collections::HashMap::with_capacity(_cap); + let mut _type_map = ::std::collections::HashMap::with_capacity(_cap); $( - _map.insert($key.to_owned(), ToSpannerValue::to_spanner_value(&$value)); + _value_map.insert($key.to_owned(), ToSpannerValue::to_spanner_value(&$value)); + _type_map.insert($key.to_owned(), ToSpannerValue::spanner_type(&$value)); )* - _map + (_value_map, _type_map) } }; } -macro_rules! param_types { - (@single $($x:tt)*) => (()); - (@count $($rest:expr),*) => (<[()]>::len(&[$(param_types!(@single $rest)),*])); +#[test] +fn test_params_macro() { + use crate::db::spanner::support::ToSpannerValue; + use googleapis_raw::spanner::v1::type_pb::{Type, TypeCode}; + use protobuf::{ + well_known_types::{ListValue, Value}, + RepeatedField, + }; + use std::collections::HashMap; - ($($key:expr => $value:expr,)+) => { param_types!($($key => $value),+) }; - ($($key:expr => $value:expr),*) => { - { - let _cap = param_types!(@count $($key),*); - let mut _map = ::std::collections::HashMap::with_capacity(_cap); - $( - _map.insert($key.to_owned(), crate::db::spanner::support::as_type($value)); - )* - _map - } + let (sqlparams, sqlparam_types) = params! { + "String param" => "I am a String".to_owned(), + "i32 param" => 100i32, + "u32 param" => 100u32, + "Vec param" => vec!["I am a String".to_owned()], + "Vec param" => vec![100i32], + "Vec param" => vec![100u32], + }; + + let mut expected_sqlparams = HashMap::new(); + let string_value = { + let mut t = Value::new(); + t.set_string_value("I am a String".to_owned()); + t + }; + expected_sqlparams.insert("String param".to_owned(), string_value.clone()); + + let i32_value = { + let mut t = Value::new(); + t.set_string_value(100i32.to_string()); + t + }; + expected_sqlparams.insert("i32 param".to_owned(), i32_value.clone()); + + let u32_value = { + let mut t = Value::new(); + t.set_string_value(100u32.to_string()); + t + }; + expected_sqlparams.insert("u32 param".to_owned(), u32_value.clone()); + + let string_vec_value = { + let mut list = ListValue::new(); + list.set_values(RepeatedField::from_vec(vec![string_value])); + let mut value = Value::new(); + value.set_list_value(list); + value }; + expected_sqlparams.insert("Vec param".to_owned(), string_vec_value); + + let i32_vec_value = { + let mut list = ListValue::new(); + list.set_values(RepeatedField::from_vec(vec![i32_value])); + let mut value = Value::new(); + value.set_list_value(list); + value + }; + expected_sqlparams.insert("Vec param".to_owned(), i32_vec_value); + + let u32_vec_value = { + let mut list = ListValue::new(); + list.set_values(RepeatedField::from_vec(vec![u32_value])); + let mut value = Value::new(); + value.set_list_value(list); + value + }; + expected_sqlparams.insert("Vec param".to_owned(), u32_vec_value); + + let mut expected_sqlparam_types = HashMap::new(); + + let string_type = { + let mut t = Type::new(); + t.set_code(TypeCode::STRING); + t + }; + expected_sqlparam_types.insert("String param".to_owned(), string_type); + + let i32_type = { + let mut t = Type::new(); + t.set_code(TypeCode::INT64); + t + }; + expected_sqlparam_types.insert("i32 param".to_owned(), i32_type); + + let u32_type = { + let mut t = Type::new(); + t.set_code(TypeCode::INT64); + t + }; + expected_sqlparam_types.insert("u32 param".to_owned(), u32_type); + + let string_vec_type = { + let mut element_type = Type::new(); + element_type.set_code(TypeCode::STRING); + + let mut vec_type = Type::new(); + vec_type.set_code(TypeCode::ARRAY); + vec_type.set_array_element_type(element_type); + + vec_type + }; + expected_sqlparam_types.insert("Vec param".to_owned(), string_vec_type); + + let i32_vec_type = { + let mut element_type = Type::new(); + element_type.set_code(TypeCode::INT64); + + let mut vec_type = Type::new(); + vec_type.set_code(TypeCode::ARRAY); + vec_type.set_array_element_type(element_type); + + vec_type + }; + expected_sqlparam_types.insert("Vec param".to_owned(), i32_vec_type); + + let u32_vec_type = { + let mut element_type = Type::new(); + element_type.set_code(TypeCode::INT64); + + let mut vec_type = Type::new(); + vec_type.set_code(TypeCode::ARRAY); + vec_type.set_array_element_type(element_type); + + vec_type + }; + expected_sqlparam_types.insert("Vec param".to_owned(), u32_vec_type); + + assert_eq!(expected_sqlparams, sqlparams); + assert_eq!(expected_sqlparam_types, sqlparam_types); } diff --git a/src/db/spanner/models.rs b/src/db/spanner/models.rs index 2dd1591b6b..c4dbb190b8 100644 --- a/src/db/spanner/models.rs +++ b/src/db/spanner/models.rs @@ -144,13 +144,15 @@ impl SpannerDb { if let Some(id) = self.coll_cache.get_id(name).await { return Ok(id); } + let (sqlparams, sqlparam_types) = params! { "name" => name.to_string() }; let result = self .sql( "SELECT collection_id FROM collections WHERE name = @name", )? - .params(params! {"name" => name.to_string()}) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await? @@ -185,15 +187,17 @@ impl SpannerDb { .parse::() .map_err(|e| DbErrorKind::Integrity(e.to_string()))?; let id = FIRST_CUSTOM_COLLECTION_ID.max(max + 1); + let (sqlparams, sqlparam_types) = params! { + "name" => name.to_string(), + "collection_id" => id, + }; self.sql( "INSERT INTO collections (collection_id, name) VALUES (@collection_id, @name)", )? - .params(params! { - "name" => name.to_string(), - "collection_id" => id.to_string(), - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; Ok(id) @@ -260,6 +264,13 @@ impl SpannerDb { { Err(DbError::internal("Can't escalate read-lock to write-lock"))? } + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => collection_id, + "pretouch_ts" => PRETOUCH_TS.to_owned(), + }; + sqlparam_types.insert("pretouch_ts".to_owned(), as_type(TypeCode::TIMESTAMP)); let result = self .sql( @@ -270,15 +281,8 @@ impl SpannerDb { AND collection_id = @collection_id AND modified > @pretouch_ts", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - "pretouch_ts" => PRETOUCH_TS.to_owned(), - }) - .param_types(param_types! { - "pretouch_ts" => TypeCode::TIMESTAMP, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await?; @@ -550,6 +554,13 @@ impl SpannerDb { { return Ok(*modified); } + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid, + "fxa_kid" => params.user_id.fxa_kid, + "collection_id" => collection_id, + "pretouch_ts" => PRETOUCH_TS.to_owned(), + }; + sqlparam_types.insert("pretouch_ts".to_owned(), as_type(TypeCode::TIMESTAMP)); let result = self .sql( @@ -560,15 +571,8 @@ impl SpannerDb { AND collection_id = @collection_id AND modified > @pretouch_ts", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid, - "fxa_kid" => params.user_id.fxa_kid, - "collection_id" => collection_id.to_string(), - "pretouch_ts" => PRETOUCH_TS.to_owned(), - }) - .param_types(param_types! { - "pretouch_ts" => TypeCode::TIMESTAMP, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await? @@ -581,6 +585,13 @@ impl SpannerDb { &self, user_id: params::GetCollectionTimestamps, ) -> Result { + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => user_id.fxa_uid, + "fxa_kid" => user_id.fxa_kid, + "collection_id" => TOMBSTONE, + "pretouch_ts" => PRETOUCH_TS.to_owned(), + }; + sqlparam_types.insert("pretouch_ts".to_owned(), as_type(TypeCode::TIMESTAMP)); let mut streaming = self .sql( "SELECT collection_id, modified @@ -590,15 +601,8 @@ impl SpannerDb { AND collection_id != @collection_id AND modified > @pretouch_ts", )? - .params(params! { - "fxa_uid" => user_id.fxa_uid, - "fxa_kid" => user_id.fxa_kid, - "collection_id" => TOMBSTONE.to_string(), - "pretouch_ts" => PRETOUCH_TS.to_owned(), - }) - .param_types(param_types! { - "pretouch_ts" => TypeCode::TIMESTAMP, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)?; let mut results = HashMap::new(); while let Some(row) = streaming.next_async().await { @@ -674,6 +678,10 @@ impl SpannerDb { &self, user_id: params::GetCollectionCounts, ) -> Result { + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => user_id.fxa_uid, + "fxa_kid" => user_id.fxa_kid, + }; let mut streaming = self .sql( "SELECT collection_id, COUNT(collection_id) @@ -683,10 +691,8 @@ impl SpannerDb { AND expiry > CURRENT_TIMESTAMP() GROUP BY collection_id", )? - .params(params! { - "fxa_uid" => user_id.fxa_uid, - "fxa_kid" => user_id.fxa_kid, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)?; let mut counts = HashMap::new(); while let Some(row) = streaming.next_async().await { @@ -708,6 +714,10 @@ impl SpannerDb { &self, user_id: params::GetCollectionUsage, ) -> Result { + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => user_id.fxa_uid, + "fxa_kid" => user_id.fxa_kid + }; let mut streaming = self .sql( "SELECT collection_id, SUM(BYTE_LENGTH(payload)) @@ -717,10 +727,8 @@ impl SpannerDb { AND expiry > CURRENT_TIMESTAMP() GROUP BY collection_id", )? - .params(params! { - "fxa_uid" => user_id.fxa_uid, - "fxa_kid" => user_id.fxa_kid - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)?; let mut usages = HashMap::new(); while let Some(row) = streaming.next_async().await { @@ -742,6 +750,12 @@ impl SpannerDb { &self, user_id: params::GetStorageTimestamp, ) -> Result { + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => user_id.fxa_uid, + "fxa_kid" => user_id.fxa_kid, + "pretouch_ts" => PRETOUCH_TS.to_owned(), + }; + sqlparam_types.insert("pretouch_ts".to_owned(), as_type(TypeCode::TIMESTAMP)); let row = self .sql( "SELECT MAX(modified) @@ -750,14 +764,8 @@ impl SpannerDb { AND fxa_kid = @fxa_kid AND modified > @pretouch_ts", )? - .params(params! { - "fxa_uid" => user_id.fxa_uid, - "fxa_kid" => user_id.fxa_kid, - "pretouch_ts" => PRETOUCH_TS.to_owned(), - }) - .param_types(param_types! { - "pretouch_ts" => TypeCode::TIMESTAMP, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one() .await?; @@ -772,6 +780,10 @@ impl SpannerDb { &self, user_id: params::GetStorageUsage, ) -> Result { + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => user_id.fxa_uid, + "fxa_kid" => user_id.fxa_kid + }; let result = self .sql( "SELECT SUM(BYTE_LENGTH(payload)) @@ -781,10 +793,8 @@ impl SpannerDb { AND expiry > CURRENT_TIMESTAMP() GROUP BY fxa_uid", )? - .params(params! { - "fxa_uid" => user_id.fxa_uid, - "fxa_kid" => user_id.fxa_kid - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await?; @@ -811,13 +821,15 @@ impl SpannerDb { WHERE fxa_uid = @fxa_uid AND fxa_kid = @fxa_kid AND collection_id = @collection_id"; + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => params.collection_id, + }; let result = self .sql(check_sql)? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => params.collection_id.to_string(), - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await?; @@ -851,15 +863,13 @@ impl SpannerDb { // specifying a TOMBSTONE collection_id. // This function should be called after any write operation. let timestamp = self.timestamp()?; - let mut sqlparams = params! { + let (mut sqlparams, mut sqltypes) = params! { "fxa_uid" => user.fxa_uid.clone(), "fxa_kid" => user.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, "modified" => timestamp.as_rfc3339()?, }; - let mut sqltypes = param_types! { - "modified" => TypeCode::TIMESTAMP, - }; + sqltypes.insert("modified".to_owned(), as_type(TypeCode::TIMESTAMP)); self.metrics .clone() @@ -879,16 +889,21 @@ impl SpannerDb { AND collection_id = @collection_id GROUP BY fxa_uid" }; - let result = self - .sql(calc_sql)? - .params(params! { + + let result = { + let (sqlparams, sqlparam_types) = params! { "fxa_uid" => user.fxa_uid.clone(), "fxa_kid" => user.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - }) - .execute_async(&self.conn)? - .one_or_none() - .await?; + "collection_id" => collection_id, + }; + + self.sql(calc_sql)? + .params(sqlparams) + .param_types(sqlparam_types) + .execute_async(&self.conn)? + .one_or_none() + .await? + }; let set_sql = if let Some(mut result) = result { // Update the user_collections table to reflect current numbers. // If there are BSOs, there are user_collections (or else something @@ -968,16 +983,13 @@ impl SpannerDb { async fn erect_tombstone(&self, user_id: &HawkIdentifier) -> Result { // Delete the old tombstone (if it exists) - let params = params! { + let (params, mut param_types) = params! { "fxa_uid" => user_id.fxa_uid.clone(), "fxa_kid" => user_id.fxa_kid.clone(), - "collection_id" => TOMBSTONE.to_string(), + "collection_id" => TOMBSTONE, "modified" => self.timestamp()?.as_rfc3339()? }; - let types = param_types! { - "collection_id" => TypeCode::INT64, - "modified" => TypeCode::TIMESTAMP, - }; + param_types.insert("modified".to_owned(), as_type(TypeCode::TIMESTAMP)); self.sql( "DELETE FROM user_collections WHERE fxa_uid = @fxa_uid @@ -985,7 +997,7 @@ impl SpannerDb { AND collection_id = @collection_id", )? .params(params.clone()) - .param_types(types.clone()) + .param_types(param_types.clone()) .execute_dml_async(&self.conn) .await?; self.update_user_collection_quotas(user_id, TOMBSTONE) @@ -998,15 +1010,17 @@ impl SpannerDb { pub async fn delete_storage_async(&self, user_id: params::DeleteStorage) -> Result<()> { // Also deletes child bsos/batch rows (INTERLEAVE IN PARENT // user_collections ON DELETE CASCADE) + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => user_id.fxa_uid, + "fxa_kid" => user_id.fxa_kid + }; self.sql( "DELETE FROM user_collections WHERE fxa_uid = @fxa_uid AND fxa_kid = @fxa_kid", )? - .params(params! { - "fxa_uid" => user_id.fxa_uid, - "fxa_kid" => user_id.fxa_kid, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; Ok(()) @@ -1025,10 +1039,14 @@ impl SpannerDb { ) -> Result { // Also deletes child bsos/batch rows (INTERLEAVE IN PARENT // user_collections ON DELETE CASCADE) - let collection_id = self - .get_collection_id_async(¶ms.collection) - .await? - .to_string(); + let collection_id = self.get_collection_id_async(¶ms.collection).await?; + let (sqlparams, mut sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid.clone(), + "fxa_kid" => params.user_id.fxa_kid.clone(), + "collection_id" => collection_id.clone(), + "pretouch_ts" => PRETOUCH_TS.to_owned(), + }; + sqlparam_types.insert("pretouch_ts".to_owned(), as_type(TypeCode::TIMESTAMP)); let affected_rows = self .sql( "DELETE FROM user_collections @@ -1037,15 +1055,8 @@ impl SpannerDb { AND collection_id = @collection_id AND modified > @pretouch_ts", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid.clone(), - "fxa_kid" => params.user_id.fxa_kid.clone(), - "collection_id" => collection_id.clone(), - "pretouch_ts" => PRETOUCH_TS.to_owned(), - }) - .param_types(param_types! { - "pretouch_ts" => TypeCode::TIMESTAMP, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; if affected_rows > 0 { @@ -1082,15 +1093,13 @@ impl SpannerDb { return Ok(timestamp); } - let sqlparams = params! { + let (sqlparams, mut sqlparam_types) = params! { "fxa_uid" => user_id.fxa_uid.clone(), "fxa_kid" => user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, "modified" => timestamp.as_rfc3339()?, }; - let sql_types = param_types! { - "modified" => TypeCode::TIMESTAMP, - }; + sqlparam_types.insert("modified".to_owned(), as_type(TypeCode::TIMESTAMP)); let result = self .sql( "SELECT 1 @@ -1100,6 +1109,7 @@ impl SpannerDb { AND collection_id = @collection_id", )? .params(sqlparams.clone()) + .param_types(sqlparam_types.clone()) .execute_async(&self.conn)? .one_or_none() .await?; @@ -1114,7 +1124,7 @@ impl SpannerDb { AND collection_id = @collection_id"; self.sql(sql)? .params(sqlparams) - .param_types(sql_types) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; } else { @@ -1133,7 +1143,7 @@ impl SpannerDb { }; self.sql(update_sql)? .params(sqlparams) - .param_types(sql_types) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; } @@ -1144,6 +1154,12 @@ impl SpannerDb { pub async fn delete_bso_async(&self, params: params::DeleteBso) -> Result { let collection_id = self.get_collection_id_async(¶ms.collection).await?; let user_id = params.user_id.clone(); + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid, + "fxa_kid" => params.user_id.fxa_kid, + "collection_id" => collection_id, + "bso_id" => params.id, + }; let affected_rows = self .sql( "DELETE FROM bsos @@ -1152,12 +1168,8 @@ impl SpannerDb { AND collection_id = @collection_id AND bso_id = @bso_id", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid, - "fxa_kid" => params.user_id.fxa_kid, - "collection_id" => collection_id.to_string(), - "bso_id" => params.id, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; if affected_rows == 0 { @@ -1177,12 +1189,12 @@ impl SpannerDb { let user_id = params.user_id.clone(); let collection_id = self.get_collection_id_async(¶ms.collection).await?; - let mut sqlparams = params! { + let (sqlparams, sqlparam_types) = params! { "fxa_uid" => user_id.fxa_uid, "fxa_kid" => user_id.fxa_kid, - "collection_id" => collection_id.to_string(), + "collection_id" => collection_id, + "ids" => params.ids, }; - sqlparams.insert("ids".to_owned(), params.ids.to_spanner_value()); self.sql( "DELETE FROM bsos WHERE fxa_uid = @fxa_uid @@ -1191,6 +1203,7 @@ impl SpannerDb { AND bso_id IN UNNEST(@ids)", )? .params(sqlparams) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; let mut tags = Tags::default(); @@ -1208,10 +1221,10 @@ impl SpannerDb { params: params::GetBsos, ) -> Result { let mut query = query_str.to_owned(); - let mut sqlparams = params! { + let (mut sqlparams, mut sqlparam_types) = params! { "fxa_uid" => params.user_id.fxa_uid, "fxa_kid" => params.user_id.fxa_kid, - "collection_id" => self.get_collection_id_async(¶ms.collection).await?.to_string(), + "collection_id" => self.get_collection_id_async(¶ms.collection).await?, }; let BsoQueryParams { newer, @@ -1223,11 +1236,10 @@ impl SpannerDb { .. } = params.params; - let mut sqltypes = HashMap::new(); - if !ids.is_empty() { query = format!("{} AND bso_id IN UNNEST(@ids)", query); sqlparams.insert("ids".to_owned(), ids.to_spanner_value()); + sqlparam_types.insert("ids".to_owned(), ids.spanner_type()); } // issue559: Dead code (timestamp always None) @@ -1251,12 +1263,12 @@ impl SpannerDb { if let Some(older) = older { query = format!("{} AND modified < @older", query); sqlparams.insert("older".to_string(), older.as_rfc3339()?.to_spanner_value()); - sqltypes.insert("older".to_string(), as_type(TypeCode::TIMESTAMP)); + sqlparam_types.insert("older".to_string(), as_type(TypeCode::TIMESTAMP)); } if let Some(newer) = newer { query = format!("{} AND modified > @newer", query); sqlparams.insert("newer".to_string(), newer.as_rfc3339()?.to_spanner_value()); - sqltypes.insert("newer".to_string(), as_type(TypeCode::TIMESTAMP)); + sqlparam_types.insert("newer".to_string(), as_type(TypeCode::TIMESTAMP)); } query = match sort { // issue559: Revert to previous sorting @@ -1295,7 +1307,7 @@ impl SpannerDb { } self.sql(&query)? .params(sqlparams) - .param_types(sqltypes) + .param_types(sqlparam_types) .execute_async(&self.conn) } @@ -1436,6 +1448,12 @@ impl SpannerDb { pub async fn get_bso_async(&self, params: params::GetBso) -> Result> { let collection_id = self.get_collection_id_async(¶ms.collection).await?; + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid, + "fxa_kid" => params.user_id.fxa_kid, + "collection_id" => collection_id, + "bso_id" => params.id, + }; self.sql( "SELECT bso_id, sortindex, payload, modified, expiry FROM bsos @@ -1445,12 +1463,8 @@ impl SpannerDb { AND bso_id = @bso_id AND expiry > CURRENT_TIMESTAMP()", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid, - "fxa_kid" => params.user_id.fxa_kid, - "collection_id" => collection_id.to_string(), - "bso_id" => params.id, - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await? @@ -1463,6 +1477,12 @@ impl SpannerDb { params: params::GetBsoTimestamp, ) -> Result { let collection_id = self.get_collection_id_async(¶ms.collection).await?; + let (sqlparams, sqlparam_types) = params! { + "fxa_uid" => params.user_id.fxa_uid, + "fxa_kid" => params.user_id.fxa_kid, + "collection_id" => collection_id, + "bso_id" => params.id, + }; let result = self .sql( @@ -1474,12 +1494,8 @@ impl SpannerDb { AND bso_id = @bso_id AND expiry > CURRENT_TIMESTAMP()", )? - .params(params! { - "fxa_uid" => params.user_id.fxa_uid, - "fxa_kid" => params.user_id.fxa_kid, - "collection_id" => collection_id.to_string(), - "bso_id" => params.id.to_string(), - }) + .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)? .one_or_none() .await?; @@ -1527,20 +1543,16 @@ impl SpannerDb { .update_collection_async(&user_id, collection_id, ¶ms.collection) .await?; - let mut sqlparams = params! { + let (sqlparams, sqlparam_types) = params! { "fxa_uid" => user_id.fxa_uid.clone(), "fxa_kid" => user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - }; - sqlparams.insert( - "ids".to_owned(), - params + "collection_id" => collection_id, + "ids" => params .bsos .iter() .map(|pbso| pbso.id.clone()) - .collect::>() - .to_spanner_value(), - ); + .collect::>(), + }; let mut streaming = self .sql( "SELECT bso_id @@ -1551,6 +1563,7 @@ impl SpannerDb { AND bso_id IN UNNEST(@ids)", )? .params(sqlparams) + .param_types(sqlparam_types) .execute_async(&self.conn)?; let mut existing = HashSet::new(); while let Some(row) = streaming.next_async().await { @@ -1677,13 +1690,12 @@ impl SpannerDb { self.check_quota(&bso.user_id, &bso.collection, collection_id) .await?; - let mut sqlparams = params! { + let (mut sqlparams, mut sqlparam_types) = params! { "fxa_uid" => bso.user_id.fxa_uid.clone(), "fxa_kid" => bso.user_id.fxa_kid.clone(), - "collection_id" => collection_id.to_string(), - "bso_id" => bso.id.to_string(), + "collection_id" => collection_id, + "bso_id" => bso.id, }; - let mut sqltypes = HashMap::new(); // prewarm the collections table by ensuring that the row is added if not present. self.update_collection_async(&bso.user_id, collection_id, &bso.collection) .await?; @@ -1699,6 +1711,7 @@ impl SpannerDb { AND bso_id = @bso_id", )? .params(sqlparams.clone()) + .param_types(sqlparam_types.clone()) .execute_async(&self.conn)? .one_or_none() .await?; @@ -1713,7 +1726,7 @@ impl SpannerDb { q, if let Some(sortindex) = bso.sortindex { sqlparams.insert("sortindex".to_string(), sortindex.to_spanner_value()); - sqltypes.insert("sortindex".to_string(), as_type(TypeCode::INT64)); + sqlparam_types.insert("sortindex".to_string(), sortindex.spanner_type()); format!("{}{}", comma(&q), "sortindex = @sortindex") } else { @@ -1727,7 +1740,7 @@ impl SpannerDb { if let Some(ttl) = bso.ttl { let expiry = timestamp.as_i64() + (i64::from(ttl) * 1000); sqlparams.insert("expiry".to_string(), to_rfc3339(expiry)?.to_spanner_value()); - sqltypes.insert("expiry".to_string(), as_type(TypeCode::TIMESTAMP)); + sqlparam_types.insert("expiry".to_string(), as_type(TypeCode::TIMESTAMP)); format!("{}{}", comma(&q), "expiry = @expiry") } else { "".to_string() @@ -1742,7 +1755,7 @@ impl SpannerDb { "modified".to_string(), timestamp.as_rfc3339()?.to_spanner_value(), ); - sqltypes.insert("modified".to_string(), as_type(TypeCode::TIMESTAMP)); + sqlparam_types.insert("modified".to_string(), as_type(TypeCode::TIMESTAMP)); format!("{}{}", comma(&q), "modified = @modified") } else { "".to_string() @@ -1754,6 +1767,7 @@ impl SpannerDb { q, if let Some(payload) = bso.payload { sqlparams.insert("payload".to_string(), payload.to_spanner_value()); + sqlparam_types.insert("payload".to_string(), payload.spanner_type()); format!("{}{}", comma(&q), "payload = @payload") } else { "".to_string() @@ -1800,14 +1814,11 @@ impl SpannerDb { .map(|sortindex| sortindex.to_spanner_value()) .unwrap_or_else(null_value); sqlparams.insert("sortindex".to_string(), sortindex); - sqltypes.insert("sortindex".to_string(), as_type(TypeCode::INT64)); + sqlparam_types.insert("sortindex".to_string(), as_type(TypeCode::INT64)); } - sqlparams.insert( - "payload".to_string(), - bso.payload - .unwrap_or_else(|| "".to_owned()) - .to_spanner_value(), - ); + let payload = bso.payload.unwrap_or_else(|| "".to_owned()); + sqlparams.insert("payload".to_string(), payload.to_spanner_value()); + sqlparam_types.insert("payload".to_owned(), payload.spanner_type()); let now_millis = timestamp.as_i64(); let ttl = bso.ttl.map_or(i64::from(DEFAULT_BSO_TTL), |ttl| { ttl.try_into() @@ -1819,19 +1830,19 @@ impl SpannerDb { &expirystring, timestamp, ttl ); sqlparams.insert("expiry".to_string(), expirystring.to_spanner_value()); - sqltypes.insert("expiry".to_string(), as_type(TypeCode::TIMESTAMP)); + sqlparam_types.insert("expiry".to_string(), as_type(TypeCode::TIMESTAMP)); sqlparams.insert( "modified".to_string(), timestamp.as_rfc3339()?.to_spanner_value(), ); - sqltypes.insert("modified".to_string(), as_type(TypeCode::TIMESTAMP)); + sqlparam_types.insert("modified".to_string(), as_type(TypeCode::TIMESTAMP)); sql.to_owned() }; self.sql(&sql)? .params(sqlparams) - .param_types(sqltypes) + .param_types(sqlparam_types) .execute_dml_async(&self.conn) .await?; // update the counts for the user_collections table. diff --git a/src/db/spanner/support.rs b/src/db/spanner/support.rs index 2cceff96f3..7f38371ec5 100644 --- a/src/db/spanner/support.rs +++ b/src/db/spanner/support.rs @@ -27,10 +27,20 @@ use crate::{ use super::{models::Result, pool::Conn}; pub trait ToSpannerValue { + const TYPE_CODE: TypeCode; + fn to_spanner_value(&self) -> Value; + + fn spanner_type(&self) -> Type { + let mut t = Type::new(); + t.set_code(Self::TYPE_CODE); + t + } } impl ToSpannerValue for String { + const TYPE_CODE: TypeCode = TypeCode::STRING; + fn to_spanner_value(&self) -> Value { let mut value = Value::new(); value.set_string_value(self.clone()); @@ -39,25 +49,28 @@ impl ToSpannerValue for String { } impl ToSpannerValue for i32 { + const TYPE_CODE: TypeCode = TypeCode::INT64; + fn to_spanner_value(&self) -> Value { - let mut value = Value::new(); - value.set_number_value(*self as f64); - value + self.to_string().to_spanner_value() } } impl ToSpannerValue for u32 { + const TYPE_CODE: TypeCode = TypeCode::INT64; + fn to_spanner_value(&self) -> Value { - let mut value = Value::new(); - value.set_number_value(*self as f64); - value + self.to_string().to_spanner_value() } } impl ToSpannerValue for Vec where T: ToSpannerValue + Clone, + Vec: SpannerArrayElementType, { + const TYPE_CODE: TypeCode = TypeCode::ARRAY; + fn to_spanner_value(&self) -> Value { let mut list = ListValue::new(); list.set_values(RepeatedField::from_vec( @@ -67,6 +80,35 @@ where value.set_list_value(list); value } + + fn spanner_type(&self) -> Type { + let mut t = Type::new(); + t.set_code(Self::TYPE_CODE); + t.set_array_element_type(self.array_element_type()); + t + } +} + +pub trait SpannerArrayElementType { + const ARRAY_ELEMENT_TYPE_CODE: TypeCode; + + fn array_element_type(&self) -> Type { + let mut t = Type::new(); + t.set_code(Self::ARRAY_ELEMENT_TYPE_CODE); + t + } +} + +impl SpannerArrayElementType for Vec { + const ARRAY_ELEMENT_TYPE_CODE: TypeCode = TypeCode::STRING; +} + +impl SpannerArrayElementType for Vec { + const ARRAY_ELEMENT_TYPE_CODE: TypeCode = TypeCode::INT64; +} + +impl SpannerArrayElementType for Vec { + const ARRAY_ELEMENT_TYPE_CODE: TypeCode = TypeCode::INT64; } pub fn as_type(v: TypeCode) -> Type {