Skip to content

Commit

Permalink
feat(wallet)!: add new_or_load methods
Browse files Browse the repository at this point in the history
These methods try to load wallet from persistence and initializes the
wallet instead if non-existant.

An internal helper method `create_signers` is added to reuse code.
Documentation is also improved.
  • Loading branch information
evanlinjin committed Oct 27, 2023
1 parent 844706e commit eb39808
Showing 1 changed file with 189 additions and 50 deletions.
239 changes: 189 additions & 50 deletions crates/bdk/src/wallet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use bdk_chain::{
Append, BlockId, ChainPosition, ConfirmationTime, ConfirmationTimeAnchor, FullTxOut,
IndexedTxGraph, Persist, PersistBackend,
};
use bitcoin::secp256k1::Secp256k1;
use bitcoin::secp256k1::{All, Secp256k1};
use bitcoin::sighash::{EcdsaSighashType, TapSighashType};
use bitcoin::{
absolute, Address, Network, OutPoint, Script, ScriptBuf, Sequence, Transaction, TxOut, Txid,
Expand Down Expand Up @@ -246,8 +246,13 @@ impl Wallet {
}
}

/// The error type when constructing a fresh [`Wallet`].
///
/// Methods [`new`] and [`new_with_genesis_hash`] may return this error.
///
/// [`new`]: Wallet::new
/// [`new_with_genesis_hash`]: Wallet::new_with_genesis_hash
#[derive(Debug)]
/// Error returned from [`Wallet::new`]
pub enum NewError<W> {
/// There was problem with the passed-in descriptor(s).
Descriptor(crate::descriptor::DescriptorError),
Expand All @@ -270,7 +275,11 @@ where
#[cfg(feature = "std")]
impl<W> std::error::Error for NewError<W> where W: core::fmt::Display + core::fmt::Debug {}

/// An error that may occur when loading a [`Wallet`] from persistence.
/// The error type when loading a [`Wallet`] from persistence.
///
/// Method [`load`] may return this error.
///
/// [`load`]: Wallet::load
#[derive(Debug)]
pub enum LoadError<L> {
/// There was a problem with the passed-in descriptor(s).
Expand Down Expand Up @@ -300,6 +309,64 @@ where
#[cfg(feature = "std")]
impl<L> std::error::Error for LoadError<L> where L: core::fmt::Display + core::fmt::Debug {}

/// Error type for when we try load a [`Wallet`] from persistence and creating it if non-existant.
///
/// Methods [`new_or_load`] and [`new_or_load_with_genesis_hash`] may return this error.
///
/// [`new_or_load`]: Wallet::new_or_load
/// [`new_or_load_with_genesis_hash`]: Wallet::new_or_load_with_genesis_hash
#[derive(Debug)]
pub enum NewOrLoadError<W, L> {
/// There is a problem with the passed-in descriptor.
Descriptor(crate::descriptor::DescriptorError),
/// Writing to the persistence backend failed.
Write(W),
/// Loading from the persistence backend failed.
Load(L),
/// The loaded genesis hash does not match what was provided.
LoadedGenesisDoesNotMatch {
/// The expected genesis block hash.
expected: BlockHash,
/// The block hash loaded from persistence.
got: Option<BlockHash>,
},
/// The loaded network type does not match what was provided.
LoadedNetworkDoesNotMatch {
/// The expected network type.
expected: Network,
/// The network type loaded from persistence.
got: Option<Network>,
},
}

impl<W, L> fmt::Display for NewOrLoadError<W, L>
where
W: fmt::Display,
L: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NewOrLoadError::Descriptor(e) => e.fmt(f),
NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e),
NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e),
NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => {
write!(f, "loaded genesis hash is not {}, got {:?}", expected, got)
}
NewOrLoadError::LoadedNetworkDoesNotMatch { expected, got } => {
write!(f, "loaded network type is not {}, got {:?}", expected, got)
}
}
}
}

#[cfg(feature = "std")]
impl<W, L> std::error::Error for NewOrLoadError<W, L>
where
W: core::fmt::Display + core::fmt::Debug,
L: core::fmt::Display + core::fmt::Debug,
{
}

