Skip to content

Commit

Permalink
Add pj_url that handles fragments
Browse files Browse the repository at this point in the history
Close #298
  • Loading branch information
DanGould committed Jun 24, 2024
1 parent 5ec0c4d commit fecba55
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 93 deletions.
19 changes: 10 additions & 9 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use url::Url;
use crate::input_type::InputType;
use crate::psbt::PsbtExt;
use crate::request::Request;
use crate::uri::pj_url::PjUrl;
use crate::uri::UriExt;
use crate::weight::{varint_size, ComputeWeight};
use crate::{PjUri, Uri};
Expand Down Expand Up @@ -211,7 +212,7 @@ impl<'a> RequestBuilder<'a> {
.map_err(InternalCreateRequestError::InvalidOriginalInput)?;
let endpoint = self.uri.extras.endpoint.clone();
#[cfg(feature = "v2")]
let ohttp_keys = self.uri.extras.ohttp_keys;
let ohttp_keys = self.uri.extras.endpoint.ohttp();
let disable_output_substitution =
self.uri.extras.disable_output_substitution || self.disable_output_substitution;
let payee = self.uri.address.script_pubkey();
Expand Down Expand Up @@ -259,7 +260,7 @@ impl<'a> RequestBuilder<'a> {
#[derive(Clone)]
pub struct RequestContext {
psbt: Psbt,
endpoint: Url,
endpoint: PjUrl,
#[cfg(feature = "v2")]
ohttp_keys: Option<crate::v2::OhttpKeys>,
disable_output_substitution: bool,
Expand All @@ -276,7 +277,7 @@ pub struct RequestContext {
impl PartialEq for RequestContext {
fn eq(&self, other: &Self) -> bool {
self.psbt == other.psbt
&& self.endpoint == other.endpoint
&& *self.endpoint == *other.endpoint
// KeyConfig is not yet PartialEq
&& self.ohttp_keys.as_ref().map(|cfg| cfg.encode().unwrap_or_default()) == other.ohttp_keys.as_ref().map(|cfg| cfg.encode().unwrap_or_default())
&& self.disable_output_substitution == other.disable_output_substitution
Expand All @@ -296,7 +297,7 @@ impl RequestContext {
/// Extract serialized V1 Request and Context froma Payjoin Proposal
pub fn extract_v1(self) -> Result<(Request, ContextV1), CreateRequestError> {
let url = serialize_url(
self.endpoint.into(),
self.endpoint,
self.disable_output_substitution,
self.fee_contribution,
self.min_fee_rate,
Expand Down Expand Up @@ -521,7 +522,7 @@ impl<'de> Deserialize<'de> for RequestContext {

Ok(RequestContext {
psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?,
endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?,
endpoint: PjUrl(endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?),
ohttp_keys,
disable_output_substitution: disable_output_substitution
.ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?,
Expand Down Expand Up @@ -1018,7 +1019,7 @@ fn serialize_v2_body(
) -> Result<Vec<u8>, CreateRequestError> {
// Grug say localhost base be discarded anyway. no big brain needed.
let placeholder_url = serialize_url(
"http:/localhost".to_string(),
PjUrl(Url::parse("http://localhost").unwrap()),
disable_output_substitution,
fee_contribution,
min_feerate,
Expand All @@ -1030,12 +1031,12 @@ fn serialize_v2_body(
}

fn serialize_url(
endpoint: String,
endpoint: PjUrl,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
) -> Result<Url, url::ParseError> {
let mut url = Url::parse(&endpoint)?;
let mut url = endpoint.0;
url.query_pairs_mut().append_pair("v", "1");
if disable_output_substitution {
url.query_pairs_mut().append_pair("disableoutputsubstitution", "1");
Expand Down Expand Up @@ -1108,7 +1109,7 @@ mod test {

let req_ctx = RequestContext {
psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(),
endpoint: Url::parse("http://localhost:1234").unwrap(),
endpoint: PjUrl(Url::parse("http://localhost:1234").unwrap()),
ohttp_keys: None,
disable_output_substitution: false,
fee_contribution: None,
Expand Down
75 changes: 17 additions & 58 deletions payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ pub use error::PjParseError;
use url::Url;

use crate::uri::error::InternalPjParseError;
use crate::uri::pj_url::PjUrl;

pub mod error;
pub(crate) mod pj_url;

#[cfg(feature = "v2")]
use crate::OhttpKeys;
Expand All @@ -29,10 +31,8 @@ impl MaybePayjoinExtras {

#[derive(Clone)]
pub struct PayjoinExtras {
pub(crate) endpoint: Url,
pub(crate) endpoint: PjUrl,
pub(crate) disable_output_substitution: bool,
#[cfg(feature = "v2")]
pub(crate) ohttp_keys: Option<OhttpKeys>,
}

pub type Uri<'a, NetworkValidation> = bip21::Uri<'a, NetworkValidation, MaybePayjoinExtras>;
Expand Down Expand Up @@ -90,33 +90,25 @@ pub struct PjUriBuilder {
/// Label
label: Option<String>,
/// Payjoin endpoint url listening for payjoin requests.
pj: Url,
pj: PjUrl,
/// Whether or not payjoin output substitution is allowed
pjos: bool,
#[cfg(feature = "v2")]
/// Config for ohttp.
///
/// Required only for v2 payjoin.
ohttp: Option<OhttpKeys>,
}

impl PjUriBuilder {
/// Create a new `PjUriBuilder` with required parameters.
///
/// address represents a bitcoin address
/// and origin represents either the payjoin endpoint in v1 or the directory in v2
pub fn new(
address: Address,
pj: Url,
origin: Url,
#[cfg(feature = "v2")] ohttp_keys: Option<OhttpKeys>,
) -> Self {
Self {
address,
amount: None,
message: None,
label: None,
pj,
pjos: false,
#[cfg(feature = "v2")]
ohttp: ohttp_keys,
}
let mut pj = PjUrl(origin);
#[cfg(feature = "v2")]
pj.set_ohttp(ohttp_keys);
Self { address, amount: None, message: None, label: None, pj, pjos: false }
}
/// Set the amount you want to receive.
pub fn amount(mut self, amount: Amount) -> Self {
Expand Down Expand Up @@ -147,12 +139,7 @@ impl PjUriBuilder {
/// Constructs a `bip21::Uri` with PayjoinParams from the
/// parameters set in the builder.
pub fn build<'a>(self) -> PjUri<'a> {
let extras = PayjoinExtras {
endpoint: self.pj,
disable_output_substitution: self.pjos,
#[cfg(feature = "v2")]
ohttp_keys: self.ohttp,
};
let extras = PayjoinExtras { endpoint: self.pj, disable_output_substitution: self.pjos };
let mut pj_uri = bip21::Uri::with_extras(self.address, extras);
pj_uri.amount = self.amount;
pj_uri.label = self.label.map(Into::into);
Expand All @@ -177,8 +164,6 @@ impl<'a> bip21::de::DeserializeParams<'a> for MaybePayjoinExtras {
pub struct DeserializationState {
pj: Option<Url>,
pjos: Option<bool>,
#[cfg(feature = "v2")]
ohttp: Option<OhttpKeys>,
}

impl<'a> bip21::SerializeParams for &'a MaybePayjoinExtras {
Expand All @@ -200,21 +185,11 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras {
type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>;

fn serialize_params(self) -> Self::Iterator {
#[allow(unused_mut)]
let mut params = vec![
vec![
("pj", self.endpoint.as_str().to_string()),
("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()),
];
#[cfg(feature = "v2")]
if let Some(ohttp_keys) = self.ohttp_keys.clone().and_then(|c| c.encode().ok()) {
let config =
bitcoin::base64::Config::new(bitcoin::base64::CharacterSet::UrlSafe, false);
let base64_ohttp_keys = bitcoin::base64::encode_config(ohttp_keys, config);
params.push(("ohttp", base64_ohttp_keys));
} else {
log::warn!("Failed to encode ohttp config, ignoring");
}
params.into_iter()
]
.into_iter()
}
}

Expand All @@ -232,20 +207,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
<Self::Value as bip21::DeserializationError>::Error,
> {
match key {
#[cfg(feature = "v2")]
"ohttp" if self.ohttp.is_none() => {
let base64_config =
Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let config_bytes =
bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE)
.map_err(|_| InternalPjParseError::NotBase64)?;
let config = OhttpKeys::decode(&config_bytes)
.map_err(|_| InternalPjParseError::DecodeOhttpKeys)?;
self.ohttp = Some(config);
Ok(bip21::de::ParamKind::Known)
}
#[cfg(feature = "v2")]
"ohttp" => Err(InternalPjParseError::MultipleParams("ohttp").into()),
"pj" if self.pj.is_none() => {
let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?;
Expand Down Expand Up @@ -279,10 +240,8 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
&& endpoint.domain().unwrap_or_default().ends_with(".onion")
{
Ok(MaybePayjoinExtras::Supported(PayjoinExtras {
endpoint,
endpoint: PjUrl(endpoint),
disable_output_substitution: pjos.unwrap_or(false),
#[cfg(feature = "v2")]
ohttp_keys: self.ohttp,
}))
} else {
Err(InternalPjParseError::UnsecureEndpoint.into())
Expand Down
82 changes: 56 additions & 26 deletions payjoin/src/uri/pj_url.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,73 @@
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};

use serde::Deserialize;
use url::Url;

pub struct PjUrl {
url: Url,
ohttp: Option<String>,
use crate::OhttpKeys;
#[derive(Clone, Debug, Deserialize)]
pub struct PjUrl(pub Url);

impl Deref for PjUrl {
type Target = Url;

fn deref(&self) -> &Self::Target { &self.0 }
}

impl PjUrl {
pub fn new(url: Url) -> Self {
let (url, ohttp) = Self::extract_ohttp(url);
PjUrl { url, ohttp }
}
impl DerefMut for PjUrl {
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}

fn extract_ohttp(mut url: Url) -> (Url, Option<String>) {
let fragment = &mut url.fragment().and_then(|f| {
impl PjUrl {
pub fn ohttp(&self) -> Option<OhttpKeys> {
self.fragment().and_then(|f| {
let parts: Vec<&str> = f.splitn(2, "ohttp=").collect();
if parts.len() == 2 {
Some((parts[0].trim_end_matches('&'), parts[1].to_string()))
let base64_config = Cow::try_from(parts[1]).ok()?;
let config_bytes =
bitcoin::base64::decode_config(&*base64_config, bitcoin::base64::URL_SAFE)
.ok()?;
OhttpKeys::decode(&config_bytes).ok()
} else {
None
}
});
})
}

if let Some((remaining_fragment, ohttp)) = fragment {
url.set_fragment(Some(remaining_fragment));
(url, Some(ohttp))
pub fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>) {
if let Some(ohttp) = ohttp {
let new_ohttp = format!("ohttp={}", ohttp.to_string());
let mut fragment = self.fragment().unwrap_or("").to_string();
if let Some(start) = fragment.find("ohttp=") {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.replace_range(start..end, &new_ohttp);
} else {
if !fragment.is_empty() {
fragment.push('&');
}
fragment.push_str(&new_ohttp);
}
self.set_fragment(Some(&fragment));
} else {
(url, None)
self.set_fragment(None);
}
}
}

pub fn into_url(self) -> Url {
let mut url = self.url;
if let Some(ohttp) = self.ohttp {
let fragment = url
.fragment()
.map(|f| format!("{}&ohttp={}", f, ohttp))
.unwrap_or_else(|| format!("ohttp={}", ohttp));
url.set_fragment(Some(&fragment));
}
url
#[cfg(test)]
mod test {
use url::Url;

use super::PjUrl;

#[test]
fn test_pj_url() {
let url = PjUrl(
Url::parse(
"https://example.com#ohttp=AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM",
)
.unwrap(),
);
assert!(url.ohttp().is_some());
}
}
5 changes: 5 additions & 0 deletions payjoin/src/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use bitcoin::secp256k1::ecdh::SharedSecret;
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
use chacha20poly1305::aead::{Aead, KeyInit, OsRng, Payload};
use chacha20poly1305::{AeadCore, ChaCha20Poly1305, Nonce};
use serde::Serialize;

pub const PADDED_MESSAGE_BYTES: usize = 7168; // 7KB

Expand Down Expand Up @@ -249,6 +250,10 @@ impl OhttpKeys {
}
}

impl std::fmt::Display for OhttpKeys {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.serialize(f) }
}

impl PartialEq for OhttpKeys {
fn eq(&self, other: &Self) -> bool {
match (self.encode(), other.encode()) {
Expand Down

0 comments on commit fecba55

Please sign in to comment.