diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index 51d5f2867e..13d7b6b098 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -162,57 +162,69 @@ pub fn test_socketpair() { } mod recvfrom { + use nix::Result; use nix::sys::socket::*; use std::thread; + use super::*; + + const MSG: &'static [u8] = b"Hello, World!"; + + fn sendrecv(rsock: RawFd, ssock: RawFd, f: F) -> Option + where F: Fn(RawFd, &[u8], MsgFlags) -> Result + Send + 'static + { + let mut buf: [u8; 13] = [0u8; 13]; + let mut l = 0; + let mut from = None; - #[test] - pub fn datagram() { - let msg = b"Hello, World!"; - let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Datagram, - None, SockFlag::empty()).unwrap(); let send_thread = thread::spawn(move || { let mut l = 0; - while l < std::mem::size_of_val(msg) { - let flags = MsgFlags::empty(); - l += send(fd1, &msg[l..], flags).unwrap(); + while l < std::mem::size_of_val(MSG) { + l += f(ssock, &MSG[l..], MsgFlags::empty()).unwrap(); } }); - let mut buf: [u8; 13] = [0u8; 13]; - let mut l = 0; - - while l < std::mem::size_of_val(msg) { - let (len, from) = recvfrom(fd2, &mut buf[l..]).unwrap(); + while l < std::mem::size_of_val(MSG) { + let (len, from_) = recvfrom(rsock, &mut buf[l..]).unwrap(); + from = from_; l += len; - assert_eq!(AddressFamily::Unix, from.unwrap().family()); } - assert_eq!(&buf, msg); + assert_eq!(&buf, MSG); send_thread.join().unwrap(); + from } #[test] pub fn stream() { - let msg = b"Hello, World!"; - let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, + let (fd2, fd1) = socketpair(AddressFamily::Unix, SockType::Stream, None, SockFlag::empty()).unwrap(); - let send_thread = thread::spawn(move || { - let mut l = 0; - while l < std::mem::size_of_val(msg) { - let flags = MsgFlags::empty(); - l += send(fd1, &msg[l..], flags).unwrap(); - } + // Ignore from for stream sockets + let _ = sendrecv(fd1, fd2, |s, m, flags| { + send(s, m, flags) }); + } - let mut buf: [u8; 13] = [0u8; 13]; - let mut l = 0; - - while l < std::mem::size_of_val(msg) { - let (len, _from) = recvfrom(fd2, &mut buf[l..]).unwrap(); - l += len; - // Ignore _from for stream sockets - } - assert_eq!(&buf, msg); - send_thread.join().unwrap(); + #[test] + pub fn udp() { + let std_sa = SocketAddr::from_str("127.0.0.1:6789").unwrap(); + let inet_addr = InetAddr::from_std(&std_sa); + let sock_addr = SockAddr::new_inet(inet_addr); + let rsock = socket(AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None + ).unwrap(); + bind(rsock, &sock_addr).unwrap(); + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + ).expect("send socket failed"); + let from = sendrecv(rsock, ssock, move |s, m, flags| { + sendto(s, m, &sock_addr, flags) + }); + // UDP sockets should set the from address + assert_eq!(AddressFamily::Inet, from.unwrap().family()); } }