Skip to content

Commit

Permalink
change: use snapshot-id to identify a snapshot stream
Browse files Browse the repository at this point in the history
A snapshot stream should be identified by some id, since the server end
should not assume messages are arrived in the correct order.
Without an id, two `install_snapshot` request belonging to different
snapshot data may corrupt the snapshot data, explicitly or even worse,
silently.

- Add SnapshotId to identify a snapshot stream.

- Add SnapshotSegmentId to identify a segment in a snapshot stream.

- Add field `snapshot_id` to snapshot related data structures.

- Add error `RaftError::SnapshotMismatch`.

- `Storage::create_snapshot()` does not need to return and id.
  Since the receiving end only keeps one snapshot stream session at
  most.
  Instead, `Storage::do_log_compaction()` should build a unique id
  everytime it is called.

- When the raft node receives an `install_snapshot` request, the id must
  match to continue.
  A request with a different id should be rejected.
  A new id with offset=0 indicates the sender has started a new stream.
  In this case, the old unfinished stream is dropped and cleaned.

- Add test for `install_snapshot` API.
  • Loading branch information
drmingdrmer committed Jul 9, 2021
1 parent 6409d10 commit 933e0b3
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 14 deletions.
30 changes: 25 additions & 5 deletions async-raft/src/core/install_snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ use crate::raft::InstallSnapshotRequest;
use crate::raft::InstallSnapshotResponse;
use crate::AppData;
use crate::AppDataResponse;
use crate::RaftError;
use crate::RaftNetwork;
use crate::RaftStorage;
use crate::SnapshotSegmentId;
use crate::Update;

impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> RaftCore<D, R, N, S> {
Expand Down Expand Up @@ -61,23 +63,41 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
}

// Compare current snapshot state with received RPC and handle as needed.
// - Init a new state if it is empty or building a snapshot locally.
// - Mismatched id with offset=0 indicates a new stream has been sent, the old one should be dropped and start
// to receive the new snapshot,
// - Mismatched id with offset greater than 0 is an out of order message that should be rejected.
match self.snapshot_state.take() {
None => Ok(self.begin_installing_snapshot(req).await?),
None => return self.begin_installing_snapshot(req).await,
Some(SnapshotState::Snapshotting { handle, .. }) => {
handle.abort(); // Abort the current compaction in favor of installation from leader.
Ok(self.begin_installing_snapshot(req).await?)
return self.begin_installing_snapshot(req).await;
}
Some(SnapshotState::Streaming { snapshot, id, offset }) => {
Ok(self.continue_installing_snapshot(req, offset, id, snapshot).await?)
if req.snapshot_id == id {
return self.continue_installing_snapshot(req, offset, id, snapshot).await;
}

if req.offset == 0 {
return self.begin_installing_snapshot(req).await;
}

Err(RaftError::SnapshotMismatch {
expect: SnapshotSegmentId { id: id.clone(), offset },
got: SnapshotSegmentId {
id: req.snapshot_id.clone(),
offset: req.offset,
},
})
}
}
}

