diff --git a/Cargo.toml b/Cargo.toml index 32da602d10..a889009065 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ slog-scope = "4.3" slog-stdlog = "4.1" slog-term = "2.6" time = "0.2" +tokio = { version = "0.2", features = ["macros"] } url = "2.1" urlencoding = "1.1" uuid = { version = "0.8.1", features = ["serde", "v4"] } @@ -75,9 +76,6 @@ validator = "0.11" validator_derive = "0.11" woothee = "0.11" -[dev-dependencies] -tokio = { version = "0.3", features = ["macros", "rt"] } - [features] no_auth = [] diff --git a/src/db/mock.rs b/src/db/mock.rs index b9b52c62fe..6fb673bb3b 100644 --- a/src/db/mock.rs +++ b/src/db/mock.rs @@ -118,7 +118,9 @@ impl<'a> Db<'a> for MockDb { mock_db_method!(delete_batch, DeleteBatch); #[cfg(test)] - fn clear_coll_cache(&self) {} + fn clear_coll_cache(&self) -> DbFuture<'_, ()> { + Box::pin(future::ok(())) + } #[cfg(test)] fn set_quota(&mut self, _: bool, _: usize) {} diff --git a/src/db/mod.rs b/src/db/mod.rs index 4584405a71..7d61acb06c 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -236,7 +236,7 @@ pub trait Db<'a>: Debug + 'a { fn delete_batch(&self, params: params::DeleteBatch) -> DbFuture<'_, ()>; #[cfg(test)] - fn clear_coll_cache(&self); + fn clear_coll_cache(&self) -> DbFuture<'_, ()>; #[cfg(test)] fn set_quota(&mut self, enabled: bool, limit: usize); diff --git a/src/db/mysql/models.rs b/src/db/mysql/models.rs index dd074b5ae8..ed2fd95b41 100644 --- a/src/db/mysql/models.rs +++ b/src/db/mysql/models.rs @@ -1118,8 +1118,15 @@ impl<'a> Db<'a> for MysqlDb { sync_db_method!(delete_batch, delete_batch_sync, DeleteBatch); #[cfg(test)] - fn clear_coll_cache(&self) { - self.coll_cache.clear(); + fn clear_coll_cache(&self) -> DbFuture<'_, ()> { + let db = self.clone(); + Box::pin( + block(move || { + db.coll_cache.clear(); + Ok(()) + }) + .map_err(Into::into), + ) } #[cfg(test)] diff --git a/src/db/spanner/batch.rs b/src/db/spanner/batch.rs index 321c00502a..d25959843d 100644 --- a/src/db/spanner/batch.rs +++ b/src/db/spanner/batch.rs @@ -290,7 +290,8 @@ pub async fn do_append_async( let mut tags = Tags::default(); tags.tags.insert( "collection".to_owned(), - db.get_collection_name(collection_id)? + db.get_collection_name(collection_id) + .await .unwrap_or_else(|| "UNKNOWN".to_string()), ); diff --git a/src/db/spanner/models.rs b/src/db/spanner/models.rs index b715e2a5c6..be4d07bfc1 100644 --- a/src/db/spanner/models.rs +++ b/src/db/spanner/models.rs @@ -137,12 +137,12 @@ impl SpannerDb { } } - pub(super) fn get_collection_name(&self, id: i32) -> Result> { - self.coll_cache.get_name(id) + pub(super) async fn get_collection_name(&self, id: i32) -> Option { + self.coll_cache.get_name(id).await } pub(super) async fn get_collection_id_async(&self, name: &str) -> Result { - if let Some(id) = self.coll_cache.get_id(name)? { + if let Some(id) = self.coll_cache.get_id(name).await { return Ok(id); } let result = self @@ -161,7 +161,7 @@ impl SpannerDb { .parse::() .map_err(|e| DbErrorKind::Integrity(e.to_string()))?; if !self.in_write_transaction() { - self.coll_cache.put(id, name.to_owned())?; + self.coll_cache.put(id, name.to_owned()).await; } Ok(id) } @@ -631,15 +631,10 @@ impl SpannerDb { &self, collection_ids: impl Iterator, ) -> Result> { - let mut names = HashMap::new(); - let mut uncached = Vec::new(); - for &id in collection_ids { - if let Some(name) = self.coll_cache.get_name(id)? { - names.insert(id, name); - } else { - uncached.push(id); - } - } + let (mut names, uncached) = self + .coll_cache + .get_names(&collection_ids.cloned().collect::>()) + .await; if !uncached.is_empty() { let mut params = HashMap::new(); @@ -664,7 +659,7 @@ impl SpannerDb { let name = row[1].take_string_value(); names.insert(id, name.clone()); if !self.in_write_transaction() { - self.coll_cache.put(id, name)?; + self.coll_cache.put(id, name).await; } } } @@ -2100,8 +2095,12 @@ impl<'a> Db<'a> for SpannerDb { } #[cfg(test)] - fn clear_coll_cache(&self) { - self.coll_cache.clear(); + fn clear_coll_cache(&self) -> DbFuture<'_, ()> { + let db = self.clone(); + Box::pin(async move { + db.coll_cache.clear().await; + Ok(()) + }) } #[cfg(test)] diff --git a/src/db/spanner/pool.rs b/src/db/spanner/pool.rs index 994ca22c8f..215a7a3d85 100644 --- a/src/db/spanner/pool.rs +++ b/src/db/spanner/pool.rs @@ -1,22 +1,22 @@ +use std::{collections::HashMap, fmt, sync::Arc}; + use async_trait::async_trait; use bb8::ErrorSink; +use tokio::sync::RwLock; -use std::{ - collections::HashMap, - fmt, - sync::{Arc, RwLock}, +use crate::{ + db::{error::DbError, results, Db, DbPool, STD_COLLS}, + error::ApiResult, + server::metrics::Metrics, + settings::Settings, }; -use super::models::Result; -use crate::db::{error::DbError, results, Db, DbPool, STD_COLLS}; -use crate::server::metrics::Metrics; -use crate::settings::Settings; - -use super::manager::{SpannerSession, SpannerSessionManager}; -use super::models::SpannerDb; -use crate::error::ApiResult; - pub use super::manager::Conn; +use super::{ + manager::{SpannerSession, SpannerSessionManager}, + models::Result, + models::SpannerDb, +}; embed_migrations!(); @@ -117,43 +117,44 @@ pub struct CollectionCache { } impl CollectionCache { - pub fn put(&self, id: i32, name: String) -> Result<()> { + pub async fn put(&self, id: i32, name: String) { // XXX: should this emit a metric? - // XXX: should probably either lock both simultaneously during - // writes or use an RwLock alternative - self.by_name - .write() - .map_err(|_| DbError::internal("by_name write"))? - .insert(name.clone(), id); - self.by_id - .write() - .map_err(|_| DbError::internal("by_id write"))? - .insert(id, name); - Ok(()) - } - - pub fn get_id(&self, name: &str) -> Result> { - Ok(self - .by_name - .read() - .map_err(|_| DbError::internal("by_name read"))? - .get(name) - .cloned()) - } - - pub fn get_name(&self, id: i32) -> Result> { - Ok(self - .by_id - .read() - .map_err(|_| DbError::internal("by_id read"))? - .get(&id) - .cloned()) + // XXX: one RwLock might be sufficient? + self.by_name.write().await.insert(name.clone(), id); + self.by_id.write().await.insert(id, name); + } + + pub async fn get_id(&self, name: &str) -> Option { + self.by_name.read().await.get(name).cloned() + } + + pub async fn get_name(&self, id: i32) -> Option { + self.by_id.read().await.get(&id).cloned() + } + + /// Get multiple names, returning a tuple of both the mapping of + /// ids to their names and a Vec of ids not found in the cache. + pub async fn get_names(&self, ids: &[i32]) -> (HashMap, Vec) { + let len = ids.len(); + // the ids array shouldn't be very large but avoid reallocating + // while holding the lock + let mut names = HashMap::with_capacity(len); + let mut missing = Vec::with_capacity(len); + let by_id = self.by_id.read().await; + for &id in ids { + if let Some(name) = by_id.get(&id) { + names.insert(id, name.to_owned()); + } else { + missing.push(id) + } + } + (names, missing) } #[cfg(test)] - pub fn clear(&self) { - self.by_name.write().expect("by_name write").clear(); - self.by_id.write().expect("by_id write").clear(); + pub async fn clear(&self) { + self.by_name.write().await.clear(); + self.by_id.write().await.clear(); } } diff --git a/src/db/tests/db.rs b/src/db/tests/db.rs index 1b9c57e818..4e0ef2e59d 100644 --- a/src/db/tests/db.rs +++ b/src/db/tests/db.rs @@ -1063,7 +1063,7 @@ async fn collection_cache() -> Result<()> { }) .await?; - db.clear_coll_cache(); + db.clear_coll_cache().await?; let cols = db.get_collection_timestamps(hid(uid)).await?; assert!(cols.contains_key(coll)); Ok(())