Skip to content

Commit

Permalink
Implement RefUnwindSafe for StatsdClient
Browse files Browse the repository at this point in the history
Make sure that the StatsdClient is unwind (panic) safe by ensuring
that pointers to sinks and error handlers require the object to be
unwind safe.

Make the QueuingMetricSink unwind safe by not using a CondVar and
Mutex but instead using an AtomicBool to indicate when the worker
is stopped. Additionally, assert that the Crossbeam MsQueue is
unwind safe because it implements Sync.

See rust-lang/rust#54768

Fixes #77
  • Loading branch information
56quarters committed Nov 2, 2018
1 parent 634698b commit dd2443d
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 37 deletions.
37 changes: 27 additions & 10 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

use std::fmt;
use std::net::{ToSocketAddrs, UdpSocket};
use std::panic::RefUnwindSafe;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -345,15 +346,15 @@ pub trait MetricBackend {
/// ```
pub struct StatsdClientBuilder {
prefix: String,
sink: Box<MetricSink + Sync + Send>,
errors: Box<Fn(MetricError) -> () + Sync + Send>,
sink: Box<MetricSink + Sync + Send + RefUnwindSafe>,
errors: Box<Fn(MetricError) -> () + Sync + Send + RefUnwindSafe>,
}

impl StatsdClientBuilder {
// Set the required fields and defaults for optional fields
fn new<T>(prefix: &str, sink: T) -> Self
where
T: MetricSink + Sync + Send + 'static,
T: MetricSink + Sync + Send + RefUnwindSafe + 'static,
{
StatsdClientBuilder {
// required
Expand All @@ -376,7 +377,7 @@ impl StatsdClientBuilder {
/// implementation.
pub fn with_error_handler<F>(mut self, errors: F) -> Self
where
F: Fn(MetricError) -> () + Sync + Send + 'static,
F: Fn(MetricError) -> () + Sync + Send + RefUnwindSafe + 'static,
{
self.errors = Box::new(errors);
self
Expand Down Expand Up @@ -437,14 +438,15 @@ impl StatsdClientBuilder {
/// idea is to share this between threads.
///
/// ``` no_run
/// use std::panic::RefUnwindSafe;
/// use std::net::UdpSocket;
/// use std::sync::Arc;
/// use std::thread;
/// use cadence::prelude::*;
/// use cadence::{StatsdClient, BufferedUdpMetricSink, DEFAULT_PORT};
///
/// struct MyRequestHandler {
/// metrics: Arc<MetricClient + Send + Sync>,
/// metrics: Arc<MetricClient + Send + Sync + RefUnwindSafe>,
/// }
///
/// impl MyRequestHandler {
Expand Down Expand Up @@ -513,8 +515,8 @@ impl StatsdClientBuilder {
#[derive(Clone)]
pub struct StatsdClient {
prefix: String,
sink: Arc<MetricSink + Sync + Send>,
errors: Arc<Fn(MetricError) -> () + Sync + Send>,
sink: Arc<MetricSink + Sync + Send + RefUnwindSafe>,
errors: Arc<Fn(MetricError) -> () + Sync + Send + RefUnwindSafe>,
}

impl StatsdClient {
Expand Down Expand Up @@ -565,7 +567,7 @@ impl StatsdClient {
/// ```
pub fn from_sink<T>(prefix: &str, sink: T) -> Self
where
T: MetricSink + Sync + Send + 'static,
T: MetricSink + Sync + Send + RefUnwindSafe + 'static,
{
Self::builder(prefix, sink).build()
}
Expand Down Expand Up @@ -641,7 +643,7 @@ impl StatsdClient {
/// ```
pub fn builder<T>(prefix: &str, sink: T) -> StatsdClientBuilder
where
T: MetricSink + Sync + Send + 'static,
T: MetricSink + Sync + Send + RefUnwindSafe + 'static,
{
StatsdClientBuilder::new(prefix, sink)
}
Expand Down Expand Up @@ -752,6 +754,7 @@ fn nop_error_handler(_err: MetricError) {
mod tests {
use std::cell::RefCell;
use std::io;
use std::panic::RefUnwindSafe;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
Expand All @@ -761,7 +764,7 @@ mod tests {
Counted, Gauged, Histogrammed, Metered, MetricClient, Setted, StatsdClient, Timed,
};

use sinks::{MetricSink, NopMetricSink};
use sinks::{MetricSink, NopMetricSink, QueuingMetricSink};
use types::{ErrorKind, Metric, MetricError};

#[test]
Expand Down Expand Up @@ -1011,4 +1014,18 @@ mod tests {
client.histogram("some.histogram", 32).unwrap();
client.set("some.set", 5).unwrap();
}

#[test]
fn test_statsd_client_as_thread_and_panic_safe() {
let client: Box<MetricClient + Send + Sync + RefUnwindSafe> = Box::new(
StatsdClient::from_sink("prefix", QueuingMetricSink::from(NopMetricSink)),
);

client.count("some.counter", 3).unwrap();
client.time("some.timer", 198).unwrap();
client.gauge("some.gauge", 4).unwrap();
client.meter("some.meter", 29).unwrap();
client.histogram("some.histogram", 32).unwrap();
client.set("some.set", 5).unwrap();
}
}
151 changes: 124 additions & 27 deletions src/sinks/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

use std::fmt;
use std::io;
use std::panic::{AssertUnwindSafe, RefUnwindSafe};
use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering};

use std::sync::Arc;
use std::thread;

Expand Down Expand Up @@ -87,7 +89,7 @@ impl QueuingMetricSink {
/// ```
pub fn from<T>(sink: T) -> QueuingMetricSink
where
T: MetricSink + Sync + Send + 'static,
T: MetricSink + Sync + Send + RefUnwindSafe + 'static,
{
let worker = Worker::new(move |v: String| {
let _r = sink.emit(&v);
Expand Down Expand Up @@ -262,8 +264,8 @@ struct Worker<T>
where
T: Send + 'static,
{
task: Box<Fn(T) -> () + Sync + Send + 'static>,
queue: MsQueue<Option<T>>,
task: Box<Fn(T) -> () + Sync + Send + RefUnwindSafe + 'static>,
queue: AssertUnwindSafe<MsQueue<Option<T>>>,
stopped: AtomicBool,
}

Expand All @@ -273,11 +275,11 @@ where
{
fn new<F>(task: F) -> Worker<T>
where
F: Fn(T) -> () + Sync + Send + 'static,
F: Fn(T) -> () + Sync + Send + RefUnwindSafe + 'static,
{
Worker {
task: Box::new(task),
queue: MsQueue::new(),
queue: AssertUnwindSafe(MsQueue::new()),
stopped: AtomicBool::new(false),
}
}
Expand Down Expand Up @@ -333,7 +335,7 @@ mod tests {
use super::{QueuingMetricSink, Worker};
use sinks::core::MetricSink;
use std::io;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;

Expand Down Expand Up @@ -391,29 +393,79 @@ mod tests {
assert!(worker.is_stopped());
}

struct TestMetricSink {
metrics: Arc<Mutex<Vec<String>>>,
}
// Make sure the worker and its queue are in the expected state
// when the producer size of the queue panics.
#[test]
fn test_worker_panic_on_submit_side() {
let worker = Arc::new(Worker::new(move |_: String| {}));
let worker_ref1 = worker.clone();
let worker_ref2 = worker.clone();

impl TestMetricSink {
fn new(store: Arc<Mutex<Vec<String>>>) -> TestMetricSink {
TestMetricSink { metrics: store }
}
#[allow(unreachable_code)]
let t1 = thread::spawn(move || {
worker_ref1.submit(panic!("This thread is supposed to panic"))
});

let t2 = thread::spawn(move || {
worker_ref2.run();
});

worker.stop();

assert!(t1.join().is_err());
assert!(t2.join().is_ok());

assert!(worker.is_stopped());
assert!(worker.queue.is_empty());
}

impl MetricSink for TestMetricSink {
fn emit(&self, m: &str) -> io::Result<usize> {
let mut store = self.metrics.lock().unwrap();
store.push(m.to_string());
Ok(m.len())
}
// Make sure the worker and its queue are in the expected state
// when the consumer side of the queue panics.
#[test]
fn test_worker_panic_on_run_side() {
let worker = Arc::new(Worker::new(move |_: String| { panic!("This thread is supposed to panic"); }));
let worker_ref1 = worker.clone();
let worker_ref2 = worker.clone();

let t1 = thread::spawn(move || {
worker_ref1.submit("foo".to_owned());
});

let t2 = thread::spawn(move || {
worker_ref2.run();
});

assert!(t1.join().is_ok());
assert!(t2.join().is_err());

assert!(!worker.is_stopped());
assert!(worker.queue.is_empty());
}

#[test]
fn test_queuing_sink_emit() {
struct TestMetricSink {
metrics: Arc<Mutex<Vec<String>>>,
}

impl TestMetricSink {
fn new(metrics: Arc<Mutex<Vec<String>>>) -> TestMetricSink {
TestMetricSink { metrics }
}
}

impl MetricSink for TestMetricSink {
fn emit(&self, m: &str) -> io::Result<usize> {
let mut store = self.metrics.lock().unwrap();
store.push(m.to_string());
Ok(m.len())
}
}

let store = Arc::new(Mutex::new(vec![]));
let wrapped = TestMetricSink::new(store.clone());
let queuing = QueuingMetricSink::from(wrapped);

queuing.emit("foo.counter:1|c").unwrap();
queuing.emit("bar.counter:2|c").unwrap();
queuing.emit("baz.counter:3|c").unwrap();
Expand All @@ -424,17 +476,17 @@ mod tests {
assert_eq!("baz.counter:3|c".to_string(), store.lock().unwrap()[2]);
}

struct PanickingMetricSink;
#[test]
fn test_queuing_sink_emit_panics() {
struct PanickingMetricSink;

impl MetricSink for PanickingMetricSink {
#[allow(unused_variables)]
fn emit(&self, metric: &str) -> io::Result<usize> {
panic!("This thread is supposed to panic, relax :p");
impl MetricSink for PanickingMetricSink {
#[allow(unused_variables)]
fn emit(&self, metric: &str) -> io::Result<usize> {
panic!("This thread is supposed to panic");
}
}
}

#[test]
fn test_queuing_sink_emit_panics() {
let queuing = QueuingMetricSink::from(PanickingMetricSink);
queuing.emit("foo.counter:4|c").unwrap();
queuing.emit("foo.counter:5|c").unwrap();
Expand All @@ -443,4 +495,49 @@ mod tests {

assert_eq!(3, queuing.panics());
}

// Make sure that subsequent metrics make it to the wrapped sink even when
// the wrapped sink panics. This ensures that the thread running the sink
// is restarted correctly and the worker and queue are in the correct state.
#[test]
fn test_queuing_sink_emit_recover_from_panics() {
struct SometimesPanickingMetricSink {
metrics: Arc<Mutex<Vec<String>>>,
counter: AtomicUsize,
}

impl SometimesPanickingMetricSink {
fn new(metrics: Arc<Mutex<Vec<String>>>) -> Self {
SometimesPanickingMetricSink {
metrics,
counter: AtomicUsize::new(0)
}
}
}

impl MetricSink for SometimesPanickingMetricSink {
fn emit(&self, m: &str) -> io::Result<usize> {
let val = self.counter.fetch_add(1, Ordering::Acquire);
if val == 0 {
panic!("This thread is supposed to panic");
}

let mut store = self.metrics.lock().unwrap();
store.push(m.to_string());
Ok(m.len())
}
}

let store = Arc::new(Mutex::new(vec![]));
let queuing = QueuingMetricSink::from(SometimesPanickingMetricSink::new(store.clone()));

queuing.emit("foo.counter:4|c").unwrap();
queuing.emit("foo.counter:5|c").unwrap();
queuing.emit("foo.timer:34|ms").unwrap();
queuing.context.worker.stop_and_wait();

assert_eq!(1, queuing.panics());
assert_eq!("foo.counter:5|c".to_string(), store.lock().unwrap()[0]);
assert_eq!("foo.timer:34|ms".to_string(), store.lock().unwrap()[1]);
}
}

0 comments on commit dd2443d

Please sign in to comment.