diff --git a/autopush_rs/src/client.rs b/autopush_rs/src/client.rs index ff87fdd7..f2f12705 100644 --- a/autopush_rs/src/client.rs +++ b/autopush_rs/src/client.rs @@ -5,6 +5,7 @@ //! of connected clients. Note that it's expected there'll be a lot of connected //! clients, so this may appears relatively heavily optimized! +use std::collections::HashMap; use std::rc::Rc; use cadence::prelude::*; @@ -23,6 +24,7 @@ use errors::*; use protocol::{ClientAck, ClientMessage, ServerMessage, ServerNotification, Notification}; use server::Server; use util::parse_user_agent; +use util::megaphone::{ClientServices, Service, ServiceClientInit}; pub struct RegisteredClient { pub uaid: Uuid, @@ -68,6 +70,7 @@ pub struct WebPushClient { uaid: Uuid, rx: mpsc::UnboundedReceiver, flags: ClientFlags, + broadcast_services: ClientServices, message_month: String, unacked_direct_notifs: Vec, unacked_stored_notifs: Vec, @@ -120,7 +123,7 @@ impl ClientFlags { pub enum ClientState { WaitingForHello(Timeout), - WaitingForProcessHello(MyFuture), + WaitingForProcessHello(MyFuture, Vec), WaitingForRegister(Uuid, MyFuture), WaitingForUnRegister(Uuid, MyFuture), WaitingForCheckStorage(MyFuture), @@ -187,6 +190,14 @@ where self.data.shutdown(); } + pub fn broadcast_delta(&mut self) -> Option> { + if let Some(ref mut webpush) = self.data.webpush { + self.data.srv.broadcast_delta(&mut webpush.broadcast_services) + } else { + None + } + } + fn transition(&mut self) -> Poll { let host = self.data.host.clone(); let next_state = match self.state { @@ -271,20 +282,25 @@ where } ClientState::WaitingForHello(ref mut timeout) => { debug!("State: WaitingForHello"); - let uaid = match try_ready!(self.data.input_with_timeout(timeout)) { + let (uaid, services) = match try_ready!(self.data.input_with_timeout(timeout)) { ClientMessage::Hello { uaid, use_webpush: Some(true), + broadcasts, .. - } => uaid.and_then(|uaid| Uuid::parse_str(uaid.as_str()).ok()), + } => ( + uaid.and_then(|uaid| Uuid::parse_str(uaid.as_str()).ok()), + Service::from_hashmap(broadcasts.unwrap_or(HashMap::new())) + ), _ => return Err("Invalid message, must be hello".into()), }; let connected_at = time::precise_time_ns() / 1000; ClientState::WaitingForProcessHello( self.data.srv.hello(&connected_at, uaid.as_ref()), + services, ) } - ClientState::WaitingForProcessHello(ref mut response) => { + ClientState::WaitingForProcessHello(ref mut response, ref services) => { debug!("State: WaitingForProcessHello"); match try_ready!(response.poll()) { call::HelloResponse { @@ -302,6 +318,7 @@ where rotate_message_table, check_storage, connected_at, + services, ) } call::HelloResponse { uaid: None, .. } => { @@ -422,6 +439,19 @@ where return Ok(next_state.into()); } match try_ready!(self.data.input_or_notif()) { + Either::A(ClientMessage::BroadcastSubscribe { broadcasts }) => { + let webpush = self.data.webpush.as_mut().unwrap(); + let service_delta = self.data.srv.client_service_add_service( + &mut webpush.broadcast_services, + &Service::from_hashmap(broadcasts), + ).unwrap_or(Vec::new()); + ClientState::FinishSend( + Some(ServerMessage::BroadcastSubscribe { + broadcasts: Service::to_hashmap(service_delta) + }), + Some(Box::new(ClientState::WaitingForAcks)), + ) + } Either::A(ClientMessage::Register { channel_id, key }) => { self.data.process_register(channel_id, key) } @@ -470,6 +500,19 @@ where return Ok(ClientState::CheckStorage.into()); } match try_ready!(self.data.input_or_notif()) { + Either::A(ClientMessage::BroadcastSubscribe { broadcasts }) => { + let webpush = self.data.webpush.as_mut().unwrap(); + let service_delta = self.data.srv.client_service_add_service( + &mut webpush.broadcast_services, + &Service::from_hashmap(broadcasts), + ).unwrap_or(Vec::new()); + ClientState::FinishSend( + Some(ServerMessage::BroadcastSubscribe { + broadcasts: Service::to_hashmap(service_delta) + }), + Some(Box::new(ClientState::Await)), + ) + } Either::A(ClientMessage::Register { channel_id, key }) => { self.data.process_register(channel_id, key) } @@ -479,7 +522,7 @@ where Either::A(ClientMessage::Nack { .. }) => { self.data.srv.metrics.incr("ua.command.nack").ok(); self.data.webpush.as_mut().unwrap().stats.nacks += 1; - ClientState::WaitingForAcks + ClientState::Await } Either::B(ServerNotification::Notification(notif)) => { let webpush = self.data.webpush.as_mut().unwrap(); @@ -570,6 +613,7 @@ where rotate_message_table: bool, check_storage: bool, connected_at: u64, + services: &Vec, ) -> ClientState { let (tx, rx) = mpsc::unbounded(); let mut flags = ClientFlags::new(); @@ -577,8 +621,10 @@ where flags.reset_uaid = reset_uaid; flags.rotate_message_table = rotate_message_table; + let ServiceClientInit(client_services, broadcasts) = self.srv.broadcast_init(services); self.webpush = Some(WebPushClient { uaid, + broadcast_services: client_services, flags, rx, message_month, @@ -608,6 +654,7 @@ where uaid: uaid.hyphenated().to_string(), status: 200, use_webpush: Some(true), + broadcasts: Service::to_hashmap(broadcasts), }; ClientState::FinishSend(Some(response), Some(Box::new(ClientState::Await))) } diff --git a/autopush_rs/src/protocol.rs b/autopush_rs/src/protocol.rs index 04975f94..55b959e2 100644 --- a/autopush_rs/src/protocol.rs +++ b/autopush_rs/src/protocol.rs @@ -17,7 +17,7 @@ pub enum ServerNotification { } #[derive(Deserialize)] -#[serde(tag = "messageType", rename_all = "lowercase")] +#[serde(tag = "messageType", rename_all = "snake_case")] pub enum ClientMessage { Hello { uaid: Option, @@ -25,6 +25,8 @@ pub enum ClientMessage { channel_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] use_webpush: Option, + #[serde(skip_serializing_if = "Option::is_none")] + broadcasts: Option>, }, Register { @@ -39,6 +41,10 @@ pub enum ClientMessage { code: Option, }, + BroadcastSubscribe { + broadcasts: HashMap, + }, + Ack { updates: Vec }, Nack { @@ -56,13 +62,14 @@ pub struct ClientAck { } #[derive(Serialize)] -#[serde(tag = "messageType", rename_all = "lowercase")] +#[serde(tag = "messageType", rename_all = "snake_case")] pub enum ServerMessage { Hello { uaid: String, status: u32, #[serde(skip_serializing_if = "Option::is_none")] use_webpush: Option, + broadcasts: HashMap, }, Register { @@ -79,6 +86,14 @@ pub enum ServerMessage { status: u32, }, + Broadcast { + broadcasts: HashMap, + }, + + BroadcastSubscribe { + broadcasts: HashMap, + }, + Notification(Notification), } diff --git a/autopush_rs/src/server/mod.rs b/autopush_rs/src/server/mod.rs index 9cac103c..23c10a5c 100644 --- a/autopush_rs/src/server/mod.rs +++ b/autopush_rs/src/server/mod.rs @@ -42,6 +42,7 @@ use server::dispatch::{Dispatch, RequestType}; use server::metrics::metrics_from_opts; use server::webpush_io::WebpushIo; use util::{self, RcObject, timeout}; +use util::megaphone::{ClientServices, Service, ServiceClientInit, ServiceChangeTracker}; mod dispatch; mod metrics; @@ -84,6 +85,7 @@ pub struct AutopushServerOptions { pub struct Server { uaids: RefCell>, + broadcaster: ServiceChangeTracker, open_connections: Cell, tls_acceptor: Option, pub tx: queue::Sender, @@ -317,6 +319,7 @@ impl Server { let core = Core::new()?; let srv = Rc::new(Server { opts: opts.clone(), + broadcaster: ServiceChangeTracker::new(Vec::new()), uaids: RefCell::new(HashMap::new()), open_connections: Cell::new(0), handle: core.handle(), @@ -478,6 +481,26 @@ impl Server { let mut uaids = self.uaids.borrow_mut(); uaids.remove(uaid).expect("uaid not registered"); } + + /// Generate a new service client list for a newly connected client + pub fn broadcast_init(&self, services: &[Service]) -> ServiceClientInit { + debug!("Initialized broadcast services"); + self.broadcaster.service_delta(services) + } + + /// Calculate whether there's new service versions to go out + pub fn broadcast_delta(&self, client_services: &mut ClientServices) -> Option> { + self.broadcaster.change_count_delta(client_services) + } + + /// Add services to be tracked by a client + pub fn client_service_add_service( + &self, + client_services: &mut ClientServices, + services: &[Service], + ) -> Option> { + self.broadcaster.client_service_add_service(client_services, services) + } } impl Drop for Server { @@ -546,10 +569,28 @@ impl Future for PingManager { let mut socket = self.socket.borrow_mut(); loop { if socket.ping { + // Don't check if we already have a delta to broadcast + if socket.broadcast_delta.is_none() { + // Determine if we can do a broadcast check, we need a connected webpush client + if let CloseState::Exchange(ref mut client) = self.client { + if let Some(delta) = client.broadcast_delta() { + socket.broadcast_delta = Some(delta); + } + } + } + if socket.send_ping()?.is_ready() { - let at = Instant::now() + self.srv.opts.auto_ping_timeout; - self.timeout.reset(at); - self.waiting = WaitingFor::Pong; + // If we just sent a broadcast, reset the ping interval and clear the delta + if socket.broadcast_delta.is_some() { + let at = Instant::now() + self.srv.opts.auto_ping_interval; + self.timeout.reset(at); + self.waiting = WaitingFor::SendPing; + socket.broadcast_delta = None; + } else { + let at = Instant::now() + self.srv.opts.auto_ping_timeout; + self.timeout.reset(at); + self.waiting = WaitingFor::Pong; + } } else { break; } @@ -641,6 +682,7 @@ struct WebpushSocket { pong_received: bool, ping: bool, pong_timeout: bool, + broadcast_delta: Option>, } impl WebpushSocket { @@ -650,6 +692,7 @@ impl WebpushSocket { pong_received: false, ping: false, pong_timeout: false, + broadcast_delta: None, } } @@ -659,8 +702,20 @@ impl WebpushSocket { Error: From, { if self.ping { - debug!("sending a ping"); - match self.inner.start_send(Message::Ping(Vec::new()))? { + let mut msg = Message::Ping(Vec::new()); + if let Some(broadcasts) = self.broadcast_delta.clone() { + debug!("sending a broadcast delta"); + let server_msg = ServerMessage::Broadcast { + broadcasts: Service::to_hashmap(broadcasts) + }; + let s = serde_json::to_string(&server_msg).chain_err( + || "failed to serialize", + )?; + msg = Message::Text(s); + } else { + debug!("sending a ping"); + } + match self.inner.start_send(msg)? { AsyncSink::Ready => { debug!("ping sent"); self.ping = false; diff --git a/autopush_rs/src/util/megaphone.rs b/autopush_rs/src/util/megaphone.rs index 7f8eb2f9..feaf8b2c 100644 --- a/autopush_rs/src/util/megaphone.rs +++ b/autopush_rs/src/util/megaphone.rs @@ -60,12 +60,35 @@ struct ServiceRevision { // A provided Service/Version used for `ChangeList` initialization, client comparisons, and // outgoing deltas -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct Service { service_id: String, version: String, } +// Handy From impls for common hashmap to/from conversions +impl From<(String, String)> for Service { + fn from(val: (String, String)) -> Service { + Service { service_id: val.0, version: val.1 } + } +} + +impl From for (String, String) { + fn from(svc: Service) -> (String, String) { + (svc.service_id, svc.version) + } +} + +impl Service { + pub fn from_hashmap(val: HashMap) -> Vec { + val.into_iter().map(|v| v.into()).collect() + } + + pub fn to_hashmap(service_vec: Vec) -> HashMap { + service_vec.into_iter().map(|v| v.into()).collect() + } +} + // ServiceChangeTracker tracks the services, their change_count, and the service lookup registry #[derive(Debug)] pub struct ServiceChangeTracker { @@ -155,6 +178,9 @@ impl ServiceChangeTracker { if svc.change_count <= client_set.change_count { break; } + if !client_set.service_list.contains(&svc.service) { + continue; + } if let Some(ver) = self.service_versions.get(&svc.service) { if let Some(svc_id) = self.service_registry.lookup_id(svc.service) { svc_delta.push(Service { @@ -174,7 +200,7 @@ impl ServiceChangeTracker { /// Returns a delta for `services` that are out of date with the latest version and a new /// `ClientSet``. - pub fn service_delta(&self, services: Vec) -> ServiceClientInit { + pub fn service_delta(&self, services: &[Service]) -> ServiceClientInit { let mut svc_list = Vec::new(); let mut svc_delta = Vec::new(); for svc in services.iter() { @@ -198,6 +224,31 @@ impl ServiceChangeTracker { svc_delta, ) } + + /// Update a `ClientServices` to account for a new service. + /// + /// Returns services that have changed. + pub fn client_service_add_service(&self, client_service: &mut ClientServices, services: &[Service]) -> Option> { + let mut svc_delta = self.change_count_delta(client_service).unwrap_or(Vec::new()); + for svc in services.iter() { + if let Some(svc_key) = self.service_registry.lookup_key(svc.service_id.clone()) { + if let Some(ver) = self.service_versions.get(&svc_key) { + if *ver != svc.version { + svc_delta.push(Service { + service_id: svc.service_id.clone(), + version: (*ver).clone(), + }); + } + } + client_service.service_list.push(svc_key) + } + } + if svc_delta.is_empty() { + None + } else { + Some(svc_delta) + } + } } #[cfg(test)] @@ -216,7 +267,7 @@ mod tests { let services = make_service_base(); let client_services = services.clone(); let mut svc_chg_tracker = ServiceChangeTracker::new(services); - let ServiceClientInit(mut client_svc, delta) = svc_chg_tracker.service_delta(client_services); + let ServiceClientInit(mut client_svc, delta) = svc_chg_tracker.service_delta(&client_services); assert_eq!(delta.len(), 0); assert_eq!(client_svc.change_count, 0); assert_eq!(client_svc.service_list.len(), 2); @@ -228,16 +279,28 @@ mod tests { assert!(delta.is_some()); let delta = delta.unwrap(); assert_eq!(delta.len(), 1); + } + + #[test] + fn test_service_change_handles_new_services() { + let services = make_service_base(); + let client_services = services.clone(); + let mut svc_chg_tracker = ServiceChangeTracker::new(services); + let ServiceClientInit(mut client_svc, _) = svc_chg_tracker.service_delta(&client_services); svc_chg_tracker.add_service( Service { service_id: String::from("svcc"), version: String::from("revmega") } ); let delta = svc_chg_tracker.change_count_delta(&mut client_svc); - assert!(delta.is_some()); - let delta = delta.unwrap(); + assert!(delta.is_none()); + + let delta = svc_chg_tracker.client_service_add_service( + &mut client_svc, + &vec![Service { service_id: String::from("svcc"), version: String::from("revision_alpha") } ], + ).unwrap(); assert_eq!(delta.len(), 1); assert_eq!(delta[0].version, String::from("revmega")); - assert_eq!(client_svc.change_count, 2); - assert_eq!(svc_chg_tracker.service_list.len(), 2); + assert_eq!(client_svc.change_count, 1); + assert_eq!(svc_chg_tracker.service_list.len(), 1); } } diff --git a/autopush_rs/src/util/mod.rs b/autopush_rs/src/util/mod.rs index 6a991d73..9ee19847 100644 --- a/autopush_rs/src/util/mod.rs +++ b/autopush_rs/src/util/mod.rs @@ -16,7 +16,7 @@ use errors::*; mod autojson; mod aws; -mod megaphone; +pub mod megaphone; mod rc; mod send_all; mod user_agent;