diff --git a/mountpoint-s3-client/src/s3_crt_client/put_object.rs b/mountpoint-s3-client/src/s3_crt_client/put_object.rs index 1b468fd96..cbcbee307 100644 --- a/mountpoint-s3-client/src/s3_crt_client/put_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/put_object.rs @@ -104,7 +104,7 @@ impl S3CrtClient { start_time: Instant::now(), total_bytes: 0, response_headers, - pending_create_mpu: Some(mpu_created), + state: S3PutObjectRequestState::CreatingMPU(mpu_created), }) } } @@ -154,9 +154,20 @@ pub struct S3PutObjectRequest { total_bytes: u64, /// Headers of the CompleteMultipartUpload response, available after the request was finished response_headers: Arc>>, - /// Signal indicating that CreateMultipartUpload completed successfully, or that the MPU failed. - /// Set to [None] once awaited on the first write, meaning the MPU was already created or failed. - pending_create_mpu: Option>>, + state: S3PutObjectRequestState, +} + +/// Internal state for a [S3PutObjectRequest]. +#[derive(Debug)] +enum S3PutObjectRequestState { + /// Initial state indicating that CreateMultipartUpload may still be in progress. To be awaited on first + /// write so errors can be reported early. The signal indicates that CreateMultipartUpload completed + /// successfully, or that the MPU failed. + CreatingMPU(oneshot::Receiver>), + /// A write operation is in progress or was interrupted before completion. + PendingWrite, + /// Idle state between write calls. + Idle, } fn try_get_header_value(headers: &Headers, key: &str) -> Option { @@ -168,19 +179,34 @@ impl PutObjectRequest for S3PutObjectRequest { type ClientError = S3RequestError; async fn write(&mut self, slice: &[u8]) -> ObjectClientResult<(), PutObjectError, Self::ClientError> { - // On first write, check the pending CreateMultipartUpload. - if let Some(create_mpu) = self.pending_create_mpu.take() { - // Wait for CreateMultipartUpload to complete successfully, or the MPU to fail. - create_mpu.await.unwrap()?; + // Writing to the meta request may require multiple calls. Set the internal + // state to `PendingWrite` until we are done. + match std::mem::replace(&mut self.state, S3PutObjectRequestState::PendingWrite) { + S3PutObjectRequestState::CreatingMPU(create_mpu) => { + // On first write, check the pending CreateMultipartUpload so we can report errors. + // Wait for CreateMultipartUpload to complete successfully, or the MPU to fail. + create_mpu.await.unwrap()?; + } + S3PutObjectRequestState::PendingWrite => { + // Fail if a previous write was not completed. + return Err(S3RequestError::RequestCanceled.into()); + } + S3PutObjectRequestState::Idle => {} } - // Write will fail if the request has already finished (because of an error). - self.body - .meta_request - .write(slice, false) - .await - .map_err(S3RequestError::CrtError)?; - self.total_bytes += slice.len() as u64; + let meta_request = &mut self.body.meta_request; + let mut slice = slice; + while !slice.is_empty() { + // Write will fail if the request has already finished (because of an error). + let remaining = meta_request + .write(slice, false) + .await + .map_err(S3RequestError::CrtError)?; + self.total_bytes += (slice.len() - remaining.len()) as u64; + slice = remaining; + } + // Write completed with no errors, we can reset to `Idle`. + self.state = S3PutObjectRequestState::Idle; Ok(()) } @@ -192,10 +218,17 @@ impl PutObjectRequest for S3PutObjectRequest { mut self, review_callback: impl FnOnce(UploadReview) -> bool + Send + 'static, ) -> ObjectClientResult { + // No need to check for `CreatingMPU`: errors will be reported on completing the upload. + if matches!(self.state, S3PutObjectRequestState::PendingWrite) { + // Fail if a previous write was not completed. + return Err(S3RequestError::RequestCanceled.into()); + } + self.review_callback.set(review_callback); // Write will fail if the request has already finished (because of an error). - self.body + _ = self + .body .meta_request .write(&[], true) .await diff --git a/mountpoint-s3-client/tests/put_object.rs b/mountpoint-s3-client/tests/put_object.rs index dd893f78f..478c26f7f 100644 --- a/mountpoint-s3-client/tests/put_object.rs +++ b/mountpoint-s3-client/tests/put_object.rs @@ -14,7 +14,6 @@ use mountpoint_s3_client::types::{ }; use mountpoint_s3_client::{ObjectClient, PutObjectRequest, S3CrtClient, S3RequestError}; use mountpoint_s3_crt::checksums::crc32c; -use mountpoint_s3_crt_sys::aws_s3_errors; use rand::Rng; use test_case::test_case; @@ -221,7 +220,7 @@ async fn test_put_object_write_cancelled() { request.write(&[1, 2, 3, 4]).await.expect("write should succeed"); { - // Write a multiple of `part_size` to ensure the copy is deferred. + // Write a multiple of `part_size` to ensure it will not complete immediately. let size = client.part_size().unwrap() * 10; let buffer = vec![0u8; size]; let write = request.write(&buffer); @@ -235,9 +234,10 @@ async fn test_put_object_write_cancelled() { .write(&[1, 2, 3, 4]) .await .expect_err("further writes should fail"); - assert!( - matches!(err, ObjectClientError::ClientError(S3RequestError::CrtError(e)) if e.raw_error() == aws_s3_errors::AWS_ERROR_S3_REQUEST_HAS_COMPLETED as i32) - ); + assert!(matches!( + err, + ObjectClientError::ClientError(S3RequestError::RequestCanceled) + )); } #[tokio::test] diff --git a/mountpoint-s3-crt/src/io/futures.rs b/mountpoint-s3-crt/src/io/futures.rs index 7e168f1c1..11ac95d23 100644 --- a/mountpoint-s3-crt/src/io/futures.rs +++ b/mountpoint-s3-crt/src/io/futures.rs @@ -3,19 +3,13 @@ use std::fmt::Debug; use std::future::Future; -use std::pin::Pin; -use std::ptr::NonNull; use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; +use std::task::{Context, Poll}; use futures::channel::oneshot; use futures::future::BoxFuture; use futures::task::ArcWake; use futures::{FutureExt, TryFutureExt}; -use mountpoint_s3_crt_sys::{ - aws_future_void, aws_future_void_get_error, aws_future_void_is_done, aws_future_void_register_callback, - aws_future_void_release, -}; use thiserror::Error; use crate::common::allocator::Allocator; @@ -226,117 +220,10 @@ pub enum JoinError { InternalError(#[from] crate::common::error::Error), } -/// Wraps a [aws_future_void]. -#[derive(Debug)] -pub struct FutureVoid { - inner: NonNull, - waker: Arc>>, -} - -// SAFETY: `aws_future_void` is thread-safe -unsafe impl Send for FutureVoid {} - -impl Drop for FutureVoid { - fn drop(&mut self) { - // SAFETY: `self.inner` contains a valid `aws_future_void`. - unsafe { - aws_future_void_release(self.inner.as_ptr()); - } - } -} - -impl FutureVoid { - /// Return whether the future is done - pub fn is_done(&self) -> bool { - // SAFETY: `self.inner` contains a valid `aws_future_void`. - unsafe { aws_future_void_is_done(self.inner.as_ptr()) } - } - - /// Create a [FutureVoid] from a [aws_future_void]. - /// - /// ## Safety - /// - /// `inner` must be a valid [aws_future_void] with no registered callbacks. - pub unsafe fn from_crt(inner: NonNull) -> Self { - Self { - inner, - waker: Arc::new(Mutex::new(None)), - } - } - - /// Get the result of this future if completed. - fn try_get_result(&self) -> Option> { - if !self.is_done() { - return None; - } - - let result = { - // SAFETY: `self.inner` has completed. - let future_result = unsafe { aws_future_void_get_error(self.inner.as_ptr()) }; - let error_result: crate::common::error::Error = future_result.into(); - if error_result.is_err() { - Err(error_result) - } else { - Ok(()) - } - }; - Some(result) - } -} - -impl Future for FutureVoid { - type Output = Result<(), crate::common::error::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut waker = self.waker.lock().unwrap(); - if let Some(result) = self.try_get_result() { - // The future has completed. Remove the waker, if any, and return the result. - _ = waker.take(); - Poll::Ready(result) - } else { - // The future has not completed yet. Do we need to register the callback? - match *waker { - Some(ref mut waker) => { - // The callback has already been registered, just replace the waker. - waker.clone_from(cx.waker()); - } - None => { - // Store the waker. Drop the lock in case the callback runs synchronously during registration. - *waker = Some(cx.waker().clone()); - drop(waker); - - // `user_data` will be cleaned up in `future_void_callback`. - let user_data = Arc::into_raw(self.waker.clone()) as *mut ::libc::c_void; - - // SAFETY: `self.inner.as_ptr()` is a valid `aws_future_void` and this is the only callback we are registering. - unsafe { - aws_future_void_register_callback(self.inner.as_ptr(), Some(future_void_callback), user_data); - } - } - } - Poll::Pending - } - } -} - -/// Safety: Don't call this function directly, only called by the CRT as a callback. -unsafe extern "C" fn future_void_callback(user_data: *mut ::libc::c_void) { - // Take ownership of the `Arc` in `user_data`. - let waker = Arc::from_raw(user_data as *mut Mutex>); - let Some(waker) = waker.lock().unwrap().take() else { - // Waker removed on `poll` finding that the future had already completed. - // Nothing to do here. - return; - }; - // Notify the waker that the future has completed. - waker.wake(); -} - #[cfg(test)] mod test { use futures::executor::block_on; use futures::future::join_all; - use mountpoint_s3_crt_sys::{aws_future_void_new, aws_future_void_set_error, aws_future_void_set_result}; use std::sync::atomic::{AtomicBool, AtomicU64}; use std::time::Duration; @@ -344,7 +231,6 @@ mod test { use crate::common::allocator::Allocator; use crate::io::event_loop::{EventLoopGroup, EventLoopTimer}; use std::sync::atomic::Ordering; - use test_case::test_case; /// Test that running a small future on an event loop works correctly. #[test] @@ -451,77 +337,4 @@ mod test { "flag should still be false after cancellation" ); } - - #[test_case(Ok(()))] - #[test_case(Err(42))] - fn test_future_void_already_done(value: Result<(), i32>) { - let allocator = Allocator::default(); - let aws_future = new_aws_future_void(&allocator); - set_aws_future_void_value(aws_future, value); - - // SAFETY: `aws_future` is a valid `aws_future_void`. - let future_void = unsafe { FutureVoid::from_crt(aws_future) }; - - // Verify that the wrapper has completed and contains the set value. - assert!(future_void.is_done()); - let Some(result) = future_void.try_get_result() else { - panic!("result should be available when the future is done"); - }; - assert_eq!(result.map_err(|e| e.raw_error()), value); - - // Verify that the wrapper returns the set value when awaited. - let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap(); - let future_handle = el_group.spawn_future(future_void); - let result = future_handle.wait().unwrap(); - assert_eq!(result.map_err(|e| e.raw_error()), value); - } - - #[test_case(Ok(()))] - #[test_case(Err(42))] - fn test_future_void_wake_up(value: Result<(), i32>) { - let allocator = Allocator::default(); - - let aws_future = new_aws_future_void(&allocator); - // SAFETY: `aws_future` is a valid `aws_future_void`. - let future_void = unsafe { FutureVoid::from_crt(aws_future) }; - - // Set up a flag that will set to true after awaiting future_void. - let flag = Arc::new(AtomicBool::new(false)); - - let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap(); - let future_handle = { - let flag = flag.clone(); - el_group.spawn_future(async move { - let result = future_void.await; - flag.store(true, Ordering::SeqCst); - result - }) - }; - assert!( - !flag.load(Ordering::SeqCst), - "the spawned future should not have completed and set the flag" - ); - set_aws_future_void_value(aws_future, value); - let result = future_handle.wait().unwrap(); - assert!( - flag.load(Ordering::SeqCst), - "the spawned future should have set the flag" - ); - assert_eq!(result.map_err(|e| e.raw_error()), value); - } - - fn new_aws_future_void(allocator: &Allocator) -> NonNull { - // SAFETY: `allocator` is a valid `aws_allocator` and `aws_future_void_new` returns a - // pointer to a valid `aws_future_void`. - unsafe { NonNull::new_unchecked(aws_future_void_new(allocator.inner.as_ptr())) } - } - - fn set_aws_future_void_value(aws_future: NonNull, value: Result<(), i32>) { - match value { - // SAFETY: `aws_future` is a valid `aws_future_void`. - Ok(()) => unsafe { aws_future_void_set_result(aws_future.as_ptr()) }, - // SAFETY: `aws_future` is a valid `aws_future_void`. - Err(code) => unsafe { aws_future_void_set_error(aws_future.as_ptr(), code) }, - } - } } diff --git a/mountpoint-s3-crt/src/s3/client.rs b/mountpoint-s3-crt/src/s3/client.rs index 23c54fb70..d34ca5300 100644 --- a/mountpoint-s3-crt/src/s3/client.rs +++ b/mountpoint-s3-crt/src/s3/client.rs @@ -8,7 +8,6 @@ use crate::common::thread::ThreadId; use crate::common::uri::Uri; use crate::http::request_response::{Headers, Message}; use crate::io::channel_bootstrap::ClientBootstrap; -use crate::io::futures::FutureVoid; use crate::io::retry_strategy::RetryStrategy; use crate::s3::s3_library_init; use crate::{aws_byte_cursor_as_slice, CrtError, ResultExt, ToAwsByteCursor}; @@ -20,6 +19,8 @@ use std::marker::PhantomPinned; use std::os::unix::prelude::OsStrExt; use std::pin::Pin; use std::ptr::NonNull; +use std::sync::{Arc, Mutex}; +use std::task::Waker; use std::time::Duration; /// A client for high-throughput access to Amazon S3 @@ -539,9 +540,14 @@ impl MetaRequest { } } - /// Write a chunk of data and indicate whether it is the last. If invoked before the previous - /// write completed, or after setting `eof` to `true`, will return an AWS_ERROR_INVALID_STATE error. - pub fn write<'a>(&'a mut self, slice: &'a [u8], eof: bool) -> MetaRequestWrite<'a> { + /// Write a chunk of data and indicate whether it is the last. Returns a [MetaRequestWrite] + /// future that starts writing when polled. May perform incomplete writes: in that case, + /// the future returns the suffix of the input slice that has not been written. The caller + /// is expected to invoke `write` again with the remaining data, until the empty slice is + /// returned. + /// Once an invocation with `eof == true` returns with the empty slice, subsequent invocations + /// will fail with an AWS_ERROR_INVALID_STATE error. + pub fn write<'r, 's>(&'r mut self, slice: &'s [u8], eof: bool) -> MetaRequestWrite<'r, 's> { MetaRequestWrite::new(self, slice, eof) } } @@ -570,54 +576,95 @@ unsafe impl Send for MetaRequest {} // SAFETY: `aws_s3_meta_request` is thread safe unsafe impl Sync for MetaRequest {} -/// Future returned by `MetaRequest::write()`. It will complete when the write completes, -/// or cancel the meta-request if dropped. +/// Future returned by `MetaRequest::write()`. Wraps `aws_s3_meta_request_poll_write`. #[derive(Debug)] -pub struct MetaRequestWrite<'a> { - /// Signals when the write operation completes - future: FutureVoid, - /// The meta-request to cancel if this future is dropped - request: &'a mut MetaRequest, +pub struct MetaRequestWrite<'r, 's> { + /// The meta-request to write to. + request: &'r mut MetaRequest, + /// The slice to write + slice: &'s [u8], + /// Is end-of-file? + eof: bool, + /// Holds the waker from the current context. Passed to `poll_write_waker_callback` + /// in order to trigger another poll. + waker: Arc>>, } -impl Drop for MetaRequestWrite<'_> { - fn drop(&mut self) { - if !self.future.is_done() { - // This future is being dropped before completion. Cancelling the meta-request - // guarantees that the client will not access the provided buffer. - self.request.cancel(); +impl<'r, 's> MetaRequestWrite<'r, 's> { + fn new(request: &'r mut MetaRequest, slice: &'s [u8], eof: bool) -> Self { + Self { + request, + slice, + eof, + waker: Default::default(), } } } -impl<'a> MetaRequestWrite<'a> { - fn new(request: &'a mut MetaRequest, slice: &'a [u8], eof: bool) -> Self { - // SAFETY: `MetaRequestWrite` will ensure that `slice` is alive until the future completes or - // that the meta-request is canceled if the future is dropped, preventing further use of the - // `aws_byte_cursor`. - let data = unsafe { slice.as_aws_byte_cursor() }; - - // SAFETY: `aws_s3_meta_request_write` never returns NULL. - let future = unsafe { - FutureVoid::from_crt(NonNull::new_unchecked(aws_s3_meta_request_write( - request.inner.as_ptr(), +impl<'r, 's> Future for MetaRequestWrite<'r, 's> { + type Output = Result<&'s [u8], Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let mut waker = self.waker.lock().unwrap(); + if let Some(ref mut waker) = *waker { + // The previous `aws_s3_meta_request_poll_write` call returned `Pending` but has not + // invoked the callback yet. Do not call it again, but make sure to store the waker + // from the current context. + waker.clone_from(cx.waker()); + return std::task::Poll::Pending; + } + + // Store the waker. + *waker = Some(cx.waker().clone()); + + // `user_data` will be dropped in `poll_write_waker_callback` (or below). + let user_data = Arc::into_raw(self.waker.clone()) as *mut ::libc::c_void; + + // SAFETY: `aws_s3_meta_request_poll_write` does not store `data`. + let data = unsafe { self.slice.as_aws_byte_cursor() }; + + // SAFETY: `self.request` wraps a valid `aws_s3_meta_request` pointer. + let result = unsafe { + aws_s3_meta_request_poll_write( + self.request.inner.as_ptr(), data, - eof, - ))) + self.eof, + Some(poll_write_waker_callback), + user_data, + ) }; + if result.is_pending { + return std::task::Poll::Pending; + } - Self { future, request } - } -} + // SAFETY: `aws_s3_meta_request_poll_write` completed. It will not invoke `poll_write_waker_callback`, + // so we need to drop `user_data` here. + _ = unsafe { Arc::from_raw(user_data as *mut Mutex>) }; -impl Future for MetaRequestWrite<'_> { - type Output = Result<(), Error>; + let error_result: crate::common::error::Error = result.error_code.into(); + let result = if error_result.is_err() { + Err(error_result) + } else { + Ok(&self.slice[result.bytes_processed..]) + }; - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - Pin::new(&mut self.future).poll(cx) + std::task::Poll::Ready(result) } } +/// Safety: Don't call this function directly, only called by the CRT as a callback. +unsafe extern "C" fn poll_write_waker_callback(user_data: *mut ::libc::c_void) { + // Take ownership of the `Arc` in `user_data`. + let waker = Arc::from_raw(user_data as *mut Mutex>); + // Notify the waker. + waker + .lock() + .unwrap() + .take() + .expect("user_data always contains a waker") + .wake(); +} + /// Client metrics which represent current workload of a client. /// Overall, num_requests_tracked_requests shows total number of requests being processed by the client at a time. /// It can be broken down into these numbers by states of the client.