Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test flakiness #388

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub async fn listen_tcp_with_tls(
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let tls_acceptor = init_tls_acceptor(tls_config)?;
let listener = TcpListener::bind(bind_addr).await?;
println!("Directory listening on tcp://{}", bind_addr);
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
Expand Down
171 changes: 96 additions & 75 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,38 +184,81 @@ mod integration {
use reqwest::{Client, ClientBuilder, Error, Response};
use testcontainers_modules::redis::Redis;
use testcontainers_modules::testcontainers::clients::Cli;
use tokio::sync::OnceCell as AsyncOnceCell;

use super::*;

static TESTS_TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(20));
static WAIT_SERVICE_INTERVAL: Lazy<Duration> = Lazy::new(|| Duration::from_secs(3));
static DIRECTORY_PORT: Lazy<u16> = Lazy::new(find_free_port);
static OHTTP_RELAY_PORT: Lazy<u16> = Lazy::new(find_free_port);
// Shared test infrastructure
static TEST_INFRASTRUCTURE: AsyncOnceCell<TestInfrastructure> = AsyncOnceCell::const_new();

struct TestInfrastructure {
directory: Url,
ohttp_relay: Url,
agent: Arc<Client>,
cert: Vec<u8>,
}

impl TestInfrastructure {
async fn new() -> Result<Self, BoxError> {
let (cert, key) = local_cert_key();
let directory = Url::parse(&format!("https://localhost:{}", *DIRECTORY_PORT))?;
let ohttp_relay = Url::parse(&format!("http://localhost:{}", *OHTTP_RELAY_PORT))?;
let gateway_origin = http::Uri::from_str(directory.as_str())?;

// Start services in background tasks
let _directory_handle =
tokio::spawn(init_directory(*DIRECTORY_PORT, (cert.clone(), key)));
let _relay_handle =
tokio::spawn(ohttp_relay::listen_tcp(*OHTTP_RELAY_PORT, gateway_origin));

// Create HTTP agent
let agent = Arc::new(http_agent(cert.clone())?);

// Wait for services to be ready
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await?;
wait_for_service_ready(directory.clone(), agent.clone()).await?;

Ok(Self { directory, ohttp_relay, agent, cert })
}
}

async fn init_infrastructure() -> &'static TestInfrastructure {
TEST_INFRASTRUCTURE
.get_or_init(|| async {
TestInfrastructure::new()
.await
.expect("Failed to initialize test infrastructure")
})
.await
}

