diff --git a/consensus/consensus-types/src/proof_of_store.rs b/consensus/consensus-types/src/proof_of_store.rs index 3fd0972658ede..dd9f18096f333 100644 --- a/consensus/consensus-types/src/proof_of_store.rs +++ b/consensus/consensus-types/src/proof_of_store.rs @@ -310,7 +310,7 @@ impl ProofOfStore { pub fn shuffled_signers(&self, validator: &ValidatorVerifier) -> Vec { let mut ret: Vec = self .multi_signature - .get_voter_addresses(&validator.get_ordered_account_addresses()); + .get_signers_addresses(&validator.get_ordered_account_addresses()); ret.shuffle(&mut thread_rng()); ret } diff --git a/consensus/consensus-types/src/timeout_2chain.rs b/consensus/consensus-types/src/timeout_2chain.rs index a0bbe49fffde5..9fc0ced6f3c50 100644 --- a/consensus/consensus-types/src/timeout_2chain.rs +++ b/consensus/consensus-types/src/timeout_2chain.rs @@ -368,7 +368,7 @@ impl AggregateSignatureWithRounds { &self, ordered_validator_addresses: &[AccountAddress], ) -> Vec { - self.sig.get_voter_addresses(ordered_validator_addresses) + self.sig.get_signers_addresses(ordered_validator_addresses) } pub fn get_voters_and_rounds( @@ -376,7 +376,7 @@ impl AggregateSignatureWithRounds { ordered_validator_addresses: &[AccountAddress], ) -> Vec<(AccountAddress, Round)> { self.sig - .get_voter_addresses(ordered_validator_addresses) + .get_signers_addresses(ordered_validator_addresses) .into_iter() .zip(self.rounds.clone()) .collect() diff --git a/consensus/src/dag/dag_fetcher.rs b/consensus/src/dag/dag_fetcher.rs new file mode 100644 index 0000000000000..194b4df2e7869 --- /dev/null +++ b/consensus/src/dag/dag_fetcher.rs @@ -0,0 +1,134 @@ +// Copyright © Aptos Foundation +// SPDX-License-Identifier: Apache-2.0 + +use crate::dag::{ + dag_store::Dag, + types::{CertifiedNode, DAGMessage, DAGNetworkSender, Node, NodeMetadata}, +}; +use aptos_consensus_types::common::{Author, Round}; +use aptos_infallible::RwLock; +use aptos_logger::error; +use aptos_types::{epoch_state::EpochState, validator_verifier::ValidatorVerifier}; +use serde::{Deserialize, Serialize}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + oneshot, +}; + +/// Represents a request to fetch missing dependencies for `target`, `start_round` represents +/// the first round we care about in the DAG, `exists_bitmask` is a two dimensional bitmask represents +/// if a node exist at [start_round + index][validator_index]. +#[derive(Serialize, Deserialize, Clone)] +struct FetchRequest { + target: NodeMetadata, + start_round: Round, + exists_bitmask: Vec>, +} + +/// Represents a response to FetchRequest, `certified_nodes` are indexed by [round][validator_index] +/// It should fill in gaps from the `exists_bitmask` according to the parents from the `target_digest` node. +#[derive(Serialize, Deserialize, Clone)] +struct FetchResponse { + epoch: u64, + certifies_nodes: Vec>, +} + +impl FetchResponse { + pub fn verify( + self, + _request: &FetchRequest, + _validator_verifier: &ValidatorVerifier, + ) -> anyhow::Result { + todo!("verification"); + } +} + +impl DAGMessage for FetchRequest { + fn epoch(&self) -> u64 { + self.target.epoch() + } +} + +impl DAGMessage for FetchResponse { + fn epoch(&self) -> u64 { + self.epoch + } +} + +enum FetchCallback { + Node(Node, oneshot::Sender), + CertifiedNode(CertifiedNode, oneshot::Sender), +} + +impl FetchCallback { + pub fn responders(&self, validators: &[Author]) -> Vec { + match self { + FetchCallback::Node(node, _) => vec![*node.author()], + FetchCallback::CertifiedNode(node, _) => node.certificate().signers(validators), + } + } + + pub fn notify(self) { + if match self { + FetchCallback::Node(node, sender) => sender.send(node).map_err(|_| ()), + FetchCallback::CertifiedNode(node, sender) => sender.send(node).map_err(|_| ()), + } + .is_err() + { + error!("Failed to send node back"); + } + } +} + +struct DagFetcher { + epoch_state: Arc, + network: Arc, + dag: Arc>, + request_rx: Receiver<(FetchRequest, FetchCallback)>, +} + +impl DagFetcher { + pub fn new( + epoch_state: Arc, + network: Arc, + dag: Arc>, + ) -> (Self, Sender<(FetchRequest, FetchCallback)>) { + let (request_tx, request_rx) = tokio::sync::mpsc::channel(16); + ( + Self { + epoch_state, + network, + dag, + request_rx, + }, + request_tx, + ) + } + + pub async fn start(mut self) { + while let Some((request, callback)) = self.request_rx.recv().await { + let responders = + callback.responders(&self.epoch_state.verifier.get_ordered_account_addresses()); + let network_request = request.clone().into_network_message(); + if let Ok(response) = self + .network + .send_rpc_with_fallbacks(responders, network_request, Duration::from_secs(1)) + .await + .and_then(FetchResponse::from_network_message) + .and_then(|response| response.verify(&request, &self.epoch_state.verifier)) + { + // TODO: support chunk response or fallback to state sync + let mut dag_writer = self.dag.write(); + for rounds in response.certifies_nodes { + for node in rounds { + if let Err(e) = dag_writer.add_node(node) { + error!("Failed to add node {}", e); + } + } + } + callback.notify(); + } + } + } +} diff --git a/consensus/src/dag/mod.rs b/consensus/src/dag/mod.rs index eced7bc2c966c..39ece5a63c42e 100644 --- a/consensus/src/dag/mod.rs +++ b/consensus/src/dag/mod.rs @@ -3,6 +3,7 @@ #![allow(dead_code)] mod dag_driver; +mod dag_fetcher; mod dag_store; mod reliable_broadcast; #[cfg(test)] diff --git a/consensus/src/dag/reliable_broadcast.rs b/consensus/src/dag/reliable_broadcast.rs index 2d3367ac81beb..7f4c179b3d3ff 100644 --- a/consensus/src/dag/reliable_broadcast.rs +++ b/consensus/src/dag/reliable_broadcast.rs @@ -1,9 +1,8 @@ // Copyright © Aptos Foundation // SPDX-License-Identifier: Apache-2.0 -use crate::{dag::types::DAGMessage, network_interface::ConsensusMsg}; +use crate::dag::types::{DAGMessage, DAGNetworkSender}; use aptos_consensus_types::common::Author; -use async_trait::async_trait; use futures::{stream::FuturesUnordered, StreamExt}; use std::{future::Future, sync::Arc, time::Duration}; @@ -15,16 +14,6 @@ pub trait BroadcastStatus { fn add(&mut self, peer: Author, ack: Self::Ack) -> anyhow::Result>; } -#[async_trait] -pub trait DAGNetworkSender: Send + Sync { - async fn send_rpc( - &self, - receiver: Author, - message: ConsensusMsg, - timeout: Duration, - ) -> anyhow::Result; -} - pub struct ReliableBroadcast { validators: Vec, network_sender: Arc, diff --git a/consensus/src/dag/tests/reliable_broadcast_tests.rs b/consensus/src/dag/tests/reliable_broadcast_tests.rs index d328e2c7f1037..fba805cafb27d 100644 --- a/consensus/src/dag/tests/reliable_broadcast_tests.rs +++ b/consensus/src/dag/tests/reliable_broadcast_tests.rs @@ -3,8 +3,8 @@ use crate::{ dag::{ - reliable_broadcast::{BroadcastStatus, DAGNetworkSender, ReliableBroadcast}, - types::DAGMessage, + reliable_broadcast::{BroadcastStatus, ReliableBroadcast}, + types::{DAGMessage, DAGNetworkSender}, }, network_interface::ConsensusMsg, }; @@ -101,6 +101,15 @@ impl DAGNetworkSender for TestDAGSender { .insert(receiver, TestMessage::from_network_message(message)?); Ok(TestAck.into_network_message()) } + + async fn send_rpc_with_fallbacks( + &self, + _responders: Vec, + _message: ConsensusMsg, + _timeout: Duration, + ) -> anyhow::Result { + unimplemented!(); + } } #[tokio::test] diff --git a/consensus/src/dag/types.rs b/consensus/src/dag/types.rs index 6e88768abb58a..03a79bcf9ccbe 100644 --- a/consensus/src/dag/types.rs +++ b/consensus/src/dag/types.rs @@ -14,8 +14,9 @@ use aptos_types::{ aggregate_signature::{AggregateSignature, PartialSignatures}, epoch_state::EpochState, }; +use async_trait::async_trait; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::{collections::HashSet, ops::Deref, sync::Arc}; +use std::{collections::HashSet, ops::Deref, sync::Arc, time::Duration}; pub trait DAGMessage: Sized + Clone + Serialize + DeserializeOwned { fn epoch(&self) -> u64; @@ -35,6 +36,25 @@ pub trait DAGMessage: Sized + Clone + Serialize + DeserializeOwned { } } +#[async_trait] +pub trait DAGNetworkSender: Send + Sync { + async fn send_rpc( + &self, + receiver: Author, + message: ConsensusMsg, + timeout: Duration, + ) -> anyhow::Result; + + /// Given a list of potential responders, sending rpc to get response from any of them and could + /// fallback to more in case of failures. + async fn send_rpc_with_fallbacks( + &self, + responders: Vec, + message: ConsensusMsg, + timeout: Duration, + ) -> anyhow::Result; +} + /// Represents the metadata about the node, without payload and parents from Node #[derive(Clone, Serialize, Deserialize)] pub struct NodeMetadata { @@ -57,6 +77,10 @@ impl NodeMetadata { pub fn author(&self) -> &Author { &self.author } + + pub fn epoch(&self) -> u64 { + self.epoch + } } /// Node representation in the DAG, parents contain 2f+1 strong links (links to previous round) @@ -143,6 +167,10 @@ impl Node { pub fn parents(&self) -> &[NodeMetadata] { &self.parents } + + pub fn author(&self) -> &Author { + self.metadata.author() + } } /// Quorum signatures over the node digest @@ -161,9 +189,13 @@ impl NodeCertificate { signatures, } } + + pub fn signers(&self, validators: &[Author]) -> Vec { + self.signatures.get_signers_addresses(validators) + } } -#[derive(Clone)] +#[derive(Serialize, Deserialize, Clone)] pub struct CertifiedNode { node: Node, certificate: NodeCertificate, @@ -173,6 +205,10 @@ impl CertifiedNode { pub fn new(node: Node, certificate: NodeCertificate) -> Self { Self { node, certificate } } + + pub fn certificate(&self) -> &NodeCertificate { + &self.certificate + } } impl Deref for CertifiedNode { diff --git a/types/src/aggregate_signature.rs b/types/src/aggregate_signature.rs index c85e642772678..0abe6f3dd9f38 100644 --- a/types/src/aggregate_signature.rs +++ b/types/src/aggregate_signature.rs @@ -36,11 +36,11 @@ impl AggregateSignature { } } - pub fn get_voters_bitvec(&self) -> &BitVec { + pub fn get_signers_bitvec(&self) -> &BitVec { &self.validator_bitmask } - pub fn get_voter_addresses( + pub fn get_signers_addresses( &self, validator_addresses: &[AccountAddress], ) -> Vec { diff --git a/types/src/ledger_info.rs b/types/src/ledger_info.rs index 3e23a80a2a97d..74a30f97e6d7c 100644 --- a/types/src/ledger_info.rs +++ b/types/src/ledger_info.rs @@ -258,7 +258,7 @@ impl LedgerInfoWithV0 { } pub fn get_voters(&self, validator_addresses: &[AccountAddress]) -> Vec { - self.signatures.get_voter_addresses(validator_addresses) + self.signatures.get_signers_addresses(validator_addresses) } pub fn get_num_voters(&self) -> usize { @@ -266,7 +266,7 @@ impl LedgerInfoWithV0 { } pub fn get_voters_bitvec(&self) -> &BitVec { - self.signatures.get_voters_bitvec() + self.signatures.get_signers_bitvec() } pub fn verify_signatures( diff --git a/types/src/validator_verifier.rs b/types/src/validator_verifier.rs index 908e06043540f..38c3dd3ed0a79 100644 --- a/types/src/validator_verifier.rs +++ b/types/src/validator_verifier.rs @@ -231,10 +231,10 @@ impl ValidatorVerifier { multi_signature: &AggregateSignature, ) -> std::result::Result<(), VerifyError> { // Verify the number of signature is not greater than expected. - Self::check_num_of_voters(self.len() as u16, multi_signature.get_voters_bitvec())?; + Self::check_num_of_voters(self.len() as u16, multi_signature.get_signers_bitvec())?; let mut pub_keys = vec![]; let mut authors = vec![]; - for index in multi_signature.get_voters_bitvec().iter_ones() { + for index in multi_signature.get_signers_bitvec().iter_ones() { let validator = self .validator_infos .get(index) @@ -274,10 +274,10 @@ impl ValidatorVerifier { aggregated_signature: &AggregateSignature, ) -> std::result::Result<(), VerifyError> { // Verify the number of signature is not greater than expected. - Self::check_num_of_voters(self.len() as u16, aggregated_signature.get_voters_bitvec())?; + Self::check_num_of_voters(self.len() as u16, aggregated_signature.get_signers_bitvec())?; let mut pub_keys = vec![]; let mut authors = vec![]; - for index in aggregated_signature.get_voters_bitvec().iter_ones() { + for index in aggregated_signature.get_signers_bitvec().iter_ones() { let validator = self .validator_infos .get(index) @@ -679,7 +679,7 @@ mod tests { .aggregate_signatures(&partial_signature) .unwrap(); assert_eq!( - aggregated_signature.get_voters_bitvec().num_buckets(), + aggregated_signature.get_signers_bitvec().num_buckets(), BitVec::required_buckets(validator_verifier.validator_infos.len() as u16) ); // Check against signatures == N; this will pass. @@ -709,7 +709,7 @@ mod tests { .aggregate_signatures(&partial_signature) .unwrap(); assert_eq!( - aggregated_signature.get_voters_bitvec().num_buckets(), + aggregated_signature.get_signers_bitvec().num_buckets(), BitVec::required_buckets(validator_verifier.validator_infos.len() as u16) ); assert_eq!( @@ -737,7 +737,7 @@ mod tests { .aggregate_signatures(&partial_signature) .unwrap(); assert_eq!( - aggregated_signature.get_voters_bitvec().num_buckets(), + aggregated_signature.get_signers_bitvec().num_buckets(), BitVec::required_buckets(validator_verifier.validator_infos.len() as u16) ); assert_eq!(