diff --git a/autoendpoint/src/extractors/subscription.rs b/autoendpoint/src/extractors/subscription.rs index 9885a013d..a48602c4d 100644 --- a/autoendpoint/src/extractors/subscription.rs +++ b/autoendpoint/src/extractors/subscription.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; use std::error::Error; -use std::str::FromStr; use actix_web::{dev::Payload, web::Data, FromRequest, HttpRequest}; use autopush_common::{ @@ -12,7 +11,6 @@ use cadence::{CountedExt, StatsdClient}; use futures::{future::LocalBoxFuture, FutureExt}; use jsonwebtoken::{Algorithm, DecodingKey, Validation}; use openssl::hash::MessageDigest; -use url::Url; use uuid::Uuid; use crate::error::{ApiError, ApiErrorKind, ApiResult}; @@ -293,8 +291,9 @@ fn validate_vapid_jwt( let public_key = decode_public_key(public_key)?; let mut validation = Validation::new(Algorithm::ES256); - let audience: Vec<&str> = settings.vapid_aud.iter().map(|s| s.as_str()).collect(); - validation.set_audience(&audience); + // Set the audiences we allow. This obsoletes the need to manually match + // against values later. + validation.set_audience(&[settings.endpoint_url().origin().ascii_serialization()]); validation.set_required_spec_claims(&["exp", "aud", "sub"]); let token_data = match jsonwebtoken::decode::( @@ -389,27 +388,6 @@ fn validate_vapid_jwt( return Err(VapidError::FutureExpirationToken.into()); } - let aud = match Url::from_str(&token_data.claims.aud) { - Ok(v) => v, - Err(_) => { - error!("Bad Aud: Invalid audience {:?}", &token_data.claims.aud); - metrics.clone().incr("notification.auth.bad_vapid.aud"); - return Err(VapidError::InvalidAudience.into()); - } - }; - - let domain = &settings.endpoint_url(); - - if domain != &aud { - info!( - "Bad Aud: I am <{:?}>, asked for <{:?}> ", - domain.as_str(), - token_data.claims.aud - ); - metrics.clone().incr("notification.auth.bad_vapid.domain"); - return Err(VapidError::InvalidAudience.into()); - } - Ok(()) } @@ -512,6 +490,23 @@ pub mod tests { )); } + #[test] + fn vapid_aud_valid_for_alternate_host() { + let domain = "https://example.org"; + let test_settings = Settings { + endpoint_url: domain.to_owned(), + ..Default::default() + }; + let header = make_vapid( + "mailto:admin@example.com", + domain, + VapidClaims::default_exp() - 100, + PUB_KEY.to_owned(), + ); + let result = validate_vapid_jwt(&header, &test_settings, &Metrics::noop()); + assert!(result.is_ok()); + } + #[test] fn vapid_exp_is_string() { #[derive(Debug, Deserialize, Serialize)] @@ -523,7 +518,7 @@ pub mod tests { let domain = "https://push.services.mozilla.org"; let test_settings = Settings { - endpoint_url: "domain".to_owned(), + endpoint_url: domain.to_owned(), ..Default::default() }; let jwk_header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256); diff --git a/autoendpoint/src/settings.rs b/autoendpoint/src/settings.rs index e519ac3ef..02dc4ba13 100644 --- a/autoendpoint/src/settings.rs +++ b/autoendpoint/src/settings.rs @@ -31,8 +31,6 @@ pub struct Settings { pub router_table_name: String, pub message_table_name: String, - pub vapid_aud: Vec, - /// A stringified JSON list of VAPID public keys which should be tracked internally. /// This should ONLY include Mozilla generated and consumed messages (e.g. "SendToTab", etc.) /// These keys should be specified in stripped, b64encoded, X962 format (e.g. a single line of @@ -70,10 +68,6 @@ impl Default for Settings { db_settings: "".to_owned(), router_table_name: "router".to_string(), message_table_name: "message".to_string(), - vapid_aud: vec![ - "https://push.services.mozilla.org".to_string(), - "http://127.0.0.1:9160".to_string(), - ], // max data is a bit hard to figure out, due to encryption. Using something // like pywebpush, if you encode a block of 4096 bytes, you'll get a // 4216 byte data block. Since we're going to be receiving this, we have to