Skip to content

Commit

Permalink
add recvfrom support
Browse files Browse the repository at this point in the history
This extension of the kernel interface solves partly issue
hermit-os#967
  • Loading branch information
stlankes committed Jul 22, 2024
1 parent a011b97 commit d91465d
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,23 @@ pub trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone {
-EINVAL
}

/// receive a message from a socket
///
/// If `address` is not a null pointer, the source address of the message is filled in. The
/// `address_len` argument is a value-result argument, initialized to the size
/// of the buffer associated with address, and modified on return to
/// indicate the actual size of the address stored there.
#[cfg(all(any(feature = "tcp", feature = "udp"), not(feature = "newlib")))]
fn recvfrom(
&self,
_buffer: *mut u8,
_len: usize,
_address: *mut sockaddr,
_address_len: *mut socklen_t,
) -> isize {
(-ENOSYS).try_into().unwrap()
}

/// shut down part of a full-duplex connection
#[cfg(all(any(feature = "tcp", feature = "udp"), not(feature = "newlib")))]
fn shutdown(&self, _how: i32) -> i32 {
Expand Down
12 changes: 12 additions & 0 deletions src/fd/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,15 @@ pub extern "C" fn __sys_recv(fd: i32, buf: *mut u8, len: usize) -> isize {
let obj = get_object(fd);
obj.map_or_else(|e| e as isize, |v| (*v).read(buf, len))
}

pub extern "C" fn __sys_recvfrom(
fd: i32,
buf: *mut u8,
len: usize,
_flags: i32,
addr: *mut sockaddr,
addr_len: *mut socklen_t,
) -> isize {
let obj = get_object(fd);
obj.map_or_else(|e| e as isize, |v| (*v).recvfrom(buf, len, addr, addr_len))
}
38 changes: 38 additions & 0 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,25 @@ impl ObjectInterface for Socket<IPv4> {
fn ioctl(&self, cmd: i32, argp: *mut c_void) -> i32 {
self.ioctl(cmd, argp)
}

fn recvfrom(
&self,
buf: *mut u8,
len: usize,
address: *mut sockaddr,
address_len: *mut socklen_t,
) -> isize {
let nbytes = self.read(buf, len);
if nbytes >= 0 && !address.is_null() {
let result = self.getpeername(address, address_len);

if result < 0 {
return result.try_into().unwrap();
}
}

nbytes
}
}

impl ObjectInterface for Socket<IPv6> {
Expand Down Expand Up @@ -720,4 +739,23 @@ impl ObjectInterface for Socket<IPv6> {
fn ioctl(&self, cmd: i32, argp: *mut c_void) -> i32 {
self.ioctl(cmd, argp)
}

fn recvfrom(
&self,
buf: *mut u8,
len: usize,
address: *mut sockaddr,
address_len: *mut socklen_t,
) -> isize {
let nbytes = self.read(buf, len);
if nbytes >= 0 && !address.is_null() {
let result = self.getpeername(address, address_len);

if result < 0 {
return result.try_into().unwrap();
}
}

nbytes
}
}
207 changes: 207 additions & 0 deletions src/fd/socket/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use core::task::Poll;

use crossbeam_utils::atomic::AtomicCell;
use smoltcp::socket::udp;
use smoltcp::socket::udp::UdpMetadata;
use smoltcp::time::Duration;
use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint};

Expand Down Expand Up @@ -91,6 +92,38 @@ impl<T> Socket<T> {
.await
}

async fn async_recvfrom(&self, buffer: &mut [u8]) -> Result<(isize, UdpMetadata), i32> {
future::poll_fn(|cx| {
self.with(|socket| {
if socket.is_open() {
if socket.can_recv() {
match socket.recv_slice(buffer) {
Ok((len, meta)) => match self.endpoint.load() {
Some(ep) => {
if meta.endpoint == ep {
Poll::Ready(Ok((len.try_into().unwrap(), meta)))
} else {
buffer[..len].iter_mut().for_each(|x| *x = 0);
socket.register_recv_waker(cx.waker());
Poll::Pending
}
}
None => Poll::Ready(Ok((len.try_into().unwrap(), meta))),
},
_ => Poll::Ready(Err(-crate::errno::EIO)),
}
} else {
socket.register_recv_waker(cx.waker());
Poll::Pending
}
} else {
Poll::Ready(Err(-crate::errno::EIO))
}
})
})
.await
}

async fn async_write(&self, buffer: &[u8]) -> Result<isize, i32> {
let endpoint = self.endpoint.load();
if endpoint.is_none() {
Expand Down Expand Up @@ -262,6 +295,93 @@ impl ObjectInterface for Socket<IPv4> {
fn write(&self, buf: *const u8, len: usize) -> isize {
self.write(buf, len)
}

fn recvfrom(
&self,
buf: *mut u8,
len: usize,
address: *mut sockaddr,
address_len: *mut socklen_t,
) -> isize {
if !address_len.is_null() {
let len = unsafe { &mut *address_len };
if *len < size_of::<sockaddr_in>().try_into().unwrap() {
return (-EINVAL).try_into().unwrap();
}
}

if len == 0 {
return (-EINVAL).try_into().unwrap();
}

let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };

if self.nonblocking.load(Ordering::Acquire) {
poll_on(self.async_recvfrom(slice), Some(Duration::ZERO)).map_or_else(
|x| {
if x == -ETIME {
(-EAGAIN).try_into().unwrap()
} else {
x.try_into().unwrap()
}
},
|(x, meta)| {
let len = unsafe { &mut *address_len };
if address.is_null() {
*len = 0;
} else {
let addr = unsafe { &mut *(address as *mut sockaddr_in) };
addr.sin_port = meta.endpoint.port.to_be();
if let IpAddress::Ipv4(ip) = meta.endpoint.addr {
addr.sin_addr.s_addr.copy_from_slice(ip.as_bytes());
}
*len = size_of::<sockaddr_in>().try_into().unwrap();
}
x.try_into().unwrap()
},
)
} else {
poll_on(self.async_recvfrom(slice), Some(Duration::from_secs(2))).map_or_else(
|x| {
if x == -ETIME {
block_on(self.async_recvfrom(slice), None).map_or_else(
|x| x.try_into().unwrap(),
|(x, meta)| {
let len = unsafe { &mut *address_len };
if address.is_null() {
*len = 0;
} else {
let addr = unsafe { &mut *(address as *mut sockaddr_in) };
addr.sin_port = meta.endpoint.port.to_be();
if let IpAddress::Ipv4(ip) = meta.endpoint.addr {
addr.sin_addr.s_addr.copy_from_slice(ip.as_bytes());
}
*len = size_of::<sockaddr_in>().try_into().unwrap();
}
x.try_into().unwrap()
},
)
} else {
x.try_into().unwrap()
}
},
|(x, meta)| {
let len = unsafe { &mut *address_len };
if address.is_null() {
*len = 0;
} else {
let addr = unsafe { &mut *(address as *mut sockaddr_in) };
addr.sin_port = meta.endpoint.port.to_be();
if let IpAddress::Ipv4(ip) = meta.endpoint.addr {
addr.sin_addr.s_addr.copy_from_slice(ip.as_bytes());
}
*len = size_of::<sockaddr_in>().try_into().unwrap();
}
x.try_into().unwrap()
},
)
}
}
}

impl ObjectInterface for Socket<IPv6> {
Expand Down Expand Up @@ -329,6 +449,93 @@ impl ObjectInterface for Socket<IPv6> {
self.read(buf, len)
}

fn recvfrom(
&self,
buf: *mut u8,
len: usize,
address: *mut sockaddr,
address_len: *mut socklen_t,
) -> isize {
if !address_len.is_null() {
let len = unsafe { &mut *address_len };
if *len < size_of::<sockaddr_in6>().try_into().unwrap() {
return (-EINVAL).try_into().unwrap();
}
}

if len == 0 {
return (-EINVAL).try_into().unwrap();
}

let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };

if self.nonblocking.load(Ordering::Acquire) {
poll_on(self.async_recvfrom(slice), Some(Duration::ZERO)).map_or_else(
|x| {
if x == -ETIME {
(-EAGAIN).try_into().unwrap()
} else {
x.try_into().unwrap()
}
},
|(x, meta)| {
let len = unsafe { &mut *address_len };
if address.is_null() {
*len = 0;
} else {
let addr = unsafe { &mut *(address as *mut sockaddr_in6) };
addr.sin6_port = meta.endpoint.port.to_be();
if let IpAddress::Ipv6(ip) = meta.endpoint.addr {
addr.sin6_addr.s6_addr.copy_from_slice(ip.as_bytes());
}
*len = size_of::<sockaddr_in6>().try_into().unwrap();
}
x.try_into().unwrap()
},
)
} else {
poll_on(self.async_recvfrom(slice), Some(Duration::from_secs(2))).map_or_else(
|x| {
if x == -ETIME {
block_on(self.async_recvfrom(slice), None).map_or_else(
|x| x.try_into().unwrap(),
|(x, meta)| {
let len = unsafe { &mut *address_len };
if address.is_null() {
*len = 0;
} else {
let addr = unsafe { &mut *(address as *mut sockaddr_in6) };
addr.sin6_port = meta.endpoint.port.to_be();
if let IpAddress::Ipv6(ip) = meta.endpoint.addr {
addr.sin6_addr.s6_addr.copy_from_slice(ip.as_bytes());
}
*len = size_of::<sockaddr_in6>().try_into().unwrap();
}
x.try_into().unwrap()
},
)
} else {
x.try_into().unwrap()
}
},
|(x, meta)| {
let len = unsafe { &mut *address_len };
if address.is_null() {
*len = 0;
} else {
let addr = unsafe { &mut *(address as *mut sockaddr_in6) };
addr.sin6_port = meta.endpoint.port.to_be();
if let IpAddress::Ipv6(ip) = meta.endpoint.addr {
addr.sin6_addr.s6_addr.copy_from_slice(ip.as_bytes());
}
*len = size_of::<sockaddr_in6>().try_into().unwrap();
}
x.try_into().unwrap()
},
)
}
}

fn write(&self, buf: *const u8, len: usize) -> isize {
self.write(buf, len)
}
Expand Down
24 changes: 24 additions & 0 deletions src/syscalls/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,27 @@ pub extern "C" fn sys_recv(fd: i32, buf: *mut u8, len: usize, flags: i32) -> isi
(-crate::errno::EINVAL).try_into().unwrap()
}
}

#[no_mangle]
pub extern "C" fn sys_sendto(
_s: i32,
_mem: *const c_void,
_len: usize,
_flags: i32,
_to: *const sockaddr,
_tolen: socklen_t,
) -> isize {
(-crate::errno::EINVAL).try_into().unwrap()
}

#[no_mangle]
pub extern "C" fn sys_recvfrom(
socket: i32,
buf: *mut u8,
len: usize,
flags: i32,
addr: *mut sockaddr,
addrlen: *mut socklen_t,
) -> isize {
kernel_function!(__sys_recvfrom(socket, buf, len, flags, addr, addrlen))
}

0 comments on commit d91465d

Please sign in to comment.