Skip to content

Commit

Permalink
Error if send or receive session expired
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Jul 13, 2024
1 parent c3ac51f commit b834465
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 14 deletions.
3 changes: 1 addition & 2 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ impl App {
session: &mut payjoin::receive::v2::ActiveSession,
) -> Result<payjoin::receive::v2::UncheckedProposal> {
loop {
let (req, context) =
session.extract_req().map_err(|_| anyhow!("Failed to extract request"))?;
let (req, context) = session.extract_req()?;
println!("Polling receive request...");
let http = http_agent()?;
let ohttp_response = http
Expand Down
44 changes: 44 additions & 0 deletions payjoin/src/receive/v2/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use core::fmt;
use std::error;

use crate::v2::OhttpEncapsulationError;

#[derive(Debug)]
pub struct SessionError(InternalSessionError);

#[derive(Debug)]
pub(crate) enum InternalSessionError {
/// The session has expired
Expired(std::time::SystemTime),
/// OHTTP Encapsulation failed
OhttpEncapsulationError(OhttpEncapsulationError),
}

impl fmt::Display for SessionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.0 {
InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry),
InternalSessionError::OhttpEncapsulationError(e) =>
write!(f, "OHTTP Encapsulation Error: {}", e),
}
}
}

impl error::Error for SessionError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self.0 {
InternalSessionError::Expired(_) => None,
InternalSessionError::OhttpEncapsulationError(e) => Some(e),
}
}
}

impl From<InternalSessionError> for SessionError {
fn from(e: InternalSessionError) -> Self { SessionError(e) }
}

impl From<OhttpEncapsulationError> for SessionError {
fn from(e: OhttpEncapsulationError) -> Self {
SessionError(InternalSessionError::OhttpEncapsulationError(e))
}
}
23 changes: 16 additions & 7 deletions payjoin/src/receive/v2.rs → payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use url::Url;

use super::v2::error::{InternalSessionError, SessionError};
use super::{Error, InternalRequestError, RequestError, SelectionError};
use crate::psbt::PsbtExt;
use crate::receive::optional_parameters::Params;
use crate::v2::OhttpEncapsulationError;
use crate::{OhttpKeys, PjUriBuilder, Request};

/// The state for a payjoin V2 receive session, including necessary
/// information for communication and cryptographic operations.
pub(crate) mod error;

