diff --git a/mutiny-core/src/lib.rs b/mutiny-core/src/lib.rs index 0489a9a1b..4d4a5a70f 100644 --- a/mutiny-core/src/lib.rs +++ b/mutiny-core/src/lib.rs @@ -1502,61 +1502,26 @@ impl MutinyWallet { return Err(MutinyError::WalletOperationFailed); }; - let (pj, ohttp) = { - use crate::payjoin::{OHTTP_RELAYS, PAYJOIN_DIR}; - use anyhow::anyhow; - - let ohttp_keys = crate::payjoin::fetch_ohttp_keys( - OHTTP_RELAYS[0].to_owned(), - PAYJOIN_DIR.to_owned(), - ) - .await - .map_err(|e| anyhow!("Payjoin OHTTP fetch error {}", e))?; - - let ohttp = base64::encode_config( - ohttp_keys - .encode() - .map_err(|_| MutinyError::PayjoinConfigError)?, - base64::URL_SAFE_NO_PAD, - ); - let mut enroller = pj::receive::v2::Enroller::from_directory_config( - PAYJOIN_DIR.to_owned(), - ohttp_keys, - OHTTP_RELAYS[0].to_owned(), // TODO pick ohttp relay at random - ); - - // enroll client - let (req, context) = enroller.extract_req().unwrap(); - let http_client = reqwest::Client::builder().build().unwrap(); - let ohttp_response = http_client - .post(req.url) - .header("Content-Type", "message/ohttp-req") - .body(req.body) - .send() - .await - .map_err(|_| MutinyError::PayjoinCreateRequest)?; - let ohttp_response = ohttp_response.bytes().await.unwrap(); - let enrolled = enroller - .process_res(ohttp_response.as_ref(), context) - .map_err(|_| MutinyError::PayjoinCreateRequest)?; - let session = self - .node_manager - .storage - .persist_payjoin(enrolled.clone())?; - let pj_uri = enrolled.fallback_target(); - log_debug!(self.logger, "{pj_uri}"); - let wallet = self.node_manager.wallet.clone(); - let stop = self.node_manager.stop.clone(); - let storage = Arc::new(self.node_manager.storage.clone()); - // run await payjoin task in the background as it'll keep polling the relay - let logger = self.logger.clone(); - utils::spawn(async move { - match NodeManager::receive_payjoin(wallet, stop, storage, session).await { - Ok(pj_txid) => log_info!(logger, "Received payjoin txid: {}", pj_txid), - Err(e) => log_error!(logger, "Payjoin error: {e}"), - } - }); - (Some(pj_uri), Some(ohttp)) + let (pj, ohttp) = match self.node_manager.start_payjoin_session().await { + Ok((enrolled, ohttp_keys)) => { + let session = self + .node_manager + .storage + .persist_payjoin(enrolled.clone())?; + let pj_uri = session.enrolled.fallback_target(); + self.node_manager.spawn_payjoin_receiver(session); + let ohttp = base64::encode_config( + ohttp_keys + .encode() + .map_err(|_| MutinyError::PayjoinConfigError)?, + base64::URL_SAFE_NO_PAD, + ); + (Some(pj_uri), Some(ohttp)) + } + Err(e) => { + log_error!(self.logger, "Error enrolling payjoin: {e}"); + (None, None) + } }; Ok(MutinyBip21RawMaterials { diff --git a/mutiny-core/src/nodemanager.rs b/mutiny-core/src/nodemanager.rs index 2264a5a14..cd9a0024e 100644 --- a/mutiny-core/src/nodemanager.rs +++ b/mutiny-core/src/nodemanager.rs @@ -3,7 +3,7 @@ use crate::event::HTLCStatus; use crate::labels::LabelStorage; use crate::ldkstorage::CHANNEL_CLOSURE_PREFIX; use crate::logging::LOGGING_KEY; -use crate::payjoin::PayjoinStorage; +use crate::payjoin::{Error as PayjoinError, PayjoinStorage}; use crate::utils::{sleep, spawn}; use crate::ActivityItem; use crate::MutinyInvoice; @@ -55,7 +55,9 @@ use lightning::util::logger::*; use lightning::{log_debug, log_error, log_info, log_trace, log_warn}; use lightning_invoice::Bolt11Invoice; use lightning_transaction_sync::EsploraSyncClient; +use payjoin::receive::v2::Enrolled; use payjoin::Uri; +use pj::OhttpKeys; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -633,15 +635,7 @@ impl NodeManager { pub(crate) fn resume_payjoins(nm: Arc>) { let all = nm.storage.get_payjoins().unwrap_or_default(); for payjoin in all { - let wallet = nm.wallet.clone(); - let stop = nm.stop.clone(); - let storage = Arc::new(nm.storage.clone()); - utils::spawn(async move { - let pj_txid = Self::receive_payjoin(wallet, stop, storage, payjoin) - .await - .unwrap(); - log::info!("Received payjoin txid: {}", pj_txid); - }); + nm.clone().spawn_payjoin_receiver(payjoin); } } @@ -734,6 +728,33 @@ impl NodeManager { Err(MutinyError::WalletOperationFailed) } + pub async fn start_payjoin_session(&self) -> Result<(Enrolled, OhttpKeys), PayjoinError> { + use crate::payjoin::{OHTTP_RELAYS, PAYJOIN_DIR}; + + let ohttp_keys = + crate::payjoin::fetch_ohttp_keys(OHTTP_RELAYS[0].to_owned(), PAYJOIN_DIR.to_owned()) + .await?; + let http_client = reqwest::Client::builder().build()?; + + let mut enroller = payjoin::receive::v2::Enroller::from_directory_config( + PAYJOIN_DIR.to_owned(), + ohttp_keys.clone(), + OHTTP_RELAYS[0].to_owned(), // TODO pick ohttp relay at random + ); + let (req, context) = enroller.extract_req()?; + let ohttp_response = http_client + .post(req.url) + .header("Content-Type", "message/ohttp-req") + .body(req.body) + .send() + .await?; + let ohttp_response = ohttp_response.bytes().await?; + Ok(( + enroller.process_res(ohttp_response.as_ref(), context)?, + ohttp_keys, + )) + } + // Send v1 payjoin request pub async fn send_payjoin( &self, @@ -809,15 +830,26 @@ impl NodeManager { Ok(txid) } + pub fn spawn_payjoin_receiver(&self, session: crate::payjoin::Session) { + let logger = self.logger.clone(); + let wallet = self.wallet.clone(); + let stop = self.stop.clone(); + let storage = Arc::new(self.storage.clone()); + utils::spawn(async move { + match Self::receive_payjoin(wallet, stop, storage, session).await { + Ok(txid) => log_info!(logger, "Received payjoin txid: {txid}"), + Err(e) => log_error!(logger, "Error receiving payjoin: {e}"), + }; + }); + } + /// Poll the payjoin relay to maintain a payjoin session and create a payjoin proposal. - pub async fn receive_payjoin( + async fn receive_payjoin( wallet: Arc>, stop: Arc, storage: Arc, mut session: crate::payjoin::Session, ) -> Result { - use crate::payjoin::Error as PayjoinError; - let http_client = reqwest::Client::builder() .build() .map_err(PayjoinError::Reqwest)?; diff --git a/mutiny-core/src/payjoin.rs b/mutiny-core/src/payjoin.rs index 71f55fb59..6171be002 100644 --- a/mutiny-core/src/payjoin.rs +++ b/mutiny-core/src/payjoin.rs @@ -69,10 +69,7 @@ impl PayjoinStorage for S { } } -pub async fn fetch_ohttp_keys( - _ohttp_relay: Url, - directory: Url, -) -> Result> { +pub async fn fetch_ohttp_keys(_ohttp_relay: Url, directory: Url) -> Result { let http_client = reqwest::Client::builder().build()?; let ohttp_keys_res = http_client @@ -81,7 +78,7 @@ pub async fn fetch_ohttp_keys( .await? .bytes() .await?; - Ok(OhttpKeys::decode(ohttp_keys_res.as_ref())?) + Ok(OhttpKeys::decode(ohttp_keys_res.as_ref()).map_err(|_| Error::OhttpDecodeFailed)?) } #[derive(Debug)] @@ -89,6 +86,7 @@ pub enum Error { Reqwest(reqwest::Error), ReceiverStateMachine(String), Txid(bitcoin::hashes::hex::Error), + OhttpDecodeFailed, Shutdown, SessionExpired, } @@ -101,6 +99,7 @@ impl std::fmt::Display for Error { Error::Reqwest(e) => write!(f, "Reqwest error: {}", e), Error::ReceiverStateMachine(e) => write!(f, "Payjoin state machine error: {}", e), Error::Txid(e) => write!(f, "Payjoin txid error: {}", e), + Error::OhttpDecodeFailed => write!(f, "Failed to decode ohttp keys"), Error::Shutdown => write!(f, "Payjoin stopped by application shutdown"), Error::SessionExpired => write!(f, "Payjoin session expired. Create a new payment request and have the sender try again."), }