diff --git a/Cargo.lock b/Cargo.lock index 9348db394a09bb..8e0d13a53004f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5506,6 +5506,7 @@ dependencies = [ "spl-token-2022", "static_assertions", "tempfile", + "test-case", "thiserror", "tokio", "tokio-stream", diff --git a/ledger/Cargo.toml b/ledger/Cargo.toml index c8f16585eef955..915bf2038de3c8 100644 --- a/ledger/Cargo.toml +++ b/ledger/Cargo.toml @@ -73,6 +73,7 @@ bs58 = "0.4.0" matches = "0.1.9" solana-account-decoder = { path = "../account-decoder", version = "=1.12.0" } solana-logger = { path = "../logger", version = "=1.12.0" } +test-case = "2.1.0" [build-dependencies] rustc_version = "0.4" diff --git a/ledger/src/blockstore.rs b/ledger/src/blockstore.rs index 5c246b5ab8045c..5bddc02bb90a4e 100644 --- a/ledger/src/blockstore.rs +++ b/ledger/src/blockstore.rs @@ -625,7 +625,7 @@ impl Blockstore { index: &mut Index, erasure_meta: &ErasureMeta, prev_inserted_shreds: &HashMap, - recovered_data_shreds: &mut Vec, + recovered_shreds: &mut Vec, data_cf: &LedgerColumn, code_cf: &LedgerColumn, ) { @@ -646,9 +646,9 @@ impl Blockstore { code_cf, )) .collect(); - if let Ok(mut result) = Shredder::try_recovery(available_shreds) { + if let Ok(mut result) = shred::recover(available_shreds) { Self::submit_metrics(slot, erasure_meta, true, "complete".into(), result.len()); - recovered_data_shreds.append(&mut result); + recovered_shreds.append(&mut result); } else { Self::submit_metrics(slot, erasure_meta, true, "incomplete".into(), 0); } @@ -709,7 +709,7 @@ impl Blockstore { ) -> Vec { let data_cf = db.column::(); let code_cf = db.column::(); - let mut recovered_data_shreds = vec![]; + let mut recovered_shreds = vec![]; // Recovery rules: // 1. Only try recovery around indexes for which new data or coding shreds are received // 2. For new data shreds, check if an erasure set exists. If not, don't try recovery @@ -725,7 +725,7 @@ impl Blockstore { index, erasure_meta, prev_inserted_shreds, - &mut recovered_data_shreds, + &mut recovered_shreds, &data_cf, &code_cf, ); @@ -744,7 +744,7 @@ impl Blockstore { } }; } - recovered_data_shreds + recovered_shreds } /// The main helper function that performs the shred insertion logic @@ -888,15 +888,18 @@ impl Blockstore { metrics.insert_shreds_elapsed_us += start.as_us(); let mut start = Measure::start("Shred recovery"); if let Some(leader_schedule_cache) = leader_schedule { - let recovered_data_shreds = Self::try_shred_recovery( + let recovered_shreds = Self::try_shred_recovery( db, &erasure_metas, &mut index_working_set, &just_inserted_shreds, ); - metrics.num_recovered += recovered_data_shreds.len(); - let recovered_data_shreds: Vec<_> = recovered_data_shreds + metrics.num_recovered += recovered_shreds + .iter() + .filter(|shred| shred.is_data()) + .count(); + let recovered_shreds: Vec<_> = recovered_shreds .into_iter() .filter_map(|shred| { let leader = @@ -905,6 +908,12 @@ impl Blockstore { metrics.num_recovered_failed_sig += 1; return None; } + // Since the data shreds are fully recovered from the + // erasure batch, no need to store coding shreds in + // blockstore. + if shred.is_code() { + return Some(shred); + } match self.check_insert_data_shred( shred.clone(), &mut erasure_metas, @@ -941,10 +950,10 @@ impl Blockstore { // Always collect recovered-shreds so that above insert code is // executed even if retransmit-sender is None. .collect(); - if !recovered_data_shreds.is_empty() { + if !recovered_shreds.is_empty() { if let Some(retransmit_sender) = retransmit_sender { let _ = retransmit_sender.send( - recovered_data_shreds + recovered_shreds .into_iter() .map(Shred::into_payload) .collect(), diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index e17055b1e7d9a9..cee63cb45df57d 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -61,6 +61,7 @@ use { crate::blockstore::{self, MAX_DATA_SHREDS_PER_SLOT}, bitflags::bitflags, num_enum::{IntoPrimitive, TryFromPrimitive}, + reed_solomon_erasure::Error::TooFewShardsPresent, serde::{Deserialize, Serialize}, solana_entry::entry::{create_ticks, Entry}, solana_perf::packet::Packet, @@ -144,6 +145,10 @@ pub enum Error { InvalidPayloadSize(/*payload size:*/ usize), #[error("Invalid proof size: {0}")] InvalidProofSize(/*proof_size:*/ u8), + #[error("Invalid recovered shred")] + InvalidRecoveredShred, + #[error("Invalid shard size: {0}")] + InvalidShardSize(/*shard_size:*/ usize), #[error("Invalid shred flags: {0}")] InvalidShredFlags(u8), #[error("Invalid {0:?} shred index: {1}")] @@ -211,7 +216,7 @@ struct DataShredHeader { struct CodingShredHeader { num_data_shreds: u16, num_coding_shreds: u16, - position: u16, + position: u16, // [0..num_coding_shreds) } #[derive(Clone, Debug, PartialEq, Eq)] @@ -294,6 +299,8 @@ macro_rules! dispatch { } } +use dispatch; + impl Shred { dispatch!(fn common_header(&self) -> &ShredCommonHeader); dispatch!(fn set_signature(&mut self, signature: Signature)); @@ -494,6 +501,7 @@ impl Shred { } } + #[must_use] pub fn verify(&self, pubkey: &Pubkey) -> bool { let message = self.signed_message(); self.signature().verify(pubkey.as_ref(), message) @@ -642,6 +650,28 @@ impl From for Shred { } } +impl From for Shred { + fn from(shred: merkle::Shred) -> Self { + match shred { + merkle::Shred::ShredCode(shred) => Self::ShredCode(ShredCode::Merkle(shred)), + merkle::Shred::ShredData(shred) => Self::ShredData(ShredData::Merkle(shred)), + } + } +} + +impl TryFrom for merkle::Shred { + type Error = Error; + + fn try_from(shred: Shred) -> Result { + match shred { + Shred::ShredCode(ShredCode::Legacy(_)) => Err(Error::InvalidShredVariant), + Shred::ShredCode(ShredCode::Merkle(shred)) => Ok(Self::ShredCode(shred)), + Shred::ShredData(ShredData::Legacy(_)) => Err(Error::InvalidShredVariant), + Shred::ShredData(ShredData::Merkle(shred)) => Ok(Self::ShredData(shred)), + } + } +} + impl From for ShredType { #[inline] fn from(shred_variant: ShredVariant) -> Self { @@ -682,6 +712,27 @@ impl TryFrom for ShredVariant { } } +pub(crate) fn recover(shreds: Vec) -> Result, Error> { + match shreds + .first() + .ok_or(TooFewShardsPresent)? + .common_header() + .shred_variant + { + ShredVariant::LegacyData | ShredVariant::LegacyCode => Shredder::try_recovery(shreds), + ShredVariant::MerkleCode(_) | ShredVariant::MerkleData(_) => { + let shreds = shreds + .into_iter() + .map(merkle::Shred::try_from) + .collect::>()?; + Ok(merkle::recover(shreds)? + .into_iter() + .map(Shred::from) + .collect()) + } + } +} + // Accepts shreds in the slot range [root + 1, max_slot]. #[must_use] pub fn should_discard_shred( diff --git a/ledger/src/shred/merkle.rs b/ledger/src/shred/merkle.rs index 5b224632a2a880..9d0482b95354a5 100644 --- a/ledger/src/shred/merkle.rs +++ b/ledger/src/shred/merkle.rs @@ -1,12 +1,20 @@ +#[cfg(test)] +use {crate::shred::ShredType, solana_sdk::pubkey::Pubkey}; use { - crate::shred::{ - common::impl_shred_common, - shred_code, shred_data, - traits::{Shred, ShredCode as ShredCodeTrait, ShredData as ShredDataTrait}, - CodingShredHeader, DataShredHeader, Error, ShredCommonHeader, ShredFlags, ShredVariant, - SIZE_OF_CODING_SHRED_HEADERS, SIZE_OF_COMMON_SHRED_HEADER, SIZE_OF_DATA_SHRED_HEADERS, - SIZE_OF_SIGNATURE, + crate::{ + shred::{ + common::impl_shred_common, + dispatch, shred_code, shred_data, + traits::{ + Shred as ShredTrait, ShredCode as ShredCodeTrait, ShredData as ShredDataTrait, + }, + CodingShredHeader, DataShredHeader, Error, ShredCommonHeader, ShredFlags, ShredVariant, + SIZE_OF_CODING_SHRED_HEADERS, SIZE_OF_COMMON_SHRED_HEADER, SIZE_OF_DATA_SHRED_HEADERS, + SIZE_OF_SIGNATURE, + }, + shredder::ReedSolomon, }, + reed_solomon_erasure::Error::{InvalidIndex, TooFewParityShards, TooFewShards}, solana_perf::packet::deserialize_from_with_limit, solana_sdk::{ clock::Slot, @@ -58,12 +66,58 @@ pub struct ShredCode { payload: Vec, } +#[derive(Clone, Debug, Eq, PartialEq)] +pub(super) enum Shred { + ShredCode(ShredCode), + ShredData(ShredData), +} + #[derive(Clone, Debug, Eq, PartialEq)] struct MerkleBranch { root: MerkleRoot, proof: Vec, } +impl Shred { + dispatch!(fn common_header(&self) -> &ShredCommonHeader); + dispatch!(fn erasure_shard_as_slice(&self) -> Result<&[u8], Error>); + dispatch!(fn erasure_shard_index(&self) -> Result); + dispatch!(fn merkle_tree_node(&self) -> Result); + dispatch!(fn sanitize(&self) -> Result<(), Error>); + dispatch!(fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error>); + + fn merkle_root(&self) -> &MerkleRoot { + match self { + Self::ShredCode(shred) => &shred.merkle_branch.root, + Self::ShredData(shred) => &shred.merkle_branch.root, + } + } +} + +#[cfg(test)] +impl Shred { + dispatch!(fn set_signature(&mut self, signature: Signature)); + dispatch!(fn signed_message(&self) -> &[u8]); + + fn index(&self) -> u32 { + self.common_header().index + } + + fn shred_type(&self) -> ShredType { + ShredType::from(self.common_header().shred_variant) + } + + fn signature(&self) -> Signature { + self.common_header().signature + } + + #[must_use] + fn verify(&self, pubkey: &Pubkey) -> bool { + let message = self.signed_message(); + self.signature().verify(pubkey.as_ref(), message) + } +} + impl ShredData { // proof_size is the number of proof entries in the merkle tree branch. fn proof_size(&self) -> Result { @@ -104,6 +158,52 @@ impl ShredData { let index = self.erasure_shard_index()?; Ok(verify_merkle_proof(index, node, &self.merkle_branch)) } + + fn from_recovered_shard(signature: &Signature, mut shard: Vec) -> Result { + let shard_size = shard.len(); + if shard_size + SIZE_OF_SIGNATURE > Self::SIZE_OF_PAYLOAD { + return Err(Error::InvalidShardSize(shard_size)); + } + shard.resize(Self::SIZE_OF_PAYLOAD, 0u8); + shard.copy_within(0..shard_size, SIZE_OF_SIGNATURE); + shard[0..SIZE_OF_SIGNATURE].copy_from_slice(signature.as_ref()); + // Deserialize headers. + let mut cursor = Cursor::new(&shard[..]); + let common_header: ShredCommonHeader = deserialize_from_with_limit(&mut cursor)?; + let proof_size = match common_header.shred_variant { + ShredVariant::MerkleData(proof_size) => proof_size, + _ => return Err(Error::InvalidShredVariant), + }; + if ShredCode::capacity(proof_size)? != shard_size { + return Err(Error::InvalidShardSize(shard_size)); + } + let data_header = deserialize_from_with_limit(&mut cursor)?; + Ok(Self { + common_header, + data_header, + merkle_branch: MerkleBranch::new_zeroed(proof_size), + payload: shard, + }) + } + + fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error> { + let proof_size = self.proof_size()?; + if merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidMerkleProof); + } + let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size)?; + let mut cursor = Cursor::new( + self.payload + .get_mut(offset..) + .ok_or(Error::InvalidProofSize(proof_size))?, + ); + bincode::serialize_into(&mut cursor, &merkle_branch.root)?; + for entry in &merkle_branch.proof { + bincode::serialize_into(&mut cursor, entry)?; + } + self.merkle_branch = merkle_branch; + Ok(()) + } } impl ShredCode { @@ -154,9 +254,66 @@ impl ShredCode { || self.merkle_branch.root != other.merkle_branch.root || self.common_header.signature != other.common_header.signature } + + fn from_recovered_shard( + common_header: ShredCommonHeader, + coding_header: CodingShredHeader, + mut shard: Vec, + ) -> Result { + let proof_size = match common_header.shred_variant { + ShredVariant::MerkleCode(proof_size) => proof_size, + _ => return Err(Error::InvalidShredVariant), + }; + let shard_size = shard.len(); + if Self::capacity(proof_size)? != shard_size { + return Err(Error::InvalidShardSize(shard_size)); + } + if shard_size + Self::SIZE_OF_HEADERS > Self::SIZE_OF_PAYLOAD { + return Err(Error::InvalidShardSize(shard_size)); + } + shard.resize(Self::SIZE_OF_PAYLOAD, 0u8); + shard.copy_within(0..shard_size, Self::SIZE_OF_HEADERS); + let mut cursor = Cursor::new(&mut shard[..]); + bincode::serialize_into(&mut cursor, &common_header)?; + bincode::serialize_into(&mut cursor, &coding_header)?; + Ok(Self { + common_header, + coding_header, + merkle_branch: MerkleBranch::new_zeroed(proof_size), + payload: shard, + }) + } + + fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error> { + let proof_size = self.proof_size()?; + if merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidMerkleProof); + } + let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size)?; + let mut cursor = Cursor::new( + self.payload + .get_mut(offset..) + .ok_or(Error::InvalidProofSize(proof_size))?, + ); + bincode::serialize_into(&mut cursor, &merkle_branch.root)?; + for entry in &merkle_branch.proof { + bincode::serialize_into(&mut cursor, entry)?; + } + self.merkle_branch = merkle_branch; + Ok(()) + } } -impl Shred for ShredData { +impl MerkleBranch { + fn new_zeroed(proof_size: u8) -> Self { + Self { + root: MerkleRoot::default(), + proof: vec![MerkleProofEntry::default(); usize::from(proof_size)], + } + } +} + +impl ShredTrait for ShredData { impl_shred_common!(); // Also equal to: @@ -249,7 +406,7 @@ impl Shred for ShredData { } } -impl Shred for ShredCode { +impl ShredTrait for ShredCode { impl_shred_common!(); const SIZE_OF_PAYLOAD: usize = shred_code::ShredCode::SIZE_OF_PAYLOAD; const SIZE_OF_HEADERS: usize = SIZE_OF_CODING_SHRED_HEADERS; @@ -391,7 +548,6 @@ fn verify_merkle_proof(index: usize, node: Hash, merkle_branch: &MerkleBranch) - (index, root) == (0usize, &merkle_branch.root[..]) } -#[cfg(test)] fn make_merkle_tree(mut nodes: Vec) -> Vec { let mut size = nodes.len(); while size > 1 { @@ -407,7 +563,6 @@ fn make_merkle_tree(mut nodes: Vec) -> Vec { nodes } -#[cfg(test)] fn make_merkle_branch( mut index: usize, // leaf index ~ shred's erasure shard index. mut size: usize, // number of leaves ~ erasure batch size. @@ -434,9 +589,170 @@ fn make_merkle_branch( Some(MerkleBranch { root, proof }) } +pub(super) fn recover(mut shreds: Vec) -> Result, Error> { + // Grab {common, coding} headers from first coding shred. + let headers = shreds.iter().find_map(|shred| { + let shred = match shred { + Shred::ShredCode(shred) => shred, + Shred::ShredData(_) => return None, + }; + let position = u32::from(shred.coding_header.position); + let common_header = ShredCommonHeader { + index: shred.common_header.index.checked_sub(position)?, + ..shred.common_header + }; + let coding_header = CodingShredHeader { + position: 0u16, + ..shred.coding_header + }; + Some((common_header, coding_header)) + }); + let (common_header, coding_header) = headers.ok_or(TooFewParityShards)?; + debug_assert!(matches!( + common_header.shred_variant, + ShredVariant::MerkleCode(_) + )); + let proof_size = match common_header.shred_variant { + ShredVariant::MerkleCode(proof_size) => proof_size, + ShredVariant::MerkleData(_) | ShredVariant::LegacyCode | ShredVariant::LegacyData => { + return Err(Error::InvalidShredVariant); + } + }; + // Verify that shreds belong to the same erasure batch + // and have consistent headers. + debug_assert!(shreds.iter().all(|shred| { + let ShredCommonHeader { + signature, + shred_variant, + slot, + index: _, + version, + fec_set_index, + } = shred.common_header(); + signature == &common_header.signature + && slot == &common_header.slot + && version == &common_header.version + && fec_set_index == &common_header.fec_set_index + && match shred { + Shred::ShredData(_) => shred_variant == &ShredVariant::MerkleData(proof_size), + Shred::ShredCode(shred) => { + let CodingShredHeader { + num_data_shreds, + num_coding_shreds, + position: _, + } = shred.coding_header; + shred_variant == &ShredVariant::MerkleCode(proof_size) + && num_data_shreds == coding_header.num_data_shreds + && num_coding_shreds == coding_header.num_coding_shreds + } + } + })); + let num_data_shreds = usize::from(coding_header.num_data_shreds); + let num_coding_shreds = usize::from(coding_header.num_coding_shreds); + let num_shards = num_data_shreds + num_coding_shreds; + // Obtain erasure encoded shards from shreds. + let shreds = { + let mut batch = vec![None; num_shards]; + while let Some(shred) = shreds.pop() { + let index = match shred.erasure_shard_index() { + Ok(index) if index < batch.len() => index, + _ => return Err(Error::from(InvalidIndex)), + }; + batch[index] = Some(shred); + } + batch + }; + let mut shards: Vec>> = shreds + .iter() + .map(|shred| Some(shred.as_ref()?.erasure_shard_as_slice().ok()?.to_vec())) + .collect(); + ReedSolomon::new(num_data_shreds, num_coding_shreds)?.reconstruct(&mut shards)?; + let mask: Vec<_> = shreds.iter().map(Option::is_some).collect(); + // Reconstruct code and data shreds from erasure encoded shards. + let mut shreds: Vec<_> = shreds + .into_iter() + .zip(shards) + .enumerate() + .map(|(index, (shred, shard))| { + if let Some(shred) = shred { + return Ok(shred); + } + let shard = shard.ok_or(TooFewShards)?; + if index < num_data_shreds { + let shred = ShredData::from_recovered_shard(&common_header.signature, shard)?; + let ShredCommonHeader { + signature: _, + shred_variant, + slot, + index: _, + version, + fec_set_index, + } = shred.common_header; + if shred_variant != ShredVariant::MerkleData(proof_size) + || common_header.slot != slot + || common_header.version != version + || common_header.fec_set_index != fec_set_index + { + return Err(Error::InvalidRecoveredShred); + } + Ok(Shred::ShredData(shred)) + } else { + let offset = index - num_data_shreds; + let coding_header = CodingShredHeader { + position: offset as u16, + ..coding_header + }; + let common_header = ShredCommonHeader { + index: common_header.index + offset as u32, + ..common_header + }; + let shred = ShredCode::from_recovered_shard(common_header, coding_header, shard)?; + Ok(Shred::ShredCode(shred)) + } + }) + .collect::>()?; + // Compute merkle tree and set the merkle branch on the recovered shreds. + let nodes: Vec<_> = shreds + .iter() + .map(Shred::merkle_tree_node) + .collect::>()?; + let tree = make_merkle_tree(nodes); + let merkle_root = &tree.last().unwrap().as_ref()[..SIZE_OF_MERKLE_ROOT]; + let merkle_root = MerkleRoot::try_from(merkle_root).unwrap(); + for (index, (shred, mask)) in shreds.iter_mut().zip(&mask).enumerate() { + if *mask { + if shred.merkle_root() != &merkle_root { + return Err(Error::InvalidMerkleProof); + } + } else { + let merkle_branch = + make_merkle_branch(index, num_shards, &tree).ok_or(Error::InvalidMerkleProof)?; + if merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidMerkleProof); + } + shred.set_merkle_branch(merkle_branch)?; + } + } + // TODO: No need to verify merkle proof in sanitize here. + shreds + .into_iter() + .zip(mask) + .filter(|(_, mask)| !mask) + .map(|(shred, _)| shred.sanitize().map(|_| shred)) + .collect() +} + #[cfg(test)] mod test { - use {super::*, rand::Rng, std::iter::repeat_with}; + use { + super::*, + itertools::Itertools, + matches::assert_matches, + rand::{seq::SliceRandom, CryptoRng, Rng}, + solana_sdk::signature::{Keypair, Signer}, + std::{cmp::Ordering, iter::repeat_with}, + test_case::test_case, + }; // Total size of a data shred including headers and merkle branch. fn shred_data_size_of_payload(proof_size: u8) -> usize { @@ -525,4 +841,153 @@ mod test { run_merkle_tree_round_trip(size); } } + + #[test_case(37)] + #[test_case(64)] + #[test_case(73)] + fn test_recover_merkle_shreds(num_shreds: usize) { + let mut rng = rand::thread_rng(); + for num_data_shreds in 1..num_shreds { + let num_coding_shreds = num_shreds - num_data_shreds; + run_recover_merkle_shreds(&mut rng, num_data_shreds, num_coding_shreds); + } + } + + fn run_recover_merkle_shreds( + rng: &mut R, + num_data_shreds: usize, + num_coding_shreds: usize, + ) { + let keypair = Keypair::generate(rng); + let num_shreds = num_data_shreds + num_coding_shreds; + let proof_size = (num_shreds as f64).log2().ceil() as u8; + let capacity = ShredData::capacity(proof_size).unwrap(); + let common_header = ShredCommonHeader { + signature: Signature::default(), + shred_variant: ShredVariant::MerkleData(proof_size), + slot: 145865705, + index: 1835, + version: 4978, + fec_set_index: 1835, + }; + let data_header = DataShredHeader { + parent_offset: 25, + flags: unsafe { ShredFlags::from_bits_unchecked(0b0010_1010) }, + size: 0, + }; + let coding_header = CodingShredHeader { + num_data_shreds: num_data_shreds as u16, + num_coding_shreds: num_coding_shreds as u16, + position: 0, + }; + let mut shreds = Vec::with_capacity(num_shreds); + for i in 0..num_data_shreds { + let common_header = ShredCommonHeader { + index: common_header.index + i as u32, + ..common_header + }; + let size = ShredData::SIZE_OF_HEADERS + rng.gen_range(0, capacity); + let data_header = DataShredHeader { + size: size as u16, + ..data_header + }; + let mut payload = vec![0u8; ShredData::SIZE_OF_PAYLOAD]; + let mut cursor = Cursor::new(&mut payload[..]); + bincode::serialize_into(&mut cursor, &common_header).unwrap(); + bincode::serialize_into(&mut cursor, &data_header).unwrap(); + rng.fill(&mut payload[ShredData::SIZE_OF_HEADERS..size]); + let shred = ShredData { + common_header, + data_header, + merkle_branch: MerkleBranch::new_zeroed(proof_size), + payload, + }; + shreds.push(Shred::ShredData(shred)); + } + let data: Vec<_> = shreds + .iter() + .map(Shred::erasure_shard_as_slice) + .collect::>() + .unwrap(); + let mut parity = vec![vec![0u8; data[0].len()]; num_coding_shreds]; + ReedSolomon::new(num_data_shreds, num_coding_shreds) + .unwrap() + .encode_sep(&data, &mut parity[..]) + .unwrap(); + for (i, code) in parity.into_iter().enumerate() { + let common_header = ShredCommonHeader { + shred_variant: ShredVariant::MerkleCode(proof_size), + index: common_header.index + i as u32 + 7, + ..common_header + }; + let coding_header = CodingShredHeader { + position: i as u16, + ..coding_header + }; + let mut payload = vec![0u8; ShredCode::SIZE_OF_PAYLOAD]; + let mut cursor = Cursor::new(&mut payload[..]); + bincode::serialize_into(&mut cursor, &common_header).unwrap(); + bincode::serialize_into(&mut cursor, &coding_header).unwrap(); + payload[ShredCode::SIZE_OF_HEADERS..ShredCode::SIZE_OF_HEADERS + code.len()] + .copy_from_slice(&code); + let shred = ShredCode { + common_header, + coding_header, + merkle_branch: MerkleBranch::new_zeroed(proof_size), + payload, + }; + shreds.push(Shred::ShredCode(shred)); + } + let nodes: Vec<_> = shreds + .iter() + .map(Shred::merkle_tree_node) + .collect::>() + .unwrap(); + let tree = make_merkle_tree(nodes); + for (index, shred) in shreds.iter_mut().enumerate() { + let merkle_branch = make_merkle_branch(index, num_shreds, &tree).unwrap(); + assert_eq!(merkle_branch.proof.len(), usize::from(proof_size)); + shred.set_merkle_branch(merkle_branch).unwrap(); + let signature = keypair.sign_message(shred.signed_message()); + shred.set_signature(signature); + assert!(shred.verify(&keypair.pubkey())); + assert_matches!(shred.sanitize(), Ok(())); + } + assert_eq!(shreds.iter().map(Shred::signature).dedup().count(), 1); + for size in num_data_shreds..num_shreds { + let mut shreds = shreds.clone(); + let mut removed_shreds = Vec::new(); + while shreds.len() > size { + let index = rng.gen_range(0, shreds.len()); + removed_shreds.push(shreds.swap_remove(index)); + } + shreds.shuffle(rng); + // Should at least contain one coding shred. + if shreds.iter().all(|shred| { + matches!( + shred.common_header().shred_variant, + ShredVariant::MerkleData(_) + ) + }) { + assert_matches!( + recover(shreds), + Err(Error::ErasureError(TooFewParityShards)) + ); + continue; + } + let recovered_shreds = recover(shreds).unwrap(); + assert_eq!(size + recovered_shreds.len(), num_shreds); + assert_eq!(recovered_shreds.len(), removed_shreds.len()); + removed_shreds.sort_by(|a, b| { + if a.shred_type() == b.shred_type() { + a.index().cmp(&b.index()) + } else if a.shred_type() == ShredType::Data { + Ordering::Less + } else { + Ordering::Greater + } + }); + assert_eq!(recovered_shreds, removed_shreds); + } + } } diff --git a/ledger/src/shredder.rs b/ledger/src/shredder.rs index d3a50cb82dc1ca..671cc0b7c44c47 100644 --- a/ledger/src/shredder.rs +++ b/ledger/src/shredder.rs @@ -33,7 +33,7 @@ const ERASURE_BATCH_SIZE: [usize; 33] = [ 55, 56, 58, 59, 60, 62, 63, 64, // 32 ]; -type ReedSolomon = reed_solomon_erasure::ReedSolomon; +pub(crate) type ReedSolomon = reed_solomon_erasure::ReedSolomon; #[derive(Debug)] pub struct Shredder {