Skip to content

Commit

Permalink
Handle fragments with uri::UrlExt trait
Browse files Browse the repository at this point in the history
This extension trait defines functions to parse and set the ohttp parameter
in the fragment of a `pj=` URL.
  • Loading branch information
DanGould committed Jun 25, 2024
1 parent 7b21179 commit a8b5e97
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
15 changes: 7 additions & 8 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ use url::Url;
use crate::input_type::InputType;
use crate::psbt::PsbtExt;
use crate::request::Request;
use crate::uri::pj_url::PjUrl;
use crate::uri::UriExt;
use crate::weight::{varint_size, ComputeWeight};
use crate::{PjUri, Uri};
Expand Down Expand Up @@ -212,7 +211,7 @@ impl<'a> RequestBuilder<'a> {
.map_err(InternalCreateRequestError::InvalidOriginalInput)?;
let endpoint = self.uri.extras.endpoint.clone();
#[cfg(feature = "v2")]
let ohttp_keys = self.uri.extras.endpoint.ohttp();
let ohttp_keys = crate::uri::UrlExt::ohttp(&self.uri.extras.endpoint);
let disable_output_substitution =
self.uri.extras.disable_output_substitution || self.disable_output_substitution;
let payee = self.uri.address.script_pubkey();
Expand Down Expand Up @@ -260,7 +259,7 @@ impl<'a> RequestBuilder<'a> {
#[derive(Clone, PartialEq)]
pub struct RequestContext {
psbt: Psbt,
endpoint: PjUrl,
endpoint: Url,
#[cfg(feature = "v2")]
ohttp_keys: Option<crate::v2::OhttpKeys>,
disable_output_substitution: bool,
Expand Down Expand Up @@ -476,7 +475,7 @@ impl<'de> Deserialize<'de> for RequestContext {

Ok(RequestContext {
psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?,
endpoint: PjUrl(endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?),
endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?,
ohttp_keys: ohttp_keys.ok_or_else(|| de::Error::missing_field("ohttp_keys"))?,
disable_output_substitution: disable_output_substitution
.ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?,
Expand Down Expand Up @@ -973,7 +972,7 @@ fn serialize_v2_body(
) -> Result<Vec<u8>, CreateRequestError> {
// Grug say localhost base be discarded anyway. no big brain needed.
let placeholder_url = serialize_url(
PjUrl(Url::parse("http://localhost").unwrap()),
Url::parse("http://localhost").unwrap(),
disable_output_substitution,
fee_contribution,
min_feerate,
Expand All @@ -985,12 +984,12 @@ fn serialize_v2_body(
}

fn serialize_url(
endpoint: PjUrl,
endpoint: Url,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
) -> Result<Url, url::ParseError> {
let mut url = endpoint.0;
let mut url = endpoint;
url.query_pairs_mut().append_pair("v", "1");
if disable_output_substitution {
url.query_pairs_mut().append_pair("disableoutputsubstitution", "1");
Expand Down Expand Up @@ -1063,7 +1062,7 @@ mod test {

let req_ctx = RequestContext {
psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(),
endpoint: PjUrl(Url::parse("http://localhost:1234").unwrap()),
endpoint: Url::parse("http://localhost:1234").unwrap(),
ohttp_keys: None,
disable_output_substitution: false,
fee_contribution: None,
Expand Down
68 changes: 62 additions & 6 deletions payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ pub use error::PjParseError;
use url::Url;

use crate::uri::error::InternalPjParseError;
use crate::uri::pj_url::PjUrl;

pub mod error;
pub(crate) mod pj_url;

#[cfg(feature = "v2")]
use crate::OhttpKeys;
Expand All @@ -31,7 +29,7 @@ impl MaybePayjoinExtras {

#[derive(Clone)]
pub struct PayjoinExtras {
pub(crate) endpoint: PjUrl,
pub(crate) endpoint: Url,
pub(crate) disable_output_substitution: bool,
}

Expand Down Expand Up @@ -90,7 +88,7 @@ pub struct PjUriBuilder {
/// Label
label: Option<String>,
/// Payjoin endpoint url listening for payjoin requests.
pj: PjUrl,
pj: Url,
/// Whether or not payjoin output substitution is allowed
pjos: bool,
}
Expand All @@ -106,7 +104,7 @@ impl PjUriBuilder {
#[cfg(feature = "v2")] ohttp_keys: Option<OhttpKeys>,
) -> Self {
#[allow(unused_mut)]
let mut pj = PjUrl(origin);
let mut pj = origin;
#[cfg(feature = "v2")]
pj.set_ohttp(ohttp_keys);
Self { address, amount: None, message: None, label: None, pj, pjos: false }
Expand Down Expand Up @@ -241,7 +239,7 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
&& endpoint.domain().unwrap_or_default().ends_with(".onion")
{
Ok(MaybePayjoinExtras::Supported(PayjoinExtras {
endpoint: PjUrl(endpoint),
endpoint,
disable_output_substitution: pjos.unwrap_or(false),
}))
} else {
Expand All @@ -252,6 +250,50 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
}
}

/// Parse and set fragment parameters from `&pj=` URLs
#[cfg(feature = "v2")]
pub trait UrlExt {
fn ohttp(&self) -> Option<OhttpKeys>;
fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>);
}

#[cfg(feature = "v2")]
impl UrlExt for Url {
fn ohttp(&self) -> Option<OhttpKeys> {
self.fragment().and_then(|f| {
let parts: Vec<&str> = f.splitn(2, "ohttp=").collect();
if parts.len() == 2 {
let base64_config = Cow::try_from(parts[1]).ok()?;
let config_bytes =
bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE)
.ok()?;
OhttpKeys::decode(&config_bytes).ok()
} else {
None
}
})
}

fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>) {
if let Some(ohttp) = ohttp {
let new_ohttp = format!("ohttp={}", ohttp.to_string());
let mut fragment = self.fragment().unwrap_or("").to_string();
if let Some(start) = fragment.find("ohttp=") {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.replace_range(start..end, &new_ohttp);
} else {
if !fragment.is_empty() {
fragment.push('&');
}
fragment.push_str(&new_ohttp);
}
self.set_fragment(Some(&fragment));
} else {
self.set_fragment(None);
}
}
}

#[cfg(test)]
mod tests {
use std::convert::TryFrom;
Expand Down Expand Up @@ -358,4 +400,18 @@ mod tests {
}
}
}

#[test]
#[cfg(feature = "v2")]
fn test_url_ext_ohttp_fragment() {
use url::Url;

use super::UrlExt;

let url = Url::parse(
"https://example.com#ohttp=AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM",
)
.unwrap();
assert!(url.ohttp().is_some());
}
}

0 comments on commit a8b5e97

Please sign in to comment.