Skip to content

Commit

Permalink
refactor: Add ToSpannerValue trait
Browse files Browse the repository at this point in the history
Closes #260
  • Loading branch information
Ethan Donowitz committed Apr 16, 2021
1 parent 57bd30a commit 0983b57
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 71 deletions.
1 change: 1 addition & 0 deletions foo.txt

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions service_account.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"client_id": "32555940559.apps.googleusercontent.com",
"client_secret": "ZmssLNjJy2998hD4CTg2ejr2",
"refresh_token": "1//0dmw73YB_jSoCCgYIARAAGA0SNwF-L9IrChvqz29RR3X_48X-4wT51oIcM5eN1mHGODSkOSnMwuqJulFNnVBAouyZlqG0Dbndeaw",
"type": "authorized_user"
}
2 changes: 1 addition & 1 deletion src/db/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl MockDbPool {
}
}

#[async_trait(?Send)]
#[async_trait]
impl DbPool for MockDbPool {
async fn get<'a>(&'a self) -> ApiResult<Box<dyn Db<'a>>> {
Ok(Box::new(MockDb::new()) as Box<dyn Db<'a>>)
Expand Down
2 changes: 1 addition & 1 deletion src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub const BATCH_LIFETIME: i64 = 2 * 60 * 60 * 1000; // 2 hours, in milliseconds

type DbFuture<'a, T> = LocalBoxFuture<'a, Result<T, ApiError>>;

#[async_trait(?Send)]
#[async_trait]
pub trait DbPool: Sync + Send + Debug {
async fn get(&self) -> ApiResult<Box<dyn Db<'_>>>;

Expand Down
2 changes: 1 addition & 1 deletion src/db/mysql/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl MysqlDbPool {
}
}

#[async_trait(?Send)]
#[async_trait]
impl DbPool for MysqlDbPool {
async fn get<'a>(&'a self) -> ApiResult<Box<dyn Db<'a>>> {
let pool = self.clone();
Expand Down
29 changes: 14 additions & 15 deletions src/db/spanner/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ use protobuf::{
};
use uuid::Uuid;

use super::support::{null_value, struct_type_field};
use super::support::{null_value, struct_type_field, ToSpannerValue};
use super::{
models::{Result, SpannerDb, DEFAULT_BSO_TTL, PRETOUCH_TS},
support::{as_list_value, as_value},
};
use crate::{
db::{params, results, util::to_rfc3339, DbError, DbErrorKind, BATCH_LIFETIME},
Expand Down Expand Up @@ -298,7 +297,7 @@ pub async fn do_append_async(
"collection_id" => collection_id.to_string(),
"batch_id" => batch.id.clone(),
};
params.insert("ids".to_owned(), as_list_value(bso_ids));
params.insert("ids".to_owned(), bso_ids.collect::<Vec<String>>().to_spanner_value());
let mut existing_stream = db
.sql(
"SELECT batch_bso_id
Expand Down Expand Up @@ -355,23 +354,23 @@ pub async fn do_append_async(
} else {
let sortindex = bso
.sortindex
.map(|sortindex| as_value(sortindex.to_string()))
.map(|sortindex| sortindex.to_spanner_value())
.unwrap_or_else(null_value);
let payload = bso.payload.map(as_value).unwrap_or_else(null_value);
let payload = bso.payload.map(ToSpannerValue::to_spanner_value).unwrap_or_else(null_value);
let ttl = bso
.ttl
.map(|ttl| as_value(ttl.to_string()))
.map(ToSpannerValue::to_spanner_value)
.unwrap_or_else(null_value);

// convert to a protobuf structure for direct insertion to
// avoid some mutation limits.
let mut row = ListValue::new();
row.set_values(RepeatedField::from_vec(vec![
as_value(user_id.fxa_uid.clone()),
as_value(user_id.fxa_kid.clone()),
as_value(collection_id.to_string()),
as_value(batch.id.clone()),
as_value(bso.id),
user_id.fxa_uid.clone().to_spanner_value(),
user_id.fxa_kid.clone().to_spanner_value(),
collection_id.to_spanner_value(),
batch.id.clone().to_spanner_value(),
bso.id.to_spanner_value(),
sortindex,
payload,
ttl,
Expand Down Expand Up @@ -480,15 +479,15 @@ pub async fn do_append_async(
};
if let Some(sortindex) = val.sortindex {
fields.push("sortindex");
params.insert("sortindex".to_owned(), as_value(sortindex.to_string()));
params.insert("sortindex".to_owned(), ToSpannerValue::to_spanner_value(sortindex.to_string()));
}
if let Some(payload) = val.payload {
fields.push("payload");
params.insert("payload".to_owned(), as_value(payload));
params.insert("payload".to_owned(), payload.to_spanner_value());
};
if let Some(ttl) = val.ttl {
fields.push("ttl");
params.insert("ttl".to_owned(), as_value(ttl.to_string()));
params.insert("ttl".to_owned(), ttl.to_spanner_value());
}
if fields.is_empty() {
continue;
Expand Down Expand Up @@ -545,7 +544,7 @@ async fn pretouch_collection_async(
.one_or_none()
.await?;
if result.is_none() {
sqlparams.insert("modified".to_owned(), as_value(PRETOUCH_TS.to_owned()));
sqlparams.insert("modified".to_owned(), PRETOUCH_TS.to_owned().to_spanner_value());
let sql = if db.quota.enabled {
"INSERT INTO user_collections (fxa_uid, fxa_kid, collection_id, modified, count, total_bytes)
VALUES (@fxa_uid, @fxa_kid, @collection_id, @modified, 0, 0)"
Expand Down
2 changes: 1 addition & 1 deletion src/db/spanner/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ macro_rules! params {
let _cap = params!(@count $($key),*);
let mut _map = ::std::collections::HashMap::with_capacity(_cap);
$(
_map.insert($key.to_owned(), as_value($value));
_map.insert($key.to_owned(), ToSpannerValue::to_spanner_value($value));
)*
_map
}
Expand Down
36 changes: 18 additions & 18 deletions src/db/spanner/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ use super::{
batch,
pool::{CollectionCache, Conn},
support::{
as_list_value, as_type, as_value, bso_from_row, bso_to_insert_row, bso_to_update_row,
ExecuteSqlRequestBuilder, StreamedResultSetAsync,
as_type, bso_from_row, bso_to_insert_row, bso_to_update_row,
ExecuteSqlRequestBuilder, StreamedResultSetAsync, ToSpannerValue,
},
};

Expand Down Expand Up @@ -639,7 +639,7 @@ impl SpannerDb {
let mut params = HashMap::new();
params.insert(
"ids".to_owned(),
as_list_value(uncached.into_iter().map(|id| id.to_string())),
uncached.into_iter().map(|id| id.to_string()).collect::<Vec<String>>().to_spanner_value(),
);
let mut rs = self
.sql(
Expand Down Expand Up @@ -892,9 +892,9 @@ impl SpannerDb {
if self.quota.enabled {
sqlparams.insert(
"total_bytes".to_owned(),
as_value(result[0].take_string_value()),
result[0].take_string_value().to_spanner_value(),
);
sqlparams.insert("count".to_owned(), as_value(result[1].take_string_value()));
sqlparams.insert("count".to_owned(), result[1].take_string_value().to_spanner_value());
sqltypes.insert(
"total_bytes".to_owned(),
crate::db::spanner::support::as_type(TypeCode::INT64),
Expand Down Expand Up @@ -1175,7 +1175,7 @@ impl SpannerDb {
"fxa_kid" => user_id.fxa_kid,
"collection_id" => collection_id.to_string(),
};
sqlparams.insert("ids".to_owned(), as_list_value(params.ids.into_iter()));
sqlparams.insert("ids".to_owned(), params.ids.to_spanner_value());
self.sql(
"DELETE FROM bsos
WHERE fxa_uid = @fxa_uid
Expand Down Expand Up @@ -1220,7 +1220,7 @@ impl SpannerDb {

if !ids.is_empty() {
query = format!("{} AND bso_id IN UNNEST(@ids)", query);
sqlparams.insert("ids".to_owned(), as_list_value(ids.into_iter()));
sqlparams.insert("ids".to_owned(), ids.to_spanner_value());
}

// issue559: Dead code (timestamp always None)
Expand All @@ -1243,12 +1243,12 @@ impl SpannerDb {
*/
if let Some(older) = older {
query = format!("{} AND modified < @older", query);
sqlparams.insert("older".to_string(), as_value(older.as_rfc3339()?));
sqlparams.insert("older".to_string(), older.as_rfc3339()?.to_spanner_value());
sqltypes.insert("older".to_string(), as_type(TypeCode::TIMESTAMP));
}
if let Some(newer) = newer {
query = format!("{} AND modified > @newer", query);
sqlparams.insert("newer".to_string(), as_value(newer.as_rfc3339()?));
sqlparams.insert("newer".to_string(), newer.as_rfc3339()?.to_spanner_value());
sqltypes.insert("newer".to_string(), as_type(TypeCode::TIMESTAMP));
}
query = match sort {
Expand Down Expand Up @@ -1527,7 +1527,7 @@ impl SpannerDb {
};
sqlparams.insert(
"ids".to_owned(),
as_list_value(params.bsos.iter().map(|pbso| pbso.id.clone())),
params.bsos.iter().map(|pbso| pbso.id.clone()).collect::<Vec<String>>().to_spanner_value(),
);
let mut streaming = self
.sql(
Expand Down Expand Up @@ -1700,7 +1700,7 @@ impl SpannerDb {
"{}{}",
q,
if let Some(sortindex) = bso.sortindex {
sqlparams.insert("sortindex".to_string(), as_value(sortindex.to_string()));
sqlparams.insert("sortindex".to_string(), sortindex.to_spanner_value());
sqltypes.insert("sortindex".to_string(), as_type(TypeCode::INT64));

format!("{}{}", comma(&q), "sortindex = @sortindex")
Expand All @@ -1714,7 +1714,7 @@ impl SpannerDb {
q,
if let Some(ttl) = bso.ttl {
let expiry = timestamp.as_i64() + (i64::from(ttl) * 1000);
sqlparams.insert("expiry".to_string(), as_value(to_rfc3339(expiry)?));
sqlparams.insert("expiry".to_string(), to_rfc3339(expiry)?.to_spanner_value());
sqltypes.insert("expiry".to_string(), as_type(TypeCode::TIMESTAMP));
format!("{}{}", comma(&q), "expiry = @expiry")
} else {
Expand All @@ -1726,7 +1726,7 @@ impl SpannerDb {
"{}{}",
q,
if bso.payload.is_some() || bso.sortindex.is_some() {
sqlparams.insert("modified".to_string(), as_value(timestamp.as_rfc3339()?));
sqlparams.insert("modified".to_string(), timestamp.as_rfc3339()?.to_spanner_value());
sqltypes.insert("modified".to_string(), as_type(TypeCode::TIMESTAMP));
format!("{}{}", comma(&q), "modified = @modified")
} else {
Expand All @@ -1738,7 +1738,7 @@ impl SpannerDb {
"{}{}",
q,
if let Some(payload) = bso.payload {
sqlparams.insert("payload".to_string(), as_value(payload));
sqlparams.insert("payload".to_string(), payload.to_spanner_value());
format!("{}{}", comma(&q), "payload = @payload")
} else {
"".to_string()
Expand Down Expand Up @@ -1782,14 +1782,14 @@ impl SpannerDb {
use super::support::null_value;
let sortindex = bso
.sortindex
.map(|sortindex| as_value(sortindex.to_string()))
.map(|sortindex| sortindex.to_spanner_value())
.unwrap_or_else(null_value);
sqlparams.insert("sortindex".to_string(), sortindex);
sqltypes.insert("sortindex".to_string(), as_type(TypeCode::INT64));
}
sqlparams.insert(
"payload".to_string(),
as_value(bso.payload.unwrap_or_else(|| "".to_owned())),
bso.payload.unwrap_or_else(|| "".to_owned()).to_spanner_value(),
);
let now_millis = timestamp.as_i64();
let ttl = bso.ttl.map_or(i64::from(DEFAULT_BSO_TTL), |ttl| {
Expand All @@ -1801,10 +1801,10 @@ impl SpannerDb {
"!!!!! [test] INSERT expirystring:{:?}, timestamp:{:?}, ttl:{:?}",
&expirystring, timestamp, ttl
);
sqlparams.insert("expiry".to_string(), as_value(expirystring));
sqlparams.insert("expiry".to_string(), expirystring.to_spanner_value());
sqltypes.insert("expiry".to_string(), as_type(TypeCode::TIMESTAMP));

sqlparams.insert("modified".to_string(), as_value(timestamp.as_rfc3339()?));
sqlparams.insert("modified".to_string(), timestamp.as_rfc3339()?.to_spanner_value());
sqltypes.insert("modified".to_string(), as_type(TypeCode::TIMESTAMP));
sql.to_owned()
};
Expand Down
2 changes: 1 addition & 1 deletion src/db/spanner/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl SpannerDbPool {
}
}

#[async_trait(?Send)]
#[async_trait]
impl DbPool for SpannerDbPool {
async fn get<'a>(&'a self) -> ApiResult<Box<dyn Db<'a>>> {
let mut metrics = self.metrics.clone();
Expand Down
Loading

0 comments on commit 0983b57

Please sign in to comment.