Skip to content

Commit

Permalink
Handle fragments with uri::UrlExt trait
Browse files Browse the repository at this point in the history
This extension trait defines functions to parse and set the ohttp parameter
in the fragment of a `pj=` URL.

Close #298
  • Loading branch information
DanGould committed Jul 9, 2024
1 parent cf49f5a commit 78e66f6
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 173 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ exclude = ["tests"]
send = []
receive = ["bitcoin/rand"]
base64 = ["bitcoin/base64"]
v2 = ["bitcoin/rand", "bitcoin/serde", "chacha20poly1305", "dep:http", "bhttp", "ohttp", "serde", "url/serde"]
v2 = ["bitcoin/rand", "bitcoin/serde", "chacha20poly1305", "dep:http", "bhttp", "ohttp", "dep:percent-encoding", "serde", "url/serde"]
io = ["reqwest/rustls-tls"]
danger-local-https = ["io", "reqwest/rustls-tls", "rustls"]

Expand All @@ -30,6 +30,7 @@ log = { version = "0.4.14"}
http = { version = "1", optional = true }
bhttp = { version = "=0.5.1", optional = true }
ohttp = { version = "0.5.1", optional = true }
percent-encoding = { version = "0.1.3", optional = true, package = "percent-encoding-rfc3986" }
serde = { version = "1.0.186", default-features = false, optional = true }
reqwest = { version = "0.12", default-features = false, optional = true }
rustls = { version = "0.22.2", optional = true }
Expand Down
60 changes: 51 additions & 9 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ pub(crate) enum InternalCreateRequestError {
#[cfg(feature = "v2")]
OhttpEncapsulation(crate::v2::OhttpEncapsulationError),
#[cfg(feature = "v2")]
SubdirectoryNotBase64(bitcoin::base64::DecodeError),
#[cfg(feature = "v2")]
SubdirectoryInvalidPubkey(bitcoin::secp256k1::Error),
ParseSubdirectory(ParseSubdirectoryError),
#[cfg(feature = "v2")]
MissingOhttpConfig,
#[cfg(feature = "v2")]
PercentEncoding,
}

impl fmt::Display for CreateRequestError {
Expand All @@ -223,11 +223,11 @@ impl fmt::Display for CreateRequestError {
#[cfg(feature = "v2")]
OhttpEncapsulation(e) => write!(f, "v2 error: {}", e),
#[cfg(feature = "v2")]
SubdirectoryNotBase64(e) => write!(f, "subdirectory is not valid base64 error: {}", e),
#[cfg(feature = "v2")]
SubdirectoryInvalidPubkey(e) => write!(f, "subdirectory does not represent a valid pubkey: {}", e),
ParseSubdirectory(e) => write!(f, "cannot parse subdirectory: {}", e),
#[cfg(feature = "v2")]
MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"),
#[cfg(feature = "v2")]
PercentEncoding => write!(f, "fragment is not RFC 3986 percent-encoded"),
}
}
}
Expand Down Expand Up @@ -256,11 +256,11 @@ impl std::error::Error for CreateRequestError {
#[cfg(feature = "v2")]
OhttpEncapsulation(error) => Some(error),
#[cfg(feature = "v2")]
SubdirectoryNotBase64(error) => Some(error),
#[cfg(feature = "v2")]
SubdirectoryInvalidPubkey(error) => Some(error),
ParseSubdirectory(error) => Some(error),
#[cfg(feature = "v2")]
MissingOhttpConfig => None,
#[cfg(feature = "v2")]
PercentEncoding => None,
}
}
}
Expand All @@ -269,6 +269,48 @@ impl From<InternalCreateRequestError> for CreateRequestError {
fn from(value: InternalCreateRequestError) -> Self { CreateRequestError(value) }
}

#[cfg(feature = "v2")]
impl From<ParseSubdirectoryError> for CreateRequestError {
fn from(value: ParseSubdirectoryError) -> Self {
CreateRequestError(InternalCreateRequestError::ParseSubdirectory(value))
}
}

#[cfg(feature = "v2")]
#[derive(Debug)]
pub(crate) enum ParseSubdirectoryError {
MissingSubdirectory,
SubdirectoryNotBase64(bitcoin::base64::DecodeError),
SubdirectoryInvalidPubkey(bitcoin::secp256k1::Error),
}

#[cfg(feature = "v2")]
impl std::fmt::Display for ParseSubdirectoryError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
use ParseSubdirectoryError::*;

match &self {
MissingSubdirectory => write!(f, "subdirectory is missing"),
SubdirectoryNotBase64(e) => write!(f, "subdirectory is not valid base64: {}", e),
SubdirectoryInvalidPubkey(e) =>
write!(f, "subdirectory does not represent a valid pubkey: {}", e),
}
}
}

