Skip to content

Commit

Permalink
feat: impl IO for some foreign types; rewrite readv/writev impl for s…
Browse files Browse the repository at this point in the history
…ome types (#313)

* feat: impl AsyncRead/WriteRent for some foreign types; rewrite readv/writev impl for some types

* test: stablize time related tests
  • Loading branch information
ihciah authored Oct 29, 2024
1 parent ead462c commit 8a8ad70
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 51 deletions.
2 changes: 0 additions & 2 deletions monoio-compat/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! For compat with tokio AsyncRead and AsyncWrite.
#![cfg_attr(feature = "unstable", feature(new_uninit))]

pub mod box_future;
mod buf;

Expand Down
47 changes: 21 additions & 26 deletions monoio/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,44 +287,39 @@ mod tests {

#[test]
fn default_pool() {
let shared_pool = Box::new(DefaultThreadPool::new(3));
let shared_pool = Box::new(DefaultThreadPool::new(6));
let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
.attach_thread_pool(shared_pool)
.enable_timer()
.build()
.unwrap();
fn thread_sleep(s: &'static str) -> impl FnOnce() -> &'static str {
move || {
// Simulate a heavy computation.
std::thread::sleep(std::time::Duration::from_millis(500));
s
}
}
rt.block_on(async {
let begin = std::time::Instant::now();
let join1 = crate::spawn_blocking(|| {
// Simulate a heavy computation.
std::thread::sleep(std::time::Duration::from_millis(150));
"hello spawn_blocking1!".to_string()
});
let join2 = crate::spawn_blocking(|| {
// Simulate a heavy computation.
std::thread::sleep(std::time::Duration::from_millis(150));
"hello spawn_blocking2!".to_string()
});
let join3 = crate::spawn_blocking(|| {
// Simulate a heavy computation.
std::thread::sleep(std::time::Duration::from_millis(150));
"hello spawn_blocking3!".to_string()
});
let join4 = crate::spawn_blocking(|| {
// Simulate a heavy computation.
std::thread::sleep(std::time::Duration::from_millis(150));
"hello spawn_blocking4!".to_string()
});
let sleep_async = crate::time::sleep(std::time::Duration::from_millis(150));
let (result1, result2, result3, result4, _) =
crate::join!(join1, join2, join3, join4, sleep_async);
let join1 = crate::spawn_blocking(thread_sleep("hello spawn_blocking1!"));
let join2 = crate::spawn_blocking(thread_sleep("hello spawn_blocking2!"));
let join3 = crate::spawn_blocking(thread_sleep("hello spawn_blocking3!"));
let join4 = crate::spawn_blocking(thread_sleep("hello spawn_blocking4!"));
let join5 = crate::spawn_blocking(thread_sleep("hello spawn_blocking5!"));
let join6 = crate::spawn_blocking(thread_sleep("hello spawn_blocking6!"));
let sleep_async = crate::time::sleep(std::time::Duration::from_millis(500));
let (result1, result2, result3, result4, result5, result6, _) =
crate::join!(join1, join2, join3, join4, join5, join6, sleep_async);
let eps = begin.elapsed();
assert!(eps < std::time::Duration::from_millis(590));
assert!(eps >= std::time::Duration::from_millis(150));
assert!(eps < std::time::Duration::from_millis(3000));
assert!(eps >= std::time::Duration::from_millis(500));
assert_eq!(result1.unwrap(), "hello spawn_blocking1!");
assert_eq!(result2.unwrap(), "hello spawn_blocking2!");
assert_eq!(result3.unwrap(), "hello spawn_blocking3!");
assert_eq!(result4.unwrap(), "hello spawn_blocking4!");
assert_eq!(result5.unwrap(), "hello spawn_blocking5!");
assert_eq!(result6.unwrap(), "hello spawn_blocking6!");
});
}
}
6 changes: 6 additions & 0 deletions monoio/src/buf/io_buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ pub unsafe trait IoBuf: Unpin + 'static {
/// For `Vec`, this is identical to `len()`.
fn bytes_init(&self) -> usize;

/// Returns a slice of the buffer.
#[inline]
fn as_slice(&self) -> &[u8] {
unsafe { core::slice::from_raw_parts(self.read_ptr(), self.bytes_init()) }
}

/// Returns a view of the buffer with the specified range.
#[inline]
fn slice(self, range: impl ops::RangeBounds<usize>) -> Slice<Self>
Expand Down
5 changes: 3 additions & 2 deletions monoio/src/fs/file/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,9 @@ impl File {
Ok(())
}

async fn flush(&mut self) -> io::Result<()> {
Ok(())
#[inline]
fn flush(&mut self) -> impl Future<Output = io::Result<()>> {
std::future::ready(Ok(()))
}

/// Closes the file.
Expand Down
101 changes: 84 additions & 17 deletions monoio/src/io/async_read_rent.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::future::Future;
use std::{future::Future, io::Cursor};

use crate::{
buf::{IoBufMut, IoVecBufMut, RawBuf},
buf::{IoBufMut, IoVecBufMut},
BufResult,
};

Expand Down Expand Up @@ -49,6 +49,17 @@ pub trait AsyncReadRentAt {
) -> impl Future<Output = BufResult<usize, T>>;
}

impl<A: ?Sized + AsyncReadRentAt> AsyncReadRentAt for &mut A {
#[inline]
fn read_at<T: IoBufMut>(
&mut self,
buf: T,
pos: usize,
) -> impl Future<Output = BufResult<usize, T>> {
(**self).read_at(buf, pos)
}
}

impl<A: ?Sized + AsyncReadRent> AsyncReadRent for &mut A {
#[inline]
fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
Expand All @@ -70,29 +81,85 @@ impl AsyncReadRent for &[u8] {
buf.set_init(amt);
}
*self = b;
async move { (Ok(amt), buf) }
std::future::ready((Ok(amt), buf))
}

fn readv<T: IoVecBufMut>(&mut self, mut buf: T) -> impl Future<Output = BufResult<usize, T>> {
// # Safety
// We do it in pure sync way.
let n = match unsafe { RawBuf::new_from_iovec_mut(&mut buf) } {
Some(mut raw_buf) => {
// copy from read to avoid await
let amt = std::cmp::min(self.len(), raw_buf.bytes_total());
let mut sum = 0;
{
#[cfg(windows)]
let buf_slice = unsafe {
std::slice::from_raw_parts_mut(buf.write_wsabuf_ptr(), buf.write_wsabuf_len())
};
#[cfg(unix)]
let buf_slice = unsafe {
std::slice::from_raw_parts_mut(buf.write_iovec_ptr(), buf.write_iovec_len())
};
for buf in buf_slice {
#[cfg(windows)]
let amt = std::cmp::min(self.len(), buf.len as usize);
#[cfg(unix)]
let amt = std::cmp::min(self.len(), buf.iov_len);

let (a, b) = self.split_at(amt);
// # Safety
// The pointer is valid.
unsafe {
raw_buf
.write_ptr()
#[cfg(windows)]
buf.buf
.cast::<u8>()
.copy_from_nonoverlapping(a.as_ptr(), amt);
#[cfg(unix)]
buf.iov_base
.cast::<u8>()
.copy_from_nonoverlapping(a.as_ptr(), amt);
raw_buf.set_init(amt);
}
*self = b;
amt
sum += amt;

if self.is_empty() {
break;
}
}
None => 0,
};
unsafe { buf.set_init(n) };
async move { (Ok(n), buf) }
}

unsafe { buf.set_init(sum) };
std::future::ready((Ok(sum), buf))
}
}

impl<T: AsRef<[u8]>> AsyncReadRent for Cursor<T> {
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
let pos = self.position();
let slice: &[u8] = (*self).get_ref().as_ref();

if pos > slice.len() as u64 {
return (Ok(0), buf);
}

(&slice[pos as usize..]).read(buf).await
}

async fn readv<B: IoVecBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
let pos = self.position();
let slice: &[u8] = (*self).get_ref().as_ref();

if pos > slice.len() as u64 {
return (Ok(0), buf);
}

(&slice[pos as usize..]).readv(buf).await
}
}

impl<T: ?Sized + AsyncReadRent> AsyncReadRent for Box<T> {
#[inline]
fn read<B: IoBufMut>(&mut self, buf: B) -> impl Future<Output = BufResult<usize, B>> {
(**self).read(buf)
}

#[inline]
fn readv<B: IoVecBufMut>(&mut self, buf: B) -> impl Future<Output = BufResult<usize, B>> {
(**self).readv(buf)
}
}
64 changes: 64 additions & 0 deletions monoio/src/io/async_write_rent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ pub trait AsyncWriteRentAt {
) -> impl Future<Output = BufResult<usize, T>>;
}

impl<A: ?Sized + AsyncWriteRentAt> AsyncWriteRentAt for &mut A {
#[inline]
fn write_at<T: IoBuf>(
&mut self,
buf: T,
pos: usize,
) -> impl Future<Output = BufResult<usize, T>> {
(**self).write_at(buf, pos)
}
}

impl<A: ?Sized + AsyncWriteRent> AsyncWriteRent for &mut A {
#[inline]
fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
Expand All @@ -104,3 +115,56 @@ impl<A: ?Sized + AsyncWriteRent> AsyncWriteRent for &mut A {
(**self).shutdown()
}
}

impl AsyncWriteRent for Vec<u8> {
fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
let slice = buf.as_slice();
self.extend_from_slice(slice);
let len = slice.len();
std::future::ready((Ok(len), buf))
}

fn writev<T: IoVecBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
let mut sum = 0;
{
#[cfg(windows)]
let buf_slice =
unsafe { std::slice::from_raw_parts(buf.read_wsabuf_ptr(), buf.read_wsabuf_len()) };
#[cfg(unix)]
let buf_slice =
unsafe { std::slice::from_raw_parts(buf.read_iovec_ptr(), buf.read_iovec_len()) };
for buf in buf_slice {
#[cfg(windows)]
let len = buf.len as usize;
#[cfg(unix)]
let len = buf.iov_len;

sum += len;
}
self.reserve(sum);
for buf in buf_slice {
#[cfg(windows)]
let ptr = buf.buf.cast::<u8>();
#[cfg(unix)]
let ptr = buf.iov_base.cast::<u8>();
#[cfg(windows)]
let len = buf.len as usize;
#[cfg(unix)]
let len = buf.iov_len;

self.extend_from_slice(unsafe { std::slice::from_raw_parts(ptr, len) });
}
}
std::future::ready((Ok(sum), buf))
}

#[inline]
fn flush(&mut self) -> impl Future<Output = std::io::Result<()>> {
std::future::ready(Ok(()))
}

#[inline]
fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
std::future::ready(Ok(()))
}
}
8 changes: 4 additions & 4 deletions monoio/src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ impl AsyncWriteRent for TcpStream {
}

#[inline]
async fn flush(&mut self) -> std::io::Result<()> {
fn flush(&mut self) -> impl Future<Output = std::io::Result<()>> {
// Tcp stream does not need flush.
Ok(())
std::future::ready(Ok(()))
}

fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
Expand All @@ -347,7 +347,7 @@ impl AsyncWriteRent for TcpStream {
-1 => Err(io::Error::last_os_error()),
_ => Ok(()),
};
async move { res }
std::future::ready(res)
}
}

Expand Down Expand Up @@ -403,7 +403,7 @@ impl CancelableAsyncWriteRent for TcpStream {
-1 => Err(io::Error::last_os_error()),
_ => Ok(()),
};
async move { res }
std::future::ready(res)
}
}

Expand Down

0 comments on commit 8a8ad70

Please sign in to comment.