#[derive(Debug, Clone, PartialEq, Eq)]
struct SessionContext {
address: Address,
Expand Down Expand Up @@ -117,8 +119,12 @@ pub struct ActiveSession {
}

impl ActiveSession {
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> {
let (body, ohttp_ctx) = self.fallback_req_body()?;
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> {
if SystemTime::now() > self.context.expiry {
return Err(InternalSessionError::Expired(self.context.expiry).into());
}
let (body, ohttp_ctx) =
self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?;
let url = self.context.ohttp_relay.clone();
let req = Request { url, body };
Ok((req, ohttp_ctx))
Expand Down Expand Up @@ -147,14 +153,16 @@ impl ActiveSession {
}
}

fn fallback_req_body(&mut self) -> Result<(Vec<u8>, ohttp::ClientResponse), Error> {
fn fallback_req_body(
&mut self,
) -> Result<(Vec<u8>, ohttp::ClientResponse), OhttpEncapsulationError> {
let fallback_target = self.pj_url();
Ok(crate::v2::ohttp_encapsulate(
crate::v2::ohttp_encapsulate(
&mut self.context.ohttp_keys,
"GET",
fallback_target.as_str(),
None,
)?)
)
}

fn extract_proposal_from_v1(&mut self, response: String) -> Result<UncheckedProposal, Error> {
Expand Down Expand Up @@ -204,6 +212,7 @@ impl ActiveSession {
self.context.address.clone(),
self.pj_url(),
Some(self.context.ohttp_keys.clone()),
None,
)
}

Expand Down
6 changes: 6 additions & 0 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ pub(crate) enum InternalCreateRequestError {
MissingOhttpConfig,
#[cfg(feature = "v2")]
PercentEncoding,
#[cfg(feature = "v2")]
Expired(std::time::SystemTime),
}

impl fmt::Display for CreateRequestError {
Expand Down Expand Up @@ -228,6 +230,8 @@ impl fmt::Display for CreateRequestError {
MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"),
#[cfg(feature = "v2")]
PercentEncoding => write!(f, "fragment is not RFC 3986 percent-encoded"),
#[cfg(feature = "v2")]
Expired(expiry) => write!(f, "session expired at {:?}", expiry),
}
}
}
Expand Down Expand Up @@ -261,6 +265,8 @@ impl std::error::Error for CreateRequestError {
MissingOhttpConfig => None,
#[cfg(feature = "v2")]
PercentEncoding => None,
#[cfg(feature = "v2")]
Expired(_) => None,
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,14 @@ impl RequestContext {
ohttp_relay: Url,
) -> Result<(Request, ContextV2), CreateRequestError> {
use crate::uri::UrlExt;

if let Some(expiry) =
self.endpoint.exp().map_err(|_| InternalCreateRequestError::PercentEncoding)?
{
if std::time::SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}
let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
Expand Down
25 changes: 20 additions & 5 deletions payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ pub mod error;
#[cfg(feature = "v2")]
pub(crate) mod url_ext;

#[derive(Debug, Clone)]
#[cfg(feature = "v2")]
static TWENTY_FOUR_HOURS_DEFAULT_EXPIRY: std::time::Duration =
std::time::Duration::from_secs(60 * 60 * 24);

#[derive(Clone, Debug)]
pub enum MaybePayjoinExtras {
Supported(PayjoinExtras),
Unsupported,
Expand Down Expand Up @@ -111,11 +115,14 @@ impl PjUriBuilder {
address: Address,
origin: Url,
#[cfg(feature = "v2")] ohttp_keys: Option<OhttpKeys>,
#[cfg(feature = "v2")] expiry: Option<std::time::SystemTime>,
) -> Self {
#[allow(unused_mut)]
let mut pj = origin;
#[cfg(feature = "v2")]
let _ = pj.set_ohttp(ohttp_keys);
#[cfg(feature = "v2")]
let _ = pj.set_exp(expiry);
Self { address, amount: None, message: None, label: None, pj, pjos: false }
}
/// Set the amount you want to receive.
Expand Down Expand Up @@ -204,7 +211,7 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras {
impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
type Value = MaybePayjoinExtras;

fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos") }
fn is_param_known(&self, param: &str) -> bool { matches!(param, "pj" | "pjos" | "exp") }

fn deserialize_temp(
&mut self,
Expand Down Expand Up @@ -243,9 +250,15 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
(None, None) => Ok(MaybePayjoinExtras::Unsupported),
(None, Some(_)) => Err(InternalPjParseError::MissingEndpoint.into()),
(Some(endpoint), pjos) => {
if endpoint.scheme() == "https"
|| endpoint.scheme() == "http"
&& endpoint.domain().unwrap_or_default().ends_with(".onion")
#[cfg(feature = "v2")]
let fragment_checked = endpoint.ohttp().is_ok() && endpoint.exp().is_ok();

#[cfg(not(feature = "v2"))]
let fragment_checked = true;
if fragment_checked
&& (endpoint.scheme() == "https"
|| endpoint.scheme() == "http"
&& endpoint.domain().unwrap_or_default().ends_with(".onion"))
{
Ok(MaybePayjoinExtras::Supported(PayjoinExtras {
endpoint,
Expand Down Expand Up @@ -352,6 +365,8 @@ mod tests {
Url::parse(pj).unwrap(),
#[cfg(feature = "v2")]
None,
#[cfg(feature = "v2")]
None,
)
.amount(amount)
.message("message".to_string())
Expand Down
46 changes: 46 additions & 0 deletions payjoin/src/uri/url_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::OhttpKeys;
pub(crate) trait UrlExt {
fn ohttp(&self) -> Result<Option<OhttpKeys>, PercentDecodeError>;
fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>) -> Result<(), PercentDecodeError>;
fn exp(&self) -> Result<Option<std::time::SystemTime>, PercentDecodeError>;
fn set_exp(&mut self, exp: Option<std::time::SystemTime>) -> Result<(), PercentDecodeError>;
}

// Characters '=' and '&' conflict with BIP21 URI parameters and must be percent-encoded
Expand Down Expand Up @@ -55,6 +57,50 @@ impl UrlExt for Url {
self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) });
Ok(())
}

/// Retrieve the exp parameter from the URL fragment
fn exp(&self) -> Result<Option<std::time::SystemTime>, PercentDecodeError> {
if let Some(fragment) = self.fragment() {
let decoded_fragment =
percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy();
for param in decoded_fragment.split('&') {
if let Some(value) = param.strip_prefix("exp=") {
if let Ok(timestamp) = value.parse::<u64>() {
return Ok(Some(
std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp),
));
}
}
}
}
Ok(None)
}

/// Set the exp parameter in the URL fragment
fn set_exp(&mut self, exp: Option<std::time::SystemTime>) -> Result<(), PercentDecodeError> {
let fragment = self.fragment().unwrap_or("").to_string();
let mut fragment =
percent_encoding::percent_decode_str(&fragment)?.decode_utf8_lossy().to_string();
if let Some(start) = fragment.find("exp=") {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.replace_range(start..end, "");
if fragment.ends_with('&') {
fragment.pop();
}
}
if let Some(exp) = exp {
let timestamp = exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
let new_exp = format!("exp={}", timestamp);
if !fragment.is_empty() {
fragment.push('&');
}
fragment.push_str(&new_exp);
}
let encoded_fragment =
percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string();
self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) });
Ok(())
}
}

#[cfg(test)]
Expand Down

0 comments on commit b834465

Please sign in to comment.