Skip to content

Commit

Permalink
Abstract the a2 client behind a trait and add a bunch of APNS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AzureMarker committed Jul 24, 2020
1 parent ce6c8dc commit 9cdb14a
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 63 deletions.
333 changes: 313 additions & 20 deletions autoendpoint/src/routers/apns/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::routers::common::build_message_data;
use crate::routers::{Router, RouterError, RouterResponse};
use a2::request::notification::LocalizedAlert;
use a2::request::payload::{APSAlert, Payload, APS};
use a2::{Endpoint, NotificationOptions, Priority};
use a2::{Endpoint, Error, NotificationOptions, Priority, Response};
use async_trait::async_trait;
use cadence::{Counted, StatsdClient};
use futures::{StreamExt, TryStreamExt};
Expand All @@ -18,17 +18,29 @@ use url::Url;

pub struct ApnsRouter {
/// A map from release channel to APNS client
clients: HashMap<String, ApnsClient>,
clients: HashMap<String, ApnsClientData>,
settings: ApnsSettings,
endpoint_url: Url,
metrics: StatsdClient,
}

struct ApnsClient {
client: a2::Client,
struct ApnsClientData {
client: Box<dyn ApnsClient>,
topic: String,
}

#[async_trait]
trait ApnsClient: Send + Sync {
async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
}

#[async_trait]
impl ApnsClient for a2::Client {
async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
self.send(payload).await
}
}

