Skip to content

Commit

Permalink
refactor(tcp): reducing branching in Transport::create_socket
Browse files Browse the repository at this point in the history
Following #4289 (comment), hereby is the PR to also improve the `create_socket`  using [`for_addr`](https://docs.rs/socket2/latest/socket2/struct.Domain.html#method.for_address). We also add a test for listening on IPv4 and IPv6 separately.

Pull-Request: #4328.
  • Loading branch information
jxs authored Aug 18, 2023
1 parent 08292c5 commit bf7fe68
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
8 changes: 6 additions & 2 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,17 +926,21 @@ mod tests {
let keypair = libp2p_identity::Keypair::generate_ed25519();
let config = Config::new(&keypair);
let mut transport = crate::tokio::Transport::new(config);
let port = {
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
socket.local_addr().unwrap().port()
};

transport
.listen_on(
ListenerId::next(),
"/ip4/0.0.0.0/udp/4001/quic-v1".parse().unwrap(),
format!("/ip4/0.0.0.0/udp/{port}/quic-v1").parse().unwrap(),
)
.unwrap();
transport
.listen_on(
ListenerId::next(),
"/ip6/::/udp/4001/quic-v1".parse().unwrap(),
format!("/ip6/::/udp/{port}/quic-v1").parse().unwrap(),
)
.unwrap();
}
Expand Down
55 changes: 46 additions & 9 deletions transports/tcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,12 @@ where
}
}

fn create_socket(&self, socket_addr: &SocketAddr) -> io::Result<Socket> {
let domain = if socket_addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
let socket = Socket::new(
Domain::for_address(socket_addr),
Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
if socket_addr.is_ipv6() {
socket.set_only_v6(true)?;
}
Expand All @@ -375,7 +374,7 @@ where
id: ListenerId,
socket_addr: SocketAddr,
) -> io::Result<ListenStream<T>> {
let socket = self.create_socket(&socket_addr)?;
let socket = self.create_socket(socket_addr)?;
socket.bind(&socket_addr.into())?;
socket.listen(self.config.backlog as _)?;
socket.set_nonblocking(true)?;
Expand Down Expand Up @@ -476,7 +475,7 @@ where
log::debug!("dialing {}", socket_addr);

let socket = self
.create_socket(&socket_addr)
.create_socket(socket_addr)
.map_err(TransportError::Other)?;

if let Some(addr) = self.port_reuse.local_dial_addr(&socket_addr.ip()) {
Expand Down Expand Up @@ -1329,4 +1328,42 @@ mod tests {
assert!(rt.block_on(cycle_listeners::<tokio::Tcp>()));
}
}

#[test]
fn test_listens_ipv4_ipv6_separately() {
fn test<T: Provider>() {
let port = {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
};
let mut tcp = Transport::<T>::default().boxed();
let listener_id = ListenerId::next();
tcp.listen_on(
listener_id,
format!("/ip4/0.0.0.0/tcp/{port}").parse().unwrap(),
)
.unwrap();
tcp.listen_on(
ListenerId::next(),
format!("/ip6/::/tcp/{port}").parse().unwrap(),
)
.unwrap();
}
#[cfg(feature = "async-io")]
{
async_std::task::block_on(async {
test::<async_io::Tcp>();
})
}
#[cfg(feature = "tokio")]
{
let rt = ::tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();
rt.block_on(async {
test::<async_io::Tcp>();
});
}
}
}

0 comments on commit bf7fe68

Please sign in to comment.