diff --git a/mountpoint-s3-client/src/s3_crt_client/put_object.rs b/mountpoint-s3-client/src/s3_crt_client/put_object.rs index 1b468fd96..b7a82553d 100644 --- a/mountpoint-s3-client/src/s3_crt_client/put_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/put_object.rs @@ -174,13 +174,17 @@ impl PutObjectRequest for S3PutObjectRequest { create_mpu.await.unwrap()?; } - // Write will fail if the request has already finished (because of an error). - self.body - .meta_request - .write(slice, false) - .await - .map_err(S3RequestError::CrtError)?; - self.total_bytes += slice.len() as u64; + let meta_request = &mut self.body.meta_request; + let mut slice = slice; + while !slice.is_empty() { + // Write will fail if the request has already finished (because of an error). + let remaining = meta_request + .write(slice, false) + .await + .map_err(S3RequestError::CrtError)?; + self.total_bytes += slice.len() as u64; + slice = remaining; + } Ok(()) } @@ -195,7 +199,8 @@ impl PutObjectRequest for S3PutObjectRequest { self.review_callback.set(review_callback); // Write will fail if the request has already finished (because of an error). - self.body + _ = self + .body .meta_request .write(&[], true) .await diff --git a/mountpoint-s3-client/tests/put_object.rs b/mountpoint-s3-client/tests/put_object.rs index dd893f78f..e966fdc25 100644 --- a/mountpoint-s3-client/tests/put_object.rs +++ b/mountpoint-s3-client/tests/put_object.rs @@ -14,7 +14,7 @@ use mountpoint_s3_client::types::{ }; use mountpoint_s3_client::{ObjectClient, PutObjectRequest, S3CrtClient, S3RequestError}; use mountpoint_s3_crt::checksums::crc32c; -use mountpoint_s3_crt_sys::aws_s3_errors; +use mountpoint_s3_crt_sys::aws_common_error; use rand::Rng; use test_case::test_case; @@ -236,7 +236,7 @@ async fn test_put_object_write_cancelled() { .await .expect_err("further writes should fail"); assert!( - matches!(err, ObjectClientError::ClientError(S3RequestError::CrtError(e)) if e.raw_error() == aws_s3_errors::AWS_ERROR_S3_REQUEST_HAS_COMPLETED as i32) + matches!(err, ObjectClientError::ClientError(S3RequestError::CrtError(e)) if e.raw_error() == aws_common_error::AWS_ERROR_INVALID_STATE as i32) ); } diff --git a/mountpoint-s3-crt-sys/crt/aws-c-s3 b/mountpoint-s3-crt-sys/crt/aws-c-s3 index f222ada33..fed06318b 160000 --- a/mountpoint-s3-crt-sys/crt/aws-c-s3 +++ b/mountpoint-s3-crt-sys/crt/aws-c-s3 @@ -1 +1 @@ -Subproject commit f222ada3392c94bdf77d0d889400d1128c90ee8c +Subproject commit fed06318b9798526f6403ace1a94adf71a579cb1 diff --git a/mountpoint-s3-crt/src/io/futures.rs b/mountpoint-s3-crt/src/io/futures.rs index 7e168f1c1..11ac95d23 100644 --- a/mountpoint-s3-crt/src/io/futures.rs +++ b/mountpoint-s3-crt/src/io/futures.rs @@ -3,19 +3,13 @@ use std::fmt::Debug; use std::future::Future; -use std::pin::Pin; -use std::ptr::NonNull; use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; +use std::task::{Context, Poll}; use futures::channel::oneshot; use futures::future::BoxFuture; use futures::task::ArcWake; use futures::{FutureExt, TryFutureExt}; -use mountpoint_s3_crt_sys::{ - aws_future_void, aws_future_void_get_error, aws_future_void_is_done, aws_future_void_register_callback, - aws_future_void_release, -}; use thiserror::Error; use crate::common::allocator::Allocator; @@ -226,117 +220,10 @@ pub enum JoinError { InternalError(#[from] crate::common::error::Error), } -/// Wraps a [aws_future_void]. -#[derive(Debug)] -pub struct FutureVoid { - inner: NonNull, - waker: Arc>>, -} - -// SAFETY: `aws_future_void` is thread-safe -unsafe impl Send for FutureVoid {} - -impl Drop for FutureVoid { - fn drop(&mut self) { - // SAFETY: `self.inner` contains a valid `aws_future_void`. - unsafe { - aws_future_void_release(self.inner.as_ptr()); - } - } -} - -impl FutureVoid { - /// Return whether the future is done - pub fn is_done(&self) -> bool { - // SAFETY: `self.inner` contains a valid `aws_future_void`. - unsafe { aws_future_void_is_done(self.inner.as_ptr()) } - } - - /// Create a [FutureVoid] from a [aws_future_void]. - /// - /// ## Safety - /// - /// `inner` must be a valid [aws_future_void] with no registered callbacks. - pub unsafe fn from_crt(inner: NonNull) -> Self { - Self { - inner, - waker: Arc::new(Mutex::new(None)), - } - } - - /// Get the result of this future if completed. - fn try_get_result(&self) -> Option> { - if !self.is_done() { - return None; - } - - let result = { - // SAFETY: `self.inner` has completed. - let future_result = unsafe { aws_future_void_get_error(self.inner.as_ptr()) }; - let error_result: crate::common::error::Error = future_result.into(); - if error_result.is_err() { - Err(error_result) - } else { - Ok(()) - } - }; - Some(result) - } -} - -impl Future for FutureVoid { - type Output = Result<(), crate::common::error::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut waker = self.waker.lock().unwrap(); - if let Some(result) = self.try_get_result() { - // The future has completed. Remove the waker, if any, and return the result. - _ = waker.take(); - Poll::Ready(result) - } else { - // The future has not completed yet. Do we need to register the callback? - match *waker { - Some(ref mut waker) => { - // The callback has already been registered, just replace the waker. - waker.clone_from(cx.waker()); - } - None => { - // Store the waker. Drop the lock in case the callback runs synchronously during registration. - *waker = Some(cx.waker().clone()); - drop(waker); - - // `user_data` will be cleaned up in `future_void_callback`. - let user_data = Arc::into_raw(self.waker.clone()) as *mut ::libc::c_void; - - // SAFETY: `self.inner.as_ptr()` is a valid `aws_future_void` and this is the only callback we are registering. - unsafe { - aws_future_void_register_callback(self.inner.as_ptr(), Some(future_void_callback), user_data); - } - } - } - Poll::Pending - } - } -} - -/// Safety: Don't call this function directly, only called by the CRT as a callback. -unsafe extern "C" fn future_void_callback(user_data: *mut ::libc::c_void) { - // Take ownership of the `Arc` in `user_data`. - let waker = Arc::from_raw(user_data as *mut Mutex>); - let Some(waker) = waker.lock().unwrap().take() else { - // Waker removed on `poll` finding that the future had already completed. - // Nothing to do here. - return; - }; - // Notify the waker that the future has completed. - waker.wake(); -} - #[cfg(test)] mod test { use futures::executor::block_on; use futures::future::join_all; - use mountpoint_s3_crt_sys::{aws_future_void_new, aws_future_void_set_error, aws_future_void_set_result}; use std::sync::atomic::{AtomicBool, AtomicU64}; use std::time::Duration; @@ -344,7 +231,6 @@ mod test { use crate::common::allocator::Allocator; use crate::io::event_loop::{EventLoopGroup, EventLoopTimer}; use std::sync::atomic::Ordering; - use test_case::test_case; /// Test that running a small future on an event loop works correctly. #[test] @@ -451,77 +337,4 @@ mod test { "flag should still be false after cancellation" ); } - - #[test_case(Ok(()))] - #[test_case(Err(42))] - fn test_future_void_already_done(value: Result<(), i32>) { - let allocator = Allocator::default(); - let aws_future = new_aws_future_void(&allocator); - set_aws_future_void_value(aws_future, value); - - // SAFETY: `aws_future` is a valid `aws_future_void`. - let future_void = unsafe { FutureVoid::from_crt(aws_future) }; - - // Verify that the wrapper has completed and contains the set value. - assert!(future_void.is_done()); - let Some(result) = future_void.try_get_result() else { - panic!("result should be available when the future is done"); - }; - assert_eq!(result.map_err(|e| e.raw_error()), value); - - // Verify that the wrapper returns the set value when awaited. - let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap(); - let future_handle = el_group.spawn_future(future_void); - let result = future_handle.wait().unwrap(); - assert_eq!(result.map_err(|e| e.raw_error()), value); - } - - #[test_case(Ok(()))] - #[test_case(Err(42))] - fn test_future_void_wake_up(value: Result<(), i32>) { - let allocator = Allocator::default(); - - let aws_future = new_aws_future_void(&allocator); - // SAFETY: `aws_future` is a valid `aws_future_void`. - let future_void = unsafe { FutureVoid::from_crt(aws_future) }; - - // Set up a flag that will set to true after awaiting future_void. - let flag = Arc::new(AtomicBool::new(false)); - - let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap(); - let future_handle = { - let flag = flag.clone(); - el_group.spawn_future(async move { - let result = future_void.await; - flag.store(true, Ordering::SeqCst); - result - }) - }; - assert!( - !flag.load(Ordering::SeqCst), - "the spawned future should not have completed and set the flag" - ); - set_aws_future_void_value(aws_future, value); - let result = future_handle.wait().unwrap(); - assert!( - flag.load(Ordering::SeqCst), - "the spawned future should have set the flag" - ); - assert_eq!(result.map_err(|e| e.raw_error()), value); - } - - fn new_aws_future_void(allocator: &Allocator) -> NonNull { - // SAFETY: `allocator` is a valid `aws_allocator` and `aws_future_void_new` returns a - // pointer to a valid `aws_future_void`. - unsafe { NonNull::new_unchecked(aws_future_void_new(allocator.inner.as_ptr())) } - } - - fn set_aws_future_void_value(aws_future: NonNull, value: Result<(), i32>) { - match value { - // SAFETY: `aws_future` is a valid `aws_future_void`. - Ok(()) => unsafe { aws_future_void_set_result(aws_future.as_ptr()) }, - // SAFETY: `aws_future` is a valid `aws_future_void`. - Err(code) => unsafe { aws_future_void_set_error(aws_future.as_ptr(), code) }, - } - } } diff --git a/mountpoint-s3-crt/src/s3/client.rs b/mountpoint-s3-crt/src/s3/client.rs index 23c54fb70..2f65a6e33 100644 --- a/mountpoint-s3-crt/src/s3/client.rs +++ b/mountpoint-s3-crt/src/s3/client.rs @@ -8,7 +8,6 @@ use crate::common::thread::ThreadId; use crate::common::uri::Uri; use crate::http::request_response::{Headers, Message}; use crate::io::channel_bootstrap::ClientBootstrap; -use crate::io::futures::FutureVoid; use crate::io::retry_strategy::RetryStrategy; use crate::s3::s3_library_init; use crate::{aws_byte_cursor_as_slice, CrtError, ResultExt, ToAwsByteCursor}; @@ -20,6 +19,8 @@ use std::marker::PhantomPinned; use std::os::unix::prelude::OsStrExt; use std::pin::Pin; use std::ptr::NonNull; +use std::sync::{Arc, Mutex}; +use std::task::Waker; use std::time::Duration; /// A client for high-throughput access to Amazon S3 @@ -539,9 +540,9 @@ impl MetaRequest { } } - /// Write a chunk of data and indicate whether it is the last. If invoked before the previous - /// write completed, or after setting `eof` to `true`, will return an AWS_ERROR_INVALID_STATE error. - pub fn write<'a>(&'a mut self, slice: &'a [u8], eof: bool) -> MetaRequestWrite<'a> { + /// Write a chunk of data and indicate whether it is the last. It will return an + /// AWS_ERROR_INVALID_STATE error if invoked after setting `eof` to `true`. + pub fn write<'r, 's: 'r>(&'r mut self, slice: &'s [u8], eof: bool) -> MetaRequestWrite<'r, 's> { MetaRequestWrite::new(self, slice, eof) } } @@ -573,51 +574,88 @@ unsafe impl Sync for MetaRequest {} /// Future returned by `MetaRequest::write()`. It will complete when the write completes, /// or cancel the meta-request if dropped. #[derive(Debug)] -pub struct MetaRequestWrite<'a> { - /// Signals when the write operation completes - future: FutureVoid, +pub struct MetaRequestWrite<'r, 's> { /// The meta-request to cancel if this future is dropped - request: &'a mut MetaRequest, + request: &'r mut MetaRequest, + /// The slice to write + slice: &'s [u8], + /// Is end-of-file? + eof: bool, + /// Waker registered with `poll_write` + waker: Arc>>, } -impl Drop for MetaRequestWrite<'_> { - fn drop(&mut self) { - if !self.future.is_done() { - // This future is being dropped before completion. Cancelling the meta-request - // guarantees that the client will not access the provided buffer. - self.request.cancel(); +impl<'r, 's> MetaRequestWrite<'r, 's> { + fn new(request: &'r mut MetaRequest, slice: &'s [u8], eof: bool) -> Self { + Self { + request, + slice, + eof, + waker: Default::default(), } } } -impl<'a> MetaRequestWrite<'a> { - fn new(request: &'a mut MetaRequest, slice: &'a [u8], eof: bool) -> Self { - // SAFETY: `MetaRequestWrite` will ensure that `slice` is alive until the future completes or - // that the meta-request is canceled if the future is dropped, preventing further use of the - // `aws_byte_cursor`. - let data = unsafe { slice.as_aws_byte_cursor() }; - - // SAFETY: `aws_s3_meta_request_write` never returns NULL. - let future = unsafe { - FutureVoid::from_crt(NonNull::new_unchecked(aws_s3_meta_request_write( - request.inner.as_ptr(), +impl<'r, 's> Future for MetaRequestWrite<'r, 's> { + type Output = Result<&'s [u8], Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let mut waker = self.waker.lock().unwrap(); + if let Some(ref mut waker) = *waker { + // The write is still pending, just replace the waker. + waker.clone_from(cx.waker()); + return std::task::Poll::Pending; + } + + // Store the waker. + *waker = Some(cx.waker().clone()); + + // `user_data` will be dropped in `poll_write_waker_callback` (or below). + let user_data = Arc::into_raw(self.waker.clone()) as *mut ::libc::c_void; + + // SAFETY: `aws_s3_meta_request_poll_write` does not store `data`. + let data = unsafe { self.slice.as_aws_byte_cursor() }; + + // SAFETY: `self.request` wraps a valid `aws_s3_meta_request` pointer. + let result = unsafe { + aws_s3_meta_request_poll_write( + self.request.inner.as_ptr(), data, - eof, - ))) + self.eof, + Some(poll_write_waker_callback), + user_data, + ) }; + if result.is_pending { + return std::task::Poll::Pending; + } - Self { future, request } - } -} + // SAFETY: `aws_s3_meta_request_poll_write` completed. It will not invoke `poll_write_waker_callback`, + // so we need to drop `user_data` here. + _ = unsafe { Arc::from_raw(user_data as *mut Mutex>) }; -impl Future for MetaRequestWrite<'_> { - type Output = Result<(), Error>; + let error_result: crate::common::error::Error = result.error_code.into(); + let result = if error_result.is_err() { + Err(error_result) + } else { + Ok(&self.slice[result.bytes_processed..]) + }; - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - Pin::new(&mut self.future).poll(cx) + std::task::Poll::Ready(result) } } +/// Safety: Don't call this function directly, only called by the CRT as a callback. +unsafe extern "C" fn poll_write_waker_callback(user_data: *mut ::libc::c_void) { + // Take ownership of the `Arc` in `user_data`. + let waker = Arc::from_raw(user_data as *mut Mutex>); + let Some(waker) = waker.lock().unwrap().take() else { + return; + }; + // Notify the waker. + waker.wake(); +} + /// Client metrics which represent current workload of a client. /// Overall, num_requests_tracked_requests shows total number of requests being processed by the client at a time. /// It can be broken down into these numbers by states of the client.