impl ApnsRouter {
/// Create a new APNS router. APNS clients will be initialized for each
/// channel listed in the settings.
Expand Down Expand Up @@ -56,17 +68,19 @@ impl ApnsRouter {
async fn create_client(
name: String,
settings: ApnsChannel,
) -> Result<(String, ApnsClient), ApnsError> {
) -> Result<(String, ApnsClientData), ApnsError> {
let endpoint = if settings.sandbox {
Endpoint::Sandbox
} else {
Endpoint::Production
};
let cert = tokio::fs::read(settings.cert).await?;
let key = tokio::fs::read(settings.key).await?;
let client = ApnsClient {
client: a2::Client::certificate_parts(&cert, &key, endpoint)
.map_err(ApnsError::ApnsClient)?,
let client = ApnsClientData {
client: Box::new(
a2::Client::certificate_parts(&cert, &key, endpoint)
.map_err(ApnsError::ApnsClient)?,
),
topic: settings
.topic
.unwrap_or_else(|| format!("com.mozilla.org.{}", name)),
Expand All @@ -75,6 +89,21 @@ impl ApnsRouter {
Ok((name, client))
}

/// The default APS data for a notification
fn default_aps<'a>() -> APS<'a> {
APS {
alert: Some(APSAlert::Localized({
LocalizedAlert {
title_loc_key: Some("SentTab.NoTabArrivingNotification.title"),
loc_key: Some("SentTab.NoTabArrivingNotification.body"),
..Default::default()
}
})),
mutable_content: Some(1),
..Default::default()
}
}

/// Update metrics after successfully routing the notification
fn incr_success_metrics(&self, notification: &Notification, channel: &str) {
self.metrics
Expand Down Expand Up @@ -190,22 +219,15 @@ impl Router for ApnsRouter {
let aps: APS<'_> = router_data
.get("aps")
.and_then(|value| APS::deserialize(value).ok())
.unwrap_or_else(|| APS {
alert: Some(APSAlert::Localized({
LocalizedAlert {
title_loc_key: Some("SentTab.NoTabArrivingNotification.title"),
loc_key: Some("SentTab.NoTabArrivingNotification.body"),
..Default::default()
}
})),
mutable_content: Some(1),
..Default::default()
});
.unwrap_or_else(Self::default_aps);
let mut message_data = build_message_data(notification, self.settings.max_data)?;
message_data.insert("ver", notification.message_id.clone());

// Get client and build payload
let ApnsClient { client, topic } = self.clients.get(channel).unwrap();
let ApnsClientData { client, topic } = self
.clients
.get(channel)
.ok_or(ApnsError::InvalidReleaseChannel)?;
let payload = Payload {
aps,
data: message_data
Expand Down Expand Up @@ -241,3 +263,274 @@ impl Router for ApnsRouter {
))
}
}

#[cfg(test)]
mod tests {
use crate::error::ApiErrorKind;
use crate::extractors::routers::RouterType;
use crate::routers::apns::error::ApnsError;
use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
use crate::routers::apns::settings::ApnsSettings;
use crate::routers::common::tests::{make_notification, CHANNEL_ID};
use crate::routers::{Router, RouterError, RouterResponse};
use a2::request::payload::Payload;
use a2::{Error, Response};
use async_trait::async_trait;
use cadence::StatsdClient;
use std::collections::HashMap;
use url::Url;

const DEVICE_TOKEN: &str = "test-token";
const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";

/// A mock APNS client which allows one to supply a custom APNS response/error
struct MockApnsClient {
send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
}

#[async_trait]
impl ApnsClient for MockApnsClient {
async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
(self.send_fn)(payload)
}
}

impl MockApnsClient {
fn new<F>(send_fn: F) -> Self
where
F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
F: Send + Sync + 'static,
{
Self {
send_fn: Box::new(send_fn),
}
}
}

/// Create a successful APNS response
fn apns_success_response() -> a2::Response {
a2::Response {
error: None,
apns_id: Some(APNS_ID.to_string()),
code: 200,
}
}

/// Create a router for testing, using the given APNS client
fn make_router(client: MockApnsClient) -> ApnsRouter {
ApnsRouter {
clients: {
let mut map = HashMap::new();
map.insert(
"test-channel".to_string(),
ApnsClientData {
client: Box::new(client),
topic: "test-topic".to_string(),
},
);
map
},
settings: ApnsSettings::default(),
endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
metrics: StatsdClient::from_sink("autopush", cadence::NopMetricSink),
}
}

/// Create default user router data
fn default_router_data() -> HashMap<String, serde_json::Value> {
let mut map = HashMap::new();
map.insert(
"token".to_string(),
serde_json::to_value(DEVICE_TOKEN).unwrap(),
);
map.insert(
"rel_channel".to_string(),
serde_json::to_value("test-channel").unwrap(),
);
map
}

/// A notification with no data is packaged correctly and sent to APNS
#[tokio::test]
async fn successful_routing_no_data() {
let client = MockApnsClient::new(|payload| {
assert_eq!(
serde_json::to_value(payload.aps).unwrap(),
serde_json::to_value(ApnsRouter::default_aps()).unwrap()
);
assert_eq!(payload.device_token, DEVICE_TOKEN);
assert_eq!(payload.options.apns_topic, Some("test-topic"));
assert_eq!(
serde_json::to_value(payload.data).unwrap(),
serde_json::json!({
"chid": CHANNEL_ID,
"ver": "test-message-id"
})
);

Ok(apns_success_response())
});
let router = make_router(client);
let notification = make_notification(default_router_data(), None, RouterType::APNS);

let result = router.route_notification(&notification).await;
assert!(result.is_ok(), "result = {:?}", result);
assert_eq!(
result.unwrap(),
RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
);
}

/// A notification with data is packaged correctly and sent to APNS
#[tokio::test]
async fn successful_routing_with_data() {
let client = MockApnsClient::new(|payload| {
assert_eq!(
serde_json::to_value(payload.aps).unwrap(),
serde_json::to_value(ApnsRouter::default_aps()).unwrap()
);
assert_eq!(payload.device_token, DEVICE_TOKEN);
assert_eq!(payload.options.apns_topic, Some("test-topic"));
assert_eq!(
serde_json::to_value(payload.data).unwrap(),
serde_json::json!({
"chid": CHANNEL_ID,
"ver": "test-message-id",
"body": "test-data",
"con": "test-encoding",
"enc": "test-encryption",
"cryptokey": "test-crypto-key",
"enckey": "test-encryption-key"
})
);

Ok(apns_success_response())
});
let router = make_router(client);
let data = "test-data".to_string();
let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);

let result = router.route_notification(&notification).await;
assert!(result.is_ok(), "result = {:?}", result);
assert_eq!(
result.unwrap(),
RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
);
}

/// If there is no client for the user's release channel, an error is
/// returned and the APNS request is not sent.
#[tokio::test]
async fn missing_client() {
let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
let router = make_router(client);
let mut router_data = default_router_data();
router_data.insert(
"rel_channel".to_string(),
serde_json::to_value("unknown-app-id").unwrap(),
);
let notification = make_notification(router_data, None, RouterType::APNS);

let result = router.route_notification(&notification).await;
assert!(result.is_err());
assert!(
matches!(
result.as_ref().unwrap_err().kind,
ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
),
"result = {:?}",
result
);
}

/// If APNS says the user doesn't exist anymore, we return a specific error
#[tokio::test]
async fn user_not_found() {
let client = MockApnsClient::new(|_| {
Err(a2::Error::ResponseError(a2::Response {
error: Some(a2::ErrorBody {
reason: a2::ErrorReason::Unregistered,
timestamp: Some(0),
}),
apns_id: None,
code: 410,
}))
});
let router = make_router(client);
let notification = make_notification(default_router_data(), None, RouterType::APNS);

let result = router.route_notification(&notification).await;
assert!(result.is_err());
assert!(
matches!(
result.as_ref().unwrap_err().kind,
ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
),
"result = {:?}",
result
);
}

/// APNS errors (other than Unregistered) are wrapped and returned
#[tokio::test]
async fn upstream_error() {
let client = MockApnsClient::new(|_| {
Err(a2::Error::ResponseError(a2::Response {
error: Some(a2::ErrorBody {
reason: a2::ErrorReason::BadCertificate,
timestamp: None,
}),
apns_id: None,
code: 403,
}))
});
let router = make_router(client);
let notification = make_notification(default_router_data(), None, RouterType::APNS);

let result = router.route_notification(&notification).await;
assert!(result.is_err());
assert!(
matches!(
result.as_ref().unwrap_err().kind,
ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
a2::Error::ResponseError(a2::Response {
error: Some(a2::ErrorBody {
reason: a2::ErrorReason::BadCertificate,
timestamp: None,
}),
apns_id: None,
code: 403,
})
)))
),
"result = {:?}",
result
);
}

/// The default APS data is used if the user's APS data is invalid
#[tokio::test]
async fn invalid_aps_data() {
let client = MockApnsClient::new(|payload| {
assert_eq!(
serde_json::to_value(payload.aps).unwrap(),
serde_json::to_value(ApnsRouter::default_aps()).unwrap()
);
Ok(apns_success_response())
});
let router = make_router(client);
let mut router_data = default_router_data();
router_data.insert(
"aps".to_string(),
serde_json::json!({"mutable-content": "should be a number"}),
);
let notification = make_notification(router_data, None, RouterType::APNS);

let result = router.route_notification(&notification).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
);
}
}
Loading

0 comments on commit 9cdb14a

Please sign in to comment.