Skip to content

Commit

Permalink
Adopt polling API for uploading data in PutObject requests (#874)
Browse files Browse the repository at this point in the history
* Adopt polling API for uploading data

Signed-off-by: Alessandro Passaro <[email protected]>

* Detect incomplete writes

Signed-off-by: Alessandro Passaro <[email protected]>

* Improve comments

Signed-off-by: Alessandro Passaro <[email protected]>

* Update `MetaRequestWrite` rustdocs

Signed-off-by: Alessandro Passaro <[email protected]>

* Fix `total_bytes` calculation and expand comments

Signed-off-by: Alessandro Passaro <[email protected]>

* Remove unnecessary lifetime constraint

Signed-off-by: Alessandro Passaro <[email protected]>

---------

Signed-off-by: Alessandro Passaro <[email protected]>
Co-authored-by: Alessandro Passaro <[email protected]>
  • Loading branch information
passaro and Alessandro Passaro authored May 14, 2024
1 parent 50720ab commit 2a3a06f
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 246 deletions.
65 changes: 49 additions & 16 deletions mountpoint-s3-client/src/s3_crt_client/put_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl S3CrtClient {
start_time: Instant::now(),
total_bytes: 0,
response_headers,
pending_create_mpu: Some(mpu_created),
state: S3PutObjectRequestState::CreatingMPU(mpu_created),
})
}
}
Expand Down Expand Up @@ -154,9 +154,20 @@ pub struct S3PutObjectRequest {
total_bytes: u64,
/// Headers of the CompleteMultipartUpload response, available after the request was finished
response_headers: Arc<Mutex<Option<Headers>>>,
/// Signal indicating that CreateMultipartUpload completed successfully, or that the MPU failed.
/// Set to [None] once awaited on the first write, meaning the MPU was already created or failed.
pending_create_mpu: Option<oneshot::Receiver<Result<(), S3RequestError>>>,
state: S3PutObjectRequestState,
}

/// Internal state for a [S3PutObjectRequest].
#[derive(Debug)]
enum S3PutObjectRequestState {
/// Initial state indicating that CreateMultipartUpload may still be in progress. To be awaited on first
/// write so errors can be reported early. The signal indicates that CreateMultipartUpload completed
/// successfully, or that the MPU failed.
CreatingMPU(oneshot::Receiver<Result<(), S3RequestError>>),
/// A write operation is in progress or was interrupted before completion.
PendingWrite,
/// Idle state between write calls.
Idle,
}

