Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ws server]: parse subscription ID for unsubscription instead of hardcoding JsonValue::Null #136

Merged
merged 4 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/ws/raw/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::ws::transport::{TransportServerEvent, WsRequestId as RequestId, WsTra
use alloc::{borrow::ToOwned as _, string::String, vec, vec::Vec};
use core::{fmt, hash::Hash, num::NonZeroUsize};
use hashbrown::{hash_map::Entry, HashMap};
use std::convert::TryFrom;

/// Wraps around a "raw server" and adds capabilities.
///
Expand Down Expand Up @@ -399,6 +400,22 @@ impl RawServerSubscriptionId {
}
}

// Try to parse a subscription ID from `Params` where we try both index 0 of an array or `subscription`
// in a `Map`.
impl<'a> TryFrom<Params<'a>> for RawServerSubscriptionId {
type Error = ();

fn try_from(params: Params) -> Result<Self, Self::Error> {
if let Ok(other_id) = params.get(0) {
Self::from_wire_message(&other_id)
} else if let Ok(other_id) = params.get("subscription") {
Self::from_wire_message(&other_id)
} else {
Err(())
}
}
}

impl<'a> ServerSubscription<'a> {
/// Returns the id of the subscription.
///
Expand Down
39 changes: 22 additions & 17 deletions src/ws/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use futures::{channel::mpsc, future::Either, pin_mut, prelude::*};
use parking_lot::Mutex;
use std::{
collections::{HashMap, HashSet},
convert::TryFrom,
error,
net::SocketAddr,
sync::{atomic, Arc},
Expand Down Expand Up @@ -410,8 +411,8 @@ async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedR
}
}
Either::Right(RawServerEvent::Request(request)) => {
log::debug!("[backend]: server received request: {:?}", request);
if let Some(handler) = registered_methods.get_mut(request.method()) {
log::debug!("[backend]: server received request: {:?}", request);
let params: &common::Params = request.params().into();
log::debug!("server called handler");
match handler.send((request.id(), params.clone())).now_or_never() {
Expand All @@ -421,6 +422,7 @@ async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedR
}
}
} else if let Some(sub_unique_id) = subscribe_methods.get(request.method()) {
log::debug!("[backend]: server received subscription: {:?}", request);
if let Ok(sub_id) = request.into_subscription() {
debug_assert!(subscribed_clients.contains_key(&sub_unique_id));
if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) {
Expand All @@ -432,22 +434,25 @@ async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedR
active_subscriptions.insert(sub_id, *sub_unique_id);
}
} else if let Some(sub_unique_id) = unsubscribe_methods.get(request.method()) {
if let Ok(sub_id) = RawServerSubscriptionId::from_wire_message(&JsonValue::Null) {
// FIXME: from request params
debug_assert!(subscribed_clients.contains_key(&sub_unique_id));
if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) {
// TODO: we don't actually check whether the unsubscribe comes from the right
// clients, but since this the ID is randomly-generated, it should be
// fine
if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) {
clients.remove(client_pos);
}

if let Some(s_u_id) = active_subscriptions.remove(&sub_id) {
debug_assert_eq!(s_u_id, *sub_unique_id);
}
}
}
log::debug!("[backend]: server received unsubscription: {:?}", request);
match RawServerSubscriptionId::try_from(request.params()) {
Ok(sub_id) => {
debug_assert!(subscribed_clients.contains_key(&sub_unique_id));
if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) {
// TODO: we don't actually check whether the unsubscribe comes from the right
// clients, but since this the ID is randomly-generated, it should be
// fine
if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) {
clients.remove(client_pos);
}

if let Some(s_u_id) = active_subscriptions.remove(&sub_id) {
debug_assert_eq!(s_u_id, *sub_unique_id);
}
}
}
Err(_) => log::error!("Unsubscription of method=\"{}\" failed; The subscription ID must passed as the first argument of Array or \"subscription\" name of Object, got={:?}", request.method(), request.params()),
Copy link
Contributor

@maciejhirsz maciejhirsz Oct 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentiation is off (tabs vs spaces), otherwise looks fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, I must have messed it up during merging but still weird that rustfmt didn't fix it.

}
} else {
// TODO: we assert that the request is valid because the parsing succeeded but
// not registered.
Expand Down