Skip to content

Commit

Permalink
transport/common: Share DNS lookups between TCP and WebSocket (#151)
Browse files Browse the repository at this point in the history
Move the dns resolving to a dedicated common module and return specific
errors on failures.

Part of: #70

---------

Signed-off-by: Alexandru Vasile <[email protected]>
  • Loading branch information
lexnv authored Jun 13, 2024
1 parent cb21f17 commit e43b0b8
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 130 deletions.
106 changes: 93 additions & 13 deletions src/transport/common/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ use multiaddr::{Multiaddr, Protocol};
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
use socket2::{Domain, Socket, Type};
use tokio::net::{TcpListener as TokioTcpListener, TcpStream};
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};

use std::{
io,
Expand All @@ -46,7 +50,73 @@ pub enum AddressType {
Socket(SocketAddr),

/// DNS address.
Dns(String, u16),
Dns {
address: String,
port: u16,
dns_type: DnsType,
},
}

/// The DNS type of the address.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnsType {
/// DNS supports both IPv4 and IPv6.
Dns,
/// DNS supports only IPv4.
Dns4,
/// DNS supports only IPv6.
Dns6,
}

impl AddressType {
/// Resolve the address to a concrete IP.
pub async fn lookup_ip(self) -> crate::Result<SocketAddr> {
let (url, port, dns_type) = match self {
// We already have the IP address.
AddressType::Socket(address) => return Ok(address),
AddressType::Dns {
address,
port,
dns_type,
} => (address, port, dns_type),
};

let lookup =
match TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default())
.lookup_ip(url.clone())
.await
{
Ok(lookup) => lookup,
Err(error) => {
tracing::debug!(
target: LOG_TARGET,
?error,
"failed to resolve DNS address `{}`",
url
);

return Err(Error::Other(format!("Failed to resolve DNS address {url}")));
}
};

let Some(ip) = lookup.iter().find(|ip| match dns_type {
DnsType::Dns => true,
DnsType::Dns4 => ip.is_ipv4(),
DnsType::Dns6 => ip.is_ipv6(),
}) else {
tracing::debug!(
target: LOG_TARGET,
"Multiaddr DNS type does not match IP version `{}`",
url
);

return Err(Error::Other(format!(
"Miss-match in DNS address IP version {url}"
)));
};

Ok(SocketAddr::new(ip, port))
}
}

