Skip to content

Commit

Permalink
Feature: Add support for external requests to be executed inside of R…
Browse files Browse the repository at this point in the history
…aft core loop

The new feature is also exposed via `RaftRouter` test fixture and tested in the initialization test (in addition to the original checks).
  • Loading branch information
schreter committed Mar 5, 2022
1 parent 0951160 commit 80f8913
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 11 deletions.
24 changes: 18 additions & 6 deletions openraft/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ pub struct RaftCore<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<
tx_compaction: mpsc::Sender<SnapshotUpdate<C>>,
rx_compaction: mpsc::Receiver<SnapshotUpdate<C>>,

rx_api: mpsc::UnboundedReceiver<(RaftMsg<C>, Span)>,
rx_api: mpsc::UnboundedReceiver<(RaftMsg<C, N, S>, Span)>,

tx_metrics: watch::Sender<RaftMetrics<C>>,

Expand All @@ -208,7 +208,7 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> RaftCore<C,
config: Arc<Config>,
network: N,
storage: S,
rx_api: mpsc::UnboundedReceiver<(RaftMsg<C>, Span)>,
rx_api: mpsc::UnboundedReceiver<(RaftMsg<C, N, S>, Span)>,
tx_metrics: watch::Sender<RaftMetrics<C>>,
rx_shutdown: oneshot::Receiver<()>,
) -> JoinHandle<Result<(), Fatal<C>>> {
Expand Down Expand Up @@ -849,7 +849,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> LeaderS
}

#[tracing::instrument(level = "debug", skip(self, msg), fields(state = "leader", id=display(self.core.id)))]
pub async fn handle_msg(&mut self, msg: RaftMsg<C>) -> Result<(), Fatal<C>> {
pub async fn handle_msg(&mut self, msg: RaftMsg<C, N, S>) -> Result<(), Fatal<C>> {
tracing::debug!("recv from rx_api: {}", msg.summary());

match msg {
Expand Down Expand Up @@ -885,6 +885,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> LeaderS
} => {
self.change_membership(members, blocking, turn_to_learner, tx).await?;
}
RaftMsg::ExternalRequest { req } => {
req(State::Leader, &mut self.core.storage, &mut self.core.network);
}
};

Ok(())
Expand Down Expand Up @@ -1016,7 +1019,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Candida
}

