Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport/common: Share DNS lookups between TCP and WebSocket #151

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 93 additions & 13 deletions src/transport/common/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ use multiaddr::{Multiaddr, Protocol};
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
use socket2::{Domain, Socket, Type};
use tokio::net::{TcpListener as TokioTcpListener, TcpStream};
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};

use std::{
io,
Expand All @@ -46,7 +50,73 @@ pub enum AddressType {
Socket(SocketAddr),

/// DNS address.
Dns(String, u16),
Dns {
address: String,
port: u16,
dns_type: DnsType,
},
}

/// The DNS type of the address.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnsType {
/// DNS supports both IPv4 and IPv6.
Dns,
/// DNS supports only IPv4.
Dns4,
/// DNS supports only IPv6.
Dns6,
}

impl AddressType {
/// Resolve the address to a concrete IP.
pub async fn lookup_ip(self) -> crate::Result<SocketAddr> {
let (url, port, dns_type) = match self {
// We already have the IP address.
AddressType::Socket(address) => return Ok(address),
AddressType::Dns {
address,
port,
dns_type,
} => (address, port, dns_type),
};

let lookup =
match TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default())
.lookup_ip(url.clone())
.await
{
Ok(lookup) => lookup,
Err(error) => {
tracing::debug!(
target: LOG_TARGET,
?error,
"failed to resolve DNS address `{}`",
url
);

return Err(Error::Other(format!("Failed to resolve DNS address {url}")));
}
};

let Some(ip) = lookup.iter().find(|ip| match dns_type {
DnsType::Dns => true,
DnsType::Dns4 => ip.is_ipv4(),
DnsType::Dns6 => ip.is_ipv6(),
}) else {
tracing::debug!(
target: LOG_TARGET,
"Multiaddr DNS type does not match IP version `{}`",
url
);

return Err(Error::Other(format!(
"Miss-match in DNS address IP version {url}"
)));
};

Ok(SocketAddr::new(ip, port))
}
}

