diff --git a/rust-runtime/aws-smithy-types/src/body.rs b/rust-runtime/aws-smithy-types/src/body.rs index c8c0515363f..d8c8c875c48 100644 --- a/rust-runtime/aws-smithy-types/src/body.rs +++ b/rust-runtime/aws-smithy-types/src/body.rs @@ -32,8 +32,6 @@ pin_project! { /// For handling responses, the type of the body will be controlled /// by the HTTP stack. /// - // TODO(naming): Consider renaming to simply `Body`, although I'm concerned about naming headaches - // between hyper::Body and our Body pub struct SdkBody { #[pin] inner: Inner, @@ -65,9 +63,6 @@ enum BoxBody { feature = "rt-tokio" ))] HttpBody04(http_body_0_4::combinators::BoxBody), - - #[cfg(feature = "http-body-1-x")] - HttpBody10(http_body_util::combinators::BoxBody), } pin_project! { @@ -142,24 +137,6 @@ impl SdkBody { poll_fn(|cx| me.as_mut().poll_next(cx)).await } - #[cfg(feature = "http-body-1-x")] - pub(crate) fn poll_next_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Error>>> { - // There has got to be a way to simplify this matching matchy match - match self.poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(maybe_ready) => match maybe_ready { - None => return Poll::Ready(None), - Some(result) => match result { - Err(err) => return Poll::Ready(Some(Err(err))), - Ok(bytes) => return Poll::Ready(Some(Ok(http_body_1_0::Frame::data(bytes)))), - }, - }, - } - } - pub(crate) fn poll_next( self: Pin<&mut Self>, #[allow(unused)] cx: &mut Context<'_>, diff --git a/rust-runtime/aws-smithy-types/src/body/http_body_1_x.rs b/rust-runtime/aws-smithy-types/src/body/http_body_1_x.rs index 507c376164f..42a36581c56 100644 --- a/rust-runtime/aws-smithy-types/src/body/http_body_1_x.rs +++ b/rust-runtime/aws-smithy-types/src/body/http_body_1_x.rs @@ -24,20 +24,23 @@ impl SdkBody { SdkBody::from_body_0_4_internal(Http1toHttp04::new(body.map_err(Into::into))) } - /// Construct an `SdkBody` from a type that implements [`hyper_1_0::body::Body`](hyper_1_0::body::Body) - pub fn from_hyper_1_x(body: T) -> Self - where - T: hyper_1_0::body::Body + Send + Sync + 'static, - E: Into + 'static, - { - SdkBody { - bytes_contents: None, - inner: super::Inner::Dyn { - inner: super::BoxBody::HttpBody10(http_body_util::combinators::BoxBody::new( - body.map_err(Into::into), - )), + pub(crate) fn poll_data_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Error>>> { + // There has got to be a way to simplify this matching matchy match + match ready!(self.as_mut().poll_next(cx)) { + None => match ready!(self.poll_next_trailers(cx)) { + Ok(Some(trailers)) => Poll::Ready(Some(Ok(http_body_1_0::Frame::trailers( + convert_headers_0x_1x(trailers), + )))), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(e))), + }, + Some(result) => match result { + Err(err) => return Poll::Ready(Some(Err(err))), + Ok(bytes) => return Poll::Ready(Some(Ok(http_body_1_0::Frame::data(bytes)))), }, - rebuild: None, } } } @@ -51,7 +54,7 @@ impl http_body_1_0::Body for SdkBody { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { - self.poll_next_frame(cx) + self.poll_data_frame(cx) } fn is_end_stream(&self) -> bool { @@ -127,7 +130,7 @@ where // already read everything let this = self.project(); match this.trailers.take() { - Some(headers) => Poll::Ready(Ok(Some(convert_header_map(headers)))), + Some(headers) => Poll::Ready(Ok(Some(convert_headers_1x_0x(headers)))), None => Poll::Ready(Ok(None)), } } @@ -151,7 +154,7 @@ where } } -fn convert_header_map(input: http_1x::HeaderMap) -> http::HeaderMap { +fn convert_headers_1x_0x(input: http_1x::HeaderMap) -> http::HeaderMap { let mut map = http::HeaderMap::with_capacity(input.capacity()); let mut mem: Option = None; for (k, v) in input.into_iter() { @@ -165,6 +168,20 @@ fn convert_header_map(input: http_1x::HeaderMap) -> http::HeaderMap { map } +fn convert_headers_0x_1x(input: http::HeaderMap) -> http_1x::HeaderMap { + let mut map = http_1x::HeaderMap::with_capacity(input.capacity()); + let mut mem: Option = None; + for (k, v) in input.into_iter() { + let name = k.or_else(|| mem.clone()).unwrap(); + map.append( + http_1x::HeaderName::from_bytes(name.as_str().as_bytes()).expect("already validated"), + http_1x::HeaderValue::from_bytes(v.as_bytes()).expect("already validated"), + ); + mem = Some(name); + } + map +} + #[cfg(test)] mod test { use std::collections::VecDeque; @@ -176,8 +193,9 @@ mod test { use http_1x::header::{CONTENT_LENGTH as CL1, CONTENT_TYPE as CT1}; use http_1x::{HeaderMap, HeaderName, HeaderValue}; use http_body_1_0::Frame; + use http_body_util::BodyExt; - use crate::body::http_body_1_x::convert_header_map; + use crate::body::http_body_1_x::{convert_headers_1x_0x, Http1toHttp04}; use crate::body::{Error, SdkBody}; use crate::byte_stream::ByteStream; @@ -259,10 +277,46 @@ mod test { while let Some(_data) = http_body_0_4::Body::data(&mut body).await {} assert_eq!( http_body_0_4::Body::trailers(&mut body).await.unwrap(), - Some(convert_header_map(trailers())) + Some(convert_headers_1x_0x(trailers())) ); } + #[tokio::test] + async fn test_read_trailers_as_1x() { + let body = TestBody { + chunks: vec![ + Chunk::Data("123"), + Chunk::Data("456"), + Chunk::Data("789"), + Chunk::Trailers(trailers()), + ] + .into(), + }; + let body = SdkBody::from_body_1_x(body); + + let collected = BodyExt::collect(body).await.expect("should succeed"); + assert_eq!(collected.trailers(), Some(&trailers())); + assert_eq!(collected.to_bytes().as_ref(), b"123456789"); + } + + #[tokio::test] + async fn test_trailers_04x_to_1x() { + let body = TestBody { + chunks: vec![ + Chunk::Data("123"), + Chunk::Data("456"), + Chunk::Data("789"), + Chunk::Trailers(trailers()), + ] + .into(), + }; + let body = SdkBody::from_body_0_4(Http1toHttp04::new(body)); + + let collected = BodyExt::collect(body).await.expect("should succeed"); + assert_eq!(collected.trailers(), Some(&trailers())); + assert_eq!(collected.to_bytes().as_ref(), b"123456789"); + } + #[tokio::test] async fn test_errors() { let body = TestBody { @@ -307,7 +361,7 @@ mod test { expect.insert(CL0, http::HeaderValue::from_static("1234")); - assert_eq!(convert_header_map(http1_headermap), expect); + assert_eq!(convert_headers_1x_0x(http1_headermap), expect); } #[test] @@ -321,7 +375,7 @@ mod test { ] .into(), }; - let body = SdkBody::from_hyper_1_x(body); + let body = SdkBody::from_body_1_x(body); assert!(format!("{:?}", body).contains("BoxBody")); } }