Skip to content

Commit

Permalink
Test send/recv expiration
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Jul 15, 2024
1 parent c303920 commit 3b8599d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 7 deletions.
2 changes: 1 addition & 1 deletion payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl AppTrait for App {
self.config.pj_directory.clone(),
ohttp_keys.clone(),
self.config.ohttp_relay.clone(),
std::time::Duration::from_secs(60 * 60),
Some(std::time::Duration::from_secs(60 * 60)),
);
let (req, ctx) =
initializer.extract_req().map_err(|e| anyhow!("Failed to extract request {}", e))?;
Expand Down
91 changes: 85 additions & 6 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,77 @@ mod integration {
}
}

#[tokio::test]
async fn test_session_expiration() {
std::env::set_var("RUST_LOG", "debug");
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => assert!(false, "Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => assert!(false, "Directory server is long running"),
res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

async fn do_expiration_tests(
ohttp_relay: Url,
directory: Url,
cert_der: Vec<u8>,
) -> Result<(), BoxError> {
let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver()?;
let agent = Arc::new(http_agent(cert_der.clone())?);
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys =
payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone())
.await?;

// **********************
// Inside the Receiver:
let address = receiver.get_new_address(None, None)?.assume_checked();
// test session with expiry in the past
let mut session = initialize_session(
address.clone(),
directory.clone(),
ohttp_keys.clone(),
cert_der,
Some(Duration::from_secs(0)),
)
.await?;
match session.extract_req() {
// Internal error types are private, so check against a string
Err(err) => assert!(err.to_string().contains("expired")),
_ => assert!(false, "Expired receive session should error"),
};
let pj_uri = session.pj_uri_builder().build();

// **********************
// Inside the Sender:
let psbt = build_original_psbt(&sender, &pj_uri)?;
// Test that an expired pj_url errors
let expired_pj_uri = payjoin::PjUriBuilder::new(
address,
directory.clone(),
Some(ohttp_keys),
Some(std::time::SystemTime::now()),
)
.build();
let mut expired_req_ctx = RequestBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
.build_non_incentivizing()?;
match expired_req_ctx.extract_v2(directory.to_owned()) {
// Internal error types are private, so check against a string
Err(err) => assert!(err.to_string().contains("expired")),
_ => assert!(false, "Expired send session should error"),
};
Ok(())
}
}

#[tokio::test]
async fn v2_to_v2() {
std::env::set_var("RUST_LOG", "debug");
Expand Down Expand Up @@ -376,16 +447,18 @@ mod integration {
// **********************
// Inside the Receiver:
let address = receiver.get_new_address(None, None)?.assume_checked();

// test session with expiry in the future
let mut session = initialize_session(
address,
address.clone(),
directory.clone(),
ohttp_keys.clone(),
cert_der.clone(),
None,
)
.await?;
println!("session: {:#?}", &session);
let pj_uri_string = session.pj_uri_builder().build().to_string();

// Poll receive request
let (req, ctx) = session.extract_req()?;
let response = agent.post(req.url).body(req.body).send().await?;
Expand Down Expand Up @@ -492,9 +565,14 @@ mod integration {
.await?;
let address = receiver.get_new_address(None, None)?.assume_checked();

let mut session =
initialize_session(address, directory, ohttp_keys.clone(), cert_der.clone())
.await?;
let mut session = initialize_session(
address,
directory,
ohttp_keys.clone(),
cert_der.clone(),
None,
)
.await?;

let pj_uri_string = session.pj_uri_builder().build().to_string();

Expand Down Expand Up @@ -614,14 +692,15 @@ mod integration {
directory: Url,
ohttp_keys: OhttpKeys,
cert_der: Vec<u8>,
custom_expire_after: Option<Duration>,
) -> Result<ActiveSession, BoxError> {
let mock_ohttp_relay = directory.clone(); // pass through to directory
let mut initializer = SessionInitializer::new(
address,
directory.clone(),
ohttp_keys,
mock_ohttp_relay.clone(),
Some(Duration::from_secs(60)),
custom_expire_after,
);
let (req, ctx) = initializer.extract_req()?;
println!("enroll req: {:#?}", &req);
Expand Down

0 comments on commit 3b8599d

Please sign in to comment.