From d91465d86feb7ea47d80ad29505607b0a9205a10 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Sat, 16 Dec 2023 00:52:53 +0100 Subject: [PATCH] add recvfrom support This extension of the kernel interface solves partly issue hermit-os/kernel#967 --- src/fd/mod.rs | 17 ++++ src/fd/socket/mod.rs | 12 +++ src/fd/socket/tcp.rs | 38 ++++++++ src/fd/socket/udp.rs | 207 +++++++++++++++++++++++++++++++++++++++++++ src/syscalls/net.rs | 24 +++++ 5 files changed, 298 insertions(+) diff --git a/src/fd/mod.rs b/src/fd/mod.rs index d357d14c18..4ba4a7eb3f 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -302,6 +302,23 @@ pub trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { -EINVAL } + /// receive a message from a socket + /// + /// If `address` is not a null pointer, the source address of the message is filled in. The + /// `address_len` argument is a value-result argument, initialized to the size + /// of the buffer associated with address, and modified on return to + /// indicate the actual size of the address stored there. + #[cfg(all(any(feature = "tcp", feature = "udp"), not(feature = "newlib")))] + fn recvfrom( + &self, + _buffer: *mut u8, + _len: usize, + _address: *mut sockaddr, + _address_len: *mut socklen_t, + ) -> isize { + (-ENOSYS).try_into().unwrap() + } + /// shut down part of a full-duplex connection #[cfg(all(any(feature = "tcp", feature = "udp"), not(feature = "newlib")))] fn shutdown(&self, _how: i32) -> i32 { diff --git a/src/fd/socket/mod.rs b/src/fd/socket/mod.rs index 797f1410c7..67d38d7070 100644 --- a/src/fd/socket/mod.rs +++ b/src/fd/socket/mod.rs @@ -186,3 +186,15 @@ pub extern "C" fn __sys_recv(fd: i32, buf: *mut u8, len: usize) -> isize { let obj = get_object(fd); obj.map_or_else(|e| e as isize, |v| (*v).read(buf, len)) } + +pub extern "C" fn __sys_recvfrom( + fd: i32, + buf: *mut u8, + len: usize, + _flags: i32, + addr: *mut sockaddr, + addr_len: *mut socklen_t, +) -> isize { + let obj = get_object(fd); + obj.map_or_else(|e| e as isize, |v| (*v).recvfrom(buf, len, addr, addr_len)) +} diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 8d464f1dd7..8eac47df2b 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -571,6 +571,25 @@ impl ObjectInterface for Socket { fn ioctl(&self, cmd: i32, argp: *mut c_void) -> i32 { self.ioctl(cmd, argp) } + + fn recvfrom( + &self, + buf: *mut u8, + len: usize, + address: *mut sockaddr, + address_len: *mut socklen_t, + ) -> isize { + let nbytes = self.read(buf, len); + if nbytes >= 0 && !address.is_null() { + let result = self.getpeername(address, address_len); + + if result < 0 { + return result.try_into().unwrap(); + } + } + + nbytes + } } impl ObjectInterface for Socket { @@ -720,4 +739,23 @@ impl ObjectInterface for Socket { fn ioctl(&self, cmd: i32, argp: *mut c_void) -> i32 { self.ioctl(cmd, argp) } + + fn recvfrom( + &self, + buf: *mut u8, + len: usize, + address: *mut sockaddr, + address_len: *mut socklen_t, + ) -> isize { + let nbytes = self.read(buf, len); + if nbytes >= 0 && !address.is_null() { + let result = self.getpeername(address, address_len); + + if result < 0 { + return result.try_into().unwrap(); + } + } + + nbytes + } } diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index 4fa86e4fe8..2137c28c48 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -8,6 +8,7 @@ use core::task::Poll; use crossbeam_utils::atomic::AtomicCell; use smoltcp::socket::udp; +use smoltcp::socket::udp::UdpMetadata; use smoltcp::time::Duration; use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint}; @@ -91,6 +92,38 @@ impl Socket { .await } + async fn async_recvfrom(&self, buffer: &mut [u8]) -> Result<(isize, UdpMetadata), i32> { + future::poll_fn(|cx| { + self.with(|socket| { + if socket.is_open() { + if socket.can_recv() { + match socket.recv_slice(buffer) { + Ok((len, meta)) => match self.endpoint.load() { + Some(ep) => { + if meta.endpoint == ep { + Poll::Ready(Ok((len.try_into().unwrap(), meta))) + } else { + buffer[..len].iter_mut().for_each(|x| *x = 0); + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + } + None => Poll::Ready(Ok((len.try_into().unwrap(), meta))), + }, + _ => Poll::Ready(Err(-crate::errno::EIO)), + } + } else { + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + } else { + Poll::Ready(Err(-crate::errno::EIO)) + } + }) + }) + .await + } + async fn async_write(&self, buffer: &[u8]) -> Result { let endpoint = self.endpoint.load(); if endpoint.is_none() { @@ -262,6 +295,93 @@ impl ObjectInterface for Socket { fn write(&self, buf: *const u8, len: usize) -> isize { self.write(buf, len) } + + fn recvfrom( + &self, + buf: *mut u8, + len: usize, + address: *mut sockaddr, + address_len: *mut socklen_t, + ) -> isize { + if !address_len.is_null() { + let len = unsafe { &mut *address_len }; + if *len < size_of::().try_into().unwrap() { + return (-EINVAL).try_into().unwrap(); + } + } + + if len == 0 { + return (-EINVAL).try_into().unwrap(); + } + + let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) }; + + if self.nonblocking.load(Ordering::Acquire) { + poll_on(self.async_recvfrom(slice), Some(Duration::ZERO)).map_or_else( + |x| { + if x == -ETIME { + (-EAGAIN).try_into().unwrap() + } else { + x.try_into().unwrap() + } + }, + |(x, meta)| { + let len = unsafe { &mut *address_len }; + if address.is_null() { + *len = 0; + } else { + let addr = unsafe { &mut *(address as *mut sockaddr_in) }; + addr.sin_port = meta.endpoint.port.to_be(); + if let IpAddress::Ipv4(ip) = meta.endpoint.addr { + addr.sin_addr.s_addr.copy_from_slice(ip.as_bytes()); + } + *len = size_of::().try_into().unwrap(); + } + x.try_into().unwrap() + }, + ) + } else { + poll_on(self.async_recvfrom(slice), Some(Duration::from_secs(2))).map_or_else( + |x| { + if x == -ETIME { + block_on(self.async_recvfrom(slice), None).map_or_else( + |x| x.try_into().unwrap(), + |(x, meta)| { + let len = unsafe { &mut *address_len }; + if address.is_null() { + *len = 0; + } else { + let addr = unsafe { &mut *(address as *mut sockaddr_in) }; + addr.sin_port = meta.endpoint.port.to_be(); + if let IpAddress::Ipv4(ip) = meta.endpoint.addr { + addr.sin_addr.s_addr.copy_from_slice(ip.as_bytes()); + } + *len = size_of::().try_into().unwrap(); + } + x.try_into().unwrap() + }, + ) + } else { + x.try_into().unwrap() + } + }, + |(x, meta)| { + let len = unsafe { &mut *address_len }; + if address.is_null() { + *len = 0; + } else { + let addr = unsafe { &mut *(address as *mut sockaddr_in) }; + addr.sin_port = meta.endpoint.port.to_be(); + if let IpAddress::Ipv4(ip) = meta.endpoint.addr { + addr.sin_addr.s_addr.copy_from_slice(ip.as_bytes()); + } + *len = size_of::().try_into().unwrap(); + } + x.try_into().unwrap() + }, + ) + } + } } impl ObjectInterface for Socket { @@ -329,6 +449,93 @@ impl ObjectInterface for Socket { self.read(buf, len) } + fn recvfrom( + &self, + buf: *mut u8, + len: usize, + address: *mut sockaddr, + address_len: *mut socklen_t, + ) -> isize { + if !address_len.is_null() { + let len = unsafe { &mut *address_len }; + if *len < size_of::().try_into().unwrap() { + return (-EINVAL).try_into().unwrap(); + } + } + + if len == 0 { + return (-EINVAL).try_into().unwrap(); + } + + let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) }; + + if self.nonblocking.load(Ordering::Acquire) { + poll_on(self.async_recvfrom(slice), Some(Duration::ZERO)).map_or_else( + |x| { + if x == -ETIME { + (-EAGAIN).try_into().unwrap() + } else { + x.try_into().unwrap() + } + }, + |(x, meta)| { + let len = unsafe { &mut *address_len }; + if address.is_null() { + *len = 0; + } else { + let addr = unsafe { &mut *(address as *mut sockaddr_in6) }; + addr.sin6_port = meta.endpoint.port.to_be(); + if let IpAddress::Ipv6(ip) = meta.endpoint.addr { + addr.sin6_addr.s6_addr.copy_from_slice(ip.as_bytes()); + } + *len = size_of::().try_into().unwrap(); + } + x.try_into().unwrap() + }, + ) + } else { + poll_on(self.async_recvfrom(slice), Some(Duration::from_secs(2))).map_or_else( + |x| { + if x == -ETIME { + block_on(self.async_recvfrom(slice), None).map_or_else( + |x| x.try_into().unwrap(), + |(x, meta)| { + let len = unsafe { &mut *address_len }; + if address.is_null() { + *len = 0; + } else { + let addr = unsafe { &mut *(address as *mut sockaddr_in6) }; + addr.sin6_port = meta.endpoint.port.to_be(); + if let IpAddress::Ipv6(ip) = meta.endpoint.addr { + addr.sin6_addr.s6_addr.copy_from_slice(ip.as_bytes()); + } + *len = size_of::().try_into().unwrap(); + } + x.try_into().unwrap() + }, + ) + } else { + x.try_into().unwrap() + } + }, + |(x, meta)| { + let len = unsafe { &mut *address_len }; + if address.is_null() { + *len = 0; + } else { + let addr = unsafe { &mut *(address as *mut sockaddr_in6) }; + addr.sin6_port = meta.endpoint.port.to_be(); + if let IpAddress::Ipv6(ip) = meta.endpoint.addr { + addr.sin6_addr.s6_addr.copy_from_slice(ip.as_bytes()); + } + *len = size_of::().try_into().unwrap(); + } + x.try_into().unwrap() + }, + ) + } + } + fn write(&self, buf: *const u8, len: usize) -> isize { self.write(buf, len) } diff --git a/src/syscalls/net.rs b/src/syscalls/net.rs index 64ce3d0c98..e80a62d85f 100644 --- a/src/syscalls/net.rs +++ b/src/syscalls/net.rs @@ -210,3 +210,27 @@ pub extern "C" fn sys_recv(fd: i32, buf: *mut u8, len: usize, flags: i32) -> isi (-crate::errno::EINVAL).try_into().unwrap() } } + +#[no_mangle] +pub extern "C" fn sys_sendto( + _s: i32, + _mem: *const c_void, + _len: usize, + _flags: i32, + _to: *const sockaddr, + _tolen: socklen_t, +) -> isize { + (-crate::errno::EINVAL).try_into().unwrap() +} + +#[no_mangle] +pub extern "C" fn sys_recvfrom( + socket: i32, + buf: *mut u8, + len: usize, + flags: i32, + addr: *mut sockaddr, + addrlen: *mut socklen_t, +) -> isize { + kernel_function!(__sys_recvfrom(socket, buf, len, flags, addr, addrlen)) +}