diff --git a/consensus/src/dag/bootstrap.rs b/consensus/src/dag/bootstrap.rs index 4bf79c1ee0260..4fd9be4a1b999 100644 --- a/consensus/src/dag/bootstrap.rs +++ b/consensus/src/dag/bootstrap.rs @@ -3,7 +3,7 @@ use super::{ anchor_election::RoundRobinAnchorElection, dag_driver::DagDriver, - dag_fetcher::{DagFetcher, FetchRequestHandler}, + dag_fetcher::{DagFetcherService, FetchRequestHandler}, dag_handler::NetworkHandler, dag_network::TDAGNetworkSender, dag_store::Dag, @@ -78,7 +78,7 @@ pub fn bootstrap_dag( ); let (dag_fetcher, fetch_requester, node_fetch_waiter, certified_node_fetch_waiter) = - DagFetcher::new( + DagFetcherService::new( epoch_state.clone(), dag_network_sender, dag.clone(), diff --git a/consensus/src/dag/dag_fetcher.rs b/consensus/src/dag/dag_fetcher.rs index 0ddddf30f8a6d..568b21de00202 100644 --- a/consensus/src/dag/dag_fetcher.rs +++ b/consensus/src/dag/dag_fetcher.rs @@ -7,12 +7,13 @@ use crate::dag::{ dag_store::Dag, types::{CertifiedNode, FetchResponse, Node, RemoteFetchRequest}, }; -use anyhow::ensure; +use anyhow::{anyhow, ensure}; use aptos_consensus_types::common::Author; use aptos_infallible::RwLock; use aptos_logger::error; use aptos_time_service::TimeService; use aptos_types::epoch_state::EpochState; +use async_trait::async_trait; use futures::{stream::FuturesUnordered, Stream, StreamExt}; use std::{ collections::HashMap, @@ -124,15 +125,14 @@ impl LocalFetchRequest { } } -pub struct DagFetcher { - epoch_state: Arc, - network: Arc, +pub struct DagFetcherService { + inner: DagFetcher, dag: Arc>, request_rx: Receiver, - time_service: TimeService, + ordered_authors: Vec, } -impl DagFetcher { +impl DagFetcherService { pub fn new( epoch_state: Arc, network: Arc, @@ -147,13 +147,13 @@ impl DagFetcher { let (request_tx, request_rx) = tokio::sync::mpsc::channel(16); let (node_tx, node_rx) = tokio::sync::mpsc::channel(100); let (certified_node_tx, certified_node_rx) = tokio::sync::mpsc::channel(100); + let ordered_authors = epoch_state.verifier.get_ordered_account_addresses(); ( Self { - epoch_state, - network, + inner: DagFetcher::new(epoch_state, network, time_service), dag, request_rx, - time_service, + ordered_authors, }, FetchRequester { request_tx, @@ -167,75 +167,126 @@ impl DagFetcher { pub async fn start(mut self) { while let Some(local_request) = self.request_rx.recv().await { - let responders = local_request - .responders(&self.epoch_state.verifier.get_ordered_account_addresses()); - let remote_request = { - let dag_reader = self.dag.read(); - - let missing_parents: Vec = dag_reader - .filter_missing(local_request.node().parents_metadata()) - .cloned() - .collect(); - - if missing_parents.is_empty() { - local_request.notify(); - continue; - } - - let target = local_request.node(); - RemoteFetchRequest::new( - target.metadata().epoch(), - missing_parents, - dag_reader.bitmask(local_request.node().round()), + match self + .fetch( + local_request.node(), + local_request.responders(&self.ordered_authors), ) - }; - - let mut rpc = RpcWithFallback::new( - responders, - remote_request.clone().into(), - Duration::from_millis(500), - Duration::from_secs(1), - self.network.clone(), - self.time_service.clone(), - ); - while let Some(response) = rpc.next().await { - if let Ok(response) = - response - .and_then(FetchResponse::try_from) - .and_then(|response| { - response.verify(&remote_request, &self.epoch_state.verifier) - }) + .await + { + Ok(_) => local_request.notify(), + Err(err) => error!("unable to complete fetch successfully: {}", err), + } + } + } + + pub(super) async fn fetch( + &mut self, + node: &Node, + responders: Vec, + ) -> anyhow::Result<()> { + let remote_request = { + let dag_reader = self.dag.read(); + + let missing_parents: Vec = dag_reader + .filter_missing(node.parents_metadata()) + .cloned() + .collect(); + + if missing_parents.is_empty() { + return Ok(()); + } + + RemoteFetchRequest::new( + node.metadata().epoch(), + missing_parents, + dag_reader.bitmask(node.round()), + ) + }; + self.inner + .fetch(remote_request, responders, self.dag.clone()) + .await + } +} + +#[async_trait] +pub trait TDagFetcher: Send { + async fn fetch( + &self, + remote_request: RemoteFetchRequest, + responders: Vec, + dag: Arc>, + ) -> anyhow::Result<()>; +} + +pub(crate) struct DagFetcher { + network: Arc, + time_service: TimeService, + epoch_state: Arc, +} + +impl DagFetcher { + pub(crate) fn new( + epoch_state: Arc, + network: Arc, + time_service: TimeService, + ) -> Self { + Self { + network, + time_service, + epoch_state, + } + } +} + +#[async_trait] +impl TDagFetcher for DagFetcher { + async fn fetch( + &self, + remote_request: RemoteFetchRequest, + responders: Vec, + dag: Arc>, + ) -> anyhow::Result<()> { + let mut rpc = RpcWithFallback::new( + responders, + remote_request.clone().into(), + Duration::from_millis(500), + Duration::from_secs(1), + self.network.clone(), + self.time_service.clone(), + ); + + // TODO retry + while let Some(response) = rpc.next().await { + if let Ok(response) = response + .and_then(FetchResponse::try_from) + .and_then(|response| response.verify(&remote_request, &self.epoch_state.verifier)) + { + let certified_nodes = response.certified_nodes(); + // TODO: support chunk response or fallback to state sync { - let certified_nodes = response.certified_nodes(); - // TODO: support chunk response or fallback to state sync - { - let mut dag_writer = self.dag.write(); - for node in certified_nodes { - if let Err(e) = dag_writer.add_node(node) { - error!("Failed to add node {}", e); - } + let mut dag_writer = dag.write(); + for node in certified_nodes { + if let Err(e) = dag_writer.add_node(node) { + error!("Failed to add node {}", e); } } + } - if self - .dag - .read() - .all_exists(local_request.node().parents_metadata()) - { - local_request.notify(); - break; - } + if dag.read().all_exists(remote_request.targets().iter()) { + return Ok(()); } } // TODO retry } + Err(anyhow!("fetch failed")) } } #[derive(Debug, ThisError)] pub enum FetchRequestHandleError { - #[error("parents are missing")] - ParentsMissing, + #[error("target nodes are missing")] + TargetsMissing, } pub struct FetchRequestHandler { @@ -266,7 +317,7 @@ impl RpcHandler for FetchRequestHandler { // to satisfy this request. ensure!( dag_reader.all_exists(message.targets().iter()), - FetchRequestHandleError::ParentsMissing + FetchRequestHandleError::TargetsMissing ); let certified_nodes: Vec<_> = dag_reader diff --git a/consensus/src/dag/tests/dag_driver_tests.rs b/consensus/src/dag/tests/dag_driver_tests.rs index 2ee3f995838b0..1e2f519c00978 100644 --- a/consensus/src/dag/tests/dag_driver_tests.rs +++ b/consensus/src/dag/tests/dag_driver_tests.rs @@ -4,7 +4,7 @@ use crate::{ dag::{ anchor_election::RoundRobinAnchorElection, dag_driver::{DagDriver, DagDriverError}, - dag_fetcher::DagFetcher, + dag_fetcher::DagFetcherService, dag_network::{RpcWithFallback, TDAGNetworkSender}, dag_store::Dag, order_rule::OrderRule, @@ -101,7 +101,7 @@ async fn test_certified_node_handler() { storage.clone(), ); - let (_, fetch_requester, _, _) = DagFetcher::new( + let (_, fetch_requester, _, _) = DagFetcherService::new( epoch_state.clone(), network_sender, dag.clone(),