Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt polling API for uploading data in PutObject requests #874

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions mountpoint-s3-client/src/s3_crt_client/put_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl S3CrtClient {
start_time: Instant::now(),
total_bytes: 0,
response_headers,
pending_create_mpu: Some(mpu_created),
state: S3PutObjectRequestState::CreatingMPU(mpu_created),
})
}
}
Expand Down Expand Up @@ -154,9 +154,20 @@ pub struct S3PutObjectRequest {
total_bytes: u64,
/// Headers of the CompleteMultipartUpload response, available after the request was finished
response_headers: Arc<Mutex<Option<Headers>>>,
/// Signal indicating that CreateMultipartUpload completed successfully, or that the MPU failed.
/// Set to [None] once awaited on the first write, meaning the MPU was already created or failed.
pending_create_mpu: Option<oneshot::Receiver<Result<(), S3RequestError>>>,
state: S3PutObjectRequestState,
}

/// Internal state for a [S3PutObjectRequest].
#[derive(Debug)]
enum S3PutObjectRequestState {
/// Initial state indicating that CreateMultipartUpload may still be in progress. To be awaited on first
/// write so errors can be reported early. The signal indicates that CreateMultipartUpload completed
/// successfully, or that the MPU failed.
CreatingMPU(oneshot::Receiver<Result<(), S3RequestError>>),
/// A write operation is in progress or was interrupted before completion.
PendingWrite,
/// Idle state between write calls.
Idle,
vladem marked this conversation as resolved.
Show resolved Hide resolved
}

fn try_get_header_value(headers: &Headers, key: &str) -> Option<String> {
Expand All @@ -168,19 +179,34 @@ impl PutObjectRequest for S3PutObjectRequest {
type ClientError = S3RequestError;

async fn write(&mut self, slice: &[u8]) -> ObjectClientResult<(), PutObjectError, Self::ClientError> {
// On first write, check the pending CreateMultipartUpload.
if let Some(create_mpu) = self.pending_create_mpu.take() {
// Wait for CreateMultipartUpload to complete successfully, or the MPU to fail.
create_mpu.await.unwrap()?;
// Writing to the meta request may require multiple calls. Set the internal
// state to `PendingWrite` until we are done.
match std::mem::replace(&mut self.state, S3PutObjectRequestState::PendingWrite) {
S3PutObjectRequestState::CreatingMPU(create_mpu) => {
// On first write, check the pending CreateMultipartUpload so we can report errors.
// Wait for CreateMultipartUpload to complete successfully, or the MPU to fail.
create_mpu.await.unwrap()?;
}
S3PutObjectRequestState::PendingWrite => {
// Fail if a previous write was not completed.
return Err(S3RequestError::RequestCanceled.into());
}
S3PutObjectRequestState::Idle => {}
}

// Write will fail if the request has already finished (because of an error).
self.body
.meta_request
.write(slice, false)
.await
.map_err(S3RequestError::CrtError)?;
self.total_bytes += slice.len() as u64;
let meta_request = &mut self.body.meta_request;
let mut slice = slice;
while !slice.is_empty() {
// Write will fail if the request has already finished (because of an error).
let remaining = meta_request
.write(slice, false)
.await
.map_err(S3RequestError::CrtError)?;
self.total_bytes += (slice.len() - remaining.len()) as u64;
slice = remaining;
}
// Write completed with no errors, we can reset to `Idle`.
self.state = S3PutObjectRequestState::Idle;
Ok(())
}

Expand All @@ -192,10 +218,17 @@ impl PutObjectRequest for S3PutObjectRequest {
mut self,
review_callback: impl FnOnce(UploadReview) -> bool + Send + 'static,
) -> ObjectClientResult<PutObjectResult, PutObjectError, Self::ClientError> {
// No need to check for `CreatingMPU`: errors will be reported on completing the upload.
if matches!(self.state, S3PutObjectRequestState::PendingWrite) {
vladem marked this conversation as resolved.
Show resolved Hide resolved
// Fail if a previous write was not completed.
return Err(S3RequestError::RequestCanceled.into());
}

self.review_callback.set(review_callback);

// Write will fail if the request has already finished (because of an error).
self.body
_ = self
.body
.meta_request
.write(&[], true)
.await
Expand Down
10 changes: 5 additions & 5 deletions mountpoint-s3-client/tests/put_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use mountpoint_s3_client::types::{
};
use mountpoint_s3_client::{ObjectClient, PutObjectRequest, S3CrtClient, S3RequestError};
use mountpoint_s3_crt::checksums::crc32c;
use mountpoint_s3_crt_sys::aws_s3_errors;
use rand::Rng;
use test_case::test_case;

Expand Down Expand Up @@ -221,7 +220,7 @@ async fn test_put_object_write_cancelled() {
request.write(&[1, 2, 3, 4]).await.expect("write should succeed");

{
// Write a multiple of `part_size` to ensure the copy is deferred.
// Write a multiple of `part_size` to ensure it will not complete immediately.
let size = client.part_size().unwrap() * 10;
let buffer = vec![0u8; size];
let write = request.write(&buffer);
Expand All @@ -235,9 +234,10 @@ async fn test_put_object_write_cancelled() {
.write(&[1, 2, 3, 4])
.await
.expect_err("further writes should fail");
assert!(
matches!(err, ObjectClientError::ClientError(S3RequestError::CrtError(e)) if e.raw_error() == aws_s3_errors::AWS_ERROR_S3_REQUEST_HAS_COMPLETED as i32)
);
assert!(matches!(
err,
ObjectClientError::ClientError(S3RequestError::RequestCanceled)
));
}

#[tokio::test]
Expand Down
189 changes: 1 addition & 188 deletions mountpoint-s3-crt/src/io/futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@

use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::task::{Context, Poll};

use futures::channel::oneshot;
use futures::future::BoxFuture;
use futures::task::ArcWake;
use futures::{FutureExt, TryFutureExt};
use mountpoint_s3_crt_sys::{
aws_future_void, aws_future_void_get_error, aws_future_void_is_done, aws_future_void_register_callback,
aws_future_void_release,
};
use thiserror::Error;

use crate::common::allocator::Allocator;
Expand Down Expand Up @@ -226,125 +220,17 @@ pub enum JoinError {
InternalError(#[from] crate::common::error::Error),
}

/// Wraps a [aws_future_void].
#[derive(Debug)]
pub struct FutureVoid {
inner: NonNull<aws_future_void>,
waker: Arc<Mutex<Option<Waker>>>,
}

// SAFETY: `aws_future_void` is thread-safe
unsafe impl Send for FutureVoid {}

impl Drop for FutureVoid {
fn drop(&mut self) {
// SAFETY: `self.inner` contains a valid `aws_future_void`.
unsafe {
aws_future_void_release(self.inner.as_ptr());
}
}
}

impl FutureVoid {
/// Return whether the future is done
pub fn is_done(&self) -> bool {
// SAFETY: `self.inner` contains a valid `aws_future_void`.
unsafe { aws_future_void_is_done(self.inner.as_ptr()) }
}

/// Create a [FutureVoid] from a [aws_future_void].
///
/// ## Safety
///
/// `inner` must be a valid [aws_future_void] with no registered callbacks.
pub unsafe fn from_crt(inner: NonNull<aws_future_void>) -> Self {
Self {
inner,
waker: Arc::new(Mutex::new(None)),
}
}

/// Get the result of this future if completed.
fn try_get_result(&self) -> Option<Result<(), crate::common::error::Error>> {
if !self.is_done() {
return None;
}

let result = {
// SAFETY: `self.inner` has completed.
let future_result = unsafe { aws_future_void_get_error(self.inner.as_ptr()) };
let error_result: crate::common::error::Error = future_result.into();
if error_result.is_err() {
Err(error_result)
} else {
Ok(())
}
};
Some(result)
}
}

impl Future for FutureVoid {
type Output = Result<(), crate::common::error::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut waker = self.waker.lock().unwrap();
if let Some(result) = self.try_get_result() {
// The future has completed. Remove the waker, if any, and return the result.
_ = waker.take();
Poll::Ready(result)
} else {
// The future has not completed yet. Do we need to register the callback?
match *waker {
Some(ref mut waker) => {
// The callback has already been registered, just replace the waker.
waker.clone_from(cx.waker());
}
None => {
// Store the waker. Drop the lock in case the callback runs synchronously during registration.
*waker = Some(cx.waker().clone());
drop(waker);

// `user_data` will be cleaned up in `future_void_callback`.
let user_data = Arc::into_raw(self.waker.clone()) as *mut ::libc::c_void;

// SAFETY: `self.inner.as_ptr()` is a valid `aws_future_void` and this is the only callback we are registering.
unsafe {
aws_future_void_register_callback(self.inner.as_ptr(), Some(future_void_callback), user_data);
}
}
}
Poll::Pending
}
}
}

/// Safety: Don't call this function directly, only called by the CRT as a callback.
unsafe extern "C" fn future_void_callback(user_data: *mut ::libc::c_void) {
// Take ownership of the `Arc` in `user_data`.
let waker = Arc::from_raw(user_data as *mut Mutex<Option<Waker>>);
let Some(waker) = waker.lock().unwrap().take() else {
// Waker removed on `poll` finding that the future had already completed.
// Nothing to do here.
return;
};
// Notify the waker that the future has completed.
waker.wake();
}

#[cfg(test)]
mod test {
use futures::executor::block_on;
use futures::future::join_all;
use mountpoint_s3_crt_sys::{aws_future_void_new, aws_future_void_set_error, aws_future_void_set_result};
use std::sync::atomic::{AtomicBool, AtomicU64};
use std::time::Duration;

use super::*;
use crate::common::allocator::Allocator;
use crate::io::event_loop::{EventLoopGroup, EventLoopTimer};
use std::sync::atomic::Ordering;
use test_case::test_case;

/// Test that running a small future on an event loop works correctly.
#[test]
Expand Down Expand Up @@ -451,77 +337,4 @@ mod test {
"flag should still be false after cancellation"
);
}

#[test_case(Ok(()))]
#[test_case(Err(42))]
fn test_future_void_already_done(value: Result<(), i32>) {
let allocator = Allocator::default();
let aws_future = new_aws_future_void(&allocator);
set_aws_future_void_value(aws_future, value);

// SAFETY: `aws_future` is a valid `aws_future_void`.
let future_void = unsafe { FutureVoid::from_crt(aws_future) };

// Verify that the wrapper has completed and contains the set value.
assert!(future_void.is_done());
let Some(result) = future_void.try_get_result() else {
panic!("result should be available when the future is done");
};
assert_eq!(result.map_err(|e| e.raw_error()), value);

// Verify that the wrapper returns the set value when awaited.
let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
let future_handle = el_group.spawn_future(future_void);
let result = future_handle.wait().unwrap();
assert_eq!(result.map_err(|e| e.raw_error()), value);
}

#[test_case(Ok(()))]
#[test_case(Err(42))]
fn test_future_void_wake_up(value: Result<(), i32>) {
let allocator = Allocator::default();

let aws_future = new_aws_future_void(&allocator);
// SAFETY: `aws_future` is a valid `aws_future_void`.
let future_void = unsafe { FutureVoid::from_crt(aws_future) };

// Set up a flag that will set to true after awaiting future_void.
let flag = Arc::new(AtomicBool::new(false));

let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
let future_handle = {
let flag = flag.clone();
el_group.spawn_future(async move {
let result = future_void.await;
flag.store(true, Ordering::SeqCst);
result
})
};
assert!(
!flag.load(Ordering::SeqCst),
"the spawned future should not have completed and set the flag"
);
set_aws_future_void_value(aws_future, value);
let result = future_handle.wait().unwrap();
assert!(
flag.load(Ordering::SeqCst),
"the spawned future should have set the flag"
);
assert_eq!(result.map_err(|e| e.raw_error()), value);
}

fn new_aws_future_void(allocator: &Allocator) -> NonNull<aws_future_void> {
// SAFETY: `allocator` is a valid `aws_allocator` and `aws_future_void_new` returns a
// pointer to a valid `aws_future_void`.
unsafe { NonNull::new_unchecked(aws_future_void_new(allocator.inner.as_ptr())) }
}

fn set_aws_future_void_value(aws_future: NonNull<aws_future_void>, value: Result<(), i32>) {
match value {
// SAFETY: `aws_future` is a valid `aws_future_void`.
Ok(()) => unsafe { aws_future_void_set_result(aws_future.as_ptr()) },
// SAFETY: `aws_future` is a valid `aws_future_void`.
Err(code) => unsafe { aws_future_void_set_error(aws_future.as_ptr(), code) },
}
}
}
Loading
Loading