diff --git a/autoendpoint/src/routers/adm/client.rs b/autoendpoint/src/routers/adm/client.rs index 581a2ff06..56715bd5f 100644 --- a/autoendpoint/src/routers/adm/client.rs +++ b/autoendpoint/src/routers/adm/client.rs @@ -1,5 +1,6 @@ use crate::routers::adm::error::AdmError; use crate::routers::adm::settings::{AdmProfile, AdmSettings}; +use crate::routers::common::message_size_check; use crate::routers::RouterError; use autopush_common::util::sec_since_epoch; use futures::lock::Mutex; @@ -15,6 +16,7 @@ pub struct AdmClient { base_url: Url, profile: AdmProfile, timeout: Duration, + max_data: usize, http: reqwest::Client, token_info: Mutex, } @@ -53,6 +55,7 @@ impl AdmClient { base_url: settings.base_url.clone(), profile, timeout: Duration::from_secs(settings.timeout as u64), + max_data: settings.max_data, http, // The default TokenInfo has dummy values to trigger a token fetch token_info: Mutex::default(), @@ -122,6 +125,10 @@ impl AdmClient { "data": data, "expiresAfter": ttl, }); + let message_json = message.to_string(); + message_size_check(message_json.as_bytes(), self.max_data)?; + + // Prepare request data let access_token = self.get_access_token().await?; let url = self .base_url @@ -136,6 +143,7 @@ impl AdmClient { .http .post(url) .header("Authorization", format!("Bearer {}", access_token.as_str())) + .header("Content-Type", "application/json") .header("Accept", "application/json") .header( "X-Amzn-Type-Version", @@ -145,7 +153,7 @@ impl AdmClient { "X-Amzn-Accept-Type", "com.amazon.device.messaging.ADMSendResult@1.0", ) - .json(&message) + .body(message_json) .timeout(self.timeout) .send() .await diff --git a/autoendpoint/src/routers/adm/router.rs b/autoendpoint/src/routers/adm/router.rs index 25c6cb068..a9e733e79 100644 --- a/autoendpoint/src/routers/adm/router.rs +++ b/autoendpoint/src/routers/adm/router.rs @@ -99,7 +99,7 @@ impl Router for AdmRouter { .and_then(Value::as_str) .ok_or(AdmError::NoProfile)?; let ttl = MAX_TTL.min(self.settings.min_ttl.max(notification.headers.ttl as usize)); - let message_data = build_message_data(notification, self.settings.max_data)?; + let message_data = build_message_data(notification)?; // Send the notification to ADM let client = self.clients.get(profile).ok_or(AdmError::InvalidProfile)?; diff --git a/autoendpoint/src/routers/apns/router.rs b/autoendpoint/src/routers/apns/router.rs index 1c1a5acc1..abf515b23 100644 --- a/autoendpoint/src/routers/apns/router.rs +++ b/autoendpoint/src/routers/apns/router.rs @@ -4,7 +4,9 @@ use crate::extractors::notification::Notification; use crate::extractors::router_data_input::RouterDataInput; use crate::routers::apns::error::ApnsError; use crate::routers::apns::settings::{ApnsChannel, ApnsSettings}; -use crate::routers::common::{build_message_data, incr_error_metric, incr_success_metrics}; +use crate::routers::common::{ + build_message_data, incr_error_metric, incr_success_metrics, message_size_check, +}; use crate::routers::{Router, RouterError, RouterResponse}; use a2::request::notification::LocalizedAlert; use a2::request::payload::{APSAlert, Payload, APS}; @@ -220,7 +222,7 @@ impl Router for ApnsRouter { .map(|value| APS::deserialize(value).map_err(|_| ApnsError::InvalidApsData)) .transpose()? .unwrap_or_else(Self::default_aps); - let mut message_data = build_message_data(notification, self.settings.max_data)?; + let mut message_data = build_message_data(notification)?; message_data.insert("ver", notification.message_id.clone()); // Get client and build payload @@ -249,11 +251,7 @@ impl Router for ApnsRouter { .clone() .to_json_string() .map_err(ApnsError::SizeLimit)?; - if payload_json.len() > self.settings.max_data { - return Err( - RouterError::TooMuchData(payload_json.len() - self.settings.max_data).into(), - ); - } + message_size_check(payload_json.as_bytes(), self.settings.max_data)?; // Send to APNS trace!("Sending message to APNS: {:?}", payload); diff --git a/autoendpoint/src/routers/common.rs b/autoendpoint/src/routers/common.rs index 4333ff67d..61359df5b 100644 --- a/autoendpoint/src/routers/common.rs +++ b/autoendpoint/src/routers/common.rs @@ -8,21 +8,12 @@ use std::collections::HashMap; use uuid::Uuid; /// Convert a notification into a WebPush message -pub fn build_message_data( - notification: &Notification, - max_data: usize, -) -> ApiResult> { +pub fn build_message_data(notification: &Notification) -> ApiResult> { let mut message_data = HashMap::new(); message_data.insert("chid", notification.subscription.channel_id.to_string()); // Only add the other headers if there's data if let Some(data) = ¬ification.data { - if data.len() > max_data { - // Too much data. Tell the client how many bytes extra they had. - return Err(RouterError::TooMuchData(data.len() - max_data).into()); - } - - // Add the body and headers message_data.insert("body", data.clone()); message_data.insert_opt("con", notification.headers.encoding.as_ref()); message_data.insert_opt("enc", notification.headers.encryption.as_ref()); @@ -33,6 +24,17 @@ pub fn build_message_data( Ok(message_data) } +/// Check the data against the max data size and return an error if there is too +/// much data. +pub fn message_size_check(data: &[u8], max_data: usize) -> Result<(), RouterError> { + if data.len() > max_data { + trace!("Data is too long by {} bytes", data.len() - max_data); + Err(RouterError::TooMuchData(data.len() - max_data)) + } else { + Ok(()) + } +} + /// Handle a bridge error by logging, updating metrics, etc pub async fn handle_error( error: RouterError, diff --git a/autoendpoint/src/routers/fcm/client.rs b/autoendpoint/src/routers/fcm/client.rs index 842f0971a..c0cb90e9a 100644 --- a/autoendpoint/src/routers/fcm/client.rs +++ b/autoendpoint/src/routers/fcm/client.rs @@ -1,3 +1,4 @@ +use crate::routers::common::message_size_check; use crate::routers::fcm::error::FcmError; use crate::routers::fcm::settings::{FcmCredential, FcmSettings}; use crate::routers::RouterError; @@ -16,6 +17,7 @@ const OAUTH_SCOPES: &[&str] = &["https://www.googleapis.com/auth/firebase.messag pub struct FcmClient { endpoint: Url, timeout: Duration, + max_data: usize, auth: DefaultAuthenticator, http: reqwest::Client, } @@ -41,6 +43,7 @@ impl FcmClient { )) .expect("Project ID is not URL-safe"), timeout: Duration::from_secs(settings.timeout as u64), + max_data: settings.max_data, auth, http, }) @@ -53,6 +56,11 @@ impl FcmClient { token: String, ttl: usize, ) -> Result<(), RouterError> { + // Check the payload size. FCM only cares about the `data` field when + // checking size. + let data_json = serde_json::to_string(&data).unwrap(); + message_size_check(data_json.as_bytes(), self.max_data)?; + // Build the FCM message let message = serde_json::json!({ "message": { @@ -63,6 +71,7 @@ impl FcmClient { } } }); + let access_token = self .auth .token(OAUTH_SCOPES) @@ -74,8 +83,7 @@ impl FcmClient { .http .post(self.endpoint.clone()) .header("Authorization", format!("Bearer {}", access_token.as_str())) - .header("Content-Type", "application/json; UTF-8") - .body(message.to_string()) + .json(&message) .timeout(self.timeout) .send() .await @@ -208,7 +216,7 @@ pub mod tests { let _token_mock = mock_token_endpoint(); let fcm_mock = mock_fcm_endpoint_builder() .match_header("Authorization", format!("Bearer {}", ACCESS_TOKEN).as_str()) - .match_header("Content-Type", "application/json; UTF-8") + .match_header("Content-Type", "application/json") .match_body(r#"{"message":{"android":{"data":{"is_test":"true"},"ttl":"42s"},"token":"test-token"}}"#) .create(); diff --git a/autoendpoint/src/routers/fcm/router.rs b/autoendpoint/src/routers/fcm/router.rs index c36a76c3e..4ec989e3f 100644 --- a/autoendpoint/src/routers/fcm/router.rs +++ b/autoendpoint/src/routers/fcm/router.rs @@ -111,7 +111,7 @@ impl Router for FcmRouter { .and_then(Value::as_str) .ok_or(FcmError::NoAppId)?; let ttl = MAX_TTL.min(self.settings.ttl.max(notification.headers.ttl as usize)); - let message_data = build_message_data(notification, self.settings.max_data)?; + let message_data = build_message_data(notification)?; // Send the notification to FCM let client = self.clients.get(app_id).ok_or(FcmError::InvalidAppId)?;