diff --git a/src/client.rs b/src/client.rs index b78856347..b4d3f206b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -29,7 +29,7 @@ use db::{CheckStorageResponse, HelloResponse, RegisterResponse}; use errors::*; use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification}; use server::Server; -use util::megaphone::{Broadcast, BroadcastSubs, BroadcastSubsInit}; +use util::megaphone::{Broadcast, BroadcastSubs}; use util::{ms_since_epoch, parse_user_agent, sec_since_epoch}; // Created and handed to the AutopushServer @@ -406,7 +406,7 @@ where flags.check = check_storage; flags.reset_uaid = reset_uaid; flags.rotate_message_table = rotate_message_table; - let BroadcastSubsInit(initialized_subs, broadcasts) = + let (initialized_subs, broadcasts) = srv.broadcast_init(&desired_broadcasts); broadcast_subs.replace(initialized_subs); let uid = Uuid::new_v4(); @@ -432,7 +432,7 @@ where uaid: uaid.simple().to_string(), status: 200, use_webpush: Some(true), - broadcasts: Broadcast::into_hashmap(broadcasts), + broadcasts, }; let auth_state_machine = AuthClientState::start( vec![response], @@ -771,15 +771,16 @@ where Either::A(ClientMessage::BroadcastSubscribe { broadcasts }) => { let broadcast_delta = { let mut broadcast_subs = data.broadcast_subs.borrow_mut(); - data.srv.subscribe_to_broadcasts( + data.srv.process_broadcasts( &mut broadcast_subs, &Broadcast::from_hashmap(broadcasts), ) }; - if let Some(delta) = broadcast_delta { + + if let Some(response) = broadcast_delta { transition!(Send { smessages: vec![ServerMessage::Broadcast { - broadcasts: Broadcast::into_hashmap(delta), + broadcasts: response, }], data, }); diff --git a/src/protocol.rs b/src/protocol.rs index 8d910b8c6..ff4f23e6f 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -86,6 +86,13 @@ pub struct ClientAck { pub version: String, } +#[derive(Serialize)] +#[serde(untagged)] +pub enum BroadcastValue { + Value(String), + Nested(HashMap), +} + #[derive(Serialize)] #[serde(tag = "messageType", rename_all = "snake_case")] pub enum ServerMessage { @@ -94,7 +101,7 @@ pub enum ServerMessage { status: u32, #[serde(skip_serializing_if = "Option::is_none")] use_webpush: Option, - broadcasts: HashMap, + broadcasts: HashMap, }, Register { @@ -112,7 +119,7 @@ pub enum ServerMessage { }, Broadcast { - broadcasts: HashMap, + broadcasts: HashMap, }, Notification(Notification), diff --git a/src/server/mod.rs b/src/server/mod.rs index 1594a3b24..03ab6fd4b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -41,7 +41,7 @@ use errors::*; use errors::{Error, Result}; use http; use logging; -use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification}; +use protocol::{BroadcastValue, ClientMessage, Notification, ServerMessage, ServerNotification}; use server::dispatch::{Dispatch, RequestType}; use server::metrics::metrics_from_opts; use server::webpush_io::WebpushIo; @@ -368,7 +368,9 @@ impl Server { let client = request.and_then(move |(socket, request)| -> MyFuture<_> { match request { RequestType::Status => write_status(socket), - RequestType::LBHeartBeat => write_json(socket, StatusCode::Ok, serde_json::Value::from("")), + RequestType::LBHeartBeat => { + write_json(socket, StatusCode::Ok, serde_json::Value::from("")) + } RequestType::Version => write_version_file(socket), RequestType::LogCheck => write_log_check(socket), RequestType::Websocket => { @@ -521,11 +523,22 @@ impl Server { } /// Initialize broadcasts for a newly connected client - pub fn broadcast_init(&self, desired_broadcasts: &[Broadcast]) -> BroadcastSubsInit { + pub fn broadcast_init( + &self, + desired_broadcasts: &[Broadcast], + ) -> (BroadcastSubs, HashMap) { debug!("Initialized broadcasts"); - self.broadcaster - .borrow() - .broadcast_delta(desired_broadcasts) + let bc = self.broadcaster.borrow(); + let BroadcastSubsInit(broadcast_subs, broadcasts) = bc.broadcast_delta(desired_broadcasts); + let mut response = Broadcast::into_hashmap(broadcasts); + let missing = bc.missing_broadcasts(desired_broadcasts); + if !missing.is_empty() { + response.insert( + "errors".to_string(), + BroadcastValue::Nested(Broadcast::into_hashmap(missing)), + ); + } + (broadcast_subs, response) } /// Calculate whether there's new broadcast versions to go out @@ -533,15 +546,30 @@ impl Server { self.broadcaster.borrow().change_count_delta(broadcast_subs) } - /// Add new broadcasts to be tracked by a client - pub fn subscribe_to_broadcasts( + /// Process a broadcast list, adding new broadcasts to be tracked and locating missing ones + /// Returns an appropriate response for use by the prototocol + pub fn process_broadcasts( &self, broadcast_subs: &mut BroadcastSubs, broadcasts: &[Broadcast], - ) -> Option> { - self.broadcaster - .borrow() - .subscribe_to_broadcasts(broadcast_subs, broadcasts) + ) -> Option> { + let bc = self.broadcaster.borrow(); + let mut response: HashMap = HashMap::new(); + let missing = bc.missing_broadcasts(broadcasts); + if !missing.is_empty() { + response.insert( + "errors".to_string(), + BroadcastValue::Nested(Broadcast::into_hashmap(missing)), + ); + } + if let Some(delta) = bc.subscribe_to_broadcasts(broadcast_subs, broadcasts) { + response.extend(Broadcast::into_hashmap(delta)); + }; + if response.is_empty() { + None + } else { + Some(response) + } } } @@ -937,9 +965,10 @@ fn write_status(socket: WebpushIo) -> MyFuture<()> { /// Return a static copy of `version.json` from compile time. pub fn write_version_file(socket: WebpushIo) -> MyFuture<()> { write_json( - socket, - StatusCode::Ok, - serde_json::Value::from(include_str!("../../version.json"))) + socket, + StatusCode::Ok, + serde_json::Value::from(include_str!("../../version.json")), + ) } fn write_log_check(socket: WebpushIo) -> MyFuture<()> { diff --git a/src/util/megaphone.rs b/src/util/megaphone.rs index df1bcae5f..064c51891 100644 --- a/src/util/megaphone.rs +++ b/src/util/megaphone.rs @@ -2,6 +2,8 @@ use errors::Result; use std::collections::HashMap; use std::time::Duration; +use protocol::BroadcastValue; + use reqwest; // A Broadcast entry Key in a BroadcastRegistry @@ -70,6 +72,16 @@ pub struct Broadcast { version: String, } +impl Broadcast { + /// Errors out a broadcast for broadcasts that weren't found + pub fn error(self) -> Broadcast { + Broadcast { + broadcast_id: self.broadcast_id, + version: "Broadcast not found".to_string(), + } + } +} + // Handy From impls for common hashmap to/from conversions impl From<(String, String)> for Broadcast { fn from(val: (String, String)) -> Broadcast { @@ -80,9 +92,9 @@ impl From<(String, String)> for Broadcast { } } -impl From for (String, String) { - fn from(bcast: Broadcast) -> (String, String) { - (bcast.broadcast_id, bcast.version) +impl From for (String, BroadcastValue) { + fn from(bcast: Broadcast) -> (String, BroadcastValue) { + (bcast.broadcast_id, BroadcastValue::Value(bcast.version)) } } @@ -91,7 +103,7 @@ impl Broadcast { val.into_iter().map(|v| v.into()).collect() } - pub fn into_hashmap(broadcasts: Vec) -> HashMap { + pub fn into_hashmap(broadcasts: Vec) -> HashMap { broadcasts.into_iter().map(|v| v.into()).collect() } } @@ -121,9 +133,7 @@ impl BroadcastChangeTracker { change_count: 0, }; for srv in broadcasts { - let key = tracker - .broadcast_registry - .add_broadcast(srv.broadcast_id); + let key = tracker.broadcast_registry.add_broadcast(srv.broadcast_id); tracker.broadcast_versions.insert(key, srv.version); } tracker @@ -154,7 +164,9 @@ impl BroadcastChangeTracker { return change_count; } self.change_count += 1; - let key = self.broadcast_registry.add_broadcast(broadcast.broadcast_id); + let key = self + .broadcast_registry + .add_broadcast(broadcast.broadcast_id); self.broadcast_versions.insert(key, broadcast.version); self.broadcast_list.push(BroadcastRevision { change_count: self.change_count, @@ -167,7 +179,8 @@ impl BroadcastChangeTracker { /// /// Returns an error if the `broadcast` was never initialized/added. pub fn update_broadcast(&mut self, broadcast: Broadcast) -> Result { - let key = self.broadcast_registry + let key = self + .broadcast_registry .lookup_key(&broadcast.broadcast_id) .ok_or("Broadcast not found")?; @@ -181,10 +194,17 @@ impl BroadcastChangeTracker { } // Check to see if this broadcast has been updated since initialization - let bcast_index = self.broadcast_list + let bcast_index = self + .broadcast_list .iter() .enumerate() - .filter_map(|(i, bcast)| if bcast.broadcast == key { Some(i) } else { None }) + .filter_map(|(i, bcast)| { + if bcast.broadcast == key { + Some(i) + } else { + None + } + }) .nth(0); self.change_count += 1; if let Some(bcast_index) = bcast_index { @@ -265,8 +285,7 @@ impl BroadcastChangeTracker { broadcast_subs: &mut BroadcastSubs, broadcasts: &[Broadcast], ) -> Option> { - let mut bcast_delta = self.change_count_delta(broadcast_subs) - .unwrap_or_default(); + let mut bcast_delta = self.change_count_delta(broadcast_subs).unwrap_or_default(); for bcast in broadcasts.iter() { if let Some(bcast_key) = self.broadcast_registry.lookup_key(&bcast.broadcast_id) { if let Some(ver) = self.broadcast_versions.get(&bcast_key) { @@ -286,6 +305,24 @@ impl BroadcastChangeTracker { Some(bcast_delta) } } + + /// Check a broadcast list and return unknown broadcast id's with their appropriate error + pub fn missing_broadcasts(&self, broadcasts: &[Broadcast]) -> Vec { + broadcasts + .iter() + .filter_map(|b| { + if self + .broadcast_registry + .lookup_key(&b.broadcast_id) + .is_none() + { + Some(b.clone().error()) + } else { + None + } + }) + .collect() + } } #[cfg(test)] @@ -345,12 +382,10 @@ mod tests { let delta = tracker .subscribe_to_broadcasts( &mut broadcast_subs, - &vec![ - Broadcast { - broadcast_id: String::from("bcastc"), - version: String::from("revision_alpha"), - }, - ], + &vec![Broadcast { + broadcast_id: String::from("bcastc"), + version: String::from("revision_alpha"), + }], ) .unwrap(); assert_eq!(delta.len(), 1); diff --git a/tests/test_integration.py b/tests/test_integration.py index d862d0464..712ddf90f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1026,6 +1026,23 @@ def test_broadcast_update_on_connect(self): yield self.shut_down(client) + @inlineCallbacks + def test_broadcast_update_on_connect_with_errors(self): + self.mock_megaphone.services = {"kinto:123": "ver1"} + self.mock_megaphone.polled.clear() + self.mock_megaphone.polled.wait(timeout=5) + + old_ver = {"kinto:123": "ver0", "kinto:456": "ver1"} + client = Client(self._ws_url) + yield client.connect() + result = yield client.hello(services=old_ver) + assert result != {} + assert result["use_webpush"] is True + assert result["broadcasts"]["kinto:123"] == "ver1" + assert result["broadcasts"]["errors"][ + "kinto:456"] == "Broadcast not found" + yield self.shut_down(client) + @inlineCallbacks def test_broadcast_subscribe(self): self.mock_megaphone.services = {"kinto:123": "ver1"} @@ -1053,6 +1070,28 @@ def test_broadcast_subscribe(self): yield self.shut_down(client) + @inlineCallbacks + def test_broadcast_subscribe_with_errors(self): + self.mock_megaphone.services = {"kinto:123": "ver1"} + self.mock_megaphone.polled.clear() + self.mock_megaphone.polled.wait(timeout=5) + + old_ver = {"kinto:123": "ver0", "kinto:456": "ver1"} + client = Client(self._ws_url) + yield client.connect() + result = yield client.hello() + assert result != {} + assert result["use_webpush"] is True + assert result["broadcasts"] == {} + + client.broadcast_subscribe(old_ver) + result = yield client.get_broadcast() + assert result["broadcasts"]["kinto:123"] == "ver1" + assert result["broadcasts"]["errors"][ + "kinto:456"] == "Broadcast not found" + + yield self.shut_down(client) + @inlineCallbacks def test_broadcast_no_changes(self): self.mock_megaphone.services = {"kinto:123": "ver1"}