Skip to content

Commit

Permalink
Error if send or receive session expired
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Jun 24, 2024
1 parent 8b2caa7 commit d4f8256
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 10 deletions.
3 changes: 1 addition & 2 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ impl App {
session: &mut payjoin::receive::v2::ActiveSession,
) -> Result<payjoin::receive::v2::UncheckedProposal> {
loop {
let (req, context) =
session.extract_req().map_err(|_| anyhow!("Failed to extract request"))?;
let (req, context) = session.extract_req()?;
println!("Polling receive request...");
let http = http_agent()?;
let ohttp_response = http
Expand Down
44 changes: 44 additions & 0 deletions payjoin/src/receive/v2/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use core::fmt;
use std::error;

use crate::v2::OhttpEncapsulationError;

#[derive(Debug)]
pub struct SessionError(InternalSessionError);

#[derive(Debug)]
pub(crate) enum InternalSessionError {
/// The session has expired
Expired(std::time::SystemTime),
/// OHTTP Encapsulation failed
OhttpEncapsulationError(OhttpEncapsulationError),
}

impl fmt::Display for SessionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.0 {
InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry),
InternalSessionError::OhttpEncapsulationError(e) =>
write!(f, "OHTTP Encapsulation Error: {}", e),
}
}
}

impl error::Error for SessionError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self.0 {
InternalSessionError::Expired(_) => None,
InternalSessionError::OhttpEncapsulationError(e) => Some(e),
}
}
}

impl From<InternalSessionError> for SessionError {
fn from(e: InternalSessionError) -> Self { SessionError(e) }
}

impl From<OhttpEncapsulationError> for SessionError {
fn from(e: OhttpEncapsulationError) -> Self {
SessionError(InternalSessionError::OhttpEncapsulationError(e))
}
}
20 changes: 15 additions & 5 deletions payjoin/src/receive/v2.rs → payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use url::Url;

use super::v2::error::{InternalSessionError, SessionError};
use super::{Error, InternalRequestError, RequestError, SelectionError};
use crate::psbt::PsbtExt;
use crate::receive::optional_parameters::Params;
use crate::v2::OhttpEncapsulationError;
use crate::{OhttpKeys, PjUriBuilder, Request};

pub(crate) mod error;

