From d4f8256e5edd614fb26f6a6256c850f89896fa96 Mon Sep 17 00:00:00 2001 From: DanGould Date: Sun, 23 Jun 2024 21:58:19 -0400 Subject: [PATCH] Error if send or receive session expired --- payjoin-cli/src/app/v2.rs | 3 +- payjoin/src/receive/v2/error.rs | 44 +++++++++++++++++ payjoin/src/receive/{v2.rs => v2/mod.rs} | 20 ++++++-- payjoin/src/send/error.rs | 3 ++ payjoin/src/send/mod.rs | 16 +++++++ payjoin/src/uri.rs | 61 ++++++++++++++++++++++-- 6 files changed, 137 insertions(+), 10 deletions(-) create mode 100644 payjoin/src/receive/v2/error.rs rename payjoin/src/receive/{v2.rs => v2/mod.rs} (97%) diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index bb9bfc4d..89a124a0 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -213,8 +213,7 @@ impl App { session: &mut payjoin::receive::v2::ActiveSession, ) -> Result { loop { - let (req, context) = - session.extract_req().map_err(|_| anyhow!("Failed to extract request"))?; + let (req, context) = session.extract_req()?; println!("Polling receive request..."); let http = http_agent()?; let ohttp_response = http diff --git a/payjoin/src/receive/v2/error.rs b/payjoin/src/receive/v2/error.rs new file mode 100644 index 00000000..c6d7daf2 --- /dev/null +++ b/payjoin/src/receive/v2/error.rs @@ -0,0 +1,44 @@ +use core::fmt; +use std::error; + +use crate::v2::OhttpEncapsulationError; + +#[derive(Debug)] +pub struct SessionError(InternalSessionError); + +#[derive(Debug)] +pub(crate) enum InternalSessionError { + /// The session has expired + Expired(std::time::SystemTime), + /// OHTTP Encapsulation failed + OhttpEncapsulationError(OhttpEncapsulationError), +} + +impl fmt::Display for SessionError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.0 { + InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry), + InternalSessionError::OhttpEncapsulationError(e) => + write!(f, "OHTTP Encapsulation Error: {}", e), + } + } +} + +impl error::Error for SessionError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match &self.0 { + InternalSessionError::Expired(_) => None, + InternalSessionError::OhttpEncapsulationError(e) => Some(e), + } + } +} + +impl From for SessionError { + fn from(e: InternalSessionError) -> Self { SessionError(e) } +} + +impl From for SessionError { + fn from(e: OhttpEncapsulationError) -> Self { + SessionError(InternalSessionError::OhttpEncapsulationError(e)) + } +} diff --git a/payjoin/src/receive/v2.rs b/payjoin/src/receive/v2/mod.rs similarity index 97% rename from payjoin/src/receive/v2.rs rename to payjoin/src/receive/v2/mod.rs index ddc78641..a49b3e37 100644 --- a/payjoin/src/receive/v2.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -12,11 +12,15 @@ use serde::ser::SerializeStruct; use serde::{Deserialize, Serialize, Serializer}; use url::Url; +use super::v2::error::{InternalSessionError, SessionError}; use super::{Error, InternalRequestError, RequestError, SelectionError}; use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; +use crate::v2::OhttpEncapsulationError; use crate::{OhttpKeys, PjUriBuilder, Request}; +pub(crate) mod error; + #[derive(Debug, Clone, PartialEq, Eq)] struct SessionContext { address: Address, @@ -101,8 +105,12 @@ pub struct ActiveSession { } impl ActiveSession { - pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let (body, ohttp_ctx) = self.fallback_req_body()?; + pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> { + if SystemTime::now() > self.context.expiry { + return Err(InternalSessionError::Expired(self.context.expiry).into()); + } + let (body, ohttp_ctx) = + self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?; let url = self.context.ohttp_relay.clone(); let req = Request { url, body }; Ok((req, ohttp_ctx)) @@ -131,14 +139,16 @@ impl ActiveSession { } } - fn fallback_req_body(&mut self) -> Result<(Vec, ohttp::ClientResponse), Error> { + fn fallback_req_body( + &mut self, + ) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { let fallback_target = self.pj_url(); - Ok(crate::v2::ohttp_encapsulate( + crate::v2::ohttp_encapsulate( &mut self.context.ohttp_keys, "GET", fallback_target.as_str(), None, - )?) + ) } fn extract_proposal_from_v1(&mut self, response: String) -> Result { diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index fc47d9ed..9f048034 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -197,6 +197,7 @@ pub(crate) enum InternalCreateRequestError { SubdirectoryInvalidPubkey(bitcoin::secp256k1::Error), #[cfg(feature = "v2")] MissingOhttpConfig, + Expired(std::time::SystemTime), } impl fmt::Display for CreateRequestError { @@ -228,6 +229,7 @@ impl fmt::Display for CreateRequestError { SubdirectoryInvalidPubkey(e) => write!(f, "subdirectory does not represent a valid pubkey: {}", e), #[cfg(feature = "v2")] MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"), + Expired(expiry) => write!(f, "session expired at {:?}", expiry), } } } @@ -261,6 +263,7 @@ impl std::error::Error for CreateRequestError { SubdirectoryInvalidPubkey(error) => Some(error), #[cfg(feature = "v2")] MissingOhttpConfig => None, + Expired(_) => None, } } } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 7b818d19..9d586534 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -25,6 +25,7 @@ //! wallet and http client. use std::str::FromStr; +use std::time::SystemTime; use bitcoin::psbt::Psbt; #[cfg(feature = "v2")] @@ -204,6 +205,7 @@ impl<'a> RequestBuilder<'a> { let endpoint = self.uri.extras.endpoint.clone(); #[cfg(feature = "v2")] let ohttp_keys = self.uri.extras.ohttp_keys; + let expiry = self.uri.extras.expiry; let disable_output_substitution = self.uri.extras.disable_output_substitution || self.disable_output_substitution; let payee = self.uri.address.script_pubkey(); @@ -236,6 +238,7 @@ impl<'a> RequestBuilder<'a> { endpoint, #[cfg(feature = "v2")] ohttp_keys, + expiry, disable_output_substitution, fee_contribution, payee, @@ -254,6 +257,7 @@ pub struct RequestContext { endpoint: Url, #[cfg(feature = "v2")] ohttp_keys: Option, + expiry: Option, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, @@ -320,6 +324,11 @@ impl RequestContext { &mut self, ohttp_relay: Url, ) -> Result<(Request, ContextV2), CreateRequestError> { + if let Some(expiry) = self.expiry { + if SystemTime::now() > expiry { + return Err(InternalCreateRequestError::Expired(expiry).into()); + } + } let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -401,6 +410,7 @@ impl Serialize for RequestContext { let mut state = serializer.serialize_struct("RequestContext", 8)?; state.serialize_field("psbt", &self.psbt.to_string())?; state.serialize_field("endpoint", &self.endpoint.as_str())?; + state.serialize_field("expiry", &self.expiry)?; let ohttp_string = self.ohttp_keys.as_ref().map_or(Ok("".to_string()), |config| { config .encode() @@ -456,6 +466,7 @@ impl<'de> Deserialize<'de> for RequestContext { { let mut psbt = None; let mut endpoint = None; + let mut expiry = None; let mut ohttp_keys = None; let mut disable_output_substitution = None; let mut fee_contribution = None; @@ -476,6 +487,9 @@ impl<'de> Deserialize<'de> for RequestContext { url::Url::from_str(&map.next_value::()?) .map_err(de::Error::custom)?, ), + "expiry" => { + expiry = map.next_value()?; + } "ohttp_keys" => { let ohttp_base64: String = map.next_value()?; ohttp_keys = if ohttp_base64.is_empty() { @@ -526,6 +540,7 @@ impl<'de> Deserialize<'de> for RequestContext { sequence: sequence.ok_or_else(|| de::Error::missing_field("sequence"))?, payee: payee.ok_or_else(|| de::Error::missing_field("payee"))?, e: e.ok_or_else(|| de::Error::missing_field("e"))?, + expiry, }) } } @@ -1104,6 +1119,7 @@ mod test { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), endpoint: Url::parse("http://localhost:1234").unwrap(), ohttp_keys: None, + expiry: None, disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, diff --git a/payjoin/src/uri.rs b/payjoin/src/uri.rs index e9aee45d..42932929 100644 --- a/payjoin/src/uri.rs +++ b/payjoin/src/uri.rs @@ -7,7 +7,11 @@ use url::Url; #[cfg(feature = "v2")] use crate::OhttpKeys; -#[derive(Clone)] +#[cfg(feature = "v2")] +static TWENTY_FOUR_HOURS_DEFAULT_EXPIRY: std::time::Duration = + std::time::Duration::from_secs(60 * 60 * 24); + +#[derive(Clone, Debug)] pub enum MaybePayjoinExtras { Supported(PayjoinExtras), Unsupported, @@ -22,12 +26,13 @@ impl MaybePayjoinExtras { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PayjoinExtras { pub(crate) endpoint: Url, pub(crate) disable_output_substitution: bool, #[cfg(feature = "v2")] pub(crate) ohttp_keys: Option, + pub(crate) expiry: Option, } impl PayjoinExtras { @@ -97,6 +102,10 @@ pub struct PjUriBuilder { /// /// Required only for v2 payjoin. ohttp: Option, + #[cfg(feature = "v2")] + /// Custom expiry for the payjoin request. + /// Default is 24 hours. + expiry: std::time::SystemTime, } impl PjUriBuilder { @@ -115,6 +124,8 @@ impl PjUriBuilder { pjos: false, #[cfg(feature = "v2")] ohttp: ohttp_keys, + #[cfg(feature = "v2")] + expiry: std::time::SystemTime::now() + TWENTY_FOUR_HOURS_DEFAULT_EXPIRY, } } /// Set the amount you want to receive. @@ -141,6 +152,13 @@ impl PjUriBuilder { self } + #[cfg(feature = "v2")] + /// Set custom expiry for the payjoin request. + pub fn expiry(mut self, expiry: std::time::SystemTime) -> Self { + self.expiry = expiry; + self + } + /// Build payjoin URI. /// /// Constructs a `bip21::Uri` with PayjoinParams from the @@ -151,6 +169,10 @@ impl PjUriBuilder { disable_output_substitution: self.pjos, #[cfg(feature = "v2")] ohttp_keys: self.ohttp, + #[cfg(feature = "v2")] + expiry: Some(self.expiry), + #[cfg(not(feature = "v2"))] + expiry: None, }; let mut pj_uri = bip21::Uri::with_extras(self.address, extras); pj_uri.amount = self.amount; @@ -178,6 +200,7 @@ pub struct DeserializationState { pjos: Option, #[cfg(feature = "v2")] ohttp: Option, + expiry: Option, } #[derive(Debug)] @@ -220,6 +243,13 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras { } else { log::warn!("Failed to encode ohttp config, ignoring"); } + if let Some(expiry) = self.expiry { + let expiry = expiry + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .expect("expiry is in the past") + .as_secs(); + params.push(("exp", expiry.to_string())); + } params.into_iter() } } @@ -227,7 +257,7 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras { impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { type Value = MaybePayjoinExtras; - fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos") } + fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos" | "exp") } fn deserialize_temp( &mut self, @@ -252,6 +282,19 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { } #[cfg(feature = "v2")] "ohttp" => Err(PjParseError(InternalPjParseError::MultipleParams("ohttp"))), + "exp" if self.expiry.is_none() => { + let expiry = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; + let expiry = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_secs( + expiry.parse().map_err(|_| InternalPjParseError::NotUtf8)?, + ); + if expiry < std::time::SystemTime::now() { + return Err(InternalPjParseError::Expired(expiry).into()); + } + self.expiry = Some(expiry); + Ok(bip21::de::ParamKind::Known) + } + "exp" => Err(PjParseError(InternalPjParseError::MultipleParams("expiry"))), "pj" if self.pj.is_none() => { let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?; @@ -289,6 +332,7 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { disable_output_substitution: pjos.unwrap_or(false), #[cfg(feature = "v2")] ohttp_keys: self.ohttp, + expiry: self.expiry, })) } else { Err(PjParseError(InternalPjParseError::UnsecureEndpoint)) @@ -315,6 +359,9 @@ impl std::fmt::Display for PjParseError { InternalPjParseError::UnsecureEndpoint => { write!(f, "Endpoint scheme is not secure (https or onion)") } + InternalPjParseError::Expired(expiry) => { + write!(f, "Expiry is in the past: {:?}", expiry) + } } } } @@ -331,6 +378,7 @@ enum InternalPjParseError { #[cfg(feature = "v2")] DecodeOhttpKeys, UnsecureEndpoint, + Expired(std::time::SystemTime), } #[cfg(test)] @@ -374,6 +422,13 @@ mod tests { assert!(Uri::try_from(uri).is_err(), "unencrypted connection"); } + #[test] + fn test_expired() { + let uri = + "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=1&pj=https://example.com&exp=0"; + assert!(Uri::try_from(uri).is_err(), "expired"); + } + #[test] fn test_valid_uris() { let https = "https://example.com";