From 4244114f7a15ae7155c2f5a3cf60aec26b997446 Mon Sep 17 00:00:00 2001 From: Rob Date: Tue, 6 Aug 2024 16:54:59 -0400 Subject: [PATCH] deduplicate some code --- .../src/network/stake_table_transport.rs | 377 +++++++++--------- 1 file changed, 183 insertions(+), 194 deletions(-) diff --git a/crates/libp2p-networking/src/network/stake_table_transport.rs b/crates/libp2p-networking/src/network/stake_table_transport.rs index b6cad13cb4..1e2089f877 100644 --- a/crates/libp2p-networking/src/network/stake_table_transport.rs +++ b/crates/libp2p-networking/src/network/stake_table_transport.rs @@ -12,6 +12,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::Poll; use tracing::warn; +use {std::io::Error as IoError, std::io::ErrorKind as IoErrorKind}; use futures::future::poll_fn; use futures::{AsyncReadExt, AsyncWriteExt}; @@ -36,7 +37,8 @@ const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_se /// by performing a handshake that checks if the remote peer is present in the /// stake table. #[pin_project] -pub struct StakeTableAuthentication { +pub struct StakeTableAuthentication +{ #[pin] /// The underlying transport we are wrapping pub inner: T, @@ -51,6 +53,10 @@ pub struct StakeTableAuthentication, } +/// A type alias for the future that upgrades a connection to perform the authentication handshake +type UpgradeFuture = + Pin::Output, ::Error>> + Send>>; + impl StakeTableAuthentication { /// Create a new `StakeTableAuthentication` transport that wraps the given transport /// and authenticates connections against the stake table. @@ -62,6 +68,158 @@ impl StakeTableAuthentica pd: std::marker::PhantomData, } } + + /// Prove to the remote peer that we are in the stake table by sending + /// them our authentication message. + /// + /// # Errors + /// - If we fail to write the message to the stream + pub async fn authenticate_with_remote_peer( + stream: &mut W, + auth_message: Arc>>, + ) -> AnyhowResult<()> { + // If we have an auth message, send it to the remote peer, prefixed with + // the message length + if let Some(auth_message) = auth_message.as_ref() { + // Write the length-delimited message + write_length_delimited(stream, auth_message).await?; + } + + Ok(()) + } + + /// Verify that the remote peer is: + /// - In the stake table + /// - Sending us a valid authentication message + /// - Sending us a valid signature + /// - Matching the peer ID we expect + /// + /// # Errors + /// If the peer fails verification. This can happen if: + /// - We fail to read the message from the stream + /// - The message is too large + /// - The message is invalid + /// - The peer is not in the stake table + /// - The signature is invalid + pub async fn verify_peer_authentication( + stream: &mut R, + stake_table: Arc>>, + required_peer_id: &PeerId, + ) -> AnyhowResult<()> { + // If we have a stake table, check if the remote peer is in it + if let Some(stake_table) = stake_table.as_ref() { + // Read the length-delimited message from the remote peer + let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?; + + // Deserialize the authentication message + let auth_message: AuthMessage = bincode::deserialize(&message) + .with_context(|| "Failed to deserialize auth message")?; + + // Verify the signature on the public keys + let public_key = auth_message + .validate() + .with_context(|| "Failed to verify authentication message")?; + + // Deserialize the `PeerId` + let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes) + .with_context(|| "Failed to deserialize peer ID")?; + + // Verify that the peer ID is the same as the remote peer + if peer_id != *required_peer_id { + return Err(anyhow::anyhow!("Peer ID mismatch")); + } + + // Check if the public key is in the stake table + if !stake_table.contains(&public_key) { + return Err(anyhow::anyhow!("Peer not in stake table")); + } + } + + Ok(()) + } + + /// Wrap the supplied future in an upgrade that performs the authentication handshake. + /// + /// `outgoing` is a boolean that indicates if the connection is incoming or outgoing. + /// This is needed because the flow of the handshake is different for each. + fn gen_handshake> + Send + 'static>( + original_future: F, + outgoing: bool, + stake_table: Arc>>, + auth_message: Arc>>, + ) -> UpgradeFuture + where + T::Error: From<::Error> + From, + T::Output: AsOutput + Send, + + C::Substream: Unpin + Send, + { + // Create a new upgrade that performs the authentication handshake on top + Box::pin(async move { + // Wait for the original future to resolve + let mut stream = original_future.await?; + + // Time out the authentication block + async_timeout(AUTH_HANDSHAKE_TIMEOUT, async { + // Open a substream for the handshake. + // The handshake order depends on whether the connection is incoming or outgoing. + let mut substream = if outgoing { + poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await? + } else { + poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await? + }; + + if outgoing { + // If the connection is outgoing, authenticate with the remote peer first + Self::authenticate_with_remote_peer(&mut substream, auth_message) + .await + .map_err(|e| { + warn!("Failed to authenticate with remote peer: {:?}", e); + IoError::new(IoErrorKind::Other, e) + })?; + + // Verify the remote peer's authentication + Self::verify_peer_authentication( + &mut substream, + stake_table, + stream.as_peer_id(), + ) + .await + .map_err(|e| { + warn!("Failed to verify remote peer: {:?}", e); + IoError::new(IoErrorKind::Other, e) + })?; + } else { + // If it is incoming, verify the remote peer's authentication first + Self::verify_peer_authentication( + &mut substream, + stake_table, + stream.as_peer_id(), + ) + .await + .map_err(|e| { + warn!("Failed to verify remote peer: {:?}", e); + IoError::new(IoErrorKind::Other, e) + })?; + + // Authenticate with the remote peer + Self::authenticate_with_remote_peer(&mut substream, auth_message) + .await + .map_err(|e| { + warn!("Failed to authenticate with remote peer: {:?}", e); + IoError::new(IoErrorKind::Other, e) + })?; + } + + Ok(stream) + }) + .await + .map_err(|e| { + warn!("Timed out performing authentication handshake: {:?}", e); + IoError::new(IoErrorKind::TimedOut, e) + })? + }) + } } /// The deserialized form of an authentication message that is sent to the remote peer @@ -131,86 +289,13 @@ pub fn construct_auth_message( bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message") } -/// Prove to the remote peer that we are in the stake table by sending -/// them our authentication message. -/// -/// # Errors -/// - If we fail to write the message to the stream -pub async fn authenticate_with_remote_peer( - stream: &mut W, - auth_message: Arc>>, -) -> AnyhowResult<()> { - // If we have an auth message, send it to the remote peer, prefixed with - // the message length - if let Some(auth_message) = auth_message.as_ref() { - // Write the length-delimited message - write_length_delimited(stream, auth_message).await?; - } - - Ok(()) -} - -/// Verify that the remote peer is: -/// - In the stake table -/// - Sending us a valid authentication message -/// - Sending us a valid signature -/// - Matching the peer ID we expect -/// -/// # Errors -/// If the peer fails verification. This can happen if: -/// - We fail to read the message from the stream -/// - The message is too large -/// - The message is invalid -/// - The peer is not in the stake table -/// - The signature is invalid -pub async fn verify_peer_authentication< - R: AsyncReadExt + Unpin, - S: SignatureKey, - H: BuildHasher, ->( - stream: &mut R, - stake_table: Arc>>, - required_peer_id: &PeerId, -) -> AnyhowResult<()> { - // If we have a stake table, check if the remote peer is in it - if let Some(stake_table) = stake_table.as_ref() { - // Read the length-delimited message from the remote peer - let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?; - - // Deserialize the authentication message - let auth_message: AuthMessage = - bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?; - - // Verify the signature on the public keys - let public_key = auth_message - .validate() - .with_context(|| "Failed to verify authentication message")?; - - // Deserialize the `PeerId` - let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes) - .with_context(|| "Failed to deserialize peer ID")?; - - // Verify that the peer ID is the same as the remote peer - if peer_id != *required_peer_id { - return Err(anyhow::anyhow!("Peer ID mismatch")); - } - - // Check if the public key is in the stake table - if !stake_table.contains(&public_key) { - return Err(anyhow::anyhow!("Peer not in stake table")); - } - } - - Ok(()) -} - impl Transport for StakeTableAuthentication where T::Dial: Future> + Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsOutput + Send, - T::Error: From<::Error> + From, + T::Error: From<::Error> + From, C::Substream: Unpin + Send, { @@ -237,42 +322,7 @@ where // If the dial was successful, perform the authentication handshake on top match res { - Ok(dial) => Ok(Box::pin(async move { - // Perform the inner dial - let mut stream = dial.await?; - - // Time out the authentication block - async_timeout(AUTH_HANDSHAKE_TIMEOUT, async { - // Open a substream for the handshake - let mut substream = - poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?; - - // (outbound) Authenticate with the remote peer - authenticate_with_remote_peer(&mut substream, auth_message) - .await - .map_err(|e| { - warn!("Failed to authenticate with remote peer: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; - - // (inbound) Verify the remote peer's authentication - verify_peer_authentication(&mut substream, stake_table, stream.as_peer_id()) - .await - .map_err(|e| { - warn!("Failed to verify remote peer: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; - - Ok::<(), T::Error>(()) - }) - .await - .map_err(|e| { - warn!("Timed out during authentication handshake: {:?}", e); - std::io::Error::new(std::io::ErrorKind::TimedOut, e) - })??; - - Ok(stream) - })), + Ok(dial) => Ok(Self::gen_handshake(dial, true, stake_table, auth_message)), Err(err) => Err(err), } } @@ -293,42 +343,7 @@ where // If the dial was successful, perform the authentication handshake on top match res { - Ok(dial) => Ok(Box::pin(async move { - // Perform the inner dial - let mut stream = dial.await?; - - // Time out the authentication block - async_timeout(AUTH_HANDSHAKE_TIMEOUT, async { - // Open a substream for the handshake - let mut substream = - poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?; - - // (inbound) Verify the remote peer's authentication - verify_peer_authentication(&mut substream, stake_table, stream.as_peer_id()) - .await - .map_err(|e| { - warn!("Failed to verify remote peer: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; - - // (outbound) Authenticate with the remote peer - authenticate_with_remote_peer(&mut substream, auth_message) - .await - .map_err(|e| { - warn!("Failed to authenticate with remote peer: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; - - Ok::<(), T::Error>(()) - }) - .await - .map_err(|e| { - warn!("Timed out performing authentication handshake: {:?}", e); - std::io::Error::new(std::io::ErrorKind::TimedOut, e) - })??; - - Ok(stream) - })), + Ok(dial) => Ok(Self::gen_handshake(dial, false, stake_table, auth_message)), Err(err) => Err(err), } } @@ -354,47 +369,9 @@ where let auth_message = Arc::clone(&self.auth_message); let stake_table = Arc::clone(&self.stake_table); - // Create a new upgrade that performs the authentication handshake on top - let auth_upgrade = Box::pin(async move { - // Perform the inner upgrade - let mut stream = upgrade.await?; - - // Time out the authentication block - async_timeout(AUTH_HANDSHAKE_TIMEOUT, async { - // Open a substream for the handshake - let mut substream = - poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?; - - // (inbound) Verify the remote peer's authentication - verify_peer_authentication( - &mut substream, - stake_table, - stream.as_peer_id(), - ) - .await - .map_err(|e| { - warn!("Failed to verify remote peer: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; - - // (outbound) Authenticate with the remote peer - authenticate_with_remote_peer(&mut substream, auth_message) - .await - .map_err(|e| { - warn!("Failed to authenticate with remote peer: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; - - Ok::<(), T::Error>(()) - }) - .await - .map_err(|e| { - warn!("Timed out performing authentication handshake: {:?}", e); - std::io::Error::new(std::io::ErrorKind::TimedOut, e) - })??; - - Ok(stream) - }); + // Generate the handshake upgrade future (inbound) + let auth_upgrade = + Self::gen_handshake(upgrade, false, stake_table, auth_message); // Return the new event TransportEvent::Incoming { @@ -539,6 +516,7 @@ pub async fn write_length_delimited( #[cfg(test)] mod test { + use libp2p::{core::transport::dummy::DummyTransport, quic::Connection}; use rand::Rng; use std::{collections::HashSet, sync::Arc}; @@ -546,7 +524,10 @@ mod test { use super::write_length_delimited; use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey}; - use super::verify_peer_authentication; + use super::StakeTableAuthentication; + + /// A mock type to help with readability + type MockStakeTableAuth = StakeTableAuthentication; // Helper macro for generating a new identity and authentication message macro_rules! new_identity { @@ -657,8 +638,12 @@ mod test { stake_table.insert(keypair.0); // Verify the authentication message - let result = - verify_peer_authentication(&mut stream, Arc::new(Some(stake_table)), &peer_id).await; + let result = MockStakeTableAuth::verify_peer_authentication( + &mut stream, + Arc::new(Some(stake_table)), + &peer_id, + ) + .await; assert!( result.is_ok(), @@ -679,8 +664,12 @@ mod test { let stake_table: HashSet = std::collections::HashSet::new(); // Verify the authentication message - let result = - verify_peer_authentication(&mut stream, Arc::new(Some(stake_table)), &peer_id).await; + let result = MockStakeTableAuth::verify_peer_authentication( + &mut stream, + Arc::new(Some(stake_table)), + &peer_id, + ) + .await; // Make sure it errored for the right reason assert!( @@ -709,7 +698,7 @@ mod test { stake_table.insert(keypair.0); // Check against the malicious peer ID - let result = verify_peer_authentication( + let result = MockStakeTableAuth::verify_peer_authentication( &mut stream, Arc::new(Some(stake_table)), &malicious_peer_id,