diff --git a/core/src/packet.rs b/core/src/packet.rs index 4fee75f5bd9675..e36223712c4091 100644 --- a/core/src/packet.rs +++ b/core/src/packet.rs @@ -9,7 +9,7 @@ use solana_metrics::inc_new_counter_debug; pub use solana_sdk::packet::{Meta, Packet, PACKET_DATA_SIZE}; use std::{io::Result, net::UdpSocket, time::Instant}; -pub fn recv_from(obj: &mut Packets, socket: &UdpSocket) -> Result { +pub fn recv_from(obj: &mut Packets, socket: &UdpSocket, max_wait_ms: usize) -> Result { let mut i = 0; //DOCUMENTED SIDE-EFFECT //Performance out of the IO without poll @@ -20,9 +20,11 @@ pub fn recv_from(obj: &mut Packets, socket: &UdpSocket) -> Result { socket.set_nonblocking(false)?; trace!("receiving on {}", socket.local_addr().unwrap()); let start = Instant::now(); - let mut total_size = 0; loop { - obj.packets.resize(i + NUM_RCVMMSGS, Packet::default()); + obj.packets.resize( + std::cmp::min(i + NUM_RCVMMSGS, PACKETS_PER_BATCH), + Packet::default(), + ); match recv_mmsg(socket, &mut obj.packets[i..]) { Err(_) if i > 0 => { if start.elapsed().as_millis() > 1 { @@ -33,16 +35,15 @@ pub fn recv_from(obj: &mut Packets, socket: &UdpSocket) -> Result { trace!("recv_from err {:?}", e); return Err(e); } - Ok((size, npkts)) => { + Ok((_, npkts)) => { if i == 0 { socket.set_nonblocking(true)?; } trace!("got {} packets", npkts); i += npkts; - total_size += size; // Try to batch into big enough buffers // will cause less re-shuffling later on. - if start.elapsed().as_millis() > 1 || total_size >= PACKETS_BATCH_SIZE { + if start.elapsed().as_millis() > max_wait_ms as u128 || i >= PACKETS_PER_BATCH { break; } } @@ -95,7 +96,7 @@ mod tests { } send_to(&p, &send_socket).unwrap(); - let recvd = recv_from(&mut p, &recv_socket).unwrap(); + let recvd = recv_from(&mut p, &recv_socket, 1).unwrap(); assert_eq!(recvd, p.packets.len()); @@ -127,4 +128,32 @@ mod tests { p2.data[0] = 4; assert!(p1 != p2); } + + #[test] + fn test_packet_resize() { + solana_logger::setup(); + let recv_socket = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let addr = recv_socket.local_addr().unwrap(); + let send_socket = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let mut p = Packets::default(); + p.packets.resize(PACKETS_PER_BATCH, Packet::default()); + + // Should only get PACKETS_PER_BATCH packets per iteration even + // if a lot more were sent, and regardless of packet size + for _ in 0..2 * PACKETS_PER_BATCH { + let mut p = Packets::default(); + p.packets.resize(1, Packet::default()); + for m in p.packets.iter_mut() { + m.meta.set_addr(&addr); + m.meta.size = 1; + } + send_to(&p, &send_socket).unwrap(); + } + + let recvd = recv_from(&mut p, &recv_socket, 100).unwrap(); + + // Check we only got PACKETS_PER_BATCH packets + assert_eq!(recvd, PACKETS_PER_BATCH); + assert_eq!(p.packets.capacity(), PACKETS_PER_BATCH); + } } diff --git a/core/src/retransmit_stage.rs b/core/src/retransmit_stage.rs index b007005555078b..1b66b48cb4f59e 100644 --- a/core/src/retransmit_stage.rs +++ b/core/src/retransmit_stage.rs @@ -331,7 +331,7 @@ mod tests { // it should send this over the sockets. retransmit_sender.send(packets).unwrap(); let mut packets = Packets::new(vec![]); - packet::recv_from(&mut packets, &me_retransmit).unwrap(); + packet::recv_from(&mut packets, &me_retransmit, 1).unwrap(); assert_eq!(packets.packets.len(), 1); assert_eq!(packets.packets[0].meta.repair, false); @@ -347,7 +347,7 @@ mod tests { let packets = Packets::new(vec![repair, Packet::default()]); retransmit_sender.send(packets).unwrap(); let mut packets = Packets::new(vec![]); - packet::recv_from(&mut packets, &me_retransmit).unwrap(); + packet::recv_from(&mut packets, &me_retransmit, 1).unwrap(); assert_eq!(packets.packets.len(), 1); assert_eq!(packets.packets[0].meta.repair, false); } diff --git a/core/src/streamer.rs b/core/src/streamer.rs index 83d183f45fa31f..a929c0e3d77ce5 100644 --- a/core/src/streamer.rs +++ b/core/src/streamer.rs @@ -49,7 +49,7 @@ fn recv_loop( if exit.load(Ordering::Relaxed) { return Ok(()); } - if let Ok(len) = packet::recv_from(&mut msgs, sock) { + if let Ok(len) = packet::recv_from(&mut msgs, sock, 1) { if len == NUM_RCVMMSGS { num_max_received += 1; } diff --git a/perf/src/cuda_runtime.rs b/perf/src/cuda_runtime.rs index ac094e4f7ec4c0..a4d146cd962d5c 100644 --- a/perf/src/cuda_runtime.rs +++ b/perf/src/cuda_runtime.rs @@ -151,6 +151,10 @@ impl PinnedVec { pub fn iter_mut(&mut self) -> PinnedIterMut { PinnedIterMut(self.x.iter_mut()) } + + pub fn capacity(&self) -> usize { + self.x.capacity() + } } impl<'a, T: Clone + Send + Sync + Default + Sized> IntoParallelIterator for &'a PinnedVec {