Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add index method for StacksEpochId #5350

Merged
merged 13 commits into from
Nov 18, 2024
Merged
81 changes: 81 additions & 0 deletions stacks-common/src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cmp::Ordering;
use std::fmt;
use std::ops::{Deref, DerefMut, Index, IndexMut};

#[cfg(feature = "canonical")]
pub mod sqlite;
Expand Down Expand Up @@ -460,3 +461,83 @@ impl<L: PartialEq + Eq> Ord for StacksEpoch<L> {
self.epoch_id.cmp(&other.epoch_id)
}
}

/// A wrapper for holding a list of Epochs, indexable by StacksEpochId
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Eq)]
pub struct EpochList<L: Clone>(Vec<StacksEpoch<L>>);

impl<L: Clone> EpochList<L> {
pub fn new(epochs: &[StacksEpoch<L>]) -> EpochList<L> {
EpochList(epochs.to_vec())
}

pub fn get(&self, index: StacksEpochId) -> Option<&StacksEpoch<L>> {
self.0.get(StacksEpoch::find_epoch_by_id(&self.0, index)?)
}

pub fn get_mut(&mut self, index: StacksEpochId) -> Option<&mut StacksEpoch<L>> {
let index = StacksEpoch::find_epoch_by_id(&self.0, index)?;
self.0.get_mut(index)
}

/// Truncates the list after the given epoch id
pub fn truncate_after(&mut self, epoch_id: StacksEpochId) {
if let Some(index) = StacksEpoch::find_epoch_by_id(&self.0, epoch_id) {
self.0.truncate(index + 1);
}
}

/// Determine which epoch, if any, a given burnchain height falls into.
pub fn epoch_id_at_height(&self, height: u64) -> Option<StacksEpochId> {
StacksEpoch::find_epoch(self, height).map(|idx| self.0[idx].epoch_id)
}

/// Determine which epoch, if any, a given burnchain height falls into.
pub fn epoch_at_height(&self, height: u64) -> Option<StacksEpoch<L>> {
StacksEpoch::find_epoch(self, height).map(|idx| self.0[idx].clone())
}

/// Pushes a new `StacksEpoch` to the end of the list
pub fn push(&mut self, epoch: StacksEpoch<L>) {
if let Some(last) = self.0.last() {
assert!(
epoch.start_height == last.end_height && epoch.epoch_id > last.epoch_id,
"Epochs must be pushed in order"
);
}
self.0.push(epoch);
}

pub fn to_vec(&self) -> Vec<StacksEpoch<L>> {
self.0.clone()
}
}

impl<L: Clone> Index<StacksEpochId> for EpochList<L> {
type Output = StacksEpoch<L>;
fn index(&self, index: StacksEpochId) -> &StacksEpoch<L> {
self.get(index)
.expect("Invalid StacksEpochId: could not find corresponding epoch")
}
}

impl<L: Clone> IndexMut<StacksEpochId> for EpochList<L> {
fn index_mut(&mut self, index: StacksEpochId) -> &mut StacksEpoch<L> {
self.get_mut(index)
.expect("Invalid StacksEpochId: could not find corresponding epoch")
}
}

impl<L: Clone> Deref for EpochList<L> {
type Target = [StacksEpoch<L>];

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<L: Clone> DerefMut for EpochList<L> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
14 changes: 7 additions & 7 deletions stackslib/src/burnchains/bitcoin/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::burnchains::{
Burnchain, BurnchainBlockHeader, Error as burnchain_error, MagicBytes, BLOCKSTACK_MAGIC_MAINNET,
};
use crate::core::{
StacksEpoch, StacksEpochExtension, STACKS_EPOCHS_MAINNET, STACKS_EPOCHS_REGTEST,
EpochList, StacksEpoch, StacksEpochExtension, STACKS_EPOCHS_MAINNET, STACKS_EPOCHS_REGTEST,
STACKS_EPOCHS_TESTNET,
};
use crate::util_lib::db::Error as DBError;
Expand Down Expand Up @@ -91,11 +91,11 @@ impl TryFrom<u32> for BitcoinNetworkType {
/// Get the default epochs definitions for the given BitcoinNetworkType.
/// Should *not* be used except by the BitcoinIndexer when no epochs vector
/// was specified.
pub fn get_bitcoin_stacks_epochs(network_id: BitcoinNetworkType) -> Vec<StacksEpoch> {
pub fn get_bitcoin_stacks_epochs(network_id: BitcoinNetworkType) -> EpochList {
match network_id {
BitcoinNetworkType::Mainnet => STACKS_EPOCHS_MAINNET.to_vec(),
BitcoinNetworkType::Testnet => STACKS_EPOCHS_TESTNET.to_vec(),
BitcoinNetworkType::Regtest => STACKS_EPOCHS_REGTEST.to_vec(),
BitcoinNetworkType::Mainnet => (*STACKS_EPOCHS_MAINNET).clone(),
BitcoinNetworkType::Testnet => (*STACKS_EPOCHS_TESTNET).clone(),
BitcoinNetworkType::Regtest => (*STACKS_EPOCHS_REGTEST).clone(),
}
}

Expand All @@ -112,7 +112,7 @@ pub struct BitcoinIndexerConfig {
pub spv_headers_path: String,
pub first_block: u64,
pub magic_bytes: MagicBytes,
pub epochs: Option<Vec<StacksEpoch>>,
pub epochs: Option<EpochList>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -1041,7 +1041,7 @@ impl BurnchainIndexer for BitcoinIndexer {
/// 2) Use hard-coded static values, otherwise.
///
/// It is an error (panic) to set custom epochs if running on `Mainnet`.
fn get_stacks_epochs(&self) -> Vec<StacksEpoch> {
fn get_stacks_epochs(&self) -> EpochList {
StacksEpoch::get_epochs(self.runtime.network_id, self.config.epochs.as_ref())
}

Expand Down
3 changes: 2 additions & 1 deletion stackslib/src/burnchains/burnchain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use stacks_common::util::hash::to_hex;
use stacks_common::util::vrf::VRFPublicKey;
use stacks_common::util::{get_epoch_time_ms, get_epoch_time_secs, log, sleep_ms};

use super::EpochList;
use crate::burnchains::affirmation::update_pox_affirmation_maps;
use crate::burnchains::bitcoin::address::{
to_c32_version_byte, BitcoinAddress, LegacyBitcoinAddressType,
Expand Down Expand Up @@ -718,7 +719,7 @@ impl Burnchain {
readwrite: bool,
first_block_header_hash: BurnchainHeaderHash,
first_block_header_timestamp: u64,
epochs: Vec<StacksEpoch>,
epochs: EpochList,
) -> Result<(SortitionDB, BurnchainDB), burnchain_error> {
Burnchain::setup_chainstate_dirs(&self.working_dir)?;

Expand Down
2 changes: 1 addition & 1 deletion stackslib/src/burnchains/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub trait BurnchainIndexer {
fn get_first_block_height(&self) -> u64;
fn get_first_block_header_hash(&self) -> Result<BurnchainHeaderHash, burnchain_error>;
fn get_first_block_header_timestamp(&self) -> Result<u64, burnchain_error>;
fn get_stacks_epochs(&self) -> Vec<StacksEpoch>;
fn get_stacks_epochs(&self) -> EpochList;

fn get_headers_path(&self) -> String;
fn get_headers_height(&self) -> Result<u64, burnchain_error>;
Expand Down
22 changes: 11 additions & 11 deletions stackslib/src/chainstate/burn/db/sortdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2969,9 +2969,9 @@ impl SortitionDB {
db_tx: &Transaction,
epochs: &[StacksEpoch],
) -> Result<(), db_error> {
let epochs = StacksEpoch::validate_epochs(epochs);
let epochs: &[StacksEpoch] = &StacksEpoch::validate_epochs(epochs);
let existing_epochs = Self::get_stacks_epochs(db_tx)?;
if existing_epochs == epochs {
if &existing_epochs == epochs {
return Ok(());
}

Expand Down Expand Up @@ -3482,9 +3482,10 @@ impl SortitionDB {
tx.commit()?;
} else if version == expected_version {
// this transaction is almost never needed
let validated_epochs = StacksEpoch::validate_epochs(epochs);
let validated_epochs: &[StacksEpoch] =
&StacksEpoch::validate_epochs(epochs);
let existing_epochs = Self::get_stacks_epochs(self.conn())?;
if existing_epochs == validated_epochs {
if &existing_epochs == validated_epochs {
return Ok(());
}

Expand Down Expand Up @@ -6636,7 +6637,7 @@ pub mod tests {
pub fn connect_test_with_epochs(
first_block_height: u64,
first_burn_hash: &BurnchainHeaderHash,
epochs: Vec<StacksEpoch>,
epochs: EpochList,
) -> Result<SortitionDB, db_error> {
let mut rng = rand::thread_rng();
let mut buf = [0u8; 32];
Expand Down Expand Up @@ -10930,10 +10931,9 @@ pub mod tests {

fs::create_dir_all(path_root).unwrap();

let mut bad_epochs = STACKS_EPOCHS_MAINNET.to_vec();
let idx = bad_epochs.len() - 2;
bad_epochs[idx].end_height += 1;
bad_epochs[idx + 1].start_height += 1;
let mut bad_epochs = (*STACKS_EPOCHS_MAINNET).clone();
bad_epochs[StacksEpochId::Epoch25].end_height += 1;
bad_epochs[StacksEpochId::Epoch30].start_height += 1;

let sortdb = SortitionDB::connect(
&format!("{}/sortdb.sqlite", &path_root),
Expand All @@ -10948,14 +10948,14 @@ pub mod tests {
.unwrap();

let db_epochs = SortitionDB::get_stacks_epochs(sortdb.conn()).unwrap();
assert_eq!(db_epochs, bad_epochs);
assert_eq!(db_epochs, bad_epochs.to_vec());

let fixed_sortdb = SortitionDB::connect(
&format!("{}/sortdb.sqlite", &path_root),
0,
&BurnchainHeaderHash([0x00; 32]),
0,
&STACKS_EPOCHS_MAINNET.to_vec(),
&STACKS_EPOCHS_MAINNET,
PoxConstants::mainnet_default(),
None,
true,
Expand Down
2 changes: 1 addition & 1 deletion stackslib/src/chainstate/coordinator/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ pub fn setup_states_with_epochs(
pox_consts: Option<PoxConstants>,
initial_balances: Option<Vec<(PrincipalData, u64)>>,
stacks_epoch_id: StacksEpochId,
epochs_opt: Option<Vec<StacksEpoch>>,
epochs_opt: Option<EpochList>,
) {
let mut burn_block = None;
let mut others = vec![];
Expand Down
2 changes: 1 addition & 1 deletion stackslib/src/chainstate/stacks/boot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,7 @@ pub mod test {
pub fn instantiate_pox_peer_with_epoch<'a>(
burnchain: &Burnchain,
test_name: &str,
epochs: Option<Vec<StacksEpoch>>,
epochs: Option<EpochList>,
observer: Option<&'a TestEventObserver>,
) -> (TestPeer<'a>, Vec<StacksPrivateKey>) {
let mut peer_config = TestPeerConfig::new(test_name, 0, 0);
Expand Down
Loading