From b834465ebcbf01069c6a994f3e4bef617e66a066 Mon Sep 17 00:00:00 2001 From: DanGould Date: Sun, 23 Jun 2024 21:58:19 -0400 Subject: [PATCH] Error if send or receive session expired --- payjoin-cli/src/app/v2.rs | 3 +- payjoin/src/receive/v2/error.rs | 44 +++++++++++++++++++++++ payjoin/src/receive/{v2.rs => v2/mod.rs} | 23 ++++++++---- payjoin/src/send/error.rs | 6 ++++ payjoin/src/send/mod.rs | 8 +++++ payjoin/src/uri/mod.rs | 25 ++++++++++--- payjoin/src/uri/url_ext.rs | 46 ++++++++++++++++++++++++ 7 files changed, 141 insertions(+), 14 deletions(-) create mode 100644 payjoin/src/receive/v2/error.rs rename payjoin/src/receive/{v2.rs => v2/mod.rs} (97%) diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index a08cb922..6eed9ef5 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -246,8 +246,7 @@ impl App { session: &mut payjoin::receive::v2::ActiveSession, ) -> Result { loop { - let (req, context) = - session.extract_req().map_err(|_| anyhow!("Failed to extract request"))?; + let (req, context) = session.extract_req()?; println!("Polling receive request..."); let http = http_agent()?; let ohttp_response = http diff --git a/payjoin/src/receive/v2/error.rs b/payjoin/src/receive/v2/error.rs new file mode 100644 index 00000000..c6d7daf2 --- /dev/null +++ b/payjoin/src/receive/v2/error.rs @@ -0,0 +1,44 @@ +use core::fmt; +use std::error; + +use crate::v2::OhttpEncapsulationError; + +#[derive(Debug)] +pub struct SessionError(InternalSessionError); + +#[derive(Debug)] +pub(crate) enum InternalSessionError { + /// The session has expired + Expired(std::time::SystemTime), + /// OHTTP Encapsulation failed + OhttpEncapsulationError(OhttpEncapsulationError), +} + +impl fmt::Display for SessionError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.0 { + InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry), + InternalSessionError::OhttpEncapsulationError(e) => + write!(f, "OHTTP Encapsulation Error: {}", e), + } + } +} + +impl error::Error for SessionError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match &self.0 { + InternalSessionError::Expired(_) => None, + InternalSessionError::OhttpEncapsulationError(e) => Some(e), + } + } +} + +impl From for SessionError { + fn from(e: InternalSessionError) -> Self { SessionError(e) } +} + +impl From for SessionError { + fn from(e: OhttpEncapsulationError) -> Self { + SessionError(InternalSessionError::OhttpEncapsulationError(e)) + } +} diff --git a/payjoin/src/receive/v2.rs b/payjoin/src/receive/v2/mod.rs similarity index 97% rename from payjoin/src/receive/v2.rs rename to payjoin/src/receive/v2/mod.rs index bdabe5b8..7248bdfe 100644 --- a/payjoin/src/receive/v2.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -12,13 +12,15 @@ use serde::ser::SerializeStruct; use serde::{Deserialize, Serialize, Serializer}; use url::Url; +use super::v2::error::{InternalSessionError, SessionError}; use super::{Error, InternalRequestError, RequestError, SelectionError}; use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; +use crate::v2::OhttpEncapsulationError; use crate::{OhttpKeys, PjUriBuilder, Request}; -/// The state for a payjoin V2 receive session, including necessary -/// information for communication and cryptographic operations. +pub(crate) mod error; + #[derive(Debug, Clone, PartialEq, Eq)] struct SessionContext { address: Address, @@ -117,8 +119,12 @@ pub struct ActiveSession { } impl ActiveSession { - pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let (body, ohttp_ctx) = self.fallback_req_body()?; + pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> { + if SystemTime::now() > self.context.expiry { + return Err(InternalSessionError::Expired(self.context.expiry).into()); + } + let (body, ohttp_ctx) = + self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?; let url = self.context.ohttp_relay.clone(); let req = Request { url, body }; Ok((req, ohttp_ctx)) @@ -147,14 +153,16 @@ impl ActiveSession { } } - fn fallback_req_body(&mut self) -> Result<(Vec, ohttp::ClientResponse), Error> { + fn fallback_req_body( + &mut self, + ) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { let fallback_target = self.pj_url(); - Ok(crate::v2::ohttp_encapsulate( + crate::v2::ohttp_encapsulate( &mut self.context.ohttp_keys, "GET", fallback_target.as_str(), None, - )?) + ) } fn extract_proposal_from_v1(&mut self, response: String) -> Result { @@ -204,6 +212,7 @@ impl ActiveSession { self.context.address.clone(), self.pj_url(), Some(self.context.ohttp_keys.clone()), + None, ) } diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 7ef34333..51ba0036 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -197,6 +197,8 @@ pub(crate) enum InternalCreateRequestError { MissingOhttpConfig, #[cfg(feature = "v2")] PercentEncoding, + #[cfg(feature = "v2")] + Expired(std::time::SystemTime), } impl fmt::Display for CreateRequestError { @@ -228,6 +230,8 @@ impl fmt::Display for CreateRequestError { MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"), #[cfg(feature = "v2")] PercentEncoding => write!(f, "fragment is not RFC 3986 percent-encoded"), + #[cfg(feature = "v2")] + Expired(expiry) => write!(f, "session expired at {:?}", expiry), } } } @@ -261,6 +265,8 @@ impl std::error::Error for CreateRequestError { MissingOhttpConfig => None, #[cfg(feature = "v2")] PercentEncoding => None, + #[cfg(feature = "v2")] + Expired(_) => None, } } } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 9f1bdc85..8103b92d 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -299,6 +299,14 @@ impl RequestContext { ohttp_relay: Url, ) -> Result<(Request, ContextV2), CreateRequestError> { use crate::uri::UrlExt; + + if let Some(expiry) = + self.endpoint.exp().map_err(|_| InternalCreateRequestError::PercentEncoding)? + { + if std::time::SystemTime::now() > expiry { + return Err(InternalCreateRequestError::Expired(expiry).into()); + } + } let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?; let url = self.endpoint.clone(); let body = serialize_v2_body( diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 901874ce..dd61b8c2 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -15,7 +15,11 @@ pub mod error; #[cfg(feature = "v2")] pub(crate) mod url_ext; -#[derive(Debug, Clone)] +#[cfg(feature = "v2")] +static TWENTY_FOUR_HOURS_DEFAULT_EXPIRY: std::time::Duration = + std::time::Duration::from_secs(60 * 60 * 24); + +#[derive(Clone, Debug)] pub enum MaybePayjoinExtras { Supported(PayjoinExtras), Unsupported, @@ -111,11 +115,14 @@ impl PjUriBuilder { address: Address, origin: Url, #[cfg(feature = "v2")] ohttp_keys: Option, + #[cfg(feature = "v2")] expiry: Option, ) -> Self { #[allow(unused_mut)] let mut pj = origin; #[cfg(feature = "v2")] let _ = pj.set_ohttp(ohttp_keys); + #[cfg(feature = "v2")] + let _ = pj.set_exp(expiry); Self { address, amount: None, message: None, label: None, pj, pjos: false } } /// Set the amount you want to receive. @@ -204,7 +211,7 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras { impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { type Value = MaybePayjoinExtras; - fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos") } + fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos" | "exp") } fn deserialize_temp( &mut self, @@ -243,9 +250,15 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { (None, None) => Ok(MaybePayjoinExtras::Unsupported), (None, Some(_)) => Err(InternalPjParseError::MissingEndpoint.into()), (Some(endpoint), pjos) => { - if endpoint.scheme() == "https" - || endpoint.scheme() == "http" - && endpoint.domain().unwrap_or_default().ends_with(".onion") + #[cfg(feature = "v2")] + let fragment_checked = endpoint.ohttp().is_ok() && endpoint.exp().is_ok(); + + #[cfg(not(feature = "v2"))] + let fragment_checked = true; + if fragment_checked + && (endpoint.scheme() == "https" + || endpoint.scheme() == "http" + && endpoint.domain().unwrap_or_default().ends_with(".onion")) { Ok(MaybePayjoinExtras::Supported(PayjoinExtras { endpoint, @@ -352,6 +365,8 @@ mod tests { Url::parse(pj).unwrap(), #[cfg(feature = "v2")] None, + #[cfg(feature = "v2")] + None, ) .amount(amount) .message("message".to_string()) diff --git a/payjoin/src/uri/url_ext.rs b/payjoin/src/uri/url_ext.rs index b2048a40..671ba741 100644 --- a/payjoin/src/uri/url_ext.rs +++ b/payjoin/src/uri/url_ext.rs @@ -9,6 +9,8 @@ use crate::OhttpKeys; pub(crate) trait UrlExt { fn ohttp(&self) -> Result, PercentDecodeError>; fn set_ohttp(&mut self, ohttp: Option) -> Result<(), PercentDecodeError>; + fn exp(&self) -> Result, PercentDecodeError>; + fn set_exp(&mut self, exp: Option) -> Result<(), PercentDecodeError>; } // Characters '=' and '&' conflict with BIP21 URI parameters and must be percent-encoded @@ -55,6 +57,50 @@ impl UrlExt for Url { self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) }); Ok(()) } + + /// Retrieve the exp parameter from the URL fragment + fn exp(&self) -> Result, PercentDecodeError> { + if let Some(fragment) = self.fragment() { + let decoded_fragment = + percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy(); + for param in decoded_fragment.split('&') { + if let Some(value) = param.strip_prefix("exp=") { + if let Ok(timestamp) = value.parse::() { + return Ok(Some( + std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp), + )); + } + } + } + } + Ok(None) + } + + /// Set the exp parameter in the URL fragment + fn set_exp(&mut self, exp: Option) -> Result<(), PercentDecodeError> { + let fragment = self.fragment().unwrap_or("").to_string(); + let mut fragment = + percent_encoding::percent_decode_str(&fragment)?.decode_utf8_lossy().to_string(); + if let Some(start) = fragment.find("exp=") { + let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); + fragment.replace_range(start..end, ""); + if fragment.ends_with('&') { + fragment.pop(); + } + } + if let Some(exp) = exp { + let timestamp = exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + let new_exp = format!("exp={}", timestamp); + if !fragment.is_empty() { + fragment.push('&'); + } + fragment.push_str(&new_exp); + } + let encoded_fragment = + percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string(); + self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) }); + Ok(()) + } } #[cfg(test)]