From 52e5009963891c75ef12763eb48d3e55f93115e5 Mon Sep 17 00:00:00 2001 From: Laszlo Nagy Date: Sat, 12 Oct 2024 13:22:09 +1100 Subject: [PATCH] rust: intercept module tested to work to transfer events --- rust/intercept/src/collector.rs | 46 +++++++------- rust/intercept/src/lib.rs | 60 ++---------------- rust/intercept/src/reporter.rs | 12 ++-- rust/intercept/tests/test.rs | 107 ++++++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 88 deletions(-) create mode 100644 rust/intercept/tests/test.rs diff --git a/rust/intercept/src/collector.rs b/rust/intercept/src/collector.rs index 8a9addad..db2806b6 100644 --- a/rust/intercept/src/collector.rs +++ b/rust/intercept/src/collector.rs @@ -17,11 +17,12 @@ along with this program. If not, see . */ -use std::net::{TcpListener, TcpStream}; +use std::net::{SocketAddr, TcpListener, TcpStream}; -use crossbeam::channel::{Receiver, Sender}; -use crossbeam_channel::bounded; +use crossbeam::channel::Sender; use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use super::Envelope; @@ -29,26 +30,27 @@ use super::Envelope; pub struct SessionLocator(pub String); pub trait EventCollector { - fn address(&self) -> Result; + fn address(&self) -> SessionLocator; fn collect(&self, destination: Sender) -> Result<(), anyhow::Error>; fn stop(&self) -> Result<(), anyhow::Error>; } pub struct EventCollectorOnTcp { - control_input: Sender, - control_output: Receiver, + shutdown: Arc, listener: TcpListener, + address: SocketAddr, } impl EventCollectorOnTcp { pub fn new() -> Result { - let (control_input, control_output) = bounded(0); + let shutdown = Arc::new(AtomicBool::new(false)); let listener = TcpListener::bind("127.0.0.1:0")?; + let address = listener.local_addr()?; let result = EventCollectorOnTcp { - control_input, - control_output, + shutdown, listener, + address, }; Ok(result) @@ -67,25 +69,20 @@ impl EventCollectorOnTcp { } impl EventCollector for EventCollectorOnTcp { - fn address(&self) -> Result { - let local_addr = self.listener.local_addr()?; - let locator = SessionLocator(local_addr.to_string()); - Ok(locator) + fn address(&self) -> SessionLocator { + SessionLocator(self.address.to_string()) } fn collect(&self, destination: Sender) -> Result<(), anyhow::Error> { - loop { - if let Ok(shutdown) = self.control_output.try_recv() { - if shutdown { - break; - } + for stream in self.listener.incoming() { + if self.shutdown.load(Ordering::Relaxed) { + break; } - match self.listener.accept() { - Ok((stream, _)) => { - println!("Got a connection"); + match stream { + Ok(connection) => { // ... (process the connection in a separate thread or task) - self.send(stream, destination.clone())?; + self.send(connection, destination.clone())?; } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // No new connection available, continue checking for shutdown @@ -97,13 +94,12 @@ impl EventCollector for EventCollectorOnTcp { } } } - - println!("Server shutting down"); Ok(()) } fn stop(&self) -> Result<(), anyhow::Error> { - self.control_input.send(true)?; + self.shutdown.store(true, Ordering::Relaxed); + let _ = TcpStream::connect(self.address)?; Ok(()) } } diff --git a/rust/intercept/src/lib.rs b/rust/intercept/src/lib.rs index 4d673db6..e3317ff3 100644 --- a/rust/intercept/src/lib.rs +++ b/rust/intercept/src/lib.rs @@ -34,10 +34,10 @@ pub mod reporter; #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct ReporterId(pub u64); -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct ProcessId(pub u32); -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct Execution { pub executable: PathBuf, pub arguments: Vec, @@ -51,7 +51,7 @@ pub struct Execution { // terminate), but can be extended later with performance related // events like monitoring the CPU usage or the memory allocation if // this information is available. -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub enum Event { Started { pid: ProcessId, @@ -66,7 +66,7 @@ pub enum Event { }, } -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct Envelope { pub rid: ReporterId, pub timestamp: u64, @@ -106,55 +106,3 @@ impl Envelope { Ok(length) } } - -#[cfg(test)] -mod test { - use super::*; - use lazy_static::lazy_static; - use std::io::Cursor; - - #[test] - fn read_write_works() { - let mut writer = Cursor::new(vec![0; 1024]); - for envelope in ENVELOPES.iter() { - let result = Envelope::write_into(envelope, &mut writer); - assert!(result.is_ok()); - } - - let mut reader = Cursor::new(writer.get_ref()); - for envelope in ENVELOPES.iter() { - let result = Envelope::read_from(&mut reader); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), *envelope.clone()); - } - } - - lazy_static! { - static ref ENVELOPES: Vec = vec![ - Envelope { - rid: ReporterId(1), - timestamp: 0, - event: Event::Started { - pid: ProcessId(1), - ppid: ProcessId(0), - execution: Execution { - executable: PathBuf::from("/usr/bin/ls"), - arguments: vec!["-l".to_string()], - working_dir: PathBuf::from("/tmp"), - environment: HashMap::new(), - }, - }, - }, - Envelope { - rid: ReporterId(1), - timestamp: 0, - event: Event::Terminated { status: 0 }, - }, - Envelope { - rid: ReporterId(1), - timestamp: 0, - event: Event::Signaled { signal: 15 }, - }, - ]; - } -} diff --git a/rust/intercept/src/reporter.rs b/rust/intercept/src/reporter.rs index b738c452..b6e4366c 100644 --- a/rust/intercept/src/reporter.rs +++ b/rust/intercept/src/reporter.rs @@ -36,21 +36,18 @@ impl ReporterId { // supervisor processes). The events are collected in a common place // in order to reconstruct of final report of a build process. pub trait Reporter { - fn report(&mut self, event: Event) -> Result<(), anyhow::Error>; + fn report(&self, event: Event) -> Result<(), anyhow::Error>; } -struct TcpReporter { - socket: TcpStream, +pub struct TcpReporter { destination: String, reporter_id: ReporterId, } impl TcpReporter { pub fn new(destination: String) -> Result { - let socket = TcpStream::connect(destination.clone())?; let reporter_id = ReporterId::new(); let result = TcpReporter { - socket, destination, reporter_id, }; @@ -59,9 +56,10 @@ impl TcpReporter { } impl Reporter for TcpReporter { - fn report(&mut self, event: Event) -> Result<(), anyhow::Error> { + fn report(&self, event: Event) -> Result<(), anyhow::Error> { let envelope = Envelope::new(&self.reporter_id, event); - envelope.write_into(&mut self.socket)?; + let mut socket = TcpStream::connect(self.destination.clone())?; + envelope.write_into(&mut socket)?; Ok(()) } diff --git a/rust/intercept/tests/test.rs b/rust/intercept/tests/test.rs new file mode 100644 index 00000000..062154a1 --- /dev/null +++ b/rust/intercept/tests/test.rs @@ -0,0 +1,107 @@ +use intercept::collector::{EventCollector, EventCollectorOnTcp}; +use intercept::reporter::{Reporter, TcpReporter}; +use intercept::*; + +mod test { + use super::*; + use crossbeam_channel::bounded; + use lazy_static::lazy_static; + use std::collections::HashMap; + use std::io::Cursor; + use std::path::PathBuf; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + // Test that the TCP reporter and the TCP collector work together. + // We create a TCP collector and a TCP reporter, then we send events + // to the reporter and check if the collector receives them. + // + // We use a bounded channel to send the events from the reporter to the + // collector. The collector reads the events from the channel and checks + // if they are the same as the original events. + #[test] + fn tcp_reporter_and_collectors_work() { + let collector = EventCollectorOnTcp::new().unwrap(); + let reporter = TcpReporter::new(collector.address().0).unwrap(); + + // Create wrapper to share the collector across threads. + let thread_collector = Arc::new(collector); + let main_collector = thread_collector.clone(); + + // Start the collector in a separate thread. + let (input, output) = bounded(EVENTS.len()); + let receiver_thread = thread::spawn(move || { + thread_collector.collect(input).unwrap(); + }); + // Send events to the reporter. + for event in EVENTS.iter() { + let result = reporter.report(event.clone()); + assert!(result.is_ok()); + } + + // Call the stop method to stop the collector. This will close the + // channel and the collector will stop reading from it. + thread::sleep(Duration::from_secs(1)); + main_collector.stop().unwrap(); + + // Empty the channel and assert that we received all the events. + let mut count = 0; + for envelope in output.iter() { + assert!(EVENTS.contains(&envelope.event)); + count += 1; + } + assert_eq!(count, EVENTS.len()); + // shutdown the receiver thread + receiver_thread.join().unwrap(); + } + + // Test that the serialization and deserialization of the Envelope works. + // We write the Envelope to a buffer and read it back to check if the + // deserialized Envelope is the same as the original one. + #[test] + fn read_write_works() { + let mut writer = Cursor::new(vec![0; 1024]); + for envelope in ENVELOPES.iter() { + let result = Envelope::write_into(envelope, &mut writer); + assert!(result.is_ok()); + } + + let mut reader = Cursor::new(writer.get_ref()); + for envelope in ENVELOPES.iter() { + let result = Envelope::read_from(&mut reader); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), envelope.clone()); + } + } + + lazy_static! { + static ref ENVELOPES: Vec = vec![ + Envelope { + rid: ReporterId(1), + timestamp: 0, + event: Event::Started { + pid: ProcessId(1), + ppid: ProcessId(0), + execution: Execution { + executable: PathBuf::from("/usr/bin/ls"), + arguments: vec!["ls".to_string(), "-l".to_string()], + working_dir: PathBuf::from("/tmp"), + environment: HashMap::new(), + }, + }, + }, + Envelope { + rid: ReporterId(1), + timestamp: 0, + event: Event::Terminated { status: 0 }, + }, + Envelope { + rid: ReporterId(1), + timestamp: 0, + event: Event::Signaled { signal: 15 }, + }, + ]; + static ref EVENTS: Vec = ENVELOPES.iter().map(|e| e.event.clone()).collect(); + } +}