From 563990a5fcee1b5f66547a6e51a48adcb1c7cd2c Mon Sep 17 00:00:00 2001 From: Liam Perlaki Date: Tue, 11 Apr 2023 11:13:06 +0200 Subject: [PATCH] use custom KeepAliveBody for CompleteMultipartUpload --- crates/s3s-aws/src/connector.rs | 3 +- crates/s3s/Cargo.toml | 1 + crates/s3s/src/http/body.rs | 45 ++++++++++------ crates/s3s/src/http/ser.rs | 27 ++++------ crates/s3s/src/keep_alive_body.rs | 84 ++++++++++++++++++++++++++++++ crates/s3s/src/lib.rs | 2 + crates/s3s/src/ops/generated.rs | 86 ++++++++++++++----------------- 7 files changed, 168 insertions(+), 80 deletions(-) create mode 100644 crates/s3s/src/keep_alive_body.rs diff --git a/crates/s3s-aws/src/connector.rs b/crates/s3s-aws/src/connector.rs index 5b3de91d..f803e17c 100644 --- a/crates/s3s-aws/src/connector.rs +++ b/crates/s3s-aws/src/connector.rs @@ -1,3 +1,4 @@ +use hyper::body::HttpBody; use s3s::service::SharedS3Service; use s3s::{S3Error, S3Result}; @@ -59,7 +60,7 @@ fn convert_input(mut req: Request) -> Request { fn convert_output(result: S3Result>) -> Result, ConnectorError> { match result { - Ok(res) => Ok(res.map(|s3s_body| SdkBody::from(hyper::Body::from(s3s_body)))), + Ok(res) => Ok(res.map(|s3s_body| SdkBody::from_dyn(s3s_body.boxed()))), Err(e) => Err(on_err(e)), } } diff --git a/crates/s3s/Cargo.toml b/crates/s3s/Cargo.toml index c97aeaa3..734abc42 100644 --- a/crates/s3s/Cargo.toml +++ b/crates/s3s/Cargo.toml @@ -46,6 +46,7 @@ transform-stream = "0.3.0" urlencoding = "2.1.2" zeroize = "1.6.0" +sync_wrapper = { version = "0.1.2", default-features = false } tokio = { version = "1.27.0", features = ["time"] } diff --git a/crates/s3s/src/http/body.rs b/crates/s3s/src/http/body.rs index 68248d73..f4603699 100644 --- a/crates/s3s/src/http/body.rs +++ b/crates/s3s/src/http/body.rs @@ -36,7 +36,11 @@ pin_project_lite::pin_project! { DynStream { #[pin] inner: DynByteStream - } + }, + HttpBody { + #[pin] + inner: http_body::combinators::BoxBody, + }, } } @@ -63,6 +67,11 @@ impl Body { kind: Kind::DynStream { inner: stream }, } } + pub fn http_body(body: http_body::combinators::BoxBody) -> Self { + Self { + kind: Kind::HttpBody { inner: body }, + } + } } impl From for Body { @@ -123,11 +132,22 @@ impl http_body::Body for Body { Stream::poll_next(inner, cx) // } + KindProj::HttpBody { inner } => { + http_body::Body::poll_data(inner, cx) + // + } } } fn poll_trailers(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) // TODO: How to impl poll_trailers? + let mut this = self.project(); + match this.kind.as_mut().project() { + KindProj::Empty => Poll::Ready(Ok(None)), + KindProj::Once { .. } => Poll::Ready(Ok(None)), + KindProj::Hyper { inner } => http_body::Body::poll_trailers(inner, _cx).map_err(|e| Box::new(e) as StdError), + KindProj::DynStream { .. } => Poll::Ready(Ok(None)), + KindProj::HttpBody { inner } => http_body::Body::poll_trailers(inner, _cx), + } } fn is_end_stream(&self) -> bool { @@ -136,6 +156,7 @@ impl http_body::Body for Body { Kind::Once { inner } => inner.is_empty(), Kind::Hyper { inner } => http_body::Body::is_end_stream(inner), Kind::DynStream { inner } => inner.remaining_length().exact() == Some(0), + Kind::HttpBody { inner } => http_body::Body::is_end_stream(inner), } } @@ -145,6 +166,7 @@ impl http_body::Body for Body { Kind::Once { inner } => http_body::SizeHint::with_exact(inner.len() as u64), Kind::Hyper { inner } => http_body::Body::size_hint(inner), Kind::DynStream { inner } => inner.remaining_length().into(), + Kind::HttpBody { inner } => http_body::Body::size_hint(inner), } } } @@ -164,6 +186,7 @@ impl ByteStream for Body { Kind::Once { inner } => RemainingLength::new_exact(inner.len()), Kind::Hyper { inner } => http_body::Body::size_hint(inner).into(), Kind::DynStream { inner } => inner.remaining_length(), + Kind::HttpBody { inner } => http_body::Body::size_hint(inner).into(), } } } @@ -183,6 +206,9 @@ impl fmt::Debug for Body { d.field("dyn_stream", &"{..}"); d.field("remaining_length", &inner.remaining_length()); } + Kind::HttpBody { inner } => { + d.field("http_body", inner); + } } d.finish() } @@ -207,19 +233,4 @@ impl Body { _ => None, } } - - fn into_hyper(self) -> hyper::Body { - match self.kind { - Kind::Empty => hyper::Body::empty(), - Kind::Once { inner } => inner.into(), - Kind::Hyper { inner } => inner, - Kind::DynStream { inner } => hyper::Body::wrap_stream(inner), - } - } -} - -impl From for hyper::Body { - fn from(value: Body) -> Self { - value.into_hyper() - } } diff --git a/crates/s3s/src/http/ser.rs b/crates/s3s/src/http/ser.rs index 39b2d979..1b7103c7 100644 --- a/crates/s3s/src/http/ser.rs +++ b/crates/s3s/src/http/ser.rs @@ -5,11 +5,15 @@ use crate::dto::SelectObjectContentEventStream; use crate::dto::{Metadata, StreamingBlob, Timestamp, TimestampFormat}; use crate::error::{S3Error, S3Result}; use crate::http::{HeaderName, HeaderValue}; +use crate::keep_alive_body::KeepAliveBody; use crate::{utils, xml}; use std::convert::Infallible; use std::fmt::Write as _; +use futures::Future; + +use http_body::Body as HttpBody; use hyper::header::{IntoHeaderName, InvalidHeaderValue}; pub fn add_header(res: &mut Response, name: N, value: V) -> S3Result @@ -104,27 +108,18 @@ pub fn set_xml_body(res: &mut Response, val: &T) -> S3Result Ok(()) } -pub async fn set_xml_sending_body(res: &mut Response) -> S3Result { - res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML); - let (mut sender, body) = hyper::Body::channel(); - res.body = body.into(); +pub fn set_keep_alive_xml_body( + res: &mut Response, + fut: impl Future + Send + Sync + 'static, + duration: std::time::Duration, +) -> S3Result { let mut buf = Vec::with_capacity(256); { let mut ser = xml::Serializer::new(&mut buf); ser.decl().map_err(S3Error::internal_error)?; } - - sender.send_data(buf.into()).await.map_err(S3Error::internal_error)?; - Ok(sender) -} - -pub async fn send_xml_body(res: &mut hyper::body::Sender, val: &T) -> S3Result { - let mut buf = Vec::with_capacity(256); - { - let mut ser = xml::Serializer::new(&mut buf); - val.serialize(&mut ser).map_err(S3Error::internal_error)?; - } - res.send_data(buf.into()).await.map_err(S3Error::internal_error)?; + res.body = Body::http_body(KeepAliveBody::with_initial_body(fut, buf.into(), duration).boxed()); + res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML); Ok(()) } diff --git a/crates/s3s/src/keep_alive_body.rs b/crates/s3s/src/keep_alive_body.rs new file mode 100644 index 00000000..055aa2ac --- /dev/null +++ b/crates/s3s/src/keep_alive_body.rs @@ -0,0 +1,84 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use bytes::Bytes; +use http_body::Body; +use tokio::time::Interval; + +use crate::{http::Response, StdError}; + +// sends whitespace while the future is pending +pin_project_lite::pin_project! { + + pub struct KeepAliveBody { + #[pin] + inner: F, + initial_body: Option, + response: Option, + interval: Interval, + } +} +impl KeepAliveBody { + pub fn new(inner: F, interval: Duration) -> Self { + Self { + inner, + initial_body: None, + response: None, + interval: tokio::time::interval(interval), + } + } + + pub fn with_initial_body(inner: F, initial_body: Bytes, interval: Duration) -> Self { + Self { + inner, + initial_body: Some(initial_body), + response: None, + interval: tokio::time::interval(interval), + } + } +} + +impl Body for KeepAliveBody +where + F: Future, +{ + type Data = Bytes; + + type Error = StdError; + + fn poll_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + let mut this = self.project(); + if let Some(initial_body) = this.initial_body.take() { + cx.waker().wake_by_ref(); + return Poll::Ready(Some(Ok(initial_body))); + } + loop { + if let Some(response) = this.response { + return Pin::new(&mut response.body).poll_data(cx); + } + match this.inner.as_mut().poll(cx) { + Poll::Ready(response) => { + *this.response = Some(response); + } + Poll::Pending => match this.interval.poll_tick(cx) { + Poll::Ready(_) => return Poll::Ready(Some(Ok(Bytes::from_static(b" ")))), + Poll::Pending => return Poll::Pending, + }, + } + } + } + + fn poll_trailers(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll, Self::Error>> { + let this = self.project(); + + if let Some(response) = this.response { + return Pin::new(&mut response.body).poll_trailers(cx); + } else { + return Poll::Ready(Ok(None)); + } + } +} diff --git a/crates/s3s/src/lib.rs b/crates/s3s/src/lib.rs index 2407021e..93ee1d6d 100644 --- a/crates/s3s/src/lib.rs +++ b/crates/s3s/src/lib.rs @@ -35,6 +35,8 @@ pub mod path; pub mod service; pub mod stream; +pub mod keep_alive_body; + pub use self::error::*; pub use self::http::Body; pub use self::request::S3Request; diff --git a/crates/s3s/src/ops/generated.rs b/crates/s3s/src/ops/generated.rs index 5f621c5e..df15802a 100644 --- a/crates/s3s/src/ops/generated.rs +++ b/crates/s3s/src/ops/generated.rs @@ -3,9 +3,6 @@ #![allow(clippy::declare_interior_mutable_const)] #![allow(clippy::borrow_interior_mutable_const)] -use bytes::Bytes; -use futures::FutureExt; - use crate::dto::*; use crate::error::*; use crate::header::*; @@ -411,52 +408,49 @@ impl CompleteMultipartUpload { pub async fn call_shared(&self, s3: std::sync::Arc, req: &mut http::Request) -> S3Result { let input = Self::deserialize_http(req)?; let req = super::build_s3_request(input, req); - let fut = async move { s3.complete_multipart_upload(req).await }.fuse(); - let mut fut = Box::pin(fut); - futures::select! { - result = &mut fut => { - let res = match result { - Ok(output) => Self::serialize_http(output)?, - Err(err) => super::serialize_error(err)?, - }; - return Ok(res) - } - _ = tokio::time::sleep(std::time::Duration::from_millis(100)).fuse() => { - () - } - } + let fut = async move { + let res = s3.complete_multipart_upload(req).await; + match res { + Ok(output) => { + let mut res = http::Response::with_status(http::StatusCode::OK); + + let mut buf = Vec::with_capacity(256); + + let mut ser = crate::xml::Serializer::new(&mut buf); + crate::xml::Serialize::serialize(&output, &mut ser) + .map_err(S3Error::internal_error) + .unwrap(); + + res.body = crate::Body::from(buf); + + http::add_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED, output.bucket_key_enabled) + .unwrap(); + http::add_opt_header(&mut res, X_AMZ_EXPIRATION, output.expiration).unwrap(); + http::add_opt_header(&mut res, X_AMZ_REQUEST_CHARGED, output.request_charged).unwrap(); + http::add_opt_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, output.ssekms_key_id).unwrap(); + http::add_opt_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION, output.server_side_encryption).unwrap(); + http::add_opt_header(&mut res, X_AMZ_VERSION_ID, output.version_id).unwrap(); + + res + } + Err(err) => { + let mut res = http::Response::with_status(http::StatusCode::OK); - let mut res = http::Response::with_status(http::StatusCode::OK); - let mut sender = http::set_xml_sending_body(&mut res).await?; - - tokio::spawn(async move { - let mut interval = tokio::time::interval(std::time::Duration::from_millis(2)); - loop { - futures::select! { - _ = interval.tick().fuse() => { - sender.send_data(Bytes::from_static(b" ")).await.unwrap(); - } - res = &mut fut => { - match res { - Ok(output) => { - http::send_xml_body(&mut sender, &output).await.unwrap(); - let mut tmp_res = http::Response::with_status(http::StatusCode::OK); - http::add_header(&mut tmp_res, X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED, output.bucket_key_enabled).unwrap(); - http::add_opt_header(&mut tmp_res, X_AMZ_EXPIRATION, output.expiration).unwrap(); - http::add_opt_header(&mut tmp_res, X_AMZ_REQUEST_CHARGED, output.request_charged).unwrap(); - http::add_opt_header(&mut tmp_res, X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, output.ssekms_key_id).unwrap(); - http::add_opt_header(&mut tmp_res, X_AMZ_SERVER_SIDE_ENCRYPTION, output.server_side_encryption).unwrap(); - http::add_opt_header(&mut tmp_res, X_AMZ_VERSION_ID, output.version_id).unwrap(); - - sender.send_trailers(tmp_res.headers).await.unwrap(); - }, - Err(err) => http::send_xml_body(&mut sender, &err).await.unwrap(), - }; - return - } + let mut buf = Vec::with_capacity(256); + + let mut ser = crate::xml::Serializer::new(&mut buf); + crate::xml::Serialize::serialize(&err, &mut ser) + .map_err(S3Error::internal_error) + .unwrap(); + + res.body = crate::Body::from(buf); + res } } - }); + }; + + let mut res = http::Response::with_status(http::StatusCode::OK); + http::set_keep_alive_xml_body(&mut res, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?; Ok(res) }