diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 8a92f29f..f44ef1e4 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -46,6 +46,7 @@ use url::Url; use crate::input_type::InputType; use crate::psbt::PsbtExt; use crate::request::Request; +use crate::uri::pj_url::PjUrl; use crate::uri::UriExt; use crate::weight::{varint_size, ComputeWeight}; use crate::{PjUri, Uri}; @@ -211,7 +212,7 @@ impl<'a> RequestBuilder<'a> { .map_err(InternalCreateRequestError::InvalidOriginalInput)?; let endpoint = self.uri.extras.endpoint.clone(); #[cfg(feature = "v2")] - let ohttp_keys = self.uri.extras.ohttp_keys; + let ohttp_keys = self.uri.extras.endpoint.ohttp(); let disable_output_substitution = self.uri.extras.disable_output_substitution || self.disable_output_substitution; let payee = self.uri.address.script_pubkey(); @@ -259,7 +260,7 @@ impl<'a> RequestBuilder<'a> { #[derive(Clone)] pub struct RequestContext { psbt: Psbt, - endpoint: Url, + endpoint: PjUrl, #[cfg(feature = "v2")] ohttp_keys: Option, disable_output_substitution: bool, @@ -276,7 +277,7 @@ pub struct RequestContext { impl PartialEq for RequestContext { fn eq(&self, other: &Self) -> bool { self.psbt == other.psbt - && self.endpoint == other.endpoint + && *self.endpoint == *other.endpoint // KeyConfig is not yet PartialEq && self.ohttp_keys.as_ref().map(|cfg| cfg.encode().unwrap_or_default()) == other.ohttp_keys.as_ref().map(|cfg| cfg.encode().unwrap_or_default()) && self.disable_output_substitution == other.disable_output_substitution @@ -296,7 +297,7 @@ impl RequestContext { /// Extract serialized V1 Request and Context froma Payjoin Proposal pub fn extract_v1(self) -> Result<(Request, ContextV1), CreateRequestError> { let url = serialize_url( - self.endpoint.into(), + self.endpoint, self.disable_output_substitution, self.fee_contribution, self.min_fee_rate, @@ -521,7 +522,7 @@ impl<'de> Deserialize<'de> for RequestContext { Ok(RequestContext { psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?, - endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?, + endpoint: PjUrl(endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?), ohttp_keys, disable_output_substitution: disable_output_substitution .ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?, @@ -1018,7 +1019,7 @@ fn serialize_v2_body( ) -> Result, CreateRequestError> { // Grug say localhost base be discarded anyway. no big brain needed. let placeholder_url = serialize_url( - "http:/localhost".to_string(), + PjUrl(Url::parse("http://localhost").unwrap()), disable_output_substitution, fee_contribution, min_feerate, @@ -1030,12 +1031,12 @@ fn serialize_v2_body( } fn serialize_url( - endpoint: String, + endpoint: PjUrl, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, ) -> Result { - let mut url = Url::parse(&endpoint)?; + let mut url = endpoint.0; url.query_pairs_mut().append_pair("v", "1"); if disable_output_substitution { url.query_pairs_mut().append_pair("disableoutputsubstitution", "1"); @@ -1108,7 +1109,7 @@ mod test { let req_ctx = RequestContext { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), - endpoint: Url::parse("http://localhost:1234").unwrap(), + endpoint: PjUrl(Url::parse("http://localhost:1234").unwrap()), ohttp_keys: None, disable_output_substitution: false, fee_contribution: None, diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 55c46a4c..9743d3a8 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -6,8 +6,10 @@ pub use error::PjParseError; use url::Url; use crate::uri::error::InternalPjParseError; +use crate::uri::pj_url::PjUrl; pub mod error; +pub(crate) mod pj_url; #[cfg(feature = "v2")] use crate::OhttpKeys; @@ -29,10 +31,8 @@ impl MaybePayjoinExtras { #[derive(Clone)] pub struct PayjoinExtras { - pub(crate) endpoint: Url, + pub(crate) endpoint: PjUrl, pub(crate) disable_output_substitution: bool, - #[cfg(feature = "v2")] - pub(crate) ohttp_keys: Option, } pub type Uri<'a, NetworkValidation> = bip21::Uri<'a, NetworkValidation, MaybePayjoinExtras>; @@ -90,33 +90,25 @@ pub struct PjUriBuilder { /// Label label: Option, /// Payjoin endpoint url listening for payjoin requests. - pj: Url, + pj: PjUrl, /// Whether or not payjoin output substitution is allowed pjos: bool, - #[cfg(feature = "v2")] - /// Config for ohttp. - /// - /// Required only for v2 payjoin. - ohttp: Option, } impl PjUriBuilder { /// Create a new `PjUriBuilder` with required parameters. + /// + /// address represents a bitcoin address + /// and origin represents either the payjoin endpoint in v1 or the directory in v2 pub fn new( address: Address, - pj: Url, + origin: Url, #[cfg(feature = "v2")] ohttp_keys: Option, ) -> Self { - Self { - address, - amount: None, - message: None, - label: None, - pj, - pjos: false, - #[cfg(feature = "v2")] - ohttp: ohttp_keys, - } + let mut pj = PjUrl(origin); + #[cfg(feature = "v2")] + pj.set_ohttp(ohttp_keys); + Self { address, amount: None, message: None, label: None, pj, pjos: false } } /// Set the amount you want to receive. pub fn amount(mut self, amount: Amount) -> Self { @@ -147,12 +139,7 @@ impl PjUriBuilder { /// Constructs a `bip21::Uri` with PayjoinParams from the /// parameters set in the builder. pub fn build<'a>(self) -> PjUri<'a> { - let extras = PayjoinExtras { - endpoint: self.pj, - disable_output_substitution: self.pjos, - #[cfg(feature = "v2")] - ohttp_keys: self.ohttp, - }; + let extras = PayjoinExtras { endpoint: self.pj, disable_output_substitution: self.pjos }; let mut pj_uri = bip21::Uri::with_extras(self.address, extras); pj_uri.amount = self.amount; pj_uri.label = self.label.map(Into::into); @@ -177,8 +164,6 @@ impl<'a> bip21::de::DeserializeParams<'a> for MaybePayjoinExtras { pub struct DeserializationState { pj: Option, pjos: Option, - #[cfg(feature = "v2")] - ohttp: Option, } impl<'a> bip21::SerializeParams for &'a MaybePayjoinExtras { @@ -200,21 +185,11 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras { type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>; fn serialize_params(self) -> Self::Iterator { - #[allow(unused_mut)] - let mut params = vec![ + vec![ ("pj", self.endpoint.as_str().to_string()), ("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()), - ]; - #[cfg(feature = "v2")] - if let Some(ohttp_keys) = self.ohttp_keys.clone().and_then(|c| c.encode().ok()) { - let config = - bitcoin::base64::Config::new(bitcoin::base64::CharacterSet::UrlSafe, false); - let base64_ohttp_keys = bitcoin::base64::encode_config(ohttp_keys, config); - params.push(("ohttp", base64_ohttp_keys)); - } else { - log::warn!("Failed to encode ohttp config, ignoring"); - } - params.into_iter() + ] + .into_iter() } } @@ -232,20 +207,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { ::Error, > { match key { - #[cfg(feature = "v2")] - "ohttp" if self.ohttp.is_none() => { - let base64_config = - Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; - let config_bytes = - bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE) - .map_err(|_| InternalPjParseError::NotBase64)?; - let config = OhttpKeys::decode(&config_bytes) - .map_err(|_| InternalPjParseError::DecodeOhttpKeys)?; - self.ohttp = Some(config); - Ok(bip21::de::ParamKind::Known) - } - #[cfg(feature = "v2")] - "ohttp" => Err(InternalPjParseError::MultipleParams("ohttp").into()), "pj" if self.pj.is_none() => { let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?; @@ -279,10 +240,8 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { && endpoint.domain().unwrap_or_default().ends_with(".onion") { Ok(MaybePayjoinExtras::Supported(PayjoinExtras { - endpoint, + endpoint: PjUrl(endpoint), disable_output_substitution: pjos.unwrap_or(false), - #[cfg(feature = "v2")] - ohttp_keys: self.ohttp, })) } else { Err(InternalPjParseError::UnsecureEndpoint.into()) diff --git a/payjoin/src/uri/pj_url.rs b/payjoin/src/uri/pj_url.rs index 199c922a..26be0b88 100644 --- a/payjoin/src/uri/pj_url.rs +++ b/payjoin/src/uri/pj_url.rs @@ -1,43 +1,73 @@ +use std::borrow::Cow; +use std::ops::{Deref, DerefMut}; + +use serde::Deserialize; use url::Url; -pub struct PjUrl { - url: Url, - ohttp: Option, +use crate::OhttpKeys; +#[derive(Clone, Debug, Deserialize)] +pub struct PjUrl(pub Url); + +impl Deref for PjUrl { + type Target = Url; + + fn deref(&self) -> &Self::Target { &self.0 } } -impl PjUrl { - pub fn new(url: Url) -> Self { - let (url, ohttp) = Self::extract_ohttp(url); - PjUrl { url, ohttp } - } +impl DerefMut for PjUrl { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } +} - fn extract_ohttp(mut url: Url) -> (Url, Option) { - let fragment = &mut url.fragment().and_then(|f| { +impl PjUrl { + pub fn ohttp(&self) -> Option { + self.fragment().and_then(|f| { let parts: Vec<&str> = f.splitn(2, "ohttp=").collect(); if parts.len() == 2 { - Some((parts[0].trim_end_matches('&'), parts[1].to_string())) + let base64_config = Cow::try_from(parts[1]).ok()?; + let config_bytes = + bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE) + .ok()?; + OhttpKeys::decode(&config_bytes).ok() } else { None } - }); + }) + } - if let Some((remaining_fragment, ohttp)) = fragment { - url.set_fragment(Some(remaining_fragment)); - (url, Some(ohttp)) + pub fn set_ohttp(&mut self, ohttp: Option) { + if let Some(ohttp) = ohttp { + let new_ohttp = format!("ohttp={}", ohttp.to_string()); + let mut fragment = self.fragment().unwrap_or("").to_string(); + if let Some(start) = fragment.find("ohttp=") { + let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); + fragment.replace_range(start..end, &new_ohttp); + } else { + if !fragment.is_empty() { + fragment.push('&'); + } + fragment.push_str(&new_ohttp); + } + self.set_fragment(Some(&fragment)); } else { - (url, None) + self.set_fragment(None); } } +} - pub fn into_url(self) -> Url { - let mut url = self.url; - if let Some(ohttp) = self.ohttp { - let fragment = url - .fragment() - .map(|f| format!("{}&ohttp={}", f, ohttp)) - .unwrap_or_else(|| format!("ohttp={}", ohttp)); - url.set_fragment(Some(&fragment)); - } - url +#[cfg(test)] +mod test { + use url::Url; + + use super::PjUrl; + + #[test] + fn test_pj_url() { + let url = PjUrl( + Url::parse( + "https://example.com#ohttp=AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM", + ) + .unwrap(), + ); + assert!(url.ohttp().is_some()); } } diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index 7f130b21..bc104656 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -5,6 +5,7 @@ use bitcoin::secp256k1::ecdh::SharedSecret; use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; use chacha20poly1305::aead::{Aead, KeyInit, OsRng, Payload}; use chacha20poly1305::{AeadCore, ChaCha20Poly1305, Nonce}; +use serde::Serialize; pub const PADDED_MESSAGE_BYTES: usize = 7168; // 7KB @@ -249,6 +250,10 @@ impl OhttpKeys { } } +impl std::fmt::Display for OhttpKeys { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.serialize(f) } +} + impl PartialEq for OhttpKeys { fn eq(&self, other: &Self) -> bool { match (self.encode(), other.encode()) {