#[tracing::instrument(level = "debug", skip(self, msg), fields(state = "candidate", id=display(self.core.id)))]
pub async fn handle_msg(&mut self, msg: RaftMsg<C>) -> Result<(), Fatal<C>> {
pub async fn handle_msg(&mut self, msg: RaftMsg<C, N, S>) -> Result<(), Fatal<C>> {
tracing::debug!("recv from rx_api: {}", msg.summary());
match msg {
RaftMsg::AppendEntries { rpc, tx } => {
Expand All @@ -1043,6 +1046,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Candida
RaftMsg::ChangeMembership { tx, .. } => {
self.core.reject_with_forward_to_leader(tx);
}
RaftMsg::ExternalRequest { req } => {
req(State::Candidate, &mut self.core.storage, &mut self.core.network);
}
};
Ok(())
}
Expand Down Expand Up @@ -1091,7 +1097,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Followe
}

#[tracing::instrument(level = "debug", skip(self, msg), fields(state = "follower", id=display(self.core.id)))]
pub(crate) async fn handle_msg(&mut self, msg: RaftMsg<C>) -> Result<(), Fatal<C>> {
pub(crate) async fn handle_msg(&mut self, msg: RaftMsg<C, N, S>) -> Result<(), Fatal<C>> {
tracing::debug!("recv from rx_api: {}", msg.summary());

match msg {
Expand Down Expand Up @@ -1119,6 +1125,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Followe
RaftMsg::ChangeMembership { tx, .. } => {
self.core.reject_with_forward_to_leader(tx);
}
RaftMsg::ExternalRequest { req } => {
req(State::Follower, &mut self.core.storage, &mut self.core.network);
}
};
Ok(())
}
Expand Down Expand Up @@ -1165,7 +1174,7 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Learner

// TODO(xp): define a handle_msg method in RaftCore that decides what to do by current State.
#[tracing::instrument(level = "debug", skip(self, msg), fields(state = "learner", id=display(self.core.id)))]
pub(crate) async fn handle_msg(&mut self, msg: RaftMsg<C>) -> Result<(), Fatal<C>> {
pub(crate) async fn handle_msg(&mut self, msg: RaftMsg<C, N, S>) -> Result<(), Fatal<C>> {
tracing::debug!("recv from rx_api: {}", msg.summary());

match msg {
Expand Down Expand Up @@ -1193,6 +1202,9 @@ impl<'a, C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Learner
RaftMsg::ChangeMembership { tx, .. } => {
self.core.reject_with_forward_to_leader(tx);
}
RaftMsg::ExternalRequest { req } => {
req(State::Learner, &mut self.core.storage, &mut self.core.network);
}
};
Ok(())
}
Expand Down
37 changes: 32 additions & 5 deletions openraft/src/raft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::NodeId;
use crate::RaftNetworkFactory;
use crate::RaftStorage;
use crate::SnapshotMeta;
use crate::State;
use crate::Vote;

/// Configuration of types used by the [`Raft`] core engine.
Expand Down Expand Up @@ -104,7 +105,7 @@ macro_rules! declare_raft_types {
}

struct RaftInner<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> {
tx_api: mpsc::UnboundedSender<(RaftMsg<C>, Span)>,
tx_api: mpsc::UnboundedSender<(RaftMsg<C, N, S>, Span)>,
rx_metrics: watch::Receiver<RaftMetrics<C>>,
#[allow(clippy::type_complexity)]
raft_handle: Mutex<Option<JoinHandle<Result<(), Fatal<C>>>>>,
Expand Down Expand Up @@ -407,7 +408,7 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Raft<C, N,

/// Invoke RaftCore by sending a RaftMsg and blocks waiting for response.
#[tracing::instrument(level = "debug", skip(self, mes, rx))]
pub(crate) async fn call_core<T, E>(&self, mes: RaftMsg<C>, rx: RaftRespRx<T, E>) -> Result<T, E>
pub(crate) async fn call_core<T, E>(&self, mes: RaftMsg<C, N, S>, rx: RaftRespRx<T, E>) -> Result<T, E>
where E: From<Fatal<C>> {
let span = tracing::Span::current();

Expand Down Expand Up @@ -451,6 +452,25 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Raft<C, N,
res
}

/// Send a request to the Raft core loop in a fire-and-forget manner.
///
/// The request functor will be called with a mutable reference to both the state machine
/// and the network factory and serialized with other Raft core loop processing (e.g., client
/// requests or general state changes). The current state of the system is passed as well.
///
/// If a response is required, then the caller can store the sender of a one-shot channel
/// in the closure of the request functor, which can then be used to send the response
/// asynchronously.
///
/// If the API channel is already closed (Raft is in shutdown), then the request functor is
/// destroyed right away and not called at all.
pub fn external_request<F: FnOnce(State, &mut S, &mut N) + Send + 'static>(&self, req: F) {
let _ignore_error = self.inner.tx_api.send((
RaftMsg::ExternalRequest { req: Box::new(req) },
tracing::span::Span::none(), // fire-and-forget, so no span
));
}

/// Get a handle to the metrics channel.
pub fn metrics(&self) -> watch::Receiver<RaftMetrics<C>> {
self.inner.rx_metrics.clone()
Expand Down Expand Up @@ -513,7 +533,7 @@ pub struct AddLearnerResponse<C: RaftTypeConfig> {
}

/// A message coming from the Raft API.
pub(crate) enum RaftMsg<C: RaftTypeConfig> {
pub(crate) enum RaftMsg<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> {
AppendEntries {
rpc: AppendEntriesRequest<C>,
tx: RaftRespTx<AppendEntriesResponse<C>, AppendEntriesError<C>>,
Expand Down Expand Up @@ -564,10 +584,16 @@ pub(crate) enum RaftMsg<C: RaftTypeConfig> {

tx: RaftRespTx<ClientWriteResponse<C>, ClientWriteError<C>>,
},
ExternalRequest {
req: Box<dyn FnOnce(State, &mut S, &mut N) + Send + 'static>,
},
}

impl<C> MessageSummary for RaftMsg<C>
where C: RaftTypeConfig
impl<C, N, S> MessageSummary for RaftMsg<C, N, S>
where
C: RaftTypeConfig,
N: RaftNetworkFactory<C>,
S: RaftStorage<C>,
{
fn summary(&self) -> String {
match self {
Expand Down Expand Up @@ -601,6 +627,7 @@ where C: RaftTypeConfig
members, blocking, turn_to_learner,
)
}
RaftMsg::ExternalRequest { .. } => "External Request".to_string(),
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions openraft/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,19 @@ where
}
}

/// Send external request to the particular node.
pub fn external_request<F: FnOnce(State, &mut StoreExt<C, S>, &mut TypedRaftRouter<C, S>) + Send + 'static>(
&self,
target: C::NodeId,
req: F,
) {
let rt = self.routing_table.lock().unwrap();
rt.get(&target)
.unwrap_or_else(|| panic!("node '{}' does not exist in routing table", target))
.0
.external_request(req)
}

/// Request the current leader from the target node.
pub async fn current_leader(&self, target: C::NodeId) -> Option<C::NodeId> {
let node = self.get_raft_handle(&target).unwrap();
Expand Down
43 changes: 43 additions & 0 deletions openraft/tests/initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use openraft::Membership;
use openraft::RaftLogReader;
use openraft::RaftStorage;
use openraft::State;
use tokio::sync::oneshot;

#[macro_use]
mod fixtures;
Expand Down Expand Up @@ -47,6 +48,23 @@ async fn initialization() -> Result<()> {
router.wait_for_state(&btreeset![0, 1, 2], State::Learner, timeout(), "empty").await?;
router.assert_pristine_cluster().await;

// Sending an external requests will also find all nodes in Learner state.
//
// This demonstrates fire-and-forget external request, which will be serialized
// with other processing. It is not required for the correctness of the test
//
// Since the execution of API messages is serialized, even if the request executes
// some unknown time in the future (due to fire-and-forget semantics), it will
// properly receive the state before initialization, as that state will appear
// later in the sequence.
//
// Also, this external request will be definitely executed, since it's ordered
// before other requests in the Raft core API queue, which definitely are executed
// (since they are awaited).
for node in [0, 1, 2] {
router.external_request(node, |s, _sm, _net| assert_eq!(s, State::Learner));
}

// Initialize the cluster, then assert that a stable cluster was formed & held.
tracing::info!("--- initializing cluster");
router.initialize_from_single_node(0).await?;
Expand Down Expand Up @@ -78,6 +96,31 @@ async fn initialization() -> Result<()> {
);
}

// At this time, one of the nodes is the leader, all the others are followers.
// Check via an external request as well. Again, this is not required for the
// correctness of the test.
//
// This demonstrates how to synchronize on the execution of the external
// request by using a oneshot channel.
let mut found_leader = false;
let mut follower_count = 0;
for node in [0, 1, 2] {
let (tx, rx) = oneshot::channel();
router.external_request(node, |s, _sm, _net| tx.send(s).unwrap());
match rx.await.unwrap() {
State::Leader => {
assert!(!found_leader);
found_leader = true;
}
State::Follower => {
follower_count += 1;
}
s => panic!("Unexpected node {} state: {:?}", node, s),
}
}
assert!(found_leader);
assert_eq!(2, follower_count);

Ok(())
}

Expand Down

0 comments on commit 80f8913

Please sign in to comment.