From a40be0857c7bf48e39f815417b0b5293cd8ed1aa Mon Sep 17 00:00:00 2001 From: Tyler Julian Date: Tue, 10 Jan 2017 19:11:56 -0800 Subject: [PATCH] libstd/net: Add `peek` APIs to UdpSocket and TcpStream These methods enable socket reads without side-effects. That is, repeated calls to peek() return identical data. This is accomplished by providing the POSIX flag MSG_PEEK to the underlying socket read operations. This also moves the current implementation of recv_from out of the platform-independent sys_common and into respective sys/windows and sys/unix implementations. This allows for more platform-dependent implementations. --- src/liblibc | 2 +- src/libstd/lib.rs | 1 + src/libstd/net/tcp.rs | 54 +++++++++++++++++++ src/libstd/net/udp.rs | 97 +++++++++++++++++++++++++++++++++++ src/libstd/sys/unix/net.rs | 45 ++++++++++++++-- src/libstd/sys/windows/c.rs | 1 + src/libstd/sys/windows/net.rs | 44 +++++++++++++++- src/libstd/sys_common/net.rs | 24 +++++---- 8 files changed, 251 insertions(+), 17 deletions(-) diff --git a/src/liblibc b/src/liblibc index 7d57bdcdbb565..cb7f66732175e 160000 --- a/src/liblibc +++ b/src/liblibc @@ -1 +1 @@ -Subproject commit 7d57bdcdbb56540f37afe5a934ce12d33a6ca7fc +Subproject commit cb7f66732175e6171587ed69656b7aae7dd2e6ec diff --git a/src/libstd/lib.rs b/src/libstd/lib.rs index 9557c520c5071..3c06409e3b18e 100644 --- a/src/libstd/lib.rs +++ b/src/libstd/lib.rs @@ -275,6 +275,7 @@ #![feature(oom)] #![feature(optin_builtin_traits)] #![feature(panic_unwind)] +#![feature(peek)] #![feature(placement_in_syntax)] #![feature(prelude_import)] #![feature(pub_restricted)] diff --git a/src/libstd/net/tcp.rs b/src/libstd/net/tcp.rs index ed1f08f9c9090..ba6160cc72331 100644 --- a/src/libstd/net/tcp.rs +++ b/src/libstd/net/tcp.rs @@ -296,6 +296,29 @@ impl TcpStream { self.0.write_timeout() } + /// Receives data on the socket from the remote adress to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// Successive calls return the same data. This is accomplished by passing + /// `MSG_PEEK` as a flag to the underlying `recv` system call. + /// + /// # Examples + /// + /// ```no_run + /// #![feature(peek)] + /// use std::net::TcpStream; + /// + /// let stream = TcpStream::connect("127.0.0.1:8000") + /// .expect("couldn't bind to address"); + /// let mut buf = [0; 10]; + /// let len = stream.peek(&mut buf).expect("peek failed"); + /// ``` + #[unstable(feature = "peek", issue = "38980")] + pub fn peek(&self, buf: &mut [u8]) -> io::Result { + self.0.peek(buf) + } + /// Sets the value of the `TCP_NODELAY` option on this socket. /// /// If set, this option disables the Nagle algorithm. This means that @@ -1405,4 +1428,35 @@ mod tests { Err(e) => panic!("unexpected error {}", e), } } + + #[test] + fn peek() { + each_ip(&mut |addr| { + let (txdone, rxdone) = channel(); + + let srv = t!(TcpListener::bind(&addr)); + let _t = thread::spawn(move|| { + let mut cl = t!(srv.accept()).0; + cl.write(&[1,3,3,7]).unwrap(); + t!(rxdone.recv()); + }); + + let mut c = t!(TcpStream::connect(&addr)); + let mut b = [0; 10]; + for _ in 1..3 { + let len = c.peek(&mut b).unwrap(); + assert_eq!(len, 4); + } + let len = c.read(&mut b).unwrap(); + assert_eq!(len, 4); + + t!(c.set_nonblocking(true)); + match c.peek(&mut b) { + Ok(_) => panic!("expected error"), + Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} + Err(e) => panic!("unexpected error {}", e), + } + t!(txdone.send(())); + }) + } } diff --git a/src/libstd/net/udp.rs b/src/libstd/net/udp.rs index f8a5ec0b3791e..2f28f475dc88b 100644 --- a/src/libstd/net/udp.rs +++ b/src/libstd/net/udp.rs @@ -83,6 +83,30 @@ impl UdpSocket { self.0.recv_from(buf) } + /// Receives data from the socket, without removing it from the queue. + /// + /// Successive calls return the same data. This is accomplished by passing + /// `MSG_PEEK` as a flag to the underlying `recvfrom` system call. + /// + /// On success, returns the number of bytes peeked and the address from + /// whence the data came. + /// + /// # Examples + /// + /// ```no_run + /// #![feature(peek)] + /// use std::net::UdpSocket; + /// + /// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address"); + /// let mut buf = [0; 10]; + /// let (number_of_bytes, src_addr) = socket.peek_from(&mut buf) + /// .expect("Didn't receive data"); + /// ``` + #[unstable(feature = "peek", issue = "38980")] + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.0.peek_from(buf) + } + /// Sends data on the socket to the given address. On success, returns the /// number of bytes written. /// @@ -579,6 +603,37 @@ impl UdpSocket { self.0.recv(buf) } + /// Receives data on the socket from the remote adress to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// Successive calls return the same data. This is accomplished by passing + /// `MSG_PEEK` as a flag to the underlying `recv` system call. + /// + /// # Errors + /// + /// This method will fail if the socket is not connected. The `connect` method + /// will connect this socket to a remote address. + /// + /// # Examples + /// + /// ```no_run + /// #![feature(peek)] + /// use std::net::UdpSocket; + /// + /// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address"); + /// socket.connect("127.0.0.1:8080").expect("connect function failed"); + /// let mut buf = [0; 10]; + /// match socket.peek(&mut buf) { + /// Ok(received) => println!("received {} bytes", received), + /// Err(e) => println!("peek function failed: {:?}", e), + /// } + /// ``` + #[unstable(feature = "peek", issue = "38980")] + pub fn peek(&self, buf: &mut [u8]) -> io::Result { + self.0.peek(buf) + } + /// Moves this UDP socket into or out of nonblocking mode. /// /// On Unix this corresponds to calling fcntl, and on Windows this @@ -869,6 +924,48 @@ mod tests { assert_eq!(b"hello world", &buf[..]); } + #[test] + fn connect_send_peek_recv() { + each_ip(&mut |addr, _| { + let socket = t!(UdpSocket::bind(&addr)); + t!(socket.connect(addr)); + + t!(socket.send(b"hello world")); + + for _ in 1..3 { + let mut buf = [0; 11]; + let size = t!(socket.peek(&mut buf)); + assert_eq!(b"hello world", &buf[..]); + assert_eq!(size, 11); + } + + let mut buf = [0; 11]; + let size = t!(socket.recv(&mut buf)); + assert_eq!(b"hello world", &buf[..]); + assert_eq!(size, 11); + }) + } + + #[test] + fn peek_from() { + each_ip(&mut |addr, _| { + let socket = t!(UdpSocket::bind(&addr)); + t!(socket.send_to(b"hello world", &addr)); + + for _ in 1..3 { + let mut buf = [0; 11]; + let (size, _) = t!(socket.peek_from(&mut buf)); + assert_eq!(b"hello world", &buf[..]); + assert_eq!(size, 11); + } + + let mut buf = [0; 11]; + let (size, _) = t!(socket.recv_from(&mut buf)); + assert_eq!(b"hello world", &buf[..]); + assert_eq!(size, 11); + }) + } + #[test] fn ttl() { let ttl = 100; diff --git a/src/libstd/sys/unix/net.rs b/src/libstd/sys/unix/net.rs index ad287bbec3889..5efddca110f05 100644 --- a/src/libstd/sys/unix/net.rs +++ b/src/libstd/sys/unix/net.rs @@ -10,12 +10,13 @@ use ffi::CStr; use io; -use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM}; +use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK}; +use mem; use net::{SocketAddr, Shutdown}; use str; use sys::fd::FileDesc; use sys_common::{AsInner, FromInner, IntoInner}; -use sys_common::net::{getsockopt, setsockopt}; +use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr}; use time::Duration; pub use sys::{cvt, cvt_r}; @@ -155,8 +156,46 @@ impl Socket { self.0.duplicate().map(Socket) } + fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result { + let ret = cvt(unsafe { + libc::recv(self.0.raw(), + buf.as_mut_ptr() as *mut c_void, + buf.len(), + flags) + })?; + Ok(ret as usize) + } + pub fn read(&self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + self.recv_with_flags(buf, 0) + } + + pub fn peek(&self, buf: &mut [u8]) -> io::Result { + self.recv_with_flags(buf, MSG_PEEK) + } + + fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int) + -> io::Result<(usize, SocketAddr)> { + let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t; + + let n = cvt(unsafe { + libc::recvfrom(self.0.raw(), + buf.as_mut_ptr() as *mut c_void, + buf.len(), + flags, + &mut storage as *mut _ as *mut _, + &mut addrlen) + })?; + Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?)) + } + + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.recv_from_with_flags(buf, 0) + } + + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.recv_from_with_flags(buf, MSG_PEEK) } pub fn read_to_end(&self, buf: &mut Vec) -> io::Result { diff --git a/src/libstd/sys/windows/c.rs b/src/libstd/sys/windows/c.rs index dc7b2fc9a6bab..9f03f5c9717fc 100644 --- a/src/libstd/sys/windows/c.rs +++ b/src/libstd/sys/windows/c.rs @@ -244,6 +244,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12; pub const IP_DROP_MEMBERSHIP: c_int = 13; pub const IPV6_ADD_MEMBERSHIP: c_int = 12; pub const IPV6_DROP_MEMBERSHIP: c_int = 13; +pub const MSG_PEEK: c_int = 0x2; #[repr(C)] pub struct ip_mreq { diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs index aca6994503ff8..adf6210d82e89 100644 --- a/src/libstd/sys/windows/net.rs +++ b/src/libstd/sys/windows/net.rs @@ -147,12 +147,12 @@ impl Socket { Ok(socket) } - pub fn read(&self, buf: &mut [u8]) -> io::Result { + fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result { // On unix when a socket is shut down all further reads return 0, so we // do the same on windows to map a shut down socket to returning EOF. let len = cmp::min(buf.len(), i32::max_value() as usize) as i32; unsafe { - match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) { + match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) { -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0), -1 => Err(last_error()), n => Ok(n as usize) @@ -160,6 +160,46 @@ impl Socket { } } + pub fn read(&self, buf: &mut [u8]) -> io::Result { + self.recv_with_flags(buf, 0) + } + + pub fn peek(&self, buf: &mut [u8]) -> io::Result { + self.recv_with_flags(buf, c::MSG_PEEK) + } + + fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int) + -> io::Result<(usize, SocketAddr)> { + let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() }; + let mut addrlen = mem::size_of_val(&storage) as c::socklen_t; + let len = cmp::min(buf.len(), ::max_value() as usize) as wrlen_t; + + // On unix when a socket is shut down all further reads return 0, so we + // do the same on windows to map a shut down socket to returning EOF. + unsafe { + match c::recvfrom(self.0, + buf.as_mut_ptr() as *mut c_void, + len, + flags, + &mut storage as *mut _ as *mut _, + &mut addrlen) { + -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => { + Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?)) + }, + -1 => Err(last_error()), + n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)), + } + } + } + + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.recv_from_with_flags(buf, 0) + } + + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.recv_from_with_flags(buf, c::MSG_PEEK) + } + pub fn read_to_end(&self, buf: &mut Vec) -> io::Result { let mut me = self; (&mut me).read_to_end(buf) diff --git a/src/libstd/sys_common/net.rs b/src/libstd/sys_common/net.rs index 10ad61f4c800c..3cdeb51194575 100644 --- a/src/libstd/sys_common/net.rs +++ b/src/libstd/sys_common/net.rs @@ -91,7 +91,7 @@ fn sockname(f: F) -> io::Result } } -fn sockaddr_to_addr(storage: &c::sockaddr_storage, +pub fn sockaddr_to_addr(storage: &c::sockaddr_storage, len: usize) -> io::Result { match storage.ss_family as c_int { c::AF_INET => { @@ -222,6 +222,10 @@ impl TcpStream { self.inner.timeout(c::SO_SNDTIMEO) } + pub fn peek(&self, buf: &mut [u8]) -> io::Result { + self.inner.peek(buf) + } + pub fn read(&self, buf: &mut [u8]) -> io::Result { self.inner.read(buf) } @@ -441,17 +445,11 @@ impl UdpSocket { } pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() }; - let mut addrlen = mem::size_of_val(&storage) as c::socklen_t; - let len = cmp::min(buf.len(), ::max_value() as usize) as wrlen_t; + self.inner.recv_from(buf) + } - let n = cvt(unsafe { - c::recvfrom(*self.inner.as_inner(), - buf.as_mut_ptr() as *mut c_void, - len, 0, - &mut storage as *mut _ as *mut _, &mut addrlen) - })?; - Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?)) + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.inner.peek_from(buf) } pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result { @@ -578,6 +576,10 @@ impl UdpSocket { self.inner.read(buf) } + pub fn peek(&self, buf: &mut [u8]) -> io::Result { + self.inner.peek(buf) + } + pub fn send(&self, buf: &[u8]) -> io::Result { let len = cmp::min(buf.len(), ::max_value() as usize) as wrlen_t; let ret = cvt(unsafe {