diff --git a/autoconnect/autoconnect-common/src/protocol.rs b/autoconnect/autoconnect-common/src/protocol.rs index 7936a5d60..d02d623da 100644 --- a/autoconnect/autoconnect-common/src/protocol.rs +++ b/autoconnect/autoconnect-common/src/protocol.rs @@ -80,10 +80,15 @@ impl FromStr for ClientMessage { } } +/// Returned ACKnowledgement of the received message by the User Agent. +/// This is the payload for the `messageType:ack` packet. +/// #[derive(Debug, Deserialize)] pub struct ClientAck { + // The channel_id which received messages #[serde(rename = "channelID")] pub channel_id: Uuid, + // The corresponding version number for the message. pub version: String, } diff --git a/autoconnect/autoconnect-settings/src/app_state.rs b/autoconnect/autoconnect-settings/src/app_state.rs index 0d7656eca..8383c6ee8 100644 --- a/autoconnect/autoconnect-settings/src/app_state.rs +++ b/autoconnect/autoconnect-settings/src/app_state.rs @@ -68,6 +68,7 @@ impl AppState { db_settings: settings.db_settings.clone(), }; let storage_type = StorageType::from_dsn(&db_settings.dsn); + #[allow(unused)] let db: Box = match storage_type { #[cfg(feature = "bigtable")] diff --git a/autoconnect/autoconnect-web/Cargo.toml b/autoconnect/autoconnect-web/Cargo.toml index beebbd533..08f5f4ba8 100644 --- a/autoconnect/autoconnect-web/Cargo.toml +++ b/autoconnect/autoconnect-web/Cargo.toml @@ -34,3 +34,5 @@ ctor.workspace = true tokio.workspace = true autoconnect_common = { workspace = true, features = ["test-support"] } + +[features] diff --git a/autoconnect/autoconnect-ws/Cargo.toml b/autoconnect/autoconnect-ws/Cargo.toml index b6b08accf..194d79742 100644 --- a/autoconnect/autoconnect-ws/Cargo.toml +++ b/autoconnect/autoconnect-ws/Cargo.toml @@ -33,3 +33,5 @@ async-stream = "0.3" ctor.workspace = true autoconnect_common = { workspace = true, features = ["test-support"] } + +[features] diff --git a/autoconnect/autoconnect-ws/autoconnect-ws-sm/Cargo.toml b/autoconnect/autoconnect-ws/autoconnect-ws-sm/Cargo.toml index 2e3dbfa30..bc76baf62 100644 --- a/autoconnect/autoconnect-ws/autoconnect-ws-sm/Cargo.toml +++ b/autoconnect/autoconnect-ws/autoconnect-ws-sm/Cargo.toml @@ -30,3 +30,5 @@ tokio.workspace = true serde_json.workspace = true autoconnect_common = { workspace = true, features = ["test-support"] } + +[features] diff --git a/autoconnect/autoconnect-ws/autoconnect-ws-sm/src/identified/on_client_msg.rs b/autoconnect/autoconnect-ws/autoconnect-ws-sm/src/identified/on_client_msg.rs index 784626ea7..ea277cdf1 100644 --- a/autoconnect/autoconnect-ws/autoconnect-ws-sm/src/identified/on_client_msg.rs +++ b/autoconnect/autoconnect-ws/autoconnect-ws-sm/src/identified/on_client_msg.rs @@ -216,6 +216,7 @@ impl WebPushClient { // Get the stored notification record. let n = &self.ack_state.unacked_stored_notifs[pos]; debug!("✅ Ack notif: {:?}", &n); + // TODO: Record "ack'd" reliability_id, if present. // Only force delete Topic messages, since they don't have a timestamp. // Other messages persist in the database, to be, eventually, cleaned up by their // TTL. We will need to update the `CurrentTimestamp` field for the channel diff --git a/autoendpoint/src/error.rs b/autoendpoint/src/error.rs index 8881193b1..3aec6c2c3 100644 --- a/autoendpoint/src/error.rs +++ b/autoendpoint/src/error.rs @@ -70,6 +70,12 @@ pub enum ApiErrorKind { #[error(transparent)] Jwt(#[from] jsonwebtoken::errors::Error), + #[error(transparent)] + Serde(#[from] serde_json::Error), + + #[error(transparent)] + ReqwestError(#[from] reqwest::Error), + #[error("Error while validating token")] TokenHashValidation(#[source] openssl::error::ErrorStack), @@ -143,6 +149,7 @@ impl ApiErrorKind { ApiErrorKind::VapidError(_) | ApiErrorKind::Jwt(_) + | ApiErrorKind::Serde(_) | ApiErrorKind::TokenHashValidation(_) | ApiErrorKind::InvalidAuthentication | ApiErrorKind::InvalidLocalAuth(_) => StatusCode::UNAUTHORIZED, @@ -161,7 +168,8 @@ impl ApiErrorKind { | ApiErrorKind::Io(_) | ApiErrorKind::Metrics(_) | ApiErrorKind::EndpointUrl(_) - | ApiErrorKind::RegistrationSecretHash(_) => StatusCode::INTERNAL_SERVER_ERROR, + | ApiErrorKind::RegistrationSecretHash(_) + | ApiErrorKind::ReqwestError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -179,7 +187,7 @@ impl ApiErrorKind { ApiErrorKind::InvalidMessageId => "invalid_message_id", ApiErrorKind::VapidError(_) => "vapid_error", - ApiErrorKind::Jwt(_) => "jwt", + ApiErrorKind::Jwt(_) | ApiErrorKind::Serde(_) => "jwt", ApiErrorKind::TokenHashValidation(_) => "token_hash_validation", ApiErrorKind::InvalidAuthentication => "invalid_authentication", ApiErrorKind::InvalidLocalAuth(_) => "invalid_local_auth", @@ -199,6 +207,7 @@ impl ApiErrorKind { ApiErrorKind::Conditional(_) => "conditional", ApiErrorKind::EndpointUrl(e) => return e.metric_label(), ApiErrorKind::RegistrationSecretHash(_) => "registration_secret_hash", + ApiErrorKind::ReqwestError(_) => "reqwest", }) } @@ -221,7 +230,8 @@ impl ApiErrorKind { // Ignore oversized payload. ApiErrorKind::PayloadError(_) | ApiErrorKind::Validation(_) | - ApiErrorKind::Conditional(_) => false, + ApiErrorKind::Conditional(_) | + ApiErrorKind::ReqwestError(_) => false, _ => true, } } @@ -251,6 +261,7 @@ impl ApiErrorKind { ApiErrorKind::VapidError(_) | ApiErrorKind::TokenHashValidation(_) | ApiErrorKind::Jwt(_) + | ApiErrorKind::Serde(_) | ApiErrorKind::InvalidAuthentication | ApiErrorKind::InvalidLocalAuth(_) => Some(109), @@ -269,7 +280,8 @@ impl ApiErrorKind { | ApiErrorKind::InvalidRouterToken | ApiErrorKind::RegistrationSecretHash(_) | ApiErrorKind::EndpointUrl(_) - | ApiErrorKind::InvalidMessageId => None, + | ApiErrorKind::InvalidMessageId + | ApiErrorKind::ReqwestError(_) => None, } } } diff --git a/autoendpoint/src/extractors/notification.rs b/autoendpoint/src/extractors/notification.rs index ad86a9660..f29143c8e 100644 --- a/autoendpoint/src/extractors/notification.rs +++ b/autoendpoint/src/extractors/notification.rs @@ -1,4 +1,4 @@ -use crate::error::{ApiError, ApiErrorKind}; +use crate::error::{ApiError, ApiErrorKind, ApiResult}; use crate::extractors::{ message_id::MessageId, notification_headers::NotificationHeaders, subscription::Subscription, }; @@ -103,6 +103,7 @@ impl From for autopush_common::notification::Notification { timestamp: notification.timestamp, data: notification.data, sortkey_timestamp, + reliability_id: notification.subscription.reliability_id, headers: { let headers: HashMap = notification.headers.into(); if headers.is_empty() { @@ -160,25 +161,28 @@ impl Notification { /// fields are still required when delivering to the connection server, so /// we can't simply convert this notification type to that one and serialize /// via serde. - pub fn serialize_for_delivery(&self) -> HashMap<&'static str, serde_json::Value> { + pub fn serialize_for_delivery(&self) -> ApiResult> { let mut map = HashMap::new(); map.insert( "channelID", - serde_json::to_value(self.subscription.channel_id).unwrap(), + serde_json::to_value(self.subscription.channel_id)?, ); - map.insert("version", serde_json::to_value(&self.message_id).unwrap()); - map.insert("ttl", serde_json::to_value(self.headers.ttl).unwrap()); - map.insert("topic", serde_json::to_value(&self.headers.topic).unwrap()); - map.insert("timestamp", serde_json::to_value(self.timestamp).unwrap()); + map.insert("version", serde_json::to_value(&self.message_id)?); + map.insert("ttl", serde_json::to_value(self.headers.ttl)?); + map.insert("topic", serde_json::to_value(&self.headers.topic)?); + map.insert("timestamp", serde_json::to_value(self.timestamp)?); + if let Some(reliability_id) = &self.subscription.reliability_id { + map.insert("reliability_id", serde_json::to_value(reliability_id)?); + } if let Some(data) = &self.data { - map.insert("data", serde_json::to_value(data).unwrap()); + map.insert("data", serde_json::to_value(data)?); let headers: HashMap<_, _> = self.headers.clone().into(); - map.insert("headers", serde_json::to_value(headers).unwrap()); + map.insert("headers", serde_json::to_value(headers)?); } - map + Ok(map) } } diff --git a/autoendpoint/src/extractors/subscription.rs b/autoendpoint/src/extractors/subscription.rs index a48602c4d..c1d5f9b0e 100644 --- a/autoendpoint/src/extractors/subscription.rs +++ b/autoendpoint/src/extractors/subscription.rs @@ -38,7 +38,7 @@ pub struct Subscription { /// (This should ONLY be applied for messages that match known /// Mozilla provided VAPID public keys.) /// - pub tracking_id: Option, + pub reliability_id: Option, } impl FromRequest for Subscription { @@ -73,11 +73,13 @@ impl FromRequest for Subscription { .transpose()?; trace!("raw vapid: {:?}", &vapid); - let trackable = if let Some(vapid) = &vapid { - app_state.reliability.is_trackable(vapid) - } else { - false - }; + let reliability_id: Option = vapid.as_ref().and_then(|v| { + app_state + .vapid_tracker + .is_trackable(v) + .then(|| app_state.vapid_tracker.get_id(req.headers())) + }); + debug!("🔍 Assigning Reliability ID: {reliability_id:?}"); // Capturing the vapid sub right now will cause too much cardinality. Instead, // let's just capture if we have a valid VAPID, as well as what sort of bad sub @@ -132,14 +134,11 @@ impl FromRequest for Subscription { .incr(&format!("updates.vapid.draft{:02}", vapid.vapid.version()))?; } - let tracking_id = - trackable.then(|| app_state.reliability.get_tracking_id(req.headers())); - Ok(Subscription { user, channel_id, vapid, - tracking_id, + reliability_id, }) } .boxed_local() diff --git a/autoendpoint/src/routers/common.rs b/autoendpoint/src/routers/common.rs index e09ffd373..34f68282d 100644 --- a/autoendpoint/src/routers/common.rs +++ b/autoendpoint/src/routers/common.rs @@ -21,6 +21,12 @@ pub fn build_message_data(notification: &Notification) -> ApiResult { - if error.is_timeout() { - self.metrics.incr("error.node.timeout")?; + if let ApiErrorKind::ReqwestError(error) = &error.kind { + if error.is_timeout() { + self.metrics.incr("error.node.timeout")?; + }; + if error.is_connect() { + self.metrics.incr("error.node.connect")?; + }; }; - if error.is_connect() { - self.metrics.incr("error.node.connect")?; - } debug!("✉ Error while sending webpush notification: {}", error); self.remove_node_id(user, node_id).await? } @@ -177,11 +179,11 @@ impl WebPushRouter { &self, notification: &Notification, node_id: &str, - ) -> Result { + ) -> ApiResult { let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid); - let notification = notification.serialize_for_delivery(); + let notification = notification.serialize_for_delivery()?; - self.http.put(&url).json(¬ification).send().await + Ok(self.http.put(&url).json(¬ification).send().await?) } /// Notify the node to check for notifications for the user diff --git a/autoendpoint/src/server.rs b/autoendpoint/src/server.rs index bebbc9a8e..864449510 100644 --- a/autoendpoint/src/server.rs +++ b/autoendpoint/src/server.rs @@ -48,7 +48,7 @@ pub struct AppState { pub apns_router: Arc, #[cfg(feature = "stub")] pub stub_router: Arc, - pub reliability: Arc, + pub vapid_tracker: Arc, } pub struct Server; @@ -109,7 +109,7 @@ impl Server { ) .await?, ); - let reliability = Arc::new(VapidTracker(settings.tracking_keys())); + let vapid_tracker = Arc::new(VapidTracker(settings.tracking_keys())); #[cfg(feature = "stub")] let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?); let app_state = AppState { @@ -122,7 +122,7 @@ impl Server { apns_router, #[cfg(feature = "stub")] stub_router, - reliability, + vapid_tracker, }; spawn_pool_periodic_reporter( diff --git a/autoendpoint/src/settings.rs b/autoendpoint/src/settings.rs index 02dc4ba13..9d5997185 100644 --- a/autoendpoint/src/settings.rs +++ b/autoendpoint/src/settings.rs @@ -169,9 +169,11 @@ impl Settings { // public key, but that may not always be true. pub fn tracking_keys(&self) -> Vec { let keys = &self.tracking_keys.replace(['"', ' '], ""); - Self::read_list_from_str(keys, "Invalid AUTOEND_TRACKING_KEYS") - .map(|v| v.to_owned()) - .collect() + let result = Self::read_list_from_str(keys, "Invalid AUTOEND_TRACKING_KEYS") + .map(|v| v.to_owned().replace("=", "")) + .collect(); + trace!("🔍 tracking_keys: {result:?}"); + result } /// Get the URL for this endpoint server @@ -193,11 +195,20 @@ impl VapidTracker { pub fn is_trackable(&self, vapid: &VapidHeaderWithKey) -> bool { // ideally, [Settings.with_env_and_config_file()] does the work of pre-populating // the Settings.tracking_vapid_pubs cache, but we can't rely on that. - self.0.contains(&vapid.public_key) + let key = vapid.public_key.replace('=', ""); + let result = self.0.contains(&key); + debug!("🔍 Checking {key} {}", { + if result { + "Match!" + } else { + "no match" + } + }); + result } /// Extract the message Id from the headers (if present), otherwise just make one up. - pub fn get_tracking_id(&self, headers: &HeaderMap) -> String { + pub fn get_id(&self, headers: &HeaderMap) -> String { headers .get("X-MessageId") .and_then(|v| @@ -304,7 +315,7 @@ mod tests { #[test] fn test_tracking_keys() -> ApiResult<()> { let settings = Settings{ - tracking_keys: r#"["BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon-H_boHTzMtMoNHsAGDlDB6X7vI"]"#.to_owned(), + tracking_keys: r#"["BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon-H_boHTzMtMoNHsAGDlDB6X7"]"#.to_owned(), ..Default::default() }; @@ -314,7 +325,7 @@ mod tests { token: "".to_owned(), version_data: crate::headers::vapid::VapidVersionData::Version1, }, - public_key: "BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon-H_boHTzMtMoNHsAGDlDB6X7vI".to_owned() + public_key: "BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon-H_boHTzMtMoNHsAGDlDB6X7==".to_owned() }; let key_set = settings.tracking_keys(); @@ -327,12 +338,12 @@ mod tests { } #[test] - fn test_tracking_id() -> ApiResult<()> { + fn test_reliability_id() -> ApiResult<()> { let mut headers = HeaderMap::new(); let keys = Vec::new(); let reliability = VapidTracker(keys); - let key = reliability.get_tracking_id(&headers); + let key = reliability.get_id(&headers); assert!(!key.is_empty()); headers.insert( @@ -340,7 +351,7 @@ mod tests { HeaderValue::from_static("123foobar456"), ); - let key = reliability.get_tracking_id(&headers); + let key = reliability.get_id(&headers); assert_eq!(key, "123foobar456".to_owned()); Ok(()) diff --git a/autopush-common/src/db/bigtable/bigtable_client/mod.rs b/autopush-common/src/db/bigtable/bigtable_client/mod.rs index 2e84aaffa..1310b5ab5 100644 --- a/autopush-common/src/db/bigtable/bigtable_client/mod.rs +++ b/autopush-common/src/db/bigtable/bigtable_client/mod.rs @@ -720,6 +720,7 @@ impl BigTableClientImpl { ) })?; + // Create from the known, required fields. let mut notif = Notification { channel_id: range_key.channel_id, topic: range_key.topic, @@ -730,6 +731,7 @@ impl BigTableClientImpl { ..Default::default() }; + // Backfill the Optional fields if let Some(cell) = row.take_cell("data") { notif.data = Some(to_string(cell.value, "data")?); } @@ -739,6 +741,10 @@ impl BigTableClientImpl { .map_err(|e| DbError::Serialization(e.to_string()))?, ); } + if let Some(cell) = row.take_cell("reliability_id") { + trace!("🚣 Is reliable"); + notif.reliability_id = Some(to_string(cell.value, "reliability_id")?); + } trace!("🚣 Deserialized message row: {:?}", ¬if); Ok(notif) @@ -1182,6 +1188,15 @@ impl DbClient for BigTableClientImpl { ..Default::default() }); } + + if let Some(reliability_id) = message.reliability_id { + cells.push(cell::Cell { + qualifier: "reliability_id".to_owned(), + value: reliability_id.into_bytes(), + timestamp: expiry, + ..Default::default() + }); + } row.add_cells(family, cells); trace!("🉑 Adding row"); self.write_row(row).await?; @@ -1301,6 +1316,7 @@ impl DbClient for BigTableClientImpl { ); let messages = self.rows_to_notifications(rows)?; + // Note: Bigtable always returns a timestamp of None. // Under Bigtable `current_timestamp` is instead initially read // from [get_user]. diff --git a/autopush-common/src/db/mod.rs b/autopush-common/src/db/mod.rs index 3b1d2f11b..70a2f0254 100644 --- a/autopush-common/src/db/mod.rs +++ b/autopush-common/src/db/mod.rs @@ -245,6 +245,10 @@ pub struct NotificationRecord { /// value before sending it to storage or a connection node. #[serde(skip_serializing_if = "Option::is_none")] updateid: Option, + /// Internal Push Reliability tracking id. (Applied only to subscription updates generated + /// by Mozilla owned and consumed messages, like SendTab updates.) + #[serde(skip_serializing_if = "Option::is_none")] + reliability_id: Option, } impl NotificationRecord { @@ -333,6 +337,7 @@ impl NotificationRecord { data: self.data, headers: self.headers.map(|m| m.into()), sortkey_timestamp: key.sortkey_timestamp, + reliability_id: None, }) } diff --git a/autopush-common/src/notification.rs b/autopush-common/src/notification.rs index 33ef0ee32..ecbb7f423 100644 --- a/autopush-common/src/notification.rs +++ b/autopush-common/src/notification.rs @@ -27,6 +27,8 @@ pub struct Notification { pub sortkey_timestamp: Option, #[serde(skip_serializing_if = "Option::is_none")] pub headers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub reliability_id: Option, } pub const TOPIC_NOTIFICATION_PREFIX: &str = "01"; diff --git a/scripts/convert_pem_to_x962.py b/scripts/convert_pem_to_x962.py index bf261215e..ffc2cd2c8 100644 --- a/scripts/convert_pem_to_x962.py +++ b/scripts/convert_pem_to_x962.py @@ -4,7 +4,7 @@ Autopush will try to scan for known VAPID public keys to track. These keys are specified in the header as x962 formatted strings. X962 is effectively "raw" format and contains the two longs that are the coordinates for the -public key. +public key prefixed with a '\04` byte. """ import base64 diff --git a/tests/integration/async_push_test_client.py b/tests/integration/async_push_test_client.py index 6b3a41437..714446ec1 100644 --- a/tests/integration/async_push_test_client.py +++ b/tests/integration/async_push_test_client.py @@ -28,6 +28,7 @@ class ClientMessageType(Enum): ACK = "ack" NACK = "nack" PING = "ping" + NOTIFICATION = "notification" class AsyncPushTestClient: diff --git a/tests/integration/test_integration_all_rust.py b/tests/integration/test_integration_all_rust.py index a9fa0a46d..4501f59b3 100644 --- a/tests/integration/test_integration_all_rust.py +++ b/tests/integration/test_integration_all_rust.py @@ -16,7 +16,7 @@ import uuid from queue import Empty, Queue from threading import Event, Thread -from typing import Any, AsyncGenerator, Generator +from typing import Any, AsyncGenerator, Generator, cast from urllib.parse import urlparse import ecdsa @@ -48,6 +48,8 @@ MSG_LIMIT = 20 CRYPTO_KEY = os.environ.get("CRYPTO_KEY") or Fernet.generate_key().decode("utf-8") +TRACKING_KEY = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) +TRACKING_PUB_KEY = cast(ecdsa.VerifyingKey, TRACKING_KEY.get_verifying_key()) CONNECTION_PORT = 9150 ENDPOINT_PORT = 9160 ROUTER_PORT = 9170 @@ -125,7 +127,7 @@ def base64url_encode(value: bytes | str) -> str: MOCK_SERVER_PORT: Any = get_free_port() MOCK_MP_SERVICES: dict = {} -MOCK_MP_TOKEN: str = "Bearer {}".format(uuid.uuid4().hex) +MOCK_MP_TOKEN: str = f"Bearer {uuid.uuid4().hex}" MOCK_MP_POLLED: Event = Event() MOCK_SENTRY_QUEUE: Queue = Queue() @@ -142,7 +144,7 @@ def base64url_encode(value: bytes | str) -> str: endpoint_scheme="http", router_tablename=ROUTER_TABLE, message_tablename=MESSAGE_TABLE, - crypto_key="[{}]".format(CRYPTO_KEY), + crypto_key=f"[{CRYPTO_KEY}]", auto_ping_interval=30.0, auto_ping_timeout=10.0, close_handshake_timeout=5, @@ -183,7 +185,9 @@ def base64url_encode(value: bytes | str) -> str: router_table_name=ROUTER_TABLE, message_table_name=MESSAGE_TABLE, human_logs="true", - crypto_keys="[{}]".format(CRYPTO_KEY), + crypto_keys=f"[{CRYPTO_KEY}]", + # convert to x692 format + tracking_keys=f"[{base64.urlsafe_b64encode((b"\4" + TRACKING_PUB_KEY.to_string())).decode()}]", ) @@ -198,10 +202,10 @@ def _get_vapid( global CONNECTION_CONFIG if endpoint is None: - endpoint = "{}://{}:{}".format( - CONNECTION_CONFIG.get("endpoint_scheme"), - CONNECTION_CONFIG.get("endpoint_hostname"), - CONNECTION_CONFIG.get("endpoint_port"), + endpoint = ( + f"{CONNECTION_CONFIG.get("endpoint_scheme")}://" + f"{CONNECTION_CONFIG.get("endpoint_hostname")}:" + f"{CONNECTION_CONFIG.get("endpoint_port")}" ) if not payload: payload = { @@ -213,7 +217,7 @@ def _get_vapid( payload["aud"] = endpoint if not key: key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) - vk: ecdsa.VerifyingKey = key.get_verifying_key() + vk: ecdsa.VerifyingKey = cast(ecdsa.VerifyingKey, key.get_verifying_key()) auth: str = jws.sign(payload, key, algorithm="ES256").strip("=") crypto_key: str = base64url_encode((b"\4" + vk.to_string())) return {"auth": auth, "crypto-key": crypto_key, "key": key} @@ -327,11 +331,11 @@ def get_rust_binary_path(binary) -> str: """ global STRICT_LOG_COUNTS - rust_bin: str = root_dir + "/target/release/{}".format(binary) + rust_bin: str = root_dir + f"/target/release/{binary}" possible_paths: list[str] = [ - "/target/debug/{}".format(binary), - "/{0}/target/release/{0}".format(binary), - "/{0}/target/debug/{0}".format(binary), + f"/target/debug/{binary}", + f"/{binary}/target/release/{binary}", + f"/{binary}/target/debug/{binary}", ] while possible_paths and not os.path.exists(rust_bin): # pragma: nocover rust_bin = root_dir + possible_paths.pop(0) @@ -347,7 +351,7 @@ def write_config_to_env(config, prefix) -> None: """Write configurations to application read environment variables.""" for key, val in config.items(): new_key = prefix + key - log.debug("✍ config {} => {}".format(new_key, val)) + log.debug(f"✍ config {new_key} => {val}") os.environ[new_key.upper()] = str(val) @@ -466,7 +470,7 @@ def setup_megaphone_server(connection_binary) -> None: else: write_config_to_env(MEGAPHONE_CONFIG, CONNECTION_SETTINGS_PREFIX) cmd = [connection_binary] - log.debug("🐍🟢 Starting Megaphone server: {}".format(" ".join(cmd))) + log.debug(f"🐍🟢 Starting Megaphone server: {' '.join(cmd)}") CN_MP_SERVER = subprocess.Popen(cmd, shell=True, env=os.environ) # nosec @@ -495,7 +499,7 @@ def setup_endpoint_server() -> None: # Run autoendpoint cmd = [get_rust_binary_path("autoendpoint")] - log.debug("🐍🟢 Starting Endpoint server: {}".format(" ".join(cmd))) + log.debug(f"🐍🟢 Starting Endpoint server: {' '.join(cmd)}") EP_SERVER = subprocess.Popen( cmd, shell=True, @@ -734,7 +738,7 @@ async def test_basic_delivery(registered_test_client: AsyncPushTestClient) -> No clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(bytes(uuid_data, "utf-8")) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value async def test_topic_basic_delivery(registered_test_client: AsyncPushTestClient) -> None: @@ -745,7 +749,7 @@ async def test_topic_basic_delivery(registered_test_client: AsyncPushTestClient) clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(uuid_data) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value async def test_topic_replacement_delivery( @@ -765,7 +769,7 @@ async def test_topic_replacement_delivery( clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(uuid_data_2) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value result = await registered_test_client.get_notification() assert result is None @@ -783,7 +787,7 @@ async def test_topic_no_delivery_on_reconnect(registered_test_client: AsyncPushT clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(uuid_data) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value await registered_test_client.ack(result["channelID"], result["version"]) await registered_test_client.disconnect() await registered_test_client.connect() @@ -807,7 +811,42 @@ async def test_basic_delivery_with_vapid( clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(uuid_data) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value + # The key we used should not have been registered, so no tracking should + # be occurring. + log.debug(f"🔍 Reliability: {result.get("reliability_id")}") + assert result.get("reliability_id") is None + + +async def test_basic_delivery_with_tracked_vapid( + registered_test_client: AsyncPushTestClient, + vapid_payload: dict[str, int | str], +) -> None: + """Test delivery of a basic push message with a VAPID header.""" + uuid_data: str = str(uuid.uuid4()) + vapid_info = _get_vapid(key=TRACKING_KEY, payload=vapid_payload) + # quick sanity check to ensure that the keys match. + # (ideally, this should dump as x962, but DER is good enough.) + key = cast( + ecdsa.VerifyingKey, cast(ecdsa.SigningKey, vapid_info["key"]).get_verifying_key() + ).to_der() + + assert key == TRACKING_PUB_KEY.to_der() + + # let's do an offline submit so we can validate the reliability_id survives storage. + await registered_test_client.disconnect() + await registered_test_client.send_notification(data=uuid_data, vapid=vapid_info) + await registered_test_client.connect() + await registered_test_client.hello() + result = await registered_test_client.get_notification() + + # the following presumes that only `salt` is padded. + clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") + assert result["headers"]["encryption"] == clean_header + assert result["data"] == base64url_encode(uuid_data) + assert result["messageType"] == ClientMessageType.NOTIFICATION.value + log.debug(f"🔍 reliability {result}") + assert result["reliability_id"] is not None async def test_basic_delivery_with_invalid_vapid( @@ -1004,7 +1043,7 @@ async def test_multiple_delivery_with_single_ack( result = await registered_test_client.get_notification(timeout=0.5) assert result != {} assert result["data"] == base64url_encode(uuid_data_1) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value result2 = await registered_test_client.get_notification() assert result2 != {} assert result2["data"] == base64url_encode(uuid_data_2) @@ -1086,7 +1125,7 @@ async def test_ttl_0_connected(registered_test_client: AsyncPushTestClient) -> N clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(uuid_data) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value async def test_ttl_0_not_connected(registered_test_client: AsyncPushTestClient) -> None: @@ -1141,7 +1180,7 @@ async def test_ttl_batch_expired_and_good_one(registered_test_client: AsyncPushT clean_header = registered_test_client._crypto_key.replace('"', "").rstrip("=") assert result["headers"]["encryption"] == clean_header assert result["data"] == base64url_encode(uuid_data_2) - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value result = await registered_test_client.get_notification(timeout=0.5) assert result is None @@ -1202,7 +1241,7 @@ async def test_empty_message_without_crypto_headers( """Test that a message without crypto headers, and does not have data, is accepted.""" result = await registered_test_client.send_notification(use_header=False) assert result is not None - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value assert "headers" not in result assert "data" not in result await registered_test_client.ack(result["channelID"], result["version"]) @@ -1226,14 +1265,14 @@ async def test_empty_message_with_crypto_headers( """ result = await registered_test_client.send_notification() assert result is not None - assert result["messageType"] == "notification" + assert result["messageType"] == ClientMessageType.NOTIFICATION.value assert "headers" not in result assert "data" not in result result2 = await registered_test_client.send_notification() # We shouldn't store headers for blank messages. assert result2 is not None - assert result2["messageType"] == "notification" + assert result2["messageType"] == ClientMessageType.NOTIFICATION.value assert "headers" not in result2 assert "data" not in result2