diff --git a/Cargo.lock b/Cargo.lock index e3eae0699d..88e9605b19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -838,6 +838,7 @@ checksum = "3e2de9deab977a153492a1468d1b1c0662c1cf39e5ea87d0c060ecd59ef18d8c" dependencies = [ "byteorder", "diesel_derives", + "libsqlite3-sys", "mysqlclient-sys", "r2d2", "url 1.7.2", @@ -1620,6 +1621,16 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "libsqlite3-sys" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e704a02bcaecd4a08b93a23f6be59d0bd79cd161e0963e9499165a0a35df7bd" +dependencies = [ + "pkg-config", + "vcpkg", +] + [[package]] name = "libz-sys" version = "1.1.2" diff --git a/Cargo.toml b/Cargo.toml index 9bcda33bc0..ad4c542c06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ cadence = "0.22" chrono = "0.4" config = "0.10" deadpool = "0.6" -diesel = { version = "1.4", features = ["mysql", "r2d2"] } +diesel = { version = "1.4", features = ["mysql", "r2d2", "sqlite"] } diesel_logger = "0.1.1" diesel_migrations = { version = "1.4.0", features = ["mysql"] } docopt = "1.1.0" diff --git a/README.md b/README.md index 7bb30265f1..c724b29968 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Mozilla Sync Storage built with [Rust](https://rust-lang.org). - [Local Setup](#local-setup) - [MySQL](#mysql) - [Spanner](#spanner) + - [Sqlite](#sqlite) - [Running via Docker](#running-via-docker) - [Connecting to Firefox](#connecting-to-firefox) - [Logging](#logging) @@ -106,6 +107,15 @@ To point to a GCP hosted Spanner instance from your local machine, follow these 4. `make run_spanner`. 5. Visit `http://localhost:8000/__heartbeat__` to make sure the server is running. +### Sqlite + +Setting up the server with sqlite only requires a path to the database file, +which will be created automatically: + +`sqlite:path/syncdb.sqlite` + +This requires at least sqlite v3.24.0 to be installed on the host system. + ### Running via Docker This requires access to the mozilla-rust-sdk which is now available at `/vendor/mozilla-rust-adk`. diff --git a/migrations/sqlite/2020-11-29-124806_init/down.sql b/migrations/sqlite/2020-11-29-124806_init/down.sql new file mode 100644 index 0000000000..63a99dd2eb --- /dev/null +++ b/migrations/sqlite/2020-11-29-124806_init/down.sql @@ -0,0 +1,8 @@ +-- DROP INDEX IF EXISTS `bso_expiry_idx`; +-- DROP INDEX IF EXISTS `bso_usr_col_mod_idx`; + +-- DROP TABLE IF EXISTS `bso`; +-- DROP TABLE IF EXISTS `collections`; +-- DROP TABLE IF EXISTS `user_collections`; +-- DROP TABLE IF EXISTS `batch_uploads`; +-- DROP TABLE IF EXISTS `batch_upload_items`; diff --git a/migrations/sqlite/2020-11-29-124806_init/up.sql b/migrations/sqlite/2020-11-29-124806_init/up.sql new file mode 100644 index 0000000000..2fd4aeb2ca --- /dev/null +++ b/migrations/sqlite/2020-11-29-124806_init/up.sql @@ -0,0 +1,79 @@ +-- XXX: bsov1, etc +-- We use Bigint for some fields instead of Integer, even though Sqlite does not have the concept of Bigint, +-- to allow diesel to assume that integer can be mapped to i64. See https://github.com/diesel-rs/diesel/issues/852 + + +CREATE TABLE IF NOT EXISTS `bso` +( + `userid` BIGINT NOT NULL, + `collection` INTEGER NOT NULL, + `id` TEXT NOT NULL, + + `sortindex` INTEGER, + + `payload` TEXT NOT NULL, + `payload_size` BIGINT DEFAULT 0, + + -- last modified time in milliseconds since epoch + `modified` BIGINT NOT NULL, + -- expiration in milliseconds since epoch + `ttl` BIGINT DEFAULT '3153600000000' NOT NULL, + + PRIMARY KEY (`userid`, `collection`, `id`) +); +CREATE INDEX IF NOT EXISTS `bso_expiry_idx` ON `bso` (`ttl`); +CREATE INDEX IF NOT EXISTS `bso_usr_col_mod_idx` ON `bso` (`userid`, `collection`, `modified`); + +CREATE TABLE IF NOT EXISTS `collections` +( + `id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + `name` TEXT UNIQUE NOT NULL +); +INSERT INTO collections (id, name) +VALUES (1, 'clients'), + (2, 'crypto'), + (3, 'forms'), + (4, 'history'), + (5, 'keys'), + (6, 'meta'), + (7, 'bookmarks'), + (8, 'prefs'), + (9, 'tabs'), + (10, 'passwords'), + (11, 'addons'), + (12, 'addresses'), + (13, 'creditcards'), + -- Reserve space for additions to the standard collections + (100, ''); + + +CREATE TABLE IF NOT EXISTS `user_collections` +( + `userid` BIGINT NOT NULL, + `collection` INTEGER NOT NULL, + -- last modified time in milliseconds since epoch + `last_modified` BIGINT NOT NULL, + `total_bytes` BIGINT, + `count` INTEGER, + PRIMARY KEY (`userid`, `collection`) +); + +CREATE TABLE IF NOT EXISTS `batch_uploads` +( + `batch` BIGINT NOT NULL, + `userid` BIGINT NOT NULL, + `collection` INTEGER NOT NULL, + PRIMARY KEY (`batch`, `userid`) +); + +CREATE TABLE IF NOT EXISTS `batch_upload_items` +( + `batch` BIGINT NOT NULL, + `userid` BIGINT NOT NULL, + `id` TEXT NOT NULL, + `sortindex` INTEGER DEFAULT NULL, + `payload` TEXT, + `payload_size` BIGINT DEFAULT NULL, + `ttl_offset` INTEGER DEFAULT NULL, + PRIMARY KEY (`batch`, `userid`, `id`) +); diff --git a/src/db/mod.rs b/src/db/mod.rs index 7b37500077..093d3f4a30 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -6,6 +6,7 @@ pub mod mysql; pub mod params; pub mod results; pub mod spanner; +pub mod sqlite; #[cfg(test)] mod tests; pub mod transaction; @@ -273,6 +274,7 @@ pub async fn pool_from_settings( Ok(match url.scheme() { "mysql" => Box::new(mysql::pool::MysqlDbPool::new(&settings, &metrics)?), "spanner" => Box::new(spanner::pool::SpannerDbPool::new(&settings, &metrics).await?), + "sqlite" => Box::new(sqlite::pool::SqliteDbPool::new(&settings, &metrics)?), _ => Err(DbErrorKind::InvalidUrl(settings.database_url.to_owned()))?, }) } diff --git a/src/db/sqlite/batch.rs b/src/db/sqlite/batch.rs new file mode 100644 index 0000000000..99b5b6d527 --- /dev/null +++ b/src/db/sqlite/batch.rs @@ -0,0 +1,284 @@ +use std::collections::HashSet; + +use diesel::{ + self, + dsl::sql, + insert_into, + result::{DatabaseErrorKind::UniqueViolation, Error as DieselError}, + sql_query, + sql_types::{BigInt, Integer}, + ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl, +}; + +use super::{ + models::{Result, SqliteDb}, + schema::{batch_upload_items, batch_uploads}, +}; + +use crate::{ + db::{params, results, DbError, DbErrorKind, BATCH_LIFETIME}, + web::extractors::HawkIdentifier, +}; + +const MAXTTL: i32 = 2_100_000_000; + +pub fn create(db: &SqliteDb, params: params::CreateBatch) -> Result { + let user_id = params.user_id.legacy_id as i64; + let collection_id = db.get_collection_id(¶ms.collection)?; + // Careful, there's some weirdness here! + // + // Sync timestamps are in seconds and quantized to two decimal places, so + // when we convert one to a bigint in milliseconds, the final digit is + // always zero. But we want to use the lower digits of the batchid for + // sharding writes via (batchid % num_tables), and leaving it as zero would + // skew the sharding distribution. + // + // So we mix in the lowest digit of the uid to improve the distribution + // while still letting us treat these ids as millisecond timestamps. It's + // yuck, but it works and it keeps the weirdness contained to this single + // line of code. + let batch_id = db.timestamp().as_i64() + (user_id % 10); + insert_into(batch_uploads::table) + .values(( + batch_uploads::batch_id.eq(&batch_id), + batch_uploads::user_id.eq(&user_id), + batch_uploads::collection_id.eq(&collection_id), + )) + .execute(&db.conn) + .map_err(|e| -> DbError { + match e { + // The user tried to create two batches with the same timestamp + DieselError::DatabaseError(UniqueViolation, _) => DbErrorKind::Conflict.into(), + _ => e.into(), + } + })?; + + do_append(db, batch_id, params.user_id, collection_id, params.bsos)?; + Ok(results::CreateBatch { + id: encode_id(batch_id), + size: None, + }) +} + +pub fn validate(db: &SqliteDb, params: params::ValidateBatch) -> Result { + let batch_id = decode_id(¶ms.id)?; + // Avoid hitting the db for batches that are obviously too old. Recall + // that the batchid is a millisecond timestamp. + if (batch_id + BATCH_LIFETIME) < db.timestamp().as_i64() { + return Ok(false); + } + + let user_id = params.user_id.legacy_id as i64; + let collection_id = db.get_collection_id(¶ms.collection)?; + let exists = batch_uploads::table + .select(sql::("1")) + .filter(batch_uploads::batch_id.eq(&batch_id)) + .filter(batch_uploads::user_id.eq(&user_id)) + .filter(batch_uploads::collection_id.eq(&collection_id)) + .get_result::(&db.conn) + .optional()?; + Ok(exists.is_some()) +} + +pub fn append(db: &SqliteDb, params: params::AppendToBatch) -> Result<()> { + let exists = validate( + db, + params::ValidateBatch { + user_id: params.user_id.clone(), + collection: params.collection.clone(), + id: params.batch.id.clone(), + }, + )?; + + if !exists { + Err(DbErrorKind::BatchNotFound)? + } + + let batch_id = decode_id(¶ms.batch.id)?; + let collection_id = db.get_collection_id(¶ms.collection)?; + do_append(db, batch_id, params.user_id, collection_id, params.bsos)?; + Ok(()) +} + +pub fn get(db: &SqliteDb, params: params::GetBatch) -> Result> { + let is_valid = validate( + db, + params::ValidateBatch { + user_id: params.user_id, + collection: params.collection, + id: params.id.clone(), + }, + )?; + let batch = if is_valid { + Some(results::GetBatch { id: params.id }) + } else { + None + }; + Ok(batch) +} + +pub fn delete(db: &SqliteDb, params: params::DeleteBatch) -> Result<()> { + let batch_id = decode_id(¶ms.id)?; + let user_id = params.user_id.legacy_id as i64; + let collection_id = db.get_collection_id(¶ms.collection)?; + diesel::delete(batch_uploads::table) + .filter(batch_uploads::batch_id.eq(&batch_id)) + .filter(batch_uploads::user_id.eq(&user_id)) + .filter(batch_uploads::collection_id.eq(&collection_id)) + .execute(&db.conn)?; + diesel::delete(batch_upload_items::table) + .filter(batch_upload_items::batch_id.eq(&batch_id)) + .filter(batch_upload_items::user_id.eq(&user_id)) + .execute(&db.conn)?; + Ok(()) +} + +/// Commits a batch to the bsos table, deleting the batch when succesful +pub fn commit(db: &SqliteDb, params: params::CommitBatch) -> Result { + let batch_id = decode_id(¶ms.batch.id)?; + let user_id = params.user_id.legacy_id as i64; + let collection_id = db.get_collection_id(¶ms.collection)?; + let timestamp = db.timestamp(); + sql_query(include_str!("batch_commit.sql")) + .bind::(user_id as i64) + .bind::(&collection_id) + .bind::(&db.timestamp().as_i64()) + .bind::(&db.timestamp().as_i64()) + .bind::((MAXTTL as i64) * 1000) // XXX: + .bind::(&batch_id) + .bind::(user_id as i64) + .bind::(&db.timestamp().as_i64()) + .execute(&db.conn)?; + + db.update_collection(user_id as u32, collection_id)?; + + delete( + db, + params::DeleteBatch { + user_id: params.user_id, + collection: params.collection, + id: params.batch.id, + }, + )?; + Ok(results::PostBsos { + modified: timestamp, + success: Default::default(), + failed: Default::default(), + }) +} + +pub fn do_append( + db: &SqliteDb, + batch_id: i64, + user_id: HawkIdentifier, + _collection_id: i32, + bsos: Vec, +) -> Result<()> { + fn exist_idx(user_id: u64, batch_id: i64, bso_id: &str) -> String { + // Construct something that matches the key for batch_upload_items + format!( + "{batch_id}-{user_id}-{bso_id}", + batch_id = batch_id, + user_id = user_id, + bso_id = bso_id, + ) + }; + + // It's possible for the list of items to contain a duplicate key entry. + // This means that we can't really call `ON DUPLICATE` here, because that's + // more about inserting one item at a time. (e.g. it works great if the + // values contain a key that's already in the database, less so if the + // the duplicate is in the value set we're inserting. + #[derive(Debug, QueryableByName)] + #[table_name = "batch_upload_items"] + struct ExistsResult { + user_id: i64, + batch_id: i64, + id: String, + }; + + #[derive(AsChangeset)] + #[table_name = "batch_upload_items"] + struct UpdateBatches { + payload: Option, + payload_size: Option, + ttl_offset: Option, + } + + let mut existing = HashSet::new(); + + // pre-load the "existing" hashset with any batched uploads that are already in the table. + for item in sql_query( + "SELECT userid as user_id, batch as batch_id, id FROM batch_upload_items WHERE userid=? AND batch=?;", + ) + .bind::(user_id.legacy_id as i64) + .bind::(batch_id) + .get_results::(&db.conn)? + { + existing.insert(exist_idx( + user_id.legacy_id, + item.batch_id, + &item.id.to_string(), + )); + } + + for bso in bsos { + let payload_size = bso.payload.as_ref().map(|p| p.len() as i64); + let exist_idx = exist_idx(user_id.legacy_id, batch_id, &bso.id); + + if existing.contains(&exist_idx) { + diesel::update( + batch_upload_items::table + .filter(batch_upload_items::user_id.eq(user_id.legacy_id as i64)) + .filter(batch_upload_items::batch_id.eq(batch_id)), + ) + .set(&UpdateBatches { + payload: bso.payload, + payload_size, + ttl_offset: bso.ttl.map(|ttl| ttl as i32), + }) + .execute(&db.conn)?; + } else { + diesel::insert_into(batch_upload_items::table) + .values(( + batch_upload_items::batch_id.eq(&batch_id), + batch_upload_items::user_id.eq(user_id.legacy_id as i64), + batch_upload_items::id.eq(bso.id.clone()), + batch_upload_items::sortindex.eq(bso.sortindex), + batch_upload_items::payload.eq(bso.payload), + batch_upload_items::payload_size.eq(payload_size), + batch_upload_items::ttl_offset.eq(bso.ttl.map(|ttl| ttl as i32)), + )) + .execute(&db.conn)?; + // make sure to include the key into our table check. + existing.insert(exist_idx); + } + } + + Ok(()) +} + +pub fn validate_batch_id(id: &str) -> Result<()> { + decode_id(id).map(|_| ()) +} + +fn encode_id(id: i64) -> String { + base64::encode(&id.to_string()) +} + +fn decode_id(id: &str) -> Result { + let bytes = base64::decode(id).unwrap_or_else(|_| id.as_bytes().to_vec()); + let decoded = std::str::from_utf8(&bytes).unwrap_or(id); + decoded + .parse::() + .map_err(|e| DbError::internal(&format!("Invalid batch_id: {}", e))) +} + +// #[macro_export] +// macro_rules! batch_db_method { +// ($name:ident, $batch_name:ident, $type:ident) => { +// pub fn $name(&self, params: params::$type) -> Result { +// batch::$batch_name(self, params) +// } +// }; +// } diff --git a/src/db/sqlite/batch_commit.sql b/src/db/sqlite/batch_commit.sql new file mode 100644 index 0000000000..9b96fa7b16 --- /dev/null +++ b/src/db/sqlite/batch_commit.sql @@ -0,0 +1,19 @@ +INSERT INTO bso (userid, collection, id, modified, sortindex, ttl, payload, payload_size) +SELECT + ?, + ?, + id, + ?, + sortindex, + COALESCE((ttl_offset * 1000) + ?, ?), + COALESCE(payload, ''), + COALESCE(payload_size, 0) + FROM batch_upload_items + WHERE batch = ? + AND userid = ? + ON CONFLICT(`userid`, `collection`, `id`) DO UPDATE SET + modified = ?, + sortindex = COALESCE(excluded.sortindex, bso.sortindex), + ttl = COALESCE(excluded.ttl, bso.ttl), + payload = CASE WHEN excluded.payload != '' THEN excluded.payload ELSE bso.payload END, + payload_size = CASE WHEN excluded.payload_size != 0 THEN excluded.payload_size ELSE bso.payload_size END diff --git a/src/db/sqlite/mod.rs b/src/db/sqlite/mod.rs new file mode 100644 index 0000000000..142b018d62 --- /dev/null +++ b/src/db/sqlite/mod.rs @@ -0,0 +1,9 @@ +#[macro_use] +mod batch; +pub mod models; +pub mod pool; +mod schema; +#[cfg(test)] +mod test; + +pub use self::pool::SqliteDbPool; diff --git a/src/db/sqlite/models.rs b/src/db/sqlite/models.rs new file mode 100644 index 0000000000..2c3e2a0aa2 --- /dev/null +++ b/src/db/sqlite/models.rs @@ -0,0 +1,1170 @@ +use actix_web::web::block; + +use futures::future::TryFutureExt; + +use std::{self, cell::RefCell, collections::HashMap, fmt, ops::Deref, sync::Arc}; + +use diesel::{ + connection::TransactionManager, + delete, + dsl::max, + expression::sql_literal::sql, + r2d2::{ConnectionManager, PooledConnection}, + sql_query, + sql_types::{BigInt, Integer, Nullable, Text}, + sqlite::SqliteConnection, + Connection, ExpressionMethods, GroupByDsl, OptionalExtension, QueryDsl, RunQueryDsl, +}; +#[cfg(test)] +use diesel_logger::LoggingConnection; + +use super::{ + batch, + pool::CollectionCache, + schema::{bso, collections, user_collections}, +}; +use crate::batch_db_method; +use crate::db::{ + error::{DbError, DbErrorKind}, + params, results, + util::SyncTimestamp, + Db, DbFuture, Sorting, +}; +use crate::server::metrics::Metrics; +use crate::settings::{Quota, DEFAULT_MAX_TOTAL_RECORDS}; +use crate::web::extractors::{BsoQueryParams, HawkIdentifier}; +use crate::web::tags::Tags; + +pub type Result = std::result::Result; +type Conn = PooledConnection>; + +/// The ttl to use for rows that are never supposed to expire (in seconds) +/// We store the TTL as a SyncTimestamp, which is milliseconds, so remember +/// to multiply this by 1000. +pub const DEFAULT_BSO_TTL: u32 = 2_100_000_000; +// this is the max number of records we will return. +pub static DEFAULT_LIMIT: u32 = DEFAULT_MAX_TOTAL_RECORDS; + +pub const TOMBSTONE: i32 = 0; +/// SQL Variable remapping +/// These names are the legacy values mapped to the new names. +pub const COLLECTION_ID: &str = "collection"; +pub const USER_ID: &str = "userid"; +pub const MODIFIED: &str = "modified"; +pub const EXPIRY: &str = "ttl"; +pub const LAST_MODIFIED: &str = "last_modified"; +pub const COUNT: &str = "count"; +pub const TOTAL_BYTES: &str = "total_bytes"; + +#[derive(Debug)] +pub enum CollectionLock { + Read, + Write, +} + +/// Per session Db metadata +#[derive(Debug, Default)] +struct SqliteDbSession { + /// The "current time" on the server used for this session's operations + timestamp: SyncTimestamp, + /// Cache of collection modified timestamps per (user_id, collection_id) + coll_modified_cache: HashMap<(u32, i32), SyncTimestamp>, + /// Currently locked collections + coll_locks: HashMap<(u32, i32), CollectionLock>, + /// Whether a transaction was started (begin() called) + in_transaction: bool, + in_write_transaction: bool, +} + +#[derive(Clone, Debug)] +pub struct SqliteDb { + /// Synchronous Diesel calls are executed in actix_web::web::block to satisfy + /// the Db trait's asynchronous interface. + /// + /// Arc provides a Clone impl utilized for safely moving to + /// the thread pool but does not provide Send as the underlying db + /// conn. structs are !Sync (Arc requires both for Send). See the Send impl + /// below. + pub(super) inner: Arc, + + /// Pool level cache of collection_ids and their names + coll_cache: Arc, + + pub metrics: Metrics, + pub quota: Quota, +} + +/// Despite the db conn structs being !Sync (see Arc above) we +/// don't spawn multiple SqliteDb calls at a time in the thread pool. Calls are +/// queued to the thread pool via Futures, naturally serialized. +unsafe impl Send for SqliteDb {} + +pub struct SqliteDbInner { + #[cfg(not(test))] + pub(super) conn: Conn, + #[cfg(test)] + pub(super) conn: LoggingConnection, // display SQL when RUST_LOG="diesel_logger=trace" + + session: RefCell, +} + +impl fmt::Debug for SqliteDbInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SqliteDbInner {{ session: {:?} }}", self.session) + } +} + +impl Deref for SqliteDb { + type Target = SqliteDbInner; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl SqliteDb { + pub fn new( + conn: Conn, + coll_cache: Arc, + metrics: &Metrics, + quota: &Quota, + ) -> Self { + let inner = SqliteDbInner { + #[cfg(not(test))] + conn, + #[cfg(test)] + conn: LoggingConnection::new(conn), + session: RefCell::new(Default::default()), + }; + SqliteDb { + inner: Arc::new(inner), + coll_cache, + metrics: metrics.clone(), + quota: *quota, + } + } + + /// APIs for collection-level locking + /// + /// Explicitly lock the matching row in the user_collections table. Read + /// locks do SELECT ... LOCK IN SHARE MODE and write locks do SELECT + /// ... FOR UPDATE. + /// + /// In theory it would be possible to use serializable transactions rather + /// than explicit locking, but our ops team have expressed concerns about + /// the efficiency of that approach at scale. + pub fn lock_for_read_sync(&self, params: params::LockCollection) -> Result<()> { + let user_id = params.user_id.legacy_id as i64; + let collection_id = + self.get_collection_id(¶ms.collection) + .or_else(|e| match e.kind() { + // If the collection doesn't exist, we still want to start a + // transaction so it will continue to not exist. + DbErrorKind::CollectionNotFound => Ok(0), + _ => Err(e), + })?; + // If we already have a read or write lock then it's safe to + // use it as-is. + if self + .session + .borrow() + .coll_locks + .get(&(user_id as u32, collection_id)) + .is_some() + { + return Ok(()); + } + + // Lock the db + self.begin(false)?; + let modified = user_collections::table + .select(user_collections::modified) + .filter(user_collections::user_id.eq(user_id)) + .filter(user_collections::collection_id.eq(collection_id)) + // .lock_in_share_mode() + .first(&self.conn) + .optional()?; + if let Some(modified) = modified { + let modified = SyncTimestamp::from_i64(modified)?; + self.session + .borrow_mut() + .coll_modified_cache + .insert((user_id as u32, collection_id), modified); // why does it still expect a u32 int? + } + // XXX: who's responsible for unlocking (removing the entry) + self.session + .borrow_mut() + .coll_locks + .insert((user_id as u32, collection_id), CollectionLock::Read); + Ok(()) + } + + pub fn lock_for_write_sync(&self, params: params::LockCollection) -> Result<()> { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_or_create_collection_id(¶ms.collection)?; + if let Some(CollectionLock::Read) = self + .session + .borrow() + .coll_locks + .get(&(user_id as u32, collection_id)) + { + Err(DbError::internal("Can't escalate read-lock to write-lock"))? + } + + // Lock the db + self.begin(true)?; + let modified = user_collections::table + .select(user_collections::modified) + .filter(user_collections::user_id.eq(user_id)) + .filter(user_collections::collection_id.eq(collection_id)) + // .for_update() + .first(&self.conn) + .optional()?; + if let Some(modified) = modified { + let modified = SyncTimestamp::from_i64(modified)?; + // Forbid the write if it would not properly incr the timestamp + if modified >= self.timestamp() { + Err(DbErrorKind::Conflict)? + } + self.session + .borrow_mut() + .coll_modified_cache + .insert((user_id as u32, collection_id), modified); + } + self.session + .borrow_mut() + .coll_locks + .insert((user_id as u32, collection_id), CollectionLock::Write); + Ok(()) + } + + pub(super) fn begin(&self, for_write: bool) -> Result<()> { + self.conn + .transaction_manager() + .begin_transaction(&self.conn)?; + self.session.borrow_mut().in_transaction = true; + if for_write { + self.session.borrow_mut().in_write_transaction = true; + } + Ok(()) + } + + pub async fn begin_async(&self, for_write: bool) -> Result<()> { + self.begin(for_write) + } + + pub fn commit_sync(&self) -> Result<()> { + if self.session.borrow().in_transaction { + self.conn + .transaction_manager() + .commit_transaction(&self.conn)?; + } + Ok(()) + } + + pub fn rollback_sync(&self) -> Result<()> { + if self.session.borrow().in_transaction { + self.conn + .transaction_manager() + .rollback_transaction(&self.conn)?; + } + Ok(()) + } + + fn erect_tombstone(&self, user_id: i32) -> Result<()> { + sql_query(format!( + r#"INSERT INTO user_collections ({user_id}, {collection_id}, {modified}) + VALUES (?, ?, ?) + ON CONFLICT({user_id}, {collection_id}) DO UPDATE SET + {modified} = excluded.{modified}"#, + user_id = USER_ID, + collection_id = COLLECTION_ID, + modified = LAST_MODIFIED + )) + .bind::(user_id as i64) + .bind::(TOMBSTONE) + .bind::(self.timestamp().as_i64()) + .execute(&self.conn)?; + Ok(()) + } + + pub fn delete_storage_sync(&self, user_id: HawkIdentifier) -> Result<()> { + let user_id = user_id.legacy_id as i64; + // Delete user data. + delete(bso::table) + .filter(bso::user_id.eq(user_id)) + .execute(&self.conn)?; + // Delete user collections. + delete(user_collections::table) + .filter(user_collections::user_id.eq(user_id)) + .execute(&self.conn)?; + Ok(()) + } + + // Deleting the collection should result in: + // - collection does not appear in /info/collections + // - X-Last-Modified timestamp at the storage level changing + pub fn delete_collection_sync( + &self, + params: params::DeleteCollection, + ) -> Result { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_collection_id(¶ms.collection)?; + let mut count = delete(bso::table) + .filter(bso::user_id.eq(user_id)) + .filter(bso::collection_id.eq(&collection_id)) + .execute(&self.conn)?; + count += delete(user_collections::table) + .filter(user_collections::user_id.eq(user_id)) + .filter(user_collections::collection_id.eq(&collection_id)) + .execute(&self.conn)?; + if count == 0 { + Err(DbErrorKind::CollectionNotFound)? + } else { + self.erect_tombstone(user_id as i32)?; + } + self.get_storage_timestamp_sync(params.user_id) + } + + pub(super) fn get_or_create_collection_id(&self, name: &str) -> Result { + if let Some(id) = self.coll_cache.get_id(name)? { + return Ok(id); + } + + let id = self.conn.transaction(|| { + diesel::insert_or_ignore_into(collections::table) + .values(collections::name.eq(name)) + .execute(&self.conn)?; + + collections::table + .select(collections::id) + .filter(collections::name.eq(name)) + .first(&self.conn) + })?; + + if !self.session.borrow().in_write_transaction { + self.coll_cache.put(id, name.to_owned())?; + } + + Ok(id) + } + + pub(super) fn get_collection_id(&self, name: &str) -> Result { + if let Some(id) = self.coll_cache.get_id(name)? { + return Ok(id); + } + + let id = sql_query( + "SELECT id + FROM collections + WHERE name = ?", + ) + .bind::(name) + .get_result::(&self.conn) + .optional()? + .ok_or(DbErrorKind::CollectionNotFound)? + .id; + if !self.session.borrow().in_write_transaction { + self.coll_cache.put(id, name.to_owned())?; + } + Ok(id) + } + + fn _get_collection_name(&self, id: i32) -> Result { + let name = if let Some(name) = self.coll_cache.get_name(id)? { + name + } else { + sql_query( + "SELECT name + FROM collections + WHERE id = ?", + ) + .bind::(&id) + .get_result::(&self.conn) + .optional()? + .ok_or(DbErrorKind::CollectionNotFound)? + .name + }; + Ok(name) + } + + pub fn put_bso_sync(&self, bso: params::PutBso) -> Result { + /* + if bso.payload.is_none() && bso.sortindex.is_none() && bso.ttl.is_none() { + // XXX: go returns an error here (ErrNothingToDo), and is treated + // as other errors + return Ok(()); + } + */ + + let collection_id = self.get_or_create_collection_id(&bso.collection)?; + let user_id: u64 = bso.user_id.legacy_id; + let timestamp = self.timestamp().as_i64(); + if self.quota.enabled { + let usage = self.get_quota_usage_sync(params::GetQuotaUsage { + user_id: HawkIdentifier::new_legacy(user_id), + collection: bso.collection.clone(), + collection_id, + })?; + if usage.total_bytes >= self.quota.size as usize { + let mut tags = Tags::default(); + tags.tags + .insert("collection".to_owned(), bso.collection.clone()); + self.metrics + .incr_with_tags("storage.quota.at_limit", Some(tags)); + if self.quota.enforced { + return Err(DbErrorKind::Quota.into()); + } else { + warn!("Quota at limit for user's collection ({} bytes)", usage.total_bytes; "collection"=>bso.collection.clone()); + } + } + } + + self.conn.transaction(|| { + let payload = bso.payload.as_deref().unwrap_or_default(); + let sortindex = bso.sortindex; + let ttl = bso.ttl.map_or(DEFAULT_BSO_TTL, |ttl| ttl); + let q = format!(r#" + INSERT INTO bso ({user_id}, {collection_id}, id, sortindex, payload, {modified}, {expiry}) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT({user_id}, {collection_id}, id) DO UPDATE SET + {user_id} = excluded.{user_id}, + {collection_id} = excluded.{collection_id}, + id = excluded.id + "#, user_id=USER_ID, modified=MODIFIED, collection_id=COLLECTION_ID, expiry=EXPIRY); + let q = format!( + "{}{}", + q, + if bso.sortindex.is_some() { + ", sortindex = excluded.sortindex" + } else { + "" + }, + ); + let q = format!( + "{}{}", + q, + if bso.payload.is_some() { + ", payload = excluded.payload" + } else { + "" + }, + ); + let q = format!( + "{}{}", + q, + if bso.ttl.is_some() { + format!(", {expiry} = excluded.{expiry}", expiry=EXPIRY) + } else { + "".to_owned() + }, + ); + let q = format!( + "{}{}", + q, + if bso.payload.is_some() || bso.sortindex.is_some() { + format!(", {modified} = excluded.{modified}", modified=MODIFIED) + } else { + "".to_owned() + }, + ); + sql_query(q) + .bind::(user_id as i64) // XXX: + .bind::(&collection_id) + .bind::(&bso.id) + .bind::, _>(sortindex) + .bind::(payload) + .bind::(timestamp) + .bind::(timestamp + (i64::from(ttl) * 1000)) // remember: this is in millis + .execute(&self.conn)?; + self.update_collection(user_id as u32, collection_id) + }) + } + + pub fn get_bsos_sync(&self, params: params::GetBsos) -> Result { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_collection_id(¶ms.collection)?; + let BsoQueryParams { + newer, + older, + sort, + limit, + offset, + ids, + .. + } = params.params; + let now = self.timestamp().as_i64(); + let mut query = bso::table + .select(( + bso::id, + bso::modified, + bso::payload, + bso::sortindex, + bso::expiry, + )) + .filter(bso::user_id.eq(user_id)) + .filter(bso::collection_id.eq(collection_id as i32)) // XXX: + .filter(bso::expiry.gt(now)) + .into_boxed(); + + if let Some(older) = older { + query = query.filter(bso::modified.lt(older.as_i64())); + } + if let Some(newer) = newer { + query = query.filter(bso::modified.gt(newer.as_i64())); + } + + if !ids.is_empty() { + query = query.filter(bso::id.eq_any(ids)); + } + + // it's possible for two BSOs to be inserted with the same `modified` date, + // since there's no guarantee of order when doing a get, pagination can return + // an error. We "fudge" a bit here by taking the id order as a secondary, since + // that is guaranteed to be unique by the client. + query = match sort { + // issue559: Revert to previous sorting + /* + Sorting::Index => query.order(bso::id.desc()).order(bso::sortindex.desc()), + Sorting::Newest | Sorting::None => { + query.order(bso::id.desc()).order(bso::modified.desc()) + } + Sorting::Oldest => query.order(bso::id.asc()).order(bso::modified.asc()), + */ + Sorting::Index => query.order(bso::sortindex.desc()), + Sorting::Newest => query.order((bso::modified.desc(), bso::id.desc())), + Sorting::Oldest => query.order((bso::modified.asc(), bso::id.asc())), + _ => query, + }; + + let limit = limit.map(i64::from).unwrap_or(DEFAULT_LIMIT as i64).max(0); + // fetch an extra row to detect if there are more rows that + // match the query conditions + query = query.limit(if limit > 0 { limit + 1 } else { limit }); + + let numeric_offset = offset.map_or(0, |offset| offset.offset as i64); + + if numeric_offset > 0 { + // XXX: copy over this optimization: + // https://github.com/mozilla-services/server-syncstorage/blob/a0f8117/syncstorage/storage/sql/__init__.py#L404 + query = query.offset(numeric_offset); + } + let mut bsos = query.load::(&self.conn)?; + + // XXX: an additional get_collection_timestamp is done here in + // python to trigger potential CollectionNotFoundErrors + //if bsos.len() == 0 { + //} + + let next_offset = if limit >= 0 && bsos.len() > limit as usize { + bsos.pop(); + Some((limit + numeric_offset).to_string()) + } else { + // if an explicit "limit=0" is sent, return the offset of "0" + // Otherwise, this would break at least the db::tests::db::get_bsos_limit_offset + // unit test. + if limit == 0 { + Some(0.to_string()) + } else { + None + } + }; + + Ok(results::GetBsos { + items: bsos, + offset: next_offset, + }) + } + + pub fn get_bso_ids_sync(&self, params: params::GetBsos) -> Result { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_collection_id(¶ms.collection)?; + let BsoQueryParams { + newer, + older, + sort, + limit, + offset, + ids, + .. + } = params.params; + let mut query = bso::table + .select(bso::id) + .filter(bso::user_id.eq(user_id)) + .filter(bso::collection_id.eq(collection_id as i32)) // XXX: + .filter(bso::expiry.gt(self.timestamp().as_i64())) + .into_boxed(); + + if let Some(older) = older { + query = query.filter(bso::modified.lt(older.as_i64())); + } + if let Some(newer) = newer { + query = query.filter(bso::modified.gt(newer.as_i64())); + } + + if !ids.is_empty() { + query = query.filter(bso::id.eq_any(ids)); + } + + query = match sort { + Sorting::Index => query.order(bso::sortindex.desc()), + Sorting::Newest => query.order(bso::modified.desc()), + Sorting::Oldest => query.order(bso::modified.asc()), + _ => query, + }; + + // negative limits are no longer allowed by mysql. TODO + let limit = limit.map(i64::from).unwrap_or(DEFAULT_LIMIT as i64).max(0); + // fetch an extra row to detect if there are more rows that + // match the query conditions. Negative limits will cause an error. + query = query.limit(if limit == 0 { limit } else { limit + 1 }); + let numeric_offset = offset.map_or(0, |offset| offset.offset as i64); + if numeric_offset != 0 { + // XXX: copy over this optimization: + // https://github.com/mozilla-services/server-syncstorage/blob/a0f8117/syncstorage/storage/sql/__init__.py#L404 + query = query.offset(numeric_offset); + } + let mut ids = query.load::(&self.conn)?; + + // XXX: an additional get_collection_timestamp is done here in + // python to trigger potential CollectionNotFoundErrors + //if bsos.len() == 0 { + //} + + let next_offset = if limit >= 0 && ids.len() > limit as usize { + ids.pop(); + Some((limit + numeric_offset).to_string()) + } else { + None + }; + + Ok(results::GetBsoIds { + items: ids, + offset: next_offset, + }) + } + + pub fn get_bso_sync(&self, params: params::GetBso) -> Result> { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_collection_id(¶ms.collection)?; + Ok(bso::table + .select(( + bso::id, + bso::modified, + bso::payload, + bso::sortindex, + bso::expiry, + )) + .filter(bso::user_id.eq(user_id)) + .filter(bso::collection_id.eq(&collection_id)) + .filter(bso::id.eq(¶ms.id)) + .filter(bso::expiry.ge(self.timestamp().as_i64())) + .get_result::(&self.conn) + .optional()?) + } + + pub fn delete_bso_sync(&self, params: params::DeleteBso) -> Result { + let user_id = params.user_id.legacy_id; + let collection_id = self.get_collection_id(¶ms.collection)?; + let affected_rows = delete(bso::table) + .filter(bso::user_id.eq(user_id as i64)) + .filter(bso::collection_id.eq(&collection_id)) + .filter(bso::id.eq(params.id)) + .filter(bso::expiry.gt(&self.timestamp().as_i64())) + .execute(&self.conn)?; + if affected_rows == 0 { + Err(DbErrorKind::BsoNotFound)? + } + self.update_collection(user_id as u32, collection_id) + } + + pub fn delete_bsos_sync(&self, params: params::DeleteBsos) -> Result { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_collection_id(¶ms.collection)?; + delete(bso::table) + .filter(bso::user_id.eq(user_id)) + .filter(bso::collection_id.eq(&collection_id)) + .filter(bso::id.eq_any(params.ids)) + .execute(&self.conn)?; + self.update_collection(user_id as u32, collection_id) + } + + pub fn post_bsos_sync(&self, input: params::PostBsos) -> Result { + let collection_id = self.get_or_create_collection_id(&input.collection)?; + let mut result = results::PostBsos { + modified: self.timestamp(), + success: Default::default(), + failed: input.failed, + }; + + for pbso in input.bsos { + let id = pbso.id; + let put_result = self.put_bso_sync(params::PutBso { + user_id: input.user_id.clone(), + collection: input.collection.clone(), + id: id.clone(), + payload: pbso.payload, + sortindex: pbso.sortindex, + ttl: pbso.ttl, + }); + // XXX: python version doesn't report failures from db + // layer.. (wouldn't db failures abort the entire transaction + // anyway?) + // XXX: sanitize to.to_string()? + match put_result { + Ok(_) => result.success.push(id), + Err(e) => { + result.failed.insert(id, e.to_string()); + } + } + } + self.update_collection(input.user_id.legacy_id as u32, collection_id)?; + Ok(result) + } + + pub fn get_storage_timestamp_sync(&self, user_id: HawkIdentifier) -> Result { + let user_id = user_id.legacy_id as i64; + let modified = user_collections::table + .select(max(user_collections::modified)) + .filter(user_collections::user_id.eq(user_id)) + .first::>(&self.conn)? + .unwrap_or_default(); + Ok(SyncTimestamp::from_i64(modified)?) + } + + pub fn get_collection_timestamp_sync( + &self, + params: params::GetCollectionTimestamp, + ) -> Result { + let user_id = params.user_id.legacy_id as u32; + let collection_id = self.get_collection_id(¶ms.collection)?; + if let Some(modified) = self + .session + .borrow() + .coll_modified_cache + .get(&(user_id, collection_id)) + { + return Ok(*modified); + } + user_collections::table + .select(user_collections::modified) + .filter(user_collections::user_id.eq(user_id as i64)) + .filter(user_collections::collection_id.eq(collection_id)) + .first(&self.conn) + .optional()? + .ok_or_else(|| DbErrorKind::CollectionNotFound.into()) + } + + pub fn get_bso_timestamp_sync(&self, params: params::GetBsoTimestamp) -> Result { + let user_id = params.user_id.legacy_id as i64; + let collection_id = self.get_collection_id(¶ms.collection)?; + let modified = bso::table + .select(bso::modified) + .filter(bso::user_id.eq(user_id)) + .filter(bso::collection_id.eq(&collection_id)) + .filter(bso::id.eq(¶ms.id)) + .first::(&self.conn) + .optional()? + .unwrap_or_default(); + Ok(SyncTimestamp::from_i64(modified)?) + } + + pub fn get_collection_timestamps_sync( + &self, + user_id: HawkIdentifier, + ) -> Result { + let modifieds = sql_query(format!( + "SELECT {collection_id}, {modified} + FROM user_collections + WHERE {user_id} = ? + AND {collection_id} != ?", + collection_id = COLLECTION_ID, + user_id = USER_ID, + modified = LAST_MODIFIED + )) + .bind::(user_id.legacy_id as i64) + .bind::(TOMBSTONE) + .load::(&self.conn)? + .into_iter() + .map(|cr| SyncTimestamp::from_i64(cr.last_modified).map(|ts| (cr.collection, ts))) + .collect::>>()?; + self.map_collection_names(modifieds) + } + + fn check_sync(&self) -> Result { + // Check if querying works + sql_query("SELECT 1").execute(&self.conn)?; + Ok(true) + } + + fn map_collection_names(&self, by_id: HashMap) -> Result> { + let mut names = self.load_collection_names(by_id.keys())?; + by_id + .into_iter() + .map(|(id, value)| { + names + .remove(&id) + .map(|name| (name, value)) + .ok_or_else(|| DbError::internal("load_collection_names unknown collection id")) + }) + .collect() + } + + fn load_collection_names<'a>( + &self, + collection_ids: impl Iterator, + ) -> Result> { + let mut names = HashMap::new(); + let mut uncached = Vec::new(); + for &id in collection_ids { + if let Some(name) = self.coll_cache.get_name(id)? { + names.insert(id, name); + } else { + uncached.push(id); + } + } + + if !uncached.is_empty() { + let result = collections::table + .select((collections::id, collections::name)) + .filter(collections::id.eq_any(uncached)) + .load::<(i32, String)>(&self.conn)?; + + for (id, name) in result { + names.insert(id, name.clone()); + if !self.session.borrow().in_write_transaction { + self.coll_cache.put(id, name)?; + } + } + } + + Ok(names) + } + + pub(super) fn update_collection( + &self, + user_id: u32, + collection_id: i32, + ) -> Result { + let quota = if self.quota.enabled { + self.calc_quota_usage_sync(user_id, collection_id)? + } else { + results::GetQuotaUsage { + count: 0, + total_bytes: 0, + } + }; + let upsert = format!( + r#" + INSERT INTO user_collections ({user_id}, {collection_id}, {modified}, {total_bytes}, {count}) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT({user_id}, {collection_id}) DO UPDATE SET + {modified} = ?, + {total_bytes} = ?, + {count} = ? + "#, + user_id = USER_ID, + collection_id = COLLECTION_ID, + modified = LAST_MODIFIED, + count = COUNT, + total_bytes = TOTAL_BYTES, + ); + let total_bytes = quota.total_bytes as i64; + sql_query(upsert) + .bind::(user_id as i64) + .bind::(&collection_id) + .bind::(&self.timestamp().as_i64()) + .bind::(&total_bytes) + .bind::("a.count) + .bind::(&self.timestamp().as_i64()) + .bind::(&total_bytes) + .bind::("a.count) + .execute(&self.conn)?; + Ok(self.timestamp()) + } + + // Perform a lighter weight "read only" storage size check + pub fn get_storage_usage_sync( + &self, + user_id: HawkIdentifier, + ) -> Result { + let uid = user_id.legacy_id as i64; + let total_bytes = bso::table + .select(sql::>("SUM(LENGTH(payload))")) + .filter(bso::user_id.eq(uid)) + .filter(bso::expiry.gt(&self.timestamp().as_i64())) + .get_result::>(&self.conn)?; + Ok(total_bytes.unwrap_or_default() as u64) + } + + // Perform a lighter weight "read only" quota storage check + pub fn get_quota_usage_sync( + &self, + params: params::GetQuotaUsage, + ) -> Result { + let uid = params.user_id.legacy_id as i64; + let (total_bytes, count): (i64, i32) = user_collections::table + .select(( + sql::("COALESCE(SUM(COALESCE(total_bytes, 0)), 0)"), + sql::("COALESCE(SUM(COALESCE(count, 0)), 0)"), + )) + .filter(user_collections::user_id.eq(uid)) + .filter(user_collections::collection_id.eq(params.collection_id)) + .get_result(&self.conn) + .optional()? + .unwrap_or_default(); + Ok(results::GetQuotaUsage { + total_bytes: total_bytes as usize, + count, + }) + } + + // perform a heavier weight quota calculation + pub fn calc_quota_usage_sync( + &self, + user_id: u32, + collection_id: i32, + ) -> Result { + let (total_bytes, count): (i64, i32) = bso::table + .select(( + sql::(r#"COALESCE(SUM(LENGTH(COALESCE(payload, ""))),0)"#), + sql::("COALESCE(COUNT(*),0)"), + )) + .filter(bso::user_id.eq(user_id as i64)) + .filter(bso::expiry.gt(self.timestamp().as_i64())) + .filter(bso::collection_id.eq(collection_id)) + .get_result(&self.conn) + .optional()? + .unwrap_or_default(); + Ok(results::GetQuotaUsage { + total_bytes: total_bytes as usize, + count, + }) + } + + pub fn get_collection_usage_sync( + &self, + user_id: HawkIdentifier, + ) -> Result { + let counts = bso::table + .select((bso::collection_id, sql::("SUM(LENGTH(payload))"))) + .filter(bso::user_id.eq(user_id.legacy_id as i64)) + .filter(bso::expiry.gt(&self.timestamp().as_i64())) + .group_by(bso::collection_id) + .load(&self.conn)? + .into_iter() + .collect(); + self.map_collection_names(counts) + } + + pub fn get_collection_counts_sync( + &self, + user_id: HawkIdentifier, + ) -> Result { + let counts = bso::table + .select(( + bso::collection_id, + sql::(&format!( + "COUNT({collection_id})", + collection_id = COLLECTION_ID + )), + )) + .filter(bso::user_id.eq(user_id.legacy_id as i64)) + .filter(bso::expiry.gt(&self.timestamp().as_i64())) + .group_by(bso::collection_id) + .load(&self.conn)? + .into_iter() + .collect(); + self.map_collection_names(counts) + } + + batch_db_method!(create_batch_sync, create, CreateBatch); + batch_db_method!(validate_batch_sync, validate, ValidateBatch); + batch_db_method!(append_to_batch_sync, append, AppendToBatch); + batch_db_method!(commit_batch_sync, commit, CommitBatch); + #[cfg(test)] + batch_db_method!(delete_batch_sync, delete, DeleteBatch); + + pub fn get_batch_sync(&self, params: params::GetBatch) -> Result> { + batch::get(&self, params) + } + + pub fn timestamp(&self) -> SyncTimestamp { + self.session.borrow().timestamp + } +} + +macro_rules! sync_db_method { + ($name:ident, $sync_name:ident, $type:ident) => { + sync_db_method!($name, $sync_name, $type, results::$type); + }; + ($name:ident, $sync_name:ident, $type:ident, $result:ty) => { + fn $name(&self, params: params::$type) -> DbFuture<'_, $result> { + let db = self.clone(); + Box::pin(block(move || db.$sync_name(params).map_err(Into::into)).map_err(Into::into)) + } + }; +} + +impl<'a> Db<'a> for SqliteDb { + fn commit(&self) -> DbFuture<'_, ()> { + let db = self.clone(); + Box::pin(block(move || db.commit_sync().map_err(Into::into)).map_err(Into::into)) + } + + fn rollback(&self) -> DbFuture<'_, ()> { + let db = self.clone(); + Box::pin(block(move || db.rollback_sync().map_err(Into::into)).map_err(Into::into)) + } + + fn begin(&self, for_write: bool) -> DbFuture<'_, ()> { + let db = self.clone(); + Box::pin(async move { db.begin_async(for_write).map_err(Into::into).await }) + } + + fn box_clone(&self) -> Box> { + Box::new(self.clone()) + } + + fn check(&self) -> DbFuture<'_, results::Check> { + let db = self.clone(); + Box::pin(block(move || db.check_sync().map_err(Into::into)).map_err(Into::into)) + } + + sync_db_method!(lock_for_read, lock_for_read_sync, LockCollection); + sync_db_method!(lock_for_write, lock_for_write_sync, LockCollection); + sync_db_method!( + get_collection_timestamps, + get_collection_timestamps_sync, + GetCollectionTimestamps + ); + sync_db_method!( + get_collection_timestamp, + get_collection_timestamp_sync, + GetCollectionTimestamp + ); + sync_db_method!( + get_collection_counts, + get_collection_counts_sync, + GetCollectionCounts + ); + sync_db_method!( + get_collection_usage, + get_collection_usage_sync, + GetCollectionUsage + ); + sync_db_method!( + get_storage_timestamp, + get_storage_timestamp_sync, + GetStorageTimestamp + ); + sync_db_method!(get_storage_usage, get_storage_usage_sync, GetStorageUsage); + sync_db_method!(get_quota_usage, get_quota_usage_sync, GetQuotaUsage); + sync_db_method!(delete_storage, delete_storage_sync, DeleteStorage); + sync_db_method!(delete_collection, delete_collection_sync, DeleteCollection); + sync_db_method!(delete_bsos, delete_bsos_sync, DeleteBsos); + sync_db_method!(get_bsos, get_bsos_sync, GetBsos); + sync_db_method!(get_bso_ids, get_bso_ids_sync, GetBsoIds); + sync_db_method!(post_bsos, post_bsos_sync, PostBsos); + sync_db_method!(delete_bso, delete_bso_sync, DeleteBso); + sync_db_method!(get_bso, get_bso_sync, GetBso, Option); + sync_db_method!( + get_bso_timestamp, + get_bso_timestamp_sync, + GetBsoTimestamp, + results::GetBsoTimestamp + ); + sync_db_method!(put_bso, put_bso_sync, PutBso); + sync_db_method!(create_batch, create_batch_sync, CreateBatch); + sync_db_method!(validate_batch, validate_batch_sync, ValidateBatch); + sync_db_method!(append_to_batch, append_to_batch_sync, AppendToBatch); + sync_db_method!( + get_batch, + get_batch_sync, + GetBatch, + Option + ); + sync_db_method!(commit_batch, commit_batch_sync, CommitBatch); + + fn get_collection_id(&self, name: String) -> DbFuture<'_, i32> { + let db = self.clone(); + Box::pin(block(move || db.get_collection_id(&name).map_err(Into::into)).map_err(Into::into)) + } + + #[cfg(test)] + fn create_collection(&self, name: String) -> DbFuture<'_, i32> { + let db = self.clone(); + Box::pin( + block(move || db.get_or_create_collection_id(&name).map_err(Into::into)) + .map_err(Into::into), + ) + } + + #[cfg(test)] + fn update_collection(&self, param: params::UpdateCollection) -> DbFuture<'_, SyncTimestamp> { + let db = self.clone(); + Box::pin( + block(move || { + db.update_collection(param.user_id.legacy_id as u32, param.collection_id) + .map_err(Into::into) + }) + .map_err(Into::into), + ) + } + + #[cfg(test)] + fn timestamp(&self) -> SyncTimestamp { + self.timestamp() + } + + #[cfg(test)] + fn set_timestamp(&self, timestamp: SyncTimestamp) { + self.session.borrow_mut().timestamp = timestamp; + } + + #[cfg(test)] + sync_db_method!(delete_batch, delete_batch_sync, DeleteBatch); + + #[cfg(test)] + fn clear_coll_cache(&self) -> DbFuture<'_, ()> { + let db = self.clone(); + Box::pin( + block(move || { + db.coll_cache.clear(); + Ok(()) + }) + .map_err(Into::into), + ) + } + + #[cfg(test)] + fn set_quota(&mut self, enabled: bool, limit: usize, enforced: bool) { + self.quota = Quota { + size: limit, + enabled, + enforced, + } + } +} + +#[derive(Debug, QueryableByName)] +struct IdResult { + #[sql_type = "Integer"] + id: i32, +} + +#[allow(dead_code)] // Not really dead, Rust can't see the use above +#[derive(Debug, QueryableByName)] +struct NameResult { + #[sql_type = "Text"] + name: String, +} + +#[derive(Debug, QueryableByName)] +struct UserCollectionsResult { + // Can't substitute column names here. + #[sql_type = "Integer"] + collection: i32, // COLLECTION_ID + #[sql_type = "BigInt"] + last_modified: i64, // LAST_MODIFIED +} diff --git a/src/db/sqlite/pool.rs b/src/db/sqlite/pool.rs new file mode 100644 index 0000000000..4b95108bd4 --- /dev/null +++ b/src/db/sqlite/pool.rs @@ -0,0 +1,200 @@ +use actix_web::web::block; + +use async_trait::async_trait; + +use std::{ + collections::HashMap, + fmt, + sync::{Arc, RwLock}, +}; + +use diesel::{ + r2d2::{ConnectionManager, Pool}, + sqlite::SqliteConnection, + Connection, +}; +#[cfg(test)] +use diesel_logger::LoggingConnection; + +use super::models::{Result, SqliteDb}; +#[cfg(test)] +use super::test::TestTransactionCustomizer; +use crate::db::{error::DbError, results, Db, DbPool, STD_COLLS}; +use crate::error::{ApiError, ApiResult}; +use crate::server::metrics::Metrics; +use crate::settings::{Quota, Settings}; + +embed_migrations!("migrations/sqlite"); +/// Run the diesel embedded migrations +/// +/// Sqlite DDL statements implicitly commit which could disrupt SqlitePool's +/// begin_test_transaction during tests. So this runs on its own separate conn. +pub fn run_embedded_migrations(settings: &Settings) -> Result<()> { + let conn = SqliteConnection::establish(strip_sqlite_prefix(&settings.database_url))?; + #[cfg(test)] + // XXX: this doesn't show the DDL statements + // https://github.com/shssoichiro/diesel-logger/issues/1 + embedded_migrations::run(&LoggingConnection::new(conn))?; + #[cfg(not(test))] + embedded_migrations::run(&conn)?; + Ok(()) +} + +/// Diesel expects a simple file path or `:memory:`, not necessarily a full URL. This functions strips +/// the `sqlite:` prefix and returns the rest. +pub fn strip_sqlite_prefix(sqlite_url: &str) -> &str { + &sqlite_url["sqlite:".len()..] +} + +#[derive(Clone)] +pub struct SqliteDbPool { + /// Pool of db connections + pool: Pool>, + /// Thread Pool for running synchronous db calls + /// In-memory cache of collection_ids and their names + coll_cache: Arc, + + metrics: Metrics, + quota: Quota, +} + +impl SqliteDbPool { + /// Creates a new pool of Sqlite db connections. + /// + /// Also initializes the Sqlite db, ensuring all migrations are ran. + pub fn new(settings: &Settings, metrics: &Metrics) -> Result { + run_embedded_migrations(settings)?; + Self::new_without_migrations(settings, metrics) + } + + pub fn new_without_migrations(settings: &Settings, metrics: &Metrics) -> Result { + let manager = + ConnectionManager::::new(strip_sqlite_prefix(&settings.database_url)); + let builder = Pool::builder() + .max_size(settings.database_pool_max_size.unwrap_or(10)) + .min_idle(settings.database_pool_min_idle); + + #[cfg(test)] + let builder = if settings.database_use_test_transactions { + builder.connection_customizer(Box::new(TestTransactionCustomizer)) + } else { + builder + }; + + Ok(Self { + pool: builder.build(manager)?, + coll_cache: Default::default(), + metrics: metrics.clone(), + quota: Quota { + size: settings.limits.max_quota_limit as usize, + enabled: settings.enable_quota, + enforced: settings.enforce_quota, + }, + }) + } + + pub fn get_sync(&self) -> Result { + Ok(SqliteDb::new( + self.pool.get()?, + Arc::clone(&self.coll_cache), + &self.metrics, + &self.quota, + )) + } +} + +#[async_trait(?Send)] +impl DbPool for SqliteDbPool { + async fn get<'a>(&'a self) -> ApiResult>> { + let pool = self.clone(); + let db = block(move || pool.get_sync().map_err(ApiError::from)).await?; + + Ok(Box::new(db) as Box>) + } + + fn state(&self) -> results::PoolState { + self.pool.state().into() + } + + fn validate_batch_id(&self, id: String) -> Result<()> { + super::batch::validate_batch_id(&id) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +impl fmt::Debug for SqliteDbPool { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("SqliteDbPool") + .field("coll_cache", &self.coll_cache) + .finish() + } +} + +#[derive(Debug)] +pub struct CollectionCache { + pub by_name: RwLock>, + pub by_id: RwLock>, +} + +impl CollectionCache { + pub fn put(&self, id: i32, name: String) -> Result<()> { + // XXX: should this emit a metric? + // XXX: should probably either lock both simultaneously during + // writes or use an RwLock alternative + self.by_name + .write() + .map_err(|_| DbError::internal("by_name write"))? + .insert(name.clone(), id); + self.by_id + .write() + .map_err(|_| DbError::internal("by_id write"))? + .insert(id, name); + Ok(()) + } + + pub fn get_id(&self, name: &str) -> Result> { + Ok(self + .by_name + .read() + .map_err(|_| DbError::internal("by_name read"))? + .get(name) + .cloned()) + } + + pub fn get_name(&self, id: i32) -> Result> { + Ok(self + .by_id + .read() + .map_err(|_| DbError::internal("by_id read"))? + .get(&id) + .cloned()) + } + + #[cfg(test)] + pub fn clear(&self) { + self.by_name.write().expect("by_name write").clear(); + self.by_id.write().expect("by_id write").clear(); + } +} + +impl Default for CollectionCache { + fn default() -> Self { + Self { + by_name: RwLock::new( + STD_COLLS + .iter() + .map(|(k, v)| ((*v).to_owned(), *k)) + .collect(), + ), + by_id: RwLock::new( + STD_COLLS + .iter() + .map(|(k, v)| (*k, (*v).to_owned())) + .collect(), + ), + } + } +} diff --git a/src/db/sqlite/schema.rs b/src/db/sqlite/schema.rs new file mode 100644 index 0000000000..c94b0d40a9 --- /dev/null +++ b/src/db/sqlite/schema.rs @@ -0,0 +1,74 @@ +// We use Bigint here instead of Integer, even though Sqlite does not have the concept of Bigint, +// to allow diesel to assume that integer is i64. See https://github.com/diesel-rs/diesel/issues/852 + +table! { + batch_uploads (batch_id, user_id) { + #[sql_name="batch"] + batch_id -> Bigint, + #[sql_name="userid"] + user_id -> Bigint, + #[sql_name="collection"] + collection_id -> Integer, + } +} + +table! { + batch_upload_items (batch_id, user_id, id) { + #[sql_name="batch"] + batch_id -> Bigint, + #[sql_name="userid"] + user_id -> Bigint, + id -> Text, + sortindex -> Nullable, + payload -> Nullable, + payload_size -> Nullable, + ttl_offset -> Nullable, + } +} + +table! { + bso (user_id, collection_id, id) { + #[sql_name="userid"] + user_id -> BigInt, + #[sql_name="collection"] + collection_id -> Integer, + id -> Text, + sortindex -> Nullable, + payload -> Text, + // not used, but legacy + payload_size -> Bigint, + modified -> Bigint, + #[sql_name="ttl"] + expiry -> Bigint, + } +} + +table! { + collections (id) { + id -> Integer, + name -> Text, + } +} + +table! { + user_collections (user_id, collection_id) { + #[sql_name="userid"] + user_id -> BigInt, + #[sql_name="collection"] + collection_id -> Integer, + #[sql_name="last_modified"] + modified -> Bigint, + #[sql_name="count"] + count -> Integer, + #[sql_name="total_bytes"] + total_bytes -> BigInt, + } +} + +allow_tables_to_appear_in_same_query!( + batch_uploads, + batch_upload_items, + bso, + collections, + user_collections, +); diff --git a/src/db/sqlite/test.rs b/src/db/sqlite/test.rs new file mode 100644 index 0000000000..be1a4a617b --- /dev/null +++ b/src/db/sqlite/test.rs @@ -0,0 +1,91 @@ +use std::{collections::HashMap, result::Result as StdResult}; + +use diesel::{ + r2d2::{CustomizeConnection, Error as PoolError}, + // expression_methods::TextExpressionMethods, // See note below about `not_like` becoming swedish + sqlite::SqliteConnection, + Connection, + ExpressionMethods, + QueryDsl, + RunQueryDsl, +}; +use url::Url; + +use crate::db::sqlite::{ + models::{Result, SqliteDb}, + pool::SqliteDbPool, + schema::collections, +}; +use crate::server::metrics; +use crate::settings::{test_settings, Settings}; + +#[derive(Debug)] +pub struct TestTransactionCustomizer; + +impl CustomizeConnection for TestTransactionCustomizer { + fn on_acquire(&self, conn: &mut SqliteConnection) -> StdResult<(), PoolError> { + conn.begin_test_transaction().map_err(PoolError::QueryError) + } +} + +pub fn db(settings: &Settings) -> Result { + let _ = env_logger::try_init(); + // inherit SYNC_DATABASE_URL from the env + + let pool = SqliteDbPool::new(&settings, &metrics::Metrics::noop())?; + pool.get_sync() +} + +#[test] +fn static_collection_id() -> Result<()> { + let settings = test_settings(); + if Url::parse(&settings.database_url).unwrap().scheme() != "sqlite" { + // Skip this test if we're not using sqlite + return Ok(()); + } + let db = db(&settings)?; + + // ensure DB actually has predefined common collections + let cols: Vec<(i32, _)> = vec![ + (1, "clients"), + (2, "crypto"), + (3, "forms"), + (4, "history"), + (5, "keys"), + (6, "meta"), + (7, "bookmarks"), + (8, "prefs"), + (9, "tabs"), + (10, "passwords"), + (11, "addons"), + (12, "addresses"), + (13, "creditcards"), + ]; + // The integration tests can create collections that start + // with `xxx%`. We should not include those in our counts for local + // unit tests. + // Note: not sure why but as of 11/02/20, `.not_like("xxx%")` is apparently + // swedish-ci. Commenting that out for now. + let results: HashMap = collections::table + .select((collections::id, collections::name)) + .filter(collections::name.ne("")) + //.filter(collections::name.not_like("xxx%")) // from most integration tests + .filter(collections::name.ne("xxx_col2")) // from server::test + .filter(collections::name.ne("col2")) // from older intergration tests + .load(&db.inner.conn)? + .into_iter() + .collect(); + assert_eq!(results.len(), cols.len(), "mismatched columns"); + for (id, name) in &cols { + assert_eq!(results.get(id).unwrap(), name); + } + + for (id, name) in &cols { + let result = db.get_collection_id(name)?; + assert_eq!(result, *id); + } + + let cid = db.get_or_create_collection_id("col1")?; + assert!(cid >= 100); + Ok(()) +}