From 42fb829fdb2eab8f3ba6897903891eb7f1bdc949 Mon Sep 17 00:00:00 2001 From: Hayden Stainsby Date: Mon, 17 Jul 2023 17:46:10 +0200 Subject: [PATCH] test(subscriber): add initial integration tests The `console-subscriber` crate has no integration tests. There are some unit tests, but without very high coverage of features. Recently, we've found or fixed a few errors which probably could have been caught by a medium level of integration testing. However, testing `console-subscriber` isn't straight forward. It is effectively a tracing subscriber (or layer) on one end, and a gRPC server on the other end. This change adds enough of a testing framework to write some initial integration tests. It is the first step towards closing #450. Each test comprises 2 parts: - One or more "expcted tasks" - A future which will be driven to completion on a dedicated Tokio runtime. Behind the scenes, a console subscriber layer is created and it's server part is connected to a duplex stream. The client of the duplex stream then records incoming updates and reconstructs "actual tasks". The layer itself is set as the default subscriber for the duration of `block_on` which is used to drive the provided future to completioin. The expected tasks have a set of "matches", which is how we find the actual task that we want to validate against. Currently, the only value we match on is the task's name. The expected tasks also have a set of expectations. These are other fields on the actual task which are validated once a matching task is found. Currently, the two fields which can have expectations set on them are the `wakes` and `self_wakes` fields. So, to construct an expected task, which will match a task with the name `"my-task"` and then validate that the matched task gets woken once, the code would be: ```rust ExpectedTask::default() .match_name("my-task") .expect_wakes(1); ``` A future which passes this test could be: ```rust async { task::Builder::new() .name("my-task") .spawn(async { tokio::time::sleep(std::time::Duration::ZERO).await }) } ``` The full test would then look like: ```rust fn wakes_once() { let expected_task = ExpectedTask::default() .match_name("my-task") .expect_wakes(1); let future = async { task::Builder::new() .name("my-task") .spawn(async { tokio::time::sleep(std::time::Duration::ZERO).await }) }; assert_task(expected_task, future); } ``` The PR depends on 2 others: - #447 which fixes an error in the logic that determines whether a task is retained in the aggregator or not. - #451 which exposes the server parts and is necessary to allow us to connect the instrument server and client via a duplex channel. This change contains some initial tests for wakes and self wakes which would have caught the error fixed in #430. Additionally there are tests for the functionality of the testing framework itself. --- Cargo.lock | 2 + console-subscriber/Cargo.toml | 1 + console-subscriber/src/aggregator/id_data.rs | 8 +- console-subscriber/tests/framework.rs | 184 ++++++++++ console-subscriber/tests/support/mod.rs | 47 +++ console-subscriber/tests/support/state.rs | 139 ++++++++ .../tests/support/subscriber.rs | 318 ++++++++++++++++++ console-subscriber/tests/support/task.rs | 228 +++++++++++++ console-subscriber/tests/wake.rs | 48 +++ 9 files changed, 971 insertions(+), 4 deletions(-) create mode 100644 console-subscriber/tests/framework.rs create mode 100644 console-subscriber/tests/support/mod.rs create mode 100644 console-subscriber/tests/support/state.rs create mode 100644 console-subscriber/tests/support/subscriber.rs create mode 100644 console-subscriber/tests/support/task.rs create mode 100644 console-subscriber/tests/wake.rs diff --git a/Cargo.lock b/Cargo.lock index 27d4c3088..2b6a44e0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,6 +257,7 @@ dependencies = [ name = "console-api" version = "0.5.0" dependencies = [ + "futures-core", "prost", "prost-build", "prost-types", @@ -283,6 +284,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tower", "tracing", "tracing-core", "tracing-subscriber", diff --git a/console-subscriber/Cargo.toml b/console-subscriber/Cargo.toml index d3c414e3b..b79e6e3d5 100644 --- a/console-subscriber/Cargo.toml +++ b/console-subscriber/Cargo.toml @@ -55,6 +55,7 @@ crossbeam-channel = "0.5" [dev-dependencies] tokio = { version = "^1.21", features = ["full", "rt-multi-thread"] } +tower = "0.4" futures = "0.3" [package.metadata.docs.rs] diff --git a/console-subscriber/src/aggregator/id_data.rs b/console-subscriber/src/aggregator/id_data.rs index b9010b445..2ad2c74b0 100644 --- a/console-subscriber/src/aggregator/id_data.rs +++ b/console-subscriber/src/aggregator/id_data.rs @@ -104,18 +104,18 @@ impl IdData { if let Some(dropped_at) = stats.dropped_at() { let dropped_for = now.checked_duration_since(dropped_at).unwrap_or_default(); let dirty = stats.is_unsent(); - let should_drop = + let should_retain = // if there are any clients watching, retain all dirty tasks regardless of age (dirty && has_watchers) - || dropped_for > retention; + || dropped_for <= retention; tracing::trace!( stats.id = ?id, stats.dropped_at = ?dropped_at, stats.dropped_for = ?dropped_for, stats.dirty = dirty, - should_drop, + should_retain, ); - return !should_drop; + return should_retain; } true diff --git a/console-subscriber/tests/framework.rs b/console-subscriber/tests/framework.rs new file mode 100644 index 000000000..68bf2a0ce --- /dev/null +++ b/console-subscriber/tests/framework.rs @@ -0,0 +1,184 @@ +//! Framework tests +//! +//! The tests in this module are here to verify the testing framework itself. +//! As such, some of these tests may be repeated elsewhere (where we wish to +//! actually test the functionality of `console-subscriber`) and others are +//! negative tests that should panic. + +use std::time::Duration; + +use tokio::{task, time::sleep}; + +mod support; +use support::{assert_task, assert_tasks, ExpectedTask}; + +#[test] +fn expect_present() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_present(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: no expectations set, if you want to just expect that a matching task is present, use `expect_present()` +")] +fn fail_no_expectations() { + let expected_task = ExpectedTask::default().match_default_name(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `wakes` to be 5, but actual was 1 +")] +fn fail_wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(5); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `self_wakes` to be 1, but actual was 0 +")] +fn fail_self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn test_spawned_task() { + let expected_task = ExpectedTask::default() + .match_name("another-name".into()) + .expect_present(); + + let future = async { + task::Builder::new() + .name("another-name") + .spawn(async { task::yield_now().await }) + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: no matching actual task was found +")] +fn fail_wrong_task_name() { + let expected_task = ExpectedTask::default().match_name("wrong-name".into()); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +fn multiple_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(1), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `wakes` to be 2, but actual was 1 +")] +fn fail_1_of_2_expected_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(2), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} diff --git a/console-subscriber/tests/support/mod.rs b/console-subscriber/tests/support/mod.rs new file mode 100644 index 000000000..4937aff6a --- /dev/null +++ b/console-subscriber/tests/support/mod.rs @@ -0,0 +1,47 @@ +use futures::Future; + +mod state; +mod subscriber; +mod task; + +use subscriber::run_test; + +pub(crate) use subscriber::MAIN_TASK_NAME; +pub(crate) use task::ExpectedTask; + +/// Assert that an `expected_task` is recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// This function is equivalent to calling [`assert_tasks`] with a vector +/// containing a single task. +/// +/// # Panics +/// +/// This function will panic if the expectations on the expected task are not +/// met or if a matching task is not recorded. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_task(expected_task: ExpectedTask, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(vec![expected_task], future) +} + +/// Assert that the `expected_tasks` are recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// # Panics +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_tasks(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(expected_tasks, future) +} diff --git a/console-subscriber/tests/support/state.rs b/console-subscriber/tests/support/state.rs new file mode 100644 index 000000000..6fc663808 --- /dev/null +++ b/console-subscriber/tests/support/state.rs @@ -0,0 +1,139 @@ +use std::fmt; + +use tokio::sync::broadcast::{ + self, + error::{RecvError, TryRecvError}, +}; + +/// A step in the running of the test +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub(super) enum TestStep { + /// The overall test has begun + Start, + /// The instrument server has been started + ServerStarted, + /// The client has connected to the instrument server + ClientConnected, + /// The future being driven has completed + TestFinished, + /// The client has finished recording updates + UpdatesRecorded, +} + +impl fmt::Display for TestStep { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self as &dyn fmt::Debug).fmt(f) + } +} + +/// The state of the test. +/// +/// This struct is used by various parts of the test framework to wait until +/// a specific test step has been reached and advance the test state to a new +/// step. +pub(super) struct TestState { + receiver: broadcast::Receiver, + sender: broadcast::Sender, + step: TestStep, +} + +impl TestState { + pub(super) fn new() -> Self { + let (sender, receiver) = broadcast::channel(1); + Self { + receiver, + sender, + step: TestStep::Start, + } + } + + /// Block asynchronously until the desired step has been reached. + /// + /// # Panics + /// + /// This function will panic if the underlying channel gets closed. + pub(super) async fn wait_for_step(&mut self, desired_step: TestStep) { + loop { + if self.step >= desired_step { + break; + } + + match self.receiver.recv().await { + Ok(step) => self.step = step, + Err(RecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(RecvError::Closed) => { + panic!("failed to receive current step, waiting for step: {desired_step}, did the test abort?"); + } + } + } + } + + /// Check whether the desired step has been reached without blocking. + pub(super) fn try_wait_for_step(&mut self, desired_step: TestStep) -> bool { + self.update_step(); + + self.step == desired_step + } + + /// Advance to the next step. + /// + /// The test must be at the step prior to the next step before starting. + /// Being in a different step is likely to indicate a logic error in the + /// test framework. + /// + /// # Panics + /// + /// This method will panic if the test state is not at the step prior to + /// `next_step` or if the underlying channel is closed. + #[track_caller] + pub(super) fn advance_to_step(&mut self, next_step: TestStep) { + self.update_step(); + + if self.step >= next_step { + panic!( + "cannot advance to previous or current step! current step: {current}, next step: {next_step}", + current = self.step); + } + + match (&self.step, &next_step) { + (TestStep::Start, TestStep::ServerStarted) | + (TestStep::ServerStarted, TestStep::ClientConnected) | + (TestStep::ClientConnected, TestStep::TestFinished) | + (TestStep::TestFinished, TestStep::UpdatesRecorded) => {}, + (_, _) => panic!( + "cannot advance more than one step! current step: {current}, next step: {next_step}", + current = self.step), + } + + self.sender + .send(next_step) + .expect("failed to send the next test step, did the test abort?"); + } + + fn update_step(&mut self) { + loop { + match self.receiver.try_recv() { + Ok(step) => self.step = step, + Err(TryRecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(TryRecvError::Closed) => { + panic!("failed to update current step, did the test abort?") + } + Err(TryRecvError::Empty) => break, + } + } + } +} + +impl Clone for TestState { + fn clone(&self) -> Self { + Self { + receiver: self.receiver.resubscribe(), + sender: self.sender.clone(), + step: self.step.clone(), + } + } +} diff --git a/console-subscriber/tests/support/subscriber.rs b/console-subscriber/tests/support/subscriber.rs new file mode 100644 index 000000000..36888ad5a --- /dev/null +++ b/console-subscriber/tests/support/subscriber.rs @@ -0,0 +1,318 @@ +use std::{collections::HashMap, fmt, future::Future, thread}; + +use console_api::{ + field::Value, + instrument::{instrument_client::InstrumentClient, InstrumentRequest}, +}; +use console_subscriber::ServerParts; +use futures::stream::StreamExt; +use tokio::{io::DuplexStream, task}; +use tonic::transport::{Channel, Endpoint, Server, Uri}; +use tower::service_fn; + +use super::state::{TestState, TestStep}; +use super::task::{ActualTask, ExpectedTask, TaskValidationFailure}; + +pub(crate) const MAIN_TASK_NAME: &str = "main"; + +#[derive(Debug)] +struct TestFailure { + failures: Vec, +} + +impl fmt::Display for TestFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Task validation failed:\n")?; + for failure in &self.failures { + write!(f, " - {failure}\n")?; + } + Ok(()) + } +} + +/// Runs the test +/// +/// This function runs the whole test. It sets up a `console-subscriber` layer +/// together with the gRPC server and connects a client to it. The subscriber +/// is then used to record traces as the provided future is driven to +/// completion on a current thread tokio runtime. +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +pub(super) fn run_test(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + use tracing_subscriber::prelude::*; + + let (client_stream, server_stream) = tokio::io::duplex(1024); + let (console_layer, server) = console_subscriber::ConsoleLayer::builder().build(); + let registry = tracing_subscriber::registry().with(console_layer); + + let mut test_state = TestState::new(); + let mut test_state_test = test_state.clone(); + + let join_handle = thread::Builder::new() + .name("console::subscriber".into()) + .spawn(move || { + let _subscriber_guard = + tracing::subscriber::set_default(tracing_core::subscriber::NoSubscriber::default()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .expect("console-test error: failed to initialize console subscriber runtime"); + + runtime.block_on(async move { + task::Builder::new() + .name("console::serve") + .spawn(console_server(server, server_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-server' task"); + + let actual_tasks = task::Builder::new() + .name("console::client") + .spawn(console_client(client_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-client' task") + .await + .expect("console-test error: failed to await 'console-client' task"); + + test_state.advance_to_step(TestStep::UpdatesRecorded); + actual_tasks + }) + }) + .expect("console subscriber could not spawn thread"); + + tracing::subscriber::with_default(registry, || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async move { + test_state_test + .wait_for_step(TestStep::ClientConnected) + .await; + + // Run the future that we are testing. + _ = tokio::task::Builder::new() + .name(MAIN_TASK_NAME) + .spawn(future) + .expect("console-test error: couldn't spawn test task") + .await; + test_state_test.advance_to_step(TestStep::TestFinished); + + test_state_test + .wait_for_step(TestStep::UpdatesRecorded) + .await; + }); + }); + + let actual_tasks = join_handle + .join() + .expect("console-test error: failed to join 'console-subscriber' thread"); + + if let Err(test_failure) = validate_expected_tasks(expected_tasks, actual_tasks) { + panic!("Test failed: {test_failure}") + } +} + +/// Starts the console server. +/// +/// The server will start serving over its side of the duplex stream. +/// +/// Once the server gets spawned into its task, the test state is advanced +/// to the `ServerStarted` step. This function will then wait until the test +/// state reaches the `UpdatesRecorded` step (indicating that all validation of the +/// received updates has been completed) before dropping the aggregator. +/// +/// # Test State +/// +/// 1. Advances to: `ServerStarted` +/// 2. Waits for: `UpdatesRecorded` +async fn console_server( + server: console_subscriber::Server, + server_stream: DuplexStream, + mut test_state: TestState, +) { + let ServerParts { + instrument_server: service, + aggregator, + .. + } = server.into_parts(); + let aggregate = task::Builder::new() + .name("console::aggregate") + .spawn(aggregator.run()) + .expect("client-console error: couldn't spawn aggregator"); + Server::builder() + .add_service(service) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + server_stream, + )])) + .await + .expect("client-console error: couldn't start instrument server."); + test_state.advance_to_step(TestStep::ServerStarted); + + test_state.wait_for_step(TestStep::UpdatesRecorded).await; + aggregate.abort(); +} + +/// Starts the console client and validates the expected tasks. +/// +/// First we wait until the server has started (test step `ServerStarted`), then +/// the client is connected to its half of the duplex stream and we start recording +/// the actual tasks. +/// +/// Once recording finishes (see [`record_actual_tasks()`] for details on the test +/// state condition), the actual tasks returned. +/// +/// # Test State +/// +/// 1. Waits for: `ServerStarted` +/// 2. Advances to: `ClientConnected` +async fn console_client(client_stream: DuplexStream, mut test_state: TestState) -> Vec { + test_state.wait_for_step(TestStep::ServerStarted).await; + + let mut client_stream = Some(client_stream); + let channel = Endpoint::try_from("http://[::]:6669") + .expect("Could not create endpoint") + .connect_with_connector(service_fn(move |_: Uri| { + let client = client_stream.take(); + + async move { + if let Some(client) = client { + Ok(client) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Client already taken", + )) + } + } + })) + .await + .expect("client-console error: couldn't create client"); + test_state.advance_to_step(TestStep::ClientConnected); + + record_actual_tasks(channel, test_state.clone()).await +} + +/// Records the actual tasks which are received by the client channel. +/// +/// Updates will be received until the test state reaches the `TestFinished` step +/// (indicating that the test itself has finished running), at which point we wait +/// for a final update before returning all the actual tasks which were recorded. +/// +/// # Test State +/// +/// 1. Waits for: `TestFinished` +async fn record_actual_tasks( + client_channel: Channel, + mut test_state: TestState, +) -> Vec { + let mut client = InstrumentClient::new(client_channel); + + let mut stream = loop { + let request = tonic::Request::new(InstrumentRequest {}); + match client.watch_updates(request).await { + Ok(stream) => break stream.into_inner(), + Err(err) => panic!("Client cannot connect to watch updates: {err}"), + } + }; + + let mut tasks = HashMap::new(); + + let mut last_update = false; + while let Some(update) = stream.next().await { + match update { + Ok(update) => { + if let Some(task_update) = &update.task_update { + for new_task in &task_update.new_tasks { + if let Some(id) = &new_task.id { + let mut actual_task = ActualTask::new(id.id); + for field in &new_task.fields { + if let Some(console_api::field::Name::StrName(field_name)) = + &field.name + { + if field_name == "task.name" { + actual_task.name = match &field.value { + Some(Value::DebugVal(value)) => Some(value.clone()), + Some(Value::StrVal(value)) => Some(value.clone()), + _ => None, // Anything that isn't string-like shouldn't be used as a name. + }; + } + } + } + tasks.insert(actual_task.id, actual_task); + } + } + + for (id, stats) in &task_update.stats_update { + if let Some(mut task) = tasks.get_mut(id) { + task.wakes = stats.wakes; + task.self_wakes = stats.self_wakes; + } + } + } + } + Err(e) => { + panic!("update stream error: {}", e); + } + } + + if last_update { + break; + } + + if test_state.try_wait_for_step(TestStep::TestFinished) { + // Once the test finishes running, we will get one further update and finish. + last_update = true; + } + } + + tasks.into_values().collect() +} + +/// Validate the expected tasks against the actual tasks. +/// +/// Each expected task is checked in turn. +/// +/// A matching actual task is searched for. If one is found it, the +/// expected task is validated against the actual task. +/// +/// Any validation errors result in failure. If no matches +fn validate_expected_tasks( + expected_tasks: Vec, + actual_tasks: Vec, +) -> Result<(), TestFailure> { + let failures: Vec<_> = expected_tasks + .iter() + .map(|expected| validate_expected_task(expected, &actual_tasks)) + .filter_map(|r| match r { + Ok(_) => None, + Err(validation_error) => Some(validation_error), + }) + .collect(); + + if failures.is_empty() { + Ok(()) + } else { + Err(TestFailure { failures: failures }) + } +} + +fn validate_expected_task( + expected: &ExpectedTask, + actual_tasks: &Vec, +) -> Result<(), TaskValidationFailure> { + for actual in actual_tasks { + if expected.matches_actual_task(actual) { + // We only match a single task. + // FIXME(hds): We should probably create an error or a warning if multiple tasks match. + return expected.validate_actual_task(actual); + } + } + + expected.no_match_error() +} diff --git a/console-subscriber/tests/support/task.rs b/console-subscriber/tests/support/task.rs new file mode 100644 index 000000000..6df878b1b --- /dev/null +++ b/console-subscriber/tests/support/task.rs @@ -0,0 +1,228 @@ +use std::{error, fmt}; + +use super::MAIN_TASK_NAME; + +/// An actual task +/// +/// This struct contains the values recorded from the console subscriber +/// client and represents what is known about an actual task running on +/// the test's runtime. +#[derive(Clone, Debug)] +pub(super) struct ActualTask { + pub(super) id: u64, + pub(super) name: Option, + pub(super) wakes: u64, + pub(super) self_wakes: u64, +} + +impl ActualTask { + pub(super) fn new(id: u64) -> Self { + Self { + id, + name: None, + wakes: 0, + self_wakes: 0, + } + } +} + +/// An error in task validation. +pub(super) struct TaskValidationFailure { + /// The expected task whose expectations were not met. + expected: ExpectedTask, + /// The actual task which failed the validation + actual: Option, + /// A textual description of the validation failure + failure: String, +} + +impl error::Error for TaskValidationFailure {} + +impl fmt::Display for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.failure) + } +} + +impl fmt::Debug for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.actual { + Some(actual) => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\n Actual Task: {actual:?}\n Failure: {failure}", + expected = self.expected, failure = self.failure), + None => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\n Failure: {failure}", + expected = self.expected, failure = self.failure), + } + } +} + +/// An expected task. +/// +/// This struct contains the fields that an expected task will attempt to match +/// actual tasks on, as well as the expectations that will be used to validate +/// which the actual task is as expected. +#[derive(Clone, Debug)] +pub(crate) struct ExpectedTask { + match_name: Option, + + expect_present: Option, + expect_wakes: Option, + expect_self_wakes: Option, +} + +impl Default for ExpectedTask { + fn default() -> Self { + Self { + match_name: None, + expect_present: None, + expect_wakes: None, + expect_self_wakes: None, + } + } +} + +impl ExpectedTask { + /// Returns whether or not an actual task matches this expected task. + /// + /// All matching rules will be run, if they all succeed, then `true` will + /// be returned, otherwise `false`. + pub(super) fn matches_actual_task(&self, actual_task: &ActualTask) -> bool { + if let Some(match_name) = &self.match_name { + if Some(match_name) == actual_task.name.as_ref() { + return true; + } + } + + false + } + + /// Returns an error specifying that no match was found for this expected + /// task. + pub(super) fn no_match_error(&self) -> Result<(), TaskValidationFailure> { + Err(TaskValidationFailure { + expected: self.clone(), + actual: None, + failure: format!("{self}: no matching actual task was found"), + }) + } + + /// Validates all expectations against the provided actual task. + /// + /// No check that the actual task matches is performed. That must have been + /// done prior. + /// + /// If all expections are met, this method returns `Ok(())`. If any + /// expectations are not met, then the first incorrect expectation will + /// be returned as an `Err`. + pub(super) fn validate_actual_task( + &self, + actual_task: &ActualTask, + ) -> Result<(), TaskValidationFailure> { + let mut no_expectations = true; + if let Some(_expected) = self.expect_present { + no_expectations = false; + } + + if let Some(expected_wakes) = self.expect_wakes { + no_expectations = false; + if expected_wakes != actual_task.wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `wakes` to be {expected_wakes}, but actual was {actual_wakes}", + actual_wakes = actual_task.wakes), + }); + } + } + + if let Some(expected_self_wakes) = self.expect_self_wakes { + no_expectations = false; + if expected_self_wakes != actual_task.self_wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `self_wakes` to be {expected_self_wakes}, but actual was {actual_self_wakes}", + actual_self_wakes = actual_task.self_wakes), + }); + } + } + + if no_expectations { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: no expectations set, if you want to just expect that a matching task is present, use `expect_present()`") + }); + } + + Ok(()) + } + + /// Matches tasks by name. + /// + /// To match this expected task, an actual task must have the name `name`. + #[allow(dead_code)] + pub(crate) fn match_name(mut self, name: String) -> Self { + self.match_name = Some(name); + self + } + + /// Matches tasks by the default task name. + /// + /// To match this expected task, an actual task must have the default name + /// assigned to the task which runs the future provided to [`assert_task`] + /// or [`assert_tasks`]. + /// + /// [`assert_task`]: fn@support::assert_task + /// [`assert_tasks`]: fn@support::assert_tasks + #[allow(dead_code)] + pub(crate) fn match_default_name(mut self) -> Self { + self.match_name = Some(MAIN_TASK_NAME.into()); + self + } + + /// Expects that a task is present. + /// + /// To validate, an actual task matching this expected task must be found. + #[allow(dead_code)] + pub(crate) fn expect_present(mut self) -> Self { + self.expect_present = Some(true); + self + } + + /// Expects that a task has a specific value for `wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of wakes equal to `wakes`. + #[allow(dead_code)] + pub(crate) fn expect_wakes(mut self, wakes: u64) -> Self { + self.expect_wakes = Some(wakes); + self + } + + /// Expects that a task has a specific value for `self_wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of self wakes equal to `self_wakes`. + #[allow(dead_code)] + pub(crate) fn expect_self_wakes(mut self, self_wakes: u64) -> Self { + self.expect_self_wakes = Some(self_wakes); + self + } +} + +impl fmt::Display for ExpectedTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let fields = match &self.match_name { + Some(name) => format!("name={name}"), + None => "(no fields to match on)".into(), + }; + write!(f, "Task<{fields}>") + } +} diff --git a/console-subscriber/tests/wake.rs b/console-subscriber/tests/wake.rs new file mode 100644 index 000000000..e64e87a6e --- /dev/null +++ b/console-subscriber/tests/wake.rs @@ -0,0 +1,48 @@ +mod support; +use std::time::Duration; + +use support::{assert_task, ExpectedTask}; +use tokio::{task, time::sleep}; + +#[test] +fn sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn double_sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(2) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(1); + + let future = async { + task::yield_now().await; + }; + + assert_task(expected_task, future); +}