Skip to content

Commit

Permalink
use custom KeepAliveBody for CompleteMultipartUpload
Browse files Browse the repository at this point in the history
  • Loading branch information
lperlaki committed Apr 11, 2023
1 parent 0fc706f commit 563990a
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 80 deletions.
3 changes: 2 additions & 1 deletion crates/s3s-aws/src/connector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use hyper::body::HttpBody;
use s3s::service::SharedS3Service;
use s3s::{S3Error, S3Result};

Expand Down Expand Up @@ -59,7 +60,7 @@ fn convert_input(mut req: Request<SdkBody>) -> Request<s3s::Body> {

fn convert_output(result: S3Result<Response<s3s::Body>>) -> Result<Response<SdkBody>, ConnectorError> {
match result {
Ok(res) => Ok(res.map(|s3s_body| SdkBody::from(hyper::Body::from(s3s_body)))),
Ok(res) => Ok(res.map(|s3s_body| SdkBody::from_dyn(s3s_body.boxed()))),
Err(e) => Err(on_err(e)),
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/s3s/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ transform-stream = "0.3.0"
urlencoding = "2.1.2"
zeroize = "1.6.0"

sync_wrapper = { version = "0.1.2", default-features = false }
tokio = { version = "1.27.0", features = ["time"] }


Expand Down
45 changes: 28 additions & 17 deletions crates/s3s/src/http/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ pin_project_lite::pin_project! {
DynStream {
#[pin]
inner: DynByteStream
}
},
HttpBody {
#[pin]
inner: http_body::combinators::BoxBody<Bytes, StdError>,
},
}
}

Expand All @@ -63,6 +67,11 @@ impl Body {
kind: Kind::DynStream { inner: stream },
}
}
pub fn http_body(body: http_body::combinators::BoxBody<Bytes, StdError>) -> Self {
Self {
kind: Kind::HttpBody { inner: body },
}
}
}

impl From<Bytes> for Body {
Expand Down Expand Up @@ -123,11 +132,22 @@ impl http_body::Body for Body {
Stream::poll_next(inner, cx)
//
}
KindProj::HttpBody { inner } => {
http_body::Body::poll_data(inner, cx)
//
}
}
}

fn poll_trailers(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None)) // TODO: How to impl poll_trailers?
let mut this = self.project();
match this.kind.as_mut().project() {
KindProj::Empty => Poll::Ready(Ok(None)),
KindProj::Once { .. } => Poll::Ready(Ok(None)),
KindProj::Hyper { inner } => http_body::Body::poll_trailers(inner, _cx).map_err(|e| Box::new(e) as StdError),
KindProj::DynStream { .. } => Poll::Ready(Ok(None)),
KindProj::HttpBody { inner } => http_body::Body::poll_trailers(inner, _cx),
}
}

fn is_end_stream(&self) -> bool {
Expand All @@ -136,6 +156,7 @@ impl http_body::Body for Body {
Kind::Once { inner } => inner.is_empty(),
Kind::Hyper { inner } => http_body::Body::is_end_stream(inner),
Kind::DynStream { inner } => inner.remaining_length().exact() == Some(0),
Kind::HttpBody { inner } => http_body::Body::is_end_stream(inner),
}
}

Expand All @@ -145,6 +166,7 @@ impl http_body::Body for Body {
Kind::Once { inner } => http_body::SizeHint::with_exact(inner.len() as u64),
Kind::Hyper { inner } => http_body::Body::size_hint(inner),
Kind::DynStream { inner } => inner.remaining_length().into(),
Kind::HttpBody { inner } => http_body::Body::size_hint(inner),
}
}
}
Expand All @@ -164,6 +186,7 @@ impl ByteStream for Body {
Kind::Once { inner } => RemainingLength::new_exact(inner.len()),
Kind::Hyper { inner } => http_body::Body::size_hint(inner).into(),
Kind::DynStream { inner } => inner.remaining_length(),
Kind::HttpBody { inner } => http_body::Body::size_hint(inner).into(),
}
}
}
Expand All @@ -183,6 +206,9 @@ impl fmt::Debug for Body {
d.field("dyn_stream", &"{..}");
d.field("remaining_length", &inner.remaining_length());
}
Kind::HttpBody { inner } => {
d.field("http_body", inner);
}
}
d.finish()
}
Expand All @@ -207,19 +233,4 @@ impl Body {
_ => None,
}
}

fn into_hyper(self) -> hyper::Body {
match self.kind {
Kind::Empty => hyper::Body::empty(),
Kind::Once { inner } => inner.into(),
Kind::Hyper { inner } => inner,
Kind::DynStream { inner } => hyper::Body::wrap_stream(inner),
}
}
}

impl From<Body> for hyper::Body {
fn from(value: Body) -> Self {
value.into_hyper()
}
}
27 changes: 11 additions & 16 deletions crates/s3s/src/http/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ use crate::dto::SelectObjectContentEventStream;
use crate::dto::{Metadata, StreamingBlob, Timestamp, TimestampFormat};
use crate::error::{S3Error, S3Result};
use crate::http::{HeaderName, HeaderValue};
use crate::keep_alive_body::KeepAliveBody;
use crate::{utils, xml};

use std::convert::Infallible;
use std::fmt::Write as _;

use futures::Future;

use http_body::Body as HttpBody;
use hyper::header::{IntoHeaderName, InvalidHeaderValue};

pub fn add_header<N, V>(res: &mut Response, name: N, value: V) -> S3Result
Expand Down Expand Up @@ -104,27 +108,18 @@ pub fn set_xml_body<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result
Ok(())
}

pub async fn set_xml_sending_body(res: &mut Response) -> S3Result<hyper::body::Sender> {
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
let (mut sender, body) = hyper::Body::channel();
res.body = body.into();
pub fn set_keep_alive_xml_body(
res: &mut Response,
fut: impl Future<Output = Response> + Send + Sync + 'static,
duration: std::time::Duration,
) -> S3Result {
let mut buf = Vec::with_capacity(256);
{
let mut ser = xml::Serializer::new(&mut buf);
ser.decl().map_err(S3Error::internal_error)?;
}

sender.send_data(buf.into()).await.map_err(S3Error::internal_error)?;
Ok(sender)
}

pub async fn send_xml_body<T: xml::Serialize>(res: &mut hyper::body::Sender, val: &T) -> S3Result {
let mut buf = Vec::with_capacity(256);
{
let mut ser = xml::Serializer::new(&mut buf);
val.serialize(&mut ser).map_err(S3Error::internal_error)?;
}
res.send_data(buf.into()).await.map_err(S3Error::internal_error)?;
res.body = Body::http_body(KeepAliveBody::with_initial_body(fut, buf.into(), duration).boxed());
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
Ok(())
}

Expand Down
84 changes: 84 additions & 0 deletions crates/s3s/src/keep_alive_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use bytes::Bytes;
use http_body::Body;
use tokio::time::Interval;

use crate::{http::Response, StdError};

// sends whitespace while the future is pending
pin_project_lite::pin_project! {

pub struct KeepAliveBody<F> {
#[pin]
inner: F,
initial_body: Option<Bytes>,
response: Option<Response>,
interval: Interval,
}
}
impl<F> KeepAliveBody<F> {
pub fn new(inner: F, interval: Duration) -> Self {
Self {
inner,
initial_body: None,
response: None,
interval: tokio::time::interval(interval),
}
}

pub fn with_initial_body(inner: F, initial_body: Bytes, interval: Duration) -> Self {
Self {
inner,
initial_body: Some(initial_body),
response: None,
interval: tokio::time::interval(interval),
}
}
}

impl<F> Body for KeepAliveBody<F>
where
F: Future<Output = Response>,
{
type Data = Bytes;

type Error = StdError;

fn poll_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let mut this = self.project();
if let Some(initial_body) = this.initial_body.take() {
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(initial_body)));
}
loop {
if let Some(response) = this.response {
return Pin::new(&mut response.body).poll_data(cx);
}
match this.inner.as_mut().poll(cx) {
Poll::Ready(response) => {
*this.response = Some(response);
}
Poll::Pending => match this.interval.poll_tick(cx) {
Poll::Ready(_) => return Poll::Ready(Some(Ok(Bytes::from_static(b" ")))),
Poll::Pending => return Poll::Pending,
},
}
}
}

fn poll_trailers(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
let this = self.project();

if let Some(response) = this.response {
return Pin::new(&mut response.body).poll_trailers(cx);
} else {
return Poll::Ready(Ok(None));
}
}
}
2 changes: 2 additions & 0 deletions crates/s3s/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub mod path;
pub mod service;
pub mod stream;

pub mod keep_alive_body;

pub use self::error::*;
pub use self::http::Body;
pub use self::request::S3Request;
Expand Down
86 changes: 40 additions & 46 deletions crates/s3s/src/ops/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
#![allow(clippy::declare_interior_mutable_const)]
#![allow(clippy::borrow_interior_mutable_const)]

use bytes::Bytes;
use futures::FutureExt;

use crate::dto::*;
use crate::error::*;
use crate::header::*;
Expand Down Expand Up @@ -411,52 +408,49 @@ impl CompleteMultipartUpload {
pub async fn call_shared(&self, s3: std::sync::Arc<dyn S3>, req: &mut http::Request) -> S3Result<http::Response> {
let input = Self::deserialize_http(req)?;
let req = super::build_s3_request(input, req);
let fut = async move { s3.complete_multipart_upload(req).await }.fuse();
let mut fut = Box::pin(fut);
futures::select! {
result = &mut fut => {
let res = match result {
Ok(output) => Self::serialize_http(output)?,
Err(err) => super::serialize_error(err)?,
};
return Ok(res)
}
_ = tokio::time::sleep(std::time::Duration::from_millis(100)).fuse() => {
()
}
}
let fut = async move {
let res = s3.complete_multipart_upload(req).await;
match res {
Ok(output) => {
let mut res = http::Response::with_status(http::StatusCode::OK);

let mut buf = Vec::with_capacity(256);

let mut ser = crate::xml::Serializer::new(&mut buf);
crate::xml::Serialize::serialize(&output, &mut ser)
.map_err(S3Error::internal_error)
.unwrap();

res.body = crate::Body::from(buf);

http::add_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED, output.bucket_key_enabled)
.unwrap();
http::add_opt_header(&mut res, X_AMZ_EXPIRATION, output.expiration).unwrap();
http::add_opt_header(&mut res, X_AMZ_REQUEST_CHARGED, output.request_charged).unwrap();
http::add_opt_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, output.ssekms_key_id).unwrap();
http::add_opt_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION, output.server_side_encryption).unwrap();
http::add_opt_header(&mut res, X_AMZ_VERSION_ID, output.version_id).unwrap();

res
}
Err(err) => {
let mut res = http::Response::with_status(http::StatusCode::OK);

let mut res = http::Response::with_status(http::StatusCode::OK);
let mut sender = http::set_xml_sending_body(&mut res).await?;

tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(2));
loop {
futures::select! {
_ = interval.tick().fuse() => {
sender.send_data(Bytes::from_static(b" ")).await.unwrap();
}
res = &mut fut => {
match res {
Ok(output) => {
http::send_xml_body(&mut sender, &output).await.unwrap();
let mut tmp_res = http::Response::with_status(http::StatusCode::OK);
http::add_header(&mut tmp_res, X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED, output.bucket_key_enabled).unwrap();
http::add_opt_header(&mut tmp_res, X_AMZ_EXPIRATION, output.expiration).unwrap();
http::add_opt_header(&mut tmp_res, X_AMZ_REQUEST_CHARGED, output.request_charged).unwrap();
http::add_opt_header(&mut tmp_res, X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, output.ssekms_key_id).unwrap();
http::add_opt_header(&mut tmp_res, X_AMZ_SERVER_SIDE_ENCRYPTION, output.server_side_encryption).unwrap();
http::add_opt_header(&mut tmp_res, X_AMZ_VERSION_ID, output.version_id).unwrap();

sender.send_trailers(tmp_res.headers).await.unwrap();
},
Err(err) => http::send_xml_body(&mut sender, &err).await.unwrap(),
};
return
}
let mut buf = Vec::with_capacity(256);

let mut ser = crate::xml::Serializer::new(&mut buf);
crate::xml::Serialize::serialize(&err, &mut ser)
.map_err(S3Error::internal_error)
.unwrap();

res.body = crate::Body::from(buf);
res
}
}
});
};

let mut res = http::Response::with_status(http::StatusCode::OK);
http::set_keep_alive_xml_body(&mut res, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;

Ok(res)
}
Expand Down

0 comments on commit 563990a

Please sign in to comment.