diff --git a/src/db/error.rs b/src/db/error.rs index 731c1f081a..595309818f 100644 --- a/src/db/error.rs +++ b/src/db/error.rs @@ -38,8 +38,8 @@ pub enum DbErrorKind { #[error("Specified batch does not exist")] BatchNotFound, - #[error("Tokenserver user not found")] - TokenserverUserNotFound, + #[error("Tokenserver user retired")] + TokenserverUserRetired, #[error("An attempt at a conflicting write")] Conflict, @@ -84,9 +84,7 @@ impl DbError { impl From for DbError { fn from(kind: DbErrorKind) -> Self { let status = match kind { - DbErrorKind::TokenserverUserNotFound - | DbErrorKind::CollectionNotFound - | DbErrorKind::BsoNotFound => StatusCode::NOT_FOUND, + DbErrorKind::CollectionNotFound | DbErrorKind::BsoNotFound => StatusCode::NOT_FOUND, // Matching the Python code here (a 400 vs 404) DbErrorKind::BatchNotFound | DbErrorKind::SpannerTooLarge(_) => StatusCode::BAD_REQUEST, // NOTE: the protocol specification states that we should return a @@ -96,6 +94,8 @@ impl From for DbError { // * android bug: https://bugzilla.mozilla.org/show_bug.cgi?id=959032 DbErrorKind::Conflict => StatusCode::SERVICE_UNAVAILABLE, DbErrorKind::Quota => StatusCode::FORBIDDEN, + // NOTE: TokenserverUserRetired is an internal service error for compatibility reasons + // (the legacy Tokenserver returned an internal service error in this situation) _ => StatusCode::INTERNAL_SERVER_ERROR, }; diff --git a/src/tokenserver/db/mock.rs b/src/tokenserver/db/mock.rs index 6141225b32..5815052a48 100644 --- a/src/tokenserver/db/mock.rs +++ b/src/tokenserver/db/mock.rs @@ -37,8 +37,8 @@ impl MockDb { } impl Db for MockDb { - fn get_user(&self, _params: params::GetUser) -> DbFuture<'_, results::GetUser> { - Box::pin(future::ok(results::GetUser::default())) + fn replace_user(&self, _params: params::ReplaceUser) -> DbFuture<'_, results::ReplaceUser> { + Box::pin(future::ok(())) } fn replace_users(&self, _params: params::ReplaceUsers) -> DbFuture<'_, results::ReplaceUsers> { @@ -49,6 +49,10 @@ impl Db for MockDb { Box::pin(future::ok(results::PostUser::default())) } + fn allocate_user(&self, _params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser> { + Box::pin(future::ok(results::AllocateUser::default())) + } + fn put_user(&self, _params: params::PutUser) -> DbFuture<'_, results::PutUser> { Box::pin(future::ok(())) } @@ -61,6 +65,24 @@ impl Db for MockDb { Box::pin(future::ok(results::GetNodeId::default())) } + fn get_best_node(&self, _params: params::GetBestNode) -> DbFuture<'_, results::GetBestNode> { + Box::pin(future::ok(results::GetBestNode::default())) + } + + fn add_user_to_node( + &self, + _params: params::AddUserToNode, + ) -> DbFuture<'_, results::AddUserToNode> { + Box::pin(future::ok(())) + } + + fn get_or_create_user( + &self, + _params: params::GetOrCreateUser, + ) -> DbFuture<'_, results::GetOrCreateUser> { + Box::pin(future::ok(results::GetOrCreateUser::default())) + } + #[cfg(test)] fn set_user_created_at( &self, @@ -69,6 +91,19 @@ impl Db for MockDb { Box::pin(future::ok(())) } + #[cfg(test)] + fn set_user_replaced_at( + &self, + _params: params::SetUserReplacedAt, + ) -> DbFuture<'_, results::SetUserReplacedAt> { + Box::pin(future::ok(())) + } + + #[cfg(test)] + fn get_user(&self, _params: params::GetUser) -> DbFuture<'_, results::GetUser> { + Box::pin(future::ok(results::GetUser::default())) + } + #[cfg(test)] fn get_users(&self, _params: params::GetRawUsers) -> DbFuture<'_, results::GetRawUsers> { Box::pin(future::ok(results::GetRawUsers::default())) @@ -79,6 +114,21 @@ impl Db for MockDb { Box::pin(future::ok(results::PostNode::default())) } + #[cfg(test)] + fn get_node(&self, _params: params::GetNode) -> DbFuture<'_, results::GetNode> { + Box::pin(future::ok(results::GetNode::default())) + } + + #[cfg(test)] + fn unassign_node(&self, _params: params::UnassignNode) -> DbFuture<'_, results::UnassignNode> { + Box::pin(future::ok(())) + } + + #[cfg(test)] + fn remove_node(&self, _params: params::RemoveNode) -> DbFuture<'_, results::RemoveNode> { + Box::pin(future::ok(())) + } + #[cfg(test)] fn post_service(&self, _params: params::PostService) -> DbFuture<'_, results::PostService> { Box::pin(future::ok(results::PostService::default())) diff --git a/src/tokenserver/db/models.rs b/src/tokenserver/db/models.rs index 1b7e6341b4..2b2cf957eb 100644 --- a/src/tokenserver/db/models.rs +++ b/src/tokenserver/db/models.rs @@ -1,22 +1,30 @@ -use actix_web::web::block; +use actix_web::{http::StatusCode, web::block}; use diesel::{ mysql::MysqlConnection, r2d2::{ConnectionManager, PooledConnection}, - sql_types::{Bigint, Integer, Nullable, Text}, - RunQueryDsl, + sql_types::{Bigint, Float, Integer, Nullable, Text}, + OptionalExtension, RunQueryDsl, }; #[cfg(test)] use diesel_logger::LoggingConnection; use futures::future::LocalBoxFuture; use futures::TryFutureExt; -use std::{result, sync::Arc}; +use std::{ + result, + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; use super::{params, results}; use crate::db::error::{DbError, DbErrorKind}; use crate::error::ApiError; use crate::sync_db_method; +/// The maximum possible generation number. Used as a tombstone to mark users that have been +/// "retired" from the db. +const MAX_GENERATION: i64 = i64::MAX; + pub type DbFuture<'a, T> = LocalBoxFuture<'a, Result>; pub type DbResult = result::Result; type Conn = PooledConnection>; @@ -66,69 +74,6 @@ impl TokenserverDb { } } - /// Get the most current user record for the given email and service ID. This function also - /// marks any old user records as replaced, in case of data races that may have occurred - /// during row creation. - fn get_user_sync(&self, params: params::GetUser) -> DbResult { - const QUERY: &str = r#" - SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at, - replaced_at - FROM users - LEFT OUTER JOIN nodes ON users.nodeid = nodes.id - WHERE email = ? - AND users.service = ? - ORDER BY created_at DESC, uid DESC - LIMIT 20 - "#; - let mut raw_users = diesel::sql_query(QUERY) - .bind::(¶ms.email) - .bind::(params.service_id) - .load::(&self.inner.conn)?; - - if raw_users.is_empty() { - return Err(DbErrorKind::TokenserverUserNotFound.into()); - } - - raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at)); - raw_users.reverse(); - - // The user with the greatest `generation` and `created_at` is the current user - let raw_user = raw_users[0].clone(); - - // Collect any old client states that differ from the current client state - let old_client_states = raw_users[1..] - .iter() - .map(|user| user.client_state.clone()) - .filter(|client_state| client_state != &raw_user.client_state) - .collect(); - - // Make sure every old row is marked as replaced. They might not be, due to races in row - // creation. - for old_user in &raw_users[1..] { - if old_user.replaced_at.is_none() { - let params = params::ReplaceUser { - uid: old_user.uid, - service_id: params.service_id, - replaced_at: old_user.created_at, - }; - - self.replace_user_sync(params)?; - } - } - - let user = results::GetUser { - uid: raw_user.uid, - client_state: raw_user.client_state, - generation: raw_user.generation, - node: raw_user.node, - keys_changed_at: raw_user.keys_changed_at, - created_at: raw_user.created_at, - old_client_states, - }; - - Ok(user) - } - fn get_node_id_sync(&self, params: params::GetNodeId) -> DbResult { const QUERY: &str = r#" SELECT id @@ -216,11 +161,11 @@ impl TokenserverDb { /// Create a new user. fn post_user_sync(&self, user: params::PostUser) -> DbResult { - const INSERT_USER_QUERY: &str = r#" + const QUERY: &str = r#" INSERT INTO users (service, email, generation, client_state, created_at, nodeid, keys_changed_at, replaced_at) VALUES (?, ?, ?, ?, ?, ?, ?, NULL); "#; - diesel::sql_query(INSERT_USER_QUERY) + diesel::sql_query(QUERY) .bind::(user.service_id) .bind::(&user.email) .bind::(user.generation) @@ -242,6 +187,83 @@ impl TokenserverDb { Ok(result as u64 > 0) } + /// Gets the least-loaded node that has available slots. + fn get_best_node_sync(&self, params: params::GetBestNode) -> DbResult { + const GET_BEST_NODE_QUERY: &str = r#" + SELECT id, node + FROM nodes + WHERE service = ? + AND available > 0 + AND capacity > current_load + AND downed = 0 + AND backoff = 0 + ORDER BY LOG(current_load) / LOG(capacity) + LIMIT 1 + "#; + const RELEASE_CAPACITY_QUERY: &str = r#" + UPDATE nodes + SET available = LEAST(capacity * ?, capacity - current_load) + WHERE service = ? + AND available <= 0 + AND capacity > current_load + AND downed = 0 + "#; + const DEFAULT_CAPACITY_RELEASE_RATE: f32 = 0.1; + + // We may have to retry the query if we need to release more capacity. This loop allows + // a maximum of five retries before bailing out. + for _ in 0..5 { + let maybe_result = diesel::sql_query(GET_BEST_NODE_QUERY) + .bind::(params.service_id) + .get_result::(&self.inner.conn) + .optional()?; + + if let Some(result) = maybe_result { + return Ok(result); + } + + // There were no available nodes. Try to release additional capacity from any nodes + // that are not fully occupied. + let affected_rows = diesel::sql_query(RELEASE_CAPACITY_QUERY) + .bind::( + params + .capacity_release_rate + .unwrap_or(DEFAULT_CAPACITY_RELEASE_RATE), + ) + .bind::(params.service_id) + .execute(&self.inner.conn)?; + + // If no nodes were affected by the last query, give up. + if affected_rows == 0 { + break; + } + } + + let mut db_error: DbError = DbErrorKind::Internal("unable to get a node".to_owned()).into(); + db_error.status = StatusCode::SERVICE_UNAVAILABLE; + Err(db_error) + } + + fn add_user_to_node_sync( + &self, + params: params::AddUserToNode, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE nodes + SET current_load = current_load + 1, + available = GREATEST(available - 1, 0) + WHERE service = ? + AND node = ? + "#; + + diesel::sql_query(QUERY) + .bind::(params.service_id) + .bind::(¶ms.node) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + #[cfg(test)] fn set_user_created_at_sync( &self, @@ -260,6 +282,38 @@ impl TokenserverDb { .map_err(Into::into) } + #[cfg(test)] + fn set_user_replaced_at_sync( + &self, + params: params::SetUserReplacedAt, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = ? + WHERE uid = ? + "#; + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(¶ms.uid) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + #[cfg(test)] + fn get_user_sync(&self, params: params::GetUser) -> DbResult { + const QUERY: &str = r#" + SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at + FROM users + WHERE uid = ? + "#; + + diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + #[cfg(test)] fn get_users_sync(&self, email: String) -> DbResult { const QUERY: &str = r#" @@ -278,11 +332,11 @@ impl TokenserverDb { #[cfg(test)] fn post_node_sync(&self, params: params::PostNode) -> DbResult { - const INSERT_NODE_QUERY: &str = r#" + const QUERY: &str = r#" INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) VALUES (?, ?, ?, ?, ?, ?, ?) "#; - diesel::sql_query(INSERT_NODE_QUERY) + diesel::sql_query(QUERY) .bind::(params.service_id) .bind::(¶ms.node) .bind::(params.available) @@ -297,6 +351,52 @@ impl TokenserverDb { .map_err(Into::into) } + #[cfg(test)] + fn get_node_sync(&self, params: params::GetNode) -> DbResult { + const QUERY: &str = r#" + SELECT * + FROM nodes + WHERE id = ? + "#; + + diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&self.inner.conn) + .map_err(Into::into) + } + + #[cfg(test)] + fn unassign_node_sync(&self, params: params::UnassignNode) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = ? + WHERE nodeid = ? + "#; + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + + diesel::sql_query(QUERY) + .bind::(current_time) + .bind::(params.node_id) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + + #[cfg(test)] + fn remove_node_sync(&self, params: params::RemoveNode) -> DbResult { + const QUERY: &str = "DELETE FROM nodes WHERE id = ?"; + + diesel::sql_query(QUERY) + .bind::(params.node_id) + .execute(&self.inner.conn) + .map(|_| ()) + .map_err(Into::into) + } + #[cfg(test)] fn post_service_sync(&self, params: params::PostService) -> DbResult { const INSERT_SERVICE_QUERY: &str = r#" @@ -315,11 +415,182 @@ impl TokenserverDb { } impl Db for TokenserverDb { - sync_db_method!(get_user, get_user_sync, GetUser); + sync_db_method!(replace_user, replace_user_sync, ReplaceUser); sync_db_method!(replace_users, replace_users_sync, ReplaceUsers); sync_db_method!(post_user, post_user_sync, PostUser); + + /// Creates a new user and assigns them to a node. + fn allocate_user(&self, params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser> { + Box::pin(async move { + // Get the least-loaded node + let node = self + .get_best_node(params::GetBestNode { + service_id: params.service_id, + capacity_release_rate: params.capacity_release_rate, + }) + .await?; + + // Decrement `available` and increment `current_load` on the node assigned to the user. + self.add_user_to_node(params::AddUserToNode { + service_id: params.service_id, + node: node.node.clone(), + }) + .await?; + + let created_at = { + let start = SystemTime::now(); + start.duration_since(UNIX_EPOCH).unwrap().as_millis() as i64 + }; + let uid = self + .post_user(params::PostUser { + service_id: params.service_id, + email: params.email.clone(), + generation: params.generation, + client_state: params.client_state.clone(), + created_at, + node_id: node.id, + keys_changed_at: params.keys_changed_at, + }) + .await? + .id; + + Ok(results::AllocateUser { + uid, + node: node.node, + created_at, + }) + }) + } + sync_db_method!(put_user, put_user_sync, PutUser); sync_db_method!(get_node_id, get_node_id_sync, GetNodeId); + sync_db_method!(get_best_node, get_best_node_sync, GetBestNode); + sync_db_method!(add_user_to_node, add_user_to_node_sync, AddUserToNode); + + /// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new + /// user. + fn get_or_create_user( + &self, + params: params::GetOrCreateUser, + ) -> DbFuture<'_, results::GetOrCreateUser> { + const QUERY: &str = r#" + SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at, + replaced_at + FROM users + LEFT OUTER JOIN nodes ON users.nodeid = nodes.id + WHERE email = ? + AND users.service = ? + ORDER BY created_at DESC, uid DESC + LIMIT 20 + "#; + + Box::pin(async move { + let mut raw_users = diesel::sql_query(QUERY) + .bind::(¶ms.email) + .bind::(params.service_id) + .load::(&self.inner.conn) + .map_err(|e| ApiError::from(DbError::from(e)))?; + + if raw_users.is_empty() { + // There are no users in the database with the given email and service ID, so + // allocate a new one. + let allocate_user_result = self + .allocate_user(params.clone() as params::AllocateUser) + .await?; + + Ok(results::GetOrCreateUser { + uid: allocate_user_result.uid, + email: params.email.clone(), + client_state: params.client_state, + generation: params.generation, + node: allocate_user_result.node, + keys_changed_at: params.keys_changed_at, + created_at: allocate_user_result.created_at, + replaced_at: None, + old_client_states: vec![], + }) + } else { + raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at)); + raw_users.reverse(); + + // The user with the greatest `generation` and `created_at` is the current user + let raw_user = raw_users[0].clone(); + + // Collect any old client states that differ from the current client state + let old_client_states = raw_users[1..] + .iter() + .map(|user| user.client_state.clone()) + .filter(|client_state| client_state != &raw_user.client_state) + .collect(); + + // Make sure every old row is marked as replaced. They might not be, due to races in row + // creation. + for old_user in &raw_users[1..] { + if old_user.replaced_at.is_none() { + let params = params::ReplaceUser { + uid: old_user.uid, + service_id: params.service_id, + replaced_at: raw_user.created_at, + }; + + self.replace_user(params).await?; + } + } + + match (raw_user.replaced_at, raw_user.node) { + // If the most up-to-date user is marked as replaced or does not have a node + // assignment, allocate a new user. Note that, if the current user is marked + // as replaced, we do not want to create a new user with the account metadata + // in the parameters to this method. Rather, we want to create a duplicate of + // the replaced user assigned to a new node. This distinction is important + // because the account metadata in the parameters to this method may not match + // that currently stored on the most up-to-date user and may be invalid. + (Some(_), _) | (_, None) if raw_user.generation < MAX_GENERATION => { + let allocate_user_result = self + .allocate_user(params::AllocateUser { + service_id: params.service_id, + email: params.email.clone(), + generation: raw_user.generation, + client_state: raw_user.client_state.clone(), + keys_changed_at: raw_user.keys_changed_at, + capacity_release_rate: params.capacity_release_rate, + }) + .await?; + + Ok(results::GetOrCreateUser { + uid: allocate_user_result.uid, + email: params.email.clone(), + client_state: raw_user.client_state, + generation: raw_user.generation, + node: allocate_user_result.node, + keys_changed_at: raw_user.keys_changed_at, + created_at: allocate_user_result.created_at, + replaced_at: None, + old_client_states, + }) + } + // The most up-to-date user has a node. Note that this user may be retired or + // replaced. + (_, Some(node)) => Ok(results::GetOrCreateUser { + uid: raw_user.uid, + email: params.email.clone(), + client_state: raw_user.client_state, + generation: raw_user.generation, + node, + keys_changed_at: raw_user.keys_changed_at, + created_at: raw_user.created_at, + replaced_at: None, + old_client_states, + }), + // The most up-to-date user doesn't have a node and is retired. + (_, None) => Err(DbError::from(DbErrorKind::TokenserverUserRetired).into()), + } + } + }) + } + + #[cfg(test)] + sync_db_method!(get_user, get_user_sync, GetUser); fn check(&self) -> DbFuture<'_, results::Check> { let db = self.clone(); @@ -333,41 +604,89 @@ impl Db for TokenserverDb { SetUserCreatedAt ); + #[cfg(test)] + sync_db_method!( + set_user_replaced_at, + set_user_replaced_at_sync, + SetUserReplacedAt + ); + #[cfg(test)] sync_db_method!(get_users, get_users_sync, GetRawUsers); #[cfg(test)] sync_db_method!(post_node, post_node_sync, PostNode); + #[cfg(test)] + sync_db_method!(get_node, get_node_sync, GetNode); + + #[cfg(test)] + sync_db_method!(unassign_node, unassign_node_sync, UnassignNode); + + #[cfg(test)] + sync_db_method!(remove_node, remove_node_sync, RemoveNode); + #[cfg(test)] sync_db_method!(post_service, post_service_sync, PostService); } pub trait Db { - fn get_user(&self, params: params::GetUser) -> DbFuture<'_, results::GetUser>; + fn replace_user(&self, params: params::ReplaceUser) -> DbFuture<'_, results::ReplaceUser>; fn replace_users(&self, params: params::ReplaceUsers) -> DbFuture<'_, results::ReplaceUsers>; fn post_user(&self, params: params::PostUser) -> DbFuture<'_, results::PostUser>; + fn allocate_user(&self, params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser>; + fn put_user(&self, params: params::PutUser) -> DbFuture<'_, results::PutUser>; fn check(&self) -> DbFuture<'_, results::Check>; fn get_node_id(&self, params: params::GetNodeId) -> DbFuture<'_, results::GetNodeId>; + fn get_best_node(&self, params: params::GetBestNode) -> DbFuture<'_, results::GetBestNode>; + + fn add_user_to_node( + &self, + params: params::AddUserToNode, + ) -> DbFuture<'_, results::AddUserToNode>; + + fn get_or_create_user( + &self, + params: params::GetOrCreateUser, + ) -> DbFuture<'_, results::GetOrCreateUser>; + #[cfg(test)] fn set_user_created_at( &self, params: params::SetUserCreatedAt, ) -> DbFuture<'_, results::SetUserCreatedAt>; + #[cfg(test)] + fn set_user_replaced_at( + &self, + params: params::SetUserReplacedAt, + ) -> DbFuture<'_, results::SetUserReplacedAt>; + + #[cfg(test)] + fn get_user(&self, params: params::GetUser) -> DbFuture<'_, results::GetUser>; + #[cfg(test)] fn get_users(&self, params: params::GetRawUsers) -> DbFuture<'_, results::GetRawUsers>; #[cfg(test)] fn post_node(&self, params: params::PostNode) -> DbFuture<'_, results::PostNode>; + #[cfg(test)] + fn get_node(&self, params: params::GetNode) -> DbFuture<'_, results::GetNode>; + + #[cfg(test)] + fn unassign_node(&self, params: params::UnassignNode) -> DbFuture<'_, results::UnassignNode>; + + #[cfg(test)] + fn remove_node(&self, params: params::RemoveNode) -> DbFuture<'_, results::RemoveNode>; + #[cfg(test)] fn post_service(&self, params: params::PostService) -> DbFuture<'_, results::PostService>; } @@ -376,7 +695,8 @@ pub trait Db { mod tests { use super::*; - use std::time::{SystemTime, UNIX_EPOCH}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::settings::test_settings; use crate::tokenserver::db; @@ -384,63 +704,16 @@ mod tests { type Result = std::result::Result; - #[tokio::test] - async fn get_user() -> Result<()> { - let pool = db_pool().await?; - let db = pool.get()?; - - // Add a node - let node_id = db - .post_node(params::PostNode { - service_id: db::SYNC_1_5_SERVICE_ID, - ..Default::default() - }) - .await?; - - // Add a user - let email1 = "test_user_1"; - let user_id = db - .post_user(params::PostUser { - service_id: db::SYNC_1_5_SERVICE_ID, - node_id: node_id.id, - email: email1.to_owned(), - ..Default::default() - }) - .await?; - - // Add another user - db.post_user(params::PostUser { - service_id: db::SYNC_1_5_SERVICE_ID, - node_id: node_id.id, - email: "test_user_2".to_owned(), - ..Default::default() - }) - .await?; - - let user = db - .get_user(params::GetUser { - email: email1.to_owned(), - service_id: db::SYNC_1_5_SERVICE_ID, - }) - .await?; - - // Ensure that the correct user has been returned - assert_eq!(user.uid, user_id.id); - - Ok(()) - } - #[tokio::test] async fn test_update_generation() -> Result<()> { let pool = db_pool().await?; let db = pool.get()?; // Add a node - let node = "node"; let node_id = db .post_node(params::PostNode { service_id: db::SYNC_1_5_SERVICE_ID, - node: node.to_owned(), + node: "https://node1".to_owned(), ..Default::default() }) .await? @@ -458,12 +731,7 @@ mod tests { .await? .id; - let user = db - .get_user(params::GetUser { - email: email.to_owned(), - service_id: db::SYNC_1_5_SERVICE_ID, - }) - .await?; + let user = db.get_user(params::GetUser { id: uid }).await?; assert_eq!(user.generation, 0); assert_eq!(user.client_state, ""); @@ -477,15 +745,9 @@ mod tests { }) .await?; - let user = db - .get_user(params::GetUser { - email: email.to_owned(), - service_id: db::SYNC_1_5_SERVICE_ID, - }) - .await?; + let user = db.get_user(params::GetUser { id: uid }).await?; - assert_eq!(user.uid, uid); - assert_eq!(user.node, node); + assert_eq!(user.node_id, node_id); assert_eq!(user.generation, 42); assert_eq!(user.client_state, ""); @@ -498,8 +760,9 @@ mod tests { }) .await?; - assert_eq!(user.uid, uid); - assert_eq!(user.node, node); + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.node_id, node_id); assert_eq!(user.generation, 42); assert_eq!(user.client_state, ""); @@ -512,11 +775,10 @@ mod tests { let db = pool.get()?; // Add a node - let node = "node"; let node_id = db .post_node(params::PostNode { service_id: db::SYNC_1_5_SERVICE_ID, - node: node.to_owned(), + node: "https://node".to_owned(), ..Default::default() }) .await? @@ -534,12 +796,7 @@ mod tests { .await? .id; - let user = db - .get_user(params::GetUser { - email: email.to_owned(), - service_id: db::SYNC_1_5_SERVICE_ID, - }) - .await?; + let user = db.get_user(params::GetUser { id: uid }).await?; assert_eq!(user.keys_changed_at, None); assert_eq!(user.client_state, ""); @@ -553,15 +810,9 @@ mod tests { }) .await?; - let user = db - .get_user(params::GetUser { - email: email.to_owned(), - service_id: db::SYNC_1_5_SERVICE_ID, - }) - .await?; + let user = db.get_user(params::GetUser { id: uid }).await?; - assert_eq!(user.uid, uid); - assert_eq!(user.node, node); + assert_eq!(user.node_id, node_id); assert_eq!(user.keys_changed_at, Some(42)); assert_eq!(user.client_state, ""); @@ -574,8 +825,9 @@ mod tests { }) .await?; - assert_eq!(user.uid, uid); - assert_eq!(user.node, node); + let user = db.get_user(params::GetUser { id: uid }).await?; + + assert_eq!(user.node_id, node_id); assert_eq!(user.keys_changed_at, Some(42)); assert_eq!(user.client_state, ""); @@ -634,12 +886,17 @@ mod tests { service_id: db::SYNC_1_5_SERVICE_ID, node_id: node_id.id, email: email1.to_owned(), - replaced_at: Some(an_hour_ago + MILLISECONDS_IN_A_MINUTE), ..Default::default() }) .await? .id; + db.set_user_replaced_at(params::SetUserReplacedAt { + replaced_at: an_hour_ago + MILLISECONDS_IN_A_MINUTE, + uid, + }) + .await?; + db.set_user_created_at(params::SetUserCreatedAt { created_at: an_hour_ago, uid, @@ -752,56 +1009,48 @@ mod tests { service_id: db::SYNC_1_5_SERVICE_ID, ..Default::default() }; - let node_id = db.post_node(post_node_params.clone()).await?; + let node_id = db.post_node(post_node_params.clone()).await?.id; // Add a user let email1 = "test_user_1"; let post_user_params1 = params::PostUser { service_id: db::SYNC_1_5_SERVICE_ID, - node_id: node_id.id, email: email1.to_owned(), - ..Default::default() + generation: 1, + client_state: "616161".to_owned(), + created_at: 2, + node_id, + keys_changed_at: Some(3), }; - let post_user_result1 = db.post_user(post_user_params1.clone()).await?; + let uid1 = db.post_user(post_user_params1.clone()).await?.id; // Add another user let email2 = "test_user_2"; let post_user_params2 = params::PostUser { service_id: db::SYNC_1_5_SERVICE_ID, - node_id: node_id.id, + node_id, email: email2.to_owned(), ..Default::default() }; - let post_user_result2 = db.post_user(post_user_params2).await?; + let uid2 = db.post_user(post_user_params2).await?.id; // Ensure that two separate users were created - assert_ne!(post_user_result1.id, post_user_result2.id); + assert_ne!(uid1, uid2); // Get a user - let user = { - let params = params::GetUser { - email: email1.to_owned(), - service_id: db::SYNC_1_5_SERVICE_ID, - }; - - db.get_user(params).await? - }; + let user = db.get_user(params::GetUser { id: uid1 }).await?; // Ensure the user has the expected values - let mut expected_get_user = results::GetUser { - uid: post_user_result1.id, - client_state: post_user_params1.client_state.clone(), - generation: post_user_params1.generation, - keys_changed_at: post_user_params1.keys_changed_at, - node: post_node_params.node, - created_at: 0, - old_client_states: vec![], + let expected_get_user = results::GetUser { + service_id: db::SYNC_1_5_SERVICE_ID, + email: email1.to_owned(), + generation: 1, + client_state: "616161".to_owned(), + replaced_at: None, + node_id, + keys_changed_at: Some(3), }; - // Set created_at manually, since there's no way for us to know that timestamp without - // querying for the user - expected_get_user.created_at = user.created_at; - assert_eq!(user, expected_get_user); Ok(()) @@ -845,6 +1094,684 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_node_allocation() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await? + .id; + + // Allocating a user assigns it to the node + let user = db + .allocate_user(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + assert_eq!(user.node, "https://node1"); + + // Getting the user from the database does not affect node assignment + let user = db.get_user(params::GetUser { id: user.uid }).await?; + assert_eq!(user.node_id, node_id); + + Ok(()) + } + + #[tokio::test] + async fn test_allocation_to_least_loaded_node() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add two nodes + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node2".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + // Allocate two users + let user1 = db + .allocate_user(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user2 = db + .allocate_user(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Because users are always assigned to the least-loaded node, the users should have been + // assigned to different nodes + assert_ne!(user1.node, user2.node); + + Ok(()) + } + + #[tokio::test] + async fn test_allocation_is_not_allowed_to_downed_nodes() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a downed node + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + downed: 1, + ..Default::default() + }) + .await?; + + // User allocation fails because allocation is not allowed to downed nodes + let result = db + .allocate_user(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await; + let error = result.unwrap_err(); + assert_eq!(error.to_string(), "Unexpected error: unable to get a node"); + + Ok(()) + } + + #[tokio::test] + async fn test_allocation_is_not_allowed_to_backoff_nodes() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a backoff node + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + backoff: 1, + ..Default::default() + }) + .await?; + + // User allocation fails because allocation is not allowed to backoff nodes + let result = db + .allocate_user(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await; + let error = result.unwrap_err(); + assert_eq!(error.to_string(), "Unexpected error: unable to get a node"); + + Ok(()) + } + + #[tokio::test] + async fn test_node_reassignment_when_records_are_replaced() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a node + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + // Allocate a user + let allocate_user_result = db + .allocate_user(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + let user1 = db + .get_user(params::GetUser { + id: allocate_user_result.uid, + }) + .await?; + + // Mark the user as replaced + db.replace_user(params::ReplaceUser { + uid: allocate_user_result.uid, + service_id: db::SYNC_1_5_SERVICE_ID, + replaced_at: 1234, + }) + .await?; + + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + email: "test@test.com".to_owned(), + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1235, + client_state: "626262".to_owned(), + keys_changed_at: Some(1235), + capacity_release_rate: None, + }) + .await?; + + // Calling get_or_create_user() results in the creation of a new user record, since the + // previous record was marked as replaced + assert_ne!(allocate_user_result.uid, user2.uid); + + // The account metadata should match that of the original user and *not* that in the + // method parameters + assert_eq!(user1.generation, user2.generation); + assert_eq!(user1.keys_changed_at, user2.keys_changed_at); + assert_eq!(user1.client_state, user2.client_state); + + Ok(()) + } + + #[tokio::test] + async fn test_node_reassignment_not_done_for_retired_users() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a node + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await?; + + // Add a retired user + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: MAX_GENERATION, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Calling get_or_create_user() does not update the user's node + assert_eq!(user1.uid, user2.uid); + assert_eq!(user2.generation, MAX_GENERATION); + assert_eq!(user1.client_state, user2.client_state); + + Ok(()) + } + + #[tokio::test] + async fn test_node_reassignment_and_removal() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add two nodes + let node1_id = db + .post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await? + .id; + + let node2_id = db + .post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node2".to_owned(), + current_load: 0, + capacity: 100, + available: 100, + ..Default::default() + }) + .await? + .id; + + // Create four users. We should get two on each node. + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user3 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test3@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let user4 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + let node1_count = [&user1, &user2, &user3, &user4] + .iter() + .filter(|user| user.node == "https://node1") + .count(); + assert_eq!(node1_count, 2); + let node2_count = [&user1, &user2, &user3, &user4] + .iter() + .filter(|user| user.node == "https://node2") + .count(); + assert_eq!(node2_count, 2); + + // Clear the assignments on the first node. + db.unassign_node(params::UnassignNode { node_id: node1_id }) + .await?; + + // The users previously on the first node should balance across both nodes, + // giving 1 on the first node and 3 on the second node. + let mut node1_count = 0; + let mut node2_count = 0; + + for user in [&user1, &user2, &user3, &user4] { + let new_user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + email: user.email.clone(), + generation: user.generation, + client_state: user.client_state.clone(), + keys_changed_at: user.keys_changed_at, + capacity_release_rate: None, + }) + .await?; + + if new_user.node == "https://node1" { + node1_count += 1; + } else { + assert_eq!(new_user.node, "https://node2"); + + node2_count += 1; + } + } + + assert_eq!(node1_count, 1); + assert_eq!(node2_count, 3); + + // Remove the second node. Everyone should end up on the first node. + db.remove_node(params::RemoveNode { node_id: node2_id }) + .await?; + + // Every user should be on the first node now. + for user in [&user1, &user2, &user3, &user4] { + let new_user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + email: user.email.clone(), + generation: user.generation, + client_state: user.client_state.clone(), + keys_changed_at: user.keys_changed_at, + capacity_release_rate: None, + }) + .await?; + + assert_eq!(new_user.node, "https://node1"); + } + + Ok(()) + } + + #[tokio::test] + async fn test_gradual_release_of_node_capacity() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add two nodes + let node1_id = db + .post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 4, + capacity: 8, + available: 1, + ..Default::default() + }) + .await? + .id; + + let node2_id = db + .post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node2".to_owned(), + current_load: 4, + capacity: 6, + available: 1, + ..Default::default() + }) + .await? + .id; + + // Two user creations should succeed without releasing capacity on either of the nodes. + // The users should be assigned to different nodes. + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 5); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node2"); + let node = db.get_node(params::GetNode { id: node2_id }).await?; + assert_eq!(node.current_load, 5); + assert_eq!(node.capacity, 6); + assert_eq!(node.available, 0); + + // The next allocation attempt will release 10% more capacity, which is one more slot for + // each node. + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test3@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 6); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node2"); + let node = db.get_node(params::GetNode { id: node2_id }).await?; + assert_eq!(node.current_load, 6); + assert_eq!(node.capacity, 6); + assert_eq!(node.available, 0); + + // Now that node2 is full, further allocations will go to node1. + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test5@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 7); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + let user = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test6@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + assert_eq!(user.node, "https://node1"); + let node = db.get_node(params::GetNode { id: node1_id }).await?; + assert_eq!(node.current_load, 8); + assert_eq!(node.capacity, 8); + assert_eq!(node.available, 0); + + // Once the capacity is reached, further user allocations will result in an error. + let result = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test7@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await; + + assert_eq!( + result.unwrap_err().to_string(), + "Unexpected error: unable to get a node" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_correct_created_at_used_during_node_reassignment() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a node + let node_id = db + .post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 4, + capacity: 8, + available: 1, + ..Default::default() + }) + .await? + .id; + + // Create a user + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Clear the user's node + db.unassign_node(params::UnassignNode { node_id }).await?; + + // Sleep very briefly to ensure the timestamp created during node reassignment is greater + // than the timestamp created during user creation + thread::sleep(Duration::from_millis(5)); + + // Get the user, prompting the user's reassignment to the same node + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // The user's timestamp should be updated since a new user record was created. + assert!(user2.created_at > user1.created_at); + + Ok(()) + } + + #[tokio::test] + async fn test_correct_created_at_used_during_user_retrieval() -> Result<()> { + let pool = db_pool().await?; + let db = pool.get()?; + + // Add a node + db.post_node(params::PostNode { + service_id: db::SYNC_1_5_SERVICE_ID, + node: "https://node1".to_owned(), + current_load: 4, + capacity: 8, + available: 1, + ..Default::default() + }) + .await?; + + // Create a user + let user1 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // Sleep very briefly to ensure that any timestamp that might be created below is greater + // than the timestamp created during user creation + thread::sleep(Duration::from_millis(5)); + + // Get the user + let user2 = db + .get_or_create_user(params::GetOrCreateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test4@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }) + .await?; + + // The user's timestamp should be equal to the one generated when the user was created + assert_eq!(user1.created_at, user2.created_at); + + Ok(()) + } + async fn db_pool() -> DbResult { let _ = env_logger::try_init(); diff --git a/src/tokenserver/db/params.rs b/src/tokenserver/db/params.rs index cade61f38d..25d2c410cb 100644 --- a/src/tokenserver/db/params.rs +++ b/src/tokenserver/db/params.rs @@ -1,11 +1,5 @@ //! Parameter types for database methods. -#[derive(Default)] -pub struct GetUser { - pub email: String, - pub service_id: i32, -} - #[derive(Clone, Default)] pub struct PostNode { pub service_id: i32, @@ -17,12 +11,29 @@ pub struct PostNode { pub backoff: i32, } +#[derive(Clone, Default)] +pub struct GetNode { + pub id: i64, +} + #[derive(Default)] pub struct PostService { pub service: String, pub pattern: String, } +#[derive(Clone, Default)] +pub struct GetOrCreateUser { + pub service_id: i32, + pub email: String, + pub generation: i64, + pub client_state: String, + pub keys_changed_at: Option, + pub capacity_release_rate: Option, +} + +pub type AllocateUser = GetOrCreateUser; + #[derive(Clone, Default)] pub struct PostUser { pub service_id: i32, @@ -30,7 +41,6 @@ pub struct PostUser { pub generation: i64, pub client_state: String, pub created_at: i64, - pub replaced_at: Option, pub node_id: i64, pub keys_changed_at: Option, } @@ -65,6 +75,18 @@ pub struct GetNodeId { pub node: String, } +#[derive(Default)] +pub struct GetBestNode { + pub service_id: i32, + pub capacity_release_rate: Option, +} + +#[derive(Default)] +pub struct AddUserToNode { + pub service_id: i32, + pub node: String, +} + #[cfg(test)] pub type GetRawUsers = String; @@ -73,3 +95,25 @@ pub struct SetUserCreatedAt { pub uid: i64, pub created_at: i64, } + +#[cfg(test)] +pub struct SetUserReplacedAt { + pub uid: i64, + pub replaced_at: i64, +} + +#[cfg(test)] +#[derive(Default)] +pub struct GetUser { + pub id: i64, +} + +#[cfg(test)] +pub struct UnassignNode { + pub node_id: i64, +} + +#[cfg(test)] +pub struct RemoveNode { + pub node_id: i64, +} diff --git a/src/tokenserver/db/results.rs b/src/tokenserver/db/results.rs index 051d296b11..9d3efc0f07 100644 --- a/src/tokenserver/db/results.rs +++ b/src/tokenserver/db/results.rs @@ -16,8 +16,8 @@ pub struct GetRawUser { pub client_state: String, #[sql_type = "Bigint"] pub generation: i64, - #[sql_type = "Text"] - pub node: String, + #[sql_type = "Nullable"] + pub node: Option, #[sql_type = "Nullable"] pub keys_changed_at: Option, #[sql_type = "Bigint"] @@ -26,57 +26,128 @@ pub struct GetRawUser { pub replaced_at: Option, } -#[cfg(test)] -pub type GetRawUsers = Vec; +#[derive(Debug, Default, PartialEq)] +pub struct AllocateUser { + pub uid: i64, + pub node: String, + pub created_at: i64, +} /// Represents the relevant information from the most recently-created user record in the database /// for a given email and service ID, along with any previously-seen client states seen for the /// user. #[derive(Debug, Default, PartialEq)] -pub struct GetUser { +pub struct GetOrCreateUser { pub uid: i64, + pub email: String, pub client_state: String, pub generation: i64, pub node: String, pub keys_changed_at: Option, pub created_at: i64, + pub replaced_at: Option, pub old_client_states: Vec, } -#[cfg(test)] #[derive(Default, QueryableByName)] -pub struct PostNode { +pub struct PostUser { #[sql_type = "Bigint"] pub id: i64, } -#[cfg(test)] +pub type ReplaceUsers = (); +pub type ReplaceUser = (); +pub type PutUser = (); + #[derive(Default, QueryableByName)] -pub struct PostService { - #[sql_type = "Integer"] - pub id: i32, +pub struct GetNodeId { + #[sql_type = "Bigint"] + pub id: i64, } #[derive(Default, QueryableByName)] -pub struct PostUser { +pub struct GetBestNode { #[sql_type = "Bigint"] pub id: i64, + #[sql_type = "Text"] + pub node: String, } -pub type ReplaceUsers = (); -pub type ReplaceUser = (); -pub type PutUser = (); +pub type AddUserToNode = (); +#[cfg(test)] +pub type GetRawUsers = Vec; + +#[cfg(test)] +#[derive(Debug, Default, PartialEq, QueryableByName)] +pub struct GetUser { + #[sql_type = "Integer"] + #[column_name = "service"] + pub service_id: i32, + #[sql_type = "Text"] + pub email: String, + #[sql_type = "Bigint"] + pub generation: i64, + #[sql_type = "Text"] + pub client_state: String, + #[sql_type = "Nullable"] + pub replaced_at: Option, + #[sql_type = "Bigint"] + #[column_name = "nodeid"] + pub node_id: i64, + #[sql_type = "Nullable"] + pub keys_changed_at: Option, +} + +#[cfg(test)] #[derive(Default, QueryableByName)] -pub struct GetNodeId { +pub struct PostNode { #[sql_type = "Bigint"] pub id: i64, } +#[cfg(test)] +#[derive(Default, QueryableByName)] +pub struct GetNode { + #[sql_type = "Bigint"] + pub id: i64, + #[sql_type = "Integer"] + #[column_name = "service"] + pub service_id: i32, + #[sql_type = "Text"] + pub node: String, + #[sql_type = "Integer"] + pub available: i32, + #[sql_type = "Integer"] + pub current_load: i32, + #[sql_type = "Integer"] + pub capacity: i32, + #[sql_type = "Integer"] + pub downed: i32, + #[sql_type = "Integer"] + pub backoff: i32, +} + +#[cfg(test)] +#[derive(Default, QueryableByName)] +pub struct PostService { + #[sql_type = "Integer"] + pub id: i32, +} + #[cfg(test)] pub type SetUserCreatedAt = (); +#[cfg(test)] +pub type SetUserReplacedAt = (); + #[cfg(test)] pub type GetUsers = Vec; pub type Check = bool; + +#[cfg(test)] +pub type UnassignNode = (); + +#[cfg(test)] +pub type RemoveNode = (); diff --git a/src/tokenserver/error.rs b/src/tokenserver/error.rs index 2b3aeecb7a..d7308bb1ad 100644 --- a/src/tokenserver/error.rs +++ b/src/tokenserver/error.rs @@ -8,17 +8,17 @@ use serde::{ #[derive(Debug, PartialEq)] pub struct TokenserverError { - status: &'static str, - location: ErrorLocation, - name: String, - description: &'static str, - http_status: StatusCode, + pub status: &'static str, + pub location: ErrorLocation, + pub name: String, + pub description: &'static str, + pub http_status: StatusCode, } impl Default for TokenserverError { fn default() -> Self { Self { - status: "", + status: "error", location: ErrorLocation::default(), name: "".to_owned(), description: "Unauthorized", @@ -55,6 +55,7 @@ impl TokenserverError { pub fn invalid_credentials(description: &'static str) -> Self { Self { status: "invalid-credentials", + location: ErrorLocation::Body, description, ..Self::default() } @@ -63,7 +64,6 @@ impl TokenserverError { pub fn invalid_client_state(description: &'static str) -> Self { Self { status: "invalid-client-state", - location: ErrorLocation::Body, description, name: "X-Client-State".to_owned(), ..Self::default() @@ -92,7 +92,7 @@ impl TokenserverError { pub fn unauthorized(description: &'static str) -> Self { Self { - status: "error", + location: ErrorLocation::Body, description, ..Self::default() } diff --git a/src/tokenserver/extractors.rs b/src/tokenserver/extractors.rs index f40df015df..3e0d1c83f7 100644 --- a/src/tokenserver/extractors.rs +++ b/src/tokenserver/extractors.rs @@ -17,7 +17,7 @@ use serde::Deserialize; use sha2::Sha256; use super::db::{self, models::Db, params, results}; -use super::error::TokenserverError; +use super::error::{ErrorLocation, TokenserverError}; use super::support::TokenData; use super::ServerState; use crate::settings::Secrets; @@ -27,7 +27,7 @@ const DEFAULT_TOKEN_DURATION: u64 = 5 * 60; /// Information from the request needed to process a Tokenserver request. #[derive(Debug, Default, PartialEq)] pub struct TokenserverRequest { - pub user: results::GetUser, + pub user: results::GetOrCreateUser, pub fxa_uid: String, pub email: String, pub generation: Option, @@ -189,9 +189,13 @@ impl FromRequest for TokenserverRequest { TokenserverError::internal_error() })?; - db.get_user(params::GetUser { - email: email.clone(), + db.get_or_create_user(params::GetOrCreateUser { service_id, + email: email.clone(), + generation: token_data.generation.unwrap_or(0), + client_state: key_id.client_state.clone(), + keys_changed_at: Some(key_id.keys_changed_at), + capacity_release_rate: state.node_capacity_release_rate, }) .await? }; @@ -321,12 +325,12 @@ impl FromRequest for KeyId { let (keys_changed_at_string, encoded_client_state) = x_key_id .split_once("-") - .ok_or_else(|| TokenserverError::invalid_key_id("Invalid X-KeyID header"))?; + .ok_or_else(|| TokenserverError::invalid_credentials("Unauthorized"))?; let client_state = { let client_state_bytes = base64::decode_config(encoded_client_state, base64::URL_SAFE_NO_PAD) - .map_err(|_| TokenserverError::invalid_key_id("Invalid base64 encoding"))?; + .map_err(|_| TokenserverError::invalid_credentials("Unauthorized"))?; let client_state = hex::encode(client_state_bytes); @@ -337,7 +341,12 @@ impl FromRequest for KeyId { .and_then(|header| header.to_str().ok()); if let Some(x_client_state) = maybe_x_client_state { if x_client_state != client_state { - return Err(TokenserverError::invalid_client_state("Unauthorized").into()); + return Err(TokenserverError { + status: "invalid-client-state", + location: ErrorLocation::Body, + ..TokenserverError::default() + } + .into()); } } @@ -346,7 +355,7 @@ impl FromRequest for KeyId { let keys_changed_at = keys_changed_at_string .parse::() - .map_err(|_| TokenserverError::invalid_credentials("invalid keysChangedAt"))?; + .map_err(|_| TokenserverError::invalid_credentials("Unauthorized"))?; Ok(KeyId { client_state, @@ -455,7 +464,7 @@ mod tests { .await .unwrap(); let expected_tokenserver_request = TokenserverRequest { - user: results::GetUser::default(), + user: results::GetOrCreateUser::default(), fxa_uid: fxa_uid.to_owned(), email: "test123@test.com".to_owned(), generation: Some(1234), @@ -682,7 +691,7 @@ mod tests { let response: HttpResponse = KeyId::extract(&request).await.unwrap_err().into(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let expected_error = TokenserverError::invalid_key_id("Invalid X-KeyID header"); + let expected_error = TokenserverError::invalid_credentials("Unauthorized"); let body = extract_body_as_str(ServiceResponse::new(request, response)); assert_eq!(body, serde_json::to_string(&expected_error).unwrap()); } @@ -695,7 +704,7 @@ mod tests { let response: HttpResponse = KeyId::extract(&request).await.unwrap_err().into(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let expected_error = TokenserverError::invalid_key_id("Invalid base64 encoding"); + let expected_error = TokenserverError::invalid_credentials("Unauthorized"); let body = extract_body_as_str(ServiceResponse::new(request, response)); assert_eq!(body, serde_json::to_string(&expected_error).unwrap()); } @@ -721,7 +730,7 @@ mod tests { let response: HttpResponse = KeyId::extract(&request).await.unwrap_err().into(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let expected_error = TokenserverError::invalid_credentials("invalid keysChangedAt"); + let expected_error = TokenserverError::invalid_credentials("Unauthorized"); let body = extract_body_as_str(ServiceResponse::new(request, response)); assert_eq!(body, serde_json::to_string(&expected_error).unwrap()); } @@ -735,7 +744,11 @@ mod tests { let response: HttpResponse = KeyId::extract(&request).await.unwrap_err().into(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let expected_error = TokenserverError::invalid_client_state("Unauthorized"); + let expected_error = TokenserverError { + status: "invalid-client-state", + location: ErrorLocation::Body, + ..TokenserverError::default() + }; let body = extract_body_as_str(ServiceResponse::new(request, response)); assert_eq!(body, serde_json::to_string(&expected_error).unwrap()); } @@ -775,12 +788,14 @@ mod tests { // The request includes a generation that is less than the generation currently stored on // the user record let tokenserver_request = TokenserverRequest { - user: results::GetUser { + user: results::GetOrCreateUser { uid: 1, + email: "test@test.com".to_owned(), client_state: "616161".to_owned(), generation: 1234, node: "node".to_owned(), keys_changed_at: Some(1234), + replaced_at: None, created_at: 1234, old_client_states: vec![], }, @@ -805,13 +820,15 @@ mod tests { // The request includes a keys_changed_at that is less than the keys_changed_at currently // stored on the user record let tokenserver_request = TokenserverRequest { - user: results::GetUser { + user: results::GetOrCreateUser { uid: 1, + email: "test@test.com".to_owned(), client_state: "616161".to_owned(), generation: 1234, node: "node".to_owned(), keys_changed_at: Some(1234), created_at: 1234, + replaced_at: None, old_client_states: vec![], }, fxa_uid: "test".to_owned(), @@ -834,13 +851,15 @@ mod tests { async fn test_keys_changed_without_generation_change() { // The request includes a new value for keys_changed_at without a new value for generation let tokenserver_request = TokenserverRequest { - user: results::GetUser { + user: results::GetOrCreateUser { uid: 1, + email: "test@test.com".to_owned(), client_state: "616161".to_owned(), generation: 1234, node: "node".to_owned(), keys_changed_at: Some(1234), created_at: 1234, + replaced_at: None, old_client_states: vec![], }, fxa_uid: "test".to_owned(), @@ -864,13 +883,15 @@ mod tests { // The request includes a previously-used client state that is not the user's current // client state let tokenserver_request = TokenserverRequest { - user: results::GetUser { + user: results::GetOrCreateUser { uid: 1, + email: "test@test.com".to_owned(), client_state: "616161".to_owned(), generation: 1234, node: "node".to_owned(), keys_changed_at: Some(1234), created_at: 1234, + replaced_at: None, old_client_states: vec!["626262".to_owned()], }, fxa_uid: "test".to_owned(), @@ -894,13 +915,15 @@ mod tests { async fn test_new_client_state_without_generation_change() { // The request includes a new client state without a new generation value let tokenserver_request = TokenserverRequest { - user: results::GetUser { + user: results::GetOrCreateUser { uid: 1, + email: "test@test.com".to_owned(), client_state: "616161".to_owned(), generation: 1234, node: "node".to_owned(), keys_changed_at: Some(1234), created_at: 1234, + replaced_at: None, old_client_states: vec![], }, fxa_uid: "test".to_owned(), @@ -924,13 +947,15 @@ mod tests { async fn test_new_client_state_without_key_change() { // The request includes a new client state without a new keys_changed_at value let tokenserver_request = TokenserverRequest { - user: results::GetUser { + user: results::GetOrCreateUser { uid: 1, + email: "test@test.com".to_owned(), client_state: "616161".to_owned(), generation: 1234, node: "node".to_owned(), keys_changed_at: Some(1234), created_at: 1234, + replaced_at: None, old_client_states: vec![], }, fxa_uid: "test".to_owned(), @@ -961,6 +986,7 @@ mod tests { fxa_metrics_hash_secret: "".to_owned(), oauth_verifier: Box::new(verifier), db_pool: Box::new(MockTokenserverPool::new()), + node_capacity_release_rate: None, } } } diff --git a/src/tokenserver/handlers.rs b/src/tokenserver/handlers.rs index ea53d2991e..16ddb13ad0 100644 --- a/src/tokenserver/handlers.rs +++ b/src/tokenserver/handlers.rs @@ -10,6 +10,7 @@ use serde_json::Value; use super::db::models::Db; use super::db::params::{GetNodeId, PostUser, PutUser, ReplaceUsers}; +use super::error::TokenserverError; use super::extractors::TokenserverRequest; use super::support::Tokenlib; use crate::tokenserver::support::MakeTokenPlaintext; @@ -31,7 +32,7 @@ pub async fn get_tokenserver_result( let updates = update_user(&req, db).await?; let (token, derived_secret) = { - let token_plaintext = get_token_plaintext(&req, &updates); + let token_plaintext = get_token_plaintext(&req, &updates)?; Tokenlib::get_token_and_derived_secret(token_plaintext, &req.shared_secret)? }; @@ -47,9 +48,19 @@ pub async fn get_tokenserver_result( Ok(HttpResponse::build(StatusCode::OK).json(result)) } -fn get_token_plaintext(req: &TokenserverRequest, updates: &UserUpdates) -> MakeTokenPlaintext { +fn get_token_plaintext( + req: &TokenserverRequest, + updates: &UserUpdates, +) -> Result { let fxa_kid = { - let client_state_b64 = base64::encode_config(&req.client_state, base64::URL_SAFE_NO_PAD); + // If decoding the hex bytes fails, it means we did something wrong when we stored the + // client state in the database + let client_state = hex::decode(req.client_state.clone()).map_err(|_| { + error!("⚠️ Failed to decode client state hex"); + + TokenserverError::internal_error() + })?; + let client_state_b64 = base64::encode_config(&client_state, base64::URL_SAFE_NO_PAD); format!("{:013}-{:}", updates.keys_changed_at, client_state_b64) }; @@ -62,7 +73,7 @@ fn get_token_plaintext(req: &TokenserverRequest, updates: &UserUpdates) -> MakeT expires.as_secs() }; - MakeTokenPlaintext { + Ok(MakeTokenPlaintext { node: req.user.node.to_owned(), fxa_kid, fxa_uid: req.fxa_uid.clone(), @@ -70,7 +81,7 @@ fn get_token_plaintext(req: &TokenserverRequest, updates: &UserUpdates) -> MakeT hashed_fxa_uid: req.hashed_fxa_uid.clone(), expires, uid: updates.uid.to_owned(), - } + }) } struct UserUpdates { @@ -115,7 +126,6 @@ async fn update_user(req: &TokenserverRequest, db: Box) -> Result, + pub node_capacity_release_rate: Option, } impl ServerState { @@ -45,6 +46,7 @@ impl ServerState { fxa_metrics_hash_secret: settings.fxa_metrics_hash_secret.clone(), oauth_verifier, db_pool: Box::new(db_pool), + node_capacity_release_rate: settings.node_capacity_release_rate, }) .map_err(Into::into) } diff --git a/src/tokenserver/settings.rs b/src/tokenserver/settings.rs index a65ddd9e1c..465edbedc4 100644 --- a/src/tokenserver/settings.rs +++ b/src/tokenserver/settings.rs @@ -28,6 +28,9 @@ pub struct Settings { /// When test mode is enabled, OAuth tokens are unpacked without being verified. pub test_mode_enabled: bool, + + /// The rate at which capacity should be released from nodes that are at capacity. + pub node_capacity_release_rate: Option, } impl Default for Settings { @@ -42,6 +45,7 @@ impl Default for Settings { fxa_metrics_hash_secret: "secret".to_owned(), fxa_oauth_server_url: None, test_mode_enabled: false, + node_capacity_release_rate: None, } } } diff --git a/tools/integration_tests/tokenserver/test_misc.py b/tools/integration_tests/tokenserver/test_misc.py index b9cd7e0eeb..5a12dda73e 100644 --- a/tools/integration_tests/tokenserver/test_misc.py +++ b/tools/integration_tests/tokenserver/test_misc.py @@ -5,6 +5,8 @@ from tokenserver.test_support import TestCase +MAX_GENERATION = 9223372036854775807 + class TestMisc(TestCase, unittest.TestCase): def setUp(self): @@ -171,3 +173,72 @@ def test_user_updates_with_same_client_state(self): user = self._get_user(uid) self.assertEqual(user['generation'], 1235) self.assertEqual(user['keys_changed_at'], 1235) + + def test_retired_users_can_make_requests(self): + # Add a retired user to the database + self._add_user(generation=MAX_GENERATION) + oauth_token = self._forge_oauth_token(generation=1234) + headers = { + 'Authorization': 'Bearer %s' % oauth_token, + 'X-KeyID': '1234-YWFh' + } + # Retired users cannot make requests with a generation smaller than + # the max generation + res = self.app.get('/1.0/sync/1.5', headers=headers, status=401) + expected_error_response = { + "status": "invalid-generation", + "errors": [ + { + "location": "body", + "name": "", + "description": "Unauthorized" + } + ] + } + self.assertEqual(res.json, expected_error_response) + # Retired users can make requests with a generation number equal to + # the max generation + oauth_token = self._forge_oauth_token(generation=MAX_GENERATION) + headers['Authorization'] = 'Bearer %s' % oauth_token + self.app.get('/1.0/sync/1.5', headers=headers) + + def test_replaced_users_can_make_requests(self): + # Add a replaced user to the database + self._add_user(generation=1234, created_at=1234, replaced_at=1234) + oauth_token = self._forge_oauth_token(generation=1234) + headers = { + 'Authorization': 'Bearer %s' % oauth_token, + 'X-KeyID': '1234-YWFh' + } + # Replaced users can make requests + self.app.get('/1.0/sync/1.5', headers=headers) + + def test_retired_users_with_no_node_cannot_make_requests(self): + # Add a retired user to the database + invalid_node_id = self.NODE_ID + 1 + self._add_user(generation=MAX_GENERATION, nodeid=invalid_node_id) + oauth_token = self._forge_oauth_token(generation=1234) + headers = { + 'Authorization': 'Bearer %s' % oauth_token, + 'X-KeyID': '1234-YWFh' + } + # Retired users without a node cannot make requests + oauth_token = self._forge_oauth_token(generation=MAX_GENERATION) + headers['Authorization'] = 'Bearer %s' % oauth_token + self.app.get('/1.0/sync/1.5', headers=headers, status=500) + + def test_replaced_users_with_no_node_can_make_requests(self): + # Add a replaced user to the database + invalid_node_id = self.NODE_ID + 1 + self._add_user(created_at=1234, replaced_at=1234, + nodeid=invalid_node_id) + oauth_token = self._forge_oauth_token(generation=1234) + headers = { + 'Authorization': 'Bearer %s' % oauth_token, + 'X-KeyID': '1234-YWFh' + } + # Replaced users without a node can make requests + res = self.app.get('/1.0/sync/1.5', headers=headers) + user = self._get_user(res.json['uid']) + # The user is assigned to a new node + self.assertEqual(user['nodeid'], self.NODE_ID)