/// Local addresses to use for outbound connections.
Expand Down Expand Up @@ -167,7 +237,7 @@ impl SocketListener {
.into_iter()
.filter_map(|address| {
let address = match T::multiaddr_to_socket_address(&address).ok()?.0 {
AddressType::Dns(address, port) => {
AddressType::Dns { address, port, .. } => {
tracing::debug!(
target: LOG_TARGET,
?address,
Expand Down Expand Up @@ -286,10 +356,14 @@ fn multiaddr_to_socket_address(
tracing::trace!(target: LOG_TARGET, ?address, "parse multi address");

let mut iter = address.iter();
let socket_address = match iter.next() {
Some(Protocol::Ip6(address)) => match iter.next() {
Some(Protocol::Tcp(port)) =>
AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)),
// Small helper to handle DNS types.
let handle_dns_type =
|address: String, dns_type: DnsType, protocol: Option<Protocol>| match protocol {
Some(Protocol::Tcp(port)) => Ok(AddressType::Dns {
address,
port,
dns_type,
}),
protocol => {
tracing::error!(
target: LOG_TARGET,
Expand All @@ -298,10 +372,12 @@ fn multiaddr_to_socket_address(
);
return Err(Error::AddressError(AddressError::InvalidProtocol));
}
},
Some(Protocol::Ip4(address)) => match iter.next() {
};

let socket_address = match iter.next() {
Some(Protocol::Ip6(address)) => match iter.next() {
Some(Protocol::Tcp(port)) =>
AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)),
AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)),
protocol => {
tracing::error!(
target: LOG_TARGET,
Expand All @@ -311,10 +387,9 @@ fn multiaddr_to_socket_address(
return Err(Error::AddressError(AddressError::InvalidProtocol));
}
},
Some(Protocol::Dns(address))
| Some(Protocol::Dns4(address))
| Some(Protocol::Dns6(address)) => match iter.next() {
Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port),
Some(Protocol::Ip4(address)) => match iter.next() {
Some(Protocol::Tcp(port)) =>
AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)),
protocol => {
tracing::error!(
target: LOG_TARGET,
Expand All @@ -324,6 +399,11 @@ fn multiaddr_to_socket_address(
return Err(Error::AddressError(AddressError::InvalidProtocol));
}
},
Some(Protocol::Dns(address)) => handle_dns_type(address.into(), DnsType::Dns, iter.next())?,
Some(Protocol::Dns4(address)) =>
handle_dns_type(address.into(), DnsType::Dns4, iter.next())?,
Some(Protocol::Dns6(address)) =>
handle_dns_type(address.into(), DnsType::Dns6, iter.next())?,
protocol => {
tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol");
return Err(Error::AddressError(AddressError::InvalidProtocol));
Expand Down
24 changes: 20 additions & 4 deletions src/transport/tcp/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ use crate::{
multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version},
protocol::{Direction, Permit, ProtocolCommand, ProtocolSet},
substream,
transport::{common::listener::AddressType, tcp::substream::Substream, Endpoint},
transport::{
common::listener::{AddressType, DnsType},
tcp::substream::Substream,
Endpoint,
},
types::{protocol::ProtocolName, ConnectionId, SubstreamId},
BandwidthSink, PeerId,
};
Expand Down Expand Up @@ -455,9 +459,21 @@ impl TcpConnection {
AddressType::Socket(address) => Multiaddr::empty()
.with(Protocol::from(address.ip()))
.with(Protocol::Tcp(address.port())),
AddressType::Dns(address, port) => Multiaddr::empty()
.with(Protocol::Dns(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
AddressType::Dns {
address,
port,
dns_type,
} => match dns_type {
DnsType::Dns => Multiaddr::empty()
.with(Protocol::Dns(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
DnsType::Dns4 => Multiaddr::empty()
.with(Protocol::Dns4(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
DnsType::Dns6 => Multiaddr::empty()
.with(Protocol::Dns6(Cow::Owned(address)))
.with(Protocol::Tcp(port)),
},
Comment on lines +462 to +476
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we had another miss-match, where regardless of the initial DNS version we'd accept both ipv4 or ipv6

};
let endpoint = match role {
Role::Dialer => Endpoint::dialer(address, connection_id),
Expand Down
63 changes: 8 additions & 55 deletions src/transport/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
config::Role,
error::Error,
transport::{
common::listener::{AddressType, DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
manager::TransportHandle,
tcp::{
config::Config,
Expand All @@ -40,13 +40,9 @@ use futures::{
future::BoxFuture,
stream::{FuturesUnordered, Stream, StreamExt},
};
use multiaddr::{Multiaddr, Protocol};
use multiaddr::Multiaddr;
use socket2::{Domain, Socket, Type};
use tokio::net::TcpStream;
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};

use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -139,55 +135,12 @@ impl TcpTransport {
nodelay: bool,
) -> crate::Result<(Multiaddr, TcpStream)> {
let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?;
let remote_address = match socket_address {
AddressType::Socket(address) => address,
AddressType::Dns(url, port) => {
let address = address.clone();
let future = async move {
match TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
)
.lookup_ip(url.clone())
.await
{
// TODO: ugly
Ok(lookup) => {
let iter = lookup.iter();
for ip in iter {
match (
address.iter().next().expect("protocol to exist"),
ip.is_ipv4(),
) {
(Protocol::Dns(_), true)
| (Protocol::Dns4(_), true)
| (Protocol::Dns6(_), false) => {
tracing::trace!(
target: LOG_TARGET,
?address,
?ip,
"address resolved",
);

return Ok(SocketAddr::new(ip, port));
}
_ => {}
}
}

Err(Error::Unknown)
}
Err(_) => Err(Error::Unknown),
}
};

match tokio::time::timeout(connection_open_timeout, future).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
}
}
};
let remote_address =
match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
};

let domain = match remote_address.is_ipv4() {
true => Domain::IPV4,
Expand Down
67 changes: 9 additions & 58 deletions src/transport/websocket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ use crate::{
config::Role,
error::{AddressError, Error},
transport::{
common::listener::{
AddressType, DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress,
},
common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
manager::TransportHandle,
websocket::{
config::Config,
Expand All @@ -43,15 +41,11 @@ use multiaddr::{Multiaddr, Protocol};
use socket2::{Domain, Socket, Type};
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};

use url::Url;

use std::{
collections::{HashMap, HashSet},
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -186,57 +180,14 @@ impl WebSocketTransport {
nodelay: bool,
) -> crate::Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>)> {
let (url, _) = Self::multiaddr_into_url(address.clone())?;
let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;

let remote_address = match socket_address {
AddressType::Socket(address) => address,
AddressType::Dns(url, port) => {
let address = address.clone();
let future = async move {
match TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
)
.lookup_ip(url.clone())
.await
{
// TODO: ugly
Ok(lookup) => {
let iter = lookup.iter();
for ip in iter {
match (
address.iter().next().expect("protocol to exist"),
ip.is_ipv4(),
) {
(Protocol::Dns(_), true)
| (Protocol::Dns4(_), true)
| (Protocol::Dns6(_), false) => {
tracing::trace!(
target: LOG_TARGET,
?address,
?ip,
"address resolved",
);

return Ok(SocketAddr::new(ip, port));
}
_ => {}
}
}

Err(Error::Unknown)
}
Err(_) => Err(Error::Unknown),
}
};

match tokio::time::timeout(connection_open_timeout, future).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
}
}
};
let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
let remote_address =
match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
Err(_) => return Err(Error::Timeout),
Ok(Err(error)) => return Err(error),
Ok(Ok(address)) => address,
};

let domain = match remote_address.is_ipv4() {
true => Domain::IPV4,
Expand Down