diff --git a/src/transport/common/listener.rs b/src/transport/common/listener.rs index aa7ea2fe..275d7beb 100644 --- a/src/transport/common/listener.rs +++ b/src/transport/common/listener.rs @@ -27,6 +27,10 @@ use multiaddr::{Multiaddr, Protocol}; use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig}; use socket2::{Domain, Socket, Type}; use tokio::net::{TcpListener as TokioTcpListener, TcpStream}; +use trust_dns_resolver::{ + config::{ResolverConfig, ResolverOpts}, + TokioAsyncResolver, +}; use std::{ io, @@ -46,7 +50,73 @@ pub enum AddressType { Socket(SocketAddr), /// DNS address. - Dns(String, u16), + Dns { + address: String, + port: u16, + dns_type: DnsType, + }, +} + +/// The DNS type of the address. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DnsType { + /// DNS supports both IPv4 and IPv6. + Dns, + /// DNS supports only IPv4. + Dns4, + /// DNS supports only IPv6. + Dns6, +} + +impl AddressType { + /// Resolve the address to a concrete IP. + pub async fn lookup_ip(self) -> crate::Result { + let (url, port, dns_type) = match self { + // We already have the IP address. + AddressType::Socket(address) => return Ok(address), + AddressType::Dns { + address, + port, + dns_type, + } => (address, port, dns_type), + }; + + let lookup = + match TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default()) + .lookup_ip(url.clone()) + .await + { + Ok(lookup) => lookup, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + "failed to resolve DNS address `{}`", + url + ); + + return Err(Error::Other(format!("Failed to resolve DNS address {url}"))); + } + }; + + let Some(ip) = lookup.iter().find(|ip| match dns_type { + DnsType::Dns => true, + DnsType::Dns4 => ip.is_ipv4(), + DnsType::Dns6 => ip.is_ipv6(), + }) else { + tracing::debug!( + target: LOG_TARGET, + "Multiaddr DNS type does not match IP version `{}`", + url + ); + + return Err(Error::Other(format!( + "Miss-match in DNS address IP version {url}" + ))); + }; + + Ok(SocketAddr::new(ip, port)) + } } /// Local addresses to use for outbound connections. @@ -167,7 +237,7 @@ impl SocketListener { .into_iter() .filter_map(|address| { let address = match T::multiaddr_to_socket_address(&address).ok()?.0 { - AddressType::Dns(address, port) => { + AddressType::Dns { address, port, .. } => { tracing::debug!( target: LOG_TARGET, ?address, @@ -286,10 +356,14 @@ fn multiaddr_to_socket_address( tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); let mut iter = address.iter(); - let socket_address = match iter.next() { - Some(Protocol::Ip6(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => - AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)), + // Small helper to handle DNS types. + let handle_dns_type = + |address: String, dns_type: DnsType, protocol: Option| match protocol { + Some(Protocol::Tcp(port)) => Ok(AddressType::Dns { + address, + port, + dns_type, + }), protocol => { tracing::error!( target: LOG_TARGET, @@ -298,10 +372,12 @@ fn multiaddr_to_socket_address( ); return Err(Error::AddressError(AddressError::InvalidProtocol)); } - }, - Some(Protocol::Ip4(address)) => match iter.next() { + }; + + let socket_address = match iter.next() { + Some(Protocol::Ip6(address)) => match iter.next() { Some(Protocol::Tcp(port)) => - AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)), + AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)), protocol => { tracing::error!( target: LOG_TARGET, @@ -311,10 +387,9 @@ fn multiaddr_to_socket_address( return Err(Error::AddressError(AddressError::InvalidProtocol)); } }, - Some(Protocol::Dns(address)) - | Some(Protocol::Dns4(address)) - | Some(Protocol::Dns6(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port), + Some(Protocol::Ip4(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => + AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)), protocol => { tracing::error!( target: LOG_TARGET, @@ -324,6 +399,11 @@ fn multiaddr_to_socket_address( return Err(Error::AddressError(AddressError::InvalidProtocol)); } }, + Some(Protocol::Dns(address)) => handle_dns_type(address.into(), DnsType::Dns, iter.next())?, + Some(Protocol::Dns4(address)) => + handle_dns_type(address.into(), DnsType::Dns4, iter.next())?, + Some(Protocol::Dns6(address)) => + handle_dns_type(address.into(), DnsType::Dns6, iter.next())?, protocol => { tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); return Err(Error::AddressError(AddressError::InvalidProtocol)); diff --git a/src/transport/tcp/connection.rs b/src/transport/tcp/connection.rs index eef67369..1dd43898 100644 --- a/src/transport/tcp/connection.rs +++ b/src/transport/tcp/connection.rs @@ -28,7 +28,11 @@ use crate::{ multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, substream, - transport::{common::listener::AddressType, tcp::substream::Substream, Endpoint}, + transport::{ + common::listener::{AddressType, DnsType}, + tcp::substream::Substream, + Endpoint, + }, types::{protocol::ProtocolName, ConnectionId, SubstreamId}, BandwidthSink, PeerId, }; @@ -455,9 +459,21 @@ impl TcpConnection { AddressType::Socket(address) => Multiaddr::empty() .with(Protocol::from(address.ip())) .with(Protocol::Tcp(address.port())), - AddressType::Dns(address, port) => Multiaddr::empty() - .with(Protocol::Dns(Cow::Owned(address))) - .with(Protocol::Tcp(port)), + AddressType::Dns { + address, + port, + dns_type, + } => match dns_type { + DnsType::Dns => Multiaddr::empty() + .with(Protocol::Dns(Cow::Owned(address))) + .with(Protocol::Tcp(port)), + DnsType::Dns4 => Multiaddr::empty() + .with(Protocol::Dns4(Cow::Owned(address))) + .with(Protocol::Tcp(port)), + DnsType::Dns6 => Multiaddr::empty() + .with(Protocol::Dns6(Cow::Owned(address))) + .with(Protocol::Tcp(port)), + }, }; let endpoint = match role { Role::Dialer => Endpoint::dialer(address, connection_id), diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 79073350..eb06a863 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -25,7 +25,7 @@ use crate::{ config::Role, error::Error, transport::{ - common::listener::{AddressType, DialAddresses, GetSocketAddr, SocketListener, TcpAddress}, + common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress}, manager::TransportHandle, tcp::{ config::Config, @@ -40,13 +40,9 @@ use futures::{ future::BoxFuture, stream::{FuturesUnordered, Stream, StreamExt}, }; -use multiaddr::{Multiaddr, Protocol}; +use multiaddr::Multiaddr; use socket2::{Domain, Socket, Type}; use tokio::net::TcpStream; -use trust_dns_resolver::{ - config::{ResolverConfig, ResolverOpts}, - TokioAsyncResolver, -}; use std::{ collections::{HashMap, HashSet}, @@ -139,55 +135,12 @@ impl TcpTransport { nodelay: bool, ) -> crate::Result<(Multiaddr, TcpStream)> { let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?; - let remote_address = match socket_address { - AddressType::Socket(address) => address, - AddressType::Dns(url, port) => { - let address = address.clone(); - let future = async move { - match TokioAsyncResolver::tokio( - ResolverConfig::default(), - ResolverOpts::default(), - ) - .lookup_ip(url.clone()) - .await - { - // TODO: ugly - Ok(lookup) => { - let iter = lookup.iter(); - for ip in iter { - match ( - address.iter().next().expect("protocol to exist"), - ip.is_ipv4(), - ) { - (Protocol::Dns(_), true) - | (Protocol::Dns4(_), true) - | (Protocol::Dns6(_), false) => { - tracing::trace!( - target: LOG_TARGET, - ?address, - ?ip, - "address resolved", - ); - - return Ok(SocketAddr::new(ip, port)); - } - _ => {} - } - } - - Err(Error::Unknown) - } - Err(_) => Err(Error::Unknown), - } - }; - - match tokio::time::timeout(connection_open_timeout, future).await { - Err(_) => return Err(Error::Timeout), - Ok(Err(error)) => return Err(error), - Ok(Ok(address)) => address, - } - } - }; + let remote_address = + match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await { + Err(_) => return Err(Error::Timeout), + Ok(Err(error)) => return Err(error), + Ok(Ok(address)) => address, + }; let domain = match remote_address.is_ipv4() { true => Domain::IPV4, diff --git a/src/transport/websocket/mod.rs b/src/transport/websocket/mod.rs index f3f9d6a8..cc2e3740 100644 --- a/src/transport/websocket/mod.rs +++ b/src/transport/websocket/mod.rs @@ -24,9 +24,7 @@ use crate::{ config::Role, error::{AddressError, Error}, transport::{ - common::listener::{ - AddressType, DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress, - }, + common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress}, manager::TransportHandle, websocket::{ config::Config, @@ -43,15 +41,11 @@ use multiaddr::{Multiaddr, Protocol}; use socket2::{Domain, Socket, Type}; use tokio::net::TcpStream; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; -use trust_dns_resolver::{ - config::{ResolverConfig, ResolverOpts}, - TokioAsyncResolver, -}; + use url::Url; use std::{ collections::{HashMap, HashSet}, - net::SocketAddr, pin::Pin, task::{Context, Poll}, time::Duration, @@ -186,57 +180,14 @@ impl WebSocketTransport { nodelay: bool, ) -> crate::Result<(Multiaddr, WebSocketStream>)> { let (url, _) = Self::multiaddr_into_url(address.clone())?; - let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?; - let remote_address = match socket_address { - AddressType::Socket(address) => address, - AddressType::Dns(url, port) => { - let address = address.clone(); - let future = async move { - match TokioAsyncResolver::tokio( - ResolverConfig::default(), - ResolverOpts::default(), - ) - .lookup_ip(url.clone()) - .await - { - // TODO: ugly - Ok(lookup) => { - let iter = lookup.iter(); - for ip in iter { - match ( - address.iter().next().expect("protocol to exist"), - ip.is_ipv4(), - ) { - (Protocol::Dns(_), true) - | (Protocol::Dns4(_), true) - | (Protocol::Dns6(_), false) => { - tracing::trace!( - target: LOG_TARGET, - ?address, - ?ip, - "address resolved", - ); - - return Ok(SocketAddr::new(ip, port)); - } - _ => {} - } - } - - Err(Error::Unknown) - } - Err(_) => Err(Error::Unknown), - } - }; - - match tokio::time::timeout(connection_open_timeout, future).await { - Err(_) => return Err(Error::Timeout), - Ok(Err(error)) => return Err(error), - Ok(Ok(address)) => address, - } - } - }; + let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?; + let remote_address = + match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await { + Err(_) => return Err(Error::Timeout), + Ok(Err(error)) => return Err(error), + Ok(Ok(address)) => address, + }; let domain = match remote_address.is_ipv4() { true => Domain::IPV4,