diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 79f95dc718..c500943e45 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -22,3 +22,11 @@ message = "The HTTP `Request`, `Response`, `Headers`, and `HeaderValue` types ha references = ["smithy-rs#3138"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } author = "jdisanti" + +[[smithy-rs]] +message = """ +`Message`, `Header`, `HeaderValue`, and `StrBytes` have been moved to `aws-smithy-types` from `aws-smithy-eventstream`. `Message::read_from` and `Message::write_to` remain in `aws-smithy-eventstream` but they are converted to free functions with the names `read_message_from` and `write_message_to` respectively. +""" +references = ["smithy-rs#3139"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all"} +author = "ysaito1001" diff --git a/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs b/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs index 4c6b7d6bec..1e3dc5ad39 100644 --- a/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs +++ b/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs @@ -223,8 +223,9 @@ mod event_stream { use aws_sigv4::event_stream::{sign_empty_message, sign_message}; use aws_sigv4::sign::v4; use aws_smithy_async::time::SharedTimeSource; - use aws_smithy_eventstream::frame::{Message, SignMessage, SignMessageError}; + use aws_smithy_eventstream::frame::{SignMessage, SignMessageError}; use aws_smithy_runtime_api::client::identity::Identity; + use aws_smithy_types::event_stream::Message; use aws_types::region::SigningRegion; use aws_types::SigningName; @@ -293,7 +294,8 @@ mod event_stream { use crate::auth::sigv4::event_stream::SigV4MessageSigner; use aws_credential_types::Credentials; use aws_smithy_async::time::SharedTimeSource; - use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage}; + use aws_smithy_eventstream::frame::SignMessage; + use aws_smithy_types::event_stream::{HeaderValue, Message}; use aws_types::region::Region; use aws_types::region::SigningRegion; diff --git a/aws/rust-runtime/aws-sigv4/Cargo.toml b/aws/rust-runtime/aws-sigv4/Cargo.toml index ea98d34f88..07766f202a 100644 --- a/aws/rust-runtime/aws-sigv4/Cargo.toml +++ b/aws/rust-runtime/aws-sigv4/Cargo.toml @@ -20,6 +20,7 @@ aws-credential-types = { path = "../aws-credential-types" } aws-smithy-eventstream = { path = "../../../rust-runtime/aws-smithy-eventstream", optional = true } aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" } aws-smithy-runtime-api = { path = "../../../rust-runtime/aws-smithy-runtime-api", features = ["client"] } +aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types" } bytes = "1" form_urlencoded = { version = "1.0", optional = true } hex = "0.4" diff --git a/aws/rust-runtime/aws-sigv4/external-types.toml b/aws/rust-runtime/aws-sigv4/external-types.toml index f0c4b533a3..37e06db236 100644 --- a/aws/rust-runtime/aws-sigv4/external-types.toml +++ b/aws/rust-runtime/aws-sigv4/external-types.toml @@ -2,6 +2,6 @@ allowed_external_types = [ # TODO(refactorHttp): Remove this and remove the signing helpers "http::request::Request", # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `event-stream` feature - "aws_smithy_eventstream::frame::Message", + "aws_smithy_types::event_stream::Message", "aws_smithy_runtime_api::client::identity::Identity" ] diff --git a/aws/rust-runtime/aws-sigv4/src/event_stream.rs b/aws/rust-runtime/aws-sigv4/src/event_stream.rs index 03cdc61303..03f8b367d2 100644 --- a/aws/rust-runtime/aws-sigv4/src/event_stream.rs +++ b/aws/rust-runtime/aws-sigv4/src/event_stream.rs @@ -9,7 +9,7 @@ //! //! ```rust //! use aws_sigv4::event_stream::sign_message; -//! use aws_smithy_eventstream::frame::{Header, HeaderValue, Message}; +//! use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; //! use std::time::SystemTime; //! use aws_credential_types::Credentials; //! use aws_smithy_runtime_api::client::identity::Identity; @@ -51,7 +51,8 @@ use crate::http_request::SigningError; use crate::sign::v4::{calculate_signature, generate_signing_key, sha256_hex_string}; use crate::SigningOutput; use aws_credential_types::Credentials; -use aws_smithy_eventstream::frame::{write_headers_to, Header, HeaderValue, Message}; +use aws_smithy_eventstream::frame::{write_headers_to, write_message_to}; +use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; use std::io::Write; use std::time::SystemTime; @@ -102,7 +103,7 @@ pub fn sign_message<'a>( ) -> Result, SigningError> { let message_payload = { let mut payload = Vec::new(); - message.write_to(&mut payload).unwrap(); + write_message_to(message, &mut payload).unwrap(); payload }; sign_payload(Some(message_payload), last_signature, params) @@ -161,7 +162,8 @@ mod tests { use crate::event_stream::{calculate_string_to_sign, sign_message, SigningParams}; use crate::sign::v4::sha256_hex_string; use aws_credential_types::Credentials; - use aws_smithy_eventstream::frame::{Header, HeaderValue, Message}; + use aws_smithy_eventstream::frame::write_message_to; + use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use std::time::{Duration, UNIX_EPOCH}; #[test] @@ -171,7 +173,7 @@ mod tests { HeaderValue::String("value".into()), )); let mut message_payload = Vec::new(); - message_to_sign.write_to(&mut message_payload).unwrap(); + write_message_to(&message_to_sign, &mut message_payload).unwrap(); let params = SigningParams { identity: &Credentials::for_tests().into(), diff --git a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs index fce762d82b..5178bd2ae9 100644 --- a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs +++ b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs @@ -7,13 +7,14 @@ use async_stream::stream; use aws_sdk_transcribestreaming::config::{Credentials, Region}; use aws_sdk_transcribestreaming::error::SdkError; use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput; +use aws_sdk_transcribestreaming::primitives::event_stream::{HeaderValue, Message}; use aws_sdk_transcribestreaming::primitives::Blob; use aws_sdk_transcribestreaming::types::error::{AudioStreamError, TranscriptResultStreamError}; use aws_sdk_transcribestreaming::types::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, }; use aws_sdk_transcribestreaming::{Client, Config}; -use aws_smithy_eventstream::frame::{DecodedFrame, HeaderValue, Message, MessageFrameDecoder}; +use aws_smithy_eventstream::frame::{read_message_from, DecodedFrame, MessageFrameDecoder}; use aws_smithy_runtime::client::http::test_util::dvr::{Event, ReplayingClient}; use bytes::BufMut; use futures_core::Stream; @@ -132,7 +133,7 @@ fn decode_frames(mut body: &[u8]) -> Vec<(Message, Option)> { let inner_msg = if msg.payload().is_empty() { None } else { - Some(Message::read_from(msg.payload().as_ref()).unwrap()) + Some(read_message_from(msg.payload().as_ref()).unwrap()) }; result.push((msg, inner_msg)); } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt index ac39f8d3bc..0216b2618c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt @@ -78,7 +78,16 @@ object ClientRustModule { val Meta = RustModule.public("meta") val Input = RustModule.public("input") val Output = RustModule.public("output") - val Primitives = RustModule.public("primitives") + + /** crate::primitives */ + val primitives = Primitives.self + object Primitives { + /** crate::primitives */ + val self = RustModule.public("primitives") + + /** crate::primitives::event_stream */ + val EventStream = RustModule.public("event_stream", parent = self) + } /** crate::types */ val types = Types.self @@ -110,7 +119,8 @@ class ClientModuleDocProvider( ClientRustModule.Meta -> strDoc("Information about this crate.") ClientRustModule.Input -> PANIC("this module shouldn't exist in the new scheme") ClientRustModule.Output -> PANIC("this module shouldn't exist in the new scheme") - ClientRustModule.Primitives -> strDoc("Primitives such as `Blob` or `DateTime` used by other types.") + ClientRustModule.primitives -> strDoc("Primitives such as `Blob` or `DateTime` used by other types.") + ClientRustModule.Primitives.EventStream -> strDoc("Event stream related primitives such as `Message` or `Header`.") ClientRustModule.types -> strDoc("Data structures used by operation inputs/outputs.") ClientRustModule.Types.Error -> strDoc("Error types that $serviceName can respond with.") else -> TODO("Document this module: $module") diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt index aa3e39ad4e..dca3fcc5c7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt @@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitivesEventStream import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization val TestUtilFeature = Feature("test-util", false, listOf()) @@ -85,9 +86,12 @@ class RequiredCustomizations : ClientCodegenDecorator { // Re-export resiliency types ResiliencyReExportCustomization(codegenContext).extras(rustCrate) - rustCrate.withModule(ClientRustModule.Primitives) { + rustCrate.withModule(ClientRustModule.primitives) { pubUseSmithyPrimitives(codegenContext, codegenContext.model, rustCrate)(this) } + rustCrate.withModule(ClientRustModule.Primitives.EventStream) { + pubUseSmithyPrimitivesEventStream(codegenContext, codegenContext.model)(this) + } rustCrate.withModule(ClientRustModule.Error) { rustTemplate( """ diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt index 64f6d9446b..db298141f0 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt @@ -174,5 +174,5 @@ class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringSha codegenContext.model, codegenContext.symbolProvider, shape, - InfallibleEnumType(ClientRustModule.Primitives), + InfallibleEnumType(ClientRustModule.primitives), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 1d013a645a..d37ffde29a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember +import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember /** Returns true if the model has normal streaming operations (excluding event streams) */ @@ -80,3 +81,22 @@ fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model, rustCra ) } } + +/** Adds re-export statements for event-stream-related Smithy primitives */ +fun pubUseSmithyPrimitivesEventStream(codegenContext: CodegenContext, model: Model): Writable = writable { + val rc = codegenContext.runtimeConfig + if (codegenContext.serviceShape.hasEventStreamOperations(model)) { + rustTemplate( + """ + pub use #{Header}; + pub use #{HeaderValue}; + pub use #{Message}; + pub use #{StrBytes}; + """, + "Header" to RuntimeType.smithyTypes(rc).resolve("event_stream::Header"), + "HeaderValue" to RuntimeType.smithyTypes(rc).resolve("event_stream::HeaderValue"), + "Message" to RuntimeType.smithyTypes(rc).resolve("event_stream::Message"), + "StrBytes" to RuntimeType.smithyTypes(rc).resolve("str_bytes::StrBytes"), + ) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 457bf5a53b..03b4f14f4f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -64,14 +64,15 @@ class EventStreamUnmarshallerGenerator( symbolProvider.symbolForEventStreamError(unionShape) } private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) + private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val codegenScope = arrayOf( "Blob" to RuntimeType.blob(runtimeConfig), "expect_fns" to smithyEventStream.resolve("smithy"), "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), - "Message" to smithyEventStream.resolve("frame::Message"), - "Header" to smithyEventStream.resolve("frame::Header"), - "HeaderValue" to smithyEventStream.resolve("frame::HeaderValue"), + "Message" to smithyTypes.resolve("event_stream::Message"), + "Header" to smithyTypes.resolve("event_stream::Header"), + "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), "Error" to smithyEventStream.resolve("error::Error"), "OpError" to errorSymbol, "SmithyError" to RuntimeType.smithyTypes(runtimeConfig).resolve("Error"), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt index 3a0c5c1b30..b9491f29e8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt @@ -44,6 +44,7 @@ class EventStreamErrorMarshallerGenerator( payloadContentType: String, ) : EventStreamMarshallerGenerator(model, target, runtimeConfig, symbolProvider, unionShape, serializerGenerator, payloadContentType) { private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) + private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) private val operationErrorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() @@ -54,9 +55,9 @@ class EventStreamErrorMarshallerGenerator( private val errorsShape = unionShape.expectTrait() private val codegenScope = arrayOf( "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), - "Message" to smithyEventStream.resolve("frame::Message"), - "Header" to smithyEventStream.resolve("frame::Header"), - "HeaderValue" to smithyEventStream.resolve("frame::HeaderValue"), + "Message" to smithyTypes.resolve("event_stream::Message"), + "Header" to smithyTypes.resolve("event_stream::Header"), + "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), "Error" to smithyEventStream.resolve("error::Error"), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index 201cd82ed5..ac6cf88ccc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -54,12 +54,13 @@ open class EventStreamMarshallerGenerator( private val payloadContentType: String, ) { private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) + private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val codegenScope = arrayOf( "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), - "Message" to smithyEventStream.resolve("frame::Message"), - "Header" to smithyEventStream.resolve("frame::Header"), - "HeaderValue" to smithyEventStream.resolve("frame::HeaderValue"), + "Message" to smithyTypes.resolve("event_stream::Message"), + "Header" to smithyTypes.resolve("event_stream::Header"), + "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), "Error" to smithyEventStream.resolve("error::Error"), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt index 896e3c0102..2850f51c72 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt @@ -35,7 +35,8 @@ object EventStreamMarshallTestCases { val typesModule = codegenContext.symbolProvider.moduleForShape(codegenContext.model.lookup("test#TestStruct")) rustTemplate( """ - use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage}; + use aws_smithy_eventstream::frame::MarshallMessage; + use aws_smithy_types::event_stream::{Message, Header, HeaderValue}; use std::collections::HashMap; use aws_smithy_types::{Blob, DateTime}; use ${typesModule.fullyQualifiedPath()}::*; diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt index 215e7ad9de..3adc813546 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt @@ -28,7 +28,8 @@ object EventStreamUnmarshallTestCases { val typesModule = codegenContext.symbolProvider.moduleForShape(codegenContext.model.lookup("test#TestStruct")) rust( """ - use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage}; + use aws_smithy_eventstream::frame::{UnmarshallMessage, UnmarshalledMessage}; + use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use aws_smithy_types::{Blob, DateTime}; use $testStreamError; use ${typesModule.fullyQualifiedPath()}::*; diff --git a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/corrected_prelude_crc.rs b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/corrected_prelude_crc.rs index 56ffed401b..4a29be7ecf 100644 --- a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/corrected_prelude_crc.rs +++ b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/corrected_prelude_crc.rs @@ -5,7 +5,7 @@ #![no_main] -use aws_smithy_eventstream::frame::Message; +use aws_smithy_eventstream::frame::read_message_from; use bytes::BufMut; use crc32fast::Hasher as Crc; use libfuzzer_sys::fuzz_target; @@ -30,7 +30,7 @@ fuzz_target!(|input: Input| { message.put_u32(crc(&message)); let mut data = &mut &message[..]; - let _ = Message::read_from(&mut data); + let _ = read_message_from(&mut data); }); fn crc(input: &[u8]) -> u32 { diff --git a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/mutated_headers.rs b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/mutated_headers.rs index a39f1697ab..c688da2c2d 100644 --- a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/mutated_headers.rs +++ b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/mutated_headers.rs @@ -5,7 +5,8 @@ #![no_main] -use aws_smithy_eventstream::frame::{Header, HeaderValue, Message}; +use aws_smithy_eventstream::frame::{read_message_from, write_message_to}; +use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use aws_smithy_types::DateTime; use bytes::{Buf, BufMut}; use crc32fast::Hasher as Crc; @@ -18,7 +19,7 @@ use libfuzzer_sys::{fuzz_mutator, fuzz_target}; // so that the fuzzer can actually explore the header parsing logic. fn mutate(data: &mut [u8], size: usize, max_size: usize) -> usize { let input = &mut &data[..size]; - let message = if let Ok(message) = Message::read_from(input) { + let message = if let Ok(message) = read_message_from(input) { message } else { Message::new(&b"some payload"[..]) @@ -44,7 +45,7 @@ fn mutate(data: &mut [u8], size: usize, max_size: usize) -> usize { }; let mut bytes = Vec::new(); - message.write_to(&mut bytes).unwrap(); + write_message_to(&message, &mut bytes).unwrap(); let headers_len = (&bytes[4..8]).get_u32(); let non_header_len = bytes.len() - headers_len as usize; @@ -72,7 +73,7 @@ fuzz_mutator!( fuzz_target!(|data: &[u8]| { let mut message = data; - let _ = Message::read_from(&mut message); + let _ = read_message_from(&mut message); }); fn crc(input: &[u8]) -> u32 { diff --git a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/prelude.rs b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/prelude.rs index fde6f224d7..9c6810f202 100644 --- a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/prelude.rs +++ b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/prelude.rs @@ -5,7 +5,8 @@ #![no_main] -use aws_smithy_eventstream::frame::{Header, HeaderValue, Message}; +use aws_smithy_eventstream::frame::{read_message_from, write_message_to}; +use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::{Buf, BufMut}; use crc32fast::Hasher as Crc; use libfuzzer_sys::fuzz_target; @@ -22,7 +23,7 @@ fuzz_target!(|input: Input| { .add_header(Header::new("str", HeaderValue::String("some str".into()))); let mut bytes = Vec::new(); - message.write_to(&mut bytes).unwrap(); + write_message_to(&message, &mut bytes).unwrap(); let headers_len = (&bytes[4..8]).get_u32(); let headers = &bytes[12..(12 + headers_len as usize)]; @@ -35,7 +36,7 @@ fuzz_target!(|input: Input| { mutated.put_slice(message.payload()); mutated.put_u32(crc(&mutated)); - let _ = Message::read_from(&mut &mutated[..]); + let _ = read_message_from(&mut &mutated[..]); }); fn crc(input: &[u8]) -> u32 { diff --git a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/raw_bytes.rs b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/raw_bytes.rs index 3db2ca7d37..83979e9d10 100644 --- a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/raw_bytes.rs +++ b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/raw_bytes.rs @@ -5,10 +5,10 @@ #![no_main] -use aws_smithy_eventstream::frame::Message; +use aws_smithy_eventstream::frame::read_message_from; use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { let mut message = data; - let _ = Message::read_from(&mut message); + let _ = read_message_from(&mut message); }); diff --git a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/round_trip.rs b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/round_trip.rs index 2294985a0a..9b754addbb 100644 --- a/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/round_trip.rs +++ b/rust-runtime/aws-smithy-eventstream/fuzz/fuzz_targets/round_trip.rs @@ -4,25 +4,23 @@ */ #![no_main] -use aws_smithy_eventstream::error::Error; -use aws_smithy_eventstream::frame::Message; +use aws_smithy_eventstream::arbitrary::ArbMessage; +use aws_smithy_eventstream::frame::{read_message_from, write_message_to}; use libfuzzer_sys::fuzz_target; -fuzz_target!(|message: Message| { +fuzz_target!(|message: ArbMessage| { + let message = message.into(); let mut buffer = Vec::new(); - match message.write_to(&mut buffer) { - Err( - Error::HeadersTooLong - | Error::PayloadTooLong - | Error::MessageTooLong - | Error::InvalidHeaderNameLength - | Error::TimestampValueTooLarge(_), - ) => {} - Err(err) => panic!("unexpected error on write: {}", err), + match write_message_to(&message, &mut buffer) { Ok(_) => { let mut data = &buffer[..]; - let parsed = Message::read_from(&mut data).unwrap(); + let parsed = read_message_from(&mut data).unwrap(); assert_eq!(message, parsed); } + Err(err) => { + if !err.is_invalid_message() { + panic!("unexpected error on write: {}", err), + } + } } }); diff --git a/rust-runtime/aws-smithy-eventstream/src/arbitrary.rs b/rust-runtime/aws-smithy-eventstream/src/arbitrary.rs new file mode 100644 index 0000000000..604c032c25 --- /dev/null +++ b/rust-runtime/aws-smithy-eventstream/src/arbitrary.rs @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Defines new-types wrapping inner types from `aws_smithy_types` to enable the `Arbitrary` trait +//! for fuzz testing. + +use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; +use aws_smithy_types::str_bytes::StrBytes; +use aws_smithy_types::DateTime; +use bytes::Bytes; + +#[derive(Clone, Debug, PartialEq)] +pub struct ArbHeaderValue(HeaderValue); + +impl<'a> arbitrary::Arbitrary<'a> for ArbHeaderValue { + fn arbitrary(unstruct: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let value_type: u8 = unstruct.int_in_range(0..=9)?; + let header_value = match value_type { + crate::frame::TYPE_TRUE => HeaderValue::Bool(true), + crate::frame::TYPE_FALSE => HeaderValue::Bool(false), + crate::frame::TYPE_BYTE => HeaderValue::Byte(i8::arbitrary(unstruct)?), + crate::frame::TYPE_INT16 => HeaderValue::Int16(i16::arbitrary(unstruct)?), + crate::frame::TYPE_INT32 => HeaderValue::Int32(i32::arbitrary(unstruct)?), + crate::frame::TYPE_INT64 => HeaderValue::Int64(i64::arbitrary(unstruct)?), + crate::frame::TYPE_BYTE_ARRAY => { + HeaderValue::ByteArray(Bytes::from(Vec::::arbitrary(unstruct)?)) + } + crate::frame::TYPE_STRING => { + HeaderValue::String(StrBytes::from(String::arbitrary(unstruct)?)) + } + crate::frame::TYPE_TIMESTAMP => { + HeaderValue::Timestamp(DateTime::from_secs(i64::arbitrary(unstruct)?)) + } + crate::frame::TYPE_UUID => HeaderValue::Uuid(u128::arbitrary(unstruct)?), + _ => unreachable!(), + }; + Ok(ArbHeaderValue(header_value)) + } +} + +impl From for HeaderValue { + fn from(header_value: ArbHeaderValue) -> Self { + header_value.0 + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ArbStrBytes(StrBytes); + +#[cfg(feature = "derive-arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for ArbStrBytes { + fn arbitrary(unstruct: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + Ok(ArbStrBytes(String::arbitrary(unstruct)?.into())) + } +} + +impl From for StrBytes { + fn from(str_bytes: ArbStrBytes) -> Self { + str_bytes.0 + } +} + +#[derive(Clone, Debug, PartialEq, derive_arbitrary::Arbitrary)] +pub struct ArbHeader { + name: ArbStrBytes, + value: ArbHeaderValue, +} + +impl From for Header { + fn from(header: ArbHeader) -> Self { + Self::new(header.name, header.value) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ArbMessage(Message); + +impl<'a> arbitrary::Arbitrary<'a> for ArbMessage { + fn arbitrary(unstruct: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let headers: Vec = unstruct + .arbitrary_iter()? + .collect::>()?; + let message = Message::new_from_parts( + headers.into_iter().map(Into::into).collect(), + Bytes::from(Vec::::arbitrary(unstruct)?), + ); + Ok(ArbMessage(message)) + } +} + +impl From for Message { + fn from(message: ArbMessage) -> Self { + message.0 + } +} diff --git a/rust-runtime/aws-smithy-eventstream/src/error.rs b/rust-runtime/aws-smithy-eventstream/src/error.rs index bda5ff900d..99cb4ba759 100644 --- a/rust-runtime/aws-smithy-eventstream/src/error.rs +++ b/rust-runtime/aws-smithy-eventstream/src/error.rs @@ -51,6 +51,20 @@ impl Error { kind: ErrorKind::Unmarshalling(message.into()), } } + + /// Returns true if the error is one generated during serialization + pub fn is_invalid_message(&self) -> bool { + use ErrorKind::*; + matches!( + self.kind, + HeadersTooLong + | PayloadTooLong + | MessageTooLong + | InvalidHeaderNameLength + | TimestampValueTooLarge(_) + | Marshalling(_) + ) + } } impl From for Error { diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index 202e410827..eb2bb1707c 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -8,8 +8,11 @@ use crate::buf::count::CountBuf; use crate::buf::crc::{CrcBuf, CrcBufMut}; use crate::error::{Error, ErrorKind}; -use crate::str_bytes::StrBytes; -use bytes::{Buf, BufMut, Bytes}; +use aws_smithy_types::config_bag::{Storable, StoreReplace}; +use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; +use aws_smithy_types::str_bytes::StrBytes; +use aws_smithy_types::DateTime; +use bytes::{Buf, BufMut}; use std::convert::{TryFrom, TryInto}; use std::error::Error as StdError; use std::fmt; @@ -22,6 +25,17 @@ const MESSAGE_CRC_LENGTH_BYTES: u32 = size_of::() as u32; const MAX_HEADER_NAME_LEN: usize = 255; const MIN_HEADER_LEN: usize = 2; +pub(crate) const TYPE_TRUE: u8 = 0; +pub(crate) const TYPE_FALSE: u8 = 1; +pub(crate) const TYPE_BYTE: u8 = 2; +pub(crate) const TYPE_INT16: u8 = 3; +pub(crate) const TYPE_INT32: u8 = 4; +pub(crate) const TYPE_INT64: u8 = 5; +pub(crate) const TYPE_BYTE_ARRAY: u8 = 6; +pub(crate) const TYPE_STRING: u8 = 7; +pub(crate) const TYPE_TIMESTAMP: u8 = 8; +pub(crate) const TYPE_UUID: u8 = 9; + pub type SignMessageError = Box; /// Signs an Event Stream message. @@ -168,451 +182,242 @@ pub trait UnmarshallMessage: fmt::Debug { ) -> Result, Error>; } -mod value { - use crate::error::{Error, ErrorKind}; - use crate::frame::checked; - use crate::str_bytes::StrBytes; - use aws_smithy_types::DateTime; - use bytes::{Buf, BufMut, Bytes}; - use std::convert::TryInto; - use std::mem::size_of; - - const TYPE_TRUE: u8 = 0; - const TYPE_FALSE: u8 = 1; - const TYPE_BYTE: u8 = 2; - const TYPE_INT16: u8 = 3; - const TYPE_INT32: u8 = 4; - const TYPE_INT64: u8 = 5; - const TYPE_BYTE_ARRAY: u8 = 6; - const TYPE_STRING: u8 = 7; - const TYPE_TIMESTAMP: u8 = 8; - const TYPE_UUID: u8 = 9; - - /// Event Stream frame header value. - #[non_exhaustive] - #[derive(Clone, Debug, PartialEq)] - pub enum HeaderValue { - Bool(bool), - Byte(i8), - Int16(i16), - Int32(i32), - Int64(i64), - ByteArray(Bytes), - String(StrBytes), - Timestamp(DateTime), - Uuid(u128), - } - - impl HeaderValue { - pub fn as_bool(&self) -> Result { - match self { - HeaderValue::Bool(value) => Ok(*value), - _ => Err(self), - } - } - - pub fn as_byte(&self) -> Result { - match self { - HeaderValue::Byte(value) => Ok(*value), - _ => Err(self), - } - } - - pub fn as_int16(&self) -> Result { - match self { - HeaderValue::Int16(value) => Ok(*value), - _ => Err(self), - } - } - - pub fn as_int32(&self) -> Result { - match self { - HeaderValue::Int32(value) => Ok(*value), - _ => Err(self), - } - } - - pub fn as_int64(&self) -> Result { - match self { - HeaderValue::Int64(value) => Ok(*value), - _ => Err(self), - } - } - - pub fn as_byte_array(&self) -> Result<&Bytes, &Self> { - match self { - HeaderValue::ByteArray(value) => Ok(value), - _ => Err(self), - } - } - - pub fn as_string(&self) -> Result<&StrBytes, &Self> { - match self { - HeaderValue::String(value) => Ok(value), - _ => Err(self), - } - } - - pub fn as_timestamp(&self) -> Result { - match self { - HeaderValue::Timestamp(value) => Ok(*value), - _ => Err(self), - } +macro_rules! read_value { + ($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => { + if $buf.remaining() >= size_of::<$size_typ>() { + Ok(HeaderValue::$typ($buf.$read_fn())) + } else { + Err(ErrorKind::InvalidHeaderValue.into()) } + }; +} - pub fn as_uuid(&self) -> Result { - match self { - HeaderValue::Uuid(value) => Ok(*value), - _ => Err(self), +fn read_header_value_from(mut buffer: B) -> Result { + let value_type = buffer.get_u8(); + match value_type { + TYPE_TRUE => Ok(HeaderValue::Bool(true)), + TYPE_FALSE => Ok(HeaderValue::Bool(false)), + TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8), + TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16), + TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32), + TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64), + TYPE_BYTE_ARRAY | TYPE_STRING => { + if buffer.remaining() > size_of::() { + let len = buffer.get_u16() as usize; + if buffer.remaining() < len { + return Err(ErrorKind::InvalidHeaderValue.into()); + } + let bytes = buffer.copy_to_bytes(len); + if value_type == TYPE_STRING { + Ok(HeaderValue::String( + bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?, + )) + } else { + Ok(HeaderValue::ByteArray(bytes)) + } + } else { + Err(ErrorKind::InvalidHeaderValue.into()) } } - } - - macro_rules! read_value { - ($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => { - if $buf.remaining() >= size_of::<$size_typ>() { - Ok(HeaderValue::$typ($buf.$read_fn())) + TYPE_TIMESTAMP => { + if buffer.remaining() >= size_of::() { + let epoch_millis = buffer.get_i64(); + Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis))) } else { Err(ErrorKind::InvalidHeaderValue.into()) } - }; + } + TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128), + _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()), } +} - impl HeaderValue { - pub(super) fn read_from(mut buffer: B) -> Result { - let value_type = buffer.get_u8(); - match value_type { - TYPE_TRUE => Ok(HeaderValue::Bool(true)), - TYPE_FALSE => Ok(HeaderValue::Bool(false)), - TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8), - TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16), - TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32), - TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64), - TYPE_BYTE_ARRAY | TYPE_STRING => { - if buffer.remaining() > size_of::() { - let len = buffer.get_u16() as usize; - if buffer.remaining() < len { - return Err(ErrorKind::InvalidHeaderValue.into()); - } - let bytes = buffer.copy_to_bytes(len); - if value_type == TYPE_STRING { - Ok(HeaderValue::String( - bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?, - )) - } else { - Ok(HeaderValue::ByteArray(bytes)) - } - } else { - Err(ErrorKind::InvalidHeaderValue.into()) - } - } - TYPE_TIMESTAMP => { - if buffer.remaining() >= size_of::() { - let epoch_millis = buffer.get_i64(); - Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis))) - } else { - Err(ErrorKind::InvalidHeaderValue.into()) - } - } - TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128), - _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()), - } +fn write_header_value_to(value: &HeaderValue, mut buffer: B) -> Result<(), Error> { + use HeaderValue::*; + match value { + Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }), + Byte(val) => { + buffer.put_u8(TYPE_BYTE); + buffer.put_i8(*val); + } + Int16(val) => { + buffer.put_u8(TYPE_INT16); + buffer.put_i16(*val); + } + Int32(val) => { + buffer.put_u8(TYPE_INT32); + buffer.put_i32(*val); + } + Int64(val) => { + buffer.put_u8(TYPE_INT64); + buffer.put_i64(*val); + } + ByteArray(val) => { + buffer.put_u8(TYPE_BYTE_ARRAY); + buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?); + buffer.put_slice(&val[..]); + } + String(val) => { + buffer.put_u8(TYPE_STRING); + buffer.put_u16(checked( + val.as_bytes().len(), + ErrorKind::HeaderValueTooLong.into(), + )?); + buffer.put_slice(&val.as_bytes()[..]); + } + Timestamp(time) => { + buffer.put_u8(TYPE_TIMESTAMP); + buffer.put_i64( + time.to_millis() + .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?, + ); } - - pub(super) fn write_to(&self, mut buffer: B) -> Result<(), Error> { - use HeaderValue::*; - match self { - Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }), - Byte(val) => { - buffer.put_u8(TYPE_BYTE); - buffer.put_i8(*val); - } - Int16(val) => { - buffer.put_u8(TYPE_INT16); - buffer.put_i16(*val); - } - Int32(val) => { - buffer.put_u8(TYPE_INT32); - buffer.put_i32(*val); - } - Int64(val) => { - buffer.put_u8(TYPE_INT64); - buffer.put_i64(*val); - } - ByteArray(val) => { - buffer.put_u8(TYPE_BYTE_ARRAY); - buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?); - buffer.put_slice(&val[..]); - } - String(val) => { - buffer.put_u8(TYPE_STRING); - buffer.put_u16(checked( - val.as_bytes().len(), - ErrorKind::HeaderValueTooLong.into(), - )?); - buffer.put_slice(&val.as_bytes()[..]); - } - Timestamp(time) => { - buffer.put_u8(TYPE_TIMESTAMP); - buffer.put_i64( - time.to_millis() - .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?, - ); - } - Uuid(val) => { - buffer.put_u8(TYPE_UUID); - buffer.put_u128(*val); - } - } - Ok(()) + Uuid(val) => { + buffer.put_u8(TYPE_UUID); + buffer.put_u128(*val); } - } - - #[cfg(feature = "derive-arbitrary")] - impl<'a> arbitrary::Arbitrary<'a> for HeaderValue { - fn arbitrary(unstruct: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let value_type: u8 = unstruct.int_in_range(0..=9)?; - Ok(match value_type { - TYPE_TRUE => HeaderValue::Bool(true), - TYPE_FALSE => HeaderValue::Bool(false), - TYPE_BYTE => HeaderValue::Byte(i8::arbitrary(unstruct)?), - TYPE_INT16 => HeaderValue::Int16(i16::arbitrary(unstruct)?), - TYPE_INT32 => HeaderValue::Int32(i32::arbitrary(unstruct)?), - TYPE_INT64 => HeaderValue::Int64(i64::arbitrary(unstruct)?), - TYPE_BYTE_ARRAY => { - HeaderValue::ByteArray(Bytes::from(Vec::::arbitrary(unstruct)?)) - } - TYPE_STRING => HeaderValue::String(StrBytes::from(String::arbitrary(unstruct)?)), - TYPE_TIMESTAMP => { - HeaderValue::Timestamp(DateTime::from_secs(i64::arbitrary(unstruct)?)) - } - TYPE_UUID => HeaderValue::Uuid(u128::arbitrary(unstruct)?), - _ => unreachable!(), - }) + _ => { + panic!("matched on unexpected variant in `aws_smithy_types::event_stream::HeaderValue`") } } + Ok(()) } -use aws_smithy_types::config_bag::{Storable, StoreReplace}; -pub use value::HeaderValue; - -/// Event Stream header. -#[non_exhaustive] -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "derive-arbitrary", derive(derive_arbitrary::Arbitrary))] -pub struct Header { - name: StrBytes, - value: HeaderValue, -} - -impl Header { - /// Creates a new header with the given `name` and `value`. - pub fn new(name: impl Into, value: HeaderValue) -> Header { - Header { - name: name.into(), - value, - } - } - - /// Returns the header name. - pub fn name(&self) -> &StrBytes { - &self.name +/// Reads a header from the given `buffer`. +fn read_header_from(mut buffer: B) -> Result<(Header, usize), Error> { + if buffer.remaining() < MIN_HEADER_LEN { + return Err(ErrorKind::InvalidHeadersLength.into()); } - /// Returns the header value. - pub fn value(&self) -> &HeaderValue { - &self.value + let mut counting_buf = CountBuf::new(&mut buffer); + let name_len = counting_buf.get_u8(); + if name_len as usize >= counting_buf.remaining() { + return Err(ErrorKind::InvalidHeaderNameLength.into()); } - /// Reads a header from the given `buffer`. - fn read_from(mut buffer: B) -> Result<(Header, usize), Error> { - if buffer.remaining() < MIN_HEADER_LEN { - return Err(ErrorKind::InvalidHeadersLength.into()); - } - - let mut counting_buf = CountBuf::new(&mut buffer); - let name_len = counting_buf.get_u8(); - if name_len as usize >= counting_buf.remaining() { - return Err(ErrorKind::InvalidHeaderNameLength.into()); - } + let name: StrBytes = counting_buf + .copy_to_bytes(name_len as usize) + .try_into() + .map_err(|_| ErrorKind::InvalidUtf8String)?; + let value = read_header_value_from(&mut counting_buf)?; + Ok((Header::new(name, value), counting_buf.into_count())) +} - let name: StrBytes = counting_buf - .copy_to_bytes(name_len as usize) - .try_into() - .map_err(|_| ErrorKind::InvalidUtf8String)?; - let value = HeaderValue::read_from(&mut counting_buf)?; - Ok((Header::new(name, value), counting_buf.into_count())) +/// Writes the header to the given `buffer`. +fn write_header_to(header: &Header, mut buffer: B) -> Result<(), Error> { + if header.name().as_bytes().len() > MAX_HEADER_NAME_LEN { + return Err(ErrorKind::InvalidHeaderNameLength.into()); } - /// Writes the header to the given `buffer`. - fn write_to(&self, mut buffer: B) -> Result<(), Error> { - if self.name.as_bytes().len() > MAX_HEADER_NAME_LEN { - return Err(ErrorKind::InvalidHeaderNameLength.into()); - } - - buffer.put_u8(u8::try_from(self.name.as_bytes().len()).expect("bounds check above")); - buffer.put_slice(&self.name.as_bytes()[..]); - self.value.write_to(buffer) - } + buffer.put_u8(u8::try_from(header.name().as_bytes().len()).expect("bounds check above")); + buffer.put_slice(&header.name().as_bytes()[..]); + write_header_value_to(header.value(), buffer) } /// Writes the given `headers` to a `buffer`. pub fn write_headers_to(headers: &[Header], mut buffer: B) -> Result<(), Error> { for header in headers { - header.write_to(&mut buffer)?; + write_header_to(header, &mut buffer)?; } Ok(()) } -/// Event Stream message. -#[non_exhaustive] -#[derive(Clone, Debug, PartialEq)] -pub struct Message { - headers: Vec
, - payload: Bytes, -} - -impl Message { - /// Creates a new message with the given `payload`. Headers can be added later. - pub fn new(payload: impl Into) -> Message { - Message { - headers: Vec::new(), - payload: payload.into(), - } - } +// Returns (total_len, header_len) +fn read_prelude_from(mut buffer: B) -> Result<(u32, u32), Error> { + let mut crc_buffer = CrcBuf::new(&mut buffer); - /// Creates a message with the given `headers` and `payload`. - pub fn new_from_parts(headers: Vec
, payload: impl Into) -> Self { - Self { - headers, - payload: payload.into(), - } + // If the buffer doesn't have the entire, then error + let total_len = crc_buffer.get_u32(); + if crc_buffer.remaining() + size_of::() < total_len as usize { + return Err(ErrorKind::InvalidMessageLength.into()); } - /// Adds a header to the message. - pub fn add_header(mut self, header: Header) -> Self { - self.headers.push(header); - self + // Validate the prelude + let header_len = crc_buffer.get_u32(); + let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32()); + if expected_crc != prelude_crc { + return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into()); } - - /// Returns all headers. - pub fn headers(&self) -> &[Header] { - &self.headers + // The header length can be 0 or >= 2, but must fit within the frame size + if header_len == 1 || header_len > max_header_len(total_len)? { + return Err(ErrorKind::InvalidHeadersLength.into()); } + Ok((total_len, header_len)) +} - /// Returns the payload bytes. - pub fn payload(&self) -> &Bytes { - &self.payload +/// Reads a message from the given `buffer`. For streaming use cases, use +/// the [`MessageFrameDecoder`] instead of this. +pub fn read_message_from(mut buffer: B) -> Result { + if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE { + return Err(ErrorKind::InvalidMessageLength.into()); } - // Returns (total_len, header_len) - fn read_prelude_from(mut buffer: B) -> Result<(u32, u32), Error> { - let mut crc_buffer = CrcBuf::new(&mut buffer); + // Calculate a CRC as we go and read the prelude + let mut crc_buffer = CrcBuf::new(&mut buffer); + let (total_len, header_len) = read_prelude_from(&mut crc_buffer)?; - // If the buffer doesn't have the entire, then error - let total_len = crc_buffer.get_u32(); - if crc_buffer.remaining() + size_of::() < total_len as usize { - return Err(ErrorKind::InvalidMessageLength.into()); - } - - // Validate the prelude - let header_len = crc_buffer.get_u32(); - let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32()); - if expected_crc != prelude_crc { - return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into()); - } - // The header length can be 0 or >= 2, but must fit within the frame size - if header_len == 1 || header_len > max_header_len(total_len)? { - return Err(ErrorKind::InvalidHeadersLength.into()); - } - Ok((total_len, header_len)) + // Verify we have the full frame before continuing + let remaining_len = total_len + .checked_sub(PRELUDE_LENGTH_BYTES) + .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; + if crc_buffer.remaining() < remaining_len as usize { + return Err(ErrorKind::InvalidMessageLength.into()); } - /// Reads a message from the given `buffer`. For streaming use cases, use - /// the [`MessageFrameDecoder`] instead of this. - pub fn read_from(mut buffer: B) -> Result { - if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE { - return Err(ErrorKind::InvalidMessageLength.into()); - } - - // Calculate a CRC as we go and read the prelude - let mut crc_buffer = CrcBuf::new(&mut buffer); - let (total_len, header_len) = Self::read_prelude_from(&mut crc_buffer)?; - - // Verify we have the full frame before continuing - let remaining_len = total_len - .checked_sub(PRELUDE_LENGTH_BYTES) - .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; - if crc_buffer.remaining() < remaining_len as usize { - return Err(ErrorKind::InvalidMessageLength.into()); - } - - // Read headers - let mut header_bytes_read = 0; - let mut headers = Vec::new(); - while header_bytes_read < header_len as usize { - let (header, bytes_read) = Header::read_from(&mut crc_buffer)?; - header_bytes_read += bytes_read; - if header_bytes_read > header_len as usize { - return Err(ErrorKind::InvalidHeaderValue.into()); - } - headers.push(header); - } - - // Read payload - let payload_len = payload_len(total_len, header_len)?; - let payload = crc_buffer.copy_to_bytes(payload_len as usize); - - let expected_crc = crc_buffer.into_crc(); - let message_crc = buffer.get_u32(); - if expected_crc != message_crc { - return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into()); + // Read headers + let mut header_bytes_read = 0; + let mut headers = Vec::new(); + while header_bytes_read < header_len as usize { + let (header, bytes_read) = read_header_from(&mut crc_buffer)?; + header_bytes_read += bytes_read; + if header_bytes_read > header_len as usize { + return Err(ErrorKind::InvalidHeaderValue.into()); } - - Ok(Message { headers, payload }) + headers.push(header); } - /// Writes the message to the given `buffer`. - pub fn write_to(&self, buffer: &mut dyn BufMut) -> Result<(), Error> { - let mut headers = Vec::new(); - for header in &self.headers { - header.write_to(&mut headers)?; - } + // Read payload + let payload_len = payload_len(total_len, header_len)?; + let payload = crc_buffer.copy_to_bytes(payload_len as usize); - let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?; - let payload_len = checked(self.payload.len(), ErrorKind::PayloadTooLong.into())?; - let message_len = [ - PRELUDE_LENGTH_BYTES, - headers_len, - payload_len, - MESSAGE_CRC_LENGTH_BYTES, - ] - .iter() - .try_fold(0u32, |acc, v| { - acc.checked_add(*v) - .ok_or_else(|| Error::from(ErrorKind::MessageTooLong)) - })?; - - let mut crc_buffer = CrcBufMut::new(buffer); - crc_buffer.put_u32(message_len); - crc_buffer.put_u32(headers_len); - crc_buffer.put_crc(); - crc_buffer.put(&headers[..]); - crc_buffer.put(&self.payload[..]); - crc_buffer.put_crc(); - Ok(()) + let expected_crc = crc_buffer.into_crc(); + let message_crc = buffer.get_u32(); + if expected_crc != message_crc { + return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into()); } + + Ok(Message::new_from_parts(headers, payload)) } -#[cfg(feature = "derive-arbitrary")] -impl<'a> arbitrary::Arbitrary<'a> for Message { - fn arbitrary(unstruct: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let headers: arbitrary::Result> = unstruct.arbitrary_iter()?.collect(); - Ok(Message { - headers: headers?, - payload: Bytes::from(Vec::::arbitrary(unstruct)?), - }) - } +/// Writes the `message` to the given `buffer`. +pub fn write_message_to(message: &Message, buffer: &mut dyn BufMut) -> Result<(), Error> { + let mut headers = Vec::new(); + for header in message.headers() { + write_header_to(header, &mut headers)?; + } + + let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?; + let payload_len = checked(message.payload().len(), ErrorKind::PayloadTooLong.into())?; + let message_len = [ + PRELUDE_LENGTH_BYTES, + headers_len, + payload_len, + MESSAGE_CRC_LENGTH_BYTES, + ] + .iter() + .try_fold(0u32, |acc, v| { + acc.checked_add(*v) + .ok_or_else(|| Error::from(ErrorKind::MessageTooLong)) + })?; + + let mut crc_buffer = CrcBufMut::new(buffer); + crc_buffer.put_u32(message_len); + crc_buffer.put_u32(headers_len); + crc_buffer.put_crc(); + crc_buffer.put(&headers[..]); + crc_buffer.put(&message.payload()[..]); + crc_buffer.put_crc(); + Ok(()) } fn checked, U>(from: U, err: Error) -> Result { @@ -637,14 +442,15 @@ fn payload_len(total_len: u32, header_len: u32) -> Result { #[cfg(test)] mod message_tests { + use super::read_message_from; use crate::error::ErrorKind; - use crate::frame::{Header, HeaderValue, Message}; + use crate::frame::{write_message_to, Header, HeaderValue, Message}; use aws_smithy_types::DateTime; use bytes::Bytes; macro_rules! read_message_expect_err { ($bytes:expr, $err:pat) => { - let result = Message::read_from(&mut Bytes::from_static($bytes)); + let result = read_message_from(&mut Bytes::from_static($bytes)); let result = result.as_ref(); assert!(result.is_err(), "Expected error, got {:?}", result); assert!( @@ -702,11 +508,11 @@ mod message_tests { 0x36, ]; - let result = Message::read_from(&mut Bytes::from_static(data)).unwrap(); + let result = read_message_from(&mut Bytes::from_static(data)).unwrap(); assert_eq!(result.headers(), Vec::new()); let expected_payload = b"{'foo':'bar'}"; - assert_eq!(expected_payload, result.payload.as_ref()); + assert_eq!(expected_payload, result.payload().as_ref()); } #[test] @@ -721,7 +527,7 @@ mod message_tests { 0x7d, 0x8D, 0x9C, 0x08, 0xB1, ]; - let result = Message::read_from(&mut Bytes::from_static(data)).unwrap(); + let result = read_message_from(&mut Bytes::from_static(data)).unwrap(); assert_eq!( result.headers(), vec![Header::new( @@ -731,13 +537,13 @@ mod message_tests { ); let expected_payload = b"{'foo':'bar'}"; - assert_eq!(expected_payload, result.payload.as_ref()); + assert_eq!(expected_payload, result.payload().as_ref()); } #[test] fn read_all_headers_and_payload() { let message = include_bytes!("../test_data/valid_with_all_headers_and_payload"); - let result = Message::read_from(&mut Bytes::from_static(message)).unwrap(); + let result = read_message_from(&mut Bytes::from_static(message)).unwrap(); assert_eq!( result.headers(), vec![ @@ -763,7 +569,7 @@ mod message_tests { ] ); - assert_eq!(b"some payload", result.payload.as_ref()); + assert_eq!(b"some payload", result.payload().as_ref()); } #[test] @@ -790,14 +596,14 @@ mod message_tests { )); let mut actual = Vec::new(); - message.write_to(&mut actual).unwrap(); + write_message_to(&message, &mut actual).unwrap(); let expected = include_bytes!("../test_data/valid_with_all_headers_and_payload").to_vec(); assert_eq!(expected, actual); - let result = Message::read_from(&mut Bytes::from(actual)).unwrap(); + let result = read_message_from(&mut Bytes::from(actual)).unwrap(); assert_eq!(message.headers(), result.headers()); - assert_eq!(message.payload().as_ref(), result.payload.as_ref()); + assert_eq!(message.payload().as_ref(), result.payload().as_ref()); } } @@ -866,7 +672,7 @@ impl MessageFrameDecoder { if let Some(remaining_len) = self.remaining_bytes_if_frame_available(&buffer)? { let mut message_buf = (&self.prelude[..]).chain(buffer.take(remaining_len)); - let result = Message::read_from(&mut message_buf).map(DecodedFrame::Complete); + let result = read_message_from(&mut message_buf).map(DecodedFrame::Complete); self.reset(); return result; } @@ -878,7 +684,7 @@ impl MessageFrameDecoder { #[cfg(test)] mod message_frame_decoder_tests { use super::{DecodedFrame, MessageFrameDecoder}; - use crate::frame::Message; + use crate::frame::read_message_from; use bytes::Bytes; use bytes_utils::SegmentedBuf; @@ -899,7 +705,7 @@ mod message_frame_decoder_tests { match decoder.decode_frame(&mut segmented).unwrap() { DecodedFrame::Incomplete => panic!("frame should be complete now"), DecodedFrame::Complete(actual) => { - let expected = Message::read_from(&mut Bytes::from_static(message)).unwrap(); + let expected = read_message_from(&mut Bytes::from_static(message)).unwrap(); assert_eq!(expected, actual); } } @@ -926,9 +732,9 @@ mod message_frame_decoder_tests { } } - let expected1 = Message::read_from(&mut Bytes::from_static(message1)).unwrap(); - let expected2 = Message::read_from(&mut Bytes::from_static(message2)).unwrap(); - let expected3 = Message::read_from(&mut Bytes::from_static(message3)).unwrap(); + let expected1 = read_message_from(&mut Bytes::from_static(message1)).unwrap(); + let expected2 = read_message_from(&mut Bytes::from_static(message2)).unwrap(); + let expected3 = read_message_from(&mut Bytes::from_static(message3)).unwrap(); assert_eq!(3, decoded.len()); assert_eq!(expected1, decoded[0]); assert_eq!(expected2, decoded[1]); diff --git a/rust-runtime/aws-smithy-eventstream/src/lib.rs b/rust-runtime/aws-smithy-eventstream/src/lib.rs index 0d060b914e..5171471d2f 100644 --- a/rust-runtime/aws-smithy-eventstream/src/lib.rs +++ b/rust-runtime/aws-smithy-eventstream/src/lib.rs @@ -13,8 +13,9 @@ //! AWS Event Stream frame serialization/deserialization implementation. +#[cfg(feature = "derive-arbitrary")] +pub mod arbitrary; mod buf; pub mod error; pub mod frame; pub mod smithy; -pub mod str_bytes; diff --git a/rust-runtime/aws-smithy-eventstream/src/smithy.rs b/rust-runtime/aws-smithy-eventstream/src/smithy.rs index 3a076d7eb2..d939602533 100644 --- a/rust-runtime/aws-smithy-eventstream/src/smithy.rs +++ b/rust-runtime/aws-smithy-eventstream/src/smithy.rs @@ -4,8 +4,8 @@ */ use crate::error::{Error, ErrorKind}; -use crate::frame::{Header, HeaderValue, Message}; -use crate::str_bytes::StrBytes; +use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; +use aws_smithy_types::str_bytes::StrBytes; use aws_smithy_types::{Blob, DateTime}; macro_rules! expect_shape_fn { @@ -125,7 +125,7 @@ pub fn parse_response_headers(message: &Message) -> Result, #[cfg(test)] mod tests { use super::parse_response_headers; - use crate::frame::{Header, HeaderValue, Message}; + use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; #[test] fn normal_message() { diff --git a/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs index 823ac6f516..69e1c4381b 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs @@ -4,10 +4,11 @@ */ use aws_smithy_eventstream::frame::{ - DecodedFrame, Message, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage, + DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage, }; use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError}; use aws_smithy_types::body::SdkBody; +use aws_smithy_types::event_stream::Message; use bytes::Buf; use bytes::Bytes; use bytes_utils::SegmentedBuf; @@ -277,9 +278,10 @@ impl Receiver { mod tests { use super::{Receiver, UnmarshallMessage}; use aws_smithy_eventstream::error::Error as EventStreamError; - use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshalledMessage}; + use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage}; use aws_smithy_runtime_api::client::result::SdkError; use aws_smithy_types::body::SdkBody; + use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; use hyper::body::Body; use std::error::Error as StdError; @@ -287,7 +289,7 @@ mod tests { fn encode_initial_response() -> Bytes { let mut buffer = Vec::new(); - Message::new(Bytes::new()) + let message = Message::new(Bytes::new()) .add_header(Header::new( ":message-type", HeaderValue::String("event".into()), @@ -295,17 +297,15 @@ mod tests { .add_header(Header::new( ":event-type", HeaderValue::String("initial-response".into()), - )) - .write_to(&mut buffer) - .unwrap(); + )); + write_message_to(&message, &mut buffer).unwrap(); buffer.into() } fn encode_message(message: &str) -> Bytes { let mut buffer = Vec::new(); - Message::new(Bytes::copy_from_slice(message.as_bytes())) - .write_to(&mut buffer) - .unwrap(); + let message = Message::new(Bytes::copy_from_slice(message.as_bytes())); + write_message_to(&message, &mut buffer).unwrap(); buffer.into() } diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index aa35b6cdee..f0dc00fa97 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_eventstream::frame::{MarshallMessage, SignMessage}; +use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage}; use aws_smithy_runtime_api::client::result::SdkError; use bytes::Bytes; use futures_core::Stream; @@ -165,8 +165,7 @@ impl Stream for MessageStreamAdapter Stream for MessageStreamAdapter { - sign.map_err(SdkError::construction_failure)? - .write_to(&mut buffer) + let message = sign.map_err(SdkError::construction_failure)?; + write_message_to(&message, &mut buffer) .map_err(SdkError::construction_failure)?; trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream"); Poll::Ready(Some(Ok(Bytes::from(buffer)))) @@ -199,9 +198,10 @@ mod tests { use async_stream::stream; use aws_smithy_eventstream::error::Error as EventStreamError; use aws_smithy_eventstream::frame::{ - Header, HeaderValue, Message, NoOpSigner, SignMessage, SignMessageError, + read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError, }; use aws_smithy_runtime_api::client::result::SdkError; + use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; use futures_core::Stream; use futures_util::stream::StreamExt; @@ -234,7 +234,7 @@ mod tests { type Input = TestServiceError; fn marshall(&self, _input: Self::Input) -> Result { - Err(Message::read_from(&b""[..]).expect_err("this should always fail")) + Err(read_message_from(&b""[..]).expect_err("this should always fail")) } } @@ -252,7 +252,7 @@ mod tests { impl SignMessage for TestSigner { fn sign(&mut self, message: Message) -> Result { let mut buffer = Vec::new(); - message.write_to(&mut buffer).unwrap(); + write_message_to(&message, &mut buffer).unwrap(); Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true)))) } @@ -299,14 +299,14 @@ mod tests { )); let mut sent_bytes = adapter.next().await.unwrap().unwrap(); - let sent = Message::read_from(&mut sent_bytes).unwrap(); + let sent = read_message_from(&mut sent_bytes).unwrap(); assert_eq!("signed", sent.headers()[0].name().as_str()); assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value()); - let inner = Message::read_from(&mut (&sent.payload()[..])).unwrap(); + let inner = read_message_from(&mut (&sent.payload()[..])).unwrap(); assert_eq!(&b"test"[..], &inner.payload()[..]); let mut end_signal_bytes = adapter.next().await.unwrap().unwrap(); - let end_signal = Message::read_from(&mut end_signal_bytes).unwrap(); + let end_signal = read_message_from(&mut end_signal_bytes).unwrap(); assert_eq!("signed", end_signal.headers()[0].name().as_str()); assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value()); assert_eq!(0, end_signal.payload().len()); diff --git a/rust-runtime/aws-smithy-types/src/event_stream.rs b/rust-runtime/aws-smithy-types/src/event_stream.rs new file mode 100644 index 0000000000..a63c44e8ed --- /dev/null +++ b/rust-runtime/aws-smithy-types/src/event_stream.rs @@ -0,0 +1,185 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Types relevant to event stream serialization/deserialization + +use crate::str_bytes::StrBytes; +use bytes::Bytes; + +mod value { + use crate::str_bytes::StrBytes; + use crate::DateTime; + use bytes::Bytes; + + /// Event Stream frame header value. + #[non_exhaustive] + #[derive(Clone, Debug, PartialEq)] + pub enum HeaderValue { + /// Represents a boolean value. + Bool(bool), + /// Represents a byte value. + Byte(i8), + /// Represents an int16 value. + Int16(i16), + /// Represents an int32 value. + Int32(i32), + /// Represents an int64 value. + Int64(i64), + /// Represents a byte array value. + ByteArray(Bytes), + /// Represents a string value. + String(StrBytes), + /// Represents a timestamp value. + Timestamp(DateTime), + /// Represents a uuid value. + Uuid(u128), + } + + impl HeaderValue { + /// If the `HeaderValue` is a `Bool`, returns the associated `bool`. Returns `Err` otherwise. + pub fn as_bool(&self) -> Result { + match self { + HeaderValue::Bool(value) => Ok(*value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is a `Byte`, returns the associated `i8`. Returns `Err` otherwise. + pub fn as_byte(&self) -> Result { + match self { + HeaderValue::Byte(value) => Ok(*value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is an `Int16`, returns the associated `i16`. Returns `Err` otherwise. + pub fn as_int16(&self) -> Result { + match self { + HeaderValue::Int16(value) => Ok(*value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is an `Int32`, returns the associated `i32`. Returns `Err` otherwise. + pub fn as_int32(&self) -> Result { + match self { + HeaderValue::Int32(value) => Ok(*value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is an `Int64`, returns the associated `i64`. Returns `Err` otherwise. + pub fn as_int64(&self) -> Result { + match self { + HeaderValue::Int64(value) => Ok(*value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is a `ByteArray`, returns the associated [`Bytes`]. Returns `Err` otherwise. + pub fn as_byte_array(&self) -> Result<&Bytes, &Self> { + match self { + HeaderValue::ByteArray(value) => Ok(value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is a `String`, returns the associated [`StrBytes`]. Returns `Err` otherwise. + pub fn as_string(&self) -> Result<&StrBytes, &Self> { + match self { + HeaderValue::String(value) => Ok(value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is a `Timestamp`, returns the associated [`DateTime`]. Returns `Err` otherwise. + pub fn as_timestamp(&self) -> Result { + match self { + HeaderValue::Timestamp(value) => Ok(*value), + _ => Err(self), + } + } + + /// If the `HeaderValue` is a `Uuid`, returns the associated `u128`. Returns `Err` otherwise. + pub fn as_uuid(&self) -> Result { + match self { + HeaderValue::Uuid(value) => Ok(*value), + _ => Err(self), + } + } + } +} + +pub use value::HeaderValue; + +/// Event Stream header. +#[non_exhaustive] +#[derive(Clone, Debug, PartialEq)] +pub struct Header { + name: StrBytes, + value: HeaderValue, +} + +impl Header { + /// Creates a new header with the given `name` and `value`. + pub fn new(name: impl Into, value: impl Into) -> Header { + Header { + name: name.into(), + value: value.into(), + } + } + + /// Returns the header name. + pub fn name(&self) -> &StrBytes { + &self.name + } + + /// Returns the header value. + pub fn value(&self) -> &HeaderValue { + &self.value + } +} + +/// Event Stream message. +#[non_exhaustive] +#[derive(Clone, Debug, PartialEq)] +pub struct Message { + headers: Vec
, + payload: Bytes, +} + +impl Message { + /// Creates a new message with the given `payload`. Headers can be added later. + pub fn new(payload: impl Into) -> Message { + Message { + headers: Vec::new(), + payload: payload.into(), + } + } + + /// Creates a message with the given `headers` and `payload`. + pub fn new_from_parts(headers: Vec
, payload: impl Into) -> Self { + Self { + headers, + payload: payload.into(), + } + } + + /// Adds a header to the message. + pub fn add_header(mut self, header: Header) -> Self { + self.headers.push(header); + self + } + + /// Returns all headers. + pub fn headers(&self) -> &[Header] { + &self.headers + } + + /// Returns the payload bytes. + pub fn payload(&self) -> &Bytes { + &self.payload + } +} diff --git a/rust-runtime/aws-smithy-types/src/lib.rs b/rust-runtime/aws-smithy-types/src/lib.rs index 7c28d18174..386c905733 100644 --- a/rust-runtime/aws-smithy-types/src/lib.rs +++ b/rust-runtime/aws-smithy-types/src/lib.rs @@ -22,6 +22,7 @@ pub mod config_bag; pub mod date_time; pub mod endpoint; pub mod error; +pub mod event_stream; pub mod primitive; pub mod retry; pub mod timeout; @@ -32,6 +33,7 @@ pub mod type_erasure; mod blob; mod document; mod number; +pub mod str_bytes; pub use blob::Blob; pub use date_time::DateTime; diff --git a/rust-runtime/aws-smithy-eventstream/src/str_bytes.rs b/rust-runtime/aws-smithy-types/src/str_bytes.rs similarity index 92% rename from rust-runtime/aws-smithy-eventstream/src/str_bytes.rs rename to rust-runtime/aws-smithy-types/src/str_bytes.rs index aa76e48c31..96661e7797 100644 --- a/rust-runtime/aws-smithy-eventstream/src/str_bytes.rs +++ b/rust-runtime/aws-smithy-types/src/str_bytes.rs @@ -16,7 +16,7 @@ use std::str::Utf8Error; /// /// Example construction from a `&str`: /// ```rust -/// use aws_smithy_eventstream::str_bytes::StrBytes; +/// use aws_smithy_types::str_bytes::StrBytes; /// /// let value: StrBytes = "example".into(); /// assert_eq!("example", value.as_str()); @@ -26,7 +26,7 @@ use std::str::Utf8Error; /// Example construction from `Bytes`: /// ```rust /// use bytes::Bytes; -/// use aws_smithy_eventstream::str_bytes::StrBytes; +/// use aws_smithy_types::str_bytes::StrBytes; /// use std::convert::TryInto; /// /// let bytes = Bytes::from_static(b"example"); @@ -71,13 +71,6 @@ impl StrBytes { } } -#[cfg(feature = "derive-arbitrary")] -impl<'a> arbitrary::Arbitrary<'a> for StrBytes { - fn arbitrary(unstruct: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - Ok(String::arbitrary(unstruct)?.into()) - } -} - impl From for StrBytes { fn from(value: String) -> Self { StrBytes::new(Bytes::from(value))