#[cfg(feature = "v2")]
impl std::error::Error for ParseSubdirectoryError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use ParseSubdirectoryError::*;

match &self {
MissingSubdirectory => None,
SubdirectoryNotBase64(error) => Some(error),
SubdirectoryInvalidPubkey(error) => Some(error),
}
}
}

/// Represent an error returned by Payjoin receiver.
pub enum ResponseError {
/// `WellKnown` Errors are defined in the [`BIP78::ReceiverWellKnownError`] spec.
Expand Down
78 changes: 29 additions & 49 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ impl<'a> RequestBuilder<'a> {
psbt.validate_input_utxos(true)
.map_err(InternalCreateRequestError::InvalidOriginalInput)?;
let endpoint = self.uri.extras.endpoint.clone();
#[cfg(feature = "v2")]
let ohttp_keys = self.uri.extras.ohttp_keys;
let disable_output_substitution =
self.uri.extras.disable_output_substitution || self.disable_output_substitution;
let payee = self.uri.address.script_pubkey();
Expand Down Expand Up @@ -234,8 +232,6 @@ impl<'a> RequestBuilder<'a> {
Ok(RequestContext {
psbt,
endpoint,
#[cfg(feature = "v2")]
ohttp_keys,
disable_output_substitution,
fee_contribution,
payee,
Expand All @@ -252,8 +248,6 @@ impl<'a> RequestBuilder<'a> {
pub struct RequestContext {
psbt: Psbt,
endpoint: Url,
#[cfg(feature = "v2")]
ohttp_keys: Option<crate::v2::OhttpKeys>,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
Expand All @@ -271,7 +265,7 @@ impl RequestContext {
/// Extract serialized V1 Request and Context froma Payjoin Proposal
pub fn extract_v1(self) -> Result<(Request, ContextV1), CreateRequestError> {
let url = serialize_url(
self.endpoint.into(),
self.endpoint,
self.disable_output_substitution,
self.fee_contribution,
self.min_fee_rate,
Expand Down Expand Up @@ -303,6 +297,7 @@ impl RequestContext {
&mut self,
ohttp_relay: Url,
) -> Result<(Request, ContextV2), CreateRequestError> {
use crate::uri::UrlExt;
let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
Expand All @@ -313,13 +308,14 @@ impl RequestContext {
)?;
let body = crate::v2::encrypt_message_a(body, self.e, rs)
.map_err(InternalCreateRequestError::Hpke)?;
let (body, ohttp_res) = crate::v2::ohttp_encapsulate(
self.ohttp_keys.as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?,
"POST",
url.as_str(),
Some(&body),
)
.map_err(InternalCreateRequestError::OhttpEncapsulation)?;
let mut ohttp = self
.endpoint
.ohttp()
.map_err(|_| InternalCreateRequestError::PercentEncoding)?
.ok_or(InternalCreateRequestError::MissingOhttpConfig)?;
let (body, ohttp_res) =
crate::v2::ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body))
.map_err(InternalCreateRequestError::OhttpEncapsulation)?;
log::debug!("ohttp_relay_url: {:?}", ohttp_relay);
Ok((
Request { url: ohttp_relay, body },
Expand All @@ -342,33 +338,22 @@ impl RequestContext {

#[cfg(feature = "v2")]
fn rs_pubkey_from_dir_endpoint(endpoint: &Url) -> Result<PublicKey, CreateRequestError> {
let path_and_query: String;

if let Some(pos) = endpoint.as_str().rfind('/') {
path_and_query = endpoint.as_str()[pos + 1..].to_string();
} else {
path_and_query = endpoint.to_string();
}

let subdirectory: String;

if let Some(pos) = path_and_query.find('?') {
subdirectory = path_and_query[..pos].to_string();
} else {
subdirectory = path_and_query;
}

let pubkey_bytes =
bitcoin::base64::decode_config(subdirectory, bitcoin::base64::URL_SAFE_NO_PAD)
.map_err(InternalCreateRequestError::SubdirectoryNotBase64)?;
Ok(bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
.map_err(InternalCreateRequestError::SubdirectoryInvalidPubkey)?)
}

#[cfg(feature = "v2")]
pub fn public_key(&self) -> PublicKey {
let secp = bitcoin::secp256k1::Secp256k1::new();
self.e.public_key(&secp)
use bitcoin::base64;

use crate::send::error::ParseSubdirectoryError;

let subdirectory = endpoint
.path_segments()
.ok_or(ParseSubdirectoryError::MissingSubdirectory)?
.next()
.ok_or(ParseSubdirectoryError::MissingSubdirectory)?
.to_string();

let pubkey_bytes = base64::decode_config(subdirectory, base64::URL_SAFE_NO_PAD)
.map_err(ParseSubdirectoryError::SubdirectoryNotBase64)?;
bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
.map_err(ParseSubdirectoryError::SubdirectoryInvalidPubkey)
.map_err(CreateRequestError::from)
}

pub fn endpoint(&self) -> &Url { &self.endpoint }
Expand All @@ -383,7 +368,6 @@ impl Serialize for RequestContext {
let mut state = serializer.serialize_struct("RequestContext", 8)?;
state.serialize_field("psbt", &self.psbt.to_string())?;
state.serialize_field("endpoint", &self.endpoint.as_str())?;
state.serialize_field("ohttp_keys", &self.ohttp_keys)?;
state.serialize_field("disable_output_substitution", &self.disable_output_substitution)?;
state.serialize_field(
"fee_contribution",
Expand Down Expand Up @@ -432,7 +416,6 @@ impl<'de> Deserialize<'de> for RequestContext {
{
let mut psbt = None;
let mut endpoint = None;
let mut ohttp_keys = None;
let mut disable_output_substitution = None;
let mut fee_contribution = None;
let mut min_fee_rate = None;
Expand All @@ -452,7 +435,6 @@ impl<'de> Deserialize<'de> for RequestContext {
url::Url::from_str(&map.next_value::<String>()?)
.map_err(de::Error::custom)?,
),
"ohttp_keys" => ohttp_keys = Some(map.next_value()?),
"disable_output_substitution" =>
disable_output_substitution = Some(map.next_value()?),
"fee_contribution" => {
Expand All @@ -478,7 +460,6 @@ impl<'de> Deserialize<'de> for RequestContext {
Ok(RequestContext {
psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?,
endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?,
ohttp_keys: ohttp_keys.ok_or_else(|| de::Error::missing_field("ohttp_keys"))?,
disable_output_substitution: disable_output_substitution
.ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?,
fee_contribution,
Expand Down Expand Up @@ -974,7 +955,7 @@ fn serialize_v2_body(
) -> Result<Vec<u8>, CreateRequestError> {
// Grug say localhost base be discarded anyway. no big brain needed.
let placeholder_url = serialize_url(
"http:/localhost".to_string(),
Url::parse("http://localhost").unwrap(),
disable_output_substitution,
fee_contribution,
min_feerate,
Expand All @@ -986,12 +967,12 @@ fn serialize_v2_body(
}

fn serialize_url(
endpoint: String,
endpoint: Url,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
) -> Result<Url, url::ParseError> {
let mut url = Url::parse(&endpoint)?;
let mut url = endpoint;
url.query_pairs_mut().append_pair("v", "1");
if disable_output_substitution {
url.query_pairs_mut().append_pair("disableoutputsubstitution", "1");
Expand Down Expand Up @@ -1065,7 +1046,6 @@ mod test {
let req_ctx = RequestContext {
psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(),
endpoint: Url::parse("http://localhost:1234").unwrap(),
ohttp_keys: None,
disable_output_substitution: false,
fee_contribution: None,
min_fee_rate: FeeRate::ZERO,
Expand Down
19 changes: 1 addition & 18 deletions payjoin/src/uri/error.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
#[cfg(feature = "v2")]
use crate::v2::ParseOhttpKeysError;

#[derive(Debug)]
pub struct PjParseError(InternalPjParseError);

impl std::error::Error for PjParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use InternalPjParseError::*;

match &self.0 {
#[cfg(feature = "v2")]
BadOhttpKeys(e) => Some(e),
_ => None,
}
}
}
impl std::error::Error for PjParseError {}

#[derive(Debug)]
pub(crate) enum InternalPjParseError {
Expand All @@ -23,8 +10,6 @@ pub(crate) enum InternalPjParseError {
MissingEndpoint,
NotUtf8,
BadEndpoint,
#[cfg(feature = "v2")]
BadOhttpKeys(ParseOhttpKeysError),
UnsecureEndpoint,
}

Expand All @@ -43,8 +28,6 @@ impl std::fmt::Display for PjParseError {
MissingEndpoint => write!(f, "Missing payjoin endpoint"),
NotUtf8 => write!(f, "Endpoint is not valid UTF-8"),
BadEndpoint => write!(f, "Endpoint is not valid"),
#[cfg(feature = "v2")]
BadOhttpKeys(e) => write!(f, "OHTTP keys are not valid: {}", e),
UnsecureEndpoint => {
write!(f, "Endpoint scheme is not secure (https or onion)")
}
Expand Down
Loading

0 comments on commit 78e66f6

Please sign in to comment.