#[tracing::instrument(level = "trace", skip(self, req))]
async fn begin_installing_snapshot(&mut self, req: InstallSnapshotRequest) -> RaftResult<InstallSnapshotResponse> {
// Create a new snapshot and begin writing its contents.
let (id, mut snapshot) =
self.storage.create_snapshot().await.map_err(|err| self.map_fatal_storage_error(err))?;
let id = req.snapshot_id.clone();
let mut snapshot = self.storage.create_snapshot().await.map_err(|err| self.map_fatal_storage_error(err))?;
snapshot.as_mut().write_all(&req.data).await?;

// If this was a small snapshot, and it is already done, then finish up.
Expand Down
3 changes: 2 additions & 1 deletion async-raft/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
});
tokio::spawn(
async move {
let res = Abortable::new(storage.do_log_compaction(), reg).await;
let f = storage.do_log_compaction();
let res = Abortable::new(f, reg).await;
match res {
Ok(res) => match res {
Ok(snapshot) => {
Expand Down
4 changes: 2 additions & 2 deletions async-raft/src/core/replication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
let res = match event {
ReplicaEvent::RateUpdate { target, is_line_rate } => self.handle_rate_update(target, is_line_rate).await,
ReplicaEvent::RevertToFollower { target, term } => self.handle_revert_to_follower(target, term).await,
ReplicaEvent::UpdateMatchIndex { target, matched } => self.handle_update_match_index(target, matched).await,
ReplicaEvent::UpdateMatchIndex { target, matched } => self.handle_update_matched(target, matched).await,
ReplicaEvent::NeedsSnapshot { target, tx } => self.handle_needs_snapshot(target, tx).await,
ReplicaEvent::Shutdown => {
self.core.set_target_state(State::Shutdown);
Expand Down Expand Up @@ -117,7 +117,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>

/// Handle events from a replication stream which updates the target node's match index.
#[tracing::instrument(level = "trace", skip(self))]
async fn handle_update_match_index(&mut self, target: NodeId, matched: LogId) -> RaftResult<()> {
async fn handle_update_matched(&mut self, target: NodeId, matched: LogId) -> RaftResult<()> {
let mut found = false;

if let Some(state) = self.non_voters.get_mut(&target) {
Expand Down
7 changes: 7 additions & 0 deletions async-raft/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt;

use thiserror::Error;

use crate::raft_types::SnapshotSegmentId;
use crate::AppData;
use crate::NodeId;

Expand All @@ -14,6 +15,12 @@ pub type RaftResult<T> = std::result::Result<T, RaftError>;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RaftError {
// Streaming-snapshot encountered mismatched snapshot_id/offset
#[error("expect: {expect}, got: {got}")]
SnapshotMismatch {
expect: SnapshotSegmentId,
got: SnapshotSegmentId,
},
/// An error which has come from the `RaftStorage` layer.
#[error("{0}")]
RaftStorage(anyhow::Error),
Expand Down
2 changes: 2 additions & 0 deletions async-raft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub use crate::metrics::RaftMetrics;
pub use crate::network::RaftNetwork;
pub use crate::raft::Raft;
pub use crate::raft_types::LogId;
pub use crate::raft_types::SnapshotId;
pub use crate::raft_types::SnapshotSegmentId;
pub use crate::raft_types::Update;
pub use crate::replication::ReplicationMetrics;
pub use crate::storage::RaftStorage;
Expand Down
4 changes: 4 additions & 0 deletions async-raft/src/raft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::error::RaftError;
use crate::error::RaftResult;
use crate::metrics::RaftMetrics;
use crate::metrics::Wait;
use crate::raft_types::SnapshotId;
use crate::AppData;
use crate::AppDataResponse;
use crate::LogId;
Expand Down Expand Up @@ -576,6 +577,9 @@ pub struct InstallSnapshotRequest {
pub term: u64,
/// The leader's ID. Useful in redirecting clients.
pub leader_id: u64,
/// The Id of a snapshot.
/// Every two snapshots should have different snapshot id.
pub snapshot_id: SnapshotId,
/// The snapshot replaces all log entries up through and including this log.
pub last_included: LogId,
/// The byte offset where this chunk of data is positioned in the snapshot file.
Expand Down
34 changes: 34 additions & 0 deletions async-raft/src/raft_types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::fmt::Display;
use std::fmt::Formatter;

use serde::Deserialize;
use serde::Serialize;

Expand All @@ -15,6 +18,37 @@ impl From<(u64, u64)> for LogId {
}
}

impl Display for LogId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}-{}", self.term, self.index)
}
}

// Everytime a snapshot is created, it is assigned with a globally unique id.
pub type SnapshotId = String;

/// The identity of a segment of a snapshot.
#[derive(Debug, Default, Clone, PartialOrd, PartialEq, Eq, Serialize, Deserialize)]
pub struct SnapshotSegmentId {
pub id: SnapshotId,
pub offset: u64,
}

impl<D: ToString> From<(D, u64)> for SnapshotSegmentId {
fn from(v: (D, u64)) -> Self {
SnapshotSegmentId {
id: v.0.to_string(),
offset: v.1,
}
}
}

impl Display for SnapshotSegmentId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}+{}", self.id, self.offset)
}
}

// An update action with option to update with some value or just ignore this update.
#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Serialize, Deserialize)]
pub enum Update<T> {
Expand Down
2 changes: 2 additions & 0 deletions async-raft/src/replication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>

#[tracing::instrument(level = "trace", skip(self, snapshot))]
async fn stream_snapshot(&mut self, mut snapshot: CurrentSnapshotData<S::Snapshot>) -> RaftResult<()> {
let snapshot_id = snapshot.snapshot_id.clone();
let mut offset = 0;
self.core.next_index = snapshot.index + 1;
self.core.matched = (snapshot.term, snapshot.index).into();
Expand All @@ -842,6 +843,7 @@ impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>
let req = InstallSnapshotRequest {
term: self.core.term,
leader_id: self.core.id,
snapshot_id: snapshot_id.clone(),
last_included: (snapshot.term, snapshot.index).into(),
offset,
data: Vec::from(&buf[..nread]),
Expand Down
9 changes: 6 additions & 3 deletions async-raft/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use tokio::io::AsyncWrite;

use crate::raft::Entry;
use crate::raft::MembershipConfig;
use crate::raft_types::SnapshotId;
use crate::AppData;
use crate::AppDataResponse;
use crate::NodeId;
Expand All @@ -26,6 +27,9 @@ where S: AsyncRead + AsyncSeek + Send + Unpin + 'static
pub index: u64,
/// The latest membership configuration covered by the snapshot.
pub membership: MembershipConfig,

pub snapshot_id: SnapshotId,

/// A read handle to the associated snapshot.
pub snapshot: Box<S>,
}
Expand Down Expand Up @@ -200,15 +204,14 @@ where
/// Errors returned from this method will be logged and retried.
async fn do_log_compaction(&self) -> Result<CurrentSnapshotData<Self::Snapshot>>;

/// Create a new blank snapshot, returning a writable handle to the snapshot object along with
/// the ID of the snapshot.
/// Create a new blank snapshot, returning a writable handle to the snapshot object.
///
/// ### implementation guide
/// See the [storage chapter of the guide](https://async-raft.github.io/async-raft/storage.html)
/// for details on log compaction / snapshotting.
///
/// Errors returned from this method will cause Raft to go into shutdown.
async fn create_snapshot(&self) -> Result<(String, Box<Self::Snapshot>)>;
async fn create_snapshot(&self) -> Result<Box<Self::Snapshot>>;

/// Finalize the installation of a snapshot which has finished streaming from the cluster leader.
///
Expand Down
91 changes: 91 additions & 0 deletions async-raft/tests/api_install_snapshot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
mod fixtures;

use std::sync::Arc;

use anyhow::Result;
use async_raft::raft::InstallSnapshotRequest;
use async_raft::Config;
use async_raft::LogId;
use async_raft::State;
use fixtures::RaftRouter;
use maplit::hashset;

/// API test: install_snapshot with various condition.
///
/// What does this test do?
///
/// - build a stable single node cluster.
/// - send install_snapshot request with matched/mismatched id and offset
///
/// export RUST_LOG=async_raft,memstore,snapshot_ge_half_threshold=trace
/// cargo test -p async-raft --test snapshot_ge_half_threshold
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn snapshot_ge_half_threshold() -> Result<()> {
fixtures::init_tracing();

let config = Arc::new(Config::build("test".into()).validate().expect("failed to build Raft config"));
let router = Arc::new(RaftRouter::new(config.clone()));

let mut want = 0;

tracing::info!("--- initializing cluster");
{
router.new_raft_node(0).await;

router.wait_for_log(&hashset![0], want, None, "empty").await?;
router.wait_for_state(&hashset![0], State::NonVoter, None, "empty").await?;

router.initialize_from_single_node(0).await?;
want += 1;

router.wait_for_log(&hashset![0], want, None, "init leader").await?;
router.assert_stable_cluster(Some(1), Some(want)).await;
}

let n = router.remove_node(0).await.ok_or_else(|| anyhow::anyhow!("node not found"))?;
let req0 = InstallSnapshotRequest {
term: 1,
leader_id: 0,
snapshot_id: "ss1".into(),
last_included: LogId { term: 1, index: 0 },
offset: 0,
data: vec![1, 2, 3],
done: false,
};
tracing::info!("--- install and write ss1:[0,3)");
{
let req = req0.clone();
n.0.install_snapshot(req).await?;
}

tracing::info!("-- continue write with different id");
{
let mut req = req0.clone();
req.offset = 3;
req.snapshot_id = "ss2".into();
let res = n.0.install_snapshot(req).await;
assert_eq!("expect: ss1+3, got: ss2+3", res.unwrap_err().to_string());
}

tracing::info!("-- write from offset=0 with different id, create a new session");
{
let mut req = req0.clone();
req.offset = 0;
req.snapshot_id = "ss2".into();
n.0.install_snapshot(req).await?;

let mut req = req0.clone();
req.offset = 3;
req.snapshot_id = "ss2".into();
n.0.install_snapshot(req).await?;
}

tracing::info!("-- continue write with mismatched offset is allowed");
{
let mut req = req0.clone();
req.offset = 8;
req.snapshot_id = "ss2".into();
n.0.install_snapshot(req).await?;
}
Ok(())
}
Loading

0 comments on commit 933e0b3

Please sign in to comment.