Skip to content

Commit

Permalink
feat: return broadcast errors for invalid broadcast id's
Browse files Browse the repository at this point in the history
Closes #59
  • Loading branch information
bbangert committed Aug 29, 2018
1 parent 22bc4dd commit ee7cb91
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 42 deletions.
13 changes: 7 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use db::{CheckStorageResponse, HelloResponse, RegisterResponse};
use errors::*;
use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification};
use server::Server;
use util::megaphone::{Broadcast, BroadcastSubs, BroadcastSubsInit};
use util::megaphone::{Broadcast, BroadcastSubs};
use util::{ms_since_epoch, parse_user_agent, sec_since_epoch};

// Created and handed to the AutopushServer
Expand Down Expand Up @@ -406,7 +406,7 @@ where
flags.check = check_storage;
flags.reset_uaid = reset_uaid;
flags.rotate_message_table = rotate_message_table;
let BroadcastSubsInit(initialized_subs, broadcasts) =
let (initialized_subs, broadcasts) =
srv.broadcast_init(&desired_broadcasts);
broadcast_subs.replace(initialized_subs);
let uid = Uuid::new_v4();
Expand All @@ -432,7 +432,7 @@ where
uaid: uaid.simple().to_string(),
status: 200,
use_webpush: Some(true),
broadcasts: Broadcast::into_hashmap(broadcasts),
broadcasts,
};
let auth_state_machine = AuthClientState::start(
vec![response],
Expand Down Expand Up @@ -771,15 +771,16 @@ where
Either::A(ClientMessage::BroadcastSubscribe { broadcasts }) => {
let broadcast_delta = {
let mut broadcast_subs = data.broadcast_subs.borrow_mut();
data.srv.subscribe_to_broadcasts(
data.srv.process_broadcasts(
&mut broadcast_subs,
&Broadcast::from_hashmap(broadcasts),
)
};
if let Some(delta) = broadcast_delta {

if let Some(response) = broadcast_delta {
transition!(Send {
smessages: vec![ServerMessage::Broadcast {
broadcasts: Broadcast::into_hashmap(delta),
broadcasts: response,
}],
data,
});
Expand Down
11 changes: 9 additions & 2 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ pub struct ClientAck {
pub version: String,
}

#[derive(Serialize)]
#[serde(untagged)]
pub enum BroadcastValue {
Value(String),
Nested(HashMap<String, BroadcastValue>),
}

#[derive(Serialize)]
#[serde(tag = "messageType", rename_all = "snake_case")]
pub enum ServerMessage {
Expand All @@ -94,7 +101,7 @@ pub enum ServerMessage {
status: u32,
#[serde(skip_serializing_if = "Option::is_none")]
use_webpush: Option<bool>,
broadcasts: HashMap<String, String>,
broadcasts: HashMap<String, BroadcastValue>,
},

Register {
Expand All @@ -112,7 +119,7 @@ pub enum ServerMessage {
},

Broadcast {
broadcasts: HashMap<String, String>,
broadcasts: HashMap<String, BroadcastValue>,
},

Notification(Notification),
Expand Down
59 changes: 44 additions & 15 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use errors::*;
use errors::{Error, Result};
use http;
use logging;
use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification};
use protocol::{BroadcastValue, ClientMessage, Notification, ServerMessage, ServerNotification};
use server::dispatch::{Dispatch, RequestType};
use server::metrics::metrics_from_opts;
use server::webpush_io::WebpushIo;
Expand Down Expand Up @@ -368,7 +368,9 @@ impl Server {
let client = request.and_then(move |(socket, request)| -> MyFuture<_> {
match request {
RequestType::Status => write_status(socket),
RequestType::LBHeartBeat => write_json(socket, StatusCode::Ok, serde_json::Value::from("")),
RequestType::LBHeartBeat => {
write_json(socket, StatusCode::Ok, serde_json::Value::from(""))
}
RequestType::Version => write_version_file(socket),
RequestType::LogCheck => write_log_check(socket),
RequestType::Websocket => {
Expand Down Expand Up @@ -521,27 +523,53 @@ impl Server {
}

/// Initialize broadcasts for a newly connected client
pub fn broadcast_init(&self, desired_broadcasts: &[Broadcast]) -> BroadcastSubsInit {
pub fn broadcast_init(
&self,
desired_broadcasts: &[Broadcast],
) -> (BroadcastSubs, HashMap<String, BroadcastValue>) {
debug!("Initialized broadcasts");
self.broadcaster
.borrow()
.broadcast_delta(desired_broadcasts)
let bc = self.broadcaster.borrow();
let BroadcastSubsInit(broadcast_subs, broadcasts) = bc.broadcast_delta(desired_broadcasts);
let mut response = Broadcast::into_hashmap(broadcasts);
let missing = bc.missing_broadcasts(desired_broadcasts);
if !missing.is_empty() {
response.insert(
"errors".to_string(),
BroadcastValue::Nested(Broadcast::into_hashmap(missing)),
);
}
(broadcast_subs, response)
}

/// Calculate whether there's new broadcast versions to go out
pub fn broadcast_delta(&self, broadcast_subs: &mut BroadcastSubs) -> Option<Vec<Broadcast>> {
self.broadcaster.borrow().change_count_delta(broadcast_subs)
}

/// Add new broadcasts to be tracked by a client
pub fn subscribe_to_broadcasts(
/// Process a broadcast list, adding new broadcasts to be tracked and locating missing ones
/// Returns an appropriate response for use by the prototocol
pub fn process_broadcasts(
&self,
broadcast_subs: &mut BroadcastSubs,
broadcasts: &[Broadcast],
) -> Option<Vec<Broadcast>> {
self.broadcaster
.borrow()
.subscribe_to_broadcasts(broadcast_subs, broadcasts)
) -> Option<HashMap<String, BroadcastValue>> {
let bc = self.broadcaster.borrow();
let mut response: HashMap<String, BroadcastValue> = HashMap::new();
let missing = bc.missing_broadcasts(broadcasts);
if !missing.is_empty() {
response.insert(
"errors".to_string(),
BroadcastValue::Nested(Broadcast::into_hashmap(missing)),
);
}
if let Some(delta) = bc.subscribe_to_broadcasts(broadcast_subs, broadcasts) {
response.extend(Broadcast::into_hashmap(delta));
};
if response.is_empty() {
None
} else {
Some(response)
}
}
}

Expand Down Expand Up @@ -937,9 +965,10 @@ fn write_status(socket: WebpushIo) -> MyFuture<()> {
/// Return a static copy of `version.json` from compile time.
pub fn write_version_file(socket: WebpushIo) -> MyFuture<()> {
write_json(
socket,
StatusCode::Ok,
serde_json::Value::from(include_str!("../../version.json")))
socket,
StatusCode::Ok,
serde_json::Value::from(include_str!("../../version.json")),
)
}

fn write_log_check(socket: WebpushIo) -> MyFuture<()> {
Expand Down
73 changes: 54 additions & 19 deletions src/util/megaphone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use errors::Result;
use std::collections::HashMap;
use std::time::Duration;

use protocol::BroadcastValue;

use reqwest;

// A Broadcast entry Key in a BroadcastRegistry
Expand Down Expand Up @@ -70,6 +72,16 @@ pub struct Broadcast {
version: String,
}

impl Broadcast {
/// Errors out a broadcast for broadcasts that weren't found
pub fn error(self) -> Broadcast {
Broadcast {
broadcast_id: self.broadcast_id,
version: "Broadcast not found".to_string(),
}
}
}

// Handy From impls for common hashmap to/from conversions
impl From<(String, String)> for Broadcast {
fn from(val: (String, String)) -> Broadcast {
Expand All @@ -80,9 +92,9 @@ impl From<(String, String)> for Broadcast {
}
}

impl From<Broadcast> for (String, String) {
fn from(bcast: Broadcast) -> (String, String) {
(bcast.broadcast_id, bcast.version)
impl From<Broadcast> for (String, BroadcastValue) {
fn from(bcast: Broadcast) -> (String, BroadcastValue) {
(bcast.broadcast_id, BroadcastValue::Value(bcast.version))
}
}

Expand All @@ -91,7 +103,7 @@ impl Broadcast {
val.into_iter().map(|v| v.into()).collect()
}

pub fn into_hashmap(broadcasts: Vec<Broadcast>) -> HashMap<String, String> {
pub fn into_hashmap(broadcasts: Vec<Broadcast>) -> HashMap<String, BroadcastValue> {
broadcasts.into_iter().map(|v| v.into()).collect()
}
}
Expand Down Expand Up @@ -121,9 +133,7 @@ impl BroadcastChangeTracker {
change_count: 0,
};
for srv in broadcasts {
let key = tracker
.broadcast_registry
.add_broadcast(srv.broadcast_id);
let key = tracker.broadcast_registry.add_broadcast(srv.broadcast_id);
tracker.broadcast_versions.insert(key, srv.version);
}
tracker
Expand Down Expand Up @@ -154,7 +164,9 @@ impl BroadcastChangeTracker {
return change_count;
}
self.change_count += 1;
let key = self.broadcast_registry.add_broadcast(broadcast.broadcast_id);
let key = self
.broadcast_registry
.add_broadcast(broadcast.broadcast_id);
self.broadcast_versions.insert(key, broadcast.version);
self.broadcast_list.push(BroadcastRevision {
change_count: self.change_count,
Expand All @@ -167,7 +179,8 @@ impl BroadcastChangeTracker {
///
/// Returns an error if the `broadcast` was never initialized/added.
pub fn update_broadcast(&mut self, broadcast: Broadcast) -> Result<u32> {
let key = self.broadcast_registry
let key = self
.broadcast_registry
.lookup_key(&broadcast.broadcast_id)
.ok_or("Broadcast not found")?;

Expand All @@ -181,10 +194,17 @@ impl BroadcastChangeTracker {
}

// Check to see if this broadcast has been updated since initialization
let bcast_index = self.broadcast_list
let bcast_index = self
.broadcast_list
.iter()
.enumerate()
.filter_map(|(i, bcast)| if bcast.broadcast == key { Some(i) } else { None })
.filter_map(|(i, bcast)| {
if bcast.broadcast == key {
Some(i)
} else {
None
}
})
.nth(0);
self.change_count += 1;
if let Some(bcast_index) = bcast_index {
Expand Down Expand Up @@ -265,8 +285,7 @@ impl BroadcastChangeTracker {
broadcast_subs: &mut BroadcastSubs,
broadcasts: &[Broadcast],
) -> Option<Vec<Broadcast>> {
let mut bcast_delta = self.change_count_delta(broadcast_subs)
.unwrap_or_default();
let mut bcast_delta = self.change_count_delta(broadcast_subs).unwrap_or_default();
for bcast in broadcasts.iter() {
if let Some(bcast_key) = self.broadcast_registry.lookup_key(&bcast.broadcast_id) {
if let Some(ver) = self.broadcast_versions.get(&bcast_key) {
Expand All @@ -286,6 +305,24 @@ impl BroadcastChangeTracker {
Some(bcast_delta)
}
}

/// Check a broadcast list and return unknown broadcast id's with their appropriate error
pub fn missing_broadcasts(&self, broadcasts: &[Broadcast]) -> Vec<Broadcast> {
broadcasts
.iter()
.filter_map(|b| {
if self
.broadcast_registry
.lookup_key(&b.broadcast_id)
.is_none()
{
Some(b.clone().error())
} else {
None
}
})
.collect()
}
}

#[cfg(test)]
Expand Down Expand Up @@ -345,12 +382,10 @@ mod tests {
let delta = tracker
.subscribe_to_broadcasts(
&mut broadcast_subs,
&vec![
Broadcast {
broadcast_id: String::from("bcastc"),
version: String::from("revision_alpha"),
},
],
&vec![Broadcast {
broadcast_id: String::from("bcastc"),
version: String::from("revision_alpha"),
}],
)
.unwrap();
assert_eq!(delta.len(), 1);
Expand Down
39 changes: 39 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,23 @@ def test_broadcast_update_on_connect(self):

yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_update_on_connect_with_errors(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
self.mock_megaphone.polled.clear()
self.mock_megaphone.polled.wait(timeout=5)

old_ver = {"kinto:123": "ver0", "kinto:456": "ver1"}
client = Client(self._ws_url)
yield client.connect()
result = yield client.hello(services=old_ver)
assert result != {}
assert result["use_webpush"] is True
assert result["broadcasts"]["kinto:123"] == "ver1"
assert result["broadcasts"]["errors"][
"kinto:456"] == "Broadcast not found"
yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_subscribe(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
Expand Down Expand Up @@ -1053,6 +1070,28 @@ def test_broadcast_subscribe(self):

yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_subscribe_with_errors(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
self.mock_megaphone.polled.clear()
self.mock_megaphone.polled.wait(timeout=5)

old_ver = {"kinto:123": "ver0", "kinto:456": "ver1"}
client = Client(self._ws_url)
yield client.connect()
result = yield client.hello()
assert result != {}
assert result["use_webpush"] is True
assert result["broadcasts"] == {}

client.broadcast_subscribe(old_ver)
result = yield client.get_broadcast()
assert result["broadcasts"]["kinto:123"] == "ver1"
assert result["broadcasts"]["errors"][
"kinto:456"] == "Broadcast not found"

yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_no_changes(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
Expand Down

0 comments on commit ee7cb91

Please sign in to comment.