Skip to content

Commit

Permalink
[dag] Handle fetch response in dag handler
Browse files Browse the repository at this point in the history
  • Loading branch information
ibalajiarun committed Aug 9, 2023
1 parent da2fad5 commit 61e587c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 45 deletions.
80 changes: 53 additions & 27 deletions consensus/src/dag/dag_fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,60 @@ use aptos_infallible::RwLock;
use aptos_logger::error;
use aptos_time_service::TimeService;
use aptos_types::epoch_state::EpochState;
use futures::{stream::FuturesUnordered, StreamExt};
use tokio::sync::{oneshot, mpsc::{Sender, Receiver}};
use std::{collections::HashMap, sync::Arc, time::Duration};
use futures::{stream::FuturesUnordered, Stream, StreamExt};
use std::{
collections::HashMap,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use thiserror::Error as ThisError;
use tokio::sync::{
mpsc::{Receiver, Sender},
oneshot,
};

pub struct FetchRequester {
request_tx: Sender<LocalFetchRequest>,
node_rx_futures: FuturesUnordered<oneshot::Receiver<Node>>,
certified_node_rx_futures: FuturesUnordered<oneshot::Receiver<CertifiedNode>>,
pub struct FetchWaiter<T> {
rx: Receiver<oneshot::Receiver<T>>,
futures: Pin<Box<FuturesUnordered<oneshot::Receiver<T>>>>,
}

impl FetchRequester {
pub fn new(request_tx: Sender<LocalFetchRequest>) -> Self {
impl<T> FetchWaiter<T> {
fn new(rx: Receiver<oneshot::Receiver<T>>) -> Self {
Self {
request_tx,
node_rx_futures: FuturesUnordered::new(),
certified_node_rx_futures: FuturesUnordered::new(),
rx,
futures: Box::pin(FuturesUnordered::new()),
}
}
}

impl<T> Stream for FetchWaiter<T> {
type Item = Result<T, oneshot::error::RecvError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Poll::Ready(Some(rx)) = self.rx.poll_recv(cx) {
self.futures.push(rx);
}

self.futures.as_mut().poll_next(cx)
}
}

pub struct FetchRequester {
request_tx: Sender<LocalFetchRequest>,
node_tx: Sender<oneshot::Receiver<Node>>,
certified_node_tx: Sender<oneshot::Receiver<CertifiedNode>>,
}

impl FetchRequester {
pub fn request_for_node(&self, node: Node) -> anyhow::Result<()> {
let (res_tx, res_rx) = oneshot::channel();
let fetch_req = LocalFetchRequest::Node(node, res_tx);
self.request_tx
.try_send(fetch_req)
.map_err(|e| anyhow::anyhow!("unable to send fetch request to channel: {}", e))?;
self.node_rx_futures.push(res_rx);
self.node_tx.try_send(res_rx)?;
Ok(())
}

Expand All @@ -49,19 +76,9 @@ impl FetchRequester {
self.request_tx
.try_send(fetch_req)
.map_err(|e| anyhow::anyhow!("unable to send fetch request to channel: {}", e))?;
self.certified_node_rx_futures.push(res_rx);
self.certified_node_tx.try_send(res_rx)?;
Ok(())
}

pub async fn next_ready_node(&mut self) -> Option<Result<Node, oneshot::error::RecvError>> {
self.node_rx_futures.next().await
}

pub async fn next_ready_certified_node(
&mut self,
) -> Option<Result<CertifiedNode, oneshot::error::RecvError>> {
self.certified_node_rx_futures.next().await
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -113,8 +130,15 @@ impl DagFetcher {
network: Arc<dyn DAGNetworkSender>,
dag: Arc<RwLock<Dag>>,
time_service: TimeService,
) -> (Self, FetchRequester) {
) -> (
Self,
FetchRequester,
FetchWaiter<Node>,
FetchWaiter<CertifiedNode>,
) {
let (request_tx, request_rx) = tokio::sync::mpsc::channel(16);
let (node_tx, node_rx) = tokio::sync::mpsc::channel(100);
let (certified_node_tx, certified_node_rx) = tokio::sync::mpsc::channel(100);
(
Self {
epoch_state,
Expand All @@ -125,9 +149,11 @@ impl DagFetcher {
},
FetchRequester {
request_tx,
node_rx_futures: FuturesUnordered::new(),
certified_node_rx_futures: FuturesUnordered::new(),
node_tx,
certified_node_tx,
},
FetchWaiter::new(node_rx),
FetchWaiter::new(certified_node_rx),
)
}

Expand Down
44 changes: 33 additions & 11 deletions consensus/src/dag/dag_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

use super::{
dag_driver::DagDriver,
dag_fetcher::{DagFetcher, FetchRequestHandler},
dag_fetcher::{DagFetcher, FetchRequestHandler, FetchRequester, FetchWaiter},
dag_network::DAGNetworkSender,
order_rule::OrderRule,
storage::DAGStorage,
types::TDAGMessage,
CertifiedNode, Node,
};
use crate::{
dag::{
Expand All @@ -28,6 +29,7 @@ use aptos_types::{epoch_state::EpochState, validator_signer::ValidatorSigner};
use bytes::Bytes;
use futures::StreamExt;
use std::sync::Arc;
use tokio::select;
use tokio_retry::strategy::ExponentialBackoff;

struct NetworkHandler {
Expand All @@ -36,6 +38,8 @@ struct NetworkHandler {
dag_driver: DagDriver,
fetch_receiver: FetchRequestHandler,
epoch_state: Arc<EpochState>,
node_fetch_waiter: FetchWaiter<Node>,
certified_node_fetch_waiter: FetchWaiter<CertifiedNode>,
}

impl NetworkHandler {
Expand All @@ -58,12 +62,14 @@ impl NetworkHandler {
ExponentialBackoff::from_millis(10),
aptos_time_service.clone(),
));
let (dag_fetcher, fetch_requester) = DagFetcher::new(
epoch_state.clone(),
dag_network_sender,
dag.clone(),
aptos_time_service,
);
let (dag_fetcher, fetch_requester, node_fetch_waiter, certified_node_fetch_waiter) =
DagFetcher::new(
epoch_state.clone(),
dag_network_sender,
dag.clone(),
aptos_time_service,
);
let fetch_requester = Arc::new(fetch_requester);
Self {
dag_rpc_rx,
node_receiver: NodeBroadcastHandler::new(
Expand All @@ -82,20 +88,36 @@ impl NetworkHandler {
time_service,
storage,
order_rule,
Arc::new(fetch_requester),
fetch_requester,
),
epoch_state: epoch_state.clone(),
fetch_receiver: FetchRequestHandler::new(dag, epoch_state),
node_fetch_waiter,
certified_node_fetch_waiter,
}
}

async fn start(mut self) {
self.dag_driver.try_enter_new_round();

// TODO(ibalajiarun): clean up Reliable Broadcast storage periodically.
while let Some(msg) = self.dag_rpc_rx.next().await {
if let Err(e) = self.process_rpc(msg).await {
warn!(error = ?e, "error processing rpc");
loop {
select! {
Some(msg) = self.dag_rpc_rx.next() => {
if let Err(e) = self.process_rpc(msg).await {
warn!(error = ?e, "error processing rpc");
}
},
Some(res) = self.node_fetch_waiter.next() => {
if let Err(e) = res.map_err(|e| anyhow::anyhow!("recv error: {}", e)).and_then(|node| self.node_receiver.process(node)) {
warn!(error = ?e, "error processing node fetch notification");
}
},
Some(res) = self.certified_node_fetch_waiter.next() => {
if let Err(e) = res.map_err(|e| anyhow::anyhow!("recv error: {}", e)).and_then(|certified_node| self.dag_driver.process(certified_node)) {
warn!(error = ?e, "error processing certified node fetch notification");
}
}
}
}
}
Expand Down
31 changes: 24 additions & 7 deletions consensus/src/dag/tests/dag_driver_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@

use crate::{
dag::{
anchor_election::RoundRobinAnchorElection,
dag_driver::{DagDriver, DagDriverError},
dag_fetcher::{DagFetcher, FetchRequester},
dag_network::{DAGNetworkSender, RpcWithFallback},
dag_store::Dag,
order_rule::OrderRule,
tests::{dag_test::MockStorage, helpers::new_certified_node},
types::{CertifiedAck, DAGMessage},
RpcHandler, order_rule::OrderRule,
anchor_election::RoundRobinAnchorElection, dag_fetcher::FetchRequester,
RpcHandler,
},
test_utils::MockPayloadManager,
util::mock_time_service::SimulatedTimeService,
};
use aptos_consensus_types::common::Author;
use aptos_infallible::RwLock;
use aptos_reliable_broadcast::{RBNetworkSender, ReliableBroadcast};
use aptos_types::{epoch_state::EpochState, validator_verifier::random_validator_verifier, ledger_info::LedgerInfo};
use aptos_types::{
epoch_state::EpochState, ledger_info::LedgerInfo, validator_verifier::random_validator_verifier,
};
use async_trait::async_trait;
use claims::{assert_ok, assert_ok_eq};
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::Sender;
use tokio_retry::strategy::ExponentialBackoff;

struct MockNetworkSender {}
Expand Down Expand Up @@ -72,19 +77,31 @@ fn test_certified_node_handler() {

let zeroth_round_node = new_certified_node(0, signers[0].author(), vec![]);

let network_sender = Arc::new(MockNetworkSender {});
let rb = Arc::new(ReliableBroadcast::new(
signers.iter().map(|s| s.author()).collect(),
Arc::new(MockNetworkSender {}),
network_sender.clone(),
ExponentialBackoff::from_millis(10),
aptos_time_service::TimeService::mock(),
));
let time_service = Arc::new(SimulatedTimeService::new());
let (ordered_nodes_sender, _) = futures_channel::mpsc::unbounded();
let validators = signers.iter().map(|vs| vs.author()).collect();
let order_rule = OrderRule::new(epoch_state.clone(), LedgerInfo::mock_genesis(None), dag.clone(), Box::new(RoundRobinAnchorElection::new(validators)), ordered_nodes_sender);
let order_rule = OrderRule::new(
epoch_state.clone(),
LedgerInfo::mock_genesis(None),
dag.clone(),
Box::new(RoundRobinAnchorElection::new(validators)),
ordered_nodes_sender,
);

let (request_tx, _) = tokio::sync::mpsc::channel(10);
let fetch_requester = Arc::new(FetchRequester::new(request_tx));
let (_, fetch_requester, _, _) = DagFetcher::new(
epoch_state.clone(),
network_sender,
dag.clone(),
aptos_time_service::TimeService::mock(),
);
let fetch_requester = Arc::new(fetch_requester);

let mut driver = DagDriver::new(
signers[0].author(),
Expand Down

0 comments on commit 61e587c

Please sign in to comment.