diff --git a/.github/workflows/test-sequencer.yml b/.github/workflows/test-sequencer.yml index 78ff9d8237..3d59116827 100644 --- a/.github/workflows/test-sequencer.yml +++ b/.github/workflows/test-sequencer.yml @@ -55,7 +55,6 @@ jobs: hotshot-state-prover = { path = "${GITHUB_WORKSPACE}/hotshot/crates/hotshot-state-prover" } hotshot-orchestrator = { path = "${GITHUB_WORKSPACE}/hotshot/crates/orchestrator" } hotshot-web-server = { path = "${GITHUB_WORKSPACE}/hotshot/crates/web_server" } - hotshot-task = { path = "${GITHUB_WORKSPACE}/hotshot/crates/task" } hotshot-task-impls = { path = "${GITHUB_WORKSPACE}/hotshot/crates/task-impls" } hotshot-testing = { path = "${GITHUB_WORKSPACE}/hotshot/crates/testing" } hotshot-types = { path = "${GITHUB_WORKSPACE}/hotshot/crates/types" } diff --git a/Cargo.lock b/Cargo.lock index 1894a51996..7ff950596c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,6 +533,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "async-broadcast" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "334d75cf09b33bede6cbc20e52515853ae7bee3d4eadd9540e13ce92af983d34" +dependencies = [ + "event-listener 3.1.0", + "event-listener-strategy 0.1.0", + "futures-core", +] + [[package]] name = "async-channel" version = "1.9.0" @@ -552,7 +563,7 @@ checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" dependencies = [ "concurrent-queue", "event-listener 4.0.3", - "event-listener-strategy", + "event-listener-strategy 0.4.0", "futures-core", "pin-project-lite 0.2.13", ] @@ -700,7 +711,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" dependencies = [ "event-listener 4.0.3", - "event-listener-strategy", + "event-listener-strategy 0.4.0", "pin-project-lite 0.2.13", ] @@ -897,17 +908,6 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "atomic_enum" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6227a8d6fdb862bcb100c4314d0d9579e5cd73fa6df31a2e6f6e1acd3c5f1207" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "attohttpc" version = "0.24.1" @@ -2153,6 +2153,16 @@ dependencies = [ "pin-project-lite 0.2.13", ] +[[package]] +name = "event-listener-strategy" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15c97b4e30ea7e4b7e7b429d6e2d8510433ba8cee4e70dfb3243794e539d29fd" +dependencies = [ + "event-listener 3.1.0", + "pin-project-lite 0.2.13", +] + [[package]] name = "event-listener-strategy" version = "0.4.0" @@ -2667,6 +2677,7 @@ dependencies = [ name = "hotshot" version = "0.3.3" dependencies = [ + "async-broadcast", "async-compatibility-layer", "async-lock 2.8.0", "async-std", @@ -2758,16 +2769,10 @@ dependencies = [ name = "hotshot-task" version = "0.1.0" dependencies = [ + "async-broadcast", "async-compatibility-layer", - "async-lock 2.8.0", "async-std", - "async-trait", - "atomic_enum", - "either", "futures", - "pin-project", - "serde", - "snafu", "tokio", "tracing", ] @@ -2776,6 +2781,7 @@ dependencies = [ name = "hotshot-task-impls" version = "0.1.0" dependencies = [ + "async-broadcast", "async-compatibility-layer", "async-lock 2.8.0", "async-std", @@ -2800,6 +2806,7 @@ dependencies = [ name = "hotshot-testing" version = "0.1.0" dependencies = [ + "async-broadcast", "async-compatibility-layer", "async-lock 2.8.0", "async-std", @@ -2880,7 +2887,6 @@ dependencies = [ "ethereum-types", "generic-array", "hotshot-constants", - "hotshot-task", "hotshot-utils", "jf-plonk", "jf-primitives", diff --git a/Cargo.toml b/Cargo.toml index c1ee9c2f1e..d9b9e2b715 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,9 +33,11 @@ ark-ed-on-bn254 = "0.4" ark-ff = "0.4" ark-serialize = "0.4" ark-std = { version = "0.4", default-features = false } +async-broadcast = "0.6.0" async-compatibility-layer = { git = "https://github.com/EspressoSystems/async-compatibility-layer.git", tag = "1.4.1", default-features = false, features = [ "logging-utils", ] } +task = { git = "https://github.com/EspressoSystems/HotShotTasks.git" } async-lock = "2.8" async-trait = "0.1.77" bincode = "1.3.3" diff --git a/crates/constants/src/lib.rs b/crates/constants/src/lib.rs index 3a44c0902a..621df3cdfe 100644 --- a/crates/constants/src/lib.rs +++ b/crates/constants/src/lib.rs @@ -28,3 +28,6 @@ pub struct Version { /// Constant for protocol version 0.1. pub const VERSION_0_1: Version = Version { major: 0, minor: 1 }; + +/// Default Channel Size for consensus event sharing +pub const EVENT_CHANNEL_SIZE: usize = 100_000; diff --git a/crates/hotshot/Cargo.toml b/crates/hotshot/Cargo.toml index d80337bd88..546489612c 100644 --- a/crates/hotshot/Cargo.toml +++ b/crates/hotshot/Cargo.toml @@ -79,6 +79,7 @@ name = "orchestrator-combined" path = "examples/combined/orchestrator.rs" [dependencies] +async-broadcast = { workspace = true } async-compatibility-layer = { workspace = true } async-lock = { workspace = true } async-trait = { workspace = true } @@ -96,7 +97,6 @@ hotshot-web-server = { version = "0.1.1", path = "../web_server", default-featur hotshot-orchestrator = { version = "0.1.1", path = "../orchestrator", default-features = false } hotshot-types = { path = "../types", version = "0.1.0", default-features = false } hotshot-utils = { path = "../utils" } -hotshot-task = { path = "../task", version = "0.1.0", default-features = false } hotshot-task-impls = { path = "../task-impls", version = "0.1.0", default-features = false } libp2p-identity = { workspace = true } libp2p-networking = { workspace = true } @@ -108,6 +108,7 @@ time = { workspace = true } derive_more = "0.99.17" portpicker = "0.1.1" lru = "0.12.2" +hotshot-task = { path = "../task" } tracing = { workspace = true } diff --git a/crates/hotshot/examples/infra/mod.rs b/crates/hotshot/examples/infra/mod.rs index 6e391b72b8..4d97d4cc3e 100644 --- a/crates/hotshot/examples/infra/mod.rs +++ b/crates/hotshot/examples/infra/mod.rs @@ -23,7 +23,6 @@ use hotshot_orchestrator::{ client::{OrchestratorClient, ValidatorArgs}, config::{NetworkConfig, NetworkConfigFile, WebServerConfig}, }; -use hotshot_task::task::FilterEvent; use hotshot_testing::{ block_types::{TestBlockHeader, TestBlockPayload, TestTransaction}, state_types::TestInstanceState, @@ -393,7 +392,7 @@ pub trait RunDA< /// Starts HotShot consensus, returns when consensus has finished async fn run_hotshot( &self, - mut context: SystemContextHandle, + context: SystemContextHandle, transactions: &mut Vec, transactions_to_send_per_round: u64, ) { @@ -413,7 +412,7 @@ pub trait RunDA< error!("Starting HotShot example!"); let start = Instant::now(); - let (mut event_stream, _streamid) = context.get_event_stream(FilterEvent::default()).await; + let mut event_stream = context.get_event_stream(); let mut anchor_view: TYPES::Time = ::genesis(); let mut num_successful_commits = 0; diff --git a/crates/hotshot/src/lib.rs b/crates/hotshot/src/lib.rs index 7d93a97449..e948c47879 100644 --- a/crates/hotshot/src/lib.rs +++ b/crates/hotshot/src/lib.rs @@ -20,19 +20,19 @@ use crate::{ traits::{NodeImplementation, Storage}, types::{Event, SystemContextHandle}, }; +use async_broadcast::{broadcast, InactiveReceiver, Receiver, Sender}; use async_compatibility_layer::art::async_spawn; use async_lock::RwLock; use async_trait::async_trait; use commit::Committable; use custom_debug::Debug; use futures::join; -use hotshot_constants::VERSION_0_1; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - task_launcher::TaskRunner, -}; -use hotshot_task_impls::{events::HotShotEvent, network::NetworkTaskKind}; +use hotshot_constants::{EVENT_CHANNEL_SIZE, VERSION_0_1}; +use hotshot_task_impls::events::HotShotEvent; +use hotshot_task_impls::helpers::broadcast_event; +use hotshot_task_impls::network; +use hotshot_task::task::TaskRegistry; use hotshot_types::{ consensus::{Consensus, ConsensusMetricsValue, View, ViewInner}, data::Leaf, @@ -60,7 +60,7 @@ use std::{ time::Duration, }; use tasks::add_vid_task; -use tracing::{debug, info, instrument, trace}; +use tracing::{debug, instrument, trace}; // -- Rexports // External @@ -122,7 +122,7 @@ pub struct SystemContextInner> { private_key: ::PrivateKey, /// Configuration items for this hotshot instance - config: HotShotConfig, + pub config: HotShotConfig, /// This `HotShot` instance's storage backend storage: I::Storage, @@ -141,13 +141,16 @@ pub struct SystemContextInner> { // global_registry: GlobalRegistry, /// Access to the output event stream. - output_event_stream: ChannelStream>, + pub output_event_stream: (Sender>, InactiveReceiver>), /// access to the internal event stream, in case we need to, say, shut something down - internal_event_stream: ChannelStream>, + internal_event_stream: ( + Sender>, + InactiveReceiver>, + ), /// uid for instrumentation - id: u64, + pub id: u64, } /// Thread safe, shared view of a `HotShot` @@ -236,6 +239,13 @@ impl> SystemContext { }; let consensus = Arc::new(RwLock::new(consensus)); + let (internal_tx, internal_rx) = broadcast(EVENT_CHANNEL_SIZE); + let (mut external_tx, external_rx) = broadcast(EVENT_CHANNEL_SIZE); + + // This makes it so we won't block on broadcasting if there is not a receiver + // Our own copy of the receiver is inactive so it doesn't count. + external_tx.set_await_active(false); + let inner: Arc> = Arc::new(SystemContextInner { id: nonce, consensus, @@ -246,21 +256,27 @@ impl> SystemContext { networks: Arc::new(networks), memberships: Arc::new(memberships), _metrics: consensus_metrics.clone(), - internal_event_stream: ChannelStream::new(), - output_event_stream: ChannelStream::new(), + internal_event_stream: (internal_tx, internal_rx.deactivate()), + output_event_stream: (external_tx, external_rx.deactivate()), }); Ok(Self { inner }) } /// "Starts" consensus by sending a `QCFormed` event + /// + /// # Panics + /// Panics if sending genesis fails pub async fn start_consensus(&self) { + debug!("Starting Consensus"); self.inner .internal_event_stream - .publish(HotShotEvent::QCFormed(either::Left( + .0 + .broadcast_direct(HotShotEvent::QCFormed(either::Left( QuorumCertificate::genesis(), ))) - .await; + .await + .expect("Genesis Broadcast failed"); } /// Emit an external event @@ -268,7 +284,7 @@ impl> SystemContext { // TODO: remove with https://github.com/EspressoSystems/HotShot/issues/2407 async fn send_external_event(&self, event: Event) { debug!(?event, "send_external_event"); - self.inner.output_event_stream.publish(event).await; + broadcast_event(event, &self.inner.output_event_stream.0).await; } /// Publishes a transaction asynchronously to the network @@ -397,7 +413,8 @@ impl> SystemContext { ) -> Result< ( SystemContextHandle, - ChannelStream>, + Sender>, + Receiver>, ), HotShotError, > { @@ -415,9 +432,9 @@ impl> SystemContext { ) .await?; let handle = hotshot.clone().run_tasks().await; - let internal_event_stream = hotshot.inner.internal_event_stream.clone(); + let (tx, rx) = hotshot.inner.internal_event_stream.clone(); - Ok((handle, internal_event_stream)) + Ok((handle, tx, rx.activate())) } /// return the timeout for a view for `self` #[must_use] @@ -439,8 +456,7 @@ impl> SystemContext { #[allow(clippy::too_many_lines)] pub async fn run_tasks(self) -> SystemContextHandle { // ED Need to set first first number to 1, or properly trigger the change upon start - let task_runner = TaskRunner::new(); - let registry = task_runner.registry.clone(); + let registry = Arc::new(TaskRegistry::default()); let output_event_stream = self.inner.output_event_stream.clone(); let internal_event_stream = self.inner.internal_event_stream.clone(); @@ -452,80 +468,97 @@ impl> SystemContext { let vid_membership = self.inner.memberships.vid_membership.clone(); let view_sync_membership = self.inner.memberships.view_sync_membership.clone(); + let (event_tx, event_rx) = internal_event_stream.clone(); + let handle = SystemContextHandle { - registry, + registry: registry.clone(), output_event_stream: output_event_stream.clone(), internal_event_stream: internal_event_stream.clone(), hotshot: self.clone(), storage: self.inner.storage.clone(), }; - let task_runner = add_network_message_task( - task_runner, - internal_event_stream.clone(), - quorum_network.clone(), - ) - .await; - let task_runner = add_network_message_task( - task_runner, - internal_event_stream.clone(), - da_network.clone(), - ) - .await; + add_network_message_task(registry.clone(), event_tx.clone(), quorum_network.clone()).await; + add_network_message_task(registry.clone(), event_tx.clone(), da_network.clone()).await; - let task_runner = add_network_event_task( - task_runner, - internal_event_stream.clone(), + add_network_event_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), quorum_network.clone(), quorum_membership, - NetworkTaskKind::Quorum, + network::quorum_filter, ) .await; - let task_runner = add_network_event_task( - task_runner, - internal_event_stream.clone(), + add_network_event_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), da_network.clone(), da_membership, - NetworkTaskKind::Committee, + network::committee_filter, ) .await; - let task_runner = add_network_event_task( - task_runner, - internal_event_stream.clone(), + add_network_event_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), quorum_network.clone(), view_sync_membership, - NetworkTaskKind::ViewSync, + network::view_sync_filter, ) .await; - let task_runner = add_network_event_task( - task_runner, - internal_event_stream.clone(), + add_network_event_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), quorum_network.clone(), vid_membership, - NetworkTaskKind::VID, + network::vid_filter, ) .await; - let task_runner = add_consensus_task( - task_runner, - internal_event_stream.clone(), - output_event_stream.clone(), - handle.clone(), + add_consensus_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), + &handle, + ) + .await; + add_da_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), + &handle, + ) + .await; + add_vid_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), + &handle, + ) + .await; + add_transaction_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), + &handle, + ) + .await; + add_view_sync_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), + &handle, + ) + .await; + add_upgrade_task( + registry.clone(), + event_tx.clone(), + event_rx.activate_cloned(), + &handle, ) .await; - let task_runner = - add_da_task(task_runner, internal_event_stream.clone(), handle.clone()).await; - let task_runner = - add_vid_task(task_runner, internal_event_stream.clone(), handle.clone()).await; - let task_runner = - add_transaction_task(task_runner, internal_event_stream.clone(), handle.clone()).await; - let task_runner = - add_view_sync_task(task_runner, internal_event_stream.clone(), handle.clone()).await; - let task_runner = - add_upgrade_task(task_runner, internal_event_stream.clone(), handle.clone()).await; - async_spawn(async move { - let _ = task_runner.launch().await; - info!("Task runner exited!"); - }); handle } } @@ -563,7 +596,7 @@ impl> ConsensusApi async fn send_event(&self, event: Event) { debug!(?event, "send_event"); - self.inner.output_event_stream.publish(event).await; + broadcast_event(event, &self.inner.output_event_stream.0).await; } fn public_key(&self) -> &TYPES::SignatureKey { diff --git a/crates/hotshot/src/tasks/mod.rs b/crates/hotshot/src/tasks/mod.rs index 05c053f09a..aafa3b1d5b 100644 --- a/crates/hotshot/src/tasks/mod.rs +++ b/crates/hotshot/src/tasks/mod.rs @@ -1,30 +1,19 @@ //! Provides a number of tasks that run continuously use crate::{types::SystemContextHandle, HotShotConsensusApi}; -use async_compatibility_layer::art::async_sleep; -use futures::FutureExt; -use hotshot_task::{ - boxed_sync, - event_stream::ChannelStream, - task::{FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted, HotShotTaskTypes}, - task_impls::TaskBuilder, - task_launcher::TaskRunner, - GeneratedStream, Merge, -}; +use async_broadcast::{Receiver, Sender}; +use async_compatibility_layer::art::{async_sleep, async_spawn}; + +use hotshot_task::task::{Task, TaskRegistry}; use hotshot_task_impls::{ - consensus::{ - consensus_event_filter, CommitmentAndMetadata, ConsensusTaskState, ConsensusTaskTypes, - }, - da::{DATaskState, DATaskTypes}, + consensus::{CommitmentAndMetadata, ConsensusTaskState}, + da::DATaskState, events::HotShotEvent, - network::{ - NetworkEventTaskState, NetworkEventTaskTypes, NetworkMessageTaskState, - NetworkMessageTaskTypes, NetworkTaskKind, - }, - transactions::{TransactionTaskState, TransactionsTaskTypes}, - upgrade::{UpgradeTaskState, UpgradeTaskTypes}, - vid::{VIDTaskState, VIDTaskTypes}, - view_sync::{ViewSyncTaskState, ViewSyncTaskStateTypes}, + network::{NetworkEventTaskState, NetworkMessageTaskState}, + transactions::TransactionTaskState, + upgrade::UpgradeTaskState, + vid::VIDTaskState, + view_sync::ViewSyncTaskState, }; use hotshot_types::traits::election::Membership; use hotshot_types::{ @@ -56,165 +45,96 @@ pub enum GlobalEvent { } /// Add the network task to handle messages and publish events. -/// # Panics -/// Is unable to panic. This section here is just to satisfy clippy pub async fn add_network_message_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, + task_reg: Arc, + event_stream: Sender>, channel: NET, -) -> TaskRunner { +) { let net = channel.clone(); - let broadcast_stream = GeneratedStream::>::new(Arc::new(move || { - let network = net.clone(); - let closure = async move { - loop { - let msgs = match network.recv_msgs(TransmitType::Broadcast).await { - Ok(msgs) => Messages(msgs), - Err(err) => { - error!("failed to receive broadcast messages: {err}"); + let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState { + event_stream: event_stream.clone(), + }; - // return zero messages so we sleep and try again - Messages(vec![]) - } - }; + // TODO we don't need two async tasks for this, we should combine the + // by getting rid of `TransmitType` + // https://github.com/EspressoSystems/HotShot/issues/2377 + let network = net.clone(); + let mut state = network_state.clone(); + let broadcast_handle = async_spawn(async move { + loop { + let msgs = match network.recv_msgs(TransmitType::Broadcast).await { + Ok(msgs) => Messages(msgs), + Err(err) => { + error!("failed to receive broadcast messages: {err}"); - if msgs.0.is_empty() { - async_sleep(Duration::from_millis(100)).await; - } else { - break msgs; + // return zero messages so we sleep and try again + Messages(vec![]) } + }; + if msgs.0.is_empty() { + // TODO: Stop sleeping here: https://github.com/EspressoSystems/HotShot/issues/2558 + async_sleep(Duration::from_millis(100)).await; + } else { + state.handle_messages(msgs.0).await; } - }; - Some(boxed_sync(closure)) - })); - let net = channel.clone(); - let direct_stream = GeneratedStream::>::new(Arc::new(move || { - let network = net.clone(); - let closure = async move { - loop { - let msgs = match network.recv_msgs(TransmitType::Direct).await { - Ok(msgs) => Messages(msgs), - Err(err) => { - error!("failed to receive direct messages: {err}"); + } + }); + let network = net.clone(); + let mut state = network_state.clone(); + let direct_handle = async_spawn(async move { + loop { + let msgs = match network.recv_msgs(TransmitType::Direct).await { + Ok(msgs) => Messages(msgs), + Err(err) => { + error!("failed to receive direct messages: {err}"); - // return zero messages so we sleep and try again - Messages(vec![]) - } - }; - if msgs.0.is_empty() { - async_sleep(Duration::from_millis(100)).await; - } else { - break msgs; + // return zero messages so we sleep and try again + Messages(vec![]) } - } - }; - Some(boxed_sync(closure)) - })); - let message_stream = Merge::new(broadcast_stream, direct_stream); - let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState { - event_stream: event_stream.clone(), - }; - let registry = task_runner.registry.clone(); - let network_message_handler = HandleMessage(Arc::new( - move |messages: either::Either, Messages>, - mut state: NetworkMessageTaskState| { - let messages = match messages { - either::Either::Left(messages) | either::Either::Right(messages) => messages, }; - async move { - state.handle_messages(messages.0).await; - (None, state) + if msgs.0.is_empty() { + // TODO: Stop sleeping here: https://github.com/EspressoSystems/HotShot/issues/2558 + async_sleep(Duration::from_millis(100)).await; + } else { + state.handle_messages(msgs.0).await; } - .boxed() - }, - )); - let networking_name = "Networking Task"; - - let networking_task_builder = - TaskBuilder::>::new(networking_name.to_string()) - .register_message_stream(message_stream) - .register_registry(&mut registry.clone()) - .await - .register_state(network_state) - .register_message_handler(network_message_handler); - - // impossible for unwraps to fail - // we *just* registered - let networking_task_id = networking_task_builder.get_task_id().unwrap(); - let networking_task = NetworkMessageTaskTypes::build(networking_task_builder).launch(); - - task_runner.add_task( - networking_task_id, - networking_name.to_string(), - networking_task, - ) + } + }); + task_reg.register(direct_handle).await; + task_reg.register(broadcast_handle).await; } /// Add the network task to handle events and send messages. -/// # Panics -/// Is unable to panic. This section here is just to satisfy clippy pub async fn add_network_event_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, + task_reg: Arc, + tx: Sender>, + rx: Receiver>, channel: NET, membership: TYPES::Membership, - task_kind: NetworkTaskKind, -) -> TaskRunner { - let filter = NetworkEventTaskState::::filter(task_kind); + filter: fn(&HotShotEvent) -> bool, +) { let network_state: NetworkEventTaskState<_, _> = NetworkEventTaskState { channel, - event_stream: event_stream.clone(), view: TYPES::Time::genesis(), + membership, + filter, }; - let registry = task_runner.registry.clone(); - let network_event_handler = HandleEvent(Arc::new( - move |event, mut state: NetworkEventTaskState<_, _>| { - let mem = membership.clone(); - - async move { - let completion_status = state.handle_event(event, &mem).await; - (completion_status, state) - } - .boxed() - }, - )); - let networking_name = "Networking Task"; - - let networking_task_builder = - TaskBuilder::>::new(networking_name.to_string()) - .register_event_stream(event_stream.clone(), filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(network_state) - .register_event_handler(network_event_handler); - - // impossible for unwraps to fail - // we *just* registered - let networking_task_id = networking_task_builder.get_task_id().unwrap(); - let networking_task = NetworkEventTaskTypes::build(networking_task_builder).launch(); - - task_runner.add_task( - networking_task_id, - networking_name.to_string(), - networking_task, - ) + let task = Task::new(tx, rx, task_reg.clone(), network_state); + task_reg.run_task(task).await; } -/// add the consensus task +/// Create the consensus task state /// # Panics -/// Is unable to panic. This section here is just to satisfy clippy -pub async fn add_consensus_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, - output_stream: ChannelStream>, - handle: SystemContextHandle, -) -> TaskRunner { +/// If genesis payload can't be encoded. This should not be possible +pub async fn create_consensus_state>( + output_stream: Sender>, + handle: &SystemContextHandle, +) -> ConsensusTaskState> { let consensus = handle.hotshot.get_consensus(); let c_api: HotShotConsensusApi = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; - let registry = task_runner.registry.clone(); + let (payload, metadata) = ::genesis(); // Impossible for `unwrap` to fail on the genesis payload. let payload_commitment = vid_commitment( @@ -228,7 +148,6 @@ pub async fn add_consensus_task>( ); // build the consensus task let consensus_state = ConsensusTaskState { - registry: registry.clone(), consensus, timeout: handle.hotshot.inner.config.next_view_timeout, cur_view: TYPES::Time::new(0), @@ -243,7 +162,6 @@ pub async fn add_consensus_task>( timeout_vote_collector: None.into(), timeout_task: None, timeout_cert: None, - event_stream: event_stream.clone(), output_event_stream: output_stream, vid_shares: HashMap::new(), current_proposal: None, @@ -283,58 +201,34 @@ pub async fn add_consensus_task>( .quorum_network .inject_consensus_info(ConsensusIntentEvent::PollForLatestViewSyncCertificate) .await; + consensus_state +} - let filter = FilterEvent(Arc::new(consensus_event_filter)); - let consensus_name = "Consensus Task"; - let consensus_event_handler = HandleEvent(Arc::new( - move |event, mut state: ConsensusTaskState>| { - async move { - if let HotShotEvent::Shutdown = event { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - state.handle_event(event).await; - (None, state) - } - } - .boxed() - }, - )); - let consensus_task_builder = TaskBuilder::< - ConsensusTaskTypes>, - >::new(consensus_name.to_string()) - .register_event_stream(event_stream.clone(), filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(consensus_state) - .register_event_handler(consensus_event_handler); - // impossible for unwrap to fail - // we *just* registered - let consensus_task_id = consensus_task_builder.get_task_id().unwrap(); - let consensus_task = ConsensusTaskTypes::build(consensus_task_builder).launch(); - - task_runner.add_task( - consensus_task_id, - consensus_name.to_string(), - consensus_task, - ) +/// add the consensus task +pub async fn add_consensus_task>( + task_reg: Arc, + tx: Sender>, + rx: Receiver>, + handle: &SystemContextHandle, +) { + let state = + create_consensus_state(handle.hotshot.inner.output_event_stream.0.clone(), handle).await; + let task = Task::new(tx, rx, task_reg.clone(), state); + task_reg.run_task(task).await; } /// add the VID task -/// # Panics -/// Is unable to panic. This section here is just to satisfy clippy pub async fn add_vid_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, - handle: SystemContextHandle, -) -> TaskRunner { + task_reg: Arc, + tx: Sender>, + rx: Receiver>, + handle: &SystemContextHandle, +) { // build the vid task let c_api: HotShotConsensusApi = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; - let registry = task_runner.registry.clone(); let vid_state = VIDTaskState { - registry: registry.clone(), api: c_api.clone(), consensus: handle.hotshot.get_consensus(), cur_view: TYPES::Time::new(0), @@ -343,38 +237,11 @@ pub async fn add_vid_task>( membership: c_api.inner.memberships.vid_membership.clone().into(), public_key: c_api.public_key().clone(), private_key: c_api.private_key().clone(), - event_stream: event_stream.clone(), id: handle.hotshot.inner.id, }; - let vid_event_handler = HandleEvent(Arc::new( - move |event, mut state: VIDTaskState>| { - async move { - let completion_status = state.handle_event(event).await; - (completion_status, state) - } - .boxed() - }, - )); - let vid_name = "VID Task"; - let vid_event_filter = FilterEvent(Arc::new( - VIDTaskState::>::filter, - )); - let vid_task_builder = - TaskBuilder::>>::new( - vid_name.to_string(), - ) - .register_event_stream(event_stream.clone(), vid_event_filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(vid_state) - .register_event_handler(vid_event_handler); - // impossible for unwrap to fail - // we *just* registered - let vid_task_id = vid_task_builder.get_task_id().unwrap(); - let vid_task = VIDTaskTypes::build(vid_task_builder).launch(); - task_runner.add_task(vid_task_id, vid_name.to_string(), vid_task) + let task = Task::new(tx, rx, task_reg.clone(), vid_state); + task_reg.run_task(task).await; } /// add the Upgrade task. @@ -383,72 +250,41 @@ pub async fn add_vid_task>( /// /// Uses .`unwrap()`, though this should never panic. pub async fn add_upgrade_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, - handle: SystemContextHandle, -) -> TaskRunner { + task_reg: Arc, + tx: Sender>, + rx: Receiver>, + handle: &SystemContextHandle, +) { let c_api: HotShotConsensusApi = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; - let registry = task_runner.registry.clone(); let upgrade_state = UpgradeTaskState { api: c_api.clone(), - registry: registry.clone(), cur_view: TYPES::Time::new(0), quorum_membership: c_api.inner.memberships.quorum_membership.clone().into(), quorum_network: c_api.inner.networks.quorum_network.clone().into(), should_vote: |_upgrade_proposal| false, vote_collector: None.into(), - event_stream: event_stream.clone(), public_key: c_api.public_key().clone(), private_key: c_api.private_key().clone(), id: handle.hotshot.inner.id, }; - let upgrade_event_handler = HandleEvent(Arc::new( - move |event, mut state: UpgradeTaskState>| { - async move { - let completion_status = state.handle_event(event).await; - (completion_status, state) - } - .boxed() - }, - )); - let upgrade_name = "Upgrade Task"; - let upgrade_event_filter = FilterEvent(Arc::new( - UpgradeTaskState::>::filter, - )); - - let upgrade_task_builder = TaskBuilder::< - UpgradeTaskTypes>, - >::new(upgrade_name.to_string()) - .register_event_stream(event_stream.clone(), upgrade_event_filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(upgrade_state) - .register_event_handler(upgrade_event_handler); - // impossible for unwrap to fail - // we *just* registered - let upgrade_task_id = upgrade_task_builder.get_task_id().unwrap(); - let upgrade_task = UpgradeTaskTypes::build(upgrade_task_builder).launch(); - task_runner.add_task(upgrade_task_id, upgrade_name.to_string(), upgrade_task) + let task = Task::new(tx, rx, task_reg.clone(), upgrade_state); + task_reg.run_task(task).await; } /// add the Data Availability task -/// # Panics -/// Is unable to panic. This section here is just to satisfy clippy pub async fn add_da_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, - handle: SystemContextHandle, -) -> TaskRunner { + task_reg: Arc, + tx: Sender>, + rx: Receiver>, + handle: &SystemContextHandle, +) { // build the da task let c_api: HotShotConsensusApi = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; - let registry = task_runner.registry.clone(); let da_state = DATaskState { - registry: registry.clone(), api: c_api.clone(), consensus: handle.hotshot.get_consensus(), da_membership: c_api.inner.memberships.da_membership.clone().into(), @@ -456,56 +292,27 @@ pub async fn add_da_task>( quorum_membership: c_api.inner.memberships.quorum_membership.clone().into(), cur_view: TYPES::Time::new(0), vote_collector: None.into(), - event_stream: event_stream.clone(), public_key: c_api.public_key().clone(), private_key: c_api.private_key().clone(), id: handle.hotshot.inner.id, }; - let da_event_handler = HandleEvent(Arc::new( - move |event, mut state: DATaskState>| { - async move { - let completion_status = state.handle_event(event).await; - (completion_status, state) - } - .boxed() - }, - )); - let da_name = "DA Task"; - let da_event_filter = FilterEvent(Arc::new( - DATaskState::>::filter, - )); - let da_task_builder = TaskBuilder::>>::new( - da_name.to_string(), - ) - .register_event_stream(event_stream.clone(), da_event_filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(da_state) - .register_event_handler(da_event_handler); - // impossible for unwrap to fail - // we *just* registered - let da_task_id = da_task_builder.get_task_id().unwrap(); - let da_task = DATaskTypes::build(da_task_builder).launch(); - task_runner.add_task(da_task_id, da_name.to_string(), da_task) + let task = Task::new(tx, rx, task_reg.clone(), da_state); + task_reg.run_task(task).await; } /// add the Transaction Handling task -/// # Panics -/// Is unable to panic. This section here is just to satisfy clippy pub async fn add_transaction_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, - handle: SystemContextHandle, -) -> TaskRunner { + task_reg: Arc, + tx: Sender>, + rx: Receiver>, + handle: &SystemContextHandle, +) { // build the transactions task let c_api: HotShotConsensusApi = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; - let registry = task_runner.registry.clone(); let transactions_state = TransactionTaskState { - registry: registry.clone(), api: c_api.clone(), consensus: handle.hotshot.get_consensus(), transactions: Arc::default(), @@ -515,53 +322,24 @@ pub async fn add_transaction_task> membership: c_api.inner.memberships.quorum_membership.clone().into(), public_key: c_api.public_key().clone(), private_key: c_api.private_key().clone(), - event_stream: event_stream.clone(), id: handle.hotshot.inner.id, }; - let transactions_event_handler = HandleEvent(Arc::new( - move |event, mut state: TransactionTaskState>| { - async move { - let completion_status = state.handle_event(event).await; - (completion_status, state) - } - .boxed() - }, - )); - let transactions_name = "Transactions Task"; - let transactions_event_filter = FilterEvent(Arc::new( - TransactionTaskState::>::filter, - )); - let transactions_task_builder = TaskBuilder::< - TransactionsTaskTypes>, - >::new(transactions_name.to_string()) - .register_event_stream(event_stream.clone(), transactions_event_filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(transactions_state) - .register_event_handler(transactions_event_handler); - // impossible for unwrap to fail - // we *just* registered - let da_task_id = transactions_task_builder.get_task_id().unwrap(); - let da_task = TransactionsTaskTypes::build(transactions_task_builder).launch(); - task_runner.add_task(da_task_id, transactions_name.to_string(), da_task) + let task = Task::new(tx, rx, task_reg.clone(), transactions_state); + task_reg.run_task(task).await; } /// add the view sync task -/// # Panics -/// Is unable to panic. This section here is just to satisfy clippy pub async fn add_view_sync_task>( - task_runner: TaskRunner, - event_stream: ChannelStream>, - handle: SystemContextHandle, -) -> TaskRunner { + task_reg: Arc, + tx: Sender>, + rx: Receiver>, + handle: &SystemContextHandle, +) { let api = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; // build the view sync task let view_sync_state = ViewSyncTaskState { - registry: task_runner.registry.clone(), - event_stream: event_stream.clone(), current_view: TYPES::Time::new(0), next_view: TYPES::Time::new(0), network: api.inner.networks.quorum_network.clone().into(), @@ -578,42 +356,7 @@ pub async fn add_view_sync_task>( id: handle.hotshot.inner.id, last_garbage_collected_view: TYPES::Time::new(0), }; - let registry = task_runner.registry.clone(); - let view_sync_event_handler = HandleEvent(Arc::new( - move |event, mut state: ViewSyncTaskState>| { - async move { - if let HotShotEvent::Shutdown = event { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - state.handle_event(event).await; - (None, state) - } - } - .boxed() - }, - )); - let view_sync_name = "ViewSync Task"; - let view_sync_event_filter = FilterEvent(Arc::new( - ViewSyncTaskState::>::filter, - )); - - let view_sync_task_builder = TaskBuilder::< - ViewSyncTaskStateTypes>, - >::new(view_sync_name.to_string()) - .register_event_stream(event_stream.clone(), view_sync_event_filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(view_sync_state) - .register_event_handler(view_sync_event_handler); - // impossible for unwrap to fail - // we *just* registered - let view_sync_task_id = view_sync_task_builder.get_task_id().unwrap(); - let view_sync_task = ViewSyncTaskStateTypes::build(view_sync_task_builder).launch(); - task_runner.add_task( - view_sync_task_id, - view_sync_name.to_string(), - view_sync_task, - ) + let task = Task::new(tx, rx, task_reg.clone(), view_sync_state); + task_reg.run_task(task).await; } diff --git a/crates/hotshot/src/traits/networking/combined_network.rs b/crates/hotshot/src/traits/networking/combined_network.rs index 81471dfe1f..9f061e8bd4 100644 --- a/crates/hotshot/src/traits/networking/combined_network.rs +++ b/crates/hotshot/src/traits/networking/combined_network.rs @@ -19,10 +19,10 @@ use async_trait::async_trait; use futures::join; use async_compatibility_layer::channel::UnboundedSendError; -use hotshot_task::{boxed_sync, BoxSyncFuture}; #[cfg(feature = "hotshot-testing")] use hotshot_types::traits::network::{NetworkReliability, TestableNetworkingImplementation}; use hotshot_types::{ + boxed_sync, data::ViewNumber, message::Message, traits::{ @@ -33,6 +33,7 @@ use hotshot_types::{ }, node_implementation::NodeType, }, + BoxSyncFuture, }; use std::{collections::hash_map::DefaultHasher, sync::Arc}; diff --git a/crates/hotshot/src/traits/networking/libp2p_network.rs b/crates/hotshot/src/traits/networking/libp2p_network.rs index 19f384d75c..9a8162c6ba 100644 --- a/crates/hotshot/src/traits/networking/libp2p_network.rs +++ b/crates/hotshot/src/traits/networking/libp2p_network.rs @@ -13,10 +13,10 @@ use async_trait::async_trait; use bimap::BiHashMap; use bincode::Options; use hotshot_constants::{Version, LOOK_AHEAD, VERSION_0_1}; -use hotshot_task::{boxed_sync, BoxSyncFuture}; #[cfg(feature = "hotshot-testing")] use hotshot_types::traits::network::{NetworkReliability, TestableNetworkingImplementation}; use hotshot_types::{ + boxed_sync, data::ViewNumber, message::{Message, MessageKind}, traits::{ @@ -28,6 +28,7 @@ use hotshot_types::{ node_implementation::{ConsensusTime, NodeType}, signature_key::SignatureKey, }, + BoxSyncFuture, }; use hotshot_utils::{bincode::bincode_opts, version::read_version}; use libp2p_identity::PeerId; diff --git a/crates/hotshot/src/traits/networking/memory_network.rs b/crates/hotshot/src/traits/networking/memory_network.rs index 3e28a5871e..ae9f1c53dc 100644 --- a/crates/hotshot/src/traits/networking/memory_network.rs +++ b/crates/hotshot/src/traits/networking/memory_network.rs @@ -13,8 +13,8 @@ use async_trait::async_trait; use bincode::Options; use dashmap::DashMap; use futures::StreamExt; -use hotshot_task::{boxed_sync, BoxSyncFuture}; use hotshot_types::{ + boxed_sync, message::{Message, MessageKind}, traits::{ election::Membership, @@ -25,6 +25,7 @@ use hotshot_types::{ node_implementation::NodeType, signature_key::SignatureKey, }, + BoxSyncFuture, }; use hotshot_utils::bincode::bincode_opts; use rand::Rng; @@ -300,7 +301,7 @@ impl ConnectedNetwork for Memory message: M, recipients: BTreeSet, ) -> Result<(), NetworkError> { - debug!(?message, "Broadcasting message"); + trace!(?message, "Broadcasting message"); // Bincode the message let vec = bincode_opts() .serialize(&message) @@ -348,7 +349,7 @@ impl ConnectedNetwork for Memory #[instrument(name = "MemoryNetwork::direct_message")] async fn direct_message(&self, message: M, recipient: K) -> Result<(), NetworkError> { - debug!(?message, ?recipient, "Sending direct message"); + // debug!(?message, ?recipient, "Sending direct message"); // Bincode the message let vec = bincode_opts() .serialize(&message) diff --git a/crates/hotshot/src/traits/networking/web_server_network.rs b/crates/hotshot/src/traits/networking/web_server_network.rs index 604e3f109d..a5dcb90a30 100644 --- a/crates/hotshot/src/traits/networking/web_server_network.rs +++ b/crates/hotshot/src/traits/networking/web_server_network.rs @@ -13,8 +13,8 @@ use async_lock::RwLock; use async_trait::async_trait; use derive_more::{Deref, DerefMut}; use hotshot_constants::VERSION_0_1; -use hotshot_task::{boxed_sync, BoxSyncFuture}; use hotshot_types::{ + boxed_sync, message::{Message, MessagePurpose}, traits::{ network::{ @@ -25,6 +25,7 @@ use hotshot_types::{ node_implementation::NodeType, signature_key::SignatureKey, }, + BoxSyncFuture, }; use hotshot_utils::version::read_version; use hotshot_web_server::{self, config}; diff --git a/crates/hotshot/src/types/handle.rs b/crates/hotshot/src/types/handle.rs index 862fa8d27a..fc322a0159 100644 --- a/crates/hotshot/src/types/handle.rs +++ b/crates/hotshot/src/types/handle.rs @@ -1,20 +1,17 @@ //! Provides an event-streaming handle for a [`SystemContext`] running in the background use crate::{traits::NodeImplementation, types::Event, SystemContext}; -use async_compatibility_layer::channel::UnboundedStream; +use async_broadcast::{InactiveReceiver, Receiver, Sender}; + use async_lock::RwLock; use futures::Stream; -use hotshot_task::{ - boxed_sync, - event_stream::{ChannelStream, EventStream, StreamId}, - global_registry::GlobalRegistry, - task::FilterEvent, - BoxSyncFuture, -}; + use hotshot_task_impls::events::HotShotEvent; #[cfg(feature = "hotshot-testing")] use hotshot_types::traits::election::Membership; +use hotshot_task::task::TaskRegistry; +use hotshot_types::{boxed_sync, BoxSyncFuture}; use hotshot_types::{ consensus::Consensus, data::Leaf, error::HotShotError, traits::node_implementation::NodeType, }; @@ -25,13 +22,19 @@ use std::sync::Arc; /// This type provides the means to message and interact with a background [`SystemContext`] instance, /// allowing the ability to receive [`Event`]s from it, send transactions to it, and interact with /// the underlying storage. +#[derive(Clone)] pub struct SystemContextHandle> { - /// The [sender](ChannelStream) for the output stream from the background process - pub(crate) output_event_stream: ChannelStream>, - /// access to the internal ev ent stream, in case we need to, say, shut something down - pub(crate) internal_event_stream: ChannelStream>, + /// The [sender](Sender) and an `InactiveReceiver` to keep the channel open. + /// The Channel will output all the events. Subscribers will get an activated + /// clone of the `Receiver` when they get output stream. + pub(crate) output_event_stream: (Sender>, InactiveReceiver>), + /// access to the internal event stream, in case we need to, say, shut something down + pub(crate) internal_event_stream: ( + Sender>, + InactiveReceiver>, + ), /// registry for controlling tasks - pub(crate) registry: GlobalRegistry, + pub(crate) registry: Arc, /// Internal reference to the underlying [`SystemContext`] pub hotshot: SystemContext, @@ -40,38 +43,18 @@ pub struct SystemContextHandle> { pub(crate) storage: I::Storage, } -impl + 'static> Clone - for SystemContextHandle -{ - fn clone(&self) -> Self { - Self { - registry: self.registry.clone(), - output_event_stream: self.output_event_stream.clone(), - internal_event_stream: self.internal_event_stream.clone(), - hotshot: self.hotshot.clone(), - storage: self.storage.clone(), - } - } -} - impl + 'static> SystemContextHandle { /// obtains a stream to expose to the user - pub async fn get_event_stream( - &mut self, - filter: FilterEvent>, - ) -> (impl Stream>, StreamId) { - self.output_event_stream.subscribe(filter).await + pub fn get_event_stream(&self) -> impl Stream> { + self.output_event_stream.1.activate_cloned() } /// HACK so we can know the types when running tests... /// there are two cleaner solutions: /// - make the stream generic and in nodetypes or nodeimpelmentation /// - type wrapper - pub async fn get_event_stream_known_impl( - &mut self, - filter: FilterEvent>, - ) -> (UnboundedStream>, StreamId) { - self.output_event_stream.subscribe(filter).await + pub fn get_event_stream_known_impl(&self) -> Receiver> { + self.output_event_stream.1.activate_cloned() } /// HACK so we can know the types when running tests... @@ -79,11 +62,8 @@ impl + 'static> SystemContextHandl /// - make the stream generic and in nodetypes or nodeimpelmentation /// - type wrapper /// NOTE: this is only used for sanity checks in our tests - pub async fn get_internal_event_stream_known_impl( - &mut self, - filter: FilterEvent>, - ) -> (UnboundedStream>, StreamId) { - self.internal_event_stream.subscribe(filter).await + pub fn get_internal_event_stream_known_impl(&self) -> Receiver> { + self.internal_event_stream.1.activate_cloned() } /// Get the last decided validated state of the [`SystemContext`] instance. @@ -164,7 +144,7 @@ impl + 'static> SystemContextHandl { boxed_sync(async move { self.hotshot.inner.networks.shut_down_networks().await; - self.registry.shutdown_all().await; + self.registry.shutdown().await; }) } diff --git a/crates/task-impls/Cargo.toml b/crates/task-impls/Cargo.toml index e1a40992b3..b236dc23da 100644 --- a/crates/task-impls/Cargo.toml +++ b/crates/task-impls/Cargo.toml @@ -15,13 +15,14 @@ async-lock = { workspace = true } tracing = { workspace = true } hotshot-constants = { path = "../constants", default-features = false } hotshot-types = { path = "../types", default-features = false } -hotshot-task = { path = "../task", default-features = false } hotshot-utils = { path = "../utils" } time = { workspace = true } commit = { workspace = true } bincode = { workspace = true } bitvec = { workspace = true } sha2 = { workspace = true } +hotshot-task = { path = "../task" } +async-broadcast = { workspace = true } [target.'cfg(all(async_executor_impl = "tokio"))'.dependencies] tokio = { workspace = true } diff --git a/crates/task-impls/src/consensus.rs b/crates/task-impls/src/consensus.rs index c1e5d08b6f..ffe19d0a37 100644 --- a/crates/task-impls/src/consensus.rs +++ b/crates/task-impls/src/consensus.rs @@ -1,6 +1,6 @@ use crate::{ - events::HotShotEvent, - helpers::cancel_task, + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::{broadcast_event, cancel_task}, vote::{create_vote_accumulator, AccumulatorInfo, VoteCollectionTaskState}, }; use async_compatibility_layer::art::{async_sleep, async_spawn}; @@ -10,12 +10,10 @@ use async_std::task::JoinHandle; use commit::Committable; use core::time::Duration; use hotshot_constants::LOOK_AHEAD; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - global_registry::GlobalRegistry, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; +use hotshot_task::task::{Task, TaskState}; + +use async_broadcast::Sender; + use hotshot_types::{ consensus::{Consensus, View}, data::{Leaf, QuorumProposal, VidCommitment, VidDisperse}, @@ -77,8 +75,6 @@ pub struct ConsensusTaskState< pub public_key: TYPES::SignatureKey, /// Our Private Key pub private_key: ::PrivateKey, - /// The global task registry - pub registry: GlobalRegistry, /// Reference to consensus. The replica will require a write lock on this. pub consensus: Arc>>, /// View timeout from config. @@ -124,11 +120,8 @@ pub struct ConsensusTaskState< /// last Timeout Certificate this node formed pub timeout_cert: Option>, - /// Global events stream to publish events - pub event_stream: ChannelStream>, - - /// Event stream to publish events to the application layer - pub output_event_stream: ChannelStream>, + /// Output events to application + pub output_event_stream: async_broadcast::Sender>, /// All the VID shares we've received for current and future views. /// In the future we will need a different struct similar to VidDisperse except @@ -175,7 +168,7 @@ impl, A: ConsensusApi + #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Consensus vote if able", level = "error")] // Check if we are able to vote, like whether the proposal is valid, // whether we have DAC and VID share, and if so, vote. - async fn vote_if_able(&mut self) -> bool { + async fn vote_if_able(&mut self, event_stream: &Sender>) -> bool { if !self.quorum_membership.has_stake(&self.public_key) { debug!( "We were not chosen for consensus committee on {:?}", @@ -240,9 +233,7 @@ impl, A: ConsensusApi + "Sending vote to next quorum leader {:?}", vote.get_view_number() + 1 ); - self.event_stream - .publish(HotShotEvent::QuorumVoteSend(vote)) - .await; + broadcast_event(HotShotEvent::QuorumVoteSend(vote), event_stream).await; if let Some(commit_and_metadata) = &self.payload_commitment_and_metadata { if commit_and_metadata.is_genesis { self.payload_commitment_and_metadata = None; @@ -343,9 +334,7 @@ impl, A: ConsensusApi + "Sending vote to next quorum leader {:?}", vote.get_view_number() + 1 ); - self.event_stream - .publish(HotShotEvent::QuorumVoteSend(vote)) - .await; + broadcast_event(HotShotEvent::QuorumVoteSend(vote), event_stream).await; return true; } } @@ -365,7 +354,11 @@ impl, A: ConsensusApi + /// Must only update the view and GC if the view actually changes #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Consensus update view", level = "error")] - async fn update_view(&mut self, new_view: TYPES::Time) -> bool { + async fn update_view( + &mut self, + new_view: TYPES::Time, + event_stream: &Sender>, + ) -> bool { if *self.cur_view < *new_view { debug!( "Updating view from {} to {} in consensus task", @@ -411,22 +404,22 @@ impl, A: ConsensusApi + .await; } - self.event_stream - .publish(HotShotEvent::ViewChange(new_view)) - .await; + broadcast_event(HotShotEvent::ViewChange(new_view), event_stream).await; // Spawn a timeout task if we did actually update view let timeout = self.timeout; self.timeout_task = Some(async_spawn({ - let stream = self.event_stream.clone(); + let stream = event_stream.clone(); // Nuance: We timeout on the view + 1 here because that means that we have // not seen evidence to transition to this new view let view_number = self.cur_view + 1; async move { async_sleep(Duration::from_millis(timeout)).await; - stream - .publish(HotShotEvent::Timeout(TYPES::Time::new(*view_number))) - .await; + broadcast_event( + HotShotEvent::Timeout(TYPES::Time::new(*view_number)), + &stream, + ) + .await; } })); let consensus = self.consensus.read().await; @@ -446,7 +439,11 @@ impl, A: ConsensusApi + /// Handles a consensus event received on the event stream #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Consensus replica task", level = "error")] - pub async fn handle_event(&mut self, event: HotShotEvent) { + pub async fn handle( + &mut self, + event: HotShotEvent, + event_stream: Sender>, + ) { match event { HotShotEvent::QuorumProposalRecv(proposal, sender) => { debug!( @@ -503,7 +500,7 @@ impl, A: ConsensusApi + } // NOTE: We could update our view with a valid TC but invalid QC, but that is not what we do here - self.update_view(view).await; + self.update_view(view, &event_stream).await; let consensus = self.consensus.upgradable_read().await; @@ -513,16 +510,18 @@ impl, A: ConsensusApi + let leaf = self.genesis_leaf().await; match leaf { Some(ref leaf) => { - self.output_event_stream - .publish(Event { + broadcast_event( + Event { view_number: TYPES::Time::genesis(), event: EventType::Decide { leaf_chain: Arc::new(vec![leaf.clone()]), qc: Arc::new(justify_qc.clone()), block_size: None, }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; } None => { error!( @@ -598,10 +597,10 @@ impl, A: ConsensusApi + "Attempting to publish proposal after voting; now in view: {}", *new_view ); - self.publish_proposal_if_able(qc.view_number + 1, None) + self.publish_proposal_if_able(qc.view_number + 1, None, &event_stream) .await; } - if self.vote_if_able().await { + if self.vote_if_able(&event_stream).await { self.current_proposal = None; } } @@ -757,12 +756,14 @@ impl, A: ConsensusApi + }, ) { error!("view publish error {e}"); - self.output_event_stream - .publish(Event { + broadcast_event( + Event { view_number: view, event: EventType::Error { error: e.into() }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; } } @@ -787,17 +788,17 @@ impl, A: ConsensusApi + } #[allow(clippy::cast_precision_loss)] if new_decide_reached { - self.event_stream - .publish(HotShotEvent::LeafDecided(leaf_views.clone())) - .await; - let decide_sent = self.output_event_stream.publish(Event { - view_number: consensus.last_decided_view, - event: EventType::Decide { - leaf_chain: Arc::new(leaf_views), - qc: Arc::new(new_decide_qc.unwrap()), - block_size: Some(included_txns_set.len().try_into().unwrap()), + let decide_sent = broadcast_event( + Event { + view_number: consensus.last_decided_view, + event: EventType::Decide { + leaf_chain: Arc::new(leaf_views), + qc: Arc::new(new_decide_qc.unwrap()), + block_size: Some(included_txns_set.len().try_into().unwrap()), + }, }, - }); + &self.output_event_stream, + ); let old_anchor_view = consensus.last_decided_view; consensus .collect_garbage(old_anchor_view, new_anchor_view) @@ -823,6 +824,7 @@ impl, A: ConsensusApi + debug!("Sending Decide for view {:?}", consensus.last_decided_view); debug!("Decided txns len {:?}", included_txns_set.len()); decide_sent.await; + debug!("decide send succeeded"); } let new_view = self.current_proposal.clone().unwrap().view_number + 1; @@ -840,11 +842,11 @@ impl, A: ConsensusApi + "Attempting to publish proposal after voting; now in view: {}", *new_view ); - self.publish_proposal_if_able(qc.view_number + 1, None) + self.publish_proposal_if_able(qc.view_number + 1, None, &event_stream) .await; } - if !self.vote_if_able().await { + if !self.vote_if_able(&event_stream).await { return; } self.current_proposal = None; @@ -867,34 +869,33 @@ impl, A: ConsensusApi + } let mut collector = self.vote_collector.write().await; - let maybe_task = collector.take(); - - if maybe_task.is_none() - || vote.get_view_number() > maybe_task.as_ref().unwrap().view + if collector.is_none() || vote.get_view_number() > collector.as_ref().unwrap().view { debug!("Starting vote handle for view {:?}", vote.get_view_number()); let info = AccumulatorInfo { public_key: self.public_key.clone(), membership: self.quorum_membership.clone(), view: vote.get_view_number(), - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; *collector = create_vote_accumulator::< TYPES, QuorumVote, QuorumCertificate, - >(&info, vote.clone(), event) + >(&info, vote.clone(), event, &event_stream) .await; } else { - let result = maybe_task.unwrap().handle_event(event.clone()).await; + let result = collector + .as_mut() + .unwrap() + .handle_event(event.clone(), &event_stream) + .await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { + *collector = None; // The protocol has finished return; } - *collector = Some(result.1); } } HotShotEvent::TimeoutVoteRecv(ref vote) => { @@ -913,34 +914,34 @@ impl, A: ConsensusApi + return; } let mut collector = self.timeout_vote_collector.write().await; - let maybe_task = collector.take(); - if maybe_task.is_none() - || vote.get_view_number() > maybe_task.as_ref().unwrap().view + if collector.is_none() || vote.get_view_number() > collector.as_ref().unwrap().view { debug!("Starting vote handle for view {:?}", vote.get_view_number()); let info = AccumulatorInfo { public_key: self.public_key.clone(), membership: self.quorum_membership.clone(), view: vote.get_view_number(), - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; *collector = create_vote_accumulator::< TYPES, TimeoutVote, TimeoutCertificate, - >(&info, vote.clone(), event) + >(&info, vote.clone(), event, &event_stream) .await; } else { - let result = maybe_task.unwrap().handle_event(event.clone()).await; + let result = collector + .as_mut() + .unwrap() + .handle_event(event.clone(), &event_stream) + .await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { + *collector = None; // The protocol has finished return; } - *collector = Some(result.1); } } HotShotEvent::QCFormed(cert) => { @@ -962,7 +963,10 @@ impl, A: ConsensusApi + let view = qc.view_number + 1; - if self.publish_proposal_if_able(view, Some(qc.clone())).await { + if self + .publish_proposal_if_able(view, Some(qc.clone()), &event_stream) + .await + { } else { warn!("Wasn't able to publish proposal"); } @@ -985,7 +989,7 @@ impl, A: ConsensusApi + ); if !self - .publish_proposal_if_able(qc.view_number + 1, None) + .publish_proposal_if_able(qc.view_number + 1, None, &event_stream) .await { debug!( @@ -1010,9 +1014,9 @@ impl, A: ConsensusApi + .write() .await .saved_da_certs - .insert(view, cert); + .insert(view, cert.clone()); - if self.vote_if_able().await { + if self.vote_if_able(&event_stream).await { self.current_proposal = None; } } @@ -1077,19 +1081,21 @@ impl, A: ConsensusApi + // update the view in state to the one in the message // Publish a view change event to the application - if !self.update_view(new_view).await { + if !self.update_view(new_view, &event_stream).await { debug!("view not updated"); return; } - self.output_event_stream - .publish(Event { + broadcast_event( + Event { view_number: old_view_number, event: EventType::ViewFinished { view_number: old_view_number, }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; } HotShotEvent::Timeout(view) => { // NOTE: We may optionally have the timeout task listen for view change events @@ -1124,19 +1130,20 @@ impl, A: ConsensusApi + return; }; - self.event_stream - .publish(HotShotEvent::TimeoutVoteSend(vote)) - .await; + broadcast_event(HotShotEvent::TimeoutVoteSend(vote), &event_stream).await; debug!( "We did not receive evidence for view {} in time, sending timeout vote for that view!", *view ); - self.output_event_stream - .publish(Event { + + broadcast_event( + Event { view_number: view, event: EventType::ReplicaViewTimeout { view_number: view }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; let consensus = self.consensus.read().await; consensus.metrics.number_of_timeouts.add(1); } @@ -1150,14 +1157,19 @@ impl, A: ConsensusApi + if self.quorum_membership.get_leader(view) == self.public_key && self.consensus.read().await.high_qc.get_view_number() + 1 == view { - self.publish_proposal_if_able(view, None).await; + self.publish_proposal_if_able(view, None, &event_stream) + .await; } if let Some(tc) = &self.timeout_cert { if self.quorum_membership.get_leader(tc.get_view_number() + 1) == self.public_key { - self.publish_proposal_if_able(view, self.timeout_cert.clone()) - .await; + self.publish_proposal_if_able( + view, + self.timeout_cert.clone(), + &event_stream, + ) + .await; } } } @@ -1171,6 +1183,7 @@ impl, A: ConsensusApi + &mut self, view: TYPES::Time, timeout_certificate: Option>, + event_stream: &Sender>, ) -> bool { if self.quorum_membership.get_leader(view) != self.public_key { // This is expected for view 1, so skipping the logging. @@ -1280,12 +1293,11 @@ impl, A: ConsensusApi + leaf.view_number, "" ); - self.event_stream - .publish(HotShotEvent::QuorumProposalSend( - message.clone(), - self.public_key.clone(), - )) - .await; + broadcast_event( + HotShotEvent::QuorumProposalSend(message.clone(), self.public_key.clone()), + event_stream, + ) + .await; self.payload_commitment_and_metadata = None; return true; @@ -1295,52 +1307,36 @@ impl, A: ConsensusApi + } } -impl, A: ConsensusApi> TS +impl, A: ConsensusApi + 'static> TaskState for ConsensusTaskState { -} - -/// Type alias for Consensus task -pub type ConsensusTaskTypes = HSTWithEvent< - ConsensusTaskError, - HotShotEvent, - ChannelStream>, - ConsensusTaskState, ->; - -/// Event handle for consensus -pub async fn sequencing_consensus_handle< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static, ->( - event: HotShotEvent, - mut state: ConsensusTaskState, -) -> ( - std::option::Option, - ConsensusTaskState, -) { - if let HotShotEvent::Shutdown = event { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - state.handle_event(event).await; - (None, state) + type Event = HotShotEvent; + type Output = (); + fn filter(&self, event: &HotShotEvent) -> bool { + !matches!( + event, + HotShotEvent::QuorumProposalRecv(_, _) + | HotShotEvent::QuorumVoteRecv(_) + | HotShotEvent::QCFormed(_) + | HotShotEvent::DACRecv(_) + | HotShotEvent::ViewChange(_) + | HotShotEvent::SendPayloadCommitmentAndMetadata(..) + | HotShotEvent::Timeout(_) + | HotShotEvent::TimeoutVoteRecv(_) + | HotShotEvent::VidDisperseRecv(..) + | HotShotEvent::Shutdown, + ) + } + async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> + where + Self: Sized, + { + let sender = task.clone_sender(); + tracing::trace!("sender queue len {}", sender.len()); + task.state_mut().handle(event, sender).await; + None + } + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) } -} - -/// Filter for consensus, returns true for event types the consensus task subscribes to. -pub fn consensus_event_filter(event: &HotShotEvent) -> bool { - matches!( - event, - HotShotEvent::QuorumProposalRecv(_, _) - | HotShotEvent::QuorumVoteRecv(_) - | HotShotEvent::QCFormed(_) - | HotShotEvent::DACRecv(_) - | HotShotEvent::ViewChange(_) - | HotShotEvent::SendPayloadCommitmentAndMetadata(..) - | HotShotEvent::Timeout(_) - | HotShotEvent::TimeoutVoteRecv(_) - | HotShotEvent::VidDisperseRecv(..) - | HotShotEvent::Shutdown, - ) } diff --git a/crates/task-impls/src/da.rs b/crates/task-impls/src/da.rs index 9f4ec122c6..577afbaca6 100644 --- a/crates/task-impls/src/da.rs +++ b/crates/task-impls/src/da.rs @@ -1,15 +1,12 @@ use crate::{ - events::HotShotEvent, + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::broadcast_event, vote::{create_vote_accumulator, AccumulatorInfo, VoteCollectionTaskState}, }; +use async_broadcast::Sender; use async_lock::RwLock; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - global_registry::GlobalRegistry, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; +use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ consensus::{Consensus, View}, data::DAProposal, @@ -50,8 +47,6 @@ pub struct DATaskState< > { /// The state's api pub api: A, - /// Global registry task for the state - pub registry: GlobalRegistry, /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -73,9 +68,6 @@ pub struct DATaskState< /// The current vote collection task, if there is one. pub vote_collector: RwLock, DACertificate>>, - /// Global events stream to publish events - pub event_stream: ChannelStream>, - /// This Nodes public key pub public_key: TYPES::SignatureKey, @@ -91,9 +83,10 @@ impl, A: ConsensusApi + { /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "DA Main Task", level = "error")] - pub async fn handle_event( + pub async fn handle( &mut self, event: HotShotEvent, + event_stream: Sender>, ) -> Option { match event { HotShotEvent::DAProposalRecv(proposal, sender) => { @@ -176,9 +169,8 @@ impl, A: ConsensusApi + // self.cur_view = view; debug!("Sending vote to the DA leader {:?}", vote.get_view_number()); - self.event_stream - .publish(HotShotEvent::DAVoteSend(vote)) - .await; + + broadcast_event(HotShotEvent::DAVoteSend(vote), &event_stream).await; let mut consensus = self.consensus.write().await; // Ensure this view is in the view map for garbage collection, but do not overwrite if @@ -203,34 +195,33 @@ impl, A: ConsensusApi + } let mut collector = self.vote_collector.write().await; - let maybe_task = collector.take(); - - if maybe_task.is_none() - || vote.get_view_number() > maybe_task.as_ref().unwrap().view + if collector.is_none() || vote.get_view_number() > collector.as_ref().unwrap().view { debug!("Starting vote handle for view {:?}", vote.get_view_number()); let info = AccumulatorInfo { public_key: self.public_key.clone(), membership: self.da_membership.clone(), view: vote.get_view_number(), - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; *collector = create_vote_accumulator::< TYPES, DAVote, DACertificate, - >(&info, vote.clone(), event) + >(&info, vote.clone(), event, &event_stream) .await; } else { - let result = maybe_task.unwrap().handle_event(event.clone()).await; + let result = collector + .as_mut() + .unwrap() + .handle_event(event.clone(), &event_stream) + .await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { + *collector = None; // The protocol has finished return None; } - *collector = Some(result.1); } } HotShotEvent::ViewChange(view) => { @@ -310,12 +301,11 @@ impl, A: ConsensusApi + _pd: PhantomData, }; - self.event_stream - .publish(HotShotEvent::DAProposalSend( - message.clone(), - self.public_key.clone(), - )) - .await; + broadcast_event( + HotShotEvent::DAProposalSend(message.clone(), self.public_key.clone()), + &event_stream, + ) + .await; } HotShotEvent::Timeout(view) => { @@ -326,7 +316,7 @@ impl, A: ConsensusApi + HotShotEvent::Shutdown => { error!("Shutting down because of shutdown signal!"); - return Some(HotShotTaskCompleted::ShutDown); + return Some(HotShotTaskCompleted); } _ => { error!("unexpected event {:?}", event); @@ -334,10 +324,18 @@ impl, A: ConsensusApi + } None } +} - /// Filter the DA event. - pub fn filter(event: &HotShotEvent) -> bool { - matches!( +/// task state implementation for DA Task +impl, A: ConsensusApi + 'static> TaskState + for DATaskState +{ + type Event = HotShotEvent; + + type Output = HotShotTaskCompleted; + + fn filter(&self, event: &HotShotEvent) -> bool { + !matches!( event, HotShotEvent::DAProposalRecv(_, _) | HotShotEvent::DAVoteRecv(_) @@ -347,18 +345,16 @@ impl, A: ConsensusApi + | HotShotEvent::ViewChange(_) ) } -} -/// task state implementation for DA Task -impl, A: ConsensusApi + 'static> TS - for DATaskState -{ -} + async fn handle_event( + event: Self::Event, + task: &mut Task, + ) -> Option { + let sender = task.clone_sender(); + task.state_mut().handle(event, sender).await + } -/// Type alias for DA Task Types -pub type DATaskTypes = HSTWithEvent< - ConsensusTaskError, - HotShotEvent, - ChannelStream>, - DATaskState, ->; + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } +} diff --git a/crates/task-impls/src/events.rs b/crates/task-impls/src/events.rs index d1d94c8438..1002c5d5e8 100644 --- a/crates/task-impls/src/events.rs +++ b/crates/task-impls/src/events.rs @@ -15,6 +15,10 @@ use hotshot_types::{ traits::{node_implementation::NodeType, BlockPayload}, }; +/// Marker that the task completed +#[derive(Eq, Hash, PartialEq, Debug, Clone)] +pub struct HotShotTaskCompleted; + /// All of the possible events that can be passed between Sequecning `HotShot` tasks #[derive(Eq, Hash, PartialEq, Debug, Clone)] pub enum HotShotEvent { diff --git a/crates/task-impls/src/harness.rs b/crates/task-impls/src/harness.rs index 2ee224b7cd..509a664751 100644 --- a/crates/task-impls/src/harness.rs +++ b/crates/task-impls/src/harness.rs @@ -1,36 +1,35 @@ -use crate::events::HotShotEvent; -use async_compatibility_layer::art::async_spawn; +use crate::events::{HotShotEvent, HotShotTaskCompleted}; +use async_broadcast::broadcast; -use futures::FutureExt; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - task::{FilterEvent, HandleEvent, HotShotTaskCompleted, HotShotTaskTypes, TS}, - task_impls::{HSTWithEvent, TaskBuilder}, - task_launcher::TaskRunner, -}; +use async_compatibility_layer::art::async_timeout; +use hotshot_task::task::{Task, TaskRegistry, TaskState}; use hotshot_types::traits::node_implementation::NodeType; -use snafu::Snafu; -use std::{collections::HashMap, future::Future, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; /// The state for the test harness task. Keeps track of which events and how many we expect to get pub struct TestHarnessState { /// The expected events we get from the test. Maps an event to the number of times we expect to see it expected_output: HashMap, usize>, + /// If true we won't fail the test if extra events come in + allow_extra_output: bool, } -impl TS for TestHarnessState {} +impl TaskState for TestHarnessState { + type Event = HotShotEvent; + type Output = HotShotTaskCompleted; -/// Error emitted if the test harness task fails -#[derive(Snafu, Debug)] -pub struct TestHarnessTaskError {} + async fn handle_event( + event: Self::Event, + task: &mut Task, + ) -> Option { + let extra = task.state_mut().allow_extra_output; + handle_event(event, task, extra) + } -/// Type alias for the Test Harness Task -pub type TestHarnessTaskTypes = HSTWithEvent< - TestHarnessTaskError, - HotShotEvent, - ChannelStream>, - TestHarnessState, ->; + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } +} /// Runs a test by building the task using `build_fn` and then passing it the `input` events /// and testing the make sure all of the `expected_output` events are seen @@ -43,46 +42,47 @@ pub type TestHarnessTaskTypes = HSTWithEvent< /// # Panics /// Panics if any state the test expects is not set. Panicing causes a test failure #[allow(clippy::implicit_hasher)] -pub async fn run_harness( +#[allow(clippy::panic)] +pub async fn run_harness>>( input: Vec>, expected_output: HashMap, usize>, - event_stream: Option>>, - build_fn: impl FnOnce(TaskRunner, ChannelStream>) -> Fut, + state: S, allow_extra_output: bool, ) where TYPES: NodeType, - Fut: Future, + S: Send + 'static, { - let task_runner = TaskRunner::new(); - let registry = task_runner.registry.clone(); - let event_stream = event_stream.unwrap_or_default(); - let state = TestHarnessState { expected_output }; - let handler = HandleEvent(Arc::new(move |event, state| { - async move { handle_event(event, state, allow_extra_output) }.boxed() - })); - let filter = FilterEvent::default(); - let builder = TaskBuilder::>::new("test_harness".to_string()) - .register_event_stream(event_stream.clone(), filter) - .await - .register_registry(&mut registry.clone()) - .await - .register_state(state) - .register_event_handler(handler); - - let id = builder.get_task_id().unwrap(); - - let task = TestHarnessTaskTypes::build(builder).launch(); - - let task_runner = task_runner.add_task(id, "test_harness".to_string(), task); - let task_runner = build_fn(task_runner, event_stream.clone()).await; - - let runner = async_spawn(async move { task_runner.launch().await }); + let registry = Arc::new(TaskRegistry::default()); + let mut tasks = vec![]; + // set up two broadcast channels so the test sends to the task and the task back to the test + let (to_task, from_test) = broadcast(1024); + let (to_test, from_task) = broadcast(1024); + let test_state = TestHarnessState { + expected_output, + allow_extra_output, + }; + + let test_task = Task::new( + to_test.clone(), + from_task.clone(), + registry.clone(), + test_state, + ); + let task = Task::new(to_test.clone(), from_test.clone(), registry.clone(), state); + + tasks.push(test_task.run()); + tasks.push(task.run()); for event in input { - let () = event_stream.publish(event).await; + to_task.broadcast_direct(event).await.unwrap(); } - let _ = runner.await; + if async_timeout(Duration::from_secs(2), futures::future::join_all(tasks)) + .await + .is_err() + { + panic!("Test timeout out before all all expected outputs received"); + } } /// Handles an event for the Test Harness Task. If the event is expected, remove it from @@ -97,12 +97,10 @@ pub async fn run_harness( #[allow(clippy::needless_pass_by_value)] pub fn handle_event( event: HotShotEvent, - mut state: TestHarnessState, + task: &mut Task>, allow_extra_output: bool, -) -> ( - std::option::Option, - TestHarnessState, -) { +) -> Option { + let state = task.state_mut(); // Check the output in either case: // * We allow outputs only in our expected output set. // * We haven't received all expected outputs yet. @@ -121,8 +119,9 @@ pub fn handle_event( } if state.expected_output.is_empty() { - return (Some(HotShotTaskCompleted::ShutDown), state); + tracing::error!("test harness task completed"); + return Some(HotShotTaskCompleted); } - (None, state) + None } diff --git a/crates/task-impls/src/helpers.rs b/crates/task-impls/src/helpers.rs index c50f776500..93376f7086 100644 --- a/crates/task-impls/src/helpers.rs +++ b/crates/task-impls/src/helpers.rs @@ -1,3 +1,4 @@ +use async_broadcast::{SendError, Sender}; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; #[cfg(async_executor_impl = "tokio")] @@ -10,3 +11,22 @@ pub async fn cancel_task(task: JoinHandle) { #[cfg(async_executor_impl = "tokio")] task.abort(); } + +/// Helper function to send events and log errors +pub async fn broadcast_event(event: E, sender: &Sender) { + match sender.broadcast_direct(event).await { + Ok(None) => (), + Ok(Some(overflowed)) => { + tracing::error!( + "Event sender queue overflow, Oldest event removed form queue: {:?}", + overflowed + ); + } + Err(SendError(e)) => { + tracing::warn!( + "Event: {:?}\n Sending failed, event stream probably shutdown", + e + ); + } + } +} diff --git a/crates/task-impls/src/lib.rs b/crates/task-impls/src/lib.rs index 8299bf1edc..1be4020af2 100644 --- a/crates/task-impls/src/lib.rs +++ b/crates/task-impls/src/lib.rs @@ -32,4 +32,4 @@ pub mod vote; pub mod upgrade; /// Helper functions used by any task -mod helpers; +pub mod helpers; diff --git a/crates/task-impls/src/network.rs b/crates/task-impls/src/network.rs index 79bedc2260..92dd284aa2 100644 --- a/crates/task-impls/src/network.rs +++ b/crates/task-impls/src/network.rs @@ -1,16 +1,15 @@ -use crate::events::HotShotEvent; +use crate::{ + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::broadcast_event, +}; +use async_broadcast::Sender; use either::Either::{self, Left, Right}; use hotshot_constants::VERSION_0_1; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - task::{FilterEvent, HotShotTaskCompleted, TS}, - task_impls::{HSTWithEvent, HSTWithMessage}, - GeneratedStream, Merge, -}; + +use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ message::{ - CommitteeConsensusMessage, GeneralConsensusMessage, Message, MessageKind, Messages, - SequencingMessage, + CommitteeConsensusMessage, GeneralConsensusMessage, Message, MessageKind, SequencingMessage, }, traits::{ election::Membership, @@ -19,31 +18,82 @@ use hotshot_types::{ }, vote::{HasViewNumber, Vote}, }; -use snafu::Snafu; -use std::sync::Arc; use tracing::error; use tracing::instrument; -/// the type of network task -#[derive(Clone, Copy, Debug)] -pub enum NetworkTaskKind { - /// quorum: the normal "everyone" committee - Quorum, - /// da committee - Committee, - /// view sync - ViewSync, - /// vid - VID, +/// quorum filter +pub fn quorum_filter(event: &HotShotEvent) -> bool { + !matches!( + event, + HotShotEvent::QuorumProposalSend(_, _) + | HotShotEvent::QuorumVoteSend(_) + | HotShotEvent::Shutdown + | HotShotEvent::DACSend(_, _) + | HotShotEvent::ViewChange(_) + | HotShotEvent::TimeoutVoteSend(_) + ) +} + +/// committee filter +pub fn committee_filter(event: &HotShotEvent) -> bool { + !matches!( + event, + HotShotEvent::DAProposalSend(_, _) + | HotShotEvent::DAVoteSend(_) + | HotShotEvent::Shutdown + | HotShotEvent::ViewChange(_) + ) } +/// vid filter +pub fn vid_filter(event: &HotShotEvent) -> bool { + !matches!( + event, + HotShotEvent::Shutdown | HotShotEvent::VidDisperseSend(_, _) | HotShotEvent::ViewChange(_) + ) +} + +/// view sync filter +pub fn view_sync_filter(event: &HotShotEvent) -> bool { + !matches!( + event, + HotShotEvent::ViewSyncPreCommitCertificate2Send(_, _) + | HotShotEvent::ViewSyncCommitCertificate2Send(_, _) + | HotShotEvent::ViewSyncFinalizeCertificate2Send(_, _) + | HotShotEvent::ViewSyncPreCommitVoteSend(_) + | HotShotEvent::ViewSyncCommitVoteSend(_) + | HotShotEvent::ViewSyncFinalizeVoteSend(_) + | HotShotEvent::Shutdown + | HotShotEvent::ViewChange(_) + ) +} /// the network message task state +#[derive(Clone)] pub struct NetworkMessageTaskState { - /// event stream (used for publishing) - pub event_stream: ChannelStream>, + /// Sender to send internal events this task generates to other tasks + pub event_stream: Sender>, } -impl TS for NetworkMessageTaskState {} +impl TaskState for NetworkMessageTaskState { + type Event = Vec>; + type Output = (); + + async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> + where + Self: Sized, + { + task.state_mut().handle_messages(event).await; + None + } + + fn filter(&self, _event: &Self::Event) -> bool { + false + } + + fn should_shutdown(_event: &Self::Event) -> bool { + false + } +} impl NetworkMessageTaskState { /// Handle the message. @@ -111,7 +161,7 @@ impl NetworkMessageTaskState { // TODO (Keyao benchmarking) Update these event variants (similar to the // `TransactionsRecv` event) so we can send one event for a vector of messages. // - self.event_stream.publish(event).await; + broadcast_event(event, &self.event_stream).await; } MessageKind::Data(message) => match message { hotshot_types::message::DataMessage::SubmitTransaction(transaction, _) => { @@ -121,9 +171,11 @@ impl NetworkMessageTaskState { }; } if !transactions.is_empty() { - self.event_stream - .publish(HotShotEvent::TransactionsRecv(transactions)) - .await; + broadcast_event( + HotShotEvent::TransactionsRecv(transactions), + &self.event_stream, + ) + .await; } } } @@ -132,16 +184,41 @@ impl NetworkMessageTaskState { pub struct NetworkEventTaskState> { /// comm channel pub channel: COMMCHANNEL, - /// event stream - pub event_stream: ChannelStream>, /// view number pub view: TYPES::Time, + /// membership for the channel + pub membership: TYPES::Membership, // TODO ED Need to add exchange so we can get the recipient key and our own key? + /// Filter which returns false for the events that this specific network task cares about + pub filter: fn(&HotShotEvent) -> bool, } -impl> TS +impl> TaskState for NetworkEventTaskState { + type Event = HotShotEvent; + + type Output = HotShotTaskCompleted; + + async fn handle_event( + event: Self::Event, + task: &mut Task, + ) -> Option { + let membership = task.state_mut().membership.clone(); + task.state_mut().handle_event(event, &membership).await + } + + fn should_shutdown(event: &Self::Event) -> bool { + if matches!(event, HotShotEvent::Shutdown) { + error!("Network Task received Shutdown event"); + return true; + } + false + } + + fn filter(&self, event: &Self::Event) -> bool { + (self.filter)(event) + } } impl> @@ -275,7 +352,7 @@ impl> } HotShotEvent::Shutdown => { error!("Networking task shutting down"); - return Some(HotShotTaskCompleted::ShutDown); + return Some(HotShotTaskCompleted); } event => { error!("Receieved unexpected message in network task {:?}", event); @@ -303,84 +380,4 @@ impl> None } - - /// network filter - pub fn filter(task_kind: NetworkTaskKind) -> FilterEvent> { - match task_kind { - NetworkTaskKind::Quorum => FilterEvent(Arc::new(Self::quorum_filter)), - NetworkTaskKind::Committee => FilterEvent(Arc::new(Self::committee_filter)), - NetworkTaskKind::ViewSync => FilterEvent(Arc::new(Self::view_sync_filter)), - NetworkTaskKind::VID => FilterEvent(Arc::new(Self::vid_filter)), - } - } - - /// quorum filter - fn quorum_filter(event: &HotShotEvent) -> bool { - matches!( - event, - HotShotEvent::QuorumProposalSend(_, _) - | HotShotEvent::QuorumVoteSend(_) - | HotShotEvent::Shutdown - | HotShotEvent::DACSend(_, _) - | HotShotEvent::ViewChange(_) - | HotShotEvent::TimeoutVoteSend(_) - ) - } - - /// committee filter - fn committee_filter(event: &HotShotEvent) -> bool { - matches!( - event, - HotShotEvent::DAProposalSend(_, _) - | HotShotEvent::DAVoteSend(_) - | HotShotEvent::Shutdown - | HotShotEvent::ViewChange(_) - ) - } - - /// vid filter - fn vid_filter(event: &HotShotEvent) -> bool { - matches!( - event, - HotShotEvent::Shutdown - | HotShotEvent::VidDisperseSend(_, _) - | HotShotEvent::ViewChange(_) - ) - } - - /// view sync filter - fn view_sync_filter(event: &HotShotEvent) -> bool { - matches!( - event, - HotShotEvent::ViewSyncPreCommitCertificate2Send(_, _) - | HotShotEvent::ViewSyncCommitCertificate2Send(_, _) - | HotShotEvent::ViewSyncFinalizeCertificate2Send(_, _) - | HotShotEvent::ViewSyncPreCommitVoteSend(_) - | HotShotEvent::ViewSyncCommitVoteSend(_) - | HotShotEvent::ViewSyncFinalizeVoteSend(_) - | HotShotEvent::Shutdown - | HotShotEvent::ViewChange(_) - ) - } } - -/// network error (no errors right now, only stub) -#[derive(Snafu, Debug)] -pub struct NetworkTaskError {} - -/// networking message task types -pub type NetworkMessageTaskTypes = HSTWithMessage< - NetworkTaskError, - Either, Messages>, - // A combination of broadcast and direct streams. - Merge>, GeneratedStream>>, - NetworkMessageTaskState, ->; - -/// network event task types -pub type NetworkEventTaskTypes = HSTWithEvent< - NetworkTaskError, - HotShotEvent, - ChannelStream>, - NetworkEventTaskState, ->; diff --git a/crates/task-impls/src/transactions.rs b/crates/task-impls/src/transactions.rs index 32b9c0c492..d29142a5a2 100644 --- a/crates/task-impls/src/transactions.rs +++ b/crates/task-impls/src/transactions.rs @@ -1,4 +1,8 @@ -use crate::events::HotShotEvent; +use crate::{ + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::broadcast_event, +}; +use async_broadcast::Sender; use async_compatibility_layer::{ art::async_timeout, async_primitives::subscribable_rwlock::{ReadView, SubscribableRwLock}, @@ -6,12 +10,8 @@ use async_compatibility_layer::{ use async_lock::RwLock; use bincode::config::Options; use commit::{Commitment, Committable}; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - global_registry::GlobalRegistry, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; + +use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ consensus::Consensus, data::Leaf, @@ -49,8 +49,6 @@ pub struct TransactionTaskState< > { /// The state's api pub api: A, - /// Global registry task for the state - pub registry: GlobalRegistry, /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -70,9 +68,6 @@ pub struct TransactionTaskState< /// Membership for teh quorum pub membership: Arc, - /// Global events stream to publish events - pub event_stream: ChannelStream>, - /// This Nodes Public Key pub public_key: TYPES::SignatureKey, /// Our Private Key @@ -87,9 +82,10 @@ impl, A: ConsensusApi + /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Transaction Handling Task", level = "error")] - pub async fn handle_event( + pub async fn handle( &mut self, event: HotShotEvent, + event_stream: Sender>, ) -> Option { match event { HotShotEvent::TransactionsRecv(transactions) => { @@ -183,6 +179,7 @@ impl, A: ConsensusApi + return None; } HotShotEvent::ViewChange(view) => { + debug!("view change in transactions to view {:?}", view); if *self.cur_view >= *view { return None; } @@ -196,6 +193,7 @@ impl, A: ConsensusApi + // return if we aren't the next leader or we skipped last view and aren't the current leader. if !make_block && self.membership.get_leader(self.cur_view + 1) != self.public_key { + debug!("Not next leader for view {:?}", self.cur_view); return None; } @@ -251,18 +249,16 @@ impl, A: ConsensusApi + // send the sequenced transactions to VID and DA tasks let block_view = if make_block { view } else { view + 1 }; - self.event_stream - .publish(HotShotEvent::TransactionsSequenced( - encoded_transactions, - metadata, - block_view, - )) - .await; + broadcast_event( + HotShotEvent::TransactionsSequenced(encoded_transactions, metadata, block_view), + &event_stream, + ) + .await; return None; } HotShotEvent::Shutdown => { - return Some(HotShotTaskCompleted::ShutDown); + return Some(HotShotTaskCompleted); } _ => {} } @@ -333,10 +329,18 @@ impl, A: ConsensusApi + // .collect(); Some(txns) } +} + +/// task state implementation for Transactions Task +impl, A: ConsensusApi + 'static> TaskState + for TransactionTaskState +{ + type Event = HotShotEvent; - /// Event filter for the transaction task - pub fn filter(event: &HotShotEvent) -> bool { - matches!( + type Output = HotShotTaskCompleted; + + fn filter(&self, event: &HotShotEvent) -> bool { + !matches!( event, HotShotEvent::TransactionsRecv(_) | HotShotEvent::LeafDecided(_) @@ -344,18 +348,16 @@ impl, A: ConsensusApi + | HotShotEvent::ViewChange(_) ) } -} -/// task state implementation for Transactions Task -impl, A: ConsensusApi + 'static> TS - for TransactionTaskState -{ -} + async fn handle_event( + event: Self::Event, + task: &mut Task, + ) -> Option { + let sender = task.clone_sender(); + task.state_mut().handle(event, sender).await + } -/// Type alias for DA Task Types -pub type TransactionsTaskTypes = HSTWithEvent< - ConsensusTaskError, - HotShotEvent, - ChannelStream>, - TransactionTaskState, ->; + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } +} diff --git a/crates/task-impls/src/upgrade.rs b/crates/task-impls/src/upgrade.rs index ec0dea8231..a77b6046fb 100644 --- a/crates/task-impls/src/upgrade.rs +++ b/crates/task-impls/src/upgrade.rs @@ -1,15 +1,12 @@ use crate::{ - events::HotShotEvent, + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::broadcast_event, vote::{create_vote_accumulator, AccumulatorInfo, VoteCollectionTaskState}, }; +use async_broadcast::Sender; use async_lock::RwLock; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - global_registry::GlobalRegistry, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; +use hotshot_task::task::TaskState; use hotshot_types::{ event::{Event, EventType}, simple_certificate::UpgradeCertificate, @@ -43,9 +40,6 @@ pub struct UpgradeTaskState< > { /// The state's api pub api: A, - /// Global registry task for the state - pub registry: GlobalRegistry, - /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -61,9 +55,6 @@ pub struct UpgradeTaskState< pub vote_collector: RwLock, UpgradeCertificate>>, - /// Global events stream to publish events - pub event_stream: ChannelStream>, - /// This Nodes public key pub public_key: TYPES::SignatureKey, @@ -79,9 +70,10 @@ impl, A: ConsensusApi + { /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Upgrade Task", level = "error")] - pub async fn handle_event( + pub async fn handle( &mut self, event: HotShotEvent, + tx: Sender>, ) -> Option { match event { HotShotEvent::UpgradeProposalRecv(proposal, sender) => { @@ -149,9 +141,7 @@ impl, A: ConsensusApi + return None; }; debug!("Sending upgrade vote {:?}", vote.get_view_number()); - self.event_stream - .publish(HotShotEvent::UpgradeVoteSend(vote)) - .await; + broadcast_event(HotShotEvent::UpgradeVoteSend(vote), &tx).await; } HotShotEvent::UpgradeVoteRecv(ref vote) => { debug!("Upgrade vote recv, Main Task {:?}", vote.get_view_number()); @@ -167,34 +157,33 @@ impl, A: ConsensusApi + } let mut collector = self.vote_collector.write().await; - let maybe_task = collector.take(); - - if maybe_task.is_none() - || vote.get_view_number() > maybe_task.as_ref().unwrap().view + if collector.is_none() || vote.get_view_number() > collector.as_ref().unwrap().view { debug!("Starting vote handle for view {:?}", vote.get_view_number()); let info = AccumulatorInfo { public_key: self.public_key.clone(), membership: self.quorum_membership.clone(), view: vote.get_view_number(), - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; *collector = create_vote_accumulator::< TYPES, UpgradeVote, UpgradeCertificate, - >(&info, vote.clone(), event) + >(&info, vote.clone(), event, &tx) .await; } else { - let result = maybe_task.unwrap().handle_event(event.clone()).await; - - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + let result = collector + .as_mut() + .unwrap() + .handle_event(event.clone(), &tx) + .await; + + if result == Some(HotShotTaskCompleted) { + *collector = None; // The protocol has finished return None; } - *collector = Some(result.1); } } HotShotEvent::ViewChange(view) => { @@ -211,7 +200,7 @@ impl, A: ConsensusApi + } HotShotEvent::Shutdown => { error!("Shutting down because of shutdown signal!"); - return Some(HotShotTaskCompleted::ShutDown); + return Some(HotShotTaskCompleted); } _ => { error!("unexpected event {:?}", event); @@ -219,10 +208,31 @@ impl, A: ConsensusApi + } None } +} + +/// task state implementation for DA Task +impl, A: ConsensusApi + 'static> TaskState + for UpgradeTaskState +{ + type Event = HotShotEvent; + + type Output = HotShotTaskCompleted; + + async fn handle_event( + event: Self::Event, + task: &mut hotshot_task::task::Task, + ) -> Option { + let sender = task.clone_sender(); + tracing::trace!("sender queue len {}", sender.len()); + task.state_mut().handle(event, sender).await + } - /// Filter the upgrade event. - pub fn filter(event: &HotShotEvent) -> bool { - matches!( + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } + + fn filter(&self, event: &Self::Event) -> bool { + !matches!( event, HotShotEvent::UpgradeProposalRecv(_, _) | HotShotEvent::UpgradeVoteRecv(_) @@ -232,17 +242,3 @@ impl, A: ConsensusApi + ) } } - -/// task state implementation for DA Task -impl, A: ConsensusApi + 'static> TS - for UpgradeTaskState -{ -} - -/// Type alias for DA Task Types -pub type UpgradeTaskTypes = HSTWithEvent< - ConsensusTaskError, - HotShotEvent, - ChannelStream>, - UpgradeTaskState, ->; diff --git a/crates/task-impls/src/vid.rs b/crates/task-impls/src/vid.rs index a4b9338f87..d07aeb2c10 100644 --- a/crates/task-impls/src/vid.rs +++ b/crates/task-impls/src/vid.rs @@ -1,13 +1,11 @@ -use crate::events::HotShotEvent; +use crate::events::{HotShotEvent, HotShotTaskCompleted}; +use crate::helpers::broadcast_event; +use async_broadcast::Sender; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::spawn_blocking; -use hotshot_task::{ - event_stream::ChannelStream, - global_registry::GlobalRegistry, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; + +use hotshot_task::task::{Task, TaskState}; use hotshot_types::traits::network::CommunicationChannel; use hotshot_types::{ consensus::Consensus, @@ -27,7 +25,6 @@ use hotshot_types::{ #[cfg(async_executor_impl = "tokio")] use tokio::task::spawn_blocking; -use hotshot_task::event_stream::EventStream; use snafu::Snafu; use std::marker::PhantomData; use std::sync::Arc; @@ -45,8 +42,6 @@ pub struct VIDTaskState< > { /// The state's api pub api: A, - /// Global registry task for the state - pub registry: GlobalRegistry, /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -63,10 +58,6 @@ pub struct VIDTaskState< pub private_key: ::PrivateKey, /// The view and ID of the current vote collection task, if there is one. pub vote_collector: Option<(TYPES::Time, usize, usize)>, - - /// Global events stream to publish events - pub event_stream: ChannelStream>, - /// This state's ID pub id: u64, } @@ -76,9 +67,10 @@ impl, A: ConsensusApi + { /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "VID Main Task", level = "error")] - pub async fn handle_event( + pub async fn handle( &mut self, event: HotShotEvent, + event_stream: Sender>, ) -> Option { match event { HotShotEvent::TransactionsSequenced(encoded_transactions, metadata, view_number) => { @@ -104,21 +96,25 @@ impl, A: ConsensusApi + // Unwrap here will just propogate any panic from the spawned task, it's not a new place we can panic. let vid_disperse = vid_disperse.unwrap(); // send the commitment and metadata to consensus for block building - self.event_stream - .publish(HotShotEvent::SendPayloadCommitmentAndMetadata( + broadcast_event( + HotShotEvent::SendPayloadCommitmentAndMetadata( vid_disperse.commit, metadata, view_number, - )) - .await; + ), + &event_stream, + ) + .await; // send the block to the VID dispersal function - self.event_stream - .publish(HotShotEvent::BlockReady( + broadcast_event( + HotShotEvent::BlockReady( VidDisperse::from_membership(view_number, vid_disperse, &self.membership), view_number, - )) - .await; + ), + &event_stream, + ) + .await; } HotShotEvent::BlockReady(vid_disperse, view_number) => { @@ -130,16 +126,18 @@ impl, A: ConsensusApi + return None; }; debug!("publishing VID disperse for view {}", *view_number); - self.event_stream - .publish(HotShotEvent::VidDisperseSend( + broadcast_event( + HotShotEvent::VidDisperseSend( Proposal { signature, data: vid_disperse, _pd: PhantomData, }, self.public_key.clone(), - )) - .await; + ), + &event_stream, + ) + .await; } HotShotEvent::ViewChange(view) => { @@ -169,7 +167,7 @@ impl, A: ConsensusApi + } HotShotEvent::Shutdown => { - return Some(HotShotTaskCompleted::ShutDown); + return Some(HotShotTaskCompleted); } _ => { error!("unexpected event {:?}", event); @@ -177,10 +175,26 @@ impl, A: ConsensusApi + } None } +} - /// Filter the VID event. - pub fn filter(event: &HotShotEvent) -> bool { - matches!( +/// task state implementation for VID Task +impl, A: ConsensusApi + 'static> TaskState + for VIDTaskState +{ + type Event = HotShotEvent; + + type Output = HotShotTaskCompleted; + + async fn handle_event( + event: Self::Event, + task: &mut Task, + ) -> Option { + let sender = task.clone_sender(); + task.state_mut().handle(event, sender).await; + None + } + fn filter(&self, event: &Self::Event) -> bool { + !matches!( event, HotShotEvent::Shutdown | HotShotEvent::TransactionsSequenced(_, _, _) @@ -188,18 +202,7 @@ impl, A: ConsensusApi + | HotShotEvent::ViewChange(_) ) } + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } } - -/// task state implementation for VID Task -impl, A: ConsensusApi + 'static> TS - for VIDTaskState -{ -} - -/// Type alias for VID Task Types -pub type VIDTaskTypes = HSTWithEvent< - ConsensusTaskError, - HotShotEvent, - ChannelStream>, - VIDTaskState, ->; diff --git a/crates/task-impls/src/view_sync.rs b/crates/task-impls/src/view_sync.rs index a123cee52f..04d47cda08 100644 --- a/crates/task-impls/src/view_sync.rs +++ b/crates/task-impls/src/view_sync.rs @@ -1,16 +1,12 @@ #![allow(clippy::module_name_repetitions)] use crate::{ - events::HotShotEvent, - helpers::cancel_task, + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::{broadcast_event, cancel_task}, vote::{create_vote_accumulator, AccumulatorInfo, HandleVoteEvent, VoteCollectionTaskState}, }; +use async_broadcast::Sender; use async_compatibility_layer::art::{async_sleep, async_spawn}; use async_lock::RwLock; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; use hotshot_types::{ simple_certificate::{ ViewSyncCommitCertificate2, ViewSyncFinalizeCertificate2, ViewSyncPreCommitCertificate2, @@ -29,7 +25,7 @@ use hotshot_types::{ #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; -use hotshot_task::global_registry::GlobalRegistry; +use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ message::GeneralConsensusMessage, traits::{ @@ -71,10 +67,6 @@ pub struct ViewSyncTaskState< I: NodeImplementation, A: ConsensusApi + 'static + std::clone::Clone, > { - /// Registry to register sub tasks - pub registry: GlobalRegistry, - /// Event stream to publish events to - pub event_stream: ChannelStream>, /// View HotShot is currently in pub current_view: TYPES::Time, /// View HotShot wishes to be in @@ -119,17 +111,38 @@ impl< TYPES: NodeType, I: NodeImplementation, A: ConsensusApi + 'static + std::clone::Clone, - > TS for ViewSyncTaskState + > TaskState for ViewSyncTaskState { -} + type Event = HotShotEvent; + + type Output = (); -/// Types for the main view sync task -pub type ViewSyncTaskStateTypes = HSTWithEvent< - ViewSyncTaskError, - HotShotEvent, - ChannelStream>, - ViewSyncTaskState, ->; + async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> { + let sender = task.clone_sender(); + task.state_mut().handle(event, sender).await; + None + } + + fn filter(&self, event: &Self::Event) -> bool { + !matches!( + event, + HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) + | HotShotEvent::ViewSyncCommitCertificate2Recv(_) + | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) + | HotShotEvent::ViewSyncPreCommitVoteRecv(_) + | HotShotEvent::ViewSyncCommitVoteRecv(_) + | HotShotEvent::ViewSyncFinalizeVoteRecv(_) + | HotShotEvent::Shutdown + | HotShotEvent::Timeout(_) + | HotShotEvent::ViewSyncTimeout(_, _, _) + | HotShotEvent::ViewChange(_) + ) + } + + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } +} /// State of a view sync replica task pub struct ViewSyncReplicaTaskState< @@ -164,22 +177,40 @@ pub struct ViewSyncReplicaTaskState< pub private_key: ::PrivateKey, /// HotShot consensus API pub api: A, - /// Event stream to publish events to - pub event_stream: ChannelStream>, } -impl, A: ConsensusApi + 'static> TS +impl, A: ConsensusApi + 'static> TaskState for ViewSyncReplicaTaskState { -} + type Event = HotShotEvent; + + type Output = (); + + async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> { + let sender = task.clone_sender(); + task.state_mut().handle(event, sender).await; + None + } + fn filter(&self, event: &Self::Event) -> bool { + !matches!( + event, + HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) + | HotShotEvent::ViewSyncCommitCertificate2Recv(_) + | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) + | HotShotEvent::ViewSyncPreCommitVoteRecv(_) + | HotShotEvent::ViewSyncCommitVoteRecv(_) + | HotShotEvent::ViewSyncFinalizeVoteRecv(_) + | HotShotEvent::Shutdown + | HotShotEvent::Timeout(_) + | HotShotEvent::ViewSyncTimeout(_, _, _) + | HotShotEvent::ViewChange(_) + ) + } -/// Types for view sync replica state -pub type ViewSyncReplicaTaskStateTypes = HSTWithEvent< - ViewSyncTaskError, - HotShotEvent, - ChannelStream>, - ViewSyncReplicaTaskState, ->; + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } +} impl< TYPES: NodeType, @@ -194,6 +225,7 @@ impl< &mut self, event: HotShotEvent, view: TYPES::Time, + sender: &Sender>, ) { // This certificate is old, we can throw it away // If next view = cert round, then that means we should already have a task running for it @@ -204,17 +236,17 @@ impl< let mut task_map = self.replica_task_map.write().await; - if let Some(replica_task) = task_map.remove(&view) { + if let Some(replica_task) = task_map.get_mut(&view) { // Forward event then return debug!("Forwarding message"); - let result = replica_task.handle_event(event.clone()).await; + let result = replica_task.handle(event.clone(), sender.clone()).await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { // The protocol has finished + task_map.remove(&view); return; } - task_map.insert(view, result.1); return; } @@ -231,47 +263,52 @@ impl< public_key: self.public_key.clone(), private_key: self.private_key.clone(), api: self.api.clone(), - event_stream: self.event_stream.clone(), view_sync_timeout: self.view_sync_timeout, id: self.id, }; - let result = replica_state.handle_event(event.clone()).await; + let result = replica_state.handle(event.clone(), sender.clone()).await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { // The protocol has finished return; } - replica_state = result.1; - task_map.insert(view, replica_state); } #[instrument(skip_all, fields(id = self.id, view = *self.current_view), name = "View Sync Main Task", level = "error")] #[allow(clippy::type_complexity)] /// Handles incoming events for the main view sync task - pub async fn handle_event(&mut self, event: HotShotEvent) { + pub async fn handle( + &mut self, + event: HotShotEvent, + event_stream: Sender>, + ) { match &event { HotShotEvent::ViewSyncPreCommitCertificate2Recv(certificate) => { debug!("Received view sync cert for phase {:?}", certificate); let view = certificate.get_view_number(); - self.send_to_or_create_replica(event, view).await; + self.send_to_or_create_replica(event, view, &event_stream) + .await; } HotShotEvent::ViewSyncCommitCertificate2Recv(certificate) => { debug!("Received view sync cert for phase {:?}", certificate); let view = certificate.get_view_number(); - self.send_to_or_create_replica(event, view).await; + self.send_to_or_create_replica(event, view, &event_stream) + .await; } HotShotEvent::ViewSyncFinalizeCertificate2Recv(certificate) => { debug!("Received view sync cert for phase {:?}", certificate); let view = certificate.get_view_number(); - self.send_to_or_create_replica(event, view).await; + self.send_to_or_create_replica(event, view, &event_stream) + .await; } HotShotEvent::ViewSyncTimeout(view, _, _) => { debug!("view sync timeout in main task {:?}", view); let view = *view; - self.send_to_or_create_replica(event, view).await; + self.send_to_or_create_replica(event, view, &event_stream) + .await; } HotShotEvent::ViewSyncPreCommitVoteRecv(ref vote) => { @@ -279,15 +316,14 @@ impl< let vote_view = vote.get_view_number(); let relay = vote.get_data().relay; let relay_map = map.entry(vote_view).or_insert(BTreeMap::new()); - if let Some(relay_task) = relay_map.remove(&relay) { + if let Some(relay_task) = relay_map.get_mut(&relay) { debug!("Forwarding message"); - let result = relay_task.handle_event(event.clone()).await; + let result = relay_task.handle_event(event.clone(), &event_stream).await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { // The protocol has finished - return; + map.remove(&vote_view); } - relay_map.insert(relay, result.1); return; } @@ -302,11 +338,10 @@ impl< public_key: self.public_key.clone(), membership: self.membership.clone(), view: vote_view, - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; - let vote_collector = create_vote_accumulator(&info, vote.clone(), event).await; + let vote_collector = + create_vote_accumulator(&info, vote.clone(), event, &event_stream).await; if let Some(vote_task) = vote_collector { relay_map.insert(relay, vote_task); } @@ -317,16 +352,14 @@ impl< let vote_view = vote.get_view_number(); let relay = vote.get_data().relay; let relay_map = map.entry(vote_view).or_insert(BTreeMap::new()); - if let Some(relay_task) = relay_map.remove(&relay) { + if let Some(relay_task) = relay_map.get_mut(&relay) { debug!("Forwarding message"); - let result = relay_task.handle_event(event.clone()).await; + let result = relay_task.handle_event(event.clone(), &event_stream).await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { // The protocol has finished - return; + map.remove(&vote_view); } - - relay_map.insert(relay, result.1); return; } @@ -341,11 +374,10 @@ impl< public_key: self.public_key.clone(), membership: self.membership.clone(), view: vote_view, - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; - let vote_collector = create_vote_accumulator(&info, vote.clone(), event).await; + let vote_collector = + create_vote_accumulator(&info, vote.clone(), event, &event_stream).await; if let Some(vote_task) = vote_collector { relay_map.insert(relay, vote_task); } @@ -356,16 +388,14 @@ impl< let vote_view = vote.get_view_number(); let relay = vote.get_data().relay; let relay_map = map.entry(vote_view).or_insert(BTreeMap::new()); - if let Some(relay_task) = relay_map.remove(&relay) { + if let Some(relay_task) = relay_map.get_mut(&relay) { debug!("Forwarding message"); - let result = relay_task.handle_event(event.clone()).await; + let result = relay_task.handle_event(event.clone(), &event_stream).await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { // The protocol has finished - return; + map.remove(&vote_view); } - - relay_map.insert(relay, result.1); return; } @@ -380,11 +410,10 @@ impl< public_key: self.public_key.clone(), membership: self.membership.clone(), view: vote_view, - event_stream: self.event_stream.clone(), id: self.id, - registry: self.registry.clone(), }; - let vote_collector = create_vote_accumulator(&info, vote.clone(), event).await; + let vote_collector = + create_vote_accumulator(&info, vote.clone(), event, &event_stream).await; if let Some(vote_task) = vote_collector { relay_map.insert(relay, vote_task); } @@ -486,39 +515,23 @@ impl< self.send_to_or_create_replica( HotShotEvent::ViewSyncTrigger(view_number + 1), view_number + 1, + &event_stream, ) .await; } else { // If this is the first timeout we've seen advance to the next view self.current_view = view_number; - self.event_stream - .publish(HotShotEvent::ViewChange(TYPES::Time::new( - *self.current_view, - ))) - .await; + broadcast_event( + HotShotEvent::ViewChange(TYPES::Time::new(*self.current_view)), + &event_stream, + ) + .await; } } _ => {} } } - - /// Filter view sync related events. - pub fn filter(event: &HotShotEvent) -> bool { - matches!( - event, - HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) - | HotShotEvent::ViewSyncPreCommitVoteRecv(_) - | HotShotEvent::ViewSyncCommitVoteRecv(_) - | HotShotEvent::ViewSyncFinalizeVoteRecv(_) - | HotShotEvent::Shutdown - | HotShotEvent::Timeout(_) - | HotShotEvent::ViewSyncTimeout(_, _, _) - | HotShotEvent::ViewChange(_) - ) - } } impl, A: ConsensusApi + 'static> @@ -526,13 +539,11 @@ impl, A: ConsensusApi + { #[instrument(skip_all, fields(id = self.id, view = *self.current_view), name = "View Sync Replica Task", level = "error")] /// Handle incoming events for the view sync replica task - pub async fn handle_event( - mut self, + pub async fn handle( + &mut self, event: HotShotEvent, - ) -> ( - std::option::Option, - ViewSyncReplicaTaskState, - ) { + event_stream: Sender>, + ) -> Option { match event { HotShotEvent::ViewSyncPreCommitCertificate2Recv(certificate) => { let last_seen_certificate = ViewSyncPhase::PreCommit; @@ -541,20 +552,20 @@ impl, A: ConsensusApi + if certificate.get_view_number() < self.next_view { warn!("We're already in a higher round"); - return (None, self); + return None; } // If certificate is not valid, return current state if !certificate.is_valid_cert(self.membership.as_ref()) { error!("Not valid view sync cert! {:?}", certificate.get_data()); - return (None, self); + return None; } // If certificate is for a higher round shutdown this task // since another task should have been started for the higher round if certificate.get_view_number() > self.next_view { - return (Some(HotShotTaskCompleted::ShutDown), self); + return Some(HotShotTaskCompleted); } if certificate.get_data().relay > self.relay { @@ -571,13 +582,12 @@ impl, A: ConsensusApi + &self.private_key, ) else { error!("Failed to sign ViewSyncCommitData!"); - return (None, self); + return None; }; let message = GeneralConsensusMessage::::ViewSyncCommitVote(vote); if let GeneralConsensusMessage::ViewSyncCommitVote(vote) = message { - self.event_stream - .publish(HotShotEvent::ViewSyncCommitVoteSend(vote)) + broadcast_event(HotShotEvent::ViewSyncCommitVoteSend(vote), &event_stream) .await; } @@ -586,18 +596,24 @@ impl, A: ConsensusApi + } self.timeout_task = Some(async_spawn({ - let stream = self.event_stream.clone(); + let stream = event_stream.clone(); let phase = last_seen_certificate; + let relay = self.relay; + let next_view = self.next_view; + let timeout = self.view_sync_timeout; async move { - async_sleep(self.view_sync_timeout).await; - info!("Vote sending timed out in ViewSyncPreCommitCertificateRecv, Relay = {}", self.relay); - stream - .publish(HotShotEvent::ViewSyncTimeout( - TYPES::Time::new(*self.next_view), - self.relay, + async_sleep(timeout).await; + info!("Vote sending timed out in ViewSyncPreCommitCertificateRecv, Relay = {}", relay); + + broadcast_event( + HotShotEvent::ViewSyncTimeout( + TYPES::Time::new(*next_view), + relay, phase, - )) - .await; + ), + &stream, + ) + .await; } })); } @@ -609,20 +625,20 @@ impl, A: ConsensusApi + if certificate.get_view_number() < self.next_view { warn!("We're already in a higher round"); - return (None, self); + return None; } // If certificate is not valid, return current state if !certificate.is_valid_cert(self.membership.as_ref()) { error!("Not valid view sync cert! {:?}", certificate.get_data()); - return (None, self); + return None; } // If certificate is for a higher round shutdown this task // since another task should have been started for the higher round if certificate.get_view_number() > self.next_view { - return (Some(HotShotTaskCompleted::ShutDown), self); + return Some(HotShotTaskCompleted); } if certificate.get_data().relay > self.relay { @@ -639,13 +655,12 @@ impl, A: ConsensusApi + &self.private_key, ) else { error!("Failed to sign view sync finalized vote!"); - return (None, self); + return None; }; let message = GeneralConsensusMessage::::ViewSyncFinalizeVote(vote); if let GeneralConsensusMessage::ViewSyncFinalizeVote(vote) = message { - self.event_stream - .publish(HotShotEvent::ViewSyncFinalizeVoteSend(vote)) + broadcast_event(HotShotEvent::ViewSyncFinalizeVoteSend(vote), &event_stream) .await; } @@ -654,33 +669,34 @@ impl, A: ConsensusApi + *self.next_view ); - self.event_stream - .publish(HotShotEvent::ViewChange(self.next_view - 1)) - .await; + broadcast_event(HotShotEvent::ViewChange(self.next_view - 1), &event_stream).await; - self.event_stream - .publish(HotShotEvent::ViewChange(self.next_view)) - .await; + broadcast_event(HotShotEvent::ViewChange(self.next_view), &event_stream).await; if let Some(timeout_task) = self.timeout_task.take() { cancel_task(timeout_task).await; } self.timeout_task = Some(async_spawn({ - let stream = self.event_stream.clone(); + let stream = event_stream.clone(); let phase = last_seen_certificate; + let relay = self.relay; + let next_view = self.next_view; + let timeout = self.view_sync_timeout; async move { - async_sleep(self.view_sync_timeout).await; + async_sleep(timeout).await; info!( "Vote sending timed out in ViewSyncCommitCertificateRecv, relay = {}", - self.relay + relay ); - stream - .publish(HotShotEvent::ViewSyncTimeout( - TYPES::Time::new(*self.next_view), - self.relay, + broadcast_event( + HotShotEvent::ViewSyncTimeout( + TYPES::Time::new(*next_view), + relay, phase, - )) - .await; + ), + &stream, + ) + .await; } })); } @@ -690,20 +706,20 @@ impl, A: ConsensusApi + if certificate.get_view_number() < self.next_view { warn!("We're already in a higher round"); - return (None, self); + return None; } // If certificate is not valid, return current state if !certificate.is_valid_cert(self.membership.as_ref()) { error!("Not valid view sync cert! {:?}", certificate.get_data()); - return (None, self); + return None; } // If certificate is for a higher round shutdown this task // since another task should have been started for the higher round if certificate.get_view_number() > self.next_view { - return (Some(HotShotTaskCompleted::ShutDown), self); + return Some(HotShotTaskCompleted); } // cancel poll for votes @@ -728,16 +744,14 @@ impl, A: ConsensusApi + cancel_task(timeout_task).await; } - self.event_stream - .publish(HotShotEvent::ViewChange(self.next_view)) - .await; - return (Some(HotShotTaskCompleted::ShutDown), self); + broadcast_event(HotShotEvent::ViewChange(self.next_view), &event_stream).await; + return Some(HotShotTaskCompleted); } HotShotEvent::ViewSyncTrigger(view_number) => { if self.next_view != TYPES::Time::new(*view_number) { error!("Unexpected view number to triger view sync"); - return (None, self); + return None; } let Ok(vote) = ViewSyncPreCommitVote::::create_signed_vote( @@ -750,32 +764,36 @@ impl, A: ConsensusApi + &self.private_key, ) else { error!("Failed to sign pre commit vote!"); - return (None, self); + return None; }; let message = GeneralConsensusMessage::::ViewSyncPreCommitVote(vote); if let GeneralConsensusMessage::ViewSyncPreCommitVote(vote) = message { - self.event_stream - .publish(HotShotEvent::ViewSyncPreCommitVoteSend(vote)) + broadcast_event(HotShotEvent::ViewSyncPreCommitVoteSend(vote), &event_stream) .await; } self.timeout_task = Some(async_spawn({ - let stream = self.event_stream.clone(); + let stream = event_stream.clone(); + let relay = self.relay; + let next_view = self.next_view; + let timeout = self.view_sync_timeout; async move { - async_sleep(self.view_sync_timeout).await; + async_sleep(timeout).await; info!("Vote sending timed out in ViewSyncTrigger"); - stream - .publish(HotShotEvent::ViewSyncTimeout( - TYPES::Time::new(*self.next_view), - self.relay, + broadcast_event( + HotShotEvent::ViewSyncTimeout( + TYPES::Time::new(*next_view), + relay, ViewSyncPhase::None, - )) - .await; + ), + &stream, + ) + .await; } })); - return (None, self); + return None; } HotShotEvent::ViewSyncTimeout(round, relay, last_seen_certificate) => { @@ -797,15 +815,17 @@ impl, A: ConsensusApi + &self.private_key, ) else { error!("Failed to sign ViewSyncPreCommitData!"); - return (None, self); + return None; }; let message = GeneralConsensusMessage::::ViewSyncPreCommitVote(vote); if let GeneralConsensusMessage::ViewSyncPreCommitVote(vote) = message { - self.event_stream - .publish(HotShotEvent::ViewSyncPreCommitVoteSend(vote)) - .await; + broadcast_event( + HotShotEvent::ViewSyncPreCommitVoteSend(vote), + &event_stream, + ) + .await; } } ViewSyncPhase::Finalize => { @@ -815,28 +835,33 @@ impl, A: ConsensusApi + } self.timeout_task = Some(async_spawn({ - let stream = self.event_stream.clone(); + let stream = event_stream.clone(); + let relay = self.relay; + let next_view = self.next_view; + let timeout = self.view_sync_timeout; async move { - async_sleep(self.view_sync_timeout).await; + async_sleep(timeout).await; info!( "Vote sending timed out in ViewSyncTimeout relay = {}", - self.relay + relay ); - stream - .publish(HotShotEvent::ViewSyncTimeout( - TYPES::Time::new(*self.next_view), - self.relay, + broadcast_event( + HotShotEvent::ViewSyncTimeout( + TYPES::Time::new(*next_view), + relay, last_seen_certificate, - )) - .await; + ), + &stream, + ) + .await; } })); - return (None, self); + return None; } } - _ => return (None, self), + _ => return None, } - (None, self) + None } } diff --git a/crates/task-impls/src/vote.rs b/crates/task-impls/src/vote.rs index e41f34cdba..651332f39b 100644 --- a/crates/task-impls/src/vote.rs +++ b/crates/task-impls/src/vote.rs @@ -1,14 +1,14 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; -use crate::events::HotShotEvent; +use crate::{ + events::{HotShotEvent, HotShotTaskCompleted}, + helpers::broadcast_event, +}; +use async_broadcast::Sender; use async_trait::async_trait; use either::Either::{self, Left, Right}; -use hotshot_task::{ - event_stream::{ChannelStream, EventStream}, - global_registry::GlobalRegistry, - task::{HotShotTaskCompleted, TS}, - task_impls::HSTWithEvent, -}; + +use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ simple_certificate::{ DACertificate, QuorumCertificate, TimeoutCertificate, UpgradeCertificate, @@ -46,9 +46,6 @@ pub struct VoteCollectionTaskState< /// The view which we are collecting votes for pub view: TYPES::Time, - /// global event stream - pub event_stream: ChannelStream>, - /// Node id pub id: u64, } @@ -75,9 +72,13 @@ impl< { /// Take one vote and accumultate it. Returns either the cert or the updated state /// after the vote is accumulated - pub async fn accumulate_vote(mut self, vote: &VOTE) -> (Option, Self) { + pub async fn accumulate_vote( + &mut self, + vote: &VOTE, + event_stream: &Sender>, + ) -> Option { if vote.get_leader(&self.membership) != self.public_key { - return (None, self); + return None; } if vote.get_view_number() != self.view { @@ -86,23 +87,19 @@ impl< *vote.get_view_number(), *self.view ); - return (None, self); + return None; } - let Some(accumulator) = self.accumulator else { - return (None, self); + let Some(ref mut accumulator) = self.accumulator else { + return None; }; match accumulator.accumulate(vote, &self.membership) { - Either::Left(acc) => { - self.accumulator = Some(acc); - (None, self) - } + Either::Left(()) => None, Either::Right(cert) => { debug!("Certificate Formed! {:?}", cert); - self.event_stream - .publish(VOTE::make_cert_event(cert, &self.public_key)) - .await; + + broadcast_event(VOTE::make_cert_event(cert, &self.public_key), event_stream).await; self.accumulator = None; - (Some(HotShotTaskCompleted::ShutDown), self) + Some(HotShotTaskCompleted) } } } @@ -120,17 +117,23 @@ impl< + std::marker::Send + std::marker::Sync + 'static, - > TS for VoteCollectionTaskState + > TaskState for VoteCollectionTaskState +where + VoteCollectionTaskState: HandleVoteEvent, { -} + type Event = HotShotEvent; -/// Types for a vote accumulator Task -pub type VoteTaskStateTypes = HSTWithEvent< - VoteTaskError, - HotShotEvent, - ChannelStream>, - VoteCollectionTaskState, ->; + type Output = HotShotTaskCompleted; + + async fn handle_event(event: Self::Event, task: &mut Task) -> Option { + let sender = task.clone_sender(); + task.state_mut().handle_event(event, &sender).await + } + + fn should_shutdown(event: &Self::Event) -> bool { + matches!(event, HotShotEvent::Shutdown) + } +} /// Trait for types which will handle a vote event. #[async_trait] @@ -142,12 +145,10 @@ where { /// Handle a vote event async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> ( - Option, - VoteCollectionTaskState, - ); + sender: &Sender>, + ) -> Option; /// Event filter to use for this event fn filter(event: &HotShotEvent) -> bool; @@ -161,12 +162,8 @@ pub struct AccumulatorInfo { pub membership: Arc, /// View of the votes we are collecting pub view: TYPES::Time, - /// Global event stream shared by all consensus tasks - pub event_stream: ChannelStream>, /// This nodes id pub id: u64, - /// Task Registry for all tasks used by this node - pub registry: GlobalRegistry, } /// Generic function for spawnnig a vote task. Returns the event stream id of the spawned task if created @@ -176,6 +173,7 @@ pub async fn create_vote_accumulator( info: &AccumulatorInfo, vote: VOTE, event: HotShotEvent, + sender: &Sender>, ) -> Option> where TYPES: NodeType, @@ -206,7 +204,6 @@ where }; let mut state = VoteCollectionTaskState:: { - event_stream: info.event_stream.clone(), membership: info.membership.clone(), public_key: info.public_key.clone(), accumulator: Some(new_accumulator), @@ -214,14 +211,13 @@ where id: info.id, }; - let result = state.handle_event(event.clone()).await; + let result = state.handle_event(event.clone(), sender).await; - if result.0 == Some(HotShotTaskCompleted::ShutDown) { + if result == Some(HotShotTaskCompleted) { // The protocol has finished return None; } - state = result.1; Some(state) } @@ -359,12 +355,13 @@ impl HandleVoteEvent, QuorumCertificat for QuorumVoteState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> (Option, QuorumVoteState) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::QuorumVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::QuorumVoteRecv(vote) => self.accumulate_vote(&vote, sender).await, + _ => None, } } fn filter(event: &HotShotEvent) -> bool { @@ -378,12 +375,13 @@ impl HandleVoteEvent, UpgradeCertific for UpgradeVoteState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> (Option, UpgradeVoteState) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::UpgradeVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::UpgradeVoteRecv(vote) => self.accumulate_vote(&vote, sender).await, + _ => None, } } fn filter(event: &HotShotEvent) -> bool { @@ -396,12 +394,13 @@ impl HandleVoteEvent, DACertificate for DAVoteState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> (Option, DAVoteState) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::DAVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::DAVoteRecv(vote) => self.accumulate_vote(&vote, sender).await, + _ => None, } } fn filter(event: &HotShotEvent) -> bool { @@ -414,12 +413,13 @@ impl HandleVoteEvent, TimeoutCertific for TimeoutVoteState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> (Option, TimeoutVoteState) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::TimeoutVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::TimeoutVoteRecv(vote) => self.accumulate_vote(&vote, sender).await, + _ => None, } } fn filter(event: &HotShotEvent) -> bool { @@ -433,12 +433,15 @@ impl for ViewSyncPreCommitState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> (Option, ViewSyncPreCommitState) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::ViewSyncPreCommitVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::ViewSyncPreCommitVoteRecv(vote) => { + self.accumulate_vote(&vote, sender).await + } + _ => None, } } fn filter(event: &HotShotEvent) -> bool { @@ -452,12 +455,13 @@ impl for ViewSyncCommitVoteState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> (Option, ViewSyncCommitVoteState) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::ViewSyncCommitVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::ViewSyncCommitVoteRecv(vote) => self.accumulate_vote(&vote, sender).await, + _ => None, } } fn filter(event: &HotShotEvent) -> bool { @@ -471,15 +475,15 @@ impl for ViewSyncFinalizeVoteState { async fn handle_event( - self, + &mut self, event: HotShotEvent, - ) -> ( - Option, - ViewSyncFinalizeVoteState, - ) { + sender: &Sender>, + ) -> Option { match event { - HotShotEvent::ViewSyncFinalizeVoteRecv(vote) => self.accumulate_vote(&vote).await, - _ => (None, self), + HotShotEvent::ViewSyncFinalizeVoteRecv(vote) => { + self.accumulate_vote(&vote, sender).await + } + _ => None, } } fn filter(event: &HotShotEvent) -> bool { diff --git a/crates/task/Cargo.toml b/crates/task/Cargo.toml index 51a80a1829..39c531a637 100644 --- a/crates/task/Cargo.toml +++ b/crates/task/Cargo.toml @@ -1,27 +1,22 @@ [package] authors = ["Espresso Systems "] -description = "Async task abstraction for use in consensus" -edition = "2021" name = "hotshot-task" version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -async-compatibility-layer = { workspace = true } -async-trait = { workspace = true } -either = { workspace = true } -futures = { workspace = true } -serde = { workspace = true } -snafu = { workspace = true } -async-lock = { workspace = true } + +futures = "0.3.30" +async-broadcast = "0.6.0" tracing = { workspace = true } -atomic_enum = "0.2.0" -pin-project = "1.1.4" +async-compatibility-layer = { workspace = true } [target.'cfg(all(async_executor_impl = "tokio"))'.dependencies] -tokio = { workspace = true } - +tokio = { workspace= true, features = ["time", "rt-multi-thread", "macros", "sync"] } [target.'cfg(all(async_executor_impl = "async-std"))'.dependencies] -async-std = { workspace = true } +async-std = { workspace= true, features = ["attributes"] } [lints] -workspace = true +workspace = true \ No newline at end of file diff --git a/crates/task/src/dependency.rs b/crates/task/src/dependency.rs new file mode 100644 index 0000000000..6ae793a7a4 --- /dev/null +++ b/crates/task/src/dependency.rs @@ -0,0 +1,270 @@ +use async_broadcast::{Receiver, RecvError}; +use futures::future::BoxFuture; +use futures::stream::FuturesUnordered; +use futures::stream::StreamExt; +use futures::FutureExt; +use std::future::Future; + +/// Type which describes the idea of waiting for a dependency to complete +pub trait Dependency { + /// Complete will wait until it gets some value `T` then return the value + fn completed(self) -> impl Future> + Send; + /// Create an or dependency from this dependency and another + fn or + Send + 'static>(self, dep: D) -> OrDependency + where + T: Send + Sync + Clone + 'static, + Self: Sized + Send + 'static, + { + let mut or = OrDependency::from_deps(vec![self]); + or.add_dep(dep); + or + } + /// Create an and dependency from this dependency and another + fn and + Send + 'static>(self, dep: D) -> AndDependency + where + T: Send + Sync + Clone + 'static, + Self: Sized + Send + 'static, + { + let mut and = AndDependency::from_deps(vec![self]); + and.add_dep(dep); + and + } +} + +/// Used to combine dependencies to create `AndDependency`s or `OrDependency`s +trait CombineDependencies: + Sized + Dependency + Send + 'static +{ +} + +/// Defines a dependency that completes when all of its deps complete +pub struct AndDependency { + /// Dependencies being combined + deps: Vec>>, +} +impl Dependency> for AndDependency { + /// Returns a vector of all of the results from it's dependencies. + /// The results will be in a random order + async fn completed(self) -> Option> { + let futures = FuturesUnordered::from_iter(self.deps); + futures + .collect::>>() + .await + .into_iter() + .collect() + } +} + +impl AndDependency { + /// Create from a vec of deps + #[must_use] + pub fn from_deps(deps: Vec + Send + 'static>) -> Self { + let mut pinned = vec![]; + for dep in deps { + pinned.push(dep.completed().boxed()); + } + Self { deps: pinned } + } + /// Add another dependency + pub fn add_dep(&mut self, dep: impl Dependency + Send + 'static) { + self.deps.push(dep.completed().boxed()); + } + /// Add multiple dependencies + pub fn add_deps(&mut self, deps: AndDependency) { + for dep in deps.deps { + self.deps.push(dep); + } + } +} + +/// Defines a dependency that complets when one of it's dependencies compeltes +pub struct OrDependency { + /// Dependencies being combined + deps: Vec>>, +} +impl Dependency for OrDependency { + /// Returns the value of the first completed dependency + async fn completed(self) -> Option { + let mut futures = FuturesUnordered::from_iter(self.deps); + loop { + if let Some(maybe) = futures.next().await { + if maybe.is_some() { + return maybe; + } + } else { + return None; + } + } + } +} + +impl OrDependency { + /// Creat an `OrDependency` from a vec of dependencies + #[must_use] + pub fn from_deps(deps: Vec + Send + 'static>) -> Self { + let mut pinned = vec![]; + for dep in deps { + pinned.push(dep.completed().boxed()); + } + Self { deps: pinned } + } + /// Add another dependecy + pub fn add_dep(&mut self, dep: impl Dependency + Send + 'static) { + self.deps.push(dep.completed().boxed()); + } +} + +/// A dependency that listens on a chanel for an event +/// that matches what some value it wants. +pub struct EventDependency { + /// Channel of incomming events + pub(crate) event_rx: Receiver, + /// Closure which returns true if the incoming `T` is the + /// thing that completes this dependency + pub(crate) match_fn: Box bool + Send>, +} + +impl EventDependency { + /// Create a new `EventDependency` + #[must_use] + pub fn new(receiver: Receiver, match_fn: Box bool + Send>) -> Self { + Self { + event_rx: receiver, + match_fn: Box::new(match_fn), + } + } +} + +impl Dependency for EventDependency { + async fn completed(mut self) -> Option { + loop { + match self.event_rx.recv_direct().await { + Ok(event) => { + if (self.match_fn)(&event) { + return Some(event); + } + } + Err(RecvError::Overflowed(n)) => { + tracing::error!("Dependency Task overloaded, skipping {} events", n); + } + Err(RecvError::Closed) => { + return None; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{AndDependency, Dependency, EventDependency, OrDependency}; + use async_broadcast::{broadcast, Receiver}; + + fn eq_dep(rx: Receiver, val: usize) -> EventDependency { + EventDependency { + event_rx: rx, + match_fn: Box::new(move |v| *v == val), + } + } + + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + async fn it_works() { + let (tx, rx) = broadcast(10); + + let mut deps = vec![]; + for i in 0..5 { + tx.broadcast(i).await.unwrap(); + deps.push(eq_dep(rx.clone(), 5)); + } + + let and = AndDependency::from_deps(deps); + tx.broadcast(5).await.unwrap(); + let result = and.completed().await; + assert_eq!(result, Some(vec![5; 5])); + } + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + async fn or_dep() { + let (tx, rx) = broadcast(10); + + tx.broadcast(5).await.unwrap(); + let mut deps = vec![]; + for _ in 0..5 { + deps.push(eq_dep(rx.clone(), 5)); + } + let or = OrDependency::from_deps(deps); + let result = or.completed().await; + assert_eq!(result, Some(5)); + } + + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + async fn and_or_dep() { + let (tx, rx) = broadcast(10); + + tx.broadcast(1).await.unwrap(); + tx.broadcast(2).await.unwrap(); + tx.broadcast(3).await.unwrap(); + tx.broadcast(5).await.unwrap(); + tx.broadcast(6).await.unwrap(); + + let or1 = OrDependency::from_deps([eq_dep(rx.clone(), 4), eq_dep(rx.clone(), 6)].into()); + let or2 = OrDependency::from_deps([eq_dep(rx.clone(), 4), eq_dep(rx.clone(), 5)].into()); + let and = AndDependency::from_deps([or1, or2].into()); + let result = and.completed().await; + assert_eq!(result, Some(vec![6, 5])); + } + + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + async fn or_and_dep() { + let (tx, rx) = broadcast(10); + + tx.broadcast(1).await.unwrap(); + tx.broadcast(2).await.unwrap(); + tx.broadcast(3).await.unwrap(); + tx.broadcast(4).await.unwrap(); + tx.broadcast(5).await.unwrap(); + + let and1 = eq_dep(rx.clone(), 4).and(eq_dep(rx.clone(), 6)); + let and2 = eq_dep(rx.clone(), 4).and(eq_dep(rx.clone(), 5)); + let or = and1.or(and2); + let result = or.completed().await; + assert_eq!(result, Some(vec![4, 5])); + } + + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + async fn many_and_dep() { + let (tx, rx) = broadcast(10); + + tx.broadcast(1).await.unwrap(); + tx.broadcast(2).await.unwrap(); + tx.broadcast(3).await.unwrap(); + tx.broadcast(4).await.unwrap(); + tx.broadcast(5).await.unwrap(); + tx.broadcast(6).await.unwrap(); + + let mut and1 = eq_dep(rx.clone(), 4).and(eq_dep(rx.clone(), 6)); + let and2 = eq_dep(rx.clone(), 4).and(eq_dep(rx.clone(), 5)); + and1.add_deps(and2); + let result = and1.completed().await; + assert_eq!(result, Some(vec![4, 6, 4, 5])); + } +} diff --git a/crates/task/src/dependency_task.rs b/crates/task/src/dependency_task.rs new file mode 100644 index 0000000000..9db6786637 --- /dev/null +++ b/crates/task/src/dependency_task.rs @@ -0,0 +1,140 @@ +#[cfg(async_executor_impl = "async-std")] +use async_std::task::{spawn, JoinHandle}; +#[cfg(async_executor_impl = "tokio")] +use tokio::task::{spawn, JoinHandle}; + +use futures::Future; + +use crate::dependency::Dependency; + +/// Defines a type that can handle the result of a dependency +pub trait HandleDepOutput: Send + Sized + Sync + 'static { + /// Type we expect from completed dependency + type Output: Send + Sync + 'static; + + /// Called once when the Dependency completes handles the results + fn handle_dep_result(self, res: Self::Output) -> impl Future + Send; +} + +/// A task that runs until it's dependency completes and it handles the result +pub struct DependencyTask + Send, H: HandleDepOutput + Send> { + /// Dependency this taks waits for + pub(crate) dep: D, + /// Handles the results returned from `self.dep.completed().await` + pub(crate) handle: H, +} + +impl + Send, H: HandleDepOutput + Send> DependencyTask { + /// Create a new `DependencyTask` + #[must_use] + pub fn new(dep: D, handle: H) -> Self { + Self { dep, handle } + } +} + +impl + Send + 'static, H: HandleDepOutput> DependencyTask { + /// Spawn the dependency task + pub fn run(self) -> JoinHandle<()> + where + Self: Sized, + { + spawn(async move { + if let Some(completed) = self.dep.completed().await { + self.handle.handle_dep_result(completed).await; + } + }) + } +} + +#[cfg(test)] +mod test { + + use std::time::Duration; + + use async_broadcast::{broadcast, Receiver, Sender}; + use futures::{stream::FuturesOrdered, StreamExt}; + + #[cfg(async_executor_impl = "async-std")] + use async_std::task::sleep; + #[cfg(async_executor_impl = "tokio")] + use tokio::time::sleep; + + use super::*; + use crate::dependency::*; + + #[derive(Clone, PartialEq, Eq, Debug)] + enum TaskResult { + Success(usize), + // Failure, + } + + struct DummyHandle { + sender: Sender, + } + impl HandleDepOutput for DummyHandle { + type Output = usize; + async fn handle_dep_result(self, res: usize) { + self.sender + .broadcast(TaskResult::Success(res)) + .await + .unwrap(); + } + } + + fn eq_dep(rx: Receiver, val: usize) -> EventDependency { + EventDependency { + event_rx: rx, + match_fn: Box::new(move |v| *v == val), + } + } + + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + // allow unused for tokio because it's a test + #[allow(unused_must_use)] + async fn it_works() { + let (tx, rx) = broadcast(10); + let (res_tx, mut res_rx) = broadcast(10); + let dep = eq_dep(rx, 2); + let handle = DummyHandle { sender: res_tx }; + let join_handle = DependencyTask { dep, handle }.run(); + tx.broadcast(2).await.unwrap(); + assert_eq!(res_rx.recv().await.unwrap(), TaskResult::Success(2)); + + join_handle.await; + } + + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + async fn many_works() { + let (tx, rx) = broadcast(20); + let (res_tx, mut res_rx) = broadcast(20); + + let mut handles = vec![]; + for i in 0..10 { + let dep = eq_dep(rx.clone(), i); + let handle = DummyHandle { + sender: res_tx.clone(), + }; + handles.push(DependencyTask { dep, handle }.run()); + } + let tx2 = tx.clone(); + spawn(async move { + for i in 0..10 { + tx.broadcast(i).await.unwrap(); + sleep(Duration::from_millis(10)).await; + } + }); + for i in 0..10 { + assert_eq!(res_rx.recv().await.unwrap(), TaskResult::Success(i)); + } + tx2.broadcast(100).await.unwrap(); + FuturesOrdered::from_iter(handles).collect::>().await; + } +} diff --git a/crates/task/src/event_stream.rs b/crates/task/src/event_stream.rs deleted file mode 100644 index 5248fe4373..0000000000 --- a/crates/task/src/event_stream.rs +++ /dev/null @@ -1,268 +0,0 @@ -use async_compatibility_layer::channel::{unbounded, UnboundedSender, UnboundedStream}; -use async_lock::RwLock; -use std::{ - collections::HashMap, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use async_trait::async_trait; -use futures::Stream; - -use crate::task::{FilterEvent, PassType}; - -/// a stream that does nothing. -/// it's immediately closed -#[derive(Clone)] -pub struct DummyStream; - -impl Stream for DummyStream { - type Item = (); - - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(None) - } -} - -#[async_trait] -impl EventStream for DummyStream { - type EventType = (); - - type StreamType = DummyStream; - - async fn publish(&self, _event: Self::EventType) {} - - async fn subscribe( - &self, - _filter: FilterEvent, - ) -> (Self::StreamType, StreamId) { - (DummyStream, 0) - } - - async fn unsubscribe(&self, _id: StreamId) {} - - async fn direct_message(&self, _id: StreamId, _event: Self::EventType) {} -} - -impl SendableStream for DummyStream {} - -/// this is only used for indexing -pub type StreamId = usize; - -/// a stream that plays nicely with async -pub trait SendableStream: Stream + Sync + Send + 'static {} - -/// Async pub sub event stream -/// NOTE: static bound indicates that if the type points to data, that data lives for the lifetime -/// of the program -#[async_trait] -pub trait EventStream: Clone + 'static + Sync + Send { - /// the type of event to process - type EventType: PassType; - /// the type of stream to use - type StreamType: SendableStream; - - /// publish an event to the event stream - async fn publish(&self, event: Self::EventType); - - /// subscribe to a particular set of events - /// specified by `filter`. Filter returns true if the event should be propagated - /// TODO (justin) rethink API, we might be able just to use `StreamExt::filter` and `Filter` - /// That would certainly be cleaner - async fn subscribe(&self, filter: FilterEvent) - -> (Self::StreamType, StreamId); - - /// unsubscribe from the stream - async fn unsubscribe(&self, id: StreamId); - - /// send direct message to node - async fn direct_message(&self, id: StreamId, event: Self::EventType); -} - -/// Event stream implementation using channels as the underlying primitive. -/// We want it to be cloneable -#[derive(Clone)] -pub struct ChannelStream { - /// inner field. Useful for having the stream itself - /// be clone - inner: Arc>>, -} - -/// trick to make the event stream clonable -struct ChannelStreamInner { - /// the subscribers to the channel - subscribers: HashMap, UnboundedSender)>, - /// the next unused assignable id - next_stream_id: StreamId, -} - -impl ChannelStream { - /// construct a new event stream - #[must_use] - pub fn new() -> Self { - Self { - inner: Arc::new(RwLock::new(ChannelStreamInner { - subscribers: HashMap::new(), - next_stream_id: 0, - })), - } - } -} - -impl Default for ChannelStream { - fn default() -> Self { - Self::new() - } -} - -impl SendableStream for UnboundedStream {} - -#[async_trait] -impl EventStream for ChannelStream { - type EventType = EVENT; - type StreamType = UnboundedStream; - - async fn direct_message(&self, id: StreamId, event: Self::EventType) { - let inner = self.inner.write().await; - match inner.subscribers.get(&id) { - Some((filter, sender)) => { - if filter(&event) { - match sender.send(event.clone()).await { - Ok(()) => (), - // error sending => stream is closed so remove it - Err(_) => self.unsubscribe(id).await, - } - } - } - None => { - tracing::debug!("Requested stream id not found"); - } - } - } - - /// publish an event to the event stream - async fn publish(&self, event: Self::EventType) { - let inner = self.inner.read().await; - for (uid, (filter, sender)) in &inner.subscribers { - if filter(&event) { - match sender.send(event.clone()).await { - Ok(()) => (), - // error sending => stream is closed so remove it - Err(_) => { - self.unsubscribe(*uid).await; - } - } - } - } - } - - async fn subscribe( - &self, - filter: FilterEvent, - ) -> (Self::StreamType, StreamId) { - let mut inner = self.inner.write().await; - let new_stream_id = inner.next_stream_id; - let (s, r) = unbounded(); - inner.next_stream_id += 1; - // NOTE: can never be already existing. - // so, this should always return `None` - inner.subscribers.insert(new_stream_id, (filter, s)); - (r.into_stream(), new_stream_id) - } - - async fn unsubscribe(&self, uid: StreamId) { - let mut inner = self.inner.write().await; - inner.subscribers.remove(&uid); - } -} - -#[cfg(test)] -pub mod test { - use crate::{event_stream::EventStream, StreamExt}; - use async_compatibility_layer::art::{async_sleep, async_spawn}; - use std::time::Duration; - - #[derive(Clone, Debug, PartialEq, Eq)] - enum TestMessage { - One, - Two, - Three, - } - - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_channel_stream_basic() { - use crate::task::FilterEvent; - - use super::ChannelStream; - - let channel_stream = ChannelStream::::new(); - let (mut stream, _) = channel_stream.subscribe(FilterEvent::default()).await; - let dup_channel_stream = channel_stream.clone(); - - let dup_dup_channel_stream = channel_stream.clone(); - - async_spawn(async move { - let (mut stream, _) = dup_channel_stream.subscribe(FilterEvent::default()).await; - assert!(stream.next().await.unwrap() == TestMessage::Three); - assert!(stream.next().await.unwrap() == TestMessage::One); - assert!(stream.next().await.unwrap() == TestMessage::Two); - }); - - async_spawn(async move { - dup_dup_channel_stream.publish(TestMessage::Three).await; - dup_dup_channel_stream.publish(TestMessage::One).await; - dup_dup_channel_stream.publish(TestMessage::Two).await; - }); - async_sleep(Duration::new(3, 0)).await; - - assert!(stream.next().await.unwrap() == TestMessage::Three); - assert!(stream.next().await.unwrap() == TestMessage::One); - assert!(stream.next().await.unwrap() == TestMessage::Two); - } - - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_channel_stream_xtreme() { - use crate::task::FilterEvent; - - use super::ChannelStream; - - let channel_stream = ChannelStream::::new(); - let mut streams = Vec::new(); - - for _i in 0..1000 { - let dup_channel_stream = channel_stream.clone(); - let (stream, _) = dup_channel_stream.subscribe(FilterEvent::default()).await; - streams.push(stream); - } - - let dup_dup_channel_stream = channel_stream.clone(); - - for _i in 0..1000 { - let mut stream = streams.pop().unwrap(); - async_spawn(async move { - for event in [TestMessage::One, TestMessage::Two, TestMessage::Three] { - for _ in 0..100 { - assert!(stream.next().await.unwrap() == event); - } - } - }); - } - - async_spawn(async move { - for event in [TestMessage::One, TestMessage::Two, TestMessage::Three] { - for _ in 0..100 { - dup_dup_channel_stream.publish(event.clone()).await; - } - } - }); - } -} diff --git a/crates/task/src/global_registry.rs b/crates/task/src/global_registry.rs deleted file mode 100644 index 1977c21c76..0000000000 --- a/crates/task/src/global_registry.rs +++ /dev/null @@ -1,214 +0,0 @@ -use async_lock::RwLock; -use either::Either; -use futures::{future::BoxFuture, FutureExt}; -use std::{ - collections::{BTreeMap, BTreeSet}, - ops::Deref, - sync::Arc, -}; - -use crate::task_state::{TaskState, TaskStatus}; - -/// function to shut down gobal registry -#[derive(Clone)] -pub struct ShutdownFn(pub Arc BoxFuture<'static, ()> + Sync + Send>); - -// TODO this might cleaner as `run()` -// but then this pattern should change everywhere -impl Deref for ShutdownFn { - type Target = dyn Fn() -> BoxFuture<'static, ()> + Sync + Send; - - fn deref(&self) -> &Self::Target { - &*self.0 - } -} - -/// id of task. Usize instead of u64 because -/// used for primarily for indexing -pub type HotShotTaskId = usize; - -/// the global registry provides a place to: -/// - inquire about the state of various tasks -/// - gracefully shut down tasks -#[derive(Debug, Clone)] -pub struct GlobalRegistry { - /// up-to-date shared list of statuses - /// only used if `state_cache` is out of date - /// or if appending - state_list: Arc>>, - /// possibly stale read version of state - /// NOTE: must include entire state in order to - /// support both incrementing and reading. - /// Writing to the status should gracefully shut down the task - state_cache: BTreeMap, -} - -/// function to modify state -#[allow(clippy::type_complexity)] -struct Modifier(Box Either + Send>); - -impl Default for GlobalRegistry { - fn default() -> Self { - Self::new() - } -} - -impl GlobalRegistry { - /// create new registry - #[must_use] - pub fn new() -> Self { - Self { - state_list: Arc::new(RwLock::new(BTreeMap::default())), - state_cache: BTreeMap::default(), - } - } - - /// register with the global registry - /// return a function to the caller (task) that can be used to deregister - /// returns a function to call to shut down the task - /// and the unique identifier of the task - pub async fn register(&mut self, name: &str, status: TaskState) -> (ShutdownFn, HotShotTaskId) { - let mut list = self.state_list.write().await; - let next_id = list - .last_key_value() - .map(|(k, _v)| k) - .copied() - .unwrap_or_default() - + 1; - let new_entry = (status.clone(), name.to_string()); - let new_entry_dup = new_entry.0.clone(); - list.insert(next_id, new_entry.clone()); - - self.state_cache.insert(next_id, new_entry); - - let shutdown_fn = ShutdownFn(Arc::new(move || { - new_entry_dup.set_state(TaskStatus::Completed); - async move {}.boxed() - })); - (shutdown_fn, next_id) - } - - /// update the cache - async fn update_cache(&mut self) { - // NOTE: this can be done much more cleverly - // avoid one intersection by comparing max keys (constant time op vs O(n + m)) - // and debatable how often the other op needs to be run - // probably much much less often - let list = self.state_list.read().await; - let list_keys: BTreeSet = list.keys().copied().collect(); - let cache_keys: BTreeSet = self.state_cache.keys().copied().collect(); - // bleh not as efficient - let missing_key_list = list_keys.difference(&cache_keys); - let expired_key_list = cache_keys.difference(&list_keys); - - for expired_key in expired_key_list { - self.state_cache.remove(expired_key); - } - - for key in missing_key_list { - // technically shouldn't be possible for this to be none since - // we have a read lock - // nevertheless, this seems easier - if let Some(val) = list.get(key) { - self.state_cache.insert(*key, val.clone()); - } - } - } - - /// internal function to run `modifier` on `uid` - /// if it exists - async fn operate_on_task( - &mut self, - uid: HotShotTaskId, - modifier: Modifier, - ) -> Either { - // the happy path - if let Some(ele) = self.state_cache.get(&uid) { - modifier.0(&ele.0) - } - // the sad path - else { - self.update_cache().await; - if let Some(ele) = self.state_cache.get(&uid) { - modifier.0(&ele.0) - } else { - Either::Right(false) - } - } - } - - /// set `uid`'s state to paused - /// returns true upon success and false if `uid` is not registered - pub async fn pause_task(&mut self, uid: HotShotTaskId) -> bool { - let modifier = Modifier(Box::new(|state| { - state.set_state(TaskStatus::Paused); - Either::Right(true) - })); - match self.operate_on_task(uid, modifier).await { - Either::Left(_) => unreachable!(), - Either::Right(b) => b, - } - } - - /// set `uid`'s state to running - /// returns true upon success and false if `uid` is not registered - pub async fn run_task(&mut self, uid: HotShotTaskId) -> bool { - let modifier = Modifier(Box::new(|state| { - state.set_state(TaskStatus::Running); - Either::Right(true) - })); - match self.operate_on_task(uid, modifier).await { - Either::Left(_) => unreachable!(), - Either::Right(b) => b, - } - } - - /// if the `uid` is registered with the global registry - /// return its task status - /// this is a way to subscribe to state changes from the taskstatus - /// since `HotShotTaskStatus` implements stream - pub async fn get_task_state(&mut self, uid: HotShotTaskId) -> Option { - let modifier = Modifier(Box::new(|state| Either::Left(state.get_status()))); - match self.operate_on_task(uid, modifier).await { - Either::Left(state) => Some(state), - Either::Right(false) => None, - Either::Right(true) => unreachable!(), - } - } - - /// shut down a task from a different thread - /// returns true if succeeded - /// returns false if the task does not exist - pub async fn shutdown_task(&mut self, uid: usize) -> bool { - let modifier = Modifier(Box::new(|state| { - state.set_state(TaskStatus::Completed); - Either::Right(true) - })); - let result = match self.operate_on_task(uid, modifier).await { - Either::Left(_) => unreachable!(), - Either::Right(b) => b, - }; - let mut list = self.state_list.write().await; - list.remove(&uid); - result - } - - /// checks if all registered tasks have completed - pub async fn is_shutdown(&mut self) -> bool { - let task_list = self.state_list.read().await; - for task in (*task_list).values() { - if task.0.get_status() != TaskStatus::Completed { - return false; - } - } - true - } - - /// shut down all tasks in registry - pub async fn shutdown_all(&mut self) { - let mut task_list = self.state_list.write().await; - while let Some((_uid, task)) = task_list.pop_last() { - task.0.set_state(TaskStatus::Completed); - } - } -} diff --git a/crates/task/src/lib.rs b/crates/task/src/lib.rs index 918a0eaded..cf71eb7090 100644 --- a/crates/task/src/lib.rs +++ b/crates/task/src/lib.rs @@ -1,385 +1,8 @@ -//! Abstractions meant for usage with long running consensus tasks -//! and testing harness +//! Task primatives for `HotShot` -use crate::task::PassType; -use either::Either; -use event_stream::SendableStream; -use Poll::{Pending, Ready}; -// The spawner of the task should be able to fire and forget the task if it makes sense. -use futures::{stream::Fuse, Future, Stream, StreamExt}; -use std::{ - pin::Pin, - slice::SliceIndex, - sync::Arc, - task::{Context, Poll}, -}; -// NOTE use pin_project here because we're already bring in procedural macros elsewhere -// so there is no reason to use pin_project_lite -use pin_project::pin_project; - -/// Astractions over the state of a task and a stream -/// interface for task changes. Allows in the happy path -/// for lockless manipulation of tasks -/// and in the sad case, only the use of a `std::sync::mutex` -pub mod task_state; - -/// the global registry storing the status of all tasks -/// as well as the abiliity to terminate them -pub mod global_registry; - -/// mpmc streamable to all subscribed tasks -pub mod event_stream; - -/// The `HotShot` Task. The main point of this library. Uses all other abstractions -/// to create an abstraction over tasks +/// Simple Dependecy types +pub mod dependency; +/// Task which can uses dependencies +pub mod dependency_task; +/// Basic task types pub mod task; - -/// The hotshot task launcher. Useful for constructing tasks -pub mod task_launcher; - -/// the task implementations with different features -pub mod task_impls; - -/// merge `N` streams of the same type -#[pin_project] -pub struct MergeN { - /// Streams to be merged. - #[pin] - streams: Vec>, - /// idx to start polling - idx: usize, -} - -impl MergeN { - /// create a new stream - #[must_use] - pub fn new(streams: Vec) -> MergeN { - let fused_streams = streams.into_iter().map(StreamExt::fuse).collect(); - MergeN { - streams: fused_streams, - idx: 0, - } - } -} - -impl PassType for T {} - -impl SendableStream for MergeN {} - -// NOTE: yoinked from https://github.com/yoshuawuyts/futures-concurrency/ -// we should really just use `futures-concurrency`. I'm being lazy here -// and not bringing in yet another dependency. Note: their merge is implemented much -// more cleverly than this rather naive impl - -// NOTE: If this is implemented through the trait, this will work on both vecs and -// slices. -// -// From: https://github.com/rust-lang/rust/pull/78370/files -/// Get a pinned mutable pointer from a list. -pub(crate) fn get_pin_mut_from_vec( - slice: Pin<&mut Vec>, - index: I, -) -> Option> -where - I: SliceIndex<[T]>, -{ - // SAFETY: `get_unchecked_mut` is never used to move the slice inside `self` (`SliceIndex` - // is sealed and all `SliceIndex::get_mut` implementations never move elements). - // `x` is guaranteed to be pinned because it comes from `self` which is pinned. - unsafe { - slice - .get_unchecked_mut() - .get_mut(index) - .map(|x| Pin::new_unchecked(x)) - } -} - -impl Stream for MergeN { - // idx of the stream, item - type Item = (usize, ::Item); - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut me = self.project(); - - let idx = *me.idx; - *me.idx = (idx + 1) % me.streams.len(); - - let first_half = idx..me.streams.len(); - let second_half = 0..idx; - - let iterator = first_half.chain(second_half); - - let mut done = false; - - for i in iterator { - let stream = get_pin_mut_from_vec(me.streams.as_mut(), i).unwrap(); - - match stream.poll_next(cx) { - Ready(Some(val)) => return Ready(Some((i, val))), - Ready(None) => {} - Pending => done = false, - } - } - - if done { - Ready(None) - } else { - Pending - } - } -} - -// NOTE: yoinked /from async-std -// except this is executor agnostic (doesn't rely on async-std streamext/fuse) -// NOTE: usage of this is for combining streams into one main stream -// for usage with `MessageStream` -// TODO move this to async-compatibility-layer -#[pin_project] -/// Stream type that merges two underlying streams -pub struct Merge { - /// first stream to merge - #[pin] - a: Fuse, - /// second stream to merge - #[pin] - b: Fuse, - /// When `true`, poll `a` first, otherwise, `poll` b`. - a_first: bool, -} - -impl Merge { - /// create a new Merged stream - pub fn new(a: T, b: U) -> Merge - where - T: Stream, - U: Stream, - { - Merge { - a: a.fuse(), - b: b.fuse(), - a_first: true, - } - } -} - -impl Stream for Merge -where - T: Stream, - U: Stream, -{ - type Item = Either; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let me = self.project(); - let a_first = *me.a_first; - - // Toggle the flag - *me.a_first = !a_first; - - poll_next(me.a, me.b, cx, a_first) - } - - fn size_hint(&self) -> (usize, Option) { - let (a_lower, a_upper) = self.a.size_hint(); - let (b_lower, b_upper) = self.b.size_hint(); - - let upper = match (a_upper, b_upper) { - (Some(a_upper), Some(b_upper)) => Some(a_upper + b_upper), - _ => None, - }; - - (a_lower + b_lower, upper) - } -} - -impl SendableStream for Merge -where - T: Stream + Send + Sync + 'static, - U: Stream + Send + Sync + 'static, -{ -} - -/// poll the next item in the merged stream -fn poll_next( - first: Pin<&mut T>, - second: Pin<&mut U>, - cx: &mut Context<'_>, - order: bool, -) -> Poll>> -where - T: Stream, - U: Stream, -{ - let mut done = true; - - // there's definitely a better way to do this - if order { - match first.poll_next(cx) { - Ready(Some(val)) => return Ready(Some(Either::Left(val))), - Ready(None) => {} - Pending => done = false, - } - - match second.poll_next(cx) { - Ready(Some(val)) => return Ready(Some(Either::Right(val))), - Ready(None) => {} - Pending => done = false, - } - } else { - match second.poll_next(cx) { - Ready(Some(val)) => return Ready(Some(Either::Right(val))), - Ready(None) => {} - Pending => done = false, - } - - match first.poll_next(cx) { - Ready(Some(val)) => return Ready(Some(Either::Left(val))), - Ready(None) => {} - Pending => done = false, - } - } - - if done { - Ready(None) - } else { - Pending - } -} - -/// gotta make the futures sync -pub type BoxSyncFuture<'a, T> = Pin + Send + Sync + 'a>>; - -/// may be treated as a stream -#[pin_project(project = ProjectedStreamableThing)] -pub struct GeneratedStream { - // todo maybe type wrapper is in order - /// Stream generator. - generator: Arc Option> + Sync + Send>, - /// Optional in-progress future. - in_progress_fut: Option>, -} - -impl GeneratedStream { - /// create a generator - pub fn new( - generator: Arc Option> + Sync + Send>, - ) -> Self { - GeneratedStream { - generator, - in_progress_fut: None, - } - } -} - -impl Stream for GeneratedStream { - type Item = O; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let projection = self.project(); - match projection.in_progress_fut { - Some(fut) => { - // NOTE: this is entirely safe. - // We will ONLY poll if we've been awakened. - // otherwise, we won't poll. - match fut.as_mut().poll(cx) { - Ready(val) => { - *projection.in_progress_fut = None; - Poll::Ready(Some(val)) - } - Pending => Poll::Pending, - } - } - None => { - let wrapped_fut = (*projection.generator)(); - let Some(mut fut) = wrapped_fut else { - return Poll::Ready(None); - }; - match fut.as_mut().poll(cx) { - Ready(val) => { - *projection.in_progress_fut = None; - Poll::Ready(Some(val)) - } - Pending => { - *projection.in_progress_fut = Some(fut); - Poll::Pending - } - } - } - } - } -} - -/// yoinked from futures crate -pub fn assert_future(future: F) -> F -where - F: Future, -{ - future -} - -/// yoinked from futures crate, adds sync bound that we need -pub fn boxed_sync<'a, F>(fut: F) -> BoxSyncFuture<'a, F::Output> -where - F: Future + Sized + Send + Sync + 'a, -{ - assert_future::(Box::pin(fut)) -} - -impl SendableStream for GeneratedStream {} - -#[cfg(test)] -pub mod test { - use crate::{boxed_sync, Arc, GeneratedStream, StreamExt}; - - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_stream_basic() { - let mut stream = GeneratedStream:: { - generator: Arc::new(move || { - let closure = async move { 5 }; - Some(boxed_sync(closure)) - }), - in_progress_fut: None, - }; - assert!(stream.next().await == Some(5)); - assert!(stream.next().await == Some(5)); - assert!(stream.next().await == Some(5)); - assert!(stream.next().await == Some(5)); - } - - #[cfg(test)] - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_stream_fancy() { - use async_compatibility_layer::art::async_sleep; - use std::{sync::atomic::Ordering, time::Duration}; - - let value = Arc::::default(); - let mut stream = GeneratedStream:: { - generator: Arc::new(move || { - let value = value.clone(); - let closure = async move { - let actual_value = value.load(Ordering::Relaxed); - value.store(actual_value + 1, Ordering::Relaxed); - async_sleep(Duration::new(0, 500)).await; - u32::from(actual_value) - }; - Some(boxed_sync(closure)) - }), - in_progress_fut: None, - }; - assert!(stream.next().await == Some(0)); - assert!(stream.next().await == Some(1)); - assert!(stream.next().await == Some(2)); - assert!(stream.next().await == Some(3)); - } -} diff --git a/crates/task/src/task.rs b/crates/task/src/task.rs index 8435ff0fcf..e87f3465d5 100644 --- a/crates/task/src/task.rs +++ b/crates/task/src/task.rs @@ -1,637 +1,454 @@ -use std::{ - fmt::{Debug, Formatter}, - ops::Deref, - pin::Pin, - task::{Context, Poll}, -}; - -use async_compatibility_layer::art::async_yield_now; -use either::Either::{self, Left, Right}; -use futures::{future::BoxFuture, stream::Fuse, Future, FutureExt, Stream, StreamExt}; -use pin_project::pin_project; use std::sync::Arc; - -use crate::{ - event_stream::{EventStream, SendableStream, StreamId}, - global_registry::{GlobalRegistry, HotShotTaskId, ShutdownFn}, - task_impls::TaskBuilder, - task_state::{TaskState, TaskStatus}, +use std::time::Duration; + +use async_broadcast::{Receiver, SendError, Sender}; +use async_compatibility_layer::art::async_timeout; +#[cfg(async_executor_impl = "async-std")] +use async_std::{ + sync::RwLock, + task::{spawn, JoinHandle}, }; +use futures::{future::select_all, Future}; -/// restrictions on types we wish to pass around. -/// Includes messages and events -pub trait PassType: Clone + Debug + Sync + Send + 'static {} +#[cfg(async_executor_impl = "async-std")] +use futures::future::join_all; -/// the task state -pub trait TS: Sync + Send + 'static {} +#[cfg(async_executor_impl = "tokio")] +use futures::future::try_join_all; -/// a task error that has nice qualities -#[allow(clippy::module_name_repetitions)] -pub trait TaskErr: std::error::Error + Sync + Send + 'static {} - -impl TaskErr for T {} +#[cfg(async_executor_impl = "tokio")] +use tokio::{ + sync::RwLock, + task::{spawn, JoinHandle}, +}; +use tracing::error; -/// group of types needed for a hotshot task -pub trait HotShotTaskTypes: 'static { - /// the event type from the event stream - type Event: PassType; - /// the state of the task - type State: TS; - /// the global event stream - type EventStream: EventStream; - /// the message stream to receive - type Message: PassType; - /// the steam of messages from other tasks - type MessageStream: SendableStream; - /// the error to return - type Error: TaskErr + 'static + ?Sized; +use crate::{ + dependency::Dependency, + dependency_task::{DependencyTask, HandleDepOutput}, +}; - /// build a task - /// NOTE: done here and not on `TaskBuilder` because - /// we want specific checks done on each variant - /// NOTE: all generics implement `Sized`, but this bound is - /// NOT applied to `Self` unless we specify - fn build(builder: TaskBuilder) -> HST +/// Type for mutable task state that can be used as the state for a `Task` +pub trait TaskState: Send { + /// Type of event sent and received by the task + type Event: Clone + Send + Sync + 'static; + /// The result returned when this task compeltes + type Output: Send; + /// Handle event and update state. Return true if the task is finished + /// false otherwise. The handler can access the state through `Task::state_mut` + fn handle_event( + event: Self::Event, + task: &mut Task, + ) -> impl Future> + Send where Self: Sized; -} -/// hot shot task -#[pin_project(project = ProjectedHST)] -#[allow(clippy::type_complexity)] -pub struct HST { - /// Optional ID of the stream. - pub(crate) stream_id: Option, - /// the eventual return value, post-cleanup - r_val: Option, - /// if we have a future for tracking shutdown progress - in_progress_shutdown_fut: Option>, - /// the in progress future - in_progress_fut: Option, HSTT::State)>>, - /// name of task - name: String, - /// state of the task - /// TODO make this boxed. We don't want to assume this is a small future. - /// since it concievably may be stored on the stack - #[pin] - status: TaskState, - /// functions performing cleanup - /// one should shut down the task - /// if we're tracking with a global registry - /// the other should unsubscribe from the stream - shutdown_fns: Vec, - /// shared stream - event_stream: MaybePinnedEventStream, - /// stream of messages - message_stream: Option>>>, - /// state - state: Option, - /// handler for events - handle_event: Option>, - /// handler for messages - handle_message: Option>, - /// task id - pub(crate) tid: Option, + /// Return true if the event should be filtered + fn filter(&self, _event: &Self::Event) -> bool { + // default doesn't filter + false + } + /// Do something with the result of the task before it shuts down + fn handle_result(&self, _res: &Self::Output) -> impl std::future::Future + Send { + async {} + } + /// Return true if the event should shut the task down + fn should_shutdown(event: &Self::Event) -> bool; + /// Handle anything before the task is completely shutdown + fn shutdown(&mut self) -> impl std::future::Future + Send { + async {} + } } -/// an option of a pinned boxed fused event stream -pub type MaybePinnedEventStream = - Option::EventStream as EventStream>::StreamType>>>>; - -/// ADT for wrapping all possible handler types -#[allow(dead_code)] -pub(crate) enum HotShotTaskHandler { - /// handle an event - HandleEvent(HandleEvent), - /// handle a message - HandleMessage(HandleMessage), - /// filter an event - FilterEvent(FilterEvent), - /// deregister with the registry - Shutdown(ShutdownFn), +/// Task state for a test. Similar to `TaskState` but it handles +/// messages as well as events. Messages are events that are +/// external to this task. (i.e. a test message would be an event from non test task) +/// This is used as state for `TestTask` and messages can come from many +/// different input streams. +pub trait TestTaskState: Send { + /// Message type handled by the task + type Message: Clone + Send + Sync + 'static; + /// Result returned by the test task on completion + type Output: Send; + /// The state type + type State: TaskState; + /// Handle and incoming message and return `Some` if the task is finished + fn handle_message( + message: Self::Message, + id: usize, + task: &mut TestTask, + ) -> impl Future> + Send + where + Self: Sized; } -/// Type wrapper for handling an event -#[allow(clippy::type_complexity)] -pub struct HandleEvent( - pub Arc< - dyn Fn( - HSTT::Event, - HSTT::State, - ) -> BoxFuture<'static, (Option, HSTT::State)> - + Sync - + Send, - >, -); - -impl Default for HandleEvent { - fn default() -> Self { - Self(Arc::new(|_event, state| { - async move { (None, state) }.boxed() - })) - } +/// A basic task which loops waiting for events to come from `event_receiver` +/// and then handles them using it's state +/// It sends events to other `Task`s through `event_sender` +/// This should be used as the primary building block for long running +/// or medium running tasks (i.e. anything that can't be described as a dependency task) +pub struct Task { + /// Sends events all tasks including itself + event_sender: Sender, + /// Receives events that are broadcast from any task, including itself + event_receiver: Receiver, + /// Contains this task, used to register any spawned tasks + registry: Arc, + /// The state of the task. It is fed events from `event_sender` + /// and mutates it state ocordingly. Also it signals the task + /// if it is complete/should shutdown + state: S, } -impl Deref for HandleEvent { - type Target = dyn Fn( - HSTT::Event, - HSTT::State, - ) -> BoxFuture<'static, (Option, HSTT::State)>; - - fn deref(&self) -> &Self::Target { - &*self.0 +impl Task { + /// Create a new task + pub fn new( + tx: Sender, + rx: Receiver, + registry: Arc, + state: S, + ) -> Self { + Task { + event_sender: tx, + event_receiver: rx, + registry, + state, + } } -} - -/// Type wrapper for handling a message -#[allow(clippy::type_complexity)] -pub struct HandleMessage( - pub Arc< - dyn Fn( - HSTT::Message, - HSTT::State, - ) -> BoxFuture<'static, (Option, HSTT::State)> - + Sync - + Send, - >, -); -impl Deref for HandleMessage { - type Target = dyn Fn( - HSTT::Message, - HSTT::State, - ) -> BoxFuture<'static, (Option, HSTT::State)>; - - fn deref(&self) -> &Self::Target { - &*self.0 + /// Spawn the task loop, consuming self. Will continue until + /// the task reaches some shutdown condition + pub fn run(mut self) -> JoinHandle<()> { + spawn(async move { + loop { + match self.event_receiver.recv_direct().await { + Ok(event) => { + if S::should_shutdown(&event) { + self.state.shutdown().await; + break; + } + if self.state.filter(&event) { + continue; + } + if let Some(res) = S::handle_event(event, &mut self).await { + self.state.handle_result(&res).await; + self.state.shutdown().await; + break; + } + } + Err(e) => { + tracing::error!("Failed to receiving from event stream Error: {}", e); + } + } + } + }) } -} - -/// Return `true` if the event should be filtered -#[derive(Clone)] -pub struct FilterEvent(pub Arc bool + Send + 'static + Sync>); -impl Default for FilterEvent { - fn default() -> Self { - Self(Arc::new(|_| true)) + /// Create a new event `Receiver` from this Task's receiver. + /// The returned receiver will get all messages not yet seen by this task + pub fn subscribe(&self) -> Receiver { + self.event_receiver.clone() } -} - -impl Deref for FilterEvent { - type Target = dyn Fn(&EVENT) -> bool + Send + 'static + Sync; - - fn deref(&self) -> &Self::Target { - &*self.0 + /// Get a new sender handle for events + pub fn sender(&self) -> &Sender { + &self.event_sender } -} - -impl HST { - /// Do a consistency check on the `HST` construction - pub(crate) fn base_check(&self) { - assert!(!self.shutdown_fns.is_empty(), "No shutdown functions"); - assert!( - self.in_progress_fut.is_none(), - "This future has already been polled" - ); - - assert!(self.state.is_some(), "Didn't register state"); - - assert!(self.tid.is_some(), "Didn't register global registry"); + /// Clone the sender handle + pub fn clone_sender(&self) -> Sender { + self.event_sender.clone() } - - /// perform event sanity checks - pub(crate) fn event_check(&self) { - assert!( - self.shutdown_fns.len() == 2, - "Expected 2 shutdown functions" - ); - assert!(self.event_stream.is_some(), "Didn't register event stream"); - assert!(self.handle_event.is_some(), "Didn't register event handler"); + /// Broadcast a message to all listening tasks + /// # Errors + /// Errors if the broadcast fails + pub async fn send(&self, event: S::Event) -> Result, SendError> { + self.event_sender.broadcast(event).await } - - /// perform message sanity checks - pub(crate) fn message_check(&self) { - assert!( - self.handle_message.is_some(), - "Didn't register message handler" - ); - assert!( - self.message_stream.is_some(), - "Didn't register message stream" - ); + /// Get a mutable reference to this tasks state + pub fn state_mut(&mut self) -> &mut S { + &mut self.state } - - /// register a handler with the task - #[must_use] - pub(crate) fn register_handler(self, handler: HotShotTaskHandler) -> Self { - match handler { - HotShotTaskHandler::HandleEvent(handler) => Self { - handle_event: Some(handler), - ..self - }, - HotShotTaskHandler::HandleMessage(handler) => Self { - handle_message: Some(handler), - ..self - }, - HotShotTaskHandler::FilterEvent(_handler) => unimplemented!(), - HotShotTaskHandler::Shutdown(_handler) => unimplemented!(), - } + /// Spawn a new task adn register it. It will get all events not seend + /// by the task creating it. + pub async fn run_sub_task(&self, state: S) { + let task = Task { + event_sender: self.clone_sender(), + event_receiver: self.subscribe(), + registry: self.registry.clone(), + state, + }; + // Note: await here is only awaiting the task to be added to the + // registry, not for the task to run. + self.registry.run_task(task).await; } +} - /// register an event stream with the task - pub(crate) async fn register_event_stream( - self, - event_stream: HSTT::EventStream, - filter: FilterEvent, - ) -> Self { - let (stream, uid) = event_stream.subscribe(filter).await; - - let mut shutdown_fns = self.shutdown_fns; - { - let event_stream = event_stream.clone(); - shutdown_fns.push(ShutdownFn(Arc::new(move || -> BoxFuture<'static, ()> { - let event_stream = event_stream.clone(); - async move { - event_stream.clone().unsubscribe(uid).await; - } - .boxed() - }))); - } - // TODO perhaps GC the event stream - // (unsunscribe) - Self { - event_stream: Some(Box::pin(stream.fuse())), - shutdown_fns, - stream_id: Some(uid), - ..self - } - } +/// Similar to `Task` but adds functionality for testing. Notably +/// it adds message receivers to collect events from many non-test tasks +pub struct TestTask { + /// Task which handles test events + task: Task, + /// Receivers for outside events + message_receivers: Vec>, +} - /// register a message with the task - #[must_use] - pub(crate) fn register_message_stream(self, stream: HSTT::MessageStream) -> Self { +impl< + S: TaskState + Send + 'static, + T: TestTaskState + Send + Sync + 'static, + > TestTask +{ + /// Create a test task + pub fn new(task: Task, rxs: Vec>) -> Self { Self { - message_stream: Some(Box::pin(stream.fuse())), - ..self + task, + message_receivers: rxs, } } + /// Runs the task, taking events from the the test events and the message receivers. + /// Consumes self and runs until some shutdown condition is met. + /// The join handle will return the result of the task, useful for deciding if the test + /// passed or not. + pub fn run(mut self) -> JoinHandle { + spawn(async move { + loop { + let mut futs = vec![]; + + if let Ok(event) = self.task.event_receiver.try_recv() { + if S::should_shutdown(&event) { + self.task.state.shutdown().await; + tracing::error!("Shutting down test task TODO!"); + todo!(); + } + if !self.state().filter(&event) { + if let Some(res) = S::handle_event(event, &mut self.task).await { + self.task.state.handle_result(&res).await; + self.task.state.shutdown().await; + return res; + } + } + } - /// register state with the task - #[must_use] - pub(crate) fn register_state(self, state: HSTT::State) -> Self { - Self { - state: Some(state), - ..self - } + for rx in &mut self.message_receivers { + futs.push(rx.recv()); + } + // if let Ok((Ok(msg), id, _)) = + match async_timeout(Duration::from_secs(1), select_all(futs)).await { + Ok((Ok(msg), id, _)) => { + if let Some(res) = T::handle_message(msg, id, &mut self).await { + self.task.state.handle_result(&res).await; + self.task.state.shutdown().await; + return res; + } + } + Err(e) => { + error!("Failed to get event from task. Error: {:?}", e); + } + Ok((Err(e), _, _)) => { + error!("A task channel returned an Error: {:?}", e); + } + } + } + }) } - /// register with the registry - pub(crate) async fn register_registry(self, registry: &mut GlobalRegistry) -> Self { - let (shutdown_fn, id) = registry.register(&self.name, self.status.clone()).await; - let mut shutdown_fns = self.shutdown_fns; - shutdown_fns.push(shutdown_fn); - Self { - shutdown_fns, - tid: Some(id), - ..self - } + /// Get a ref to state + pub fn state(&self) -> &S { + &self.task.state } - - /// create a new task - pub(crate) fn new(name: String) -> Self { - Self { - stream_id: None, - r_val: None, - name, - status: TaskState::new(), - event_stream: None, - state: None, - handle_event: None, - handle_message: None, - shutdown_fns: vec![], - message_stream: None, - in_progress_fut: None, - in_progress_shutdown_fut: None, - tid: None, - } + /// Get a mutable ref to state + pub fn state_mut(&mut self) -> &mut S { + self.task.state_mut() } - - /// launch the task - /// NOTE: the only way to get a `HST` is by usage - /// of one of the impls. Those all have checks enabled. - /// So, it should be safe to launch. - pub fn launch(self) -> BoxFuture<'static, HotShotTaskCompleted> { - Box::pin(self) + /// Send an event to other listening test tasks + /// + /// # Panics + /// panics if the event can't be sent (ok to panic in test) + pub async fn send_event(&self, event: S::Event) { + self.task.send(event).await.unwrap(); } } -/// enum describing how the tasks completed -pub enum HotShotTaskCompleted { - /// the task shut down successfully - ShutDown, - /// the task encountered an error - Error(Box), - /// the streams the task was listening for died - StreamsDied, - /// we somehow lost the state - /// this is definitely a bug. - LostState, - /// lost the return value somehow - LostReturnValue, - /// Stream exists but missing handler - MissingHandler, +#[derive(Default)] +/// A collection of tasks which can handle shutdown +pub struct TaskRegistry { + /// Tasks this registry controls + task_handles: RwLock>>, } -impl std::fmt::Debug for HotShotTaskCompleted { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - HotShotTaskCompleted::ShutDown => f.write_str("HotShotTaskCompleted::ShutDown"), - HotShotTaskCompleted::Error(_) => f.write_str("HotShotTaskCompleted::Error"), - HotShotTaskCompleted::StreamsDied => f.write_str("HotShotTaskCompleted::StreamsDied"), - HotShotTaskCompleted::LostState => f.write_str("HotShotTaskCompleted::LostState"), - HotShotTaskCompleted::LostReturnValue => { - f.write_str("HotShotTaskCompleted::LostReturnValue") - } - HotShotTaskCompleted::MissingHandler => { - f.write_str("HotShotTaskCompleted::MissingHandler") - } - } +impl TaskRegistry { + /// Add a task to the registry + pub async fn register(&self, handle: JoinHandle<()>) { + self.task_handles.write().await.push(handle); } -} - -impl PartialEq for HotShotTaskCompleted { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Error(_l0), Self::Error(_r0)) => false, - _ => core::mem::discriminant(self) == core::mem::discriminant(other), + /// Try to cancel/abort the task this registry has + pub async fn shutdown(&self) { + let mut handles = self.task_handles.write().await; + while let Some(handle) = handles.pop() { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); } } -} - -impl<'pin, HSTT: HotShotTaskTypes> ProjectedHST<'pin, HSTT> { - /// launches the shutdown future - fn launch_shutdown_fut(&mut self, cx: &mut Context<'_>) -> Poll { - let fut = self.create_shutdown_fut(); - self.check_ip_shutdown_fut(fut, cx) + /// Take a task, run it, and register it + pub async fn run_task(&self, task: Task) + where + S: TaskState + Send + 'static, + { + self.register(task.run()).await; } - - /// checks the in progress shutdown future, `fut` - fn check_ip_shutdown_fut( - &mut self, - mut fut: Pin + Send>>, - cx: &mut Context<'_>, - ) -> Poll { - match fut.as_mut().poll(cx) { - Poll::Ready(()) => Poll::Ready( - self.r_val - .take() - .unwrap_or_else(|| HotShotTaskCompleted::LostReturnValue), - ), - Poll::Pending => { - *self.in_progress_shutdown_fut = Some(fut); - Poll::Pending - } - } + /// Create a new `DependencyTask` run it, and register it + pub async fn spawn_dependency_task( + &self, + dep: impl Dependency + Send + 'static, + handle: impl HandleDepOutput, + ) { + let join_handle = DependencyTask { dep, handle }.run(); + self.register(join_handle).await; } - - /// creates the shutdown future and returns it - fn create_shutdown_fut(&mut self) -> Pin + Send>> { - let shutdown_fns = self.shutdown_fns.clone(); - let fut = async move { - for shutdown_fn in shutdown_fns { - shutdown_fn().await; - } - } - .boxed(); - fut + /// Wait for the results of all the tasks registered + /// # Panics + /// Panics if one of the tasks paniced + pub async fn join_all(self) -> Vec<()> { + #[cfg(async_executor_impl = "async-std")] + let ret = join_all(self.task_handles.into_inner()).await; + #[cfg(async_executor_impl = "tokio")] + let ret = try_join_all(self.task_handles.into_inner()).await.unwrap(); + ret } +} - /// check the event stream - /// returns either a poll if there's a future IP - /// or a bool stating whether or not the stream is finished - fn check_event_stream( - &mut self, - cx: &mut Context<'_>, - ) -> Either, bool> { - let event_stream = self.event_stream.take(); - if let Some(mut inner_event_stream) = event_stream { - while let Poll::Ready(maybe_event) = inner_event_stream.as_mut().poll_next(cx) { - if let Some(event) = maybe_event { - if let Some(handle_event) = self.handle_event { - let maybe_state = self.state.take(); - if let Some(state) = maybe_state { - let mut fut = handle_event(event, state); - match fut.as_mut().poll(cx) { - Poll::Ready((result, state)) => { - if let Some(completed) = result { - *self.in_progress_fut = None; - *self.state = Some(state); - *self.r_val = Some(completed); - let result = self.launch_shutdown_fut(cx); - *self.event_stream = Some(inner_event_stream); - return Left(result); - } - // run a yield to tell the executor to go do work on other - // tasks if they are available - // this is necessary otherwise we could end up with one - // task that returns really quickly blocking the executor - // from dealing with other tasks. - let mut fut = async move { - async_yield_now().await; - (None, state) - } - .boxed(); - // if the executor has no extra work to do, - // continue to poll the event stream - if let Poll::Ready((_, state)) = fut.as_mut().poll(cx) { - *self.state = Some(state); - *self.in_progress_fut = None; - // NOTE: don't need to set event stream because - // that will be done on the next iteration - continue; - } - // otherwise, return pending and finish executing the - // yield later - *self.event_stream = Some(inner_event_stream); - *self.in_progress_fut = Some(fut); - return Left(Poll::Pending); - } - Poll::Pending => { - *self.in_progress_fut = Some(fut); - *self.event_stream = Some(inner_event_stream); - return Left(Poll::Pending); - } - } - } - // lost state case - *self.r_val = Some(HotShotTaskCompleted::LostState); - let result = self.launch_shutdown_fut(cx); - *self.event_stream = Some(inner_event_stream); - return Left(result); - } - // no handler case - *self.r_val = Some(HotShotTaskCompleted::MissingHandler); - let result = self.launch_shutdown_fut(cx); - *self.event_stream = Some(inner_event_stream); - return Left(result); - } - // this is a fused future so `None` will come every time after the stream - // finishes - *self.event_stream = Some(inner_event_stream); - return Right(true); - } - *self.event_stream = Some(inner_event_stream); - return Right(false); - } - // stream doesn't exist so trivially true - *self.event_stream = event_stream; - Right(true) +#[cfg(test)] +mod tests { + use super::*; + use async_broadcast::broadcast; + #[cfg(async_executor_impl = "async-std")] + use async_std::task::sleep; + use std::{collections::HashSet, time::Duration}; + #[cfg(async_executor_impl = "tokio")] + use tokio::time::sleep; + + #[derive(Default)] + pub struct DummyHandle { + val: usize, + seen: HashSet, } - /// check the message stream - /// returns either a poll if there's a future IP - /// or a bool stating whether or not the stream is finished - fn check_message_stream( - &mut self, - cx: &mut Context<'_>, - ) -> Either, bool> { - let message_stream = self.message_stream.take(); - if let Some(mut inner_message_stream) = message_stream { - while let Poll::Ready(maybe_msg) = inner_message_stream.as_mut().poll_next(cx) { - if let Some(msg) = maybe_msg { - if let Some(handle_msg) = self.handle_message { - let maybe_state = self.state.take(); - if let Some(state) = maybe_state { - let mut fut = handle_msg(msg, state); - match fut.as_mut().poll(cx) { - Poll::Ready((result, state)) => { - if let Some(completed) = result { - *self.in_progress_fut = None; - *self.state = Some(state); - *self.r_val = Some(completed); - let result = self.launch_shutdown_fut(cx); - *self.message_stream = Some(inner_message_stream); - return Left(result); - } - // run a yield to tell the executor to go do work on other - // tasks if they are available - // this is necessary otherwise we could end up with one - // task that returns really quickly blocking the executor - // from dealing with other tasks. - let mut fut = async move { - async_yield_now().await; - (None, state) - } - .boxed(); - // if the executor has no extra work to do, - // continue to poll the event stream - if let Poll::Ready((_, state)) = fut.as_mut().poll(cx) { - *self.state = Some(state); - *self.in_progress_fut = None; - // NOTE: don't need to set event stream because - // that will be done on the next iteration - continue; - } - // otherwise, return pending and finish executing the - // yield later - *self.message_stream = Some(inner_message_stream); - *self.in_progress_fut = Some(fut); - return Left(Poll::Pending); - } - Poll::Pending => { - *self.in_progress_fut = Some(fut); - *self.message_stream = Some(inner_message_stream); - return Left(Poll::Pending); - } - } - } - // lost state case - *self.r_val = Some(HotShotTaskCompleted::LostState); - let result = self.launch_shutdown_fut(cx); - *self.message_stream = Some(inner_message_stream); - return Left(result); - } - // no handler case - *self.r_val = Some(HotShotTaskCompleted::MissingHandler); - let result = self.launch_shutdown_fut(cx); - *self.message_stream = Some(inner_message_stream); - return Left(result); - } - // this is a fused future so `None` will come every time after the stream - // finishes - *self.message_stream = Some(inner_message_stream); - return Right(true); + #[allow(clippy::panic)] + impl TaskState for DummyHandle { + type Event = usize; + type Output = (); + async fn handle_event(event: usize, task: &mut Task) -> Option<()> { + sleep(Duration::from_millis(10)).await; + let state = task.state_mut(); + state.seen.insert(event); + if event > state.val { + state.val = event; + assert!( + state.val < 100, + "Test should shutdown before getting an event for 100" + ); + task.send(event + 1).await.unwrap(); } - *self.message_stream = Some(inner_message_stream); - return Right(false); + None } - // stream doesn't exist so trivially true - *self.message_stream = message_stream; - Right(true) - } -} - -// NOTE: this is a Future, but it could easily be a stream. -// but these are semantically equivalent because instead of -// returning when paused, we just return `Poll::Pending` -impl Future for HST { - type Output = HotShotTaskCompleted; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut projected = self.as_mut().project(); - - if let Some(fut) = projected.in_progress_shutdown_fut.take() { - return projected.check_ip_shutdown_fut(fut, cx); + fn should_shutdown(event: &usize) -> bool { + *event >= 98 } - - // check if task is complete - if let Some(state_change) = projected.status.as_mut().try_next() { - match state_change { - TaskStatus::NotStarted | TaskStatus::Paused => { - return Poll::Pending; - } - TaskStatus::Running => {} - TaskStatus::Completed => { - *projected.r_val = Some(HotShotTaskCompleted::ShutDown); - return projected.launch_shutdown_fut(cx); - } + async fn shutdown(&mut self) { + for i in 1..98 { + assert!(self.seen.contains(&i)); } } + } - // check if there's an in progress future - if let Some(in_progress_fut) = projected.in_progress_fut { - match in_progress_fut.as_mut().poll(cx) { - Poll::Ready((result, state)) => { - *projected.in_progress_fut = None; - *projected.state = Some(state); - // if the future errored out, return it, we're done - if let Some(completed) = result { - *projected.r_val = Some(completed); - return projected.launch_shutdown_fut(cx); - } - } - Poll::Pending => { - return Poll::Pending; - } + impl TestTaskState for DummyHandle { + type Message = String; + type Output = (); + type State = Self; + + async fn handle_message( + message: Self::Message, + _: usize, + _: &mut TestTask, + ) -> Option<()> { + if message == *"done".to_string() { + return Some(()); } + None } - - let event_stream_finished = match projected.check_event_stream(cx) { - Left(result) => return result, - Right(finished) => finished, + } + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + #[allow(unused_must_use)] + async fn it_works() { + let reg = Arc::new(TaskRegistry::default()); + let (tx, rx) = broadcast(10); + let task1 = Task:: { + event_sender: tx.clone(), + event_receiver: rx.clone(), + registry: reg.clone(), + state: DummyHandle::default(), + }; + tx.broadcast(1).await.unwrap(); + let task2 = Task:: { + event_sender: tx.clone(), + event_receiver: rx, + registry: reg, + state: DummyHandle::default(), }; + let handle = task2.run(); + let _res = task1.run().await; + handle.await; + } - let message_stream_finished = match projected.check_message_stream(cx) { - Left(result) => return result, - Right(finished) => finished, + #[cfg_attr( + async_executor_impl = "tokio", + tokio::test(flavor = "multi_thread", worker_threads = 10) + )] + #[cfg_attr(async_executor_impl = "async-std", async_std::test)] + #[allow(clippy::should_panic_without_expect)] + #[should_panic] + async fn test_works() { + let reg = Arc::new(TaskRegistry::default()); + let (tx, rx) = broadcast(10); + let (msg_tx, msg_rx) = broadcast(10); + let task1 = Task:: { + event_sender: tx.clone(), + event_receiver: rx.clone(), + registry: reg.clone(), + state: DummyHandle::default(), + }; + tx.broadcast(1).await.unwrap(); + let task2 = Task:: { + event_sender: tx.clone(), + event_receiver: rx, + registry: reg, + state: DummyHandle::default(), + }; + let test1 = TestTask::<_, DummyHandle> { + task: task1, + message_receivers: vec![msg_rx.clone()], + }; + let test2 = TestTask::<_, DummyHandle> { + task: task2, + message_receivers: vec![msg_rx.clone()], }; - if message_stream_finished && event_stream_finished { - tracing::error!("Message and event stream both finished!"); - *projected.r_val = Some(HotShotTaskCompleted::StreamsDied); - let result = projected.launch_shutdown_fut(cx); - return result; + let handle = test1.run(); + let handle2 = test2.run(); + sleep(Duration::from_millis(30)).await; + msg_tx.broadcast("done".into()).await.unwrap(); + #[cfg(async_executor_impl = "tokio")] + { + handle.await.unwrap(); + handle2.await.unwrap(); + } + #[cfg(async_executor_impl = "async-std")] + { + handle.await; + handle2.await; } - - Poll::Pending } } diff --git a/crates/task/src/task_impls.rs b/crates/task/src/task_impls.rs deleted file mode 100644 index 768e011775..0000000000 --- a/crates/task/src/task_impls.rs +++ /dev/null @@ -1,457 +0,0 @@ -use futures::Stream; -use std::marker::PhantomData; - -use crate::{ - event_stream::{DummyStream, EventStream, SendableStream, StreamId}, - global_registry::{GlobalRegistry, HotShotTaskId}, - task::{ - FilterEvent, HandleEvent, HandleMessage, HotShotTaskHandler, HotShotTaskTypes, PassType, - TaskErr, HST, TS, - }, -}; - -/// trait to specify features -pub trait ImplMessageStream {} - -/// trait to specify features -pub trait ImplEventStream {} - -/// builder for task -pub struct TaskBuilder(HST); - -impl TaskBuilder { - /// register an event handler - #[must_use] - pub fn register_event_handler(self, handler: HandleEvent) -> Self - where - HSTT: ImplEventStream, - { - Self( - self.0 - .register_handler(HotShotTaskHandler::HandleEvent(handler)), - ) - } - - /// obtains stream id if it exists - pub fn get_stream_id(&self) -> Option { - self.0.stream_id - } - - /// register a message handler - #[must_use] - pub fn register_message_handler(self, handler: HandleMessage) -> Self - where - HSTT: ImplMessageStream, - { - Self( - self.0 - .register_handler(HotShotTaskHandler::HandleMessage(handler)), - ) - } - - /// register a message stream - #[must_use] - pub fn register_message_stream(self, stream: HSTT::MessageStream) -> Self - where - HSTT: ImplMessageStream, - { - Self(self.0.register_message_stream(stream)) - } - - /// register an event stream - pub async fn register_event_stream( - self, - stream: HSTT::EventStream, - filter: FilterEvent, - ) -> Self - where - HSTT: ImplEventStream, - { - Self(self.0.register_event_stream(stream, filter).await) - } - - /// register the state - #[must_use] - pub fn register_state(self, state: HSTT::State) -> Self { - Self(self.0.register_state(state)) - } - - /// register with the global registry - pub async fn register_registry(self, registry: &mut GlobalRegistry) -> Self { - Self(self.0.register_registry(registry).await) - } - - /// get the task id in the global registry - pub fn get_task_id(&self) -> Option { - self.0.tid - } - - /// create a new task builder - #[must_use] - pub fn new(name: String) -> Self { - Self(HST::new(name)) - } -} - -/// a hotshot task with an event stream -pub struct HSTWithEvent< - ERR: std::error::Error, - EVENT: PassType, - ESTREAM: EventStream, - STATE: TS, -> { - /// phantom data - _pd: PhantomData<(ERR, EVENT, ESTREAM, STATE)>, -} - -impl< - ERR: std::error::Error, - EVENT: PassType, - ESTREAM: EventStream, - STATE: TS, - > ImplEventStream for HSTWithEvent -{ -} - -impl, STATE: TS> - ImplMessageStream for HSTWithMessage -{ -} - -impl, STATE: TS> - HotShotTaskTypes for HSTWithEvent -{ - type Event = EVENT; - type State = STATE; - type EventStream = ESTREAM; - type Message = (); - type MessageStream = DummyStream; - type Error = ERR; - - fn build(builder: TaskBuilder) -> HST - where - Self: Sized, - { - builder.0.base_check(); - builder.0.event_check(); - builder.0 - } -} - -/// a hotshot task with a message -pub struct HSTWithMessage< - ERR: std::error::Error, - MSG: PassType, - MSTREAM: Stream, - STATE: TS, -> { - /// phantom data - _pd: PhantomData<(ERR, MSG, MSTREAM, STATE)>, -} - -impl, STATE: TS> HotShotTaskTypes - for HSTWithMessage -{ - type Event = (); - type State = STATE; - type EventStream = DummyStream; - type Message = MSG; - type MessageStream = MSTREAM; - type Error = ERR; - - fn build(builder: TaskBuilder) -> HST - where - Self: Sized, - { - builder.0.base_check(); - builder.0.message_check(); - builder.0 - } -} - -/// hotshot task with even and message -pub struct HSTWithEventAndMessage< - ERR: std::error::Error, - EVENT: PassType, - ESTREAM: EventStream, - MSG: PassType, - MSTREAM: Stream, - STATE: TS, -> { - /// phantom data - _pd: PhantomData<(ERR, EVENT, ESTREAM, MSG, MSTREAM, STATE)>, -} - -impl< - ERR: std::error::Error, - EVENT: PassType, - ESTREAM: EventStream, - MSG: PassType, - MSTREAM: Stream, - STATE: TS, - > ImplEventStream for HSTWithEventAndMessage -{ -} - -impl< - ERR: std::error::Error, - EVENT: PassType, - ESTREAM: EventStream, - MSG: PassType, - MSTREAM: Stream, - STATE: TS, - > ImplMessageStream for HSTWithEventAndMessage -{ -} - -impl< - ERR: TaskErr, - EVENT: PassType, - ESTREAM: EventStream, - MSG: PassType, - MSTREAM: SendableStream, - STATE: TS, - > HotShotTaskTypes for HSTWithEventAndMessage -{ - type Event = EVENT; - type State = STATE; - type EventStream = ESTREAM; - type Message = MSG; - type MessageStream = MSTREAM; - type Error = ERR; - - fn build(builder: TaskBuilder) -> HST - where - Self: Sized, - { - builder.0.base_check(); - builder.0.message_check(); - builder.0.event_check(); - builder.0 - } -} - -#[cfg(test)] -pub mod test { - use async_compatibility_layer::channel::{unbounded, UnboundedStream}; - use snafu::Snafu; - - use crate::{event_stream, event_stream::ChannelStream, task::TS}; - - use super::{HSTWithEvent, HSTWithEventAndMessage, HSTWithMessage}; - use crate::{event_stream::EventStream, task::HotShotTaskTypes, task_impls::TaskBuilder}; - use async_compatibility_layer::art::async_spawn; - use futures::FutureExt; - use std::sync::Arc; - - use crate::{ - global_registry::GlobalRegistry, - task::{FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted}, - }; - use async_compatibility_layer::logging::setup_logging; - - #[derive(Snafu, Debug)] - pub struct Error {} - - #[derive(Clone, Debug, Eq, PartialEq, Hash)] - pub struct State {} - - #[derive(Clone, Debug, Eq, PartialEq, Hash, Default)] - pub struct CounterState { - num_events_recved: u64, - } - - #[derive(Clone, Debug, Eq, PartialEq, Hash)] - pub enum Event { - Finished, - Dummy, - } - - impl TS for State {} - - impl TS for CounterState {} - - #[derive(Clone, Debug, PartialEq, Eq, Hash)] - pub enum Message { - Finished, - Dummy, - } - - // TODO fill in generics for stream - - pub type AppliedHSTWithEvent = HSTWithEvent, State>; - pub type AppliedHSTWithEventCounterState = - HSTWithEvent, CounterState>; - pub type AppliedHSTWithMessage = - HSTWithMessage, State>; - pub type AppliedHSTWithEventMessage = HSTWithEventAndMessage< - Error, - Event, - ChannelStream, - Message, - UnboundedStream, - State, - >; - - #[cfg(test)] - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - #[allow(clippy::should_panic_without_expect)] - #[should_panic] - async fn test_init_with_event_stream() { - setup_logging(); - let task = TaskBuilder::::new("Test Task".to_string()); - AppliedHSTWithEvent::build(task).launch().await; - } - - // TODO this should be moved to async-compatibility-layer - #[cfg(test)] - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_channel_stream() { - use futures::StreamExt; - let (s, r) = unbounded(); - let mut stream: UnboundedStream = r.into_stream(); - s.send(Message::Dummy).await.unwrap(); - s.send(Message::Finished).await.unwrap(); - assert!(stream.next().await.unwrap() == Message::Dummy); - assert!(stream.next().await.unwrap() == Message::Finished); - } - - #[cfg(test)] - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_task_with_event_stream() { - setup_logging(); - let event_stream: event_stream::ChannelStream = event_stream::ChannelStream::new(); - let mut registry = GlobalRegistry::new(); - - let mut task_runner = crate::task_launcher::TaskRunner::default(); - - for i in 0..10000 { - let state = CounterState::default(); - let event_handler = HandleEvent(Arc::new(move |event, mut state: CounterState| { - async move { - if let Event::Dummy = event { - state.num_events_recved += 1; - } - - if state.num_events_recved == 100 { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - (None, state) - } - } - .boxed() - })); - let name = format!("Test Task {i:?}").to_string(); - let built_task = TaskBuilder::::new(name.clone()) - .register_event_stream(event_stream.clone(), FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_state(state) - .register_event_handler(event_handler); - let id = built_task.get_task_id().unwrap(); - let result = AppliedHSTWithEventCounterState::build(built_task).launch(); - task_runner = task_runner.add_task(id, name, result); - } - - async_spawn(async move { - for _ in 0..100 { - event_stream.publish(Event::Dummy).await; - } - }); - - let results = task_runner.launch().await; - for result in results { - assert!(result.1 == HotShotTaskCompleted::ShutDown); - } - } - - #[cfg(test)] - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_task_with_event_stream_xtreme() { - setup_logging(); - let event_stream: event_stream::ChannelStream = event_stream::ChannelStream::new(); - - let state = State {}; - - let mut registry = GlobalRegistry::new(); - - let event_handler = HandleEvent(Arc::new(move |event, state| { - async move { - if let Event::Finished = event { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - (None, state) - } - } - .boxed() - })); - - let built_task = TaskBuilder::::new("Test Task".to_string()) - .register_event_stream(event_stream.clone(), FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_state(state) - .register_event_handler(event_handler); - event_stream.publish(Event::Dummy).await; - event_stream.publish(Event::Dummy).await; - event_stream.publish(Event::Finished).await; - AppliedHSTWithEvent::build(built_task).launch().await; - } - - #[cfg(test)] - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 2) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - async fn test_task_with_message_stream() { - setup_logging(); - let state = State {}; - - let mut registry = GlobalRegistry::new(); - - let (s, r) = async_compatibility_layer::channel::unbounded(); - - let message_handler = HandleMessage(Arc::new(move |message, state| { - async move { - if let Message::Finished = message { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - (None, state) - } - } - .boxed() - })); - - let built_task = TaskBuilder::::new("Test Task".to_string()) - .register_message_handler(message_handler) - .register_message_stream(r.into_stream()) - .register_registry(&mut registry) - .await - .register_state(state); - async_spawn(async move { - s.send(Message::Dummy).await.unwrap(); - s.send(Message::Finished).await.unwrap(); - }); - let result = AppliedHSTWithMessage::build(built_task).launch().await; - assert!(result == HotShotTaskCompleted::ShutDown); - } -} diff --git a/crates/task/src/task_launcher.rs b/crates/task/src/task_launcher.rs deleted file mode 100644 index deff065af2..0000000000 --- a/crates/task/src/task_launcher.rs +++ /dev/null @@ -1,68 +0,0 @@ -use futures::future::{join_all, BoxFuture}; - -use crate::{ - global_registry::{GlobalRegistry, HotShotTaskId}, - task::HotShotTaskCompleted, -}; - -// TODO use genericarray + typenum to make this use the number of tasks as a parameter -/// runner for tasks -/// `N` specifies the number of tasks to ensure that the user -/// doesn't forget how many tasks they wished to add. -pub struct TaskRunner -// < -// const N: usize, -// > -{ - /// internal set of tasks to launch - tasks: Vec<( - HotShotTaskId, - String, - BoxFuture<'static, HotShotTaskCompleted>, - )>, - /// global registry - pub registry: GlobalRegistry, -} - -impl Default for TaskRunner { - fn default() -> Self { - Self::new() - } -} - -impl TaskRunner /* */ { - /// create new runner - #[must_use] - pub fn new() -> Self { - Self { - tasks: Vec::new(), - registry: GlobalRegistry::new(), - } - } - - // `name` is for logging purposes only and may be duplicated or inconsistent. - /// to support builder pattern - #[must_use] - pub fn add_task( - mut self, - id: HotShotTaskId, - name: String, - task: BoxFuture<'static, HotShotTaskCompleted>, - ) -> TaskRunner { - self.tasks.push((id, name, task)); - self - } - - /// returns a `Vec` because type isn't known - #[must_use] - pub async fn launch(self) -> Vec<(String, HotShotTaskCompleted)> { - let names = self - .tasks - .iter() - .map(|(_id, name, _)| name.clone()) - .collect::>(); - let result = join_all(self.tasks.into_iter().map(|(_, _, task)| task)).await; - - names.into_iter().zip(result).collect::>() - } -} diff --git a/crates/task/src/task_state.rs b/crates/task/src/task_state.rs deleted file mode 100644 index 01758965a1..0000000000 --- a/crates/task/src/task_state.rs +++ /dev/null @@ -1,182 +0,0 @@ -use atomic_enum::atomic_enum; -use serde::{Deserialize, Serialize}; -use std::{ - fmt::Debug, - sync::{atomic::Ordering, Arc}, -}; - -/// Nit: wish this was for u8 but sadly no -/// Represents the status of a hotshot task -#[atomic_enum] -#[derive(Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] -pub enum TaskStatus { - /// the task hasn't started running - NotStarted = 0, - /// the task is running - Running = 1, - /// NOTE: not useful generally, but VERY useful for byzantine nodes - /// and testing malfunctions - /// we'll have a granular way to, from the registry, stop a task momentarily - /// and inspect/modify its state - Paused = 2, - /// the task completed - Completed = 3, -} - -/// The state of a task -/// `AtomicTaskStatus` + book keeping to notify btwn tasks -#[derive(Clone)] -pub struct TaskState { - /// previous status - prev: Arc, - /// next status - next: Arc, - // using `std::sync::mutex` here because it's faster than async's version - // wakers: Arc>>, -} - -impl Debug for TaskState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TaskState") - .field("status", &self.get_status()) - .finish() - } -} -impl Default for TaskState { - fn default() -> Self { - Self::new() - } -} - -impl TaskState { - /// create a new state - #[must_use] - pub fn new() -> Self { - Self { - prev: Arc::new(TaskStatus::NotStarted.into()), - next: Arc::new(TaskStatus::NotStarted.into()), - // wakers: Arc::default(), - } - } - - /// create a task state from a task status - #[must_use] - pub fn from_status(state: Arc) -> Self { - let prev_state = AtomicTaskStatus::new(state.load(Ordering::SeqCst)); - Self { - prev: Arc::new(prev_state), - next: state, - // wakers: Arc::default(), - } - } - - /// sets the state - /// # Panics - /// should never panic unless internally a lock poison happens - /// this should NOT be possible - pub fn set_state(&self, state: TaskStatus) { - self.next.swap(state, Ordering::SeqCst); - // no panics, so can never be poisoned. - // let mut wakers = self.wakers.lock().unwrap(); - - // drain the wakers - // for waker in wakers.drain(..) { - // waker.wake(); - // } - } - /// gets a possibly stale version of the state - #[must_use] - pub fn get_status(&self) -> TaskStatus { - self.next.load(Ordering::SeqCst) - } -} - -// GNARLY bug @jbearer found -// cx gets *really* large in some cases -// impl Stream for TaskState { -// type Item = TaskStatus; -// -// #[unstable] -// fn poll_next( -// self: std::pin::Pin<&mut Self>, -// cx: &mut std::task::Context<'_>, -// ) -> std::task::Poll> { -// let next = self.next.load(Ordering::SeqCst); -// let prev = self.prev.swap(next, Ordering::SeqCst); -// // a new value has been set -// if prev == next { -// // no panics, so impossible to be poisoned -// self.wakers.lock().unwrap().push(cx.waker().clone()); -// -// // no value has been set, poll again later -// std::task::Poll::Pending -// } else { -// std::task::Poll::Ready(Some(next)) -// } -// } -// } - -impl TaskState { - /// Try to get the next task status. - #[must_use] - pub fn try_next(self: std::pin::Pin<&mut Self>) -> Option { - let next = self.next.load(Ordering::SeqCst); - let prev = self.prev.swap(next, Ordering::SeqCst); - // a new value has been set - if prev == next { - None - } else { - // drain the wakers to wake up the stream. - // we did change value - // let mut wakers = self.wakers.lock().unwrap(); - // for waker in wakers.drain(..) { - // waker.wake(); - // } - Some(next) - } - } -} - -#[cfg(test)] -pub mod test { - - // #[cfg(test)] - // #[cfg_attr( - // feature = "tokio-executor", - // tokio::test(flavor = "multi_thread", worker_threads = 2) - // )] - // #[cfg_attr(feature = "async-std-executor", async_std::test)] - // async fn test_state_stream() { - // setup_logging(); - // - // let mut task = crate::task_state::TaskState::new(); - // - // let task_dup = task.clone(); - // - // async_spawn(async move { - // async_sleep(std::time::Duration::from_secs(1)).await; - // task_dup.set_state(crate::task_state::TaskStatus::Running); - // async_sleep(std::time::Duration::from_secs(1)).await; - // task_dup.set_state(crate::task_state::TaskStatus::Paused); - // async_sleep(std::time::Duration::from_secs(1)).await; - // task_dup.set_state(crate::task_state::TaskStatus::Completed); - // }); - // - // // spawn new task that sleeps then increments - // - // assert_eq!( - // task.try_next().unwrap(), - // crate::task_state::TaskStatus::Running - // ); - // assert_eq!( - // task.next().unwrap(), - // crate::task_state::TaskStatus::Paused - // ); - // assert_eq!( - // task.next().unwrap(), - // crate::task_state::TaskStatus::Completed - // ); - // } - // TODO test global registry using either global + lazy_static - // or passing around global registry -} diff --git a/crates/testing/Cargo.toml b/crates/testing/Cargo.toml index f202f94117..ceb54a410f 100644 --- a/crates/testing/Cargo.toml +++ b/crates/testing/Cargo.toml @@ -11,6 +11,7 @@ default = [] slow-tests = [] [dependencies] +async-broadcast = { workspace = true } async-compatibility-layer = { workspace = true } sha3 = "^0.10" bincode = { workspace = true } @@ -24,7 +25,6 @@ hotshot-constants = { path = "../constants" } hotshot-types = { path = "../types", default-features = false } hotshot-utils = { path = "../utils" } hotshot-orchestrator = { version = "0.1.1", path = "../orchestrator", default-features = false } -hotshot-task = { path = "../task", version = "0.1.0", default-features = false } hotshot-task-impls = { path = "../task-impls", version = "0.1.0", default-features = false } rand = { workspace = true } snafu = { workspace = true } @@ -34,11 +34,10 @@ sha2 = { workspace = true } async-lock = { workspace = true } bitvec = { workspace = true } ethereum-types = { workspace = true } +hotshot-task = { path = "../task" } [target.'cfg(all(async_executor_impl = "tokio"))'.dependencies] tokio = { workspace = true } [target.'cfg(all(async_executor_impl = "async-std"))'.dependencies] async-std = { workspace = true } -[lints] -workspace = true diff --git a/crates/testing/src/completion_task.rs b/crates/testing/src/completion_task.rs index e5367cb8bd..94efb83b3c 100644 --- a/crates/testing/src/completion_task.rs +++ b/crates/testing/src/completion_task.rs @@ -1,21 +1,19 @@ -use std::{sync::Arc, time::Duration}; +#[cfg(async_executor_impl = "async-std")] +use async_std::task::JoinHandle; +use std::time::Duration; +#[cfg(async_executor_impl = "tokio")] +use tokio::task::JoinHandle; -use async_compatibility_layer::art::async_sleep; -use futures::FutureExt; +use async_broadcast::{Receiver, Sender}; +use async_compatibility_layer::art::{async_spawn, async_timeout}; use hotshot::traits::TestableNodeImplementation; -use hotshot_task::{ - boxed_sync, - event_stream::{ChannelStream, EventStream}, - task::{FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted, HotShotTaskTypes, TS}, - task_impls::{HSTWithEventAndMessage, TaskBuilder}, - GeneratedStream, -}; +use hotshot_task_impls::helpers::broadcast_event; use hotshot_types::traits::node_implementation::NodeType; use snafu::Snafu; -use crate::test_runner::Node; +use crate::test_runner::{HotShotTaskCompleted, Node}; -use super::{test_launcher::TaskGenerator, GlobalTestEvent}; +use super::GlobalTestEvent; /// the idea here is to run as long as we want @@ -25,24 +23,39 @@ pub struct CompletionTaskErr {} /// Completion task state pub struct CompletionTask> { - /// the test level event stream - pub(crate) test_event_stream: ChannelStream, + pub tx: Sender, + + pub rx: Receiver, /// handles to the nodes in the test pub(crate) handles: Vec>, + /// Duration of the task. + pub duration: Duration, } -impl> TS for CompletionTask {} - -/// Completion task types -pub type CompletionTaskTypes = HSTWithEventAndMessage< - CompletionTaskErr, - GlobalTestEvent, - ChannelStream, - (), - GeneratedStream<()>, - CompletionTask, ->; - +impl> CompletionTask { + pub fn run(mut self) -> JoinHandle { + async_spawn(async move { + if async_timeout(self.duration, self.wait_for_shutdown()) + .await + .is_err() + { + broadcast_event(GlobalTestEvent::ShutDown, &self.tx).await; + } + for node in &self.handles { + node.handle.clone().shut_down().await; + } + HotShotTaskCompleted::ShutDown + }) + } + async fn wait_for_shutdown(&mut self) { + while let Ok(event) = self.rx.recv_direct().await { + if matches!(event, GlobalTestEvent::ShutDown) { + tracing::error!("Completion Task shutting down"); + return; + } + } + } +} /// Description for a time-based completion task. #[derive(Clone, Debug)] pub struct TimeBasedCompletionTaskDescription { @@ -56,81 +69,3 @@ pub enum CompletionTaskDescription { /// Time-based completion task. TimeBasedCompletionTaskBuilder(TimeBasedCompletionTaskDescription), } - -impl CompletionTaskDescription { - /// Build and launch a completion task. - #[must_use] - pub fn build_and_launch>( - self, - ) -> TaskGenerator> { - match self { - CompletionTaskDescription::TimeBasedCompletionTaskBuilder(td) => td.build_and_launch(), - } - } -} - -impl TimeBasedCompletionTaskDescription { - /// create the task and launch it - /// # Panics - /// if cannot obtain task id after launching - #[must_use] - pub fn build_and_launch>( - self, - ) -> TaskGenerator> { - Box::new(move |state, mut registry, test_event_stream| { - async move { - let event_handler = - HandleEvent::>(Arc::new(move |event, state| { - async move { - match event { - GlobalTestEvent::ShutDown => { - for node in &state.handles { - node.handle.clone().shut_down().await; - } - (Some(HotShotTaskCompleted::ShutDown), state) - } - } - } - .boxed() - })); - let message_handler = - HandleMessage::>(Arc::new(move |(), state| { - async move { - state - .test_event_stream - .publish(GlobalTestEvent::ShutDown) - .await; - for node in &state.handles { - node.handle.clone().shut_down().await; - } - (Some(HotShotTaskCompleted::ShutDown), state) - } - .boxed() - })); - // normally I'd say "let's use Interval from async-std!" - // but doing this is easier than unifying async-std with tokio's slightly different - // interval abstraction - let stream_generator = GeneratedStream::new(Arc::new(move || { - let fut = async move { - async_sleep(self.duration).await; - }; - Some(boxed_sync(fut)) - })); - let builder = TaskBuilder::>::new( - "Test Completion Task".to_string(), - ) - .register_event_stream(test_event_stream, FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_state(state) - .register_event_handler(event_handler) - .register_message_handler(message_handler) - .register_message_stream(stream_generator); - let task_id = builder.get_task_id().unwrap(); - (task_id, CompletionTaskTypes::build(builder).launch()) - } - .boxed() - }) - } -} diff --git a/crates/testing/src/lib.rs b/crates/testing/src/lib.rs index 1c1718a918..c1a84af20a 100644 --- a/crates/testing/src/lib.rs +++ b/crates/testing/src/lib.rs @@ -6,8 +6,6 @@ deprecated = "suspicious usage of testing/demo implementations in non-test/non-debug build" )] -use hotshot_task::{event_stream::ChannelStream, task_impls::HSTWithEvent}; - /// Helpers for initializing system context handle and building tasks. pub mod task_helpers; @@ -50,15 +48,3 @@ pub enum GlobalTestEvent { /// the test is shutting down ShutDown, } - -/// the reason for shutting down the test -pub enum ShutDownReason { - /// the test is shutting down because of a safety violation - SafetyViolation, - /// the test is shutting down because the test has completed successfully - SuccessfullyCompleted, -} - -/// type alias for the type of tasks created in testing -pub type TestTask = - HSTWithEvent, STATE>; diff --git a/crates/testing/src/overall_safety_task.rs b/crates/testing/src/overall_safety_task.rs index d145a2fd58..693a0d46dd 100644 --- a/crates/testing/src/overall_safety_task.rs +++ b/crates/testing/src/overall_safety_task.rs @@ -1,15 +1,6 @@ -use async_compatibility_layer::channel::UnboundedStream; -use either::Either; -use futures::FutureExt; use hotshot::{traits::TestableNodeImplementation, HotShotError}; -use hotshot_task::{ - event_stream::ChannelStream, - task::{FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted, HotShotTaskTypes, TS}, - task_impls::{HSTWithEventAndMessage, TaskBuilder}, - MergeN, -}; -use hotshot_task::{event_stream::EventStream, Merge}; -use hotshot_task_impls::events::HotShotEvent; + +use hotshot_task::task::{Task, TaskState, TestTaskState}; use hotshot_types::{ data::{Leaf, VidCommitment}, error::RoundTimedoutState, @@ -24,7 +15,7 @@ use std::{ }; use tracing::error; -use crate::{test_launcher::TaskGenerator, test_runner::Node}; +use crate::test_runner::{HotShotTaskCompleted, Node}; /// convenience type alias for state and block pub type StateAndBlock = (Vec, Vec); @@ -77,11 +68,193 @@ pub struct OverallSafetyTask>, /// ctx pub ctx: RoundCtx, - /// event stream for publishing safety violations - pub test_event_stream: ChannelStream, + /// configure properties + pub properties: OverallSafetyPropertiesDescription, } -impl> TS for OverallSafetyTask {} +impl> TaskState + for OverallSafetyTask +{ + type Event = GlobalTestEvent; + + type Output = HotShotTaskCompleted; + + async fn handle_event(event: Self::Event, task: &mut Task) -> Option { + match event { + GlobalTestEvent::ShutDown => { + tracing::error!("Shutting down SafetyTask"); + let state = task.state_mut(); + let OverallSafetyPropertiesDescription { + check_leaf: _, + check_block: _, + num_failed_views: num_failed_rounds_total, + num_successful_views, + threshold_calculator: _, + transaction_threshold: _, + }: OverallSafetyPropertiesDescription = state.properties.clone(); + + let num_incomplete_views = state.ctx.round_results.len() + - state.ctx.successful_views.len() + - state.ctx.failed_views.len(); + + if state.ctx.successful_views.len() < num_successful_views { + return Some(HotShotTaskCompleted::Error(Box::new( + OverallSafetyTaskErr::::NotEnoughDecides { + got: state.ctx.successful_views.len(), + expected: num_successful_views, + }, + ))); + } + + if state.ctx.failed_views.len() + num_incomplete_views >= num_failed_rounds_total { + return Some(HotShotTaskCompleted::Error(Box::new( + OverallSafetyTaskErr::::TooManyFailures { + failed_views: state.ctx.failed_views.clone(), + }, + ))); + } + Some(HotShotTaskCompleted::ShutDown) + } + } + } + + fn should_shutdown(_event: &Self::Event) -> bool { + false + } +} + +impl> TestTaskState + for OverallSafetyTask +{ + type Message = Event; + + type Output = HotShotTaskCompleted; + + type State = Self; + + async fn handle_message( + message: Self::Message, + idx: usize, + task: &mut hotshot_task::task::TestTask, + ) -> Option { + let OverallSafetyPropertiesDescription { + check_leaf, + check_block, + num_failed_views, + num_successful_views, + threshold_calculator, + transaction_threshold, + }: OverallSafetyPropertiesDescription = task.state().properties.clone(); + let Event { view_number, event } = message; + let key = match event { + EventType::Error { error } => { + task.state_mut() + .ctx + .insert_error_to_context(view_number, idx, error); + None + } + EventType::Decide { + leaf_chain, + qc, + block_size: maybe_block_size, + } => { + // Skip the genesis leaf. + if leaf_chain.len() == 1 + && leaf_chain[0].get_view_number() == TYPES::Time::genesis() + { + return None; + } + let paired_up = (leaf_chain.to_vec(), (*qc).clone()); + match task.state_mut().ctx.round_results.entry(view_number) { + Entry::Occupied(mut o) => { + o.get_mut() + .insert_into_result(idx, paired_up, maybe_block_size) + } + Entry::Vacant(v) => { + let mut round_result = RoundResult::default(); + let key = round_result.insert_into_result(idx, paired_up, maybe_block_size); + v.insert(round_result); + key + } + } + } + EventType::ReplicaViewTimeout { view_number } => { + let error = Arc::new(HotShotError::::ViewTimeoutError { + view_number, + state: RoundTimedoutState::TestCollectRoundEventsTimedOut, + }); + task.state_mut() + .ctx + .insert_error_to_context(view_number, idx, error); + None + } + _ => return None, + }; + + // update view count + let threshold = + (threshold_calculator)(task.state().handles.len(), task.state().handles.len()); + + let len = task.state().handles.len(); + let view = task + .state_mut() + .ctx + .round_results + .get_mut(&view_number) + .unwrap(); + if let Some(key) = key { + view.update_status( + threshold, + len, + &key, + check_leaf, + check_block, + transaction_threshold, + ); + match view.status.clone() { + ViewStatus::Ok => { + task.state_mut().ctx.successful_views.insert(view_number); + if task.state_mut().ctx.successful_views.len() >= num_successful_views { + task.send_event(GlobalTestEvent::ShutDown).await; + return Some(HotShotTaskCompleted::ShutDown); + } + return None; + } + ViewStatus::Failed => { + task.state_mut().ctx.failed_views.insert(view_number); + if task.state_mut().ctx.failed_views.len() > num_failed_views { + task.send_event(GlobalTestEvent::ShutDown).await; + return Some(HotShotTaskCompleted::Error(Box::new( + OverallSafetyTaskErr::::TooManyFailures { + failed_views: task.state_mut().ctx.failed_views.clone(), + }, + ))); + } + return None; + } + ViewStatus::Err(e) => { + return Some(HotShotTaskCompleted::Error(Box::new(e))); + } + ViewStatus::InProgress => { + return None; + } + } + } else if view.check_if_failed(threshold, len) { + view.status = ViewStatus::Failed; + task.state_mut().ctx.failed_views.insert(view_number); + if task.state_mut().ctx.failed_views.len() > num_failed_views { + task.send_event(GlobalTestEvent::ShutDown).await; + return Some(HotShotTaskCompleted::Error(Box::new( + OverallSafetyTaskErr::::TooManyFailures { + failed_views: task.state_mut().ctx.failed_views.clone(), + }, + ))); + } + return None; + } + None + } +} /// Result of running a round of consensus #[derive(Debug)] @@ -365,249 +538,3 @@ impl Default for OverallSafetyPropertiesDescription { } } } - -impl OverallSafetyPropertiesDescription { - /// build a task - /// # Panics - /// if an internal variant that the prior views are filled is violated - #[must_use] - #[allow(clippy::too_many_lines)] - pub fn build>( - self, - ) -> TaskGenerator> { - let Self { - check_leaf, - check_block, - num_failed_views: num_failed_rounds_total, - num_successful_views, - threshold_calculator, - transaction_threshold, - }: Self = self; - - Box::new(move |mut state, mut registry, test_event_stream| { - async move { - let event_handler = HandleEvent::>(Arc::new( - move |event, state| { - async move { - match event { - GlobalTestEvent::ShutDown => { - let num_incomplete_views = state.ctx.round_results.len() - - state.ctx.successful_views.len() - - state.ctx.failed_views.len(); - - if state.ctx.successful_views.len() < num_successful_views { - return ( - Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::NotEnoughDecides { - got: state.ctx.successful_views.len(), - expected: num_successful_views, - }, - ))), - state, - ); - } - - if state.ctx.failed_views.len() + num_incomplete_views - >= num_failed_rounds_total - { - return ( - Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::TooManyFailures { - failed_views: state.ctx.failed_views.clone(), - }, - ))), - state, - ); - } - // TODO check if we got enough successful views - (Some(HotShotTaskCompleted::ShutDown), state) - } - } - } - .boxed() - }, - )); - - let message_handler = HandleMessage::>(Arc::new( - move |msg, mut state| { - let threshold_calculator = threshold_calculator.clone(); - async move { - - let (idx, maybe_event ) : (usize, Either<_, _>)= msg; - if let Either::Left(Event { view_number, event }) = maybe_event { - let key = match event { - EventType::Error { error } => { - state.ctx.insert_error_to_context(view_number, idx, error); - None - } - EventType::Decide { - leaf_chain, - qc, - block_size: maybe_block_size, - } => { - // Skip the genesis leaf. - if leaf_chain.len() == 1 && leaf_chain[0].get_view_number() == TYPES::Time::genesis() { - return (None, state); - } - let paired_up = (leaf_chain.to_vec(), (*qc).clone()); - match state.ctx.round_results.entry(view_number) { - Entry::Occupied(mut o) => o.get_mut().insert_into_result( - idx, - paired_up, - maybe_block_size, - ), - Entry::Vacant(v) => { - let mut round_result = RoundResult::default(); - let key = round_result.insert_into_result( - idx, - paired_up, - maybe_block_size, - ); - v.insert(round_result); - key - } - } - } - EventType::ReplicaViewTimeout { view_number } => { - let error = Arc::new(HotShotError::::ViewTimeoutError { - view_number, - state: RoundTimedoutState::TestCollectRoundEventsTimedOut, - }); - state.ctx.insert_error_to_context(view_number, idx, error); - None - } - _ => return (None, state), - }; - - // update view count - let threshold = - (threshold_calculator)(state.handles.len(), state.handles.len()); - - let view = state.ctx.round_results.get_mut(&view_number).unwrap(); - - if let Some(key) = key { - view.update_status( - threshold, - state.handles.len(), - &key, - check_leaf, - check_block, - transaction_threshold, - ); - match view.status.clone() { - ViewStatus::Ok => { - state.ctx.successful_views.insert(view_number); - if state.ctx.successful_views.len() - >= self.num_successful_views - { - state - .test_event_stream - .publish(GlobalTestEvent::ShutDown) - .await; - return (Some(HotShotTaskCompleted::ShutDown), state); - } - return (None, state); - } - ViewStatus::Failed => { - state.ctx.failed_views.insert(view_number); - if state.ctx.failed_views.len() > self.num_failed_views { - state - .test_event_stream - .publish(GlobalTestEvent::ShutDown) - .await; - return ( - Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::TooManyFailures { - failed_views: state.ctx.failed_views.clone(), - }, - ))), - state, - ); - } - return (None, state); - } - ViewStatus::Err(e) => { - return ( - Some(HotShotTaskCompleted::Error(Box::new(e))), - state, - ); - } - ViewStatus::InProgress => { - return (None, state); - } - } - } - else if view.check_if_failed(threshold, state.handles.len()) { - view.status = ViewStatus::Failed; - state.ctx.failed_views.insert(view_number); - if state.ctx.failed_views.len() > self.num_failed_views { - state - .test_event_stream - .publish(GlobalTestEvent::ShutDown) - .await; - return ( - Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::TooManyFailures { - failed_views: state.ctx.failed_views.clone(), - }, - ))), - state, - ); - } - return (None, state); - } - - } - - (None, state) - } - .boxed() - }, - )); - - let mut streams = vec![]; - for handle in &mut state.handles { - let s1 = - handle - .handle - .get_event_stream_known_impl(FilterEvent::default()) - .await - .0; - let s2 = - handle - .handle - .get_internal_event_stream_known_impl(FilterEvent::default()) - .await - .0; - streams.push( - Merge::new(s1, s2) - ); - } - let builder = TaskBuilder::>::new( - "Test Overall Safety Task".to_string(), - ) - .register_event_stream(test_event_stream, FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_message_handler(message_handler) - .register_message_stream(MergeN::new(streams)) - .register_event_handler(event_handler) - .register_state(state); - let task_id = builder.get_task_id().unwrap(); - (task_id, OverallSafetyTaskTypes::build(builder).launch()) - } - .boxed() - }) - } -} - -/// overall types for safety task -pub type OverallSafetyTaskTypes = HSTWithEventAndMessage< - OverallSafetyTaskErr, - GlobalTestEvent, - ChannelStream, - (usize, Either, HotShotEvent>), - MergeN>, UnboundedStream>>>, - OverallSafetyTask, ->; diff --git a/crates/testing/src/per_node_safety_task.rs b/crates/testing/src/per_node_safety_task.rs deleted file mode 100644 index af20f00b79..0000000000 --- a/crates/testing/src/per_node_safety_task.rs +++ /dev/null @@ -1,258 +0,0 @@ -// // TODO rename this file to per-node -// -// use std::{ops::Deref, sync::Arc}; -// -// use async_compatibility_layer::channel::UnboundedStream; -// use either::Either; -// use futures::{ -// future::{BoxFuture, LocalBoxFuture}, -// FutureExt, -// }; -// use hotshot::traits::TestableNodeImplementation; -// use hotshot_task::{ -// event_stream::ChannelStream, -// global_registry::{GlobalRegistry, HotShotTaskId}, -// task::{ -// FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted, HotShotTaskTypes, TaskErr, -// HST, TS, -// }, -// task_impls::{HSTWithEvent, HSTWithEventAndMessage, TaskBuilder}, -// }; -// use hotshot_types::{ -// event::{Event, EventType}, -// traits::node_implementation::NodeType, -// }; -// use nll::nll_todo::nll_todo; -// use snafu::Snafu; -// use tracing::log::warn; -// -// use crate::test_errors::ConsensusTestError; -// -// use super::{ -// completion_task::CompletionTask, -// node_ctx::{NodeCtx, ViewFailed, ViewStatus, ViewSuccess}, -// GlobalTestEvent, -// }; -// -// #[derive(Snafu, Debug)] -// pub enum PerNodeSafetyTaskErr { -// // TODO make this more detailed -// TooManyFailures, -// NotEnoughDecides, -// } -// impl TaskErr for PerNodeSafetyTaskErr {} -// -// /// Data availability task state -// /// -// #[derive(Debug)] -// pub struct PerNodeSafetyTask> { -// pub(crate) ctx: NodeCtx, -// } -// -// impl> Default -// for PerNodeSafetyTask -// { -// fn default() -> Self { -// Self { -// ctx: Default::default(), -// } -// } -// } -// -// impl> TS -// for PerNodeSafetyTask -// { -// } -// -// /// builder describing custom safety properties -// #[derive(Clone)] -// pub enum PerNodeSafetyTaskDescription< -// TYPES: NodeType, -// I: TestableNodeImplementation, -// > { -// GenProperties(PerNodeSafetyPropertiesDescription), -// CustomProperties(PerNodeSafetyFinisher), -// } -// -// /// properties used for gen -// #[derive(Clone, Debug)] -// pub struct PerNodeSafetyPropertiesDescription { -// /// number failed views -// pub num_failed_views: Option, -// /// number decide events -// pub num_decide_events: Option, -// } -// -// // basic consistency check for single node -// /// Exists for easier overriding -// /// runs at end of all tasks -// #[derive(Clone)] -// #[allow(clippy::type_complexity)] -// pub struct PerNodeSafetyFinisher< -// TYPES: NodeType, -// I: TestableNodeImplementation, -// >( -// pub Arc< -// dyn for<'a> Fn(&'a mut NodeCtx) -> BoxFuture<'a, Result<(), PerNodeSafetyTaskErr>> -// + Send -// + 'static -// + Sync, -// >, -// ); -// -// impl> Deref -// for PerNodeSafetyFinisher -// { -// type Target = dyn for<'a> Fn(&'a mut NodeCtx) -> BoxFuture<'a, Result<(), PerNodeSafetyTaskErr>> -// + Send -// + 'static -// + Sync; -// -// fn deref(&self) -> &Self::Target { -// &*self.0 -// } -// } -// -// impl> -// PerNodeSafetyTaskDescription -// { -// fn gen_finisher(self) -> PerNodeSafetyFinisher { -// match self { -// PerNodeSafetyTaskDescription::CustomProperties(finisher) => finisher, -// PerNodeSafetyTaskDescription::GenProperties(PerNodeSafetyPropertiesDescription { -// num_failed_views, -// num_decide_events, -// }) => PerNodeSafetyFinisher(Arc::new(move |ctx: &mut NodeCtx| { -// async move { -// let mut num_failed = 0; -// let mut num_decided = 0; -// for (_view_num, view_status) in &ctx.round_results { -// match view_status { -// ViewStatus::InProgress(_) => {} -// ViewStatus::ViewFailed(_) => { -// num_failed += 1; -// } -// ViewStatus::ViewSuccess(_) => { -// num_decided += 1; -// } -// } -// } -// if let Some(num_failed_views) = num_failed_views { -// if num_failed >= num_failed_views { -// return Err(PerNodeSafetyTaskErr::TooManyFailures); -// } -// } -// -// if let Some(num_decide_events) = num_decide_events { -// if num_decided < num_decide_events { -// return Err(PerNodeSafetyTaskErr::NotEnoughDecides); -// } -// } -// Ok(()) -// } -// .boxed() -// })), -// } -// } -// -// /// build -// pub fn build( -// self, -// // registry: &mut GlobalRegistry, -// // test_event_stream: ChannelStream, -// // hotshot_event_stream: UnboundedStream>, -// ) -> TaskGenerator< -// PerNodeSafetyTask -// > { -// Box::new( -// move |state, mut registry, test_event_stream, hotshot_event_stream| { -// // TODO this is cursed, there's definitely a better way to do this -// let desc = self.clone(); -// async move { -// let test_event_handler = HandleEvent::>(Arc::new( -// move |event, mut state| { -// let finisher = desc.clone().gen_finisher(); -// async move { -// match event { -// GlobalTestEvent::ShutDown => { -// let finished = finisher(&mut state.ctx).await; -// let result = match finished { -// Ok(()) => HotShotTaskCompleted::ShutDown, -// Err(err) => HotShotTaskCompleted::Error(Box::new(err)), -// }; -// return (Some(result), state); -// } -// _ => { -// unimplemented!() -// } -// } -// } -// .boxed() -// }, -// )); -// let message_handler = HandleMessage::>(Arc::new( -// move |msg, mut state| { -// async move { -// let Event { view_number, event } = msg; -// match event { -// EventType::Error { error } => { -// // TODO better warn with node idx -// warn!("View {:?} failed for a replica", view_number); -// state.ctx.round_results.insert( -// view_number, -// ViewStatus::ViewFailed(ViewFailed(error)), -// ); -// } -// EventType::Decide { leaf_chain, qc } => { -// state.ctx.round_results.insert( -// view_number, -// ViewStatus::ViewSuccess(ViewSuccess { -// agreed_state: -// -// }), -// ); -// } -// // these aren't failures -// EventType::ReplicaViewTimeout { view_number } -// | EventType::NextLeaderViewTimeout { view_number } -// | EventType::ViewFinished { view_number } => todo!(), -// _ => todo!(), -// } -// (None, state) -// } -// .boxed() -// }, -// )); -// -// let builder = TaskBuilder::>::new( -// "Safety Check Task".to_string(), -// ) -// .register_event_stream(test_event_stream, FilterEvent::default()) -// .await -// .register_registry(&mut registry) -// .await -// .register_state(state) -// .register_event_handler(test_event_handler) -// .register_message_handler(message_handler) -// .register_message_stream(hotshot_event_stream); -// let task_id = builder.get_task_id().unwrap(); -// (task_id, PerNodeSafetyTaskTypes::build(builder).launch()) -// } -// .boxed() -// }, -// ) -// } -// } -// -// // /// Data Availability task types -// pub type PerNodeSafetyTaskTypes< -// TYPES: NodeType, -// I: TestableNodeImplementation, -// > = HSTWithEventAndMessage< -// PerNodeSafetyTaskErr, -// GlobalTestEvent, -// ChannelStream, -// Event, -// UnboundedStream>, -// PerNodeSafetyTask, -// >; diff --git a/crates/testing/src/soundness_task.rs b/crates/testing/src/soundness_task.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/crates/testing/src/soundness_task.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/testing/src/spinning_task.rs b/crates/testing/src/spinning_task.rs index d3311a8a24..017e1497a0 100644 --- a/crates/testing/src/spinning_task.rs +++ b/crates/testing/src/spinning_task.rs @@ -1,28 +1,15 @@ -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; +use std::collections::HashMap; -use async_compatibility_layer::channel::UnboundedStream; -use futures::FutureExt; use hotshot::traits::TestableNodeImplementation; -use hotshot_task::{ - event_stream::ChannelStream, - task::{FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted, HotShotTaskTypes, TS}, - task_impls::{HSTWithEventAndMessage, TaskBuilder}, - MergeN, -}; + +use crate::test_runner::HotShotTaskCompleted; +use crate::test_runner::LateStartNode; +use crate::test_runner::Node; +use hotshot_task::task::{Task, TaskState, TestTaskState}; use hotshot_types::traits::network::CommunicationChannel; -use hotshot_types::{ - event::Event, - traits::node_implementation::{ConsensusTime, NodeType}, -}; +use hotshot_types::{event::Event, traits::node_implementation::NodeType}; use snafu::Snafu; - -use crate::{ - test_launcher::TaskGenerator, - test_runner::{LateStartNode, Node}, -}; +use std::collections::BTreeMap; /// convience type for state and block pub type StateAndBlock = (Vec, Vec); @@ -44,7 +31,88 @@ pub struct SpinningTask> { pub(crate) latest_view: Option, } -impl> TS for SpinningTask {} +impl> TaskState for SpinningTask { + type Event = GlobalTestEvent; + + type Output = HotShotTaskCompleted; + + async fn handle_event(event: Self::Event, _task: &mut Task) -> Option { + if matches!(event, GlobalTestEvent::ShutDown) { + return Some(HotShotTaskCompleted::ShutDown); + } + None + } + + fn should_shutdown(_event: &Self::Event) -> bool { + false + } +} + +impl> TestTaskState + for SpinningTask +{ + type Message = Event; + + type Output = HotShotTaskCompleted; + + type State = Self; + + async fn handle_message( + message: Self::Message, + _id: usize, + task: &mut hotshot_task::task::TestTask, + ) -> Option { + let Event { + view_number, + event: _, + } = message; + + let state = &mut task.state_mut(); + + // if we have not seen this view before + if state.latest_view.is_none() || view_number > state.latest_view.unwrap() { + // perform operations on the nodes + if let Some(operations) = state.changes.remove(&view_number) { + for ChangeNode { idx, updown } in operations { + match updown { + UpDown::Up => { + if let Some(node) = state.late_start.remove(&idx.try_into().unwrap()) { + tracing::error!("Node {} spinning up late", idx); + let handle = node.context.run_tasks().await; + handle.hotshot.start_consensus().await; + } + } + UpDown::Down => { + if let Some(node) = state.handles.get_mut(idx) { + tracing::error!("Node {} shutting down", idx); + node.handle.shut_down().await; + } + } + UpDown::NetworkUp => { + if let Some(handle) = state.handles.get(idx) { + tracing::error!("Node {} networks resuming", idx); + handle.networks.0.resume(); + handle.networks.1.resume(); + } + } + UpDown::NetworkDown => { + if let Some(handle) = state.handles.get(idx) { + tracing::error!("Node {} networks pausing", idx); + handle.networks.0.pause(); + handle.networks.1.pause(); + } + } + } + } + } + + // update our latest view + state.latest_view = Some(view_number); + } + + None + } +} /// Spin the node up or down #[derive(Clone, Debug)] @@ -75,165 +143,3 @@ pub struct SpinningTaskDescription { /// the changes in node status, time -> changes pub node_changes: Vec<(u64, Vec)>, } - -impl SpinningTaskDescription { - /// build a task - /// # Panics - /// If there is no latest view - /// or if the node id is over `u32::MAX` - #[must_use] - #[allow(clippy::too_many_lines)] - pub fn build>( - self, - ) -> TaskGenerator> { - Box::new(move |mut state, mut registry, test_event_stream| { - async move { - let event_handler = - HandleEvent::>(Arc::new(move |event, state| { - async move { - match event { - GlobalTestEvent::ShutDown => { - // We do this here as well as in the completion task - // because that task has no knowledge of our late start handles. - for node in &state.handles { - node.handle.clone().shut_down().await; - } - - (Some(HotShotTaskCompleted::ShutDown), state) - } - } - } - .boxed() - })); - - let message_handler = HandleMessage::>(Arc::new( - move |msg, mut state| { - async move { - let Event { - view_number, - event: _, - } = msg.1; - - // if we have not seen this view before - if state.latest_view.is_none() - || view_number > state.latest_view.unwrap() - { - // perform operations on the nodes - - // We want to make sure we didn't miss any views (for example, there is no decide event - // if we get a timeout) - let views_with_relevant_changes: Vec<_> = state - .changes - .range(TYPES::Time::new(0)..view_number) - .map(|(k, _v)| *k) - .collect(); - - for view in views_with_relevant_changes { - if let Some(operations) = state.changes.remove(&view) { - for ChangeNode { idx, updown } in operations { - match updown { - UpDown::Up => { - if let Some(node) = state - .late_start - .remove(&idx.try_into().unwrap()) - { - tracing::error!( - "Node {} spinning up late", - idx - ); - - // create node and add to state, so we can shut them down properly later - let node = Node { - node_id: idx.try_into().unwrap(), - networks: node.networks, - handle: node.context.run_tasks().await, - }; - - // bootstrap consensus by sending the event - node.handle.hotshot.start_consensus().await; - - // add nodes to our state - state.handles.push(node); - } - } - UpDown::Down => { - if let Some(node) = state.handles.get_mut(idx) { - tracing::error!( - "Node {} shutting down", - idx - ); - node.handle.shut_down().await; - } - } - UpDown::NetworkUp => { - if let Some(handle) = state.handles.get(idx) { - tracing::error!( - "Node {} networks resuming", - idx - ); - handle.networks.0.resume(); - handle.networks.1.resume(); - } - } - UpDown::NetworkDown => { - if let Some(handle) = state.handles.get(idx) { - tracing::error!( - "Node {} networks pausing", - idx - ); - handle.networks.0.pause(); - handle.networks.1.pause(); - } - } - } - } - } - } - - // update our latest view - state.latest_view = Some(view_number); - } - - (None, state) - } - .boxed() - }, - )); - - let mut streams = vec![]; - for handle in &mut state.handles { - let s1 = handle - .handle - .get_event_stream_known_impl(FilterEvent::default()) - .await - .0; - streams.push(s1); - } - let builder = TaskBuilder::>::new( - "Test Spinning Task".to_string(), - ) - .register_event_stream(test_event_stream, FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_message_handler(message_handler) - .register_message_stream(MergeN::new(streams)) - .register_event_handler(event_handler) - .register_state(state); - let task_id = builder.get_task_id().unwrap(); - (task_id, SpinningTaskTypes::build(builder).launch()) - } - .boxed() - }) - } -} - -/// types for safety task -pub type SpinningTaskTypes = HSTWithEventAndMessage< - SpinningTaskErr, - GlobalTestEvent, - ChannelStream, - (usize, Event), - MergeN>>, - SpinningTask, ->; diff --git a/crates/testing/src/task_helpers.rs b/crates/testing/src/task_helpers.rs index a8c752a1d4..8a7a4744f0 100644 --- a/crates/testing/src/task_helpers.rs +++ b/crates/testing/src/task_helpers.rs @@ -13,7 +13,6 @@ use hotshot::{ types::{BLSPubKey, SignatureKey, SystemContextHandle}, HotShotConsensusApi, HotShotInitializer, Memberships, Networks, SystemContext, }; -use hotshot_task::event_stream::ChannelStream; use hotshot_task_impls::events::HotShotEvent; use hotshot_types::{ consensus::ConsensusMetricsValue, @@ -32,6 +31,7 @@ use hotshot_types::{ vote::HasViewNumber, }; +use async_broadcast::{Receiver, Sender}; use async_lock::RwLockUpgradableReadGuard; use bitvec::bitvec; use hotshot_types::simple_vote::QuorumData; @@ -51,7 +51,8 @@ pub async fn build_system_handle( node_id: u64, ) -> ( SystemContextHandle, - ChannelStream>, + Sender>, + Receiver>, ) { let builder = TestMetadata::default_multiple_rounds(); diff --git a/crates/testing/src/test_builder.rs b/crates/testing/src/test_builder.rs index f64f113a63..4cba204166 100644 --- a/crates/testing/src/test_builder.rs +++ b/crates/testing/src/test_builder.rs @@ -216,12 +216,8 @@ impl TestMetadata { min_transactions, timing_data, da_committee_size, - txn_description, - completion_task_description, - overall_safety_properties, - spinning_properties, + unreliable_network, - view_sync_properties, .. } = self.clone(); @@ -286,11 +282,6 @@ impl TestMetadata { a.propose_max_round_time = propose_max_round_time; }; - let txn_task_generator = txn_description.build(); - let completion_task_generator = completion_task_description.build_and_launch(); - let overall_safety_task_generator = overall_safety_properties.build(); - let spinning_task_generator = spinning_properties.build(); - let view_sync_task_generator = view_sync_properties.build(); TestLauncher { resource_generator: ResourceGenerators { channel_generator: >::gen_comm_channels( @@ -303,12 +294,6 @@ impl TestMetadata { config, }, metadata: self, - txn_task_generator, - overall_safety_task_generator, - completion_task_generator, - spinning_task_generator, - view_sync_task_generator, - hooks: vec![], } .modify_default_config(mod_config) } diff --git a/crates/testing/src/test_launcher.rs b/crates/testing/src/test_launcher.rs index 591253af4a..df7d0a6a47 100644 --- a/crates/testing/src/test_launcher.rs +++ b/crates/testing/src/test_launcher.rs @@ -1,24 +1,12 @@ use std::{collections::HashMap, marker::PhantomData, sync::Arc}; -use futures::future::BoxFuture; use hotshot::traits::{NodeImplementation, TestableNodeImplementation}; -use hotshot_task::{ - event_stream::ChannelStream, - global_registry::{GlobalRegistry, HotShotTaskId}, - task::HotShotTaskCompleted, - task_launcher::TaskRunner, -}; use hotshot_types::{ traits::{network::CommunicationChannel, node_implementation::NodeType}, HotShotConfig, }; -use crate::{spinning_task::SpinningTask, view_sync_task::ViewSyncTask}; - -use super::{ - completion_task::CompletionTask, overall_safety_task::OverallSafetyTask, - test_builder::TestMetadata, test_runner::TestRunner, txn_task::TxnTask, GlobalTestEvent, -}; +use super::{test_builder::TestMetadata, test_runner::TestRunner}; /// convience type alias for the networks available pub type Networks = ( @@ -35,25 +23,6 @@ pub type CommitteeNetworkGenerator = Box) -> T + 'static>; /// Wrapper Type for view sync function that takes a `ConnectedNetwork` and returns a `CommunicationChannel` pub type ViewSyncNetworkGenerator = Box) -> T + 'static>; -/// Wrapper type for a task generator. -pub type TaskGenerator = Box< - dyn FnOnce( - TASK, - GlobalRegistry, - ChannelStream, - ) - -> BoxFuture<'static, (HotShotTaskId, BoxFuture<'static, HotShotTaskCompleted>)>, ->; - -/// Wrapper type for a hook. -pub type Hook = Box< - dyn FnOnce( - GlobalRegistry, - ChannelStream, - ) - -> BoxFuture<'static, (HotShotTaskId, BoxFuture<'static, HotShotTaskCompleted>)>, ->; - /// generators for resources used by each node pub struct ResourceGenerators> { /// generate channels @@ -70,18 +39,6 @@ pub struct TestLauncher> { pub resource_generator: ResourceGenerators, /// metadasta used for tasks pub metadata: TestMetadata, - /// overrideable txn task generator function - pub txn_task_generator: TaskGenerator>, - /// overrideable timeout task generator function - pub completion_task_generator: TaskGenerator>, - /// overall safety task generator - pub overall_safety_task_generator: TaskGenerator>, - /// task for spinning nodes up/down - pub spinning_task_generator: TaskGenerator>, - /// task for view sync - pub view_sync_task_generator: TaskGenerator>, - /// extra hooks in case we want to check additional things - pub hooks: Vec, } impl> TestLauncher { @@ -93,93 +50,9 @@ impl> TestLauncher>, - ) -> Self { - Self { - overall_safety_task_generator, - ..self - } - } - - /// override the safety task generator - #[must_use] - pub fn with_spinning_task_generator( - self, - spinning_task_generator: TaskGenerator>, - ) -> Self { - Self { - spinning_task_generator, - ..self - } - } - - /// overridde the completion task generator - #[must_use] - pub fn with_completion_task_generator( - self, - completion_task_generator: TaskGenerator>, - ) -> Self { - Self { - completion_task_generator, - ..self - } - } - - /// override the txn task generator - #[must_use] - pub fn with_txn_task_generator( - self, - txn_task_generator: TaskGenerator>, - ) -> Self { - Self { - txn_task_generator, - ..self - } - } - - /// override the view sync task generator - #[must_use] - pub fn with_view_sync_task_generator( - self, - view_sync_task_generator: TaskGenerator>, - ) -> Self { - Self { - view_sync_task_generator, - ..self - } - } - - /// override resource generators - #[must_use] - pub fn with_resource_generator(self, resource_generator: ResourceGenerators) -> Self { - Self { - resource_generator, - ..self - } - } - - /// add a hook - #[must_use] - pub fn add_hook(mut self, hook: Hook) -> Self { - self.hooks.push(hook); - self - } - - /// overwrite hooks with more hooks - #[must_use] - pub fn with_hooks(self, hooks: Vec) -> Self { - Self { hooks, ..self } - } - /// Modifies the config used when generating nodes with `f` #[must_use] pub fn modify_default_config( diff --git a/crates/testing/src/test_runner.rs b/crates/testing/src/test_runner.rs index 57edbf45e4..7ae67b979a 100644 --- a/crates/testing/src/test_runner.rs +++ b/crates/testing/src/test_runner.rs @@ -5,17 +5,21 @@ use super::{ txn_task::TxnTask, }; use crate::{ - spinning_task::{ChangeNode, UpDown}, + completion_task::CompletionTaskDescription, + spinning_task::{ChangeNode, SpinningTask, UpDown}, state_types::TestInstanceState, test_launcher::{Networks, TestLauncher}, + txn_task::TxnTaskDescription, view_sync_task::ViewSyncTask, }; +use async_broadcast::broadcast; +use futures::future::join_all; use hotshot::{types::SystemContextHandle, Memberships}; use hotshot::{traits::TestableNodeImplementation, HotShotInitializer, SystemContext}; -use hotshot_task::{ - event_stream::ChannelStream, global_registry::GlobalRegistry, task_launcher::TaskRunner, -}; + +use hotshot_constants::EVENT_CHANNEL_SIZE; +use hotshot_task::task::{Task, TaskRegistry, TestTask}; use hotshot_types::traits::{ network::CommunicationChannel, node_implementation::NodeImplementation, }; @@ -30,6 +34,7 @@ use hotshot_types::{ use std::{ collections::{BTreeMap, HashMap, HashSet}, marker::PhantomData, + sync::Arc, }; #[allow(deprecated)] @@ -70,12 +75,30 @@ pub struct TestRunner< pub(crate) late_start: HashMap>, /// the next node unique identifier pub(crate) next_node_id: u64, - /// overarching test task - pub(crate) task_runner: TaskRunner, - /// PhantomData for N + /// Phantom for N pub(crate) _pd: PhantomData, } +/// enum describing how the tasks completed +pub enum HotShotTaskCompleted { + /// the task shut down successfully + ShutDown, + /// the task encountered an error + Error(Box), + /// the streams the task was listening for died + StreamsDied, + /// we somehow lost the state + /// this is definitely a bug. + LostState, + /// lost the return value somehow + LostReturnValue, + /// Stream exists but missing handler + MissingHandler, +} + +pub trait TaskErr: std::error::Error + Sync + Send + 'static {} +impl TaskErr for T {} + impl< TYPES: NodeType, I: TestableNodeImplementation, @@ -90,6 +113,7 @@ where /// if the test fails #[allow(clippy::too_many_lines)] pub async fn run_test(mut self) { + let (tx, rx) = broadcast(EVENT_CHANNEL_SIZE); let spinning_changes = self .launcher .metadata @@ -108,45 +132,53 @@ where self.add_nodes(self.launcher.metadata.total_nodes, &late_start_nodes) .await; + let mut event_rxs = vec![]; + let mut internal_event_rxs = vec![]; + + for node in &self.nodes { + let r = node.handle.get_event_stream_known_impl(); + event_rxs.push(r); + } + for node in &self.nodes { + let r = node.handle.get_internal_event_stream_known_impl(); + internal_event_rxs.push(r); + } + + let reg = Arc::new(TaskRegistry::default()); let TestRunner { - launcher, + ref launcher, nodes, late_start, next_node_id: _, - mut task_runner, - _pd: PhantomData, + _pd: _, } = self; - let registry = GlobalRegistry::default(); - let test_event_stream = ChannelStream::new(); - // add transaction task - let txn_task_state = TxnTask { - handles: nodes.clone(), - next_node_idx: Some(0), - }; - let (id, task) = (launcher.txn_task_generator)( - txn_task_state, - registry.clone(), - test_event_stream.clone(), - ) - .await; - task_runner = - task_runner.add_task(id, "Test Transaction Submission Task".to_string(), task); + let mut task_futs = vec![]; + let meta = launcher.metadata.clone(); + + let txn_task = + if let TxnTaskDescription::RoundRobinTimeBased(duration) = meta.txn_description { + let txn_task = TxnTask { + handles: nodes.clone(), + next_node_idx: Some(0), + duration, + shutdown_chan: rx.clone(), + }; + Some(txn_task) + } else { + None + }; // add completion task - let completion_task_state = CompletionTask { + let CompletionTaskDescription::TimeBasedCompletionTaskBuilder(time_based) = + meta.completion_task_description; + let completion_task = CompletionTask { + tx: tx.clone(), + rx: rx.clone(), handles: nodes.clone(), - test_event_stream: test_event_stream.clone(), + duration: time_based.duration, }; - let (id, task) = (launcher.completion_task_generator)( - completion_task_state, - registry.clone(), - test_event_stream.clone(), - ) - .await; - - task_runner = task_runner.add_task(id, "Test Completion Task".to_string(), task); // add spinning task // map spinning to view @@ -158,48 +190,44 @@ where .append(&mut change); } - let spinning_task_state = crate::spinning_task::SpinningTask { + let spinning_task_state = SpinningTask { handles: nodes.clone(), late_start, latest_view: None, changes, }; - - let (id, task) = (launcher.spinning_task_generator)( - spinning_task_state, - registry.clone(), - test_event_stream.clone(), - ) - .await; - task_runner = task_runner.add_task(id, "Test Spinning Task".to_string(), task); - + let spinning_task = TestTask::, SpinningTask>::new( + Task::new(tx.clone(), rx.clone(), reg.clone(), spinning_task_state), + event_rxs.clone(), + ); // add safety task let overall_safety_task_state = OverallSafetyTask { handles: nodes.clone(), ctx: RoundCtx::default(), - test_event_stream: test_event_stream.clone(), + properties: self.launcher.metadata.overall_safety_properties, }; - let (id, task) = (launcher.overall_safety_task_generator)( - overall_safety_task_state, - registry.clone(), - test_event_stream.clone(), - ) - .await; - task_runner = task_runner.add_task(id, "Test Overall Safety Task".to_string(), task); + + let safety_task = TestTask::, OverallSafetyTask>::new( + Task::new( + tx.clone(), + rx.clone(), + reg.clone(), + overall_safety_task_state, + ), + event_rxs.clone(), + ); // add view sync task let view_sync_task_state = ViewSyncTask { - handles: nodes.clone(), hit_view_sync: HashSet::new(), + description: self.launcher.metadata.view_sync_properties, + _pd: PhantomData, }; - let (id, task) = (launcher.view_sync_task_generator)( - view_sync_task_state, - registry.clone(), - test_event_stream.clone(), - ) - .await; - task_runner = task_runner.add_task(id, "View Sync Task".to_string(), task); + let view_sync_task = TestTask::, ViewSyncTask>::new( + Task::new(tx.clone(), rx.clone(), reg.clone(), view_sync_task_state), + internal_event_rxs, + ); // wait for networks to be ready for node in &nodes { @@ -212,21 +240,57 @@ where node.handle.hotshot.start_consensus().await; } } - - let results = task_runner.launch().await; - + task_futs.push(safety_task.run()); + task_futs.push(view_sync_task.run()); + if let Some(txn) = txn_task { + task_futs.push(txn.run()); + } + task_futs.push(completion_task.run()); + task_futs.push(spinning_task.run()); let mut error_list = vec![]; - for (name, result) in results { - match result { - hotshot_task::task::HotShotTaskCompleted::ShutDown => { - info!("Task {} shut down successfully", name); + + #[cfg(async_executor_impl = "async-std")] + { + let results = join_all(task_futs).await; + tracing::error!("test tasks joined"); + for result in results { + match result { + HotShotTaskCompleted::ShutDown => { + info!("Task shut down successfully"); + } + HotShotTaskCompleted::Error(e) => error_list.push(e), + _ => { + panic!("Future impl for task abstraction failed! This should never happen"); + } } - hotshot_task::task::HotShotTaskCompleted::Error(e) => error_list.push((name, e)), - _ => { - panic!("Future impl for task abstraction failed! This should never happen"); + } + } + + #[cfg(async_executor_impl = "tokio")] + { + let results = join_all(task_futs).await; + + tracing::error!("test tasks joined"); + for result in results { + match result { + Ok(res) => { + match res { + HotShotTaskCompleted::ShutDown => { + info!("Task shut down successfully"); + } + HotShotTaskCompleted::Error(e) => error_list.push(e), + _ => { + panic!("Future impl for task abstraction failed! This should never happen"); + } + } + } + Err(e) => { + panic!("Error Joining the test task {:?}", e); + } } } } + assert!( error_list.is_empty(), "TEST FAILED! Results: {error_list:?}" diff --git a/crates/testing/src/timeout_task.rs b/crates/testing/src/timeout_task.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/crates/testing/src/timeout_task.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/testing/src/txn_task.rs b/crates/testing/src/txn_task.rs index 994a37baeb..4e510cc50a 100644 --- a/crates/testing/src/txn_task.rs +++ b/crates/testing/src/txn_task.rs @@ -1,21 +1,18 @@ -use crate::test_runner::Node; -use async_compatibility_layer::art::{async_sleep, async_timeout}; -use futures::FutureExt; +use crate::test_runner::{HotShotTaskCompleted, Node}; +use async_broadcast::{Receiver, TryRecvError}; +use async_compatibility_layer::art::{async_sleep, async_spawn}; +#[cfg(async_executor_impl = "async-std")] +use async_std::task::JoinHandle; use hotshot::traits::TestableNodeImplementation; -use hotshot_task::{ - boxed_sync, - event_stream::ChannelStream, - task::{FilterEvent, HandleEvent, HandleMessage, HotShotTaskCompleted, HotShotTaskTypes, TS}, - task_impls::{HSTWithEventAndMessage, TaskBuilder}, - GeneratedStream, -}; -use hotshot_types::traits::node_implementation::{NodeImplementation, NodeType}; +use hotshot_types::traits::node_implementation::NodeType; use rand::thread_rng; use snafu::Snafu; -use std::{sync::Arc, time::Duration}; -use tracing::error; +#[cfg(async_executor_impl = "tokio")] +use tokio::task::JoinHandle; -use super::{test_launcher::TaskGenerator, GlobalTestEvent}; +use std::time::Duration; + +use super::GlobalTestEvent; // the obvious idea here is to pass in a "stream" that completes every `n` seconds // the stream construction can definitely be fancier but that's the baseline idea @@ -31,19 +28,58 @@ pub struct TxnTask> { pub handles: Vec>, /// Optional index of the next node. pub next_node_idx: Option, + /// time to wait between txns + pub duration: Duration, + /// + pub shutdown_chan: Receiver, } -impl> TS for TxnTask {} - -/// types for task that deices when things are completed -pub type TxnTaskTypes = HSTWithEventAndMessage< - TxnTaskErr, - GlobalTestEvent, - ChannelStream, - (), - GeneratedStream<()>, - TxnTask, ->; +impl> TxnTask { + pub fn run(mut self) -> JoinHandle { + async_spawn(async move { + async_sleep(Duration::from_millis(100)).await; + loop { + async_sleep(self.duration).await; + match self.shutdown_chan.try_recv() { + Ok(_event) => { + return HotShotTaskCompleted::ShutDown; + } + Err(TryRecvError::Empty) => {} + Err(_) => { + return HotShotTaskCompleted::StreamsDied; + } + } + self.submit_tx().await; + } + }) + } + async fn submit_tx(&mut self) { + if let Some(idx) = self.next_node_idx { + // submit to idx handle + // increment state + self.next_node_idx = Some((idx + 1) % self.handles.len()); + match self.handles.get(idx) { + None => { + tracing::error!("couldn't get node in txn task"); + // should do error + unimplemented!() + } + Some(node) => { + // use rand::seq::IteratorRandom; + // we're assuming all nodes have the same leaf. + // If they don't match, this is probably fine since + // it should be caught by an assertion (and the txn will be rejected anyway) + let leaf = node.handle.get_decided_leaf().await; + let txn = I::leaf_create_random_transaction(&leaf, &mut thread_rng(), 0); + node.handle + .submit_transaction(txn.clone()) + .await + .expect("Could not send transaction"); + } + } + } + } +} /// build the transaction task #[derive(Clone, Debug)] @@ -54,120 +90,3 @@ pub enum TxnTaskDescription { /// TODO DistributionBased, // others? } - -impl TxnTaskDescription { - /// build a task - /// # Panics - /// if unable to get task id - #[must_use] - pub fn build>( - self, - ) -> TaskGenerator> - where - TYPES: NodeType, - I: NodeImplementation, - { - Box::new(move |state, mut registry, test_event_stream| { - async move { - // consistency check - match self { - TxnTaskDescription::RoundRobinTimeBased(_) => { - assert!(state.next_node_idx.is_some()); - } - TxnTaskDescription::DistributionBased => assert!(state.next_node_idx.is_none()), - } - // TODO we'll possibly want multiple criterion including: - // - certain number of txns committed - // - anchor of certain depth - // - some other stuff? probably? - let event_handler = - HandleEvent::>(Arc::new(move |event, state| { - async move { - match event { - GlobalTestEvent::ShutDown => { - (Some(HotShotTaskCompleted::ShutDown), state) - } - } - } - .boxed() - })); - let message_handler = - HandleMessage::>(Arc::new(move |(), mut state| { - async move { - if let Some(idx) = state.next_node_idx { - // submit to idx handle - // increment state - state.next_node_idx = Some((idx + 1) % state.handles.len()); - match state.handles.get(idx) { - None => { - // should do error - unimplemented!() - } - Some(node) => { - // use rand::seq::IteratorRandom; - // we're assuming all nodes have the same leaf. - // If they don't match, this is probably fine since - // it should be caught by an assertion (and the txn will be rejected anyway) - - // Attempts to grab the most recently decided leaf. On failure, we don't - // send a transaction. This is to prevent deadlock. - if let Some(leaf) = node.handle.try_get_decided_leaf() { - let txn = I::leaf_create_random_transaction( - &leaf, - &mut thread_rng(), - 0, - ); - - // Time out if we can't get a lock on consensus in a reasonable time. This is to - // prevent deadlock. - if let Err(err) = async_timeout( - Duration::from_secs(1), - node.handle.submit_transaction(txn.clone()), - ) - .await - { - error!("Failed to send test transaction: {err}"); - }; - } - - (None, state) - } - } - } else { - // TODO make an issue - // in the case that this is random - // which I haven't implemented yet - unimplemented!() - } - } - .boxed() - })); - let stream_generator = match self { - TxnTaskDescription::RoundRobinTimeBased(duration) => { - GeneratedStream::new(Arc::new(move || { - let fut = async move { - async_sleep(duration).await; - }; - Some(boxed_sync(fut)) - })) - } - TxnTaskDescription::DistributionBased => unimplemented!(), - }; - let builder = TaskBuilder::>::new( - "Test Transaction Submission Task".to_string(), - ) - .register_event_stream(test_event_stream, FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_state(state) - .register_event_handler(event_handler) - .register_message_handler(message_handler) - .register_message_stream(stream_generator); - let task_id = builder.get_task_id().unwrap(); - (task_id, TxnTaskTypes::build(builder).launch()) - } - .boxed() - }) - } -} diff --git a/crates/testing/src/view_sync_task.rs b/crates/testing/src/view_sync_task.rs index 0da12d6a3b..139e6b73fd 100644 --- a/crates/testing/src/view_sync_task.rs +++ b/crates/testing/src/view_sync_task.rs @@ -1,18 +1,10 @@ -use async_compatibility_layer::channel::UnboundedStream; -use futures::FutureExt; -use hotshot_task::task::{HotShotTaskCompleted, HotShotTaskTypes}; -use hotshot_task::{ - event_stream::ChannelStream, - task::{FilterEvent, HandleEvent, HandleMessage, TS}, - task_impls::{HSTWithEventAndMessage, TaskBuilder}, - MergeN, -}; +use hotshot_task::task::{Task, TaskState, TestTaskState}; use hotshot_task_impls::events::HotShotEvent; use hotshot_types::traits::node_implementation::{NodeType, TestableNodeImplementation}; use snafu::Snafu; -use std::{collections::HashSet, sync::Arc}; +use std::{collections::HashSet, marker::PhantomData}; -use crate::{test_launcher::TaskGenerator, test_runner::Node, GlobalTestEvent}; +use crate::{test_runner::HotShotTaskCompleted, GlobalTestEvent}; /// `ViewSync` Task error #[derive(Snafu, Debug, Clone)] @@ -23,23 +15,79 @@ pub struct ViewSyncTaskErr { /// `ViewSync` task state pub struct ViewSyncTask> { - /// the node handles - pub(crate) handles: Vec>, /// nodes that hit view sync pub(crate) hit_view_sync: HashSet, + /// properties of task + pub(crate) description: ViewSyncTaskDescription, + /// Phantom data for TYPES and I + pub(crate) _pd: PhantomData<(TYPES, I)>, } -impl> TS for ViewSyncTask {} +impl> TaskState for ViewSyncTask { + type Event = GlobalTestEvent; -/// `ViewSync` task types -pub type ViewSyncTaskTypes = HSTWithEventAndMessage< - ViewSyncTaskErr, - GlobalTestEvent, - ChannelStream, - (usize, HotShotEvent), - MergeN>>, - ViewSyncTask, ->; + type Output = HotShotTaskCompleted; + + async fn handle_event(event: Self::Event, task: &mut Task) -> Option { + let state = task.state_mut(); + match event { + GlobalTestEvent::ShutDown => match state.description.clone() { + ViewSyncTaskDescription::Threshold(min, max) => { + let num_hits = state.hit_view_sync.len(); + if min <= num_hits && num_hits <= max { + Some(HotShotTaskCompleted::ShutDown) + } else { + Some(HotShotTaskCompleted::Error(Box::new(ViewSyncTaskErr { + hit_view_sync: state.hit_view_sync.clone(), + }))) + } + } + }, + } + } + + fn should_shutdown(_event: &Self::Event) -> bool { + false + } +} + +impl> TestTaskState + for ViewSyncTask +{ + type Message = HotShotEvent; + + type Output = HotShotTaskCompleted; + + type State = Self; + + async fn handle_message( + message: Self::Message, + id: usize, + task: &mut hotshot_task::task::TestTask, + ) -> Option { + match message { + // all the view sync events + HotShotEvent::ViewSyncTimeout(_, _, _) + | HotShotEvent::ViewSyncPreCommitVoteRecv(_) + | HotShotEvent::ViewSyncCommitVoteRecv(_) + | HotShotEvent::ViewSyncFinalizeVoteRecv(_) + | HotShotEvent::ViewSyncPreCommitVoteSend(_) + | HotShotEvent::ViewSyncCommitVoteSend(_) + | HotShotEvent::ViewSyncFinalizeVoteSend(_) + | HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) + | HotShotEvent::ViewSyncCommitCertificate2Recv(_) + | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) + | HotShotEvent::ViewSyncPreCommitCertificate2Send(_, _) + | HotShotEvent::ViewSyncCommitCertificate2Send(_, _) + | HotShotEvent::ViewSyncFinalizeCertificate2Send(_, _) + | HotShotEvent::ViewSyncTrigger(_) => { + task.state_mut().hit_view_sync.insert(id); + } + _ => (), + } + None + } +} /// enum desecribing whether a node should hit view sync #[derive(Clone, Debug, Copy)] @@ -58,100 +106,3 @@ pub enum ViewSyncTaskDescription { /// (min, max) number nodes that may hit view sync, inclusive Threshold(usize, usize), } - -impl ViewSyncTaskDescription { - /// build a view sync task from its description - /// # Panics - /// if there is an violation of the view sync description - #[must_use] - pub fn build>( - self, - ) -> TaskGenerator> { - Box::new(move |mut state, mut registry, test_event_stream| { - async move { - let event_handler = - HandleEvent::>(Arc::new(move |event, state| { - let self_dup = self.clone(); - async move { - match event { - GlobalTestEvent::ShutDown => match self_dup.clone() { - ViewSyncTaskDescription::Threshold(min, max) => { - let num_hits = state.hit_view_sync.len(); - if min <= num_hits && num_hits <= max { - (Some(HotShotTaskCompleted::ShutDown), state) - } else { - ( - Some(HotShotTaskCompleted::Error(Box::new( - ViewSyncTaskErr { - hit_view_sync: state.hit_view_sync.clone(), - }, - ))), - state, - ) - } - } - }, - } - } - .boxed() - })); - - let message_handler = HandleMessage::>(Arc::new( - // NOTE: could short circuit on entering view sync if we're not supposed to - // enter view sync. I opted not to do this just to gather more information - // (since we'll fail the test later anyway) - move |(id, msg), mut state| { - async move { - match msg { - // all the view sync events - HotShotEvent::ViewSyncTimeout(_, _, _) - | HotShotEvent::ViewSyncPreCommitVoteRecv(_) - | HotShotEvent::ViewSyncCommitVoteRecv(_) - | HotShotEvent::ViewSyncFinalizeVoteRecv(_) - | HotShotEvent::ViewSyncPreCommitVoteSend(_) - | HotShotEvent::ViewSyncCommitVoteSend(_) - | HotShotEvent::ViewSyncFinalizeVoteSend(_) - | HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) - | HotShotEvent::ViewSyncPreCommitCertificate2Send(_, _) - | HotShotEvent::ViewSyncCommitCertificate2Send(_, _) - | HotShotEvent::ViewSyncFinalizeCertificate2Send(_, _) - | HotShotEvent::ViewSyncTrigger(_) => { - state.hit_view_sync.insert(id); - } - _ => (), - } - (None, state) - } - .boxed() - }, - )); - let mut streams = vec![]; - for handle in &mut state.handles { - let stream = handle - .handle - .get_internal_event_stream_known_impl(FilterEvent::default()) - .await - .0; - streams.push(stream); - } - - let builder = TaskBuilder::>::new( - "Test Completion Task".to_string(), - ) - .register_event_stream(test_event_stream, FilterEvent::default()) - .await - .register_registry(&mut registry) - .await - .register_state(state) - .register_event_handler(event_handler) - .register_message_handler(message_handler) - .register_message_stream(MergeN::new(streams)); - let task_id = builder.get_task_id().unwrap(); - (task_id, ViewSyncTaskTypes::build(builder).launch()) - } - .boxed() - }) - } -} diff --git a/crates/testing/tests/consensus_task.rs b/crates/testing/tests/consensus_task.rs index b5cdcc7887..8ed52d6aea 100644 --- a/crates/testing/tests/consensus_task.rs +++ b/crates/testing/tests/consensus_task.rs @@ -1,7 +1,6 @@ #![allow(clippy::panic)] use commit::Committable; -use hotshot::{tasks::add_consensus_task, types::SystemContextHandle, HotShotConsensusApi}; -use hotshot_task::event_stream::ChannelStream; +use hotshot::{types::SystemContextHandle, HotShotConsensusApi}; use hotshot_task_impls::events::HotShotEvent; use hotshot_testing::{ node_types::{MemoryImpl, TestTypes}, @@ -82,6 +81,7 @@ async fn build_vote( )] #[cfg_attr(async_executor_impl = "async-std", async_std::test)] async fn test_consensus_task() { + use hotshot::tasks::create_consensus_state; use hotshot_task_impls::harness::run_harness; use hotshot_testing::task_helpers::build_system_handle; use hotshot_types::simple_certificate::QuorumCertificate; @@ -108,30 +108,22 @@ async fn test_consensus_task() { input.push(HotShotEvent::Shutdown); - output.insert(HotShotEvent::QCFormed(either::Left(qc)), 1); output.insert( HotShotEvent::QuorumProposalSend(proposal.clone(), public_key), 1, ); - output.insert( - HotShotEvent::QuorumProposalRecv(proposal.clone(), public_key), - 1, - ); + output.insert(HotShotEvent::ViewChange(ViewNumber::new(1)), 1); if let GeneralConsensusMessage::Vote(vote) = build_vote(&handle, proposal.data).await { output.insert(HotShotEvent::QuorumVoteSend(vote.clone()), 1); input.push(HotShotEvent::QuorumVoteRecv(vote.clone())); - output.insert(HotShotEvent::QuorumVoteRecv(vote), 1); } - output.insert(HotShotEvent::Shutdown, 1); - - let build_fn = |task_runner, event_stream| { - add_consensus_task(task_runner, event_stream, ChannelStream::new(), handle) - }; + let consensus_state = + create_consensus_state(handle.hotshot.inner.output_event_stream.0.clone(), &handle).await; - run_harness(input, output, None, build_fn, false).await; + run_harness(input, output, consensus_state, false).await; } #[cfg(test)] @@ -141,6 +133,7 @@ async fn test_consensus_task() { )] #[cfg_attr(async_executor_impl = "async-std", async_std::test)] async fn test_consensus_vote() { + use hotshot::tasks::create_consensus_state; use hotshot_task_impls::harness::run_harness; use hotshot_testing::task_helpers::build_system_handle; @@ -161,27 +154,21 @@ async fn test_consensus_vote() { proposal.clone(), public_key, )); - output.insert( - HotShotEvent::QuorumProposalRecv(proposal.clone(), public_key), - 1, - ); + let proposal = proposal.data; if let GeneralConsensusMessage::Vote(vote) = build_vote(&handle, proposal).await { output.insert(HotShotEvent::QuorumVoteSend(vote.clone()), 1); input.push(HotShotEvent::QuorumVoteRecv(vote.clone())); - output.insert(HotShotEvent::QuorumVoteRecv(vote), 1); } output.insert(HotShotEvent::ViewChange(ViewNumber::new(1)), 1); input.push(HotShotEvent::Shutdown); - output.insert(HotShotEvent::Shutdown, 1); - let build_fn = |task_runner, event_stream| { - add_consensus_task(task_runner, event_stream, ChannelStream::new(), handle) - }; + let consensus_state = + create_consensus_state(handle.hotshot.inner.output_event_stream.0.clone(), &handle).await; - run_harness(input, output, None, build_fn, false).await; + run_harness(input, output, consensus_state, false).await; } #[cfg(test)] @@ -215,7 +202,7 @@ async fn test_consensus_with_vid() { async_compatibility_layer::logging::setup_logging(); async_compatibility_layer::logging::setup_backtrace(); - let (handle, _event_stream) = build_system_handle(2).await; + let (handle, _tx, _rx) = build_system_handle(2).await; // We assign node's key pair rather than read from config file since it's a test // In view 2, node 2 is the leader. let (private_key_view2, public_key_view2) = key_pair_for_id(2); @@ -283,32 +270,21 @@ async fn test_consensus_with_vid() { public_key_view2, )); - output.insert( - HotShotEvent::QuorumProposalRecv(proposal_view2.clone(), public_key_view2), - 1, - ); - output.insert(HotShotEvent::DACRecv(created_dac_view2), 1); - output.insert(HotShotEvent::VidDisperseRecv(vid_proposal, pub_key), 1); - if let GeneralConsensusMessage::Vote(vote) = build_vote(&handle, proposal_view2.data).await { output.insert(HotShotEvent::QuorumVoteSend(vote.clone()), 1); } - output.insert( - HotShotEvent::ViewChange(ViewNumber::new(1)), - 2, // 2 occurrences: 1 from `QuorumProposalRecv`, 1 from input - ); - output.insert( - HotShotEvent::ViewChange(ViewNumber::new(2)), - 2, // 2 occurrences: 1 from `QuorumProposalRecv`?, 1 from input - ); + output.insert(HotShotEvent::ViewChange(ViewNumber::new(1)), 1); + output.insert(HotShotEvent::ViewChange(ViewNumber::new(2)), 1); input.push(HotShotEvent::Shutdown); output.insert(HotShotEvent::Shutdown, 1); - let build_fn = |task_runner, event_stream| { - add_consensus_task(task_runner, event_stream, ChannelStream::new(), handle) - }; + let consensus_state = hotshot::tasks::create_consensus_state( + handle.hotshot.inner.output_event_stream.0.clone(), + &handle, + ) + .await; - run_harness(input, output, None, build_fn, false).await; + run_harness(input, output, consensus_state, false).await; } diff --git a/crates/testing/tests/da_task.rs b/crates/testing/tests/da_task.rs index fa0e12eb43..9371b4f913 100644 --- a/crates/testing/tests/da_task.rs +++ b/crates/testing/tests/da_task.rs @@ -1,5 +1,5 @@ use hotshot::{types::SignatureKey, HotShotConsensusApi}; -use hotshot_task_impls::events::HotShotEvent; +use hotshot_task_impls::{da::DATaskState, events::HotShotEvent}; use hotshot_testing::{ block_types::TestTransaction, node_types::{MemoryImpl, TestTypes}, @@ -23,7 +23,6 @@ use std::{collections::HashMap, marker::PhantomData}; )] #[cfg_attr(async_executor_impl = "async-std", async_std::test)] async fn test_da_task() { - use hotshot::tasks::add_da_task; use hotshot_task_impls::harness::run_harness; use hotshot_testing::task_helpers::build_system_handle; use hotshot_types::message::Proposal; @@ -83,11 +82,6 @@ async fn test_da_task() { input.push(HotShotEvent::Shutdown); - output.insert(HotShotEvent::ViewChange(ViewNumber::new(1)), 1); - output.insert( - HotShotEvent::TransactionsSequenced(encoded_transactions, (), ViewNumber::new(2)), - 1, - ); output.insert(HotShotEvent::DAProposalSend(message.clone(), pub_key), 1); let da_vote = DAVote::create_signed_vote( DAData { @@ -100,12 +94,17 @@ async fn test_da_task() { .expect("Failed to sign DAData"); output.insert(HotShotEvent::DAVoteSend(da_vote), 1); - output.insert(HotShotEvent::DAProposalRecv(message, pub_key), 1); - - output.insert(HotShotEvent::ViewChange(ViewNumber::new(2)), 1); - output.insert(HotShotEvent::Shutdown, 1); - - let build_fn = |task_runner, event_stream| add_da_task(task_runner, event_stream, handle); - - run_harness(input, output, None, build_fn, false).await; + let da_state = DATaskState { + api: api.clone(), + consensus: handle.hotshot.get_consensus(), + da_membership: api.inner.memberships.da_membership.clone().into(), + da_network: api.inner.networks.da_network.clone().into(), + quorum_membership: api.inner.memberships.quorum_membership.clone().into(), + cur_view: ViewNumber::new(0), + vote_collector: None.into(), + public_key: *api.public_key(), + private_key: api.private_key().clone(), + id: handle.hotshot.inner.id, + }; + run_harness(input, output, da_state, false).await; } diff --git a/crates/testing/tests/network_task.rs b/crates/testing/tests/network_task.rs index a3810fc21b..447805d9c7 100644 --- a/crates/testing/tests/network_task.rs +++ b/crates/testing/tests/network_task.rs @@ -20,7 +20,6 @@ use std::{collections::HashMap, marker::PhantomData}; #[ignore] #[allow(clippy::too_many_lines)] async fn test_network_task() { - use hotshot_task_impls::harness::run_harness; use hotshot_testing::task_helpers::build_system_handle; use hotshot_types::{data::VidDisperse, message::Proposal}; @@ -28,7 +27,7 @@ async fn test_network_task() { async_compatibility_layer::logging::setup_backtrace(); // Build the API for node 2. - let (handle, event_stream) = build_system_handle(2).await; + let (handle, _tx, _rx) = build_system_handle(2).await; let api: HotShotConsensusApi = HotShotConsensusApi { inner: handle.hotshot.inner.clone(), }; @@ -143,10 +142,10 @@ async fn test_network_task() { output.insert(HotShotEvent::VidDisperseRecv(vid_proposal, pub_key), 1); output.insert(HotShotEvent::DAProposalRecv(da_proposal, pub_key), 1); - let build_fn = |task_runner, _| async { task_runner }; + // let build_fn = |task_runner, _| async { task_runner }; // There may be extra outputs not in the expected set, e.g., a second `VidDisperseRecv` if the // VID task runs fast. All event types we want to test should be seen by this point, so waiting // for more events will not help us test more cases for now. Therefore, we set // `allow_extra_output` to `true` for deterministic test result. - run_harness(input, output, Some(event_stream), build_fn, true).await; + // run_harness(input, output, Some(event_stream), build_fn, true).await; } diff --git a/crates/testing/tests/vid_task.rs b/crates/testing/tests/vid_task.rs index 85e1d27fb2..041fbc1b75 100644 --- a/crates/testing/tests/vid_task.rs +++ b/crates/testing/tests/vid_task.rs @@ -1,5 +1,5 @@ -use hotshot::{tasks::add_vid_task, types::SignatureKey, HotShotConsensusApi}; -use hotshot_task_impls::events::HotShotEvent; +use hotshot::{types::SignatureKey, HotShotConsensusApi}; +use hotshot_task_impls::{events::HotShotEvent, vid::VIDTaskState}; use hotshot_testing::{ block_types::TestTransaction, node_types::{MemoryImpl, TestTypes}, @@ -68,7 +68,6 @@ async fn test_vid_task() { _pd: PhantomData, }; - // Every event input is seen on the event stream in the output. let mut input = Vec::new(); let mut output = HashMap::new(); @@ -88,15 +87,9 @@ async fn test_vid_task() { input.push(HotShotEvent::VidDisperseRecv(vid_proposal.clone(), pub_key)); input.push(HotShotEvent::Shutdown); - output.insert(HotShotEvent::ViewChange(ViewNumber::new(1)), 1); - output.insert( - HotShotEvent::TransactionsSequenced(encoded_transactions, (), ViewNumber::new(2)), - 1, - ); - output.insert( HotShotEvent::BlockReady(vid_disperse, ViewNumber::new(2)), - 2, + 1, ); output.insert( @@ -105,14 +98,19 @@ async fn test_vid_task() { ); output.insert( HotShotEvent::VidDisperseSend(vid_proposal.clone(), pub_key), - 2, // 2 occurrences: 1 from `input`, 1 from the DA task + 1, ); - output.insert(HotShotEvent::VidDisperseRecv(vid_proposal, pub_key), 1); - output.insert(HotShotEvent::ViewChange(ViewNumber::new(2)), 1); - output.insert(HotShotEvent::Shutdown, 1); - - let build_fn = |task_runner, event_stream| add_vid_task(task_runner, event_stream, handle); - - run_harness(input, output, None, build_fn, false).await; + let vid_state = VIDTaskState { + api: api.clone(), + consensus: handle.hotshot.get_consensus(), + cur_view: ViewNumber::new(0), + vote_collector: None, + network: api.inner.networks.quorum_network.clone().into(), + membership: api.inner.memberships.vid_membership.clone().into(), + public_key: *api.public_key(), + private_key: api.private_key().clone(), + id: handle.hotshot.inner.id, + }; + run_harness(input, output, vid_state, false).await; } diff --git a/crates/testing/tests/view_sync_task.rs b/crates/testing/tests/view_sync_task.rs index 5dd956a145..a9a7a51a96 100644 --- a/crates/testing/tests/view_sync_task.rs +++ b/crates/testing/tests/view_sync_task.rs @@ -11,10 +11,12 @@ use std::collections::HashMap; )] #[cfg_attr(async_executor_impl = "async-std", async_std::test)] async fn test_view_sync_task() { - use hotshot::tasks::add_view_sync_task; use hotshot_task_impls::harness::run_harness; + use hotshot_task_impls::view_sync::ViewSyncTaskState; use hotshot_testing::task_helpers::build_system_handle; use hotshot_types::simple_vote::ViewSyncPreCommitData; + use hotshot_types::traits::consensus_api::ConsensusApi; + use std::time::Duration; async_compatibility_layer::logging::setup_logging(); async_compatibility_layer::logging::setup_backtrace(); @@ -39,7 +41,6 @@ async fn test_view_sync_task() { tracing::error!("Vote in test is {:?}", vote.clone()); - // Every event input is seen on the event stream in the output. let mut input = Vec::new(); let mut output = HashMap::new(); @@ -48,16 +49,25 @@ async fn test_view_sync_task() { input.push(HotShotEvent::Shutdown); - output.insert(HotShotEvent::Timeout(ViewNumber::new(2)), 1); - output.insert(HotShotEvent::Timeout(ViewNumber::new(3)), 1); - output.insert(HotShotEvent::ViewChange(ViewNumber::new(2)), 1); output.insert(HotShotEvent::ViewSyncPreCommitVoteSend(vote.clone()), 1); - output.insert(HotShotEvent::Shutdown, 1); - - let build_fn = - |task_runner, event_stream| add_view_sync_task(task_runner, event_stream, handle); - - run_harness(input, output, None, build_fn, false).await; + let view_sync_state = ViewSyncTaskState { + current_view: ViewNumber::new(0), + next_view: ViewNumber::new(0), + network: api.inner.networks.quorum_network.clone().into(), + membership: api.inner.memberships.view_sync_membership.clone().into(), + public_key: *api.public_key(), + private_key: api.private_key().clone(), + api, + num_timeouts_tracked: 0, + replica_task_map: HashMap::default().into(), + pre_commit_relay_map: HashMap::default().into(), + commit_relay_map: HashMap::default().into(), + finalize_relay_map: HashMap::default().into(), + view_sync_timeout: Duration::new(10, 0), + id: handle.hotshot.inner.id, + last_garbage_collected_view: ViewNumber::new(0), + }; + run_harness(input, output, view_sync_state, false).await; } diff --git a/crates/types/Cargo.toml b/crates/types/Cargo.toml index ffe6b819d4..9b8117bf88 100644 --- a/crates/types/Cargo.toml +++ b/crates/types/Cargo.toml @@ -31,7 +31,6 @@ espresso-systems-common = { workspace = true } ethereum-types = { workspace = true } generic-array = { workspace = true } hotshot-constants = { path = "../constants" } -hotshot-task = { path = "../task", default-features = false } hotshot-utils = { path = "../utils" } jf-plonk = { workspace = true } jf-primitives = { workspace = true, features = ["test-srs"] } diff --git a/crates/types/src/lib.rs b/crates/types/src/lib.rs index 4b1cffac1a..388d8ad9bf 100644 --- a/crates/types/src/lib.rs +++ b/crates/types/src/lib.rs @@ -1,6 +1,6 @@ //! Types and Traits for the `HotShot` consensus module use displaydoc::Display; -use std::{num::NonZeroUsize, time::Duration}; +use std::{future::Future, num::NonZeroUsize, pin::Pin, time::Duration}; use traits::{election::ElectionConfig, signature_key::SignatureKey}; pub mod consensus; pub mod data; @@ -17,6 +17,23 @@ pub mod traits; pub mod utils; pub mod vote; +/// Pinned future that is Send and Sync +pub type BoxSyncFuture<'a, T> = Pin + Send + Sync + 'a>>; + +/// yoinked from futures crate +pub fn assert_future(future: F) -> F +where + F: Future, +{ + future +} +/// yoinked from futures crate, adds sync bound that we need +pub fn boxed_sync<'a, F>(fut: F) -> BoxSyncFuture<'a, F::Output> +where + F: Future + Sized + Send + Sync + 'a, +{ + assert_future::(Box::pin(fut)) +} /// the type of consensus to run. Either: /// wait for a signal to start a view, /// or constantly run diff --git a/crates/types/src/traits/network.rs b/crates/types/src/traits/network.rs index c968913e2c..a2c8357a29 100644 --- a/crates/types/src/traits/network.rs +++ b/crates/types/src/traits/network.rs @@ -6,7 +6,6 @@ use async_compatibility_layer::art::async_sleep; #[cfg(async_executor_impl = "async-std")] use async_std::future::TimeoutError; use dyn_clone::DynClone; -use hotshot_task::{boxed_sync, BoxSyncFuture}; use libp2p_networking::network::NetworkNodeHandleError; #[cfg(async_executor_impl = "tokio")] use tokio::time::error::Elapsed as TimeoutError; @@ -16,6 +15,7 @@ use super::{node_implementation::NodeType, signature_key::SignatureKey}; use crate::{ data::ViewNumber, message::{Message, MessagePurpose}, + BoxSyncFuture, }; use async_compatibility_layer::channel::UnboundedSendError; use async_trait::async_trait; @@ -451,7 +451,7 @@ pub trait NetworkReliability: Debug + Sync + std::marker::Send + DynClone + 'sta } } }; - boxed_sync(closure) + Box::pin(closure) } } diff --git a/crates/types/src/vote.rs b/crates/types/src/vote.rs index 808f127562..ba49f4732d 100644 --- a/crates/types/src/vote.rs +++ b/crates/types/src/vote.rs @@ -108,17 +108,17 @@ impl, CERT: Certificate Either { + pub fn accumulate(&mut self, vote: &VOTE, membership: &TYPES::Membership) -> Either<(), CERT> { let key = vote.get_signing_key(); let vote_commitment = vote.get_data_commitment(); if !key.validate(&vote.get_signature(), vote_commitment.as_ref()) { error!("Invalid vote! Vote Data {:?}", vote.get_data()); - return Either::Left(self); + return Either::Left(()); } let Some(stake_table_entry) = membership.get_stake(&key) else { - return Either::Left(self); + return Either::Left(()); }; let stake_table = membership.get_committee_qc_stake_table(); let vote_node_id = stake_table @@ -136,7 +136,7 @@ impl, CERT: Certificate, CERT: Certificate, CERT: Certificate