/// Local addresses to use for outbound connections.
Expand Down Expand Up @@ -167,7 +237,7 @@ impl SocketListener {
.into_iter()
.filter_map(|address| {
let address = match T::multiaddr_to_socket_address(&address).ok()?.0 {
AddressType::Dns(address, port) => {
AddressType::Dns { address, port, .. } => {
tracing::debug!(
target: LOG_TARGET,
?address,
Expand Down Expand Up @@ -286,10 +356,14 @@ fn multiaddr_to_socket_address(
tracing::trace!(target: LOG_TARGET, ?address, "parse multi address");

let mut iter = address.iter();
let socket_address = match iter.next() {
Some(Protocol::Ip6(address)) => match iter.next() {
Some(Protocol::Tcp(port)) =>
AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)),
// Small helper to handle DNS types.
let handle_dns_type =
|address: String, dns_type: DnsType, protocol: Option<Protocol>| match protocol {
Some(Protocol::Tcp(port)) => Ok(AddressType::Dns {
address,
port,
dns_type,
}),
protocol => {
tracing::error!(
target: LOG_TARGET,
Expand All @@ -298,10 +372,12 @@ fn multiaddr_to_socket_address(
);
return Err(Error::AddressError(AddressError::InvalidProtocol));
}
},
Some(Protocol::Ip4(address)) => match iter.next() {
};

let socket_address = match iter.next() {
Some(Protocol::Ip6(address)) => match iter.next() {
Some(Protocol::Tcp(port)) =>
AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)),
AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)),
protocol => {
tracing::error!(
target: LOG_TARGET,
Expand All @@ -311,10 +387,9 @@ fn multiaddr_to_socket_address(
return Err(Error::AddressError(AddressError::InvalidProtocol));
}
},
Some(Protocol::Dns(address))
| Some(Protocol::Dns4(address))
| Some(Protocol::Dns6(address)) => match iter.next() {
Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port),
Some(Protocol::Ip4(address)) => match iter.next() {
Some(Protocol::Tcp(port)) =>
AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)),
protocol => {
tracing::error!(
target: LOG_TARGET,
Expand All @@ -324,6 +399,11 @@ fn multiaddr_to_socket_address(
return Err(Error::AddressError(AddressError::InvalidProtocol));
}
},
Some(Protocol::Dns(address)) => handle_dns_type(address.into(), DnsType::Dns, iter.next())?,
Some(Protocol::Dns4(address)) =>
handle_dns_type(address.into(), DnsType::Dns4, iter.next())?,
Some(Protocol::Dns6(address)) =>
handle_dns_type(address.into(), DnsType::Dns6, iter.next())?,
protocol => {
tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol");
return Err(Error::AddressError(AddressError::InvalidProtocol));
Expand Down
24 changes: 20 additions & 4 deletions src/transport/tcp/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ use crate::{
multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version},
protocol::{Direction, Permit, ProtocolCommand, ProtocolSet},
substream,
transport::{common::listener::AddressType, tcp::substream::Substream, Endpoint},
transport::{
common::listener::{AddressType, DnsType},
tcp::substream::Substream,
Endpoint,
},
types::{protocol::ProtocolName, ConnectionId, SubstreamId},
BandwidthSink, PeerId,
};
Expand Down Expand Up @@ -455,9 +459,21 @@ impl TcpConnection {
AddressType::Socket(address) => Multiaddr::empty()
.with(Protocol::from(address.ip()))
.with(Protocol::Tcp(address.port())),
AddressType::Dns(address, port) => Multiaddr::empty()
.with(Protocol::Dns(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
AddressType::Dns {
address,
port,
dns_type,
} => match dns_type {
DnsType::Dns => Multiaddr::empty()
.with(Protocol::Dns(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
DnsType::Dns4 => Multiaddr::empty()
.with(Protocol::Dns4(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
DnsType::Dns6 => Multiaddr::empty()
.with(Protocol::Dns6(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
},
};
let endpoint = match role {
Role::Dialer => Endpoint::dialer(address, connection_id),
Expand Down
63 changes: 8 additions & 55 deletions src/transport/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
config::Role,
error::Error,
transport::{
common::listener::{AddressType, DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
manager::TransportHandle,
tcp::{
config::Config,
Expand All @@ -40,13 +40,9 @@ use futures::{
future::BoxFuture,
stream::{FuturesUnordered, Stream, StreamExt},
};
use multiaddr::{Multiaddr, Protocol};
use multiaddr::Multiaddr;
use socket2::{Domain, Socket, Type};
use tokio::net::TcpStream;
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};

use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -139,55 +135,12 @@ impl TcpTransport {
nodelay: bool,
) -> crate::Result<(Multiaddr, TcpStream)> {
let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?;
let remote_address = match socket_address {
AddressType::Socket(address) => address,
AddressType::Dns(url, port) => {
let address = address.clone();
let future = async move {
match TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
)
.lookup_ip(url.clone())
.await
{
// TODO: ugly
Ok(lookup) => {
let iter = lookup.iter();
for ip in iter {
match (
address.iter().next().expect("protocol to exist"),
ip.is_ipv4(),
) {
(Protocol::Dns(_), true)
| (Protocol::Dns4(_), true)
| (Protocol::Dns6(_), false) => {
tracing::trace!(
target: LOG_TARGET,
?address,
?ip,
"address resolved",
);

return Ok(SocketAddr::new(ip, port));
}
_ => {}
}
}

Err(Error::Unknown)
}
Err(_) => Err(Error::Unknown),
}
};

match tokio::time::timeout(connection_open_timeout, future).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
}
}
};
let remote_address =
match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
};

let domain = match remote_address.is_ipv4() {
true => Domain::IPV4,
Expand Down
67 changes: 9 additions & 58 deletions src/transport/websocket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ use crate::{
config::Role,
error::{AddressError, Error},
transport::{
common::listener::{
AddressType, DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress,
},
common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
manager::TransportHandle,
websocket::{
config::Config,
Expand All @@ -43,15 +41,11 @@ use multiaddr::{Multiaddr, Protocol};
use socket2::{Domain, Socket, Type};
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};

use url::Url;

use std::{
collections::{HashMap, HashSet},
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -186,57 +180,14 @@ impl WebSocketTransport {
nodelay: bool,
) -> crate::Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>)> {
let (url, _) = Self::multiaddr_into_url(address.clone())?;
let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;

let remote_address = match socket_address {
AddressType::Socket(address) => address,
AddressType::Dns(url, port) => {
let address = address.clone();
let future = async move {
match TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
)
.lookup_ip(url.clone())
.await
{
// TODO: ugly
Ok(lookup) => {
let iter = lookup.iter();
for ip in iter {
match (
address.iter().next().expect("protocol to exist"),
ip.is_ipv4(),
) {
(Protocol::Dns(_), true)
| (Protocol::Dns4(_), true)
| (Protocol::Dns6(_), false) => {
tracing::trace!(
target: LOG_TARGET,
?address,
?ip,
"address resolved",
);

return Ok(SocketAddr::new(ip, port));
}
_ => {}
}
}

Err(Error::Unknown)
}
Err(_) => Err(Error::Unknown),
}
};

match tokio::time::timeout(connection_open_timeout, future).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
}
}
};
let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
let remote_address =
match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
};

let domain = match remote_address.is_ipv4() {
true => Domain::IPV4,
Expand Down

0 comments on commit e43b0b8

Please sign in to comment.