Skip to content

Commit

Permalink
feat: resharding v3 read-only state split in test loop (#12211)
Browse files Browse the repository at this point in the history
Freeze memtries for source shard and create correct memtries for target
shards after resharding, so that chain state can be properly read after
resharding.

This is the major part of "early MVP", after which we can consider
testing resharding on real nodes. While resharding isn't properly
implemented for receipts, and state uses only frozen memtrie, the live
test can already be useful for catching bugs.

Couple more changes:
* use `ChainStoreUpdate` for "chain resharding updates". This is because
`ChainStore` has cache for chunk extras which also needs to be updated.
* on validating chunk endorsements, getting shard id and index after we
know that chunk is new,
* use correct boundary account and intervals for state split,
* checking that for all shards, state is non-empty.
  • Loading branch information
Longarithm authored Oct 15, 2024
1 parent e0d9637 commit 3221b86
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 56 deletions.
2 changes: 2 additions & 0 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1978,7 +1978,9 @@ impl Chain {

if need_storage_update {
// TODO(#12019): consider adding to catchup flow.
let chain_store_update = self.chain_store.store_update();
self.resharding_manager.process_memtrie_resharding_storage_update(
chain_store_update,
&block,
shard_uid,
self.runtime_adapter.get_tries(),
Expand Down
93 changes: 50 additions & 43 deletions chain/chain/src/resharding/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use near_epoch_manager::EpochManagerAdapter;
use near_primitives::block::Block;
use near_primitives::challenge::PartialState;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::get_block_shard_uid;
use near_primitives::stateless_validation::stored_chunk_state_transition_data::StoredChunkStateTransitionData;
use near_primitives::shard_layout::{get_block_shard_uid, ShardLayout};
use near_primitives::types::chunk_extra::ChunkExtra;
use near_primitives::utils::get_block_shard_id;
use near_store::adapter::StoreUpdateAdapter;
use near_store::trie::mem::resharding::RetainMode;
use near_store::{DBCol, PartialStorage, ShardTries, ShardUId, Store};

use crate::ChainStoreUpdate;

pub struct ReshardingManager {
store: Store,
epoch_manager: Arc<dyn EpochManagerAdapter>,
Expand All @@ -40,60 +40,71 @@ impl ReshardingManager {
/// created later.
pub fn process_memtrie_resharding_storage_update(
&mut self,
mut chain_store_update: ChainStoreUpdate,
block: &Block,
shard_uid: ShardUId,
tries: ShardTries,
) -> Result<(), Error> {
let block_hash = block.hash();
let block_height = block.header().height();
let prev_hash = block.header().prev_hash();
let shard_layout = self.epoch_manager.get_shard_layout(&block.header().epoch_id())?;
let next_epoch_id = self.epoch_manager.get_next_epoch_id_from_prev_block(prev_hash)?;
let next_shard_layout = self.epoch_manager.get_shard_layout(&next_epoch_id)?;

let next_block_has_new_shard_layout =
self.epoch_manager.will_shard_layout_change(prev_hash)?
&& self.epoch_manager.is_next_block_epoch_start(block.hash())?;
self.epoch_manager.is_next_block_epoch_start(block_hash)?
&& shard_layout != next_shard_layout;
if !next_block_has_new_shard_layout {
return Ok(());
}

let next_epoch_id = self.epoch_manager.get_next_epoch_id_from_prev_block(prev_hash)?;
let next_shard_layout = self.epoch_manager.get_shard_layout(&next_epoch_id)?;

// Hack to ensure this logic is not applied before ReshardingV3.
// TODO(#12019): proper logic.
if next_shard_layout.version() < 3 {
if !matches!(next_shard_layout, ShardLayout::V2(_)) {
return Ok(());
}

let resharding_event_type =
ReshardingEventType::from_shard_layout(&next_shard_layout, *block_hash, *prev_hash)?;
let Some(ReshardingEventType::SplitShard(split_shard_event)) = resharding_event_type else {
return Ok(());
};
if split_shard_event.parent_shard != shard_uid {
return Ok(());
}

let chunk_extra = self.get_chunk_extra(block_hash, &shard_uid)?;
let Some(mem_tries) = tries.get_mem_tries(shard_uid) else {
// TODO(#12019): what if node doesn't have memtrie? just pause
// processing?
tracing::error!(
"Memtrie not loaded. Cannot process memtrie resharding storage
update for block {:?}, shard {:?}",
block_hash,
shard_uid
);
return Err(Error::Other("Memtrie not loaded".to_string()));
};
// TODO(#12019): what if node doesn't have memtrie? just pause
// processing?
// TODO(#12019): fork handling. if epoch is finalized on different
// blocks, the second finalization will crash.
tries.freeze_mem_tries(
shard_uid,
vec![split_shard_event.left_child_shard, split_shard_event.right_child_shard],
)?;

let chunk_extra = self.get_chunk_extra(block_hash, &shard_uid)?;
let boundary_account = split_shard_event.boundary_account;

let mut trie_store_update = self.store.store_update();

// TODO(#12019): leave only tracked shards.
for (new_shard_uid, retain_mode) in [
(split_shard_event.left_child_shard, RetainMode::Left),
(split_shard_event.right_child_shard, RetainMode::Right),
] {
let Some(mem_tries) = tries.get_mem_tries(new_shard_uid) else {
tracing::error!(
"Memtrie not loaded. Cannot process memtrie resharding storage
update for block {:?}, shard {:?}",
block_hash,
shard_uid
);
return Err(Error::Other("Memtrie not loaded".to_string()));
};

let mut mem_tries = mem_tries.write().unwrap();
let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), true)?;

let (trie_changes, _) =
mem_trie_update.retain_split_shard(boundary_account.clone(), retain_mode);
mem_trie_update.retain_split_shard(&boundary_account, retain_mode);
let partial_state = PartialState::default();
let partial_storage = PartialStorage { nodes: partial_state };
let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap();
Expand All @@ -104,31 +115,27 @@ impl ReshardingManager {
let mut child_chunk_extra = ChunkExtra::clone(&chunk_extra);
*child_chunk_extra.state_root_mut() = new_state_root;

let state_transition_data = StoredChunkStateTransitionData {
base_state: partial_storage.nodes,
receipts_hash: CryptoHash::default(),
};
chain_store_update.save_chunk_extra(block_hash, &new_shard_uid, child_chunk_extra);
chain_store_update.save_state_transition_data(
*block_hash,
new_shard_uid.shard_id(),
Some(partial_storage),
CryptoHash::default(),
);

// TODO(store): Use proper store interface
let mut store_update = self.store.store_update();
store_update.set_ser(
DBCol::ChunkExtra,
&get_block_shard_uid(block_hash, &new_shard_uid),
&child_chunk_extra,
)?;
store_update.set_ser(
DBCol::StateTransitionData,
&get_block_shard_id(block_hash, new_shard_uid.shard_id()),
&state_transition_data,
)?;
// Commit `TrieChanges` directly. They are needed to serve reads of
// new nodes from `DBCol::State` while memtrie is properly created
// from flat storage.
tries.apply_insertions(
&trie_changes,
new_shard_uid,
&mut store_update.trie_store_update(),
&mut trie_store_update.trie_store_update(),
);
store_update.commit()?;
}

chain_store_update.merge(trie_store_update);
chain_store_update.commit()?;

Ok(())
}

Expand Down
6 changes: 4 additions & 2 deletions chain/chain/src/stateless_validation/chunk_endorsement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ pub fn validate_chunk_endorsements_in_block(
let epoch_id = epoch_manager.get_epoch_id_from_prev_block(block.header().prev_hash())?;
let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?;
for (chunk_header, signatures) in block.chunks().iter().zip(block.chunk_endorsements()) {
let shard_id = chunk_header.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);
// For old chunks, we optimize the block by not including the chunk endorsements.
if chunk_header.height_included() != block.header().height() {
if !signatures.is_empty() {
Expand All @@ -60,8 +58,12 @@ pub fn validate_chunk_endorsements_in_block(
}
continue;
}

// Validation for chunks in each shard
// The signatures from chunk validators for each shard must match the ordered_chunk_validators
let shard_id = chunk_header.shard_id();
let shard_index = shard_layout.get_shard_index(shard_id);

let chunk_validator_assignments = epoch_manager.get_chunk_validator_assignments(
&epoch_id,
shard_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl Client {
)?;

// Fetch all incoming receipts for `prev_chunk`.
// They will be between `prev_prev_chunk.height_included`` (first block containing `prev_prev_chunk`)
// They will be between `prev_prev_chunk.height_included` (first block containing `prev_prev_chunk`)
// and `prev_chunk_original_block`
let incoming_receipt_proofs = self.chain.chain_store().get_incoming_receipts_for_shard(
self.epoch_manager.as_ref(),
Expand Down
2 changes: 0 additions & 2 deletions core/store/src/trie/mem/arena/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ impl From<STArena> for HybridArena {

impl HybridArena {
/// Function to create a new HybridArena from an existing instance of shared memory in FrozenArena.
#[allow(dead_code)]
pub fn from_frozen(name: String, frozen_arena: FrozenArena) -> Self {
let allocator = Allocator::new_with_initial_stats(
name,
Expand All @@ -114,7 +113,6 @@ impl HybridArena {
///
/// Instances of FrozenArena are cloneable and can be used to create new instances of HybridArena with
/// shared memory from FrozenArena.
#[allow(dead_code)]
pub fn freeze(self) -> FrozenArena {
assert!(!self.has_shared_memory(), "Cannot freeze arena with shared memory");
FrozenArena {
Expand Down
1 change: 1 addition & 0 deletions core/store/src/trie/mem/arena/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod frozen;
pub mod hybrid;
mod metrics;
pub mod single_thread;
pub use frozen::FrozenArena;

/// An abstraction of an arena that also allows being implemented differently,
/// specifically in the case of a multi-threaded arena where each arena instance
Expand Down
25 changes: 25 additions & 0 deletions core/store/src/trie/mem/mem_tries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::Trie;
use super::arena::hybrid::{HybridArena, HybridArenaMemory};
use super::arena::single_thread::STArena;
use super::arena::Arena;
use super::arena::FrozenArena;
use super::flexible_data::value::ValueView;
use super::iter::STMemTrieIterator;
use super::lookup::memtrie_lookup;
Expand Down Expand Up @@ -42,6 +43,15 @@ pub struct MemTries {
shard_uid: ShardUId,
}

/// Frozen arena together with supported roots and heights.
/// Used to construct new memtries which share nodes from the same arena.
#[derive(Clone)]
pub struct FrozenMemTries {
arena: FrozenArena,
roots: HashMap<StateRoot, Vec<MemTrieNodeId>>,
heights: BTreeMap<BlockHeight, Vec<StateRoot>>,
}

impl MemTries {
pub fn new(shard_uid: ShardUId) -> Self {
Self {
Expand All @@ -52,6 +62,15 @@ impl MemTries {
}
}

pub fn from_frozen_memtries(shard_uid: ShardUId, frozen_memtries: FrozenMemTries) -> Self {
Self {
arena: HybridArena::from_frozen(shard_uid.to_string(), frozen_memtries.arena),
roots: frozen_memtries.roots,
heights: frozen_memtries.heights,
shard_uid,
}
}

pub fn new_from_arena_and_root(
shard_uid: ShardUId,
block_height: BlockHeight,
Expand Down Expand Up @@ -192,6 +211,12 @@ impl MemTries {
Ok(memtrie_lookup(root, key, nodes_accessed))
}

/// Freezes memtrie. The result is used as a shared data to construct new
/// memtries.
pub fn freeze(self) -> FrozenMemTries {
FrozenMemTries { arena: self.arena.freeze(), roots: self.roots, heights: self.heights }
}

#[cfg(test)]
pub fn arena(&self) -> &HybridArena {
&self.arena
Expand Down
21 changes: 16 additions & 5 deletions core/store/src/trie/mem/resharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{NibbleSlice, TrieChanges};
use super::arena::ArenaMemory;
use super::updating::{MemTrieUpdate, OldOrUpdatedNodeId, TrieAccesses, UpdatedMemTrieNode};
use itertools::Itertools;
use near_primitives::trie_key::col::COLUMNS_WITH_ACCOUNT_ID_IN_KEY;
use near_primitives::types::AccountId;
use std::ops::Range;

Expand Down Expand Up @@ -32,12 +33,22 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> {
/// responsibility to apply the changes.
pub fn retain_split_shard(
self,
_boundary_account: AccountId,
_retain_mode: RetainMode,
boundary_account: &AccountId,
retain_mode: RetainMode,
) -> (TrieChanges, TrieAccesses) {
// TODO(#12074): generate intervals in nibbles.

self.retain_multi_range(&[])
let mut intervals = vec![];
// TODO(#12074): generate correct intervals in nibbles.
for (col, _) in COLUMNS_WITH_ACCOUNT_ID_IN_KEY {
match retain_mode {
RetainMode::Left => {
intervals.push(vec![col]..[&[col], boundary_account.as_bytes()].concat())
}
RetainMode::Right => {
intervals.push([&[col], boundary_account.as_bytes()].concat()..vec![col + 1])
}
}
}
self.retain_multi_range(&intervals)
}

/// Retains keys belonging to any of the ranges given in `intervals` from
Expand Down
30 changes: 30 additions & 0 deletions core/store/src/trie/shard_tries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,36 @@ impl ShardTries {
memtries.write().unwrap().delete_until_height(height);
}
}

/// Freezes in-memory trie for source shard and copies reference to it to
/// target shards.
/// Needed to serve queries for these shards just after resharding, before
/// proper memtries are loaded.
pub fn freeze_mem_tries(
&self,
source_shard_uid: ShardUId,
target_shard_uids: Vec<ShardUId>,
) -> Result<(), StorageError> {
let mut outer_guard = self.0.mem_tries.write().unwrap();
let Some(memtries) = outer_guard.remove(&source_shard_uid) else {
return Err(StorageError::MemTrieLoadingError("Memtrie not loaded".to_string()));
};
let mut guard = memtries.write().unwrap();
let memtries = std::mem::replace(&mut *guard, MemTries::new(source_shard_uid));
let frozen_memtries = memtries.freeze();

for shard_uid in [vec![source_shard_uid], target_shard_uids].concat() {
outer_guard.insert(
shard_uid,
Arc::new(RwLock::new(MemTries::from_frozen_memtries(
shard_uid,
frozen_memtries.clone(),
))),
);
}

Ok(())
}
}

pub struct WrappedTrieChanges {
Expand Down
29 changes: 26 additions & 3 deletions integration-tests/src/test_loop/tests/resharding_v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use near_primitives::epoch_manager::EpochConfigStore;
use near_primitives::shard_layout::ShardLayout;
use near_primitives::types::{AccountId, ShardId};
use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION};
use near_store::ShardUId;
use std::collections::BTreeMap;
use std::sync::Arc;

Expand Down Expand Up @@ -69,9 +70,9 @@ fn test_resharding_v3() {
let new_shards = vec![max_shard_id, max_shard_id + 1];
shard_ids.extend(new_shards.clone());
shards_split_map.insert(last_shard_id, new_shards);
boundary_accounts.push(AccountId::try_from("xyz.near".to_string()).unwrap());
boundary_accounts.push(AccountId::try_from("account6".to_string()).unwrap());
epoch_config.shard_layout =
ShardLayout::v2(boundary_accounts, shard_ids, Some(shards_split_map));
ShardLayout::v2(boundary_accounts, shard_ids.clone(), Some(shards_split_map));
let expected_num_shards = epoch_config.shard_layout.shard_ids().count();
let epoch_config_store = EpochConfigStore::test(BTreeMap::from_iter(vec![
(base_protocol_version, Arc::new(base_epoch_config)),
Expand Down Expand Up @@ -113,7 +114,29 @@ fn test_resharding_v3() {
let prev_epoch_id =
client.epoch_manager.get_prev_epoch_id_from_prev_block(&tip.prev_block_hash).unwrap();
let epoch_config = client.epoch_manager.get_epoch_config(&prev_epoch_id).unwrap();
epoch_config.shard_layout.shard_ids().count() == expected_num_shards
if epoch_config.shard_layout.shard_ids().count() != expected_num_shards {
return false;
}

// If resharding happened, also check that each shard has non-empty state.
for shard_id in 0..3 {
let shard_uid = ShardUId { version: 3, shard_id: shard_id as u32 };
let chunk_extra =
client.chain.get_chunk_extra(&tip.prev_block_hash, &shard_uid).unwrap();
let trie = client
.runtime_adapter
.get_trie_for_shard(
shard_id,
&tip.prev_block_hash,
*chunk_extra.state_root(),
false,
)
.unwrap();
let items = trie.lock_for_iter().iter().unwrap().count();
assert!(items > 0);
}

return true;
};

test_loop.run_until(
Expand Down

0 comments on commit 3221b86

Please sign in to comment.