#[tokio::test]
async fn test_bad_ohttp_keys() {
let bad_ohttp_keys =
OhttpKeys::from_str("AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw")
.expect("Invalid OhttpKeys");

let (cert, key) = local_cert_key();
let port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();
tokio::select!(
_ = init_directory(port, (cert.clone(), key)) => panic!("Directory server is long running"),
res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => {
assert_eq!(
res.unwrap().headers().get("content-type").unwrap(),
"application/problem+json"
);
}
let infra = init_infrastructure().await;
let res = try_request_with_bad_keys(
infra.directory.clone(),
bad_ohttp_keys,
infra.cert.clone(),
)
.await;
assert_eq!(
res.unwrap().headers().get("content-type").unwrap(),
"application/problem+json"
);

async fn try_request_with_bad_keys(
directory: Url,
bad_ohttp_keys: OhttpKeys,
cert_der: Vec<u8>,
) -> Result<Response, Error> {
let agent = Arc::new(http_agent(cert_der.clone()).unwrap());
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let agent = Arc::new(http_agent(cert_der).unwrap());
let mock_ohttp_relay = directory.clone(); // pass through to directory
let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4")
.unwrap()
Expand All @@ -230,18 +273,14 @@ mod integration {
#[tokio::test]
async fn test_session_expiration() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);
let infra = init_infrastructure().await;
let res = do_expiration_tests(
infra.ohttp_relay.clone(),
infra.directory.clone(),
infra.cert.clone(),
)
.await;
assert!(res.is_ok(), "v2 send receive failed: {:#?}", res);

async fn do_expiration_tests(
ohttp_relay: Url,
Expand All @@ -250,8 +289,6 @@ mod integration {
) -> Result<(), BoxError> {
let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?;
let agent = Arc::new(http_agent(cert_der.clone())?);
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys =
payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone())
.await?;
Expand Down Expand Up @@ -298,18 +335,14 @@ mod integration {
#[tokio::test]
async fn v2_to_v2() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);
let infra = init_infrastructure().await;
let res = do_v2_send_receive(
infra.ohttp_relay.clone(),
infra.directory.clone(),
infra.cert.clone(),
)
.await;
assert!(res.is_ok(), "v2 send receive failed: {:#?}", res);

async fn do_v2_send_receive(
ohttp_relay: Url,
Expand All @@ -318,8 +351,6 @@ mod integration {
) -> Result<(), BoxError> {
let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?;
let agent = Arc::new(http_agent(cert_der.clone())?);
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys =
payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone())
.await?;
Expand Down Expand Up @@ -430,18 +461,14 @@ mod integration {
#[tokio::test]
async fn v2_to_v2_mixed_input_script_types() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);
let infra = init_infrastructure().await;
let res = do_v2_send_receive(
infra.ohttp_relay.clone(),
infra.directory.clone(),
infra.cert.clone(),
)
.await;
assert!(res.is_ok(), "v2 send receive failed: {:#?}", res);

async fn do_v2_send_receive(
ohttp_relay: Url,
Expand All @@ -450,15 +477,12 @@ mod integration {
) -> Result<(), BoxError> {
let (bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?;
let agent = Arc::new(http_agent(cert_der.clone())?);
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap();
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
let ohttp_keys =
payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone())
.await?;
// **********************
// Inside the Receiver:
// make utxos with different script types

let legacy_address =
receiver.get_new_address(None, Some(AddressType::Legacy))?.assume_checked();
let nested_segwit_address =
Expand Down Expand Up @@ -647,18 +671,11 @@ mod integration {
#[tokio::test]
async fn v1_to_v2() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()),
);
let infra = init_infrastructure().await;
let res =
do_v1_to_v2(infra.ohttp_relay.clone(), infra.directory.clone(), infra.cert.clone())
.await;
assert!(res.is_ok());

async fn do_v1_to_v2(
ohttp_relay: Url,
Expand All @@ -667,8 +684,6 @@ mod integration {
) -> Result<(), BoxError> {
let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?;
let agent: Arc<Client> = Arc::new(http_agent(cert_der.clone())?);
wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await?;
wait_for_service_ready(directory.clone(), agent.clone()).await?;
let ohttp_keys =
payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone())
.await?;
Expand Down Expand Up @@ -780,13 +795,19 @@ mod integration {
async fn init_directory(
port: u16,
local_cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(), BoxError> {
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let docker: Cli = Cli::default();
let timeout = Duration::from_secs(2);
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key)
.await
.map_err(|e| {
let err_string = e.to_string();
Box::new(std::io::Error::new(std::io::ErrorKind::Other, err_string))
as Box<dyn std::error::Error + Send + Sync + 'static>
})
}

// generates or gets a DER encoded localhost cert and key.
Expand Down Expand Up @@ -1237,7 +1258,7 @@ mod integration {
handle_proposal(proposal, receiver, custom_outputs, drain_script, custom_inputs)?;
assert!(!proposal.is_output_substitution_disabled());
let psbt = proposal.psbt();
tracing::debug!("Receiver's Payjoin proposal PSBT: {:#?}", &psbt);
//tracing::debug!("Receiver's Payjoin proposal PSBT: {:#?}", &psbt);
Ok(psbt.to_string())
}

Expand Down Expand Up @@ -1328,7 +1349,7 @@ mod integration {
let payjoin_psbt = sender.wallet_process_psbt(&psbt.to_string(), None, None, None)?.psbt;
let payjoin_psbt = sender.finalize_psbt(&payjoin_psbt, Some(false))?.psbt.unwrap();
let payjoin_psbt = Psbt::from_str(&payjoin_psbt)?;
tracing::debug!("Sender's Payjoin PSBT: {:#?}", payjoin_psbt);
//tracing::debug!("Sender's Payjoin PSBT: {:#?}", payjoin_psbt);

Ok(payjoin_psbt.extract_tx()?)
}
Expand Down
Loading