#[derive(Debug, Clone, PartialEq, Eq)]
struct SessionContext {
address: Address,
Expand Down Expand Up @@ -101,8 +105,12 @@ pub struct ActiveSession {
}

impl ActiveSession {
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> {
let (body, ohttp_ctx) = self.fallback_req_body()?;
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> {
if SystemTime::now() > self.context.expiry {
return Err(InternalSessionError::Expired(self.context.expiry).into());
}
let (body, ohttp_ctx) =
self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?;
let url = self.context.ohttp_relay.clone();
let req = Request { url, body };
Ok((req, ohttp_ctx))
Expand Down Expand Up @@ -131,14 +139,16 @@ impl ActiveSession {
}
}

fn fallback_req_body(&mut self) -> Result<(Vec<u8>, ohttp::ClientResponse), Error> {
fn fallback_req_body(
&mut self,
) -> Result<(Vec<u8>, ohttp::ClientResponse), OhttpEncapsulationError> {
let fallback_target = self.pj_url();
Ok(crate::v2::ohttp_encapsulate(
crate::v2::ohttp_encapsulate(
&mut self.context.ohttp_keys,
"GET",
fallback_target.as_str(),
None,
)?)
)
}

fn extract_proposal_from_v1(&mut self, response: String) -> Result<UncheckedProposal, Error> {
Expand Down
3 changes: 3 additions & 0 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ pub(crate) enum InternalCreateRequestError {
SubdirectoryInvalidPubkey(bitcoin::secp256k1::Error),
#[cfg(feature = "v2")]
MissingOhttpConfig,
Expired(std::time::SystemTime),
}

impl fmt::Display for CreateRequestError {
Expand Down Expand Up @@ -228,6 +229,7 @@ impl fmt::Display for CreateRequestError {
SubdirectoryInvalidPubkey(e) => write!(f, "subdirectory does not represent a valid pubkey: {}", e),
#[cfg(feature = "v2")]
MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"),
Expired(expiry) => write!(f, "session expired at {:?}", expiry),
}
}
}
Expand Down Expand Up @@ -261,6 +263,7 @@ impl std::error::Error for CreateRequestError {
SubdirectoryInvalidPubkey(error) => Some(error),
#[cfg(feature = "v2")]
MissingOhttpConfig => None,
Expired(_) => None,
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//! wallet and http client.
use std::str::FromStr;
use std::time::SystemTime;

use bitcoin::psbt::Psbt;
#[cfg(feature = "v2")]
Expand Down Expand Up @@ -204,6 +205,7 @@ impl<'a> RequestBuilder<'a> {
let endpoint = self.uri.extras.endpoint.clone();
#[cfg(feature = "v2")]
let ohttp_keys = self.uri.extras.ohttp_keys;
let expiry = self.uri.extras.expiry;
let disable_output_substitution =
self.uri.extras.disable_output_substitution || self.disable_output_substitution;
let payee = self.uri.address.script_pubkey();
Expand Down Expand Up @@ -236,6 +238,7 @@ impl<'a> RequestBuilder<'a> {
endpoint,
#[cfg(feature = "v2")]
ohttp_keys,
expiry,
disable_output_substitution,
fee_contribution,
payee,
Expand All @@ -254,6 +257,7 @@ pub struct RequestContext {
endpoint: Url,
#[cfg(feature = "v2")]
ohttp_keys: Option<crate::v2::OhttpKeys>,
expiry: Option<SystemTime>,
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
Expand Down Expand Up @@ -320,6 +324,11 @@ impl RequestContext {
&mut self,
ohttp_relay: Url,
) -> Result<(Request, ContextV2), CreateRequestError> {
if let Some(expiry) = self.expiry {
if SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}
let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
Expand Down Expand Up @@ -401,6 +410,7 @@ impl Serialize for RequestContext {
let mut state = serializer.serialize_struct("RequestContext", 8)?;
state.serialize_field("psbt", &self.psbt.to_string())?;
state.serialize_field("endpoint", &self.endpoint.as_str())?;
state.serialize_field("expiry", &self.expiry)?;
let ohttp_string = self.ohttp_keys.as_ref().map_or(Ok("".to_string()), |config| {
config
.encode()
Expand Down Expand Up @@ -456,6 +466,7 @@ impl<'de> Deserialize<'de> for RequestContext {
{
let mut psbt = None;
let mut endpoint = None;
let mut expiry = None;
let mut ohttp_keys = None;
let mut disable_output_substitution = None;
let mut fee_contribution = None;
Expand All @@ -476,6 +487,9 @@ impl<'de> Deserialize<'de> for RequestContext {
url::Url::from_str(&map.next_value::<String>()?)
.map_err(de::Error::custom)?,
),
"expiry" => {
expiry = map.next_value()?;
}
"ohttp_keys" => {
let ohttp_base64: String = map.next_value()?;
ohttp_keys = if ohttp_base64.is_empty() {
Expand Down Expand Up @@ -526,6 +540,7 @@ impl<'de> Deserialize<'de> for RequestContext {
sequence: sequence.ok_or_else(|| de::Error::missing_field("sequence"))?,
payee: payee.ok_or_else(|| de::Error::missing_field("payee"))?,
e: e.ok_or_else(|| de::Error::missing_field("e"))?,
expiry,
})
}
}
Expand Down Expand Up @@ -1104,6 +1119,7 @@ mod test {
psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(),
endpoint: Url::parse("http://localhost:1234").unwrap(),
ohttp_keys: None,
expiry: None,
disable_output_substitution: false,
fee_contribution: None,
min_fee_rate: FeeRate::ZERO,
Expand Down
61 changes: 58 additions & 3 deletions payjoin/src/uri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ use url::Url;
#[cfg(feature = "v2")]
use crate::OhttpKeys;

#[derive(Clone)]
#[cfg(feature = "v2")]
static TWENTY_FOUR_HOURS_DEFAULT_EXPIRY: std::time::Duration =
std::time::Duration::from_secs(60 * 60 * 24);

#[derive(Clone, Debug)]
pub enum MaybePayjoinExtras {
Supported(PayjoinExtras),
Unsupported,
Expand All @@ -22,12 +26,13 @@ impl MaybePayjoinExtras {
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct PayjoinExtras {
pub(crate) endpoint: Url,
pub(crate) disable_output_substitution: bool,
#[cfg(feature = "v2")]
pub(crate) ohttp_keys: Option<OhttpKeys>,
pub(crate) expiry: Option<std::time::SystemTime>,
}

impl PayjoinExtras {
Expand Down Expand Up @@ -97,6 +102,10 @@ pub struct PjUriBuilder {
///
/// Required only for v2 payjoin.
ohttp: Option<OhttpKeys>,
#[cfg(feature = "v2")]
/// Custom expiry for the payjoin request.
/// Default is 24 hours.
expiry: std::time::SystemTime,
}

impl PjUriBuilder {
Expand All @@ -115,6 +124,8 @@ impl PjUriBuilder {
pjos: false,
#[cfg(feature = "v2")]
ohttp: ohttp_keys,
#[cfg(feature = "v2")]
expiry: std::time::SystemTime::now() + TWENTY_FOUR_HOURS_DEFAULT_EXPIRY,
}
}
/// Set the amount you want to receive.
Expand All @@ -141,6 +152,13 @@ impl PjUriBuilder {
self
}

#[cfg(feature = "v2")]
/// Set custom expiry for the payjoin request.
pub fn expiry(mut self, expiry: std::time::SystemTime) -> Self {
self.expiry = expiry;
self
}

/// Build payjoin URI.
///
/// Constructs a `bip21::Uri` with PayjoinParams from the
Expand All @@ -151,6 +169,10 @@ impl PjUriBuilder {
disable_output_substitution: self.pjos,
#[cfg(feature = "v2")]
ohttp_keys: self.ohttp,
#[cfg(feature = "v2")]
expiry: Some(self.expiry),
#[cfg(not(feature = "v2"))]
expiry: None,
};
let mut pj_uri = bip21::Uri::with_extras(self.address, extras);
pj_uri.amount = self.amount;
Expand Down Expand Up @@ -178,6 +200,7 @@ pub struct DeserializationState {
pjos: Option<bool>,
#[cfg(feature = "v2")]
ohttp: Option<OhttpKeys>,
expiry: Option<std::time::SystemTime>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -220,14 +243,21 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras {
} else {
log::warn!("Failed to encode ohttp config, ignoring");
}
if let Some(expiry) = self.expiry {
let expiry = expiry
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.expect("expiry is in the past")
.as_secs();
params.push(("exp", expiry.to_string()));
}
params.into_iter()
}
}

impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
type Value = MaybePayjoinExtras;

fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos") }
fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos" | "exp") }

fn deserialize_temp(
&mut self,
Expand All @@ -252,6 +282,19 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
}
#[cfg(feature = "v2")]
"ohttp" => Err(PjParseError(InternalPjParseError::MultipleParams("ohttp"))),
"exp" if self.expiry.is_none() => {
let expiry = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let expiry = std::time::SystemTime::UNIX_EPOCH
+ std::time::Duration::from_secs(
expiry.parse().map_err(|_| InternalPjParseError::NotUtf8)?,
);
if expiry < std::time::SystemTime::now() {
return Err(InternalPjParseError::Expired(expiry).into());
}
self.expiry = Some(expiry);
Ok(bip21::de::ParamKind::Known)
}
"exp" => Err(PjParseError(InternalPjParseError::MultipleParams("expiry"))),
"pj" if self.pj.is_none() => {
let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?;
Expand Down Expand Up @@ -289,6 +332,7 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
disable_output_substitution: pjos.unwrap_or(false),
#[cfg(feature = "v2")]
ohttp_keys: self.ohttp,
expiry: self.expiry,
}))
} else {
Err(PjParseError(InternalPjParseError::UnsecureEndpoint))
Expand All @@ -315,6 +359,9 @@ impl std::fmt::Display for PjParseError {
InternalPjParseError::UnsecureEndpoint => {
write!(f, "Endpoint scheme is not secure (https or onion)")
}
InternalPjParseError::Expired(expiry) => {
write!(f, "Expiry is in the past: {:?}", expiry)
}
}
}
}
Expand All @@ -331,6 +378,7 @@ enum InternalPjParseError {
#[cfg(feature = "v2")]
DecodeOhttpKeys,
UnsecureEndpoint,
Expired(std::time::SystemTime),
}

#[cfg(test)]
Expand Down Expand Up @@ -374,6 +422,13 @@ mod tests {
assert!(Uri::try_from(uri).is_err(), "unencrypted connection");
}

#[test]
fn test_expired() {
let uri =
"bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=1&pj=https://example.com&exp=0";
assert!(Uri::try_from(uri).is_err(), "expired");
}

#[test]
fn test_valid_uris() {
let https = "https://example.com";
Expand Down

0 comments on commit d4f8256

Please sign in to comment.