diff --git a/autoendpoint/src/auth.rs b/autoendpoint/src/auth.rs new file mode 100644 index 000000000..7bf6485bc --- /dev/null +++ b/autoendpoint/src/auth.rs @@ -0,0 +1,13 @@ +use openssl::error::ErrorStack; +use openssl::hash::MessageDigest; +use openssl::pkey::PKey; +use openssl::sign::Signer; + +/// Sign some data with a key and return the hex representation +pub fn sign_with_key(key: &[u8], data: &[u8]) -> Result { + let key = PKey::hmac(key)?; + let mut signer = Signer::new(MessageDigest::sha256(), &key)?; + + signer.update(data)?; + Ok(hex::encode(signer.sign_to_vec()?)) +} diff --git a/autoendpoint/src/db/client.rs b/autoendpoint/src/db/client.rs index dc3029198..a49352faf 100644 --- a/autoendpoint/src/db/client.rs +++ b/autoendpoint/src/db/client.rs @@ -29,6 +29,11 @@ pub trait DbClient: Send + Sync { /// exists. async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()>; + /// Update a user in the database. An error will occur if the user does not + /// already exist, has a different router type, or has a newer + /// `connected_at` timestamp. + async fn update_user(&self, user: &DynamoDbUser) -> DbResult<()>; + /// Read a user from the database async fn get_user(&self, uaid: Uuid) -> DbResult>; @@ -137,21 +142,63 @@ impl DbClientImpl { impl DbClient for DbClientImpl { async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()> { let input = PutItemInput { + table_name: self.router_table.clone(), item: serde_dynamodb::to_hashmap(user)?, - table_name: self.router_table.to_string(), condition_expression: Some("attribute_not_exists(uaid)".to_string()), ..Default::default() }; retry_policy() .retry_if( - move || self.ddb.put_item(input.clone()), + || self.ddb.put_item(input.clone()), retryable_putitem_error(self.metrics.clone()), ) .await?; Ok(()) } + async fn update_user(&self, user: &DynamoDbUser) -> DbResult<()> { + let mut user_map = serde_dynamodb::to_hashmap(&user)?; + user_map.remove("uaid"); + let input = UpdateItemInput { + table_name: self.router_table.clone(), + key: ddb_item! { uaid: s => user.uaid.to_simple().to_string() }, + update_expression: Some(format!( + "SET {}", + user_map + .keys() + .map(|key| format!("{0}=:{0}", key)) + .collect::>() + .join(", ") + )), + expression_attribute_values: Some( + user_map + .into_iter() + .map(|(key, value)| (format!(":{}", key), value)) + .collect(), + ), + condition_expression: Some( + "attribute_exists(uaid) and ( + attribute_not_exists(router_type) or + (router_type = :router_type) + ) and ( + attribute_not_exists(node_id) or + (connected_at < :connected_at) + )" + .to_string(), + ), + ..Default::default() + }; + + retry_policy() + .retry_if( + || self.ddb.update_item(input.clone()), + retryable_updateitem_error(self.metrics.clone()), + ) + .await?; + Ok(()) + } + async fn get_user(&self, uaid: Uuid) -> DbResult> { let input = GetItemInput { table_name: self.router_table.clone(), @@ -245,8 +292,7 @@ impl DbClient for DbClientImpl { // Convert the IDs from String to Uuid let channels = channels .into_iter() - .map(|s| Uuid::parse_str(&s)) - .filter_map(Result::ok) + .filter_map(|s| Uuid::parse_str(&s).ok()) .collect(); Ok(channels) diff --git a/autoendpoint/src/db/mock.rs b/autoendpoint/src/db/mock.rs index f81bad99e..aa4ca60f1 100644 --- a/autoendpoint/src/db/mock.rs +++ b/autoendpoint/src/db/mock.rs @@ -17,6 +17,8 @@ mockall::mock! { pub DbClient { fn add_user(&self, user: &DynamoDbUser) -> DbResult<()>; + fn update_user(&self, user: &DynamoDbUser) -> DbResult<()>; + fn get_user(&self, uaid: Uuid) -> DbResult>; fn remove_user(&self, uaid: Uuid) -> DbResult<()>; @@ -47,6 +49,10 @@ impl DbClient for Arc { Arc::as_ref(self).add_user(user) } + async fn update_user(&self, user: &DynamoDbUser) -> DbResult<()> { + Arc::as_ref(self).update_user(user) + } + async fn get_user(&self, uaid: Uuid) -> DbResult> { Arc::as_ref(self).get_user(uaid) } diff --git a/autoendpoint/src/error.rs b/autoendpoint/src/error.rs index 12ea0b279..974a70866 100644 --- a/autoendpoint/src/error.rs +++ b/autoendpoint/src/error.rs @@ -110,6 +110,9 @@ pub enum ApiErrorKind { #[error("{0}")] Internal(String), + + #[error("Invalid Authentication")] + InvalidAuthentication, } impl ApiErrorKind { @@ -121,18 +124,20 @@ impl ApiErrorKind { ApiErrorKind::Validation(_) | ApiErrorKind::InvalidEncryption(_) - | ApiErrorKind::TokenHashValidation(_) | ApiErrorKind::NoTTL | ApiErrorKind::InvalidRouterType | ApiErrorKind::InvalidRouterToken | ApiErrorKind::InvalidMessageId => StatusCode::BAD_REQUEST, - ApiErrorKind::NoUser | ApiErrorKind::NoSubscription => StatusCode::GONE, - - ApiErrorKind::VapidError(_) | ApiErrorKind::Jwt(_) => StatusCode::UNAUTHORIZED, + ApiErrorKind::VapidError(_) + | ApiErrorKind::Jwt(_) + | ApiErrorKind::TokenHashValidation(_) + | ApiErrorKind::InvalidAuthentication => StatusCode::UNAUTHORIZED, ApiErrorKind::InvalidToken | ApiErrorKind::InvalidApiVersion => StatusCode::NOT_FOUND, + ApiErrorKind::NoUser | ApiErrorKind::NoSubscription => StatusCode::GONE, + ApiErrorKind::Io(_) | ApiErrorKind::Metrics(_) | ApiErrorKind::Database(_) @@ -166,7 +171,8 @@ impl ApiErrorKind { ApiErrorKind::VapidError(_) | ApiErrorKind::TokenHashValidation(_) - | ApiErrorKind::Jwt(_) => Some(109), + | ApiErrorKind::Jwt(_) + | ApiErrorKind::InvalidAuthentication => Some(109), ApiErrorKind::InvalidEncryption(_) => Some(110), diff --git a/autoendpoint/src/extractors/authorization_check.rs b/autoendpoint/src/extractors/authorization_check.rs new file mode 100644 index 000000000..1e5153eb8 --- /dev/null +++ b/autoendpoint/src/extractors/authorization_check.rs @@ -0,0 +1,69 @@ +use crate::auth::sign_with_key; +use crate::error::{ApiError, ApiErrorKind}; +use crate::headers::util::get_header; +use crate::server::ServerState; +use actix_web::dev::{Payload, PayloadStream}; +use actix_web::web::Data; +use actix_web::{FromRequest, HttpRequest}; +use futures::future::LocalBoxFuture; +use futures::FutureExt; +use uuid::Uuid; + +/// Verifies the request authorization via the authorization header. +/// +/// The expected token is the HMAC-SHA256 hash of the UAID, signed with one of +/// the available keys (allows for key rotation). +pub struct AuthorizationCheck; + +impl FromRequest for AuthorizationCheck { + type Error = ApiError; + type Future = LocalBoxFuture<'static, Result>; + type Config = (); + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let req = req.clone(); + + async move { + let uaid = req + .match_info() + .get("uaid") + .expect("{uaid} must be part of the path") + .parse::() + .map_err(|_| ApiErrorKind::NoUser)?; + let state: Data = Data::extract(&req) + .into_inner() + .expect("No server state found"); + let auth_header = + get_header(&req, "Authorization").ok_or(ApiErrorKind::InvalidAuthentication)?; + let token = get_token_from_auth_header(auth_header) + .ok_or(ApiErrorKind::InvalidAuthentication)?; + + // Check the token against the expected token for each key + for key in state.settings.auth_keys() { + let expected_token = sign_with_key(key.as_bytes(), uaid.as_bytes()) + .map_err(ApiErrorKind::RegistrationSecretHash)?; + + if expected_token.len() == token.len() + && openssl::memcmp::eq(expected_token.as_bytes(), token.as_bytes()) + { + return Ok(Self); + } + } + + Err(ApiErrorKind::InvalidAuthentication.into()) + } + .boxed_local() + } +} + +/// Get the token from a bearer authorization header +fn get_token_from_auth_header(header: &str) -> Option<&str> { + let mut split = header.splitn(2, ' '); + let scheme = split.next()?; + + if scheme.to_lowercase() != "bearer" { + return None; + } + + split.next() +} diff --git a/autoendpoint/src/extractors/mod.rs b/autoendpoint/src/extractors/mod.rs index 3526d3a5d..34c817021 100644 --- a/autoendpoint/src/extractors/mod.rs +++ b/autoendpoint/src/extractors/mod.rs @@ -1,10 +1,12 @@ //! Actix extractors (`FromRequest`). These extractors transform and validate //! the incoming request data. +pub mod authorization_check; pub mod message_id; pub mod notification; pub mod notification_headers; pub mod registration_path_args; +pub mod registration_path_args_with_uaid; pub mod router_data_input; pub mod routers; pub mod subscription; diff --git a/autoendpoint/src/extractors/registration_path_args_with_uaid.rs b/autoendpoint/src/extractors/registration_path_args_with_uaid.rs new file mode 100644 index 000000000..18a59dfda --- /dev/null +++ b/autoendpoint/src/extractors/registration_path_args_with_uaid.rs @@ -0,0 +1,53 @@ +use crate::error::{ApiError, ApiErrorKind}; +use crate::extractors::registration_path_args::RegistrationPathArgs; +use crate::extractors::routers::RouterType; +use crate::server::ServerState; +use actix_web::dev::{Payload, PayloadStream}; +use actix_web::web::Data; +use actix_web::{FromRequest, HttpRequest}; +use futures::future::LocalBoxFuture; +use futures::FutureExt; +use uuid::Uuid; + +/// An extension of `RegistrationPathArgs` which requires a `uaid` path arg. +/// The `uaid` is verified by checking if the user exists in the database. +pub struct RegistrationPathArgsWithUaid { + pub router_type: RouterType, + pub app_id: String, + pub uaid: Uuid, +} + +impl FromRequest for RegistrationPathArgsWithUaid { + type Error = ApiError; + type Future = LocalBoxFuture<'static, Result>; + type Config = (); + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let req = req.clone(); + + async move { + let state: Data = Data::extract(&req) + .into_inner() + .expect("No server state found"); + let path_args = RegistrationPathArgs::extract(&req).into_inner()?; + let uaid = req + .match_info() + .get("uaid") + .expect("{uaid} must be part of the path") + .parse::() + .map_err(|_| ApiErrorKind::NoUser)?; + + // Verify that the user exists + if state.ddb.get_user(uaid).await?.is_none() { + return Err(ApiErrorKind::NoUser.into()); + } + + Ok(Self { + router_type: path_args.router_type, + app_id: path_args.app_id, + uaid, + }) + } + .boxed_local() + } +} diff --git a/autoendpoint/src/main.rs b/autoendpoint/src/main.rs index 7d75c2acf..f7bb04e1c 100644 --- a/autoendpoint/src/main.rs +++ b/autoendpoint/src/main.rs @@ -3,6 +3,7 @@ #[macro_use] extern crate slog_scope; +mod auth; mod db; mod error; mod extractors; diff --git a/autoendpoint/src/routes/registration.rs b/autoendpoint/src/routes/registration.rs index 47d90de6d..aee89b611 100644 --- a/autoendpoint/src/routes/registration.rs +++ b/autoendpoint/src/routes/registration.rs @@ -1,5 +1,8 @@ +use crate::auth::sign_with_key; use crate::error::{ApiErrorKind, ApiResult}; +use crate::extractors::authorization_check::AuthorizationCheck; use crate::extractors::registration_path_args::RegistrationPathArgs; +use crate::extractors::registration_path_args_with_uaid::RegistrationPathArgsWithUaid; use crate::extractors::router_data_input::RouterDataInput; use crate::extractors::routers::Routers; use crate::headers::util::get_header; @@ -9,10 +12,6 @@ use actix_web::{HttpRequest, HttpResponse}; use autopush_common::db::DynamoDbUser; use autopush_common::endpoint::make_endpoint; use cadence::{Counted, StatsdClient}; -use openssl::error::ErrorStack; -use openssl::hash::MessageDigest; -use openssl::pkey::PKey; -use openssl::sign::Signer; use uuid::Uuid; /// Handle the `POST /v1/{router_type}/{app_id}/registration` route @@ -77,6 +76,38 @@ pub async fn register_uaid_route( }))) } +/// Handle the `PUT /v1/{router_type}/{app_id}/registration/{uaid}` route +pub async fn update_token_route( + _auth: AuthorizationCheck, + path_args: RegistrationPathArgsWithUaid, + router_data_input: RouterDataInput, + routers: Routers, + state: Data, +) -> ApiResult { + // Re-register with router + debug!( + "Updating the token of UAID {} with the {} router", + path_args.uaid, path_args.router_type + ); + trace!("token = {}", router_data_input.token); + let router = routers.get(path_args.router_type); + let router_data = router.register(&router_data_input, &path_args.app_id)?; + + // Update the user in the database + let user = DynamoDbUser { + uaid: path_args.uaid, + router_type: path_args.router_type.to_string(), + router_data: Some(router_data), + ..Default::default() + }; + trace!("Updating user with UAID {}", user.uaid); + trace!("user = {:?}", user); + state.ddb.update_user(&user).await?; + + trace!("Finished updating token for UAID {}", user.uaid); + Ok(HttpResponse::Ok().finish()) +} + /// Increment a metric with data from the request fn incr_metric(name: &str, metrics: &StatsdClient, request: &HttpRequest) { metrics @@ -88,12 +119,3 @@ fn incr_metric(name: &str, metrics: &StatsdClient, request: &HttpRequest) { .with_tag("host", get_header(&request, "Host").unwrap_or("unknown")) .send() } - -/// Sign some data with a key and return the hex representation -fn sign_with_key(key: &[u8], data: &[u8]) -> Result { - let key = PKey::hmac(key)?; - let mut signer = Signer::new(MessageDigest::sha256(), &key)?; - - signer.update(data)?; - Ok(hex::encode(signer.sign_to_vec()?)) -} diff --git a/autoendpoint/src/server.rs b/autoendpoint/src/server.rs index 8b7f8af53..a7dc6e7ab 100644 --- a/autoendpoint/src/server.rs +++ b/autoendpoint/src/server.rs @@ -6,7 +6,7 @@ use crate::metrics; use crate::middleware::sentry::sentry_middleware; use crate::routers::fcm::router::FcmRouter; use crate::routes::health::{health_route, lb_heartbeat_route, status_route, version_route}; -use crate::routes::registration::register_uaid_route; +use crate::routes::registration::{register_uaid_route, update_token_route}; use crate::routes::webpush::{delete_notification_route, webpush_route}; use crate::settings::Settings; use actix_cors::Cors; @@ -86,6 +86,10 @@ impl Server { web::resource("/v1/{router_type}/{app_id}/registration") .route(web::post().to(register_uaid_route)), ) + .service( + web::resource("/v1/{router_type}/{app_id}/registration/{uaid}") + .route(web::put().to(update_token_route)), + ) // Health checks .service(web::resource("/status").route(web::get().to(status_route))) .service(web::resource("/health").route(web::get().to(health_route))) diff --git a/autoendpoint/src/settings.rs b/autoendpoint/src/settings.rs index 88eaf41ff..973327efa 100644 --- a/autoendpoint/src/settings.rs +++ b/autoendpoint/src/settings.rs @@ -103,7 +103,7 @@ impl Settings { /// Initialize the fernet encryption instance pub fn make_fernet(&self) -> MultiFernet { - let fernets = Self::read_list_from_str(&self.crypto_keys, "Invalid AUTOEND_CRYPTO_KEY") + let fernets = Self::read_list_from_str(&self.crypto_keys, "Invalid AUTOEND_CRYPTO_KEYS") .map(|key| Fernet::new(key).expect("Invalid AUTOEND_CRYPTO_KEYS")) .collect(); MultiFernet::new(fernets)