diff --git a/crates/bdk/src/wallet/mod.rs b/crates/bdk/src/wallet/mod.rs index 02f4520a77..e0ac89c1b7 100644 --- a/crates/bdk/src/wallet/mod.rs +++ b/crates/bdk/src/wallet/mod.rs @@ -382,7 +382,7 @@ impl Wallet { } /// Returns the latest checkpoint. - pub fn latest_checkpoint(&self) -> Option> { + pub fn latest_checkpoint(&self) -> Option { self.chain.tip() } @@ -504,7 +504,7 @@ impl Wallet { .range(height..) .next() .ok_or(InsertTxError::ConfirmationHeightCannotBeGreaterThanTip { - tip_height: self.chain.tip().map(|b| b.height), + tip_height: self.chain.tip().map(|b| b.height()), tx_height: height, }) .map(|(&_, cp)| ConfirmationTimeAnchor { @@ -713,8 +713,7 @@ impl Wallet { None => self .chain .tip() - .and_then(|cp| cp.height.into()) - .map(|height| LockTime::from_height(height).expect("Invalid height")), + .map(|cp| LockTime::from_height(cp.height()).expect("Invalid height")), h => h, }; @@ -1286,7 +1285,7 @@ impl Wallet { }); let current_height = sign_options .assume_height - .or(self.chain.tip().map(|b| b.height)); + .or(self.chain.tip().map(|b| b.height())); debug!( "Input #{} - {}, using `confirmation_height` = {:?}, `current_height` = {:?}", diff --git a/crates/bdk/tests/wallet.rs b/crates/bdk/tests/wallet.rs index 9f6488a0aa..ed014f70a1 100644 --- a/crates/bdk/tests/wallet.rs +++ b/crates/bdk/tests/wallet.rs @@ -45,7 +45,7 @@ fn receive_output(wallet: &mut Wallet, value: u64, height: ConfirmationTime) -> fn receive_output_in_latest_block(wallet: &mut Wallet, value: u64) -> OutPoint { let height = match wallet.latest_checkpoint() { Some(cp) => ConfirmationTime::Confirmed { - height: cp.height, + height: cp.height(), time: 0, }, None => ConfirmationTime::Unconfirmed { last_seen: 0 }, @@ -225,7 +225,7 @@ fn test_create_tx_fee_sniping_locktime_last_sync() { // If there's no current_height we're left with using the last sync height assert_eq!( psbt.unsigned_tx.lock_time.0, - wallet.latest_checkpoint().unwrap().height + wallet.latest_checkpoint().unwrap().height() ); } @@ -1485,7 +1485,7 @@ fn test_bump_fee_drain_wallet() { .insert_tx( tx.clone(), ConfirmationTime::Confirmed { - height: wallet.latest_checkpoint().unwrap().height, + height: wallet.latest_checkpoint().unwrap().height(), time: 42_000, }, ) diff --git a/crates/chain/src/local_chain.rs b/crates/chain/src/local_chain.rs index 826d5f4397..9fce8dcb7d 100644 --- a/crates/chain/src/local_chain.rs +++ b/crates/chain/src/local_chain.rs @@ -12,79 +12,116 @@ pub type ChangeSet = BTreeMap>; /// Represents a block of [`LocalChain`]. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct CheckPoint { - /// Block height. - pub height: u32, - /// Block hash. - pub hash: BlockHash, +pub struct CheckPoint(Arc); + +/// The internal contents of [`CheckPoint`]. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct CPInner { + /// Block id (hash and height). + block: BlockId, /// Previous checkpoint (if any). - pub prev: Option>, + prev: Option>, } +/// Occurs when the caller contructs a [`CheckPoint`] with a height that is not higher than the +/// previous checkpoint it points to. +#[derive(Debug, Clone, PartialEq)] +pub struct NewCheckPointError { + /// The height of the new checkpoint. + pub new_height: u32, + /// The height of the previous checkpoint. + pub prev_height: u32, +} + +impl core::fmt::Display for NewCheckPointError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "cannot construct checkpoint with a height ({}) that is not higher than the previous checkpoint ({})", self.new_height, self.prev_height) + } +} + +impl std::error::Error for NewCheckPointError {} + impl CheckPoint { + /// Construct a [`CheckPoint`] from a [`BlockId`]. + pub fn new(block: BlockId) -> Self { + Self(Arc::new(CPInner { block, prev: None })) + } + + /// Construct a [`CheckPoint`] of `block` with a previous checkpoint. + pub fn new_with_prev( + block: BlockId, + prev: Option, + ) -> Result { + if let Some(prev_cp) = &prev { + if prev_cp.height() >= block.height { + return Err(NewCheckPointError { + new_height: block.height, + prev_height: prev_cp.height(), + }); + } + } + + Ok(Self(Arc::new(CPInner { + block, + prev: prev.map(|cp| cp.0), + }))) + } + /// Get the [`BlockId`] of the checkpoint. pub fn block_id(&self) -> BlockId { - BlockId { - height: self.height, - hash: self.hash, - } + self.0.block + } + + /// Get the height of the checkpoint. + pub fn height(&self) -> u32 { + self.0.block.height + } + + /// Get the block hash of the checkpoint. + pub fn hash(&self) -> BlockHash { + self.0.block.hash } /// Detach this checkpoint from the previous. - pub fn detach(self: Arc) -> Arc { - Arc::new(Self { - height: self.height, - hash: self.hash, + pub fn detach(self) -> Self { + Self(Arc::new(CPInner { + block: self.0.block, prev: None, - }) + })) + } + + /// Get previous checkpoint. + pub fn prev(&self) -> Option { + self.0.prev.clone().map(CheckPoint) } /// Iterate - pub fn iter(self: &Arc) -> CheckPointIter { + pub fn iter(&self) -> CheckPointIter { CheckPointIter { - current: Some(Arc::clone(self)), + current: Some(Arc::clone(&self.0)), } } } /// A structure that iterates over checkpoints backwards. pub struct CheckPointIter { - current: Option>, + current: Option>, } impl Iterator for CheckPointIter { - type Item = Arc; + type Item = CheckPoint; fn next(&mut self) -> Option { let current = self.current.clone()?; self.current = current.prev.clone(); - Some(current) + Some(CheckPoint(current)) } } /// This is a local implementation of [`ChainOracle`]. #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct LocalChain { - checkpoints: BTreeMap>, -} - -impl FromIterator> for LocalChain { - fn from_iter>>(iter: T) -> Self { - let mut chain = Self { - checkpoints: iter.into_iter().map(|cp| (cp.height, cp)).collect(), - }; - chain.fix_links(0); - chain - } -} - -impl IntoIterator for LocalChain { - type Item = Arc; - type IntoIter = CheckPointIter; - - fn into_iter(self) -> Self::IntoIter { - self.iter_checkpoints(None) - } + checkpoints: BTreeMap, } impl From for BTreeMap { @@ -92,7 +129,7 @@ impl From for BTreeMap { value .checkpoints .values() - .map(|cp| (cp.height, cp.hash)) + .map(|cp| (cp.height(), cp.hash())) .collect() } } @@ -130,7 +167,7 @@ impl ChainOracle for LocalChain { self.checkpoints.get(&chain_tip.height), ) { (Some(cp), Some(tip_cp)) => { - Some(cp.hash == block.hash && tip_cp.hash == chain_tip.hash) + Some(cp.hash() == block.hash && tip_cp.hash() == chain_tip.hash) } _ => None, }, @@ -138,10 +175,7 @@ impl ChainOracle for LocalChain { } fn get_chain_tip(&self) -> Result, Self::Error> { - Ok(self.checkpoints.values().last().map(|cp| BlockId { - height: cp.height, - hash: cp.hash, - })) + Ok(self.checkpoints.values().last().map(CheckPoint::block_id)) } } @@ -154,7 +188,7 @@ impl LocalChain { } /// Construct a [`LocalChain`] from a given `checkpoint` tip. - pub fn from_checkpoint(checkpoint: Arc) -> Result { + pub fn from_checkpoint(checkpoint: CheckPoint) -> Result { match Self::default().extend_from_last_agreement(checkpoint) { Ok((chain, _)) => Ok(chain), Err(ExtendError::InvalidHeightOrder(err)) => Err(err), @@ -172,16 +206,7 @@ impl LocalChain { let mut chain = Self { checkpoints: blocks .into_iter() - .map(|b| { - ( - b.height, - Arc::new(CheckPoint { - height: b.height, - hash: b.hash, - prev: None, - }), - ) - }) + .map(|b| (b.height, CheckPoint::new(b))) .collect(), }; chain.fix_links(0); @@ -189,7 +214,7 @@ impl LocalChain { } /// Get the highest checkpoint. - pub fn tip(&self) -> Option> { + pub fn tip(&self) -> Option { self.checkpoints.values().last().cloned() } @@ -207,7 +232,7 @@ impl LocalChain { /// [module-level documentation]: crate::local_chain pub fn determine_changeset( &self, - new_tip: &Arc, + new_tip: &CheckPoint, must_include_height: Option, ) -> Result { // The changeset if we only records additions from `update`. @@ -216,24 +241,28 @@ impl LocalChain { let mut agreement_height = Option::::None; for update_cp in new_tip.iter() { - let original_cp = self.checkpoints.get(&update_cp.height); - - match original_cp { - Some(original_cp) if original_cp.hash == update_cp.hash => {} + let update_id = update_cp.block_id(); + let original_id = self + .checkpoints + .get(&update_cp.height()) + .map(CheckPoint::block_id); + + match original_id { + Some(original_id) if original_id.hash == update_id.hash => {} _ => { - additions.insert(update_cp.height, Some(update_cp.hash)); + additions.insert(update_id.height, Some(update_id.hash)); } }; if agreement_height.is_none() { - if let Some(original_cp) = original_cp { - if update_cp.hash == original_cp.hash { - agreement_height = Some(update_cp.height); + if let Some(original_id) = original_id { + if update_id.hash == original_id.hash { + agreement_height = Some(update_id.height); } } } else { if let Some(must_include_height) = must_include_height { - if update_cp.height > must_include_height { + if update_id.height > must_include_height { continue; } } @@ -246,7 +275,7 @@ impl LocalChain { // if there is no agreement, we invalidate all of the original chain None => u32::MIN, // if the agreement is at the update's tip, we don't need to invalidate - Some(height) if height == new_tip.height => u32::MAX, + Some(height) if height == new_tip.height() => u32::MAX, Some(height) => height + 1, }; @@ -279,41 +308,47 @@ impl LocalChain { /// This will replace all checkpoints higher than the last common checkpoint. pub fn extend_from_last_agreement( &self, - new_tip: Arc, + new_tip: CheckPoint, ) -> Result<(Self, ChangeSet), ExtendError> { if self.checkpoints.is_empty() { return Ok(new_tip.iter().fold( Default::default(), |(mut acc_chain, mut acc_changeset), cp| { - acc_changeset.insert(cp.height, Some(cp.hash)); - acc_chain.checkpoints.insert(cp.height, cp); + let cp_block = cp.block_id(); + acc_changeset.insert(cp_block.height, Some(cp_block.hash)); + acc_chain.checkpoints.insert(cp_block.height, cp); (acc_chain, acc_changeset) }, )); } let (new_cps, agreement_point) = { - let mut new_cps = Vec::>::new(); + let mut new_cps = Vec::::new(); let mut agreement_point = Option::::None; let mut last_height = u32::MAX; for cp in new_tip.iter() { + let cp_block = cp.block_id(); // ensure incoming checkpoints are consistent - if cp.height >= last_height { + if cp_block.height >= last_height { return Err(ExtendError::InvalidHeightOrder(InvalidHeightOrderError { last_height, - conflicting_block: BlockId { - height: cp.height, - hash: cp.hash, - }, + conflicting_block: cp_block, })); } - last_height = cp.height; + last_height = cp_block.height; + + let update_cp_ptr = Some(Arc::as_ptr(&cp.0)); + let chain_cp_ptr = self + .checkpoints + .get(&cp_block.height) + .map(|cp| Arc::as_ptr(&cp.0)); - if Some(Arc::as_ptr(&cp)) == self.checkpoints.get(&cp.height).map(Arc::as_ptr) { - agreement_point = Some(cp.height); + if update_cp_ptr == chain_cp_ptr { + agreement_point = Some(cp_block.height); break; } + new_cps.push(cp); } @@ -335,15 +370,22 @@ impl LocalChain { .checkpoints .range(agreement_point + 1..) .map(|(&height, _)| (height, Option::::None)) - .chain(new_cps.iter().map(|cp| (cp.height, Some(cp.hash)))) + .chain(new_cps.iter().map(|cp| (cp.height(), Some(cp.hash())))) .collect::(); - let new_chain = self - .checkpoints - .range(..=agreement_point) - .map(|(_, cp)| Arc::clone(cp)) - .chain(new_cps) - .collect::(); + let new_chain = { + let mut new_chain = Self { + checkpoints: self + .checkpoints + .range(..=agreement_point) + .map(|(_, cp)| cp.clone()) + .chain(new_cps) + .map(|cp| (cp.height(), cp)) + .collect(), + }; + new_chain.fix_links(0); + new_chain + }; Ok((new_chain, changeset)) } @@ -353,14 +395,9 @@ impl LocalChain { if let Some(start_height) = changeset.keys().next().cloned() { for (&height, &hash) in changeset { match hash { - Some(hash) => self.checkpoints.insert( - height, - Arc::new(CheckPoint { - height, - hash, - prev: None, - }), - ), + Some(hash) => self + .checkpoints + .insert(height, CheckPoint::new(BlockId { height, hash })), None => self.checkpoints.remove(&height), }; } @@ -376,7 +413,7 @@ impl LocalChain { /// [`apply_changeset`]: Self::apply_changeset pub fn apply_update( &mut self, - new_tip: Arc, + new_tip: CheckPoint, must_include_height: Option, ) -> Result { let changeset = self.determine_changeset(&new_tip, must_include_height)?; @@ -392,16 +429,12 @@ impl LocalChain { pub fn get_or_insert( &mut self, block_id: BlockId, - ) -> Result<(Arc, ChangeSet), InsertBlockError> { + ) -> Result<(CheckPoint, ChangeSet), InsertBlockError> { use crate::collections::btree_map::Entry; match self.checkpoints.entry(block_id.height) { Entry::Vacant(entry) => { - entry.insert(Arc::new(CheckPoint { - height: block_id.height, - hash: block_id.hash, - prev: None, - })); + entry.insert(CheckPoint::new(block_id)); self.fix_links(block_id.height); let cp = self.checkpoint(block_id.height).expect("must be inserted"); let changeset = @@ -415,7 +448,7 @@ impl LocalChain { } else { Err(InsertBlockError { height: block_id.height, - original_hash: cp.hash, + original_hash: cp.hash(), update_hash: block_id.hash, }) } @@ -431,11 +464,10 @@ impl LocalChain { .map(|(_, cp)| cp.clone()); for (_, cp) in self.checkpoints.range_mut(start_height..) { - if cp.prev.as_ref().map(Arc::as_ptr) != prev.as_ref().map(Arc::as_ptr) { - *cp = Arc::new(CheckPoint { - height: cp.height, - hash: cp.hash, - prev: prev.clone(), + if cp.0.prev.as_ref().map(Arc::as_ptr) != prev.as_ref().map(|cp| Arc::as_ptr(&cp.0)) { + cp.0 = Arc::new(CPInner { + block: cp.block_id(), + prev: prev.clone().map(|cp| cp.0), }); } prev = Some(cp.clone()); @@ -446,12 +478,12 @@ impl LocalChain { /// recover the current chain. pub fn initial_changeset(&self) -> ChangeSet { self.iter_checkpoints(None) - .map(|cp| (cp.height, Some(cp.hash))) + .map(|cp| (cp.height(), Some(cp.hash()))) .collect() } /// Get checkpoint of `height` (if any). - pub fn checkpoint(&self, height: u32) -> Option> { + pub fn checkpoint(&self, height: u32) -> Option { self.checkpoints.get(&height).cloned() } @@ -466,14 +498,14 @@ impl LocalChain { .checkpoints .range(..=height) .last() - .map(|(_, cp)| cp.clone()), - None => self.checkpoints.values().last().cloned(), + .map(|(_, cp)| cp.0.clone()), + None => self.checkpoints.values().last().map(|cp| cp.0.clone()), }, } } /// Get a reference to the internal checkpoint map. - pub fn checkpoints(&self) -> &BTreeMap> { + pub fn checkpoints(&self) -> &BTreeMap { &self.checkpoints } } @@ -506,7 +538,7 @@ impl std::error::Error for InsertBlockError {} #[derive(Clone, Debug, PartialEq)] pub struct CannotConnectError { /// The suggested checkpoint to include to connect the two chains. - pub try_include: Arc, + pub try_include: CheckPoint, } impl core::fmt::Display for CannotConnectError { @@ -514,7 +546,8 @@ impl core::fmt::Display for CannotConnectError { write!( f, "introduced chain cannot connect with the original chain, try include {}:{}", - self.try_include.height, self.try_include.hash + self.try_include.height(), + self.try_include.hash() ) } } diff --git a/crates/chain/tests/test_indexed_tx_graph.rs b/crates/chain/tests/test_indexed_tx_graph.rs index 97b159468a..3319b25940 100644 --- a/crates/chain/tests/test_indexed_tx_graph.rs +++ b/crates/chain/tests/test_indexed_tx_graph.rs @@ -8,7 +8,7 @@ use bdk_chain::{ keychain::{Balance, DerivationAdditions, KeychainTxOutIndex}, local_chain::LocalChain, tx_graph::Additions, - BlockId, ChainPosition, ConfirmationHeightAnchor, + ChainPosition, ConfirmationHeightAnchor, }; use bitcoin::{secp256k1::Secp256k1, BlockHash, OutPoint, Script, Transaction, TxIn, TxOut}; use miniscript::Descriptor; @@ -213,10 +213,7 @@ fn test_list_owned_txouts() { *tx, local_chain .checkpoint(height) - .map(|cp| BlockId { - height: cp.height, - hash: cp.hash, - }) + .map(|cp| cp.block_id()) .map(|anchor_block| ConfirmationHeightAnchor { anchor_block, confirmation_height: anchor_block.height, @@ -234,10 +231,7 @@ fn test_list_owned_txouts() { graph: &IndexedTxGraph>| { let chain_tip = local_chain .checkpoint(height) - .map(|cp| BlockId { - height: cp.height, - hash: cp.hash, - }) + .map(|cp| cp.block_id()) .expect("block must exist"); let txouts = graph .graph() diff --git a/example-crates/example_electrum/src/main.rs b/example-crates/example_electrum/src/main.rs index a38c3fad6b..d0a6d43233 100644 --- a/example-crates/example_electrum/src/main.rs +++ b/example-crates/example_electrum/src/main.rs @@ -141,7 +141,7 @@ fn main() -> anyhow::Result<()> { let c = chain .iter_checkpoints(None) .take(ASSUME_FINAL_DEPTH) - .map(|cp| (cp.height, cp.hash)) + .map(|cp| (cp.height(), cp.hash())) .collect::>(); (keychain_spks, c) @@ -251,7 +251,7 @@ fn main() -> anyhow::Result<()> { let c = chain .iter_checkpoints(None) .take(ASSUME_FINAL_DEPTH) - .map(|cp| (cp.height, cp.hash)) + .map(|cp| (cp.height(), cp.hash())) .collect::>(); // drop lock on graph and chain diff --git a/example-crates/wallet_electrum/src/main.rs b/example-crates/wallet_electrum/src/main.rs index 12aa29a64c..63037e669c 100644 --- a/example-crates/wallet_electrum/src/main.rs +++ b/example-crates/wallet_electrum/src/main.rs @@ -53,7 +53,7 @@ fn main() -> Result<(), Box> { .collect(); let electrum_update = client.scan( - &local_chain.map(|cp| (cp.height, cp.hash)).collect(), + &local_chain.map(|cp| (cp.height(), cp.hash())).collect(), keychain_spks, None, None, diff --git a/example-crates/wallet_esplora/src/main.rs b/example-crates/wallet_esplora/src/main.rs index 5b1a884f56..95fdecfc33 100644 --- a/example-crates/wallet_esplora/src/main.rs +++ b/example-crates/wallet_esplora/src/main.rs @@ -53,7 +53,7 @@ fn main() -> Result<(), Box> { }) .collect(); let update = client.scan( - &local_chain.map(|cp| (cp.height, cp.hash)).collect(), + &local_chain.map(|cp| (cp.height(), cp.hash())).collect(), keychain_spks, None, None, diff --git a/example-crates/wallet_esplora_async/src/main.rs b/example-crates/wallet_esplora_async/src/main.rs index 25d2b0f82f..9310563f31 100644 --- a/example-crates/wallet_esplora_async/src/main.rs +++ b/example-crates/wallet_esplora_async/src/main.rs @@ -55,7 +55,7 @@ async fn main() -> Result<(), Box> { .collect(); let update = client .scan( - &local_chain.map(|cp| (cp.height, cp.hash)).collect(), + &local_chain.map(|cp| (cp.height(), cp.hash())).collect(), keychain_spks, [], [],