/// An error that may occur when inserting a transaction into [`Wallet`].
#[derive(Debug)]
pub enum InsertTxError {
Expand All @@ -314,8 +381,7 @@ pub enum InsertTxError {
}

impl<D> Wallet<D> {
/// Create a wallet from a `descriptor` (and an optional `change_descriptor`) and load related
/// transaction data from `db`.
/// Initialize an empty [`Wallet`].
pub fn new<E: IntoWalletDescriptor>(
descriptor: E,
change_descriptor: Option<E>,
Expand All @@ -329,9 +395,10 @@ impl<D> Wallet<D> {
Self::new_with_genesis_hash(descriptor, change_descriptor, db, network, genesis_hash)
}

/// Create a new [`Wallet`] with a custom genesis hash.
/// Initialize an empty [`Wallet`] with a custom genesis hash.
///
/// This is like [`Wallet::new`] with an additional `custom_genesis_hash` parameter.
/// This is like [`Wallet::new`] with an additional `genesis_hash` parameter. This is useful
/// for syncing from alternative networks.
pub fn new_with_genesis_hash<E: IntoWalletDescriptor>(
descriptor: E,
change_descriptor: Option<E>,
Expand All @@ -343,33 +410,18 @@ impl<D> Wallet<D> {
D: PersistBackend<ChangeSet>,
{
let secp = Secp256k1::new();
let (chain, _) = LocalChain::from_genesis_hash(genesis_hash);
let mut indexed_graph =
IndexedTxGraph::<ConfirmationTimeAnchor, KeychainTxOutIndex<KeychainKind>>::default();
let (chain, chain_changeset) = LocalChain::from_genesis_hash(genesis_hash);
let mut index = KeychainTxOutIndex::<KeychainKind>::default();

let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, &secp, network)
.map_err(NewError::Descriptor)?;
indexed_graph
.index
.add_keychain(KeychainKind::External, descriptor.clone());
let signers = Arc::new(SignersContainer::build(keymap, &descriptor, &secp));

let change_signers = Arc::new(match change_descriptor {
Some(desc) => {
let (descriptor, keymap) = into_wallet_descriptor_checked(desc, &secp, network)
.map_err(NewError::Descriptor)?;
let signers = SignersContainer::build(keymap, &descriptor, &secp);
indexed_graph
.index
.add_keychain(KeychainKind::Internal, descriptor);
signers
}
None => SignersContainer::new(),
});
let (signers, change_signers) =
create_signers(&mut index, &secp, descriptor, change_descriptor, network)
.map_err(NewError::Descriptor)?;

let indexed_graph = IndexedTxGraph::new(index);

let mut persist = Persist::new(db);
persist.stage(ChangeSet {
chain: chain.initial_changeset(),
chain: chain_changeset,
indexed_tx_graph: indexed_graph.initial_changeset(),
network: Some(network),
});
Expand All @@ -386,7 +438,7 @@ impl<D> Wallet<D> {
})
}

/// Load [`Wallet`] from persistence.
/// Load [`Wallet`] from the given persistence backend.
pub fn load<E: IntoWalletDescriptor>(
descriptor: E,
change_descriptor: Option<E>,
Expand All @@ -396,31 +448,15 @@ impl<D> Wallet<D> {
D: PersistBackend<ChangeSet>,
{
let secp = Secp256k1::new();

let changeset = db.load_from_persistence().map_err(LoadError::Load)?;
let network = changeset.network.ok_or(LoadError::MissingNetwork)?;

let chain =
LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?;

let mut index = KeychainTxOutIndex::<KeychainKind>::default();

let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, &secp, network)
.map_err(LoadError::Descriptor)?;
let signers = Arc::new(SignersContainer::build(keymap, &descriptor, &secp));
index.add_keychain(KeychainKind::External, descriptor);

let change_signers = Arc::new(match change_descriptor {
Some(descriptor) => {
let (descriptor, keymap) =
into_wallet_descriptor_checked(descriptor, &secp, network)
.map_err(LoadError::Descriptor)?;
let signers = SignersContainer::build(keymap, &descriptor, &secp);
index.add_keychain(KeychainKind::Internal, descriptor);
signers
}
None => SignersContainer::new(),
});
let (signers, change_signers) =
create_signers(&mut index, &secp, descriptor, change_descriptor, network)
.map_err(LoadError::Descriptor)?;

let indexed_graph = IndexedTxGraph::new(index);
let persist = Persist::new(db);
Expand All @@ -436,6 +472,85 @@ impl<D> Wallet<D> {
})
}

/// Either loads [`Wallet`] from persistence, or initializes it if it does not exist.
///
/// This method will fail if the loaded [`Wallet`] has different parameters to those provided.
pub fn new_or_load<E: IntoWalletDescriptor>(
descriptor: E,
change_descriptor: Option<E>,
db: D,
network: Network,
) -> Result<Self, NewOrLoadError<D::WriteError, D::LoadError>>
where
D: PersistBackend<ChangeSet>,
{
let genesis_hash = genesis_block(network).block_hash();
Self::new_or_load_with_genesis_hash(
descriptor,
change_descriptor,
db,
network,
genesis_hash,
)
}

/// Either loads [`Wallet`] from persistence, or initializes it if it does not exist (with a
/// custom genesis hash).
///
/// This method will fail if the loaded [`Wallet`] has different parameters to those provided.
/// This is like [`Wallet::new_or_load`] with an additional `genesis_hash` parameter. This is
/// useful for syncing from alternative networks.
pub fn new_or_load_with_genesis_hash<E: IntoWalletDescriptor>(
descriptor: E,
change_descriptor: Option<E>,
mut db: D,
network: Network,
genesis_hash: BlockHash,
) -> Result<Self, NewOrLoadError<D::WriteError, D::LoadError>>
where
D: PersistBackend<ChangeSet>,
{
if db.is_empty().map_err(NewOrLoadError::Load)? {
return Self::new_with_genesis_hash(
descriptor,
change_descriptor,
db,
network,
genesis_hash,
)
.map_err(|e| match e {
NewError::Descriptor(e) => NewOrLoadError::Descriptor(e),
NewError::Write(e) => NewOrLoadError::Write(e),
});
}

let wallet = Self::load(descriptor, change_descriptor, db).map_err(|e| match e {
LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e),
LoadError::Load(e) => NewOrLoadError::Load(e),
LoadError::MissingNetwork => NewOrLoadError::LoadedNetworkDoesNotMatch {
expected: network,
got: None,
},
LoadError::MissingGenesis => NewOrLoadError::LoadedGenesisDoesNotMatch {
expected: genesis_hash,
got: None,
},
})?;
if wallet.chain.genesis_hash() != genesis_hash {
return Err(NewOrLoadError::LoadedGenesisDoesNotMatch {
expected: genesis_hash,
got: Some(wallet.chain.genesis_hash()),
});
}
if wallet.network != network {
return Err(NewOrLoadError::LoadedNetworkDoesNotMatch {
expected: network,
got: Some(wallet.network),
});
}
Ok(wallet)
}

/// Get the Bitcoin network the wallet is using.
pub fn network(&self) -> Network {
self.network
Expand Down Expand Up @@ -2149,6 +2264,30 @@ fn new_local_utxo(
}
}

fn create_signers<E: IntoWalletDescriptor>(
index: &mut KeychainTxOutIndex<KeychainKind>,
secp: &Secp256k1<All>,
descriptor: E,
change_descriptor: Option<E>,
network: Network,
) -> Result<(Arc<SignersContainer>, Arc<SignersContainer>), crate::descriptor::error::Error> {
let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?;
let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp));
index.add_keychain(KeychainKind::External, descriptor);

let change_signers = match change_descriptor {
Some(descriptor) => {
let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?;
let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp));
index.add_keychain(KeychainKind::Internal, descriptor);
signers
}
None => Arc::new(SignersContainer::new()),
};

Ok((signers, change_signers))
}

#[macro_export]
#[doc(hidden)]
/// Macro for getting a wallet for use in a doctest
Expand Down

0 comments on commit eb39808

Please sign in to comment.