diff --git a/.unreleased/LLT-5884 b/.unreleased/LLT-5884 new file mode 100644 index 000000000..0576ebf73 --- /dev/null +++ b/.unreleased/LLT-5884 @@ -0,0 +1 @@ +Restart postquantum if the last handshake has passed the wireguard reject timeout diff --git a/Cargo.lock b/Cargo.lock index 88f77cb5f..57fe4d06e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4863,6 +4863,7 @@ dependencies = [ "hmac", "mockall", "neptun", + "parking_lot", "pnet_packet", "pqcrypto-kyber", "pqcrypto-traits", diff --git a/crates/telio-pq/Cargo.toml b/crates/telio-pq/Cargo.toml index 85110cd01..9773c888d 100644 --- a/crates/telio-pq/Cargo.toml +++ b/crates/telio-pq/Cargo.toml @@ -27,6 +27,7 @@ telio-utils.workspace = true neptun.workspace = true pnet_packet.workspace = true mockall = { workspace = true, optional = true } +parking_lot.workspace = true rand.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/crates/telio-pq/src/entity.rs b/crates/telio-pq/src/entity.rs index d9e07290b..49007e1c5 100644 --- a/crates/telio-pq/src/entity.rs +++ b/crates/telio-pq/src/entity.rs @@ -1,12 +1,15 @@ use std::net::SocketAddr; use std::sync::Arc; +use parking_lot::Mutex; + use telio_model::features::FeaturePostQuantumVPN; use telio_task::io::chan; use telio_utils::telio_log_debug; struct Peer { pubkey: telio_crypto::PublicKey, + wg_secret: telio_crypto::SecretKey, addr: SocketAddr, /// This is a key rotation task guard, its `Drop` implementation aborts the task _rotation_task: super::conn::ConnKeyRotation, @@ -17,16 +20,16 @@ pub struct Entity { features: FeaturePostQuantumVPN, sockets: Arc, chan: chan::Tx, - peer: Option, + peer: Mutex>, } impl crate::PostQuantum for Entity { fn keys(&self) -> Option { - self.peer.as_ref().and_then(|p| p.keys.clone()) + self.peer.lock().as_ref().and_then(|p| p.keys.clone()) } fn is_rotating_keys(&self) -> bool { - self.peer.is_some() + self.peer.lock().is_some() } } @@ -40,44 +43,47 @@ impl Entity { features, sockets, chan, - peer: None, + peer: Mutex::new(None), } } - pub fn on_event(&mut self, event: super::Event) { - let Some(peer) = &mut self.peer else { - return; - }; - - match event { - super::Event::Handshake(addr, keys) => { - if peer.addr == addr { - peer.keys = Some(keys); + pub fn on_event(&self, event: super::Event) { + if let Some(peer) = self.peer.lock().as_mut() { + match event { + super::Event::Handshake(addr, keys) => { + if peer.addr == addr { + peer.keys = Some(keys); + } } - } - super::Event::Rekey(super::Keys { - wg_secret, - pq_shared, - }) => { - if let Some(keys) = &mut peer.keys { - // Check if we are still talking to the same exit node - if keys.wg_secret == wg_secret { - // and only then update the preshared key, - // otherwise we're connecting to different node already - keys.pq_shared = pq_shared; - } else { - telio_log_debug!( - "PQ secret key does not match, ignoring shared secret rotation" - ); + super::Event::Rekey(super::Keys { + wg_secret, + pq_shared, + }) => { + if let Some(keys) = &mut peer.keys { + // Check if we are still talking to the same exit node + if keys.wg_secret == wg_secret { + // and only then update the preshared key, + // otherwise we're connecting to different node already + keys.pq_shared = pq_shared; + } else { + telio_log_debug!( + "PQ secret key does not match, ignoring shared secret rotation" + ); + } } } + _ => (), } - _ => (), } } - pub async fn stop(&mut self) { - if let Some(peer) = self.peer.take() { + pub fn peer_addr(&self) -> Option { + self.peer.lock().as_ref().map(|p| p.addr) + } + + pub async fn stop(&self) { + let peer = self.peer.lock().take(); + if let Some(peer) = peer { #[allow(mpsc_blocking_send)] let _ = self .chan @@ -87,15 +93,16 @@ impl Entity { } pub async fn start( - &mut self, + &self, addr: SocketAddr, wg_secret: telio_crypto::SecretKey, peer: telio_crypto::PublicKey, ) { self.stop().await; - self.peer = Some(Peer { + *self.peer.lock() = Some(Peer { pubkey: peer, + wg_secret: wg_secret.clone(), addr, _rotation_task: super::conn::ConnKeyRotation::run( self.chan.clone(), @@ -111,4 +118,16 @@ impl Entity { #[allow(mpsc_blocking_send)] let _ = self.chan.send(super::Event::Connecting(peer)).await; } + + pub async fn restart(&self) { + let params = self.peer.lock().as_ref().map(|peer| { + let addr = peer.addr; + let pubkey = peer.pubkey; + let wg_secret = peer.wg_secret.clone(); + (addr, pubkey, wg_secret) + }); + if let Some((addr, pubkey, wg_secret)) = params { + self.start(addr, wg_secret, pubkey).await; + } + } } diff --git a/nat-lab/tests/test_pq.py b/nat-lab/tests/test_pq.py index 525fa4558..71b9a1b11 100644 --- a/nat-lab/tests/test_pq.py +++ b/nat-lab/tests/test_pq.py @@ -1,3 +1,4 @@ +import asyncio import config import pytest from contextlib import AsyncExitStack @@ -136,7 +137,7 @@ async def test_pq_vpn_connection( ) -> None: async with AsyncExitStack() as exit_stack: env = await exit_stack.enter_async_context( - setup_environment(exit_stack, [alpha_setup_params]) + setup_environment(exit_stack, [alpha_setup_params], prepare_vpn=True) ) client_conn, *_ = [conn.connection for conn in env.connections] @@ -238,7 +239,7 @@ async def test_pq_vpn_rekey( async with AsyncExitStack() as exit_stack: env = await exit_stack.enter_async_context( - setup_environment(exit_stack, [alpha_setup_params]) + setup_environment(exit_stack, [alpha_setup_params], prepare_vpn=True) ) client_conn, *_ = [conn.connection for conn in env.connections] @@ -510,3 +511,43 @@ async def test_pq_vpn_upgrade_from_non_pq( preshared = await read_preshared_key_slot(nlx_conn) assert preshared != EMPTY_PRESHARED_KEY_SLOT + + +# Regression test for LLT-5884 +@pytest.mark.timeout(240) +async def test_pq_vpn_handshake_after_nonet() -> None: + public_ip = "10.0.254.1" + async with AsyncExitStack() as exit_stack: + env = await exit_stack.enter_async_context( + setup_environment( + exit_stack, + [ + SetupParameters( + connection_tag=ConnectionTag.DOCKER_CONE_CLIENT_1, + adapter_type_override=TelioAdapterType.NEP_TUN, + is_meshnet=False, + ), + ], + prepare_vpn=True, + ) + ) + + client_conn, *_ = [conn.connection for conn in env.connections] + client_alpha, *_ = env.clients + + ip = await stun.get(client_conn, config.STUN_SERVER) + assert ip == public_ip, f"wrong public IP before connecting to VPN {ip}" + + await _connect_vpn_pq( + client_conn, + client_alpha, + ) + + async with client_alpha.get_router().break_udp_conn_to_host( + str(config.NLX_SERVER["ipv4"]) + ), client_alpha.get_router().break_tcp_conn_to_host( + str(config.NLX_SERVER["ipv4"]) + ): + await asyncio.sleep(195) + + await ping(client_conn, config.PHOTO_ALBUM_IP, timeout=10) diff --git a/src/device/wg_controller.rs b/src/device/wg_controller.rs index 51d33295c..d2d9b815f 100644 --- a/src/device/wg_controller.rs +++ b/src/device/wg_controller.rs @@ -69,6 +69,8 @@ pub async fn consolidate_wg_state( entities: &Entities, features: &Features, ) -> Result { + maybe_restart_pq(entities).await; + let remote_peer_states = if let Some(meshnet_entities) = entities.meshnet.left() { meshnet_entities.derp.get_remote_peer_states().await } else { @@ -147,6 +149,42 @@ pub async fn consolidate_wg_state( Ok(()) } +// Postquantum has a quirk that can cause VPN connections to fail due to the client having a preshared key when the server doesn't +// This can happen if there is a handshake and a preshared key, then the client nonets for more than 180s (wg reject threshold), +// and then tries to handshake again +// +// The purpose of this function is to prevent this by restarting postquantum if the above case is reached. There are two conditions +// that need to be fulfilled for us to restart postquantum: +// 1. postquantum is active (here represented by checking if there is a peer address) +// 2. the last handshake to the VPN server has passed the hardcoded wireguard reject threshold +// (here represented by comparing against `Some(None)`. The `time_since_last_handshake` will tick up to 180s and then be set to `None` so +// if `None` means that wireguard will reject the connection) +async fn maybe_restart_pq(entities: &Entities) { + let should_restart_pq = match entities.postquantum_wg.peer_addr() { + Some(addr) => { + let time_since_last_handshake = entities + .wireguard_interface + .get_interface() + .await + .ok() + .and_then(|ifc| { + ifc.peers.iter().find_map(|(_, peer)| { + if peer.endpoint.is_some_and(|e| e == addr) { + Some(peer.time_since_last_handshake) + } else { + None + } + }) + }); + matches!(time_since_last_handshake, Some(None)) + } + None => false, + }; + if should_restart_pq { + entities.postquantum_wg.restart().await; + } +} + async fn consolidate_wg_private_key( requested_state: &RequestedState, wireguard_interface: &W,