Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep CompleteMultipartUpload alive #31

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions codegen/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)
}

if is_xml_output(ty) {
g.ln("http::set_xml_body(&mut res, &x)?;");
if op.name == "CompleteMultipartUpload" {
g.ln("http::set_xml_body_no_decl(&mut res, &x)?;");
} else {
g.ln("http::set_xml_body(&mut res, &x)?;");
}
} else if let Some(field) = ty.fields.iter().find(|x| x.position == "payload") {
match field.type_.as_str() {
"Policy" => {
Expand Down Expand Up @@ -603,13 +607,26 @@ fn codegen_op_http_call(op: &Operation, g: &mut Codegen) {

g.ln("let input = Self::deserialize_http(req)?;");
g.ln("let req = super::build_s3_request(input, req);");
g.ln(f!("let result = s3.{method}(req).await;"));
if op.name == "CompleteMultipartUpload" {
g.ln("let s3 = s3.clone();");
g.ln("let fut = async move {");
g.ln(f!("let result = s3.{method}(req).await;"));
g.ln("match result {");
g.ln("Ok(output) => Self::serialize_http(output).unwrap(),");
g.ln("Err(err) => super::serialize_error_no_decl(err).unwrap(),");
g.ln("}");
g.ln("};");
g.ln("let mut res = http::Response::with_status(http::StatusCode::OK);");
g.ln("http::set_keep_alive_xml_body(&mut res, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;");
} else {
g.ln(f!("let result = s3.{method}(req).await;"));

g.ln("let res = match result {");
g.ln("Ok(output) => Self::serialize_http(output)?,");
g.ln("let res = match result {");
g.ln("Ok(output) => Self::serialize_http(output)?,");

g.ln("Err(err) => super::serialize_error(err)?,");
g.ln("};");
g.ln("Err(err) => super::serialize_error(err)?,");
g.ln("};");
}

g.ln("Ok(res)");

Expand Down
3 changes: 2 additions & 1 deletion crates/s3s-aws/src/connector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use hyper::body::HttpBody;
use s3s::service::SharedS3Service;
use s3s::{S3Error, S3Result};

Expand Down Expand Up @@ -59,7 +60,7 @@ fn convert_input(mut req: Request<SdkBody>) -> Request<s3s::Body> {

fn convert_output(result: S3Result<Response<s3s::Body>>) -> Result<Response<SdkBody>, ConnectorError> {
match result {
Ok(res) => Ok(res.map(|s3s_body| SdkBody::from(hyper::Body::from(s3s_body)))),
Ok(res) => Ok(res.map(|s3s_body| SdkBody::from_dyn(s3s_body.boxed()))),
Err(e) => Err(on_err(e)),
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/s3s/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,9 @@ transform-stream = "0.3.0"
urlencoding = "2.1.2"
zeroize = "1.6.0"

sync_wrapper = { version = "0.1.2", default-features = false }
tokio = { version = "1.27.0", features = ["time"] }


[dev-dependencies]
tokio = { version = "1.27.0", features = ["full"] }
47 changes: 30 additions & 17 deletions crates/s3s/src/http/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ pin_project_lite::pin_project! {
DynStream {
#[pin]
inner: DynByteStream
}
},
HttpBody {
#[pin]
inner: http_body::combinators::BoxBody<Bytes, StdError>,
},
}
}

Expand All @@ -63,6 +67,13 @@ impl Body {
kind: Kind::DynStream { inner: stream },
}
}

#[must_use]
pub fn http_body(body: http_body::combinators::BoxBody<Bytes, StdError>) -> Self {
Self {
kind: Kind::HttpBody { inner: body },
}
}
}

impl From<Bytes> for Body {
Expand Down Expand Up @@ -123,11 +134,22 @@ impl http_body::Body for Body {
Stream::poll_next(inner, cx)
//
}
KindProj::HttpBody { inner } => {
http_body::Body::poll_data(inner, cx)
//
}
}
}

fn poll_trailers(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None)) // TODO: How to impl poll_trailers?
let mut this = self.project();
match this.kind.as_mut().project() {
KindProj::Empty => Poll::Ready(Ok(None)),
KindProj::Once { .. } => Poll::Ready(Ok(None)),
KindProj::Hyper { inner } => http_body::Body::poll_trailers(inner, _cx).map_err(|e| Box::new(e) as StdError),
KindProj::DynStream { .. } => Poll::Ready(Ok(None)),
KindProj::HttpBody { inner } => http_body::Body::poll_trailers(inner, _cx),
}
}

fn is_end_stream(&self) -> bool {
Expand All @@ -136,6 +158,7 @@ impl http_body::Body for Body {
Kind::Once { inner } => inner.is_empty(),
Kind::Hyper { inner } => http_body::Body::is_end_stream(inner),
Kind::DynStream { inner } => inner.remaining_length().exact() == Some(0),
Kind::HttpBody { inner } => http_body::Body::is_end_stream(inner),
}
}

Expand All @@ -145,6 +168,7 @@ impl http_body::Body for Body {
Kind::Once { inner } => http_body::SizeHint::with_exact(inner.len() as u64),
Kind::Hyper { inner } => http_body::Body::size_hint(inner),
Kind::DynStream { inner } => inner.remaining_length().into(),
Kind::HttpBody { inner } => http_body::Body::size_hint(inner),
}
}
}
Expand All @@ -164,6 +188,7 @@ impl ByteStream for Body {
Kind::Once { inner } => RemainingLength::new_exact(inner.len()),
Kind::Hyper { inner } => http_body::Body::size_hint(inner).into(),
Kind::DynStream { inner } => inner.remaining_length(),
Kind::HttpBody { inner } => http_body::Body::size_hint(inner).into(),
}
}
}
Expand All @@ -183,6 +208,9 @@ impl fmt::Debug for Body {
d.field("dyn_stream", &"{..}");
d.field("remaining_length", &inner.remaining_length());
}
Kind::HttpBody { inner } => {
d.field("http_body", inner);
}
}
d.finish()
}
Expand All @@ -207,19 +235,4 @@ impl Body {
_ => None,
}
}

fn into_hyper(self) -> hyper::Body {
match self.kind {
Kind::Empty => hyper::Body::empty(),
Kind::Once { inner } => inner.into(),
Kind::Hyper { inner } => inner,
Kind::DynStream { inner } => hyper::Body::wrap_stream(inner),
}
}
}

impl From<Body> for hyper::Body {
fn from(value: Body) -> Self {
value.into_hyper()
}
}
23 changes: 23 additions & 0 deletions crates/s3s/src/http/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::dto::SelectObjectContentEventStream;
use crate::dto::{Metadata, StreamingBlob, Timestamp, TimestampFormat};
use crate::error::{S3Error, S3Result};
use crate::http::{HeaderName, HeaderValue};
use crate::keep_alive_body::KeepAliveBody;
use crate::{utils, xml};

use std::convert::Infallible;
Expand Down Expand Up @@ -104,6 +105,28 @@ pub fn set_xml_body<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result
Ok(())
}

pub fn set_keep_alive_xml_body(
res: &mut Response,
fut: impl std::future::Future<Output = Response> + Send + Sync + 'static,
duration: std::time::Duration,
) -> S3Result {
let mut buf = Vec::with_capacity(40);
let mut ser = xml::Serializer::new(&mut buf);
ser.decl().map_err(S3Error::internal_error)?;

res.body = Body::http_body(http_body::Body::boxed(KeepAliveBody::with_initial_body(fut, buf.into(), duration)));
res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
Ok(())
}

pub fn set_xml_body_no_decl<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result {
let mut buf = Vec::with_capacity(256);
let mut ser = xml::Serializer::new(&mut buf);
val.serialize(&mut ser).map_err(S3Error::internal_error)?;
res.body = Body::from(buf);
Ok(())
}

pub fn set_stream_body(res: &mut Response, stream: StreamingBlob) {
res.body = Body::from(stream);
}
Expand Down
84 changes: 84 additions & 0 deletions crates/s3s/src/keep_alive_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use bytes::Bytes;
use http_body::Body;
use tokio::time::Interval;

use crate::{http::Response, StdError};

// sends whitespace while the future is pending
pin_project_lite::pin_project! {

pub struct KeepAliveBody<F> {
#[pin]
inner: F,
initial_body: Option<Bytes>,
response: Option<Response>,
interval: Interval,
}
}
impl<F> KeepAliveBody<F> {
pub fn new(inner: F, interval: Duration) -> Self {
Self {
inner,
initial_body: None,
response: None,
interval: tokio::time::interval(interval),
}
}

pub fn with_initial_body(inner: F, initial_body: Bytes, interval: Duration) -> Self {
Self {
inner,
initial_body: Some(initial_body),
response: None,
interval: tokio::time::interval(interval),
}
}
}

impl<F> Body for KeepAliveBody<F>
where
F: Future<Output = Response>,
{
type Data = Bytes;

type Error = StdError;

fn poll_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let mut this = self.project();
if let Some(initial_body) = this.initial_body.take() {
cx.waker().wake_by_ref();
return Poll::Ready(Some(Ok(initial_body)));
}
loop {
if let Some(response) = this.response {
return Pin::new(&mut response.body).poll_data(cx);
}
match this.inner.as_mut().poll(cx) {
Poll::Ready(response) => {
*this.response = Some(response);
}
Poll::Pending => match this.interval.poll_tick(cx) {
Poll::Ready(_) => return Poll::Ready(Some(Ok(Bytes::from_static(b" ")))),
Poll::Pending => return Poll::Pending,
},
}
}
}

fn poll_trailers(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
let this = self.project();

if let Some(response) = this.response {
Pin::new(&mut response.body).poll_trailers(cx)
} else {
Poll::Ready(Ok(None))
}
}
}
2 changes: 2 additions & 0 deletions crates/s3s/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub mod path;
pub mod service;
pub mod stream;

pub mod keep_alive_body;

pub use self::error::*;
pub use self::http::Body;
pub use self::request::S3Request;
Expand Down
15 changes: 10 additions & 5 deletions crates/s3s/src/ops/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ impl CompleteMultipartUpload {

pub fn serialize_http(x: CompleteMultipartUploadOutput) -> S3Result<http::Response> {
let mut res = http::Response::with_status(http::StatusCode::OK);
http::set_xml_body(&mut res, &x)?;
http::set_xml_body_no_decl(&mut res, &x)?;
http::add_header(&mut res, X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED, x.bucket_key_enabled)?;
http::add_opt_header(&mut res, X_AMZ_EXPIRATION, x.expiration)?;
http::add_opt_header(&mut res, X_AMZ_REQUEST_CHARGED, x.request_charged)?;
Expand All @@ -416,11 +416,16 @@ impl super::Operation for CompleteMultipartUpload {
async fn call(&self, s3: &Arc<dyn S3>, req: &mut http::Request) -> S3Result<http::Response> {
let input = Self::deserialize_http(req)?;
let req = super::build_s3_request(input, req);
let result = s3.complete_multipart_upload(req).await;
let res = match result {
Ok(output) => Self::serialize_http(output)?,
Err(err) => super::serialize_error(err)?,
let s3 = s3.clone();
let fut = async move {
let result = s3.complete_multipart_upload(req).await;
match result {
Ok(output) => Self::serialize_http(output).unwrap(),
Err(err) => super::serialize_error_no_decl(err).unwrap(),
}
};
let mut res = http::Response::with_status(http::StatusCode::OK);
http::set_keep_alive_xml_body(&mut res, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;
Ok(res)
}
}
Expand Down
7 changes: 7 additions & 0 deletions crates/s3s/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ fn serialize_error(x: S3Error) -> S3Result<Response> {
Ok(res)
}

fn serialize_error_no_decl(x: S3Error) -> S3Result<Response> {
let status = x.status_code().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut res = Response::with_status(status);
http::set_xml_body_no_decl(&mut res, &x)?;
Ok(res)
}

fn unknown_operation() -> S3Error {
S3Error::with_message(S3ErrorCode::NotImplemented, "Unknown operation")
}
Expand Down