Skip to content

Commit

Permalink
Implementation of local_receive_filter
Browse files Browse the repository at this point in the history
FilterChain for local_receive_filter is in place, and working as
expected, with accompanying unit tests.

Found a small issue in Session as well that also got fixed in this PR
that came up in the unit tests.

Work on #1
  • Loading branch information
markmandel committed Jun 1, 2020
1 parent a808ea2 commit 7b4c960
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 94 deletions.
2 changes: 0 additions & 2 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ impl FilterRegistry {
}

/// get returns the filter for a given Key. Returns None if not found.
// TODO: remove when used
#[allow(dead_code)]
pub fn get(&self, key: &String) -> Option<&Arc<dyn Filter>> {
self.registry.get(key)
}
Expand Down
242 changes: 150 additions & 92 deletions src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tokio::sync::{Mutex, RwLock};
use tokio::time::{delay_for, Duration, Instant};

use crate::config::{Config, ConnectionConfig};
use crate::extensions::FilterRegistry;
use crate::extensions::{Filter, FilterChain, FilterRegistry};
use crate::server::sessions::{Packet, Session, SESSION_TIMEOUT_SECONDS};

type SessionMap = Arc<RwLock<HashMap<(SocketAddr, SocketAddr), Mutex<Session>>>>;
Expand All @@ -38,8 +38,6 @@ type SessionMap = Arc<RwLock<HashMap<(SocketAddr, SocketAddr), Mutex<Session>>>>
pub struct Server {
log: Logger,
/// registry for the set of available filters
/// TODO: remove this once we have a registry
#[allow(dead_code)]
filter_registry: FilterRegistry,
}

Expand All @@ -62,10 +60,14 @@ impl Server {
// HashMap key is from,destination addresses as a tuple.
let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new()));
let (send_packets, receive_packets) = mpsc::channel::<Packet>(1024);
let chain = Arc::new(FilterChain::from_config(
config.clone(),
&self.filter_registry,
)?);

self.run_receive_packet(send_socket, receive_packets);
self.run_prune_sessions(&sessions);
self.run_recv_from(config, receive_socket, &sessions, send_packets);
self.run_recv_from(config, chain, receive_socket, &sessions, send_packets);
// convert to an IO error
stop.await
.map_err(|err| Error::new(ErrorKind::BrokenPipe, err))
Expand All @@ -92,6 +94,7 @@ impl Server {
fn run_recv_from(
&self,
config: Arc<Config>,
chain: Arc<FilterChain>,
mut receive_socket: RecvHalf,
sessions: &SessionMap,
send_packets: mpsc::Sender<Packet>,
Expand All @@ -103,6 +106,7 @@ impl Server {
if let Err(err) = Server::recv_from(
&log,
config.clone(),
chain.clone(),
&mut receive_socket,
sessions.clone(),
send_packets.clone(),
Expand All @@ -120,6 +124,7 @@ impl Server {
async fn recv_from(
log: &Logger,
config: Arc<Config>,
chain: Arc<FilterChain>,
receive_socket: &mut RecvHalf,
sessions: SessionMap,
send_packets: mpsc::Sender<Packet>,
Expand All @@ -138,44 +143,48 @@ impl Server {
from_utf8(packet).unwrap()
);

for endpoint in endpoints.iter() {
if let Err(err) = Server::ensure_session(
&log,
sessions.clone(),
recv_addr,
endpoint.address,
send_packets.clone(),
)
.await
{
error!(log, "Error ensuring session exists"; "error" => %err);
continue;
}
let result = chain.local_receive_filter(&endpoints, recv_addr, packet.to_vec());

if let Some((endpoints, packet)) = result {
for endpoint in endpoints.iter() {
if let Err(err) = Server::ensure_session(
&log,
sessions.clone(),
recv_addr,
endpoint.address,
send_packets.clone(),
)
.await
{
error!(log, "Error ensuring session exists"; "error" => %err);
continue;
}

let map = sessions.read().await;
let key = (recv_addr, endpoint.address);
match map.get(&key) {
Some(mtx) => {
let mut session = mtx.lock().await;
match session.send_to(packet).await {
Ok(_) => {
session.increment_expiration().await;
}
Err(err) => {
error!(log, "Error sending packet from session"; "error" => %err)
}
};
let map = sessions.read().await;
let key = (recv_addr, endpoint.address);
match map.get(&key) {
Some(mtx) => {
let mut session = mtx.lock().await;
match session.send_to(packet.as_slice()).await {
Ok(_) => {
session.increment_expiration().await;
}
Err(err) => {
error!(log, "Error sending packet from session"; "error" => %err)
}
};
}
None => warn!(
log,
"Could not find session for key: ({}:{})",
key.0.to_string(),
key.1.to_string()
),
}
None => warn!(
log,
"Could not find session for key: ({}:{})",
key.0.to_string(),
key.1.to_string()
),
}
}
});
return Ok(());
Ok(())
}

/// run_receive_packet is a non-blocking loop on receive_packets.recv() channel
Expand Down Expand Up @@ -294,7 +303,7 @@ mod tests {
use crate::config::{Config, ConnectionConfig, EndPoint, Local};
use crate::extensions::default_filters;
use crate::server::sessions::{Packet, SESSION_TIMEOUT_SECONDS};
use crate::test_utils::{ephemeral_socket, logger, recv_udp, recv_udp_done};
use crate::test_utils::{ephemeral_socket, logger, recv_udp, recv_udp_done, TestFilter};

use super::*;

Expand Down Expand Up @@ -400,70 +409,113 @@ mod tests {

#[tokio::test]
async fn recv_from() {
time::pause();

let log = logger();
let msg = "hello";
let (local_addr, wait) = recv_udp().await;
struct Result {
msg: String,
addr: SocketAddr,
}
struct Expected {
session_len: usize,
}

let config = Arc::new(Config {
local: Local { port: 0 },
filters: vec![],
connections: ConnectionConfig::Client {
address: local_addr,
connection_id: String::from(""),
},
});
let receive_socket = ephemeral_socket().await;
let receive_addr = receive_socket.local_addr().unwrap();
let (mut recv, mut send) = receive_socket.split();
let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new()));
let (send_packets, mut recv_packets) = mpsc::channel::<Packet>(1);
async fn test(
name: String,
log: &Logger,
chain: Arc<FilterChain>,
expected: Expected,
) -> Result {
time::pause();
info!(log, "Test"; "name" => name);
let msg = "hello".to_string();
let (local_addr, wait) = recv_udp().await;

let config = Arc::new(Config {
local: Local { port: 0 },
filters: vec![],
connections: ConnectionConfig::Client {
address: local_addr,
connection_id: String::from(""),
},
});
let receive_socket = ephemeral_socket().await;
let receive_addr = receive_socket.local_addr().unwrap();
let (mut recv, mut send) = receive_socket.split();
let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new()));
let (send_packets, mut recv_packets) = mpsc::channel::<Packet>(1);

let sessions_clone = sessions.clone();
let log_clone = log.clone();

let time_increment = 10;
time::advance(Duration::from_secs(time_increment)).await;

tokio::spawn(async move {
Server::recv_from(
&log_clone,
config,
chain,
&mut recv,
sessions_clone,
send_packets.clone(),
)
.await
});

let sessions_clone = sessions.clone();
let log_clone = log.clone();
send.send_to(msg.as_bytes(), &receive_addr).await.unwrap();

let time_increment = 10;
time::advance(Duration::from_secs(time_increment)).await;
let result = wait.await.unwrap();
recv_packets.close();

tokio::spawn(async move {
Server::recv_from(
&log_clone,
config,
&mut recv,
sessions_clone,
send_packets.clone(),
)
.await
});
let map = sessions.read().await;
assert_eq!(expected.session_len, map.len());

// need to switch to 127.0.0.1, as the request comes locally
let mut receive_addr_local = receive_addr.clone();
receive_addr_local.set_ip("127.0.0.1".parse().unwrap());
let build_key = (receive_addr_local, local_addr);
assert!(map.contains_key(&build_key));
let session = map.get(&build_key).unwrap().lock().await;
assert_eq!(
SESSION_TIMEOUT_SECONDS,
session
.expiration()
.await
.duration_since(Instant::now())
.as_secs(),
);

send.send_to("hello".as_bytes(), &receive_addr)
.await
.unwrap();
time::resume();

assert_eq!(msg, wait.await.unwrap());
recv_packets.close();
Result {
msg: result,
addr: receive_addr_local,
}
}

let map = sessions.read().await;
assert_eq!(1, map.len());
let log = logger();

// need to switch to 127.0.0.1, as the request comes locally
let mut receive_addr_local = receive_addr.clone();
receive_addr_local.set_ip("127.0.0.1".parse().unwrap());
let build_key = (receive_addr_local, local_addr);
assert!(map.contains_key(&build_key));
let chain = Arc::new(FilterChain::new(vec![]));
let result = test(
"no filter".to_string(),
&log,
chain,
Expected { session_len: 1 },
)
.await;
assert_eq!("hello", result.msg);

let chain = Arc::new(FilterChain::new(vec![Arc::new(TestFilter {})]));
let result = test(
"test filter".to_string(),
&log,
chain,
Expected { session_len: 2 },
)
.await;

let session = map.get(&build_key).unwrap().lock().await;
assert_eq!(
SESSION_TIMEOUT_SECONDS,
session
.expiration()
.await
.duration_since(Instant::now())
.as_secs()
format!("hello:lrf:127.0.0.1:{}", result.addr.port()),
result.msg
);

time::resume();
}

#[tokio::test]
Expand All @@ -486,7 +538,13 @@ mod tests {
let sessions: SessionMap = Arc::new(RwLock::new(HashMap::new()));
let (send_packets, mut recv_packets) = mpsc::channel::<Packet>(1);

server.run_recv_from(config, recv, &sessions, send_packets);
server.run_recv_from(
config,
Arc::new(FilterChain::new(vec![])),
recv,
&sessions,
send_packets,
);

send.send_to(msg.as_bytes(), &addr).await.unwrap();
assert_eq!(msg, wait.await.unwrap());
Expand Down
4 changes: 4 additions & 0 deletions src/server/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ impl Session {
is_closed.store(true, Relaxed);
debug!(log, "Closing Session");
return;
} else if let None = close_request {
is_closed.store(true, Relaxed);
debug!(log, "Dropping Session");
return;
}
}
};
Expand Down

0 comments on commit 7b4c960

Please sign in to comment.