fn try_get_header_value(headers: &Headers, key: &str) -> Option<String> {
Expand All @@ -168,19 +179,34 @@ impl PutObjectRequest for S3PutObjectRequest {
type ClientError = S3RequestError;

async fn write(&mut self, slice: &[u8]) -> ObjectClientResult<(), PutObjectError, Self::ClientError> {
// On first write, check the pending CreateMultipartUpload.
if let Some(create_mpu) = self.pending_create_mpu.take() {
// Wait for CreateMultipartUpload to complete successfully, or the MPU to fail.
create_mpu.await.unwrap()?;
// Writing to the meta request may require multiple calls. Set the internal
// state to `PendingWrite` until we are done.
match std::mem::replace(&mut self.state, S3PutObjectRequestState::PendingWrite) {
S3PutObjectRequestState::CreatingMPU(create_mpu) => {
// On first write, check the pending CreateMultipartUpload so we can report errors.
// Wait for CreateMultipartUpload to complete successfully, or the MPU to fail.
create_mpu.await.unwrap()?;
}
S3PutObjectRequestState::PendingWrite => {
// Fail if a previous write was not completed.
return Err(S3RequestError::RequestCanceled.into());
}
S3PutObjectRequestState::Idle => {}
}

// Write will fail if the request has already finished (because of an error).
self.body
.meta_request
.write(slice, false)
.await
.map_err(S3RequestError::CrtError)?;
self.total_bytes += slice.len() as u64;
let meta_request = &mut self.body.meta_request;
let mut slice = slice;
while !slice.is_empty() {
// Write will fail if the request has already finished (because of an error).
let remaining = meta_request
.write(slice, false)
.await
.map_err(S3RequestError::CrtError)?;
self.total_bytes += (slice.len() - remaining.len()) as u64;
slice = remaining;
}
// Write completed with no errors, we can reset to `Idle`.
self.state = S3PutObjectRequestState::Idle;
Ok(())
}

Expand All @@ -192,10 +218,17 @@ impl PutObjectRequest for S3PutObjectRequest {
mut self,
review_callback: impl FnOnce(UploadReview) -> bool + Send + 'static,
) -> ObjectClientResult<PutObjectResult, PutObjectError, Self::ClientError> {
// No need to check for `CreatingMPU`: errors will be reported on completing the upload.
if matches!(self.state, S3PutObjectRequestState::PendingWrite) {
// Fail if a previous write was not completed.
return Err(S3RequestError::RequestCanceled.into());
}

self.review_callback.set(review_callback);

// Write will fail if the request has already finished (because of an error).
self.body
_ = self
.body
.meta_request
.write(&[], true)
.await
Expand Down
10 changes: 5 additions & 5 deletions mountpoint-s3-client/tests/put_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use mountpoint_s3_client::types::{
};
use mountpoint_s3_client::{ObjectClient, PutObjectRequest, S3CrtClient, S3RequestError};
use mountpoint_s3_crt::checksums::crc32c;
use mountpoint_s3_crt_sys::aws_s3_errors;
use rand::Rng;
use test_case::test_case;

Expand Down Expand Up @@ -221,7 +220,7 @@ async fn test_put_object_write_cancelled() {
request.write(&[1, 2, 3, 4]).await.expect("write should succeed");

{
// Write a multiple of `part_size` to ensure the copy is deferred.
// Write a multiple of `part_size` to ensure it will not complete immediately.
let size = client.part_size().unwrap() * 10;
let buffer = vec![0u8; size];
let write = request.write(&buffer);
Expand All @@ -235,9 +234,10 @@ async fn test_put_object_write_cancelled() {
.write(&[1, 2, 3, 4])
.await
.expect_err("further writes should fail");
assert!(
matches!(err, ObjectClientError::ClientError(S3RequestError::CrtError(e)) if e.raw_error() == aws_s3_errors::AWS_ERROR_S3_REQUEST_HAS_COMPLETED as i32)
);
assert!(matches!(
err,
ObjectClientError::ClientError(S3RequestError::RequestCanceled)
));
}

#[tokio::test]
Expand Down
189 changes: 1 addition & 188 deletions mountpoint-s3-crt/src/io/futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::task::{Context, Poll};

use futures::channel::oneshot;
use futures::future::BoxFuture;
use futures::task::ArcWake;
use futures::{FutureExt, TryFutureExt};
use mountpoint_s3_crt_sys::{
aws_future_void, aws_future_void_get_error, aws_future_void_is_done, aws_future_void_register_callback,
aws_future_void_release,
};
use thiserror::Error;

use crate::common::allocator::Allocator;
Expand Down Expand Up @@ -226,125 +220,17 @@ pub enum JoinError {
InternalError(#[from] crate::common::error::Error),
}

/// Wraps a [aws_future_void].
#[derive(Debug)]
pub struct FutureVoid {
inner: NonNull<aws_future_void>,
waker: Arc<Mutex<Option<Waker>>>,
}

// SAFETY: `aws_future_void` is thread-safe
unsafe impl Send for FutureVoid {}

impl Drop for FutureVoid {
fn drop(&mut self) {
// SAFETY: `self.inner` contains a valid `aws_future_void`.
unsafe {
aws_future_void_release(self.inner.as_ptr());
}
}
}

impl FutureVoid {
/// Return whether the future is done
pub fn is_done(&self) -> bool {
// SAFETY: `self.inner` contains a valid `aws_future_void`.
unsafe { aws_future_void_is_done(self.inner.as_ptr()) }
}

/// Create a [FutureVoid] from a [aws_future_void].
///
/// ## Safety
///
/// `inner` must be a valid [aws_future_void] with no registered callbacks.
pub unsafe fn from_crt(inner: NonNull<aws_future_void>) -> Self {
Self {
inner,
waker: Arc::new(Mutex::new(None)),
}
}

/// Get the result of this future if completed.
fn try_get_result(&self) -> Option<Result<(), crate::common::error::Error>> {
if !self.is_done() {
return None;
}

let result = {
// SAFETY: `self.inner` has completed.
let future_result = unsafe { aws_future_void_get_error(self.inner.as_ptr()) };
let error_result: crate::common::error::Error = future_result.into();
if error_result.is_err() {
Err(error_result)
} else {
Ok(())
}
};
Some(result)
}
}

impl Future for FutureVoid {
type Output = Result<(), crate::common::error::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut waker = self.waker.lock().unwrap();
if let Some(result) = self.try_get_result() {
// The future has completed. Remove the waker, if any, and return the result.
_ = waker.take();
Poll::Ready(result)
} else {
// The future has not completed yet. Do we need to register the callback?
match *waker {
Some(ref mut waker) => {
// The callback has already been registered, just replace the waker.
waker.clone_from(cx.waker());
}
None => {
// Store the waker. Drop the lock in case the callback runs synchronously during registration.
*waker = Some(cx.waker().clone());
drop(waker);

// `user_data` will be cleaned up in `future_void_callback`.
let user_data = Arc::into_raw(self.waker.clone()) as *mut ::libc::c_void;

// SAFETY: `self.inner.as_ptr()` is a valid `aws_future_void` and this is the only callback we are registering.
unsafe {
aws_future_void_register_callback(self.inner.as_ptr(), Some(future_void_callback), user_data);
}
}
}
Poll::Pending
}
}
}

/// Safety: Don't call this function directly, only called by the CRT as a callback.
unsafe extern "C" fn future_void_callback(user_data: *mut ::libc::c_void) {
// Take ownership of the `Arc` in `user_data`.
let waker = Arc::from_raw(user_data as *mut Mutex<Option<Waker>>);
let Some(waker) = waker.lock().unwrap().take() else {
// Waker removed on `poll` finding that the future had already completed.
// Nothing to do here.
return;
};
// Notify the waker that the future has completed.
waker.wake();
}

#[cfg(test)]
mod test {
use futures::executor::block_on;
use futures::future::join_all;
use mountpoint_s3_crt_sys::{aws_future_void_new, aws_future_void_set_error, aws_future_void_set_result};
use std::sync::atomic::{AtomicBool, AtomicU64};
use std::time::Duration;

use super::*;
use crate::common::allocator::Allocator;
use crate::io::event_loop::{EventLoopGroup, EventLoopTimer};
use std::sync::atomic::Ordering;
use test_case::test_case;

/// Test that running a small future on an event loop works correctly.
#[test]
Expand Down Expand Up @@ -451,77 +337,4 @@ mod test {
"flag should still be false after cancellation"
);
}

#[test_case(Ok(()))]
#[test_case(Err(42))]
fn test_future_void_already_done(value: Result<(), i32>) {
let allocator = Allocator::default();
let aws_future = new_aws_future_void(&allocator);
set_aws_future_void_value(aws_future, value);

// SAFETY: `aws_future` is a valid `aws_future_void`.
let future_void = unsafe { FutureVoid::from_crt(aws_future) };

// Verify that the wrapper has completed and contains the set value.
assert!(future_void.is_done());
let Some(result) = future_void.try_get_result() else {
panic!("result should be available when the future is done");
};
assert_eq!(result.map_err(|e| e.raw_error()), value);

// Verify that the wrapper returns the set value when awaited.
let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
let future_handle = el_group.spawn_future(future_void);
let result = future_handle.wait().unwrap();
assert_eq!(result.map_err(|e| e.raw_error()), value);
}

#[test_case(Ok(()))]
#[test_case(Err(42))]
fn test_future_void_wake_up(value: Result<(), i32>) {
let allocator = Allocator::default();

let aws_future = new_aws_future_void(&allocator);
// SAFETY: `aws_future` is a valid `aws_future_void`.
let future_void = unsafe { FutureVoid::from_crt(aws_future) };

// Set up a flag that will set to true after awaiting future_void.
let flag = Arc::new(AtomicBool::new(false));

let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
let future_handle = {
let flag = flag.clone();
el_group.spawn_future(async move {
let result = future_void.await;
flag.store(true, Ordering::SeqCst);
result
})
};
assert!(
!flag.load(Ordering::SeqCst),
"the spawned future should not have completed and set the flag"
);
set_aws_future_void_value(aws_future, value);
let result = future_handle.wait().unwrap();
assert!(
flag.load(Ordering::SeqCst),
"the spawned future should have set the flag"
);
assert_eq!(result.map_err(|e| e.raw_error()), value);
}

fn new_aws_future_void(allocator: &Allocator) -> NonNull<aws_future_void> {
// SAFETY: `allocator` is a valid `aws_allocator` and `aws_future_void_new` returns a
// pointer to a valid `aws_future_void`.
unsafe { NonNull::new_unchecked(aws_future_void_new(allocator.inner.as_ptr())) }
}

fn set_aws_future_void_value(aws_future: NonNull<aws_future_void>, value: Result<(), i32>) {
match value {
// SAFETY: `aws_future` is a valid `aws_future_void`.
Ok(()) => unsafe { aws_future_void_set_result(aws_future.as_ptr()) },
// SAFETY: `aws_future` is a valid `aws_future_void`.
Err(code) => unsafe { aws_future_void_set_error(aws_future.as_ptr(), code) },
}
}
}
Loading

0 comments on commit 2a3a06f

Please sign in to comment.