Skip to content

Commit

Permalink
Handle fragments with uri::UrlExt trait
Browse files Browse the repository at this point in the history
This extension trait defines functions to parse and set the ohttp parameter
in the fragment of a `pj=` URL.

Close payjoin#298
  • Loading branch information
DanGould committed Jun 26, 2024
1 parent caffc54 commit f534375
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 119 deletions.
20 changes: 5 additions & 15 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ impl<'a> RequestBuilder<'a> {
psbt.validate_input_utxos(true)
.map_err(InternalCreateRequestError::InvalidOriginalInput)?;
let endpoint = self.uri.extras.endpoint.clone();
#[cfg(feature = "v2")]
let ohttp_keys = self.uri.extras.ohttp_keys;
let disable_output_substitution =
self.uri.extras.disable_output_substitution || self.disable_output_substitution;
let payee = self.uri.address.script_pubkey();
Expand Down Expand Up @@ -234,8 +232,6 @@ impl<'a> RequestBuilder<'a> {
Ok(RequestContext {
psbt,
endpoint,
#[cfg(feature = "v2")]
ohttp_keys,
disable_output_substitution,
fee_contribution,
payee,
Expand All @@ -252,8 +248,6 @@ impl<'a> RequestBuilder<'a> {
pub struct RequestContext {
psbt: Psbt,
endpoint: Url,
#[cfg(feature = "v2")]
ohttp_keys: Option<crate::v2::OhttpKeys>,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
Expand Down Expand Up @@ -303,6 +297,7 @@ impl RequestContext {
&mut self,
ohttp_relay: Url,
) -> Result<(Request, ContextV2), CreateRequestError> {
use crate::uri::PjUrlExt;
let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
Expand All @@ -314,7 +309,7 @@ impl RequestContext {
let body = crate::v2::encrypt_message_a(body, self.e, rs)
.map_err(InternalCreateRequestError::Hpke)?;
let (body, ohttp_res) = crate::v2::ohttp_encapsulate(
self.ohttp_keys.as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?,
self.endpoint.ohttp().as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?,
"POST",
url.as_str(),
Some(&body),
Expand Down Expand Up @@ -384,7 +379,6 @@ impl Serialize for RequestContext {
let mut state = serializer.serialize_struct("RequestContext", 8)?;
state.serialize_field("psbt", &self.psbt.to_string())?;
state.serialize_field("endpoint", &self.endpoint.as_str())?;
state.serialize_field("ohttp_keys", &self.ohttp_keys)?;
state.serialize_field("disable_output_substitution", &self.disable_output_substitution)?;
state.serialize_field(
"fee_contribution",
Expand Down Expand Up @@ -433,7 +427,6 @@ impl<'de> Deserialize<'de> for RequestContext {
{
let mut psbt = None;
let mut endpoint = None;
let mut ohttp_keys = None;
let mut disable_output_substitution = None;
let mut fee_contribution = None;
let mut min_fee_rate = None;
Expand All @@ -453,7 +446,6 @@ impl<'de> Deserialize<'de> for RequestContext {
url::Url::from_str(&map.next_value::<String>()?)
.map_err(de::Error::custom)?,
),
"ohttp_keys" => ohttp_keys = Some(map.next_value()?),
"disable_output_substitution" =>
disable_output_substitution = Some(map.next_value()?),
"fee_contribution" => {
Expand All @@ -479,7 +471,6 @@ impl<'de> Deserialize<'de> for RequestContext {
Ok(RequestContext {
psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?,
endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?,
ohttp_keys: ohttp_keys.ok_or_else(|| de::Error::missing_field("ohttp_keys"))?,
disable_output_substitution: disable_output_substitution
.ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?,
fee_contribution,
Expand Down Expand Up @@ -975,7 +966,7 @@ fn serialize_v2_body(
) -> Result<Vec<u8>, CreateRequestError> {
// Grug say localhost base be discarded anyway. no big brain needed.
let placeholder_url = serialize_url(
"http:/localhost".to_string(),
Url::parse("http://localhost").unwrap(),
disable_output_substitution,
fee_contribution,
min_feerate,
Expand All @@ -987,12 +978,12 @@ fn serialize_v2_body(
}

fn serialize_url(
endpoint: String,
endpoint: Url,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
) -> Result<Url, url::ParseError> {
let mut url = Url::parse(&endpoint)?;
let mut url = endpoint;
url.query_pairs_mut().append_pair("v", "1");
if disable_output_substitution {
url.query_pairs_mut().append_pair("disableoutputsubstitution", "1");
Expand Down Expand Up @@ -1066,7 +1057,6 @@ mod test {
let req_ctx = RequestContext {
psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(),
endpoint: Url::parse("http://localhost:1234").unwrap(),
ohttp_keys: None,
disable_output_substitution: false,
fee_contribution: None,
min_fee_rate: FeeRate::ZERO,
Expand Down
7 changes: 0 additions & 7 deletions payjoin/src/uri/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(feature = "v2")]
use crate::uri::OhttpKeysParseError;

#[derive(Debug)]
pub struct PjParseError(InternalPjParseError);

Expand All @@ -11,8 +8,6 @@ pub(crate) enum InternalPjParseError {
MissingEndpoint,
NotUtf8,
BadEndpoint,
#[cfg(feature = "v2")]
ParseOhttpKeys(OhttpKeysParseError),
UnsecureEndpoint,
}

Expand All @@ -30,8 +25,6 @@ impl std::fmt::Display for PjParseError {
InternalPjParseError::MissingEndpoint => write!(f, "Missing payjoin endpoint"),
InternalPjParseError::NotUtf8 => write!(f, "Endpoint is not valid UTF-8"),
InternalPjParseError::BadEndpoint => write!(f, "Endpoint is not valid"),
#[cfg(feature = "v2")]
InternalPjParseError::ParseOhttpKeys(e) => write!(f, "OHTTP Keys are not valid: {}", e),
InternalPjParseError::UnsecureEndpoint => {
write!(f, "Endpoint scheme is not secure (https or onion)")
}
Expand Down
127 changes: 73 additions & 54 deletions payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ pub use error::PjParseError;
use url::Url;

use crate::uri::error::InternalPjParseError;
#[cfg(feature = "v2")]
use crate::v2::OhttpKeysParseError;

pub mod error;

#[cfg(feature = "v2")]
Expand All @@ -33,8 +30,6 @@ impl MaybePayjoinExtras {
pub struct PayjoinExtras {
pub(crate) endpoint: Url,
pub(crate) disable_output_substitution: bool,
#[cfg(feature = "v2")]
pub(crate) ohttp_keys: Option<OhttpKeys>,
}

impl PayjoinExtras {
Expand Down Expand Up @@ -99,30 +94,25 @@ pub struct PjUriBuilder {
pj: Url,
/// Whether or not payjoin output substitution is allowed
pjos: bool,
#[cfg(feature = "v2")]
/// Config for ohttp.
///
/// Required only for v2 payjoin.
ohttp: Option<OhttpKeys>,
}

impl PjUriBuilder {
/// Create a new `PjUriBuilder` with required parameters.
///
/// ## Parameters
/// - `address`: Represents a bitcoin address.
/// - `origin`: Represents either the payjoin endpoint in v1 or the directory in v2.
/// - `ohttp_keys`: Optional OHTTP keys for v2 (only available if the "v2" feature is enabled).
pub fn new(
address: Address,
pj: Url,
origin: Url,
#[cfg(feature = "v2")] ohttp_keys: Option<OhttpKeys>,
) -> Self {
Self {
address,
amount: None,
message: None,
label: None,
pj,
pjos: false,
#[cfg(feature = "v2")]
ohttp: ohttp_keys,
}
#[allow(unused_mut)]
let mut pj = origin;
#[cfg(feature = "v2")]
pj.set_ohttp(ohttp_keys);
Self { address, amount: None, message: None, label: None, pj, pjos: false }
}
/// Set the amount you want to receive.
pub fn amount(mut self, amount: Amount) -> Self {
Expand Down Expand Up @@ -153,12 +143,7 @@ impl PjUriBuilder {
/// Constructs a `bip21::Uri` with PayjoinParams from the
/// parameters set in the builder.
pub fn build<'a>(self) -> PjUri<'a> {
let extras = PayjoinExtras {
endpoint: self.pj,
disable_output_substitution: self.pjos,
#[cfg(feature = "v2")]
ohttp_keys: self.ohttp,
};
let extras = PayjoinExtras { endpoint: self.pj, disable_output_substitution: self.pjos };
let mut pj_uri = bip21::Uri::with_extras(self.address, extras);
pj_uri.amount = self.amount;
pj_uri.label = self.label.map(Into::into);
Expand All @@ -183,8 +168,6 @@ impl<'a> bip21::de::DeserializeParams<'a> for MaybePayjoinExtras {
pub struct DeserializationState {
pj: Option<Url>,
pjos: Option<bool>,
#[cfg(feature = "v2")]
ohttp: Option<OhttpKeys>,
}

impl<'a> bip21::SerializeParams for &'a MaybePayjoinExtras {
Expand All @@ -206,18 +189,11 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras {
type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>;

fn serialize_params(self) -> Self::Iterator {
#[allow(unused_mut)]
let mut params = vec![
vec![
("pj", self.endpoint.as_str().to_string()),
("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()),
];
#[cfg(feature = "v2")]
if let Some(ohttp_keys) = &self.ohttp_keys {
params.push(("ohttp", ohttp_keys.to_string()));
} else {
log::warn!("Failed to encode ohttp config, ignoring");
}
params.into_iter()
]
.into_iter()
}
}

Expand All @@ -235,19 +211,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
<Self::Value as bip21::DeserializationError>::Error,
> {
match key {
#[cfg(feature = "v2")]
"ohttp" if self.ohttp.is_none() => {
use std::str::FromStr;

let base64_config =
Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let config = OhttpKeys::from_str(&base64_config)
.map_err(InternalPjParseError::ParseOhttpKeys)?;
self.ohttp = Some(config);
Ok(bip21::de::ParamKind::Known)
}
#[cfg(feature = "v2")]
"ohttp" => Err(InternalPjParseError::MultipleParams("ohttp").into()),
"pj" if self.pj.is_none() => {
let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?;
Expand Down Expand Up @@ -283,8 +246,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
Ok(MaybePayjoinExtras::Supported(PayjoinExtras {
endpoint,
disable_output_substitution: pjos.unwrap_or(false),
#[cfg(feature = "v2")]
ohttp_keys: self.ohttp,
}))
} else {
Err(InternalPjParseError::UnsecureEndpoint.into())
Expand All @@ -294,6 +255,50 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
}
}

/// Parse and set fragment parameters from `&pj=` URLs
#[cfg(feature = "v2")]
pub trait PjUrlExt {
fn ohttp(&self) -> Option<OhttpKeys>;
fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>);
}

#[cfg(feature = "v2")]
impl PjUrlExt for Url {
fn ohttp(&self) -> Option<OhttpKeys> {
self.fragment().and_then(|f| {
let parts: Vec<&str> = f.splitn(2, "ohttp=").collect();
if parts.len() == 2 {
let base64_config = Cow::try_from(parts[1]).ok()?;
let config_bytes =
bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE)
.ok()?;
OhttpKeys::decode(&config_bytes).ok()
} else {
None
}
})
}

fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>) {
if let Some(ohttp) = ohttp {
let new_ohttp = format!("ohttp={}", ohttp.to_string());
let mut fragment = self.fragment().unwrap_or("").to_string();
if let Some(start) = fragment.find("ohttp=") {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.replace_range(start..end, &new_ohttp);
} else {
if !fragment.is_empty() {
fragment.push('&');
}
fragment.push_str(&new_ohttp);
}
self.set_fragment(Some(&fragment));
} else {
self.set_fragment(None);
}
}
}

#[cfg(test)]
mod tests {
use std::convert::TryFrom;
Expand Down Expand Up @@ -400,4 +405,18 @@ mod tests {
}
}
}

#[test]
#[cfg(feature = "v2")]
fn test_url_ext_ohttp_fragment() {
use url::Url;

use super::PjUrlExt;

let url = Url::parse(
"https://example.com#ohttp=AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM",
)
.unwrap();
assert!(url.ohttp().is_some());
}
}
43 changes: 0 additions & 43 deletions payjoin/src/uri/pj_url.rs

This file was deleted.

0 comments on commit f534375

Please sign in to comment.