Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return broadcast errors for invalid broadcast id's #63

Merged
merged 1 commit into from
Aug 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use db::{CheckStorageResponse, HelloResponse, RegisterResponse};
use errors::*;
use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification};
use server::Server;
use util::megaphone::{Broadcast, BroadcastSubs, BroadcastSubsInit};
use util::megaphone::{Broadcast, BroadcastSubs};
use util::{ms_since_epoch, parse_user_agent, sec_since_epoch};

// Created and handed to the AutopushServer
Expand Down Expand Up @@ -406,7 +406,7 @@ where
flags.check = check_storage;
flags.reset_uaid = reset_uaid;
flags.rotate_message_table = rotate_message_table;
let BroadcastSubsInit(initialized_subs, broadcasts) =
let (initialized_subs, broadcasts) =
srv.broadcast_init(&desired_broadcasts);
broadcast_subs.replace(initialized_subs);
let uid = Uuid::new_v4();
Expand All @@ -432,7 +432,7 @@ where
uaid: uaid.simple().to_string(),
status: 200,
use_webpush: Some(true),
broadcasts: Broadcast::into_hashmap(broadcasts),
broadcasts,
};
let auth_state_machine = AuthClientState::start(
vec![response],
Expand Down Expand Up @@ -771,15 +771,16 @@ where
Either::A(ClientMessage::BroadcastSubscribe { broadcasts }) => {
let broadcast_delta = {
let mut broadcast_subs = data.broadcast_subs.borrow_mut();
data.srv.subscribe_to_broadcasts(
data.srv.process_broadcasts(
&mut broadcast_subs,
&Broadcast::from_hashmap(broadcasts),
)
};
if let Some(delta) = broadcast_delta {

if let Some(response) = broadcast_delta {
transition!(Send {
smessages: vec![ServerMessage::Broadcast {
broadcasts: Broadcast::into_hashmap(delta),
broadcasts: response,
}],
data,
});
Expand Down
11 changes: 9 additions & 2 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ pub struct ClientAck {
pub version: String,
}

#[derive(Serialize)]
#[serde(untagged)]
pub enum BroadcastValue {
Value(String),
Nested(HashMap<String, BroadcastValue>),
}

#[derive(Serialize)]
#[serde(tag = "messageType", rename_all = "snake_case")]
pub enum ServerMessage {
Expand All @@ -94,7 +101,7 @@ pub enum ServerMessage {
status: u32,
#[serde(skip_serializing_if = "Option::is_none")]
use_webpush: Option<bool>,
broadcasts: HashMap<String, String>,
broadcasts: HashMap<String, BroadcastValue>,
},

Register {
Expand All @@ -112,7 +119,7 @@ pub enum ServerMessage {
},

Broadcast {
broadcasts: HashMap<String, String>,
broadcasts: HashMap<String, BroadcastValue>,
},

Notification(Notification),
Expand Down
59 changes: 44 additions & 15 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use errors::*;
use errors::{Error, Result};
use http;
use logging;
use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification};
use protocol::{BroadcastValue, ClientMessage, Notification, ServerMessage, ServerNotification};
use server::dispatch::{Dispatch, RequestType};
use server::metrics::metrics_from_opts;
use server::webpush_io::WebpushIo;
Expand Down Expand Up @@ -368,7 +368,9 @@ impl Server {
let client = request.and_then(move |(socket, request)| -> MyFuture<_> {
match request {
RequestType::Status => write_status(socket),
RequestType::LBHeartBeat => write_json(socket, StatusCode::Ok, serde_json::Value::from("")),
RequestType::LBHeartBeat => {
write_json(socket, StatusCode::Ok, serde_json::Value::from(""))
}
RequestType::Version => write_version_file(socket),
RequestType::LogCheck => write_log_check(socket),
RequestType::Websocket => {
Expand Down Expand Up @@ -521,27 +523,53 @@ impl Server {
}

/// Initialize broadcasts for a newly connected client
pub fn broadcast_init(&self, desired_broadcasts: &[Broadcast]) -> BroadcastSubsInit {
pub fn broadcast_init(
&self,
desired_broadcasts: &[Broadcast],
) -> (BroadcastSubs, HashMap<String, BroadcastValue>) {
debug!("Initialized broadcasts");
self.broadcaster
.borrow()
.broadcast_delta(desired_broadcasts)
let bc = self.broadcaster.borrow();
let BroadcastSubsInit(broadcast_subs, broadcasts) = bc.broadcast_delta(desired_broadcasts);
let mut response = Broadcast::into_hashmap(broadcasts);
let missing = bc.missing_broadcasts(desired_broadcasts);
if !missing.is_empty() {
response.insert(
"errors".to_string(),
BroadcastValue::Nested(Broadcast::into_hashmap(missing)),
);
}
(broadcast_subs, response)
}

/// Calculate whether there's new broadcast versions to go out
pub fn broadcast_delta(&self, broadcast_subs: &mut BroadcastSubs) -> Option<Vec<Broadcast>> {
self.broadcaster.borrow().change_count_delta(broadcast_subs)
}

/// Add new broadcasts to be tracked by a client
pub fn subscribe_to_broadcasts(
/// Process a broadcast list, adding new broadcasts to be tracked and locating missing ones
/// Returns an appropriate response for use by the prototocol
pub fn process_broadcasts(
&self,
broadcast_subs: &mut BroadcastSubs,
broadcasts: &[Broadcast],
) -> Option<Vec<Broadcast>> {
self.broadcaster
.borrow()
.subscribe_to_broadcasts(broadcast_subs, broadcasts)
) -> Option<HashMap<String, BroadcastValue>> {
let bc = self.broadcaster.borrow();
let mut response: HashMap<String, BroadcastValue> = HashMap::new();
let missing = bc.missing_broadcasts(broadcasts);
if !missing.is_empty() {
response.insert(
"errors".to_string(),
BroadcastValue::Nested(Broadcast::into_hashmap(missing)),
);
}
if let Some(delta) = bc.subscribe_to_broadcasts(broadcast_subs, broadcasts) {
response.extend(Broadcast::into_hashmap(delta));
};
if response.is_empty() {
None
} else {
Some(response)
}
}
}

Expand Down Expand Up @@ -937,9 +965,10 @@ fn write_status(socket: WebpushIo) -> MyFuture<()> {
/// Return a static copy of `version.json` from compile time.
pub fn write_version_file(socket: WebpushIo) -> MyFuture<()> {
write_json(
socket,
StatusCode::Ok,
serde_json::Value::from(include_str!("../../version.json")))
socket,
StatusCode::Ok,
serde_json::Value::from(include_str!("../../version.json")),
)
}

fn write_log_check(socket: WebpushIo) -> MyFuture<()> {
Expand Down
73 changes: 54 additions & 19 deletions src/util/megaphone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use errors::Result;
use std::collections::HashMap;
use std::time::Duration;

use protocol::BroadcastValue;

use reqwest;

// A Broadcast entry Key in a BroadcastRegistry
Expand Down Expand Up @@ -70,6 +72,16 @@ pub struct Broadcast {
version: String,
}

impl Broadcast {
/// Errors out a broadcast for broadcasts that weren't found
pub fn error(self) -> Broadcast {
Broadcast {
broadcast_id: self.broadcast_id,
version: "Broadcast not found".to_string(),
}
}
}

// Handy From impls for common hashmap to/from conversions
impl From<(String, String)> for Broadcast {
fn from(val: (String, String)) -> Broadcast {
Expand All @@ -80,9 +92,9 @@ impl From<(String, String)> for Broadcast {
}
}

impl From<Broadcast> for (String, String) {
fn from(bcast: Broadcast) -> (String, String) {
(bcast.broadcast_id, bcast.version)
impl From<Broadcast> for (String, BroadcastValue) {
fn from(bcast: Broadcast) -> (String, BroadcastValue) {
(bcast.broadcast_id, BroadcastValue::Value(bcast.version))
}
}

Expand All @@ -91,7 +103,7 @@ impl Broadcast {
val.into_iter().map(|v| v.into()).collect()
}

pub fn into_hashmap(broadcasts: Vec<Broadcast>) -> HashMap<String, String> {
pub fn into_hashmap(broadcasts: Vec<Broadcast>) -> HashMap<String, BroadcastValue> {
broadcasts.into_iter().map(|v| v.into()).collect()
}
}
Expand Down Expand Up @@ -121,9 +133,7 @@ impl BroadcastChangeTracker {
change_count: 0,
};
for srv in broadcasts {
let key = tracker
.broadcast_registry
.add_broadcast(srv.broadcast_id);
let key = tracker.broadcast_registry.add_broadcast(srv.broadcast_id);
tracker.broadcast_versions.insert(key, srv.version);
}
tracker
Expand Down Expand Up @@ -154,7 +164,9 @@ impl BroadcastChangeTracker {
return change_count;
}
self.change_count += 1;
let key = self.broadcast_registry.add_broadcast(broadcast.broadcast_id);
let key = self
.broadcast_registry
.add_broadcast(broadcast.broadcast_id);
self.broadcast_versions.insert(key, broadcast.version);
self.broadcast_list.push(BroadcastRevision {
change_count: self.change_count,
Expand All @@ -167,7 +179,8 @@ impl BroadcastChangeTracker {
///
/// Returns an error if the `broadcast` was never initialized/added.
pub fn update_broadcast(&mut self, broadcast: Broadcast) -> Result<u32> {
let key = self.broadcast_registry
let key = self
.broadcast_registry
.lookup_key(&broadcast.broadcast_id)
.ok_or("Broadcast not found")?;

Expand All @@ -181,10 +194,17 @@ impl BroadcastChangeTracker {
}

// Check to see if this broadcast has been updated since initialization
let bcast_index = self.broadcast_list
let bcast_index = self
.broadcast_list
.iter()
.enumerate()
.filter_map(|(i, bcast)| if bcast.broadcast == key { Some(i) } else { None })
.filter_map(|(i, bcast)| {
if bcast.broadcast == key {
Some(i)
} else {
None
}
})
.nth(0);
self.change_count += 1;
if let Some(bcast_index) = bcast_index {
Expand Down Expand Up @@ -265,8 +285,7 @@ impl BroadcastChangeTracker {
broadcast_subs: &mut BroadcastSubs,
broadcasts: &[Broadcast],
) -> Option<Vec<Broadcast>> {
let mut bcast_delta = self.change_count_delta(broadcast_subs)
.unwrap_or_default();
let mut bcast_delta = self.change_count_delta(broadcast_subs).unwrap_or_default();
for bcast in broadcasts.iter() {
if let Some(bcast_key) = self.broadcast_registry.lookup_key(&bcast.broadcast_id) {
if let Some(ver) = self.broadcast_versions.get(&bcast_key) {
Expand All @@ -286,6 +305,24 @@ impl BroadcastChangeTracker {
Some(bcast_delta)
}
}

/// Check a broadcast list and return unknown broadcast id's with their appropriate error
pub fn missing_broadcasts(&self, broadcasts: &[Broadcast]) -> Vec<Broadcast> {
broadcasts
.iter()
.filter_map(|b| {
if self
.broadcast_registry
.lookup_key(&b.broadcast_id)
.is_none()
{
Some(b.clone().error())
} else {
None
}
})
.collect()
}
}

#[cfg(test)]
Expand Down Expand Up @@ -345,12 +382,10 @@ mod tests {
let delta = tracker
.subscribe_to_broadcasts(
&mut broadcast_subs,
&vec![
Broadcast {
broadcast_id: String::from("bcastc"),
version: String::from("revision_alpha"),
},
],
&vec![Broadcast {
broadcast_id: String::from("bcastc"),
version: String::from("revision_alpha"),
}],
)
.unwrap();
assert_eq!(delta.len(), 1);
Expand Down
39 changes: 39 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,23 @@ def test_broadcast_update_on_connect(self):

yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_update_on_connect_with_errors(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
self.mock_megaphone.polled.clear()
self.mock_megaphone.polled.wait(timeout=5)

old_ver = {"kinto:123": "ver0", "kinto:456": "ver1"}
client = Client(self._ws_url)
yield client.connect()
result = yield client.hello(services=old_ver)
assert result != {}
assert result["use_webpush"] is True
assert result["broadcasts"]["kinto:123"] == "ver1"
assert result["broadcasts"]["errors"][
"kinto:456"] == "Broadcast not found"
yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_subscribe(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
Expand Down Expand Up @@ -1053,6 +1070,28 @@ def test_broadcast_subscribe(self):

yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_subscribe_with_errors(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
self.mock_megaphone.polled.clear()
self.mock_megaphone.polled.wait(timeout=5)

old_ver = {"kinto:123": "ver0", "kinto:456": "ver1"}
client = Client(self._ws_url)
yield client.connect()
result = yield client.hello()
assert result != {}
assert result["use_webpush"] is True
assert result["broadcasts"] == {}

client.broadcast_subscribe(old_ver)
result = yield client.get_broadcast()
assert result["broadcasts"]["kinto:123"] == "ver1"
assert result["broadcasts"]["errors"][
"kinto:456"] == "Broadcast not found"

yield self.shut_down(client)

@inlineCallbacks
def test_broadcast_no_changes(self):
self.mock_megaphone.services = {"kinto:123": "ver1"}
Expand Down