diff --git a/crates/chain/src/indexed_tx_graph.rs b/crates/chain/src/indexed_tx_graph.rs index c550d86f07..f69b227a21 100644 --- a/crates/chain/src/indexed_tx_graph.rs +++ b/crates/chain/src/indexed_tx_graph.rs @@ -1,12 +1,9 @@ -use core::convert::Infallible; - use alloc::vec::Vec; -use bitcoin::{OutPoint, Script, Transaction, TxOut}; +use bitcoin::{OutPoint, Transaction, TxOut}; use crate::{ - keychain::Balance, tx_graph::{Additions, TxGraph}, - Anchor, Append, BlockId, ChainOracle, FullTxOut, ObservedAs, + Anchor, Append, }; /// A struct that combines [`TxGraph`] and an [`Indexer`] implementation. @@ -29,6 +26,14 @@ impl Default for IndexedTxGraph { } impl IndexedTxGraph { + /// Construct a new [`IndexedTxGraph`] with a given `index`. + pub fn new(index: I) -> Self { + Self { + index, + graph: TxGraph::default(), + } + } + /// Get a reference of the internal transaction graph. pub fn graph(&self) -> &TxGraph { &self.graph @@ -157,115 +162,6 @@ where } } -impl IndexedTxGraph { - pub fn try_list_owned_txouts<'a, C: ChainOracle + 'a>( - &'a self, - chain: &'a C, - chain_tip: BlockId, - ) -> impl Iterator>, C::Error>> + 'a { - self.graph() - .try_list_chain_txouts(chain, chain_tip) - .filter(|r| { - if let Ok(full_txout) = r { - if !self.index.is_spk_owned(&full_txout.txout.script_pubkey) { - return false; - } - } - true - }) - } - - pub fn list_owned_txouts<'a, C: ChainOracle + 'a>( - &'a self, - chain: &'a C, - chain_tip: BlockId, - ) -> impl Iterator>> + 'a { - self.try_list_owned_txouts(chain, chain_tip) - .map(|r| r.expect("oracle is infallible")) - } - - pub fn try_list_owned_unspents<'a, C: ChainOracle + 'a>( - &'a self, - chain: &'a C, - chain_tip: BlockId, - ) -> impl Iterator>, C::Error>> + 'a { - self.graph() - .try_list_chain_unspents(chain, chain_tip) - .filter(|r| { - if let Ok(full_txout) = r { - if !self.index.is_spk_owned(&full_txout.txout.script_pubkey) { - return false; - } - } - true - }) - } - - pub fn list_owned_unspents<'a, C: ChainOracle + 'a>( - &'a self, - chain: &'a C, - chain_tip: BlockId, - ) -> impl Iterator>> + 'a { - self.try_list_owned_unspents(chain, chain_tip) - .map(|r| r.expect("oracle is infallible")) - } - - pub fn try_balance( - &self, - chain: &C, - chain_tip: BlockId, - mut should_trust: F, - ) -> Result - where - C: ChainOracle, - F: FnMut(&Script) -> bool, - { - let tip_height = chain_tip.height; - - let mut immature = 0; - let mut trusted_pending = 0; - let mut untrusted_pending = 0; - let mut confirmed = 0; - - for res in self.try_list_owned_unspents(chain, chain_tip) { - let txout = res?; - - match &txout.chain_position { - ObservedAs::Confirmed(_) => { - if txout.is_confirmed_and_spendable(tip_height) { - confirmed += txout.txout.value; - } else if !txout.is_mature(tip_height) { - immature += txout.txout.value; - } - } - ObservedAs::Unconfirmed(_) => { - if should_trust(&txout.txout.script_pubkey) { - trusted_pending += txout.txout.value; - } else { - untrusted_pending += txout.txout.value; - } - } - } - } - - Ok(Balance { - immature, - trusted_pending, - untrusted_pending, - confirmed, - }) - } - - pub fn balance(&self, chain: &C, chain_tip: BlockId, should_trust: F) -> Balance - where - C: ChainOracle, - F: FnMut(&Script) -> bool, - { - self.try_balance(chain, chain_tip, should_trust) - .expect("error is infallible") - } -} - /// A structure that represents changes to an [`IndexedTxGraph`]. #[derive(Clone, Debug, PartialEq)] #[cfg_attr( @@ -301,6 +197,10 @@ impl Append for IndexedAdditions { self.graph_additions.append(other.graph_additions); self.index_additions.append(other.index_additions); } + + fn is_empty(&self) -> bool { + self.graph_additions.is_empty() && self.index_additions.is_empty() + } } /// Represents a structure that can index transaction data. @@ -320,9 +220,3 @@ pub trait Indexer { /// Determines whether the transaction should be included in the index. fn is_tx_relevant(&self, tx: &Transaction) -> bool; } - -/// A trait that extends [`Indexer`] to also index "owned" script pubkeys. -pub trait OwnedIndexer: Indexer { - /// Determines whether a given script pubkey (`spk`) is owned. - fn is_spk_owned(&self, spk: &Script) -> bool; -} diff --git a/crates/chain/src/keychain.rs b/crates/chain/src/keychain.rs index 81503049bd..f4d398ab0c 100644 --- a/crates/chain/src/keychain.rs +++ b/crates/chain/src/keychain.rs @@ -84,6 +84,10 @@ impl Append for DerivationAdditions { self.0.append(&mut other.0); } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } } impl Default for DerivationAdditions { diff --git a/crates/chain/src/keychain/txout_index.rs b/crates/chain/src/keychain/txout_index.rs index c7a8dd54b4..397c43386d 100644 --- a/crates/chain/src/keychain/txout_index.rs +++ b/crates/chain/src/keychain/txout_index.rs @@ -1,6 +1,6 @@ use crate::{ collections::*, - indexed_tx_graph::{Indexer, OwnedIndexer}, + indexed_tx_graph::Indexer, miniscript::{Descriptor, DescriptorPublicKey}, spk_iter::BIP32_MAX_INDEX, ForEachTxOut, SpkIterator, SpkTxOutIndex, @@ -109,12 +109,6 @@ impl Indexer for KeychainTxOutIndex { } } -impl OwnedIndexer for KeychainTxOutIndex { - fn is_spk_owned(&self, spk: &Script) -> bool { - self.index_of_spk(spk).is_some() - } -} - impl KeychainTxOutIndex { /// Scans an object for relevant outpoints, which are stored and indexed internally. /// @@ -153,6 +147,11 @@ impl KeychainTxOutIndex { &self.inner } + /// Get a reference to the set of indexed outpoints. + pub fn outpoints(&self) -> &BTreeSet<((K, u32), OutPoint)> { + self.inner.outpoints() + } + /// Return a reference to the internal map of the keychain to descriptors. pub fn keychains(&self) -> &BTreeMap> { &self.keychains diff --git a/crates/chain/src/lib.rs b/crates/chain/src/lib.rs index cf3cda3b02..cbadf1709a 100644 --- a/crates/chain/src/lib.rs +++ b/crates/chain/src/lib.rs @@ -33,6 +33,8 @@ pub mod tx_graph; pub use tx_data_traits::*; mod chain_oracle; pub use chain_oracle::*; +mod persist; +pub use persist::*; #[doc(hidden)] pub mod example_utils; diff --git a/crates/chain/src/persist.rs b/crates/chain/src/persist.rs new file mode 100644 index 0000000000..07ff679573 --- /dev/null +++ b/crates/chain/src/persist.rs @@ -0,0 +1,97 @@ +use core::convert::Infallible; + +use crate::Append; + +/// `Persist` wraps a [`PersistBackend`] (`B`) to create a convenient staging area for changes (`C`) +/// before they are persisted. +/// +/// Not all changes to the in-memory representation needs to be written to disk right away, so +/// [`Persist::stage`] can be used to *stage* changes first and then [`Persist::commit`] can be used +/// to write changes to disk. +#[derive(Debug)] +pub struct Persist { + backend: B, + stage: C, +} + +impl Persist +where + B: PersistBackend, + C: Default + Append, +{ + /// Create a new [`Persist`] from [`PersistBackend`]. + pub fn new(backend: B) -> Self { + Self { + backend, + stage: Default::default(), + } + } + + /// Stage a `changeset` to be commited later with [`commit`]. + /// + /// [`commit`]: Self::commit + pub fn stage(&mut self, changeset: C) { + self.stage.append(changeset) + } + + /// Get the changes that have not been commited yet. + pub fn staged(&self) -> &C { + &self.stage + } + + /// Commit the staged changes to the underlying persistance backend. + /// + /// Changes that are committed (if any) are returned. + /// + /// # Error + /// + /// Returns a backend-defined error if this fails. + pub fn commit(&mut self) -> Result, B::WriteError> { + if self.stage.is_empty() { + return Ok(None); + } + self.backend + .write_changes(&self.stage) + // if written successfully, take and return `self.stage` + .map(|_| Some(core::mem::take(&mut self.stage))) + } +} + +/// A persistence backend for [`Persist`]. +/// +/// `C` represents the changeset; a datatype that records changes made to in-memory data structures +/// that are to be persisted, or retrieved from persistence. +pub trait PersistBackend { + /// The error the backend returns when it fails to write. + type WriteError: core::fmt::Debug; + + /// The error the backend returns when it fails to load changesets `C`. + type LoadError: core::fmt::Debug; + + /// Writes a changeset to the persistence backend. + /// + /// It is up to the backend what it does with this. It could store every changeset in a list or + /// it inserts the actual changes into a more structured database. All it needs to guarantee is + /// that [`load_from_persistence`] restores a keychain tracker to what it should be if all + /// changesets had been applied sequentially. + /// + /// [`load_from_persistence`]: Self::load_from_persistence + fn write_changes(&mut self, changeset: &C) -> Result<(), Self::WriteError>; + + /// Return the aggregate changeset `C` from persistence. + fn load_from_persistence(&mut self) -> Result; +} + +impl PersistBackend for () { + type WriteError = Infallible; + + type LoadError = Infallible; + + fn write_changes(&mut self, _changeset: &C) -> Result<(), Self::WriteError> { + Ok(()) + } + + fn load_from_persistence(&mut self) -> Result { + Ok(C::default()) + } +} diff --git a/crates/chain/src/spk_txout_index.rs b/crates/chain/src/spk_txout_index.rs index ae94414921..0eaec4bb79 100644 --- a/crates/chain/src/spk_txout_index.rs +++ b/crates/chain/src/spk_txout_index.rs @@ -2,7 +2,7 @@ use core::ops::RangeBounds; use crate::{ collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap}, - indexed_tx_graph::{Indexer, OwnedIndexer}, + indexed_tx_graph::Indexer, ForEachTxOut, }; use bitcoin::{self, OutPoint, Script, Transaction, TxOut, Txid}; @@ -75,12 +75,6 @@ impl Indexer for SpkTxOutIndex { } } -impl OwnedIndexer for SpkTxOutIndex { - fn is_spk_owned(&self, spk: &Script) -> bool { - self.spk_indices.get(spk).is_some() - } -} - /// This macro is used instead of a member function of `SpkTxOutIndex`, which would result in a /// compiler error[E0521]: "borrowed data escapes out of closure" when we attempt to take a /// reference out of the `ForEachTxOut` closure during scanning. @@ -126,6 +120,11 @@ impl SpkTxOutIndex { scan_txout!(self, op, txout) } + /// Get a reference to the set of indexed outpoints. + pub fn outpoints(&self) -> &BTreeSet<(I, OutPoint)> { + &self.spk_txouts + } + /// Iterate over all known txouts that spend to tracked script pubkeys. pub fn txouts( &self, diff --git a/crates/chain/src/tx_data_traits.rs b/crates/chain/src/tx_data_traits.rs index 8ec695add3..b130793b30 100644 --- a/crates/chain/src/tx_data_traits.rs +++ b/crates/chain/src/tx_data_traits.rs @@ -1,6 +1,7 @@ use crate::collections::BTreeMap; use crate::collections::BTreeSet; use crate::BlockId; +use alloc::vec::Vec; use bitcoin::{Block, OutPoint, Transaction, TxOut}; /// Trait to do something with every txout contained in a structure. @@ -64,20 +65,56 @@ impl Anchor for &'static A { pub trait Append { /// Append another object of the same type onto `self`. fn append(&mut self, other: Self); + + /// Returns whether the structure is considered empty. + fn is_empty(&self) -> bool; } impl Append for () { fn append(&mut self, _other: Self) {} + + fn is_empty(&self) -> bool { + true + } } impl Append for BTreeMap { fn append(&mut self, mut other: Self) { BTreeMap::append(self, &mut other) } + + fn is_empty(&self) -> bool { + BTreeMap::is_empty(self) + } } impl Append for BTreeSet { fn append(&mut self, mut other: Self) { BTreeSet::append(self, &mut other) } + + fn is_empty(&self) -> bool { + BTreeSet::is_empty(self) + } +} + +impl Append for Vec { + fn append(&mut self, mut other: Self) { + Vec::append(self, &mut other) + } + + fn is_empty(&self) -> bool { + Vec::is_empty(self) + } +} + +impl Append for (A, B) { + fn append(&mut self, other: Self) { + Append::append(&mut self.0, other.0); + Append::append(&mut self.1, other.1); + } + + fn is_empty(&self) -> bool { + Append::is_empty(&self.0) && Append::is_empty(&self.1) + } } diff --git a/crates/chain/src/tx_graph.rs b/crates/chain/src/tx_graph.rs index e75255e4af..335a19197f 100644 --- a/crates/chain/src/tx_graph.rs +++ b/crates/chain/src/tx_graph.rs @@ -56,10 +56,11 @@ //! ``` use crate::{ - collections::*, Anchor, Append, BlockId, ChainOracle, ForEachTxOut, FullTxOut, ObservedAs, + collections::*, keychain::Balance, Anchor, Append, BlockId, ChainOracle, ForEachTxOut, + FullTxOut, ObservedAs, }; use alloc::vec::Vec; -use bitcoin::{OutPoint, Transaction, TxOut, Txid}; +use bitcoin::{OutPoint, Script, Transaction, TxOut, Txid}; use core::{ convert::Infallible, ops::{Deref, RangeInclusive}, @@ -762,107 +763,201 @@ impl TxGraph { .map(|r| r.expect("oracle is infallible")) } - /// List outputs that are in `chain` with `chain_tip`. + /// Get a filtered list of outputs from the given `outpoints` that are in `chain` with + /// `chain_tip`. /// - /// Floating ouputs are not iterated over. + /// `outpoints` is a list of outpoints we are interested in, coupled with an outpoint identifier + /// (`OI`) for convenience. If `OI` is not necessary, the caller can use `()`, or + /// [`Iterator::enumerate`] over a list of [`OutPoint`]s. /// - /// The `filter_predicate` should return true for outputs that we wish to iterate over. + /// Floating outputs are ignored. /// /// # Error /// - /// A returned item can error if the [`ChainOracle`] implementation (`chain`) fails. + /// An [`Iterator::Item`] can be an [`Err`] if the [`ChainOracle`] implementation (`chain`) + /// fails. /// - /// If the [`ChainOracle`] is infallible, [`list_chain_txouts`] can be used instead. + /// If the [`ChainOracle`] implementation is infallible, [`filter_chain_txouts`] can be used + /// instead. /// - /// [`list_chain_txouts`]: Self::list_chain_txouts - pub fn try_list_chain_txouts<'a, C: ChainOracle + 'a>( + /// [`filter_chain_txouts`]: Self::filter_chain_txouts + pub fn try_filter_chain_txouts<'a, C: ChainOracle + 'a, OI: Clone + 'a>( &'a self, chain: &'a C, chain_tip: BlockId, - ) -> impl Iterator>, C::Error>> + 'a { - self.try_list_chain_txs(chain, chain_tip) - .flat_map(move |tx_res| match tx_res { - Ok(canonical_tx) => canonical_tx - .node - .output - .iter() - .enumerate() - .map(|(vout, txout)| { - let outpoint = OutPoint::new(canonical_tx.node.txid, vout as _); - Ok((outpoint, txout.clone(), canonical_tx.clone())) - }) - .collect::>(), - Err(err) => vec![Err(err)], - }) - .map(move |res| -> Result<_, C::Error> { - let ( - outpoint, - txout, - CanonicalTx { - observed_as, - node: tx_node, - }, - ) = res?; - let chain_position = observed_as.cloned(); - let spent_by = self - .try_get_chain_spend(chain, chain_tip, outpoint)? - .map(|(obs_as, txid)| (obs_as.cloned(), txid)); - let is_on_coinbase = tx_node.tx.is_coin_base(); - Ok(FullTxOut { - outpoint, - txout, - chain_position, - spent_by, - is_on_coinbase, - }) - }) + outpoints: impl IntoIterator + 'a, + ) -> impl Iterator>), C::Error>> + 'a { + outpoints + .into_iter() + .map( + move |(spk_i, op)| -> Result)>, C::Error> { + let tx_node = match self.get_tx_node(op.txid) { + Some(n) => n, + None => return Ok(None), + }; + + let txout = match tx_node.tx.output.get(op.vout as usize) { + Some(txout) => txout.clone(), + None => return Ok(None), + }; + + let chain_position = + match self.try_get_chain_position(chain, chain_tip, op.txid)? { + Some(pos) => pos.cloned(), + None => return Ok(None), + }; + + let spent_by = self + .try_get_chain_spend(chain, chain_tip, op)? + .map(|(a, txid)| (a.cloned(), txid)); + + Ok(Some(( + spk_i, + FullTxOut { + outpoint: op, + txout, + chain_position, + spent_by, + is_on_coinbase: tx_node.tx.is_coin_base(), + }, + ))) + }, + ) + .filter_map(Result::transpose) } - /// List outputs that are in `chain` with `chain_tip`. + /// Get a filtered list of outputs from the given `outpoints` that are in `chain` with + /// `chain_tip`. /// - /// This is the infallible version of [`try_list_chain_txouts`]. + /// This is the infallible version of [`try_filter_chain_txouts`]. /// - /// [`try_list_chain_txouts`]: Self::try_list_chain_txouts - pub fn list_chain_txouts<'a, C: ChainOracle + 'a>( + /// [`try_filter_chain_txouts`]: Self::try_filter_chain_txouts + pub fn filter_chain_txouts<'a, C: ChainOracle + 'a, OI: Clone + 'a>( &'a self, chain: &'a C, chain_tip: BlockId, - ) -> impl Iterator>> + 'a { - self.try_list_chain_txouts(chain, chain_tip) - .map(|r| r.expect("error in infallible")) + outpoints: impl IntoIterator + 'a, + ) -> impl Iterator>)> + 'a { + self.try_filter_chain_txouts(chain, chain_tip, outpoints) + .map(|r| r.expect("oracle is infallible")) } - /// List unspent outputs (UTXOs) that are in `chain` with `chain_tip`. + /// Get a filtered list of unspent outputs (UTXOs) from the given `outpoints` that are in + /// `chain` with `chain_tip`. /// - /// Floating outputs are not iterated over. + /// `outpoints` is a list of outpoints we are interested in, coupled with an outpoint identifier + /// (`OI`) for convenience. If `OI` is not necessary, the caller can use `()`, or + /// [`Iterator::enumerate`] over a list of [`OutPoint`]s. + /// + /// Floating outputs are ignored. /// /// # Error /// - /// An item can be an error if the [`ChainOracle`] implementation fails. If the oracle is - /// infallible, [`list_chain_unspents`] can be used instead. + /// An [`Iterator::Item`] can be an [`Err`] if the [`ChainOracle`] implementation (`chain`) + /// fails. + /// + /// If the [`ChainOracle`] implementation is infallible, [`filter_chain_unspents`] can be used + /// instead. /// - /// [`list_chain_unspents`]: Self::list_chain_unspents - pub fn try_list_chain_unspents<'a, C: ChainOracle + 'a>( + /// [`filter_chain_unspents`]: Self::filter_chain_unspents + pub fn try_filter_chain_unspents<'a, C: ChainOracle + 'a, OI: Clone + 'a>( &'a self, chain: &'a C, chain_tip: BlockId, - ) -> impl Iterator>, C::Error>> + 'a { - self.try_list_chain_txouts(chain, chain_tip) - .filter(|r| matches!(r, Ok(txo) if txo.spent_by.is_none())) + outpoints: impl IntoIterator + 'a, + ) -> impl Iterator>), C::Error>> + 'a { + self.try_filter_chain_txouts(chain, chain_tip, outpoints) + .filter(|r| match r { + // keep unspents, drop spents + Ok((_, full_txo)) => full_txo.spent_by.is_none(), + // keep errors + Err(_) => true, + }) } - /// List unspent outputs (UTXOs) that are in `chain` with `chain_tip`. + /// Get a filtered list of unspent outputs (UTXOs) from the given `outpoints` that are in + /// `chain` with `chain_tip`. /// - /// This is the infallible version of [`try_list_chain_unspents`]. + /// This is the infallible version of [`try_filter_chain_unspents`]. /// - /// [`try_list_chain_unspents`]: Self::try_list_chain_unspents - pub fn list_chain_unspents<'a, C: ChainOracle + 'a>( + /// [`try_filter_chain_unspents`]: Self::try_filter_chain_unspents + pub fn filter_chain_unspents<'a, C: ChainOracle + 'a, OI: Clone + 'a>( &'a self, chain: &'a C, - static_block: BlockId, - ) -> impl Iterator>> + 'a { - self.try_list_chain_unspents(chain, static_block) - .map(|r| r.expect("error is infallible")) + chain_tip: BlockId, + txouts: impl IntoIterator + 'a, + ) -> impl Iterator>)> + 'a { + self.try_filter_chain_unspents(chain, chain_tip, txouts) + .map(|r| r.expect("oracle is infallible")) + } + + /// Get the total balance of `outpoints` that are in `chain` of `chain_tip`. + /// + /// The output of `trust_predicate` should return `true` for scripts that we trust. + /// + /// `outpoints` is a list of outpoints we are interested in, coupled with an outpoint identifier + /// (`OI`) for convenience. If `OI` is not necessary, the caller can use `()`, or + /// [`Iterator::enumerate`] over a list of [`OutPoint`]s. + /// + /// If the provided [`ChainOracle`] implementation (`chain`) is infallible, [`balance`] can be + /// used instead. + /// + /// [`balance`]: Self::balance + pub fn try_balance( + &self, + chain: &C, + chain_tip: BlockId, + outpoints: impl IntoIterator, + mut trust_predicate: impl FnMut(&OI, &Script) -> bool, + ) -> Result { + let mut immature = 0; + let mut trusted_pending = 0; + let mut untrusted_pending = 0; + let mut confirmed = 0; + + for res in self.try_filter_chain_unspents(chain, chain_tip, outpoints) { + let (spk_i, txout) = res?; + + match &txout.chain_position { + ObservedAs::Confirmed(_) => { + if txout.is_confirmed_and_spendable(chain_tip.height) { + confirmed += txout.txout.value; + } else if !txout.is_mature(chain_tip.height) { + immature += txout.txout.value; + } + } + ObservedAs::Unconfirmed(_) => { + if trust_predicate(&spk_i, &txout.txout.script_pubkey) { + trusted_pending += txout.txout.value; + } else { + untrusted_pending += txout.txout.value; + } + } + } + } + + Ok(Balance { + immature, + trusted_pending, + untrusted_pending, + confirmed, + }) + } + + /// Get the total balance of `outpoints` that are in `chain` of `chain_tip`. + /// + /// This is the infallible version of [`try_balance`]. + /// + /// [`try_balance`]: Self::try_balance + pub fn balance, OI: Clone>( + &self, + chain: &C, + chain_tip: BlockId, + outpoints: impl IntoIterator, + trust_predicate: impl FnMut(&OI, &Script) -> bool, + ) -> Balance { + self.try_balance(chain, chain_tip, outpoints, trust_predicate) + .expect("oracle is infallible") } } @@ -940,6 +1035,13 @@ impl Append for Additions { .collect::>(), ); } + + fn is_empty(&self) -> bool { + self.tx.is_empty() + && self.txout.is_empty() + && self.anchors.is_empty() + && self.last_seen.is_empty() + } } impl AsRef> for TxGraph { diff --git a/crates/chain/tests/test_indexed_tx_graph.rs b/crates/chain/tests/test_indexed_tx_graph.rs index f32ffe4f0b..f231f76835 100644 --- a/crates/chain/tests/test_indexed_tx_graph.rs +++ b/crates/chain/tests/test_indexed_tx_graph.rs @@ -236,23 +236,36 @@ fn test_list_owned_txouts() { .map(|&hash| BlockId { height, hash }) .expect("block must exist"); let txouts = graph - .list_owned_txouts(&local_chain, chain_tip) + .graph() + .filter_chain_txouts( + &local_chain, + chain_tip, + graph.index.outpoints().iter().cloned(), + ) .collect::>(); let utxos = graph - .list_owned_unspents(&local_chain, chain_tip) + .graph() + .filter_chain_unspents( + &local_chain, + chain_tip, + graph.index.outpoints().iter().cloned(), + ) .collect::>(); - let balance = graph.balance(&local_chain, chain_tip, |spk: &Script| { - trusted_spks.contains(spk) - }); + let balance = graph.graph().balance( + &local_chain, + chain_tip, + graph.index.outpoints().iter().cloned(), + |_, spk: &Script| trusted_spks.contains(spk), + ); assert_eq!(txouts.len(), 5); assert_eq!(utxos.len(), 4); let confirmed_txouts_txid = txouts .iter() - .filter_map(|full_txout| { + .filter_map(|(_, full_txout)| { if matches!(full_txout.chain_position, ObservedAs::Confirmed(_)) { Some(full_txout.outpoint.txid) } else { @@ -263,7 +276,7 @@ fn test_list_owned_txouts() { let unconfirmed_txouts_txid = txouts .iter() - .filter_map(|full_txout| { + .filter_map(|(_, full_txout)| { if matches!(full_txout.chain_position, ObservedAs::Unconfirmed(_)) { Some(full_txout.outpoint.txid) } else { @@ -274,7 +287,7 @@ fn test_list_owned_txouts() { let confirmed_utxos_txid = utxos .iter() - .filter_map(|full_txout| { + .filter_map(|(_, full_txout)| { if matches!(full_txout.chain_position, ObservedAs::Confirmed(_)) { Some(full_txout.outpoint.txid) } else { @@ -285,7 +298,7 @@ fn test_list_owned_txouts() { let unconfirmed_utxos_txid = utxos .iter() - .filter_map(|full_txout| { + .filter_map(|(_, full_txout)| { if matches!(full_txout.chain_position, ObservedAs::Unconfirmed(_)) { Some(full_txout.outpoint.txid) } else { diff --git a/crates/file_store/src/entry_iter.rs b/crates/file_store/src/entry_iter.rs new file mode 100644 index 0000000000..770f264f3a --- /dev/null +++ b/crates/file_store/src/entry_iter.rs @@ -0,0 +1,100 @@ +use bincode::Options; +use std::{ + fs::File, + io::{self, Seek}, + marker::PhantomData, +}; + +use crate::bincode_options; + +/// Iterator over entries in a file store. +/// +/// Reads and returns an entry each time [`next`] is called. If an error occurs while reading the +/// iterator will yield a `Result::Err(_)` instead and then `None` for the next call to `next`. +/// +/// [`next`]: Self::next +pub struct EntryIter<'t, T> { + db_file: Option<&'t mut File>, + + /// The file position for the first read of `db_file`. + start_pos: Option, + types: PhantomData, +} + +impl<'t, T> EntryIter<'t, T> { + pub fn new(start_pos: u64, db_file: &'t mut File) -> Self { + Self { + db_file: Some(db_file), + start_pos: Some(start_pos), + types: PhantomData, + } + } +} + +impl<'t, T> Iterator for EntryIter<'t, T> +where + T: serde::de::DeserializeOwned, +{ + type Item = Result; + + fn next(&mut self) -> Option { + // closure which reads a single entry starting from `self.pos` + let read_one = |f: &mut File, start_pos: Option| -> Result, IterError> { + let pos = match start_pos { + Some(pos) => f.seek(io::SeekFrom::Start(pos))?, + None => f.stream_position()?, + }; + + match bincode_options().deserialize_from(&*f) { + Ok(changeset) => { + f.stream_position()?; + Ok(Some(changeset)) + } + Err(e) => { + if let bincode::ErrorKind::Io(inner) = &*e { + if inner.kind() == io::ErrorKind::UnexpectedEof { + let eof = f.seek(io::SeekFrom::End(0))?; + if pos == eof { + return Ok(None); + } + } + } + f.seek(io::SeekFrom::Start(pos))?; + Err(IterError::Bincode(*e)) + } + } + }; + + let result = read_one(self.db_file.as_mut()?, self.start_pos.take()); + if result.is_err() { + self.db_file = None; + } + result.transpose() + } +} + +impl From for IterError { + fn from(value: io::Error) -> Self { + IterError::Io(value) + } +} + +/// Error type for [`EntryIter`]. +#[derive(Debug)] +pub enum IterError { + /// Failure to read from the file. + Io(io::Error), + /// Failure to decode data from the file. + Bincode(bincode::ErrorKind), +} + +impl core::fmt::Display for IterError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + IterError::Io(e) => write!(f, "io error trying to read entry {}", e), + IterError::Bincode(e) => write!(f, "bincode error while reading entry {}", e), + } + } +} + +impl std::error::Error for IterError {} diff --git a/crates/file_store/src/file_store.rs b/crates/file_store/src/keychain_store.rs similarity index 78% rename from crates/file_store/src/file_store.rs rename to crates/file_store/src/keychain_store.rs index 824e3ccc56..5f5074d511 100644 --- a/crates/file_store/src/file_store.rs +++ b/crates/file_store/src/keychain_store.rs @@ -6,14 +6,15 @@ use bdk_chain::{ keychain::{KeychainChangeSet, KeychainTracker}, sparse_chain, }; -use bincode::{DefaultOptions, Options}; -use core::marker::PhantomData; +use bincode::Options; use std::{ fs::{File, OpenOptions}, io::{self, Read, Seek, Write}, path::Path, }; +use crate::{bincode_options, EntryIter, IterError}; + /// BDK File Store magic bytes length. const MAGIC_BYTES_LEN: usize = 12; @@ -28,10 +29,6 @@ pub struct KeychainStore { changeset_type_params: core::marker::PhantomData<(K, P)>, } -fn bincode() -> impl bincode::Options { - DefaultOptions::new().with_varint_encoding() -} - impl KeychainStore where K: Ord + Clone + core::fmt::Debug, @@ -85,11 +82,8 @@ where /// **WARNING**: This method changes the write position in the underlying file. You should /// always iterate over all entries until `None` is returned if you want your next write to go /// at the end; otherwise, you will write over existing entries. - pub fn iter_changesets(&mut self) -> Result>, io::Error> { - self.db_file - .seek(io::SeekFrom::Start(MAGIC_BYTES_LEN as _))?; - - Ok(EntryIter::new(&mut self.db_file)) + pub fn iter_changesets(&mut self) -> Result>, io::Error> { + Ok(EntryIter::new(MAGIC_BYTES_LEN as u64, &mut self.db_file)) } /// Loads all the changesets that have been stored as one giant changeset. @@ -144,7 +138,7 @@ where return Ok(()); } - bincode() + bincode_options() .serialize_into(&mut self.db_file, changeset) .map_err(|e| match *e { bincode::ErrorKind::Io(inner) => inner, @@ -197,92 +191,6 @@ impl From for FileError { impl std::error::Error for FileError {} -/// Error type for [`EntryIter`]. -#[derive(Debug)] -pub enum IterError { - /// Failure to read from the file. - Io(io::Error), - /// Failure to decode data from the file. - Bincode(bincode::ErrorKind), -} - -impl core::fmt::Display for IterError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - IterError::Io(e) => write!(f, "io error trying to read entry {}", e), - IterError::Bincode(e) => write!(f, "bincode error while reading entry {}", e), - } - } -} - -impl std::error::Error for IterError {} - -/// Iterator over entries in a file store. -/// -/// Reads and returns an entry each time [`next`] is called. If an error occurs while reading the -/// iterator will yield a `Result::Err(_)` instead and then `None` for the next call to `next`. -/// -/// [`next`]: Self::next -pub struct EntryIter<'a, V> { - db_file: &'a mut File, - types: PhantomData, - error_exit: bool, -} - -impl<'a, V> EntryIter<'a, V> { - pub fn new(db_file: &'a mut File) -> Self { - Self { - db_file, - types: PhantomData, - error_exit: false, - } - } -} - -impl<'a, V> Iterator for EntryIter<'a, V> -where - V: serde::de::DeserializeOwned, -{ - type Item = Result; - - fn next(&mut self) -> Option { - let result = (|| { - let pos = self.db_file.stream_position()?; - - match bincode().deserialize_from(&mut self.db_file) { - Ok(changeset) => Ok(Some(changeset)), - Err(e) => { - if let bincode::ErrorKind::Io(inner) = &*e { - if inner.kind() == io::ErrorKind::UnexpectedEof { - let eof = self.db_file.seek(io::SeekFrom::End(0))?; - if pos == eof { - return Ok(None); - } - } - } - - self.db_file.seek(io::SeekFrom::Start(pos))?; - Err(IterError::Bincode(*e)) - } - } - })(); - - let result = result.transpose(); - - if let Some(Err(_)) = &result { - self.error_exit = true; - } - - result - } -} - -impl From for IterError { - fn from(value: io::Error) -> Self { - IterError::Io(value) - } -} - #[cfg(test)] mod test { use super::*; @@ -290,6 +198,7 @@ mod test { keychain::{DerivationAdditions, KeychainChangeSet}, TxHeight, }; + use bincode::DefaultOptions; use std::{ io::{Read, Write}, vec::Vec, diff --git a/crates/file_store/src/lib.rs b/crates/file_store/src/lib.rs index e334741947..b10c8c29ef 100644 --- a/crates/file_store/src/lib.rs +++ b/crates/file_store/src/lib.rs @@ -1,10 +1,51 @@ #![doc = include_str!("../README.md")] -mod file_store; +mod entry_iter; +mod keychain_store; +mod store; +use std::io; + use bdk_chain::{ keychain::{KeychainChangeSet, KeychainTracker, PersistBackend}, sparse_chain::ChainPosition, }; -pub use file_store::*; +use bincode::{DefaultOptions, Options}; +pub use entry_iter::*; +pub use keychain_store::*; +pub use store::*; + +pub(crate) fn bincode_options() -> impl bincode::Options { + DefaultOptions::new().with_varint_encoding() +} + +/// Error that occurs due to problems encountered with the file. +#[derive(Debug)] +pub enum FileError<'a> { + /// IO error, this may mean that the file is too short. + Io(io::Error), + /// Magic bytes do not match what is expected. + InvalidMagicBytes { got: Vec, expected: &'a [u8] }, +} + +impl<'a> core::fmt::Display for FileError<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Io(e) => write!(f, "io error trying to read file: {}", e), + Self::InvalidMagicBytes { got, expected } => write!( + f, + "file has invalid magic bytes: expected={:?} got={:?}", + expected, got, + ), + } + } +} + +impl<'a> From for FileError<'a> { + fn from(value: io::Error) -> Self { + Self::Io(value) + } +} + +impl<'a> std::error::Error for FileError<'a> {} impl PersistBackend for KeychainStore where diff --git a/crates/file_store/src/store.rs b/crates/file_store/src/store.rs new file mode 100644 index 0000000000..a4aa2963ce --- /dev/null +++ b/crates/file_store/src/store.rs @@ -0,0 +1,255 @@ +use std::{ + fmt::Debug, + fs::{File, OpenOptions}, + io::{self, Read, Seek, Write}, + marker::PhantomData, + path::Path, +}; + +use bdk_chain::{Append, PersistBackend}; +use bincode::Options; + +use crate::{bincode_options, EntryIter, FileError, IterError}; + +/// Persists an append-only list of changesets (`C`) to a single file. +/// +/// The changesets are the results of altering a tracker implementation (`T`). +#[derive(Debug)] +pub struct Store<'a, C> { + magic: &'a [u8], + db_file: File, + marker: PhantomData, +} + +impl<'a, C> PersistBackend for Store<'a, C> +where + C: Default + Append + serde::Serialize + serde::de::DeserializeOwned, +{ + type WriteError = std::io::Error; + + type LoadError = IterError; + + fn write_changes(&mut self, changeset: &C) -> Result<(), Self::WriteError> { + self.append_changeset(changeset) + } + + fn load_from_persistence(&mut self) -> Result { + let (changeset, result) = self.aggregate_changesets(); + result.map(|_| changeset) + } +} + +impl<'a, C> Store<'a, C> +where + C: Default + Append + serde::Serialize + serde::de::DeserializeOwned, +{ + /// Creates a new store from a [`File`]. + /// + /// The file must have been opened with read and write permissions. + /// + /// `magic` is the expected prefixed bytes of the file. If this does not match, an error will be + /// returned. + /// + /// [`File`]: std::fs::File + pub fn new(magic: &'a [u8], mut db_file: File) -> Result { + db_file.rewind()?; + + let mut magic_buf = vec![0_u8; magic.len()]; + db_file.read_exact(magic_buf.as_mut())?; + + if magic_buf != magic { + return Err(FileError::InvalidMagicBytes { + got: magic_buf, + expected: magic, + }); + } + + Ok(Self { + magic, + db_file, + marker: Default::default(), + }) + } + + /// Creates or loads a store from `db_path`. + /// + /// If no file exists there, it will be created. + /// + /// Refer to [`new`] for documentation on the `magic` input. + /// + /// [`new`]: Self::new + pub fn new_from_path

(magic: &'a [u8], db_path: P) -> Result + where + P: AsRef, + { + let already_exists = db_path.as_ref().exists(); + + let mut db_file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(db_path)?; + + if !already_exists { + db_file.write_all(magic)?; + } + + Self::new(magic, db_file) + } + + /// Iterates over the stored changeset from first to last, changing the seek position at each + /// iteration. + /// + /// The iterator may fail to read an entry and therefore return an error. However, the first time + /// it returns an error will be the last. After doing so, the iterator will always yield `None`. + /// + /// **WARNING**: This method changes the write position in the underlying file. You should + /// always iterate over all entries until `None` is returned if you want your next write to go + /// at the end; otherwise, you will write over existing entries. + pub fn iter_changesets(&mut self) -> EntryIter { + EntryIter::new(self.magic.len() as u64, &mut self.db_file) + } + + /// Loads all the changesets that have been stored as one giant changeset. + /// + /// This function returns a tuple of the aggregate changeset and a result that indicates + /// whether an error occurred while reading or deserializing one of the entries. If so the + /// changeset will consist of all of those it was able to read. + /// + /// You should usually check the error. In many applications, it may make sense to do a full + /// wallet scan with a stop-gap after getting an error, since it is likely that one of the + /// changesets it was unable to read changed the derivation indices of the tracker. + /// + /// **WARNING**: This method changes the write position of the underlying file. The next + /// changeset will be written over the erroring entry (or the end of the file if none existed). + pub fn aggregate_changesets(&mut self) -> (C, Result<(), IterError>) { + let mut changeset = C::default(); + let result = (|| { + for next_changeset in self.iter_changesets() { + changeset.append(next_changeset?); + } + Ok(()) + })(); + + (changeset, result) + } + + /// Append a new changeset to the file and truncate the file to the end of the appended + /// changeset. + /// + /// The truncation is to avoid the possibility of having a valid but inconsistent changeset + /// directly after the appended changeset. + pub fn append_changeset(&mut self, changeset: &C) -> Result<(), io::Error> { + // no need to write anything if changeset is empty + if changeset.is_empty() { + return Ok(()); + } + + bincode_options() + .serialize_into(&mut self.db_file, changeset) + .map_err(|e| match *e { + bincode::ErrorKind::Io(inner) => inner, + unexpected_err => panic!("unexpected bincode error: {}", unexpected_err), + })?; + + // truncate file after this changeset addition + // if this is not done, data after this changeset may represent valid changesets, however + // applying those changesets on top of this one may result in an inconsistent state + let pos = self.db_file.stream_position()?; + self.db_file.set_len(pos)?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use bincode::DefaultOptions; + use std::{ + io::{Read, Write}, + vec::Vec, + }; + use tempfile::NamedTempFile; + + const TEST_MAGIC_BYTES_LEN: usize = 12; + const TEST_MAGIC_BYTES: [u8; TEST_MAGIC_BYTES_LEN] = + [98, 100, 107, 102, 115, 49, 49, 49, 49, 49, 49, 49]; + + type TestChangeSet = Vec; + + #[derive(Debug)] + struct TestTracker; + + #[test] + fn new_fails_if_file_is_too_short() { + let mut file = NamedTempFile::new().unwrap(); + file.write_all(&TEST_MAGIC_BYTES[..TEST_MAGIC_BYTES_LEN - 1]) + .expect("should write"); + + match Store::::new(&TEST_MAGIC_BYTES, file.reopen().unwrap()) { + Err(FileError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof), + unexpected => panic!("unexpected result: {:?}", unexpected), + }; + } + + #[test] + fn new_fails_if_magic_bytes_are_invalid() { + let invalid_magic_bytes = "ldkfs0000000"; + + let mut file = NamedTempFile::new().unwrap(); + file.write_all(invalid_magic_bytes.as_bytes()) + .expect("should write"); + + match Store::::new(&TEST_MAGIC_BYTES, file.reopen().unwrap()) { + Err(FileError::InvalidMagicBytes { got, .. }) => { + assert_eq!(got, invalid_magic_bytes.as_bytes()) + } + unexpected => panic!("unexpected result: {:?}", unexpected), + }; + } + + #[test] + fn append_changeset_truncates_invalid_bytes() { + // initial data to write to file (magic bytes + invalid data) + let mut data = [255_u8; 2000]; + data[..TEST_MAGIC_BYTES_LEN].copy_from_slice(&TEST_MAGIC_BYTES); + + let changeset = vec!["one".into(), "two".into(), "three!".into()]; + + let mut file = NamedTempFile::new().unwrap(); + file.write_all(&data).expect("should write"); + + let mut store = Store::::new(&TEST_MAGIC_BYTES, file.reopen().unwrap()) + .expect("should open"); + match store.iter_changesets().next() { + Some(Err(IterError::Bincode(_))) => {} + unexpected_res => panic!("unexpected result: {:?}", unexpected_res), + } + + store.append_changeset(&changeset).expect("should append"); + + drop(store); + + let got_bytes = { + let mut buf = Vec::new(); + file.reopen() + .unwrap() + .read_to_end(&mut buf) + .expect("should read"); + buf + }; + + let expected_bytes = { + let mut buf = TEST_MAGIC_BYTES.to_vec(); + DefaultOptions::new() + .with_varint_encoding() + .serialize_into(&mut buf, &changeset) + .expect("should encode"); + buf + }; + + assert_eq!(got_bytes, expected_bytes); + } +}