From fba1006136ddb46ae4cfc038e310550e37e94bf0 Mon Sep 17 00:00:00 2001 From: Luc Talatinian <102624213+lucix-aws@users.noreply.github.com> Date: Fri, 29 Mar 2024 13:22:33 -0400 Subject: [PATCH] feat: add rpcv2 cbor support (#509) --- codegen/smithy-go-codegen/build.gradle.kts | 1 + .../smithy/go/codegen/AddOperationShapes.java | 2 + .../go/codegen/EventStreamGenerator.java | 2 + .../smithy/go/codegen/GoDependency.java | 36 + .../amazon/smithy/go/codegen/GoSettings.java | 21 +- .../smithy/go/codegen/GoStdlibTypes.java | 14 + .../smithy/go/codegen/SmithyGoDependency.java | 2 + .../smithy/go/codegen/SmithyGoTypes.java | 41 +- .../amazon/smithy/go/codegen/SymbolUtils.java | 13 + .../codegen/integration/DefaultProtocols.java | 27 + .../HttpProtocolUnitTestRequestGenerator.java | 5 + ...rotocolUnitTestResponseErrorGenerator.java | 15 +- ...HttpProtocolUnitTestResponseGenerator.java | 15 +- .../DeserializeResponseMiddleware.java | 84 ++ .../go/codegen/protocol/ProtocolUtil.java | 54 + .../protocol/SerializeRequestMiddleware.java | 91 ++ .../Rpc2DeserializeResponseMiddleware.java | 101 ++ .../protocol/rpc2/Rpc2ProtocolGenerator.java | 64 + .../rpc2/Rpc2SerializeRequestMiddleware.java | 91 ++ .../rpc2/cbor/DeserializeMiddleware.java | 75 + .../protocol/rpc2/cbor/ProtocolUtil.java | 62 + .../rpc2/cbor/Rpc2CborProtocolGenerator.java | 209 +++ .../rpc2/cbor/SerializeMiddleware.java | 76 + .../smithy/go/codegen/serde/SerdeUtil.java | 99 ++ .../serde/cbor/CborDeserializerGenerator.java | 420 ++++++ .../serde/cbor/CborSerializerGenerator.java | 323 ++++ .../trait/BackfilledInputOutputTrait.java | 33 + ...mithy.go.codegen.integration.GoIntegration | 2 +- document/cbor/cbor.go | 8 + document/cbor/decode.go | 342 +++++ document/cbor/decode_test.go | 130 ++ document/cbor/encode.go | 228 +++ document/cbor/encode_test.go | 115 ++ encoding/cbor/cbor.go | 139 ++ encoding/cbor/coerce.go | 229 +++ encoding/cbor/coerce_test.go | 531 +++++++ encoding/cbor/const.go | 41 + encoding/cbor/decode.go | 320 ++++ encoding/cbor/decode_test.go | 1334 +++++++++++++++++ encoding/cbor/encode.go | 218 +++ encoding/cbor/encode_test.go | 466 ++++++ encoding/cbor/float16.go | 45 + encoding/cbor/float16_test.go | 41 + encoding/cbor/fuzz_test.go | 114 ++ testing/cbor.go | 169 +++ testing/struct.go | 14 +- testing/struct_test.go | 21 +- 47 files changed, 6471 insertions(+), 12 deletions(-) create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DefaultProtocols.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/DeserializeResponseMiddleware.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/ProtocolUtil.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/SerializeRequestMiddleware.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2DeserializeResponseMiddleware.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2ProtocolGenerator.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2SerializeRequestMiddleware.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/DeserializeMiddleware.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/ProtocolUtil.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/Rpc2CborProtocolGenerator.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/SerializeMiddleware.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/SerdeUtil.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborDeserializerGenerator.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborSerializerGenerator.java create mode 100644 codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/trait/BackfilledInputOutputTrait.java create mode 100644 document/cbor/cbor.go create mode 100644 document/cbor/decode.go create mode 100644 document/cbor/decode_test.go create mode 100644 document/cbor/encode.go create mode 100644 document/cbor/encode_test.go create mode 100644 encoding/cbor/cbor.go create mode 100644 encoding/cbor/coerce.go create mode 100644 encoding/cbor/coerce_test.go create mode 100644 encoding/cbor/const.go create mode 100644 encoding/cbor/decode.go create mode 100644 encoding/cbor/decode_test.go create mode 100644 encoding/cbor/encode.go create mode 100644 encoding/cbor/encode_test.go create mode 100644 encoding/cbor/float16.go create mode 100644 encoding/cbor/float16_test.go create mode 100644 encoding/cbor/fuzz_test.go create mode 100644 testing/cbor.go diff --git a/codegen/smithy-go-codegen/build.gradle.kts b/codegen/smithy-go-codegen/build.gradle.kts index c8e49f11b..2f3651a8e 100644 --- a/codegen/smithy-go-codegen/build.gradle.kts +++ b/codegen/smithy-go-codegen/build.gradle.kts @@ -27,4 +27,5 @@ dependencies { api("org.jsoup:jsoup:1.14.1") api("software.amazon.smithy:smithy-rules-engine:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + api("software.amazon.smithy:smithy-protocol-traits:$smithyVersion") } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/AddOperationShapes.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/AddOperationShapes.java index f7b6d60b0..e678e010e 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/AddOperationShapes.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/AddOperationShapes.java @@ -17,6 +17,7 @@ import java.util.TreeSet; import java.util.logging.Logger; +import software.amazon.smithy.go.codegen.trait.BackfilledInputOutputTrait; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.AbstractShapeBuilder; @@ -83,6 +84,7 @@ private static StructureShape emptyOperationStructure(ServiceShape service, Shap return StructureShape.builder() .id(ShapeId.fromParts(CodegenUtils.getSyntheticTypeNamespace(), opShapeId.getName(service) + suffix)) .addTrait(Synthetic.builder().build()) + .addTrait(new BackfilledInputOutputTrait()) .build(); } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/EventStreamGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/EventStreamGenerator.java index 347741f28..d929b74f0 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/EventStreamGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/EventStreamGenerator.java @@ -31,6 +31,8 @@ import software.amazon.smithy.utils.StringUtils; public final class EventStreamGenerator { + public static final String AMZ_CONTENT_TYPE = "application/vnd.amazon.eventstream"; + private static final String EVENT_STREAM_FILE = "eventstream.go"; private final GoSettings settings; diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDependency.java index 8de9745ce..e526bc3d1 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoDependency.java @@ -156,6 +156,42 @@ public Symbol valueSymbol(String name) { return SymbolUtils.createValueSymbolBuilder(name, this).build(); } + /** + * Creates a Symbol for a `const` exported by this package. + * @param name The name. + * @return The symbol. + */ + public Symbol constSymbol(String name) { + return SymbolUtils.createValueSymbolBuilder(name, this).build(); + } + + /** + * Creates a Symbol for a `func` exported by this package. + * @param name The name. + * @return The symbol. + */ + public Symbol func(String name) { + return SymbolUtils.createValueSymbolBuilder(name, this).build(); + } + + /** + * Creates a Symbol for a `struct` exported by this package. + * @param name The name. + * @return The symbol. + */ + public Symbol struct(String name) { + return SymbolUtils.createPointableSymbolBuilder(name, this).build(); + } + + /** + * Creates a Symbol for a `Value` exported by this package. + * @param name The name. + * @return The symbol. + */ + public Symbol interfaceSymbol(String name) { + return SymbolUtils.createValueSymbolBuilder(name, this).build(); + } + /** * Creates a pointable Symbol for a name exported by this package. * @param name The name. diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoSettings.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoSettings.java index f91751763..646865b4c 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoSettings.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoSettings.java @@ -19,12 +19,19 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait; +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait; +import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait; +import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait; +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait; +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait; import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.ServiceIndex; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait; import software.amazon.smithy.utils.SmithyInternalApi; /** @@ -32,6 +39,15 @@ */ @SmithyInternalApi public final class GoSettings { + public static final Set PROTOCOLS_BY_PRIORITY = Set.of( + Rpcv2CborTrait.ID, + AwsJson1_0Trait.ID, + AwsJson1_1Trait.ID, + RestJson1Trait.ID, + RestXmlTrait.ID, + AwsQueryTrait.ID, + Ec2QueryTrait.ID + ); private static final String SERVICE = "service"; private static final String MODULE_NAME = "module"; @@ -247,7 +263,10 @@ public ShapeId resolveServiceProtocol( Set resolvedProtocols = serviceIndex.getProtocols(service).keySet(); - return resolvedProtocols.stream() + var byPriority = PROTOCOLS_BY_PRIORITY.stream() + .filter(resolvedProtocols::contains) + .toList(); + return byPriority.stream() .filter(supportedProtocolTraits::contains) .findFirst() .orElseThrow(() -> new UnresolvableProtocolException(String.format( diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoStdlibTypes.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoStdlibTypes.java index 359a564b3..2b642c10f 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoStdlibTypes.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoStdlibTypes.java @@ -24,6 +24,10 @@ public final class GoStdlibTypes { private GoStdlibTypes() { } + public static final class Bytes { + public static final Symbol NewReader = SmithyGoDependency.BYTES.valueSymbol("NewReader"); + } + public static final class Context { public static final Symbol Context = SmithyGoDependency.CONTEXT.valueSymbol("Context"); public static final Symbol Background = SmithyGoDependency.CONTEXT.valueSymbol("Background"); @@ -55,6 +59,15 @@ public static final class Fmt { public static final Symbol Sprintf = SmithyGoDependency.FMT.valueSymbol("Sprintf"); } + public static final class Io { + public static final Symbol ReadAll = SmithyGoDependency.IO.valueSymbol("ReadAll"); + public static final Symbol Copy = SmithyGoDependency.IO.valueSymbol("Copy"); + + public static final class IoUtil { + public static final Symbol Discard = SmithyGoDependency.IOUTIL.valueSymbol("Discard"); + } + } + public static final class Net { public static final class Http { public static final Symbol Request = SmithyGoDependency.NET_HTTP.pointableSymbol("Request"); @@ -62,6 +75,7 @@ public static final class Http { public static final Symbol Server = SmithyGoDependency.NET_HTTP.pointableSymbol("Server"); public static final Symbol Handler = SmithyGoDependency.NET_HTTP.valueSymbol("Handler"); public static final Symbol ResponseWriter = SmithyGoDependency.NET_HTTP.valueSymbol("ResponseWriter"); + public static final Symbol MethodPost = SmithyGoDependency.NET_HTTP.valueSymbol("MethodPost"); } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index 75cce524c..72b3c4519 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -60,6 +60,7 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding"); public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson"); public static final GoDependency SMITHY_XML = smithy("encoding/xml", "smithyxml"); + public static final GoDependency SMITHY_CBOR = smithy("encoding/cbor", "smithycbor"); public static final GoDependency SMITHY_IO = smithy("io", "smithyio"); public static final GoDependency SMITHY_LOGGING = smithy("logging"); public static final GoDependency SMITHY_PTR = smithy("ptr"); @@ -68,6 +69,7 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_WAITERS = smithy("waiter", "smithywaiter"); public static final GoDependency SMITHY_DOCUMENT = smithy("document", "smithydocument"); public static final GoDependency SMITHY_DOCUMENT_JSON = smithy("document/json", "smithydocumentjson"); + public static final GoDependency SMITHY_DOCUMENT_CBOR = smithy("document/cbor", "smithydocumentcbor"); public static final GoDependency SMITHY_SYNC = smithy("sync", "smithysync"); public static final GoDependency SMITHY_AUTH = smithy("auth", "smithyauth"); public static final GoDependency SMITHY_AUTH_BEARER = smithy("auth/bearer"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java index 7ca1ccbb6..e69e6e83e 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoTypes.java @@ -28,6 +28,7 @@ public static final class Smithy { public static final Symbol Properties = SmithyGoDependency.SMITHY.pointableSymbol("Properties"); public static final Symbol OperationError = SmithyGoDependency.SMITHY.pointableSymbol("OperationError"); public static final Symbol InvalidParamsError = SmithyGoDependency.SMITHY.pointableSymbol("InvalidParamsError"); + public static final Symbol SerializationError = SmithyGoDependency.SMITHY.pointableSymbol("SerializationError"); public static final class Document { public static final Symbol NoSerde = SmithyGoDependency.SMITHY_DOCUMENT.pointableSymbol("NoSerde"); @@ -39,15 +40,47 @@ public static final class Time { public static final Symbol FormatDateTime = SmithyGoDependency.SMITHY_TIME.valueSymbol("FormatDateTime"); } + public static final class Rand { + public static final Symbol NewUUID = SmithyGoDependency.SMITHY_RAND.valueSymbol("NewUUID"); + } + public static final class Encoding { public static final class Json { public static final Symbol NewEncoder = SmithyGoDependency.SMITHY_JSON.valueSymbol("NewEncoder"); public static final Symbol Value = SmithyGoDependency.SMITHY_JSON.valueSymbol("Value"); } + + public static final class Cbor { + public static final Symbol Encode = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Encode"); + public static final Symbol Decode = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Decode"); + public static final Symbol Value = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Value"); + public static final Symbol Uint = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Uint"); + public static final Symbol NegInt = SmithyGoDependency.SMITHY_CBOR.valueSymbol("NegInt"); + public static final Symbol Slice = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Slice"); + public static final Symbol String = SmithyGoDependency.SMITHY_CBOR.valueSymbol("String"); + public static final Symbol List = SmithyGoDependency.SMITHY_CBOR.valueSymbol("List"); + public static final Symbol Map = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Map"); + public static final Symbol Tag = SmithyGoDependency.SMITHY_CBOR.pointableSymbol("Tag"); + public static final Symbol Bool = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Bool"); + public static final Symbol Nil = SmithyGoDependency.SMITHY_CBOR.pointableSymbol("Nil"); + public static final Symbol Undefined = SmithyGoDependency.SMITHY_CBOR.pointableSymbol("Undefined"); + public static final Symbol Float32 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Float32"); + public static final Symbol Float64 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("Float64"); + public static final Symbol EncodeRaw = SmithyGoDependency.SMITHY_CBOR.valueSymbol("EncodeRaw"); + public static final Symbol AsInt8 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsInt8"); + public static final Symbol AsInt16 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsInt16"); + public static final Symbol AsInt32 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsInt32"); + public static final Symbol AsInt64 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsInt64"); + public static final Symbol AsFloat32 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsFloat32"); + public static final Symbol AsFloat64 = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsFloat64"); + public static final Symbol AsTime = SmithyGoDependency.SMITHY_CBOR.valueSymbol("AsTime"); + } } - public static final class Rand { - public static final Symbol NewUUID = SmithyGoDependency.SMITHY_RAND.valueSymbol("NewUUID"); + public static final class Document { + public static final class Cbor { + public static final Symbol NewEncoder = SmithyGoDependency.SMITHY_DOCUMENT_CBOR.valueSymbol("NewEncoder"); + } } public static final class Ptr { @@ -57,6 +90,9 @@ public static final class Ptr { public static final Symbol Int16 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int16"); public static final Symbol Int32 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int32"); public static final Symbol Int64 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Int64"); + public static final Symbol Float32 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Float32"); + public static final Symbol Float64 = SmithyGoDependency.SMITHY_PTR.valueSymbol("Float64"); + public static final Symbol Time = SmithyGoDependency.SMITHY_PTR.valueSymbol("Time"); } public static final class Middleware { @@ -84,6 +120,7 @@ public static final class Middleware { public static final class Transport { public static final class Http { public static final Symbol Request = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.pointableSymbol("Request"); + public static final Symbol Response = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.pointableSymbol("Response"); public static final Symbol NewStackRequest = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewStackRequest"); public static final Symbol NewClientHandler = SmithyGoDependency.SMITHY_HTTP_TRANSPORT.valueSymbol("NewClientHandler"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolUtils.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolUtils.java index c964d2fba..1a9fed49c 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolUtils.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolUtils.java @@ -210,4 +210,17 @@ public static Symbol getReference(Symbol symbol) { public static Symbol buildPackageSymbol(String name) { return Symbol.builder().name(name).build(); } + + public static Symbol buildSymbol(String name, String namespace) { + return Symbol.builder() + .name(name) + .namespace(namespace, ".") + .build(); + } + + public static boolean isNilable(Symbol symbol) { + return isPointable(symbol) + || symbol.getProperty(SymbolUtils.GO_SLICE).isPresent() + || symbol.getProperty(SymbolUtils.GO_MAP).isPresent(); + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DefaultProtocols.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DefaultProtocols.java new file mode 100644 index 000000000..8d59b472e --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DefaultProtocols.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.integration; + +import java.util.List; +import software.amazon.smithy.go.codegen.protocol.rpc2.cbor.Rpc2CborProtocolGenerator; +import software.amazon.smithy.utils.ListUtils; + +public class DefaultProtocols implements GoIntegration { + @Override + public List getProtocolGenerators() { + return ListUtils.of(new Rpc2CborProtocolGenerator()); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 3c4f3cea8..0b81909a3 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -190,6 +190,11 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC "return smithytesting.CompareURLFormReaderBytes(actual, []byte(`%s`))", body); break; + case "application/cbor": + compareFunc = String.format( + "return smithytesting.CompareCBOR(actual, `%s`)", + body); + break; default: compareFunc = String.format( "return smithytesting.CompareReaderBytes(actual, []byte(`%s`))", diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java index c00f76e1a..f26664660 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java @@ -107,7 +107,20 @@ protected void generateTestCaseValues(GoWriter writer, HttpResponseTestCase test writeStructField(writer, "BodyMediaType", "$S", mediaType); }); testCase.getBody().ifPresent(body -> { - writeStructField(writer, "Body", "[]byte(`$L`)", body); + var mediaType = testCase.getBodyMediaType().orElse(""); + if (mediaType.equalsIgnoreCase("application/cbor")) { + writeStructField(writer, "Body", """ + func() []byte { + p, err := $T.DecodeString(`$L`) + if err != nil { + panic(err) + } + + return p + }()""", SmithyGoDependency.BASE64.func("StdEncoding"), body); + } else { + writeStructField(writer, "Body", "[]byte(`$L`)", body); + } }); writeStructField(writer, "ExpectError", errorShape, testCase.getParams()); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseGenerator.java index 33a0b9c24..6ff003e9e 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseGenerator.java @@ -97,7 +97,20 @@ protected void generateTestCaseValues(GoWriter writer, HttpResponseTestCase test writeStructField(writer, "BodyMediaType", "$S", mediaType); }); testCase.getBody().ifPresent(body -> { - writeStructField(writer, "Body", "[]byte(`$L`)", body); + var mediaType = testCase.getBodyMediaType().orElse(""); + if (mediaType.equalsIgnoreCase("application/cbor")) { + writeStructField(writer, "Body", """ + func() []byte { + p, err := $T.DecodeString(`$L`) + if err != nil { + panic(err) + } + + return p + }()""", SmithyGoDependency.BASE64.func("StdEncoding"), body); + } else { + writeStructField(writer, "Body", "[]byte(`$L`)", body); + } }); writeStructField(writer, "ExpectResult", outputShape, testCase.getParams()); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/DeserializeResponseMiddleware.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/DeserializeResponseMiddleware.java new file mode 100644 index 000000000..f6d440f09 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/DeserializeResponseMiddleware.java @@ -0,0 +1,84 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol; + +import static software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware; +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.integration.ProtocolGenerator.getDeserializeMiddlewareName; + +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.integration.ProtocolUtils; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public abstract class DeserializeResponseMiddleware implements GoWriter.Writable { + protected final ProtocolGenerator generator; + protected final ProtocolGenerator.GenerationContext ctx; + protected final OperationShape operation; + + protected final StructureShape output; + + public DeserializeResponseMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + this.generator = generator; + this.ctx = ctx; + this.operation = operation; + + this.output = ctx.getModel().expectShape(operation.getOutputShape(), StructureShape.class); + } + + @Override + public void accept(GoWriter writer) { + var middleware = createDeserializeStepMiddleware( + getDeserializeMiddlewareName(operation.getId(), ctx.getService(), generator.getProtocolName()), + ProtocolUtils.OPERATION_DESERIALIZER_MIDDLEWARE_ID + ); + + writer.write(middleware.asWritable(generateHandleDeserialize(), emptyGoTemplate())); + } + + public abstract GoWriter.Writable generateDeserialize(); + + private GoWriter.Writable generateHandleDeserialize() { + return goTemplate(""" + out, metadata, err = next.HandleDeserialize(ctx, in) + if err != nil { + return out, metadata, err + } + + resp, ok := out.RawResponse.($response:P) + if !ok { + return out, metadata, $errorf:T("unexpected transport type %T", out.RawResponse) + } + + $deserialize:W + + return out, metadata, nil + """, + MapUtils.of( + "response", generator.getApplicationProtocol().getResponseType(), + "deserialize", generateDeserialize(), + "errorf", GoStdlibTypes.Fmt.Errorf + )); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/ProtocolUtil.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/ProtocolUtil.java new file mode 100644 index 000000000..0e7c6fe4e --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/ProtocolUtil.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol; + +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.model.traits.StreamingTrait.isEventStream; + +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class ProtocolUtil { + public static final GoWriter.Writable GET_AWS_QUERY_ERROR_CODE = goTemplate(""" + func getAwsQueryErrorCode(resp $P) string { + header := resp.Header.Get("x-amzn-query-error") + if header == "" { + return "" + } + + parts := $T(header, ";") + if len(parts) != 2 { + return "" + } + + return parts[0] + } + """, + SmithyGoDependency.SMITHY_HTTP_TRANSPORT.struct("Response"), + SmithyGoDependency.STRINGS.func("Split") + ); + + private ProtocolUtil() {} + + public static boolean hasEventStream(Model model, Shape shape) { + return shape.members().stream() + .anyMatch(it -> isEventStream(model, it)); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/SerializeRequestMiddleware.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/SerializeRequestMiddleware.java new file mode 100644 index 000000000..cf2fd2932 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/SerializeRequestMiddleware.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol; + +import static software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator.createSerializeStepMiddleware; +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; + +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.integration.ProtocolUtils; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public abstract class SerializeRequestMiddleware implements GoWriter.Writable { + protected final ProtocolGenerator generator; + protected final ProtocolGenerator.GenerationContext ctx; + protected final OperationShape operation; + + protected final StructureShape input; + protected final StructureShape output; + + public SerializeRequestMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + this.generator = generator; + this.ctx = ctx; + this.operation = operation; + + this.input = ctx.getModel().expectShape(operation.getInputShape(), StructureShape.class); + this.output = ctx.getModel().expectShape(operation.getOutputShape(), StructureShape.class); + } + + @Override + public void accept(GoWriter writer) { + var name = ProtocolGenerator.getSerializeMiddlewareName(operation.getId(), ctx.getService(), + generator.getProtocolName()); + var middleware = createSerializeStepMiddleware(name, ProtocolUtils.OPERATION_SERIALIZER_MIDDLEWARE_ID); + + writer.write(middleware.asWritable(generateHandleSerialize(), emptyGoTemplate())); + } + + public abstract GoWriter.Writable generateRouteRequest(); + + public abstract GoWriter.Writable generateSerialize(); + + private GoWriter.Writable generateHandleSerialize() { + return goTemplate(""" + input, ok := in.Parameters.($input:P) + if !ok { + return out, metadata, $errorf:T("unexpected input type %T", in.Parameters) + } + _ = input + + req, ok := in.Request.($request:P) + if !ok { + return out, metadata, $errorf:T("unexpected transport type %T", in.Request) + } + + $route:W + + $serialize:W + + return next.HandleSerialize(ctx, in) + """, + MapUtils.of( + "input", ctx.getSymbolProvider().toSymbol(input), + "request", generator.getApplicationProtocol().getRequestType(), + "route", generateRouteRequest(), + "serialize", generateSerialize(), + "errorf", GoStdlibTypes.Fmt.Errorf + )); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2DeserializeResponseMiddleware.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2DeserializeResponseMiddleware.java new file mode 100644 index 000000000..8c1e7fe0f --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2DeserializeResponseMiddleware.java @@ -0,0 +1,101 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2; + +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.protocol.ProtocolUtil.hasEventStream; + +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.DeserializeResponseMiddleware; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public abstract class Rpc2DeserializeResponseMiddleware extends DeserializeResponseMiddleware { + protected Rpc2DeserializeResponseMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + super(generator, ctx, operation); + } + + protected abstract String getProtocolName(); + + protected abstract GoWriter.Writable deserializeSuccessResponse(); + + @Override + public GoWriter.Writable generateDeserialize() { + return goTemplate(""" + if resp.Header.Get("smithy-protocol") != $protocol:S { + return out, metadata, &$deserError:T{ + Err: $errorf:T( + "unexpected smithy-protocol response header '%s' (HTTP status: %s)", + resp.Header.Get("smithy-protocol"), + resp.Status, + ), + } + } + + if resp.StatusCode != 200 { + return out, metadata, $deserializeError:L(resp) + } + + $handleResponse:W + """, + MapUtils.of( + "deserError", SmithyGoDependency.SMITHY.struct("DeserializationError"), + "protocol", getProtocolName(), + "errorf", GoStdlibTypes.Fmt.Errorf, + "handleResponse", handleResponse(), + "deserializeError", ProtocolGenerator + .getOperationErrorDeserFunctionName(operation, ctx.getService(), "rpc2") + )); + } + + private GoWriter.Writable handleResponse() { + if (output.members().isEmpty()) { + return discardDeserialize(); + } else if (hasEventStream(ctx.getModel(), output)) { + return deserializeEventStream(); + } + return deserializeSuccessResponse(); + } + + private GoWriter.Writable discardDeserialize() { + return goTemplate(""" + if _, err = $copy:T($discard:T, resp.Body); err != nil { + return out, metadata, $errorf:T("discard response body: %w", err) + } + + out.Result = &$result:T{} + """, + MapUtils.of( + "copy", GoStdlibTypes.Io.Copy, + "discard", GoStdlibTypes.Io.IoUtil.Discard, + "errorf", GoStdlibTypes.Fmt.Errorf, + "result", ctx.getSymbolProvider().toSymbol(output) + )); + } + + // Basically a no-op. Event stream deserializer middleware, implemented elsewhere, will handle the wire-up here, + // including handling the initial-response message to deserialize any non-stream members to output. + private GoWriter.Writable deserializeEventStream() { + return goTemplate("out.Result = &$T{}", ctx.getSymbolProvider().toSymbol(output)); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2ProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2ProtocolGenerator.java new file mode 100644 index 000000000..e70153263 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2ProtocolGenerator.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2; + +import static software.amazon.smithy.go.codegen.ApplicationProtocol.createDefaultHttpApplicationProtocol; + +import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.go.codegen.ApplicationProtocol; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.DeserializeResponseMiddleware; +import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public abstract class Rpc2ProtocolGenerator implements ProtocolGenerator { + public static final String SMITHY_PROTOCOL_NAME = "rpc-v2-cbor"; + public static final String CONTENT_TYPE = "application/cbor"; + + public abstract Rpc2SerializeRequestMiddleware getSerializeRequestMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ); + + public abstract DeserializeResponseMiddleware getDeserializeResponseMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ); + + @Override + public final ApplicationProtocol getApplicationProtocol() { + return createDefaultHttpApplicationProtocol(); + } + + @Override + public final void generateRequestSerializers(GenerationContext ctx) { + TopDownIndex.of(ctx.getModel()).getContainedOperations(ctx.getService()).forEach(it -> { + ctx.getWriter().get().write(getSerializeRequestMiddleware(this, ctx, it)); + }); + } + + @Override + public final void generateResponseDeserializers(GenerationContext ctx) { + TopDownIndex.of(ctx.getModel()).getContainedOperations(ctx.getService()).forEach(it -> { + ctx.getWriter().get().write(getDeserializeResponseMiddleware(this, ctx, it)); + }); + } + + @Override + public void generateEventStreamComponents(GenerationContext context) { + throw new CodegenException("event stream codegen is not currently supported in smithy-go"); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2SerializeRequestMiddleware.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2SerializeRequestMiddleware.java new file mode 100644 index 000000000..c1c36ce1a --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/Rpc2SerializeRequestMiddleware.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2; + +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; + +import software.amazon.smithy.go.codegen.EventStreamGenerator; +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.SerializeRequestMiddleware; +import software.amazon.smithy.go.codegen.trait.BackfilledInputOutputTrait; +import software.amazon.smithy.model.knowledge.EventStreamIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public abstract class Rpc2SerializeRequestMiddleware extends SerializeRequestMiddleware { + private final EventStreamIndex eventStreamIndex; + + protected Rpc2SerializeRequestMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + super(generator, ctx, operation); + + this.eventStreamIndex = EventStreamIndex.of(ctx.getModel()); + } + + public abstract String getProtocolName(); + + public abstract String getContentType(); + + @Override + public final GoWriter.Writable generateRouteRequest() { + return goTemplate(""" + req.Method = $methodPost:T + req.URL.Path = "/service/$service:L/operation/$operation:L" + req.Header.Set("smithy-protocol", $protocol:S) + + $contentTypeHeader:W + $acceptHeader:W + """, + MapUtils.of( + "methodPost", GoStdlibTypes.Net.Http.MethodPost, + "service", ctx.getService().getId().getName(), + "operation", operation.getId().getName(), + "protocol", getProtocolName(), + "contentTypeHeader", setContentTypeHeader(), + "acceptHeader", acceptHeader() + )); + } + + private GoWriter.Writable setContentTypeHeader() { + if (input.hasTrait(BackfilledInputOutputTrait.class)) { + return emptyGoTemplate(); + } + + return goTemplate(""" + req.Header.Set("Content-Type", $S) + """, isInputEventStream() ? EventStreamGenerator.AMZ_CONTENT_TYPE : getContentType()); + } + + private GoWriter.Writable acceptHeader() { + return goTemplate(""" + req.Header.Set("Accept", $S) + """, isOutputEventStream() ? EventStreamGenerator.AMZ_CONTENT_TYPE : getContentType()); + } + + private boolean isInputEventStream() { + return eventStreamIndex.getInputInfo(operation).isPresent(); + } + + private boolean isOutputEventStream() { + return eventStreamIndex.getOutputInfo(operation).isPresent(); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/DeserializeMiddleware.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/DeserializeMiddleware.java new file mode 100644 index 000000000..18ebabd15 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/DeserializeMiddleware.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2.cbor; + +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2ProtocolGenerator.SMITHY_PROTOCOL_NAME; +import static software.amazon.smithy.go.codegen.serde.cbor.CborDeserializerGenerator.getDeserializerName; + +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2DeserializeResponseMiddleware; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.utils.MapUtils; + +final class DeserializeMiddleware extends Rpc2DeserializeResponseMiddleware { + DeserializeMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + super(generator, ctx, operation); + } + + @Override + protected String getProtocolName() { + return SMITHY_PROTOCOL_NAME; + } + + @Override + public GoWriter.Writable deserializeSuccessResponse() { + return goTemplate(""" + payload, err := $readAll:T(resp.Body) + if err != nil { + return out, metadata, err + } + + if len(payload) == 0 { + out.Result = &$output:T{} + return out, metadata, nil + } + + cv, err := $decode:T(payload) + if err != nil { + return out, metadata, err + } + + output, err := $deserialize:L(cv) + if err != nil { + return out, metadata, err + } + + out.Result = output + """, + MapUtils.of( + "readAll", GoStdlibTypes.Io.ReadAll, + "decode", SmithyGoTypes.Encoding.Cbor.Decode, + "deserialize", getDeserializerName(output), + "output", ctx.getSymbolProvider() + .toSymbol(ctx.getModel().expectShape(operation.getOutputShape())) + )); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/ProtocolUtil.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/ProtocolUtil.java new file mode 100644 index 000000000..7201eef36 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/ProtocolUtil.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2.cbor; + +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; + +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.utils.MapUtils; + +final class ProtocolUtil { + public static final GoWriter.Writable GET_PROTOCOL_ERROR_INFO = goTemplate(""" + func getProtocolErrorInfo(payload []byte) (typ, msg string, v $cborValue:T, err error) { + v, err = $cborDecode:T(payload) + if err != nil { + return "", "", nil, $fmtErrorf:T("decode: %w", err) + } + + mv, ok := v.($cborMap:T) + if !ok { + return "", "", nil, $fmtErrorf:T("unexpected payload type %T", v) + } + + if ctyp, ok := mv["__type"]; ok { + if ttyp, ok := ctyp.($cborString:T); ok { + typ = string(ttyp) + } + } + + if cmsg, ok := mv["message"]; ok { + if tmsg, ok := cmsg.($cborString:T); ok { + msg = string(tmsg) + } + } + + return typ, msg, mv, nil + } + """, + MapUtils.of( + "fmtErrorf", GoStdlibTypes.Fmt.Errorf, + "cborDecode", SmithyGoTypes.Encoding.Cbor.Decode, + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "cborMap", SmithyGoTypes.Encoding.Cbor.Map, + "cborString", SmithyGoTypes.Encoding.Cbor.String + )); + + private ProtocolUtil() {} +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/Rpc2CborProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/Rpc2CborProtocolGenerator.java new file mode 100644 index 000000000..9504eb49a --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/Rpc2CborProtocolGenerator.java @@ -0,0 +1,209 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2.cbor; + +import static java.util.stream.Collectors.toCollection; +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.protocol.ProtocolUtil.GET_AWS_QUERY_ERROR_CODE; +import static software.amazon.smithy.go.codegen.protocol.rpc2.cbor.ProtocolUtil.GET_PROTOCOL_ERROR_INFO; +import static software.amazon.smithy.go.codegen.serde.SerdeUtil.getShapesToSerde; +import static software.amazon.smithy.go.codegen.serde.cbor.CborDeserializerGenerator.getDeserializerName; + +import java.util.LinkedHashSet; +import java.util.stream.Stream; +import software.amazon.smithy.aws.traits.protocols.AwsQueryCompatibleTrait; +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.DeserializeResponseMiddleware; +import software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2SerializeRequestMiddleware; +import software.amazon.smithy.go.codegen.serde.cbor.CborDeserializerGenerator; +import software.amazon.smithy.go.codegen.serde.cbor.CborSerializerGenerator; +import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public class Rpc2CborProtocolGenerator extends Rpc2ProtocolGenerator { + @Override + public final ShapeId getProtocol() { + return Rpcv2CborTrait.ID; + } + + @Override + public void generateSharedSerializerComponents(GenerationContext context) { + var model = context.getModel(); + var service = context.getService(); + var shapes = TopDownIndex.of(model).getContainedOperations(service).stream() + .map(it -> model.expectShape(it.getInputShape(), StructureShape.class)) + .flatMap(it -> getShapesToSerde(model, it).stream()) + .sorted() + .collect(toCollection(LinkedHashSet::new)); + var generator = new CborSerializerGenerator(context); + context.getWriter().get().write(generator.generate(shapes)); + } + + @Override + public void generateSharedDeserializerComponents(GenerationContext context) { + var model = context.getModel(); + var service = context.getService(); + var operations = TopDownIndex.of(model).getContainedOperations(service); + + var outputShapes = operations.stream() + .map(it -> model.expectShape(it.getOutputShape(), StructureShape.class)) + .filter(it -> !it.members().isEmpty()) + .flatMap(it -> getShapesToSerde(model, it).stream()); + var errorShapes = operations.stream() + .flatMap(it -> it.getErrors().stream()) + .map(model::expectShape) + .flatMap(it -> getShapesToSerde(model, it).stream()); + + var generator = new CborDeserializerGenerator(context); + var writer = context.getWriter().get(); + writer.write(generator.generate( + Stream.concat(outputShapes, errorShapes) + .sorted() + .collect(toCollection(LinkedHashSet::new))) // in case of overlap + ); + writer.write(GoWriter.ChainWritable.of( + operations.stream() + .sorted() + .map(it -> deserializeOperationError(context, it)) + .toList() + ).compose()); + + writer.write(GET_PROTOCOL_ERROR_INFO); + writer.write(GET_AWS_QUERY_ERROR_CODE); + } + + @Override + public final Rpc2SerializeRequestMiddleware getSerializeRequestMiddleware( + ProtocolGenerator generator, GenerationContext ctx, OperationShape operation + ) { + return new SerializeMiddleware(generator, ctx, operation); + } + + @Override + public final DeserializeResponseMiddleware getDeserializeResponseMiddleware( + ProtocolGenerator generator, GenerationContext ctx, OperationShape operation + ) { + return new DeserializeMiddleware(generator, ctx, operation); + } + + private GoWriter.Writable deserializeOperationError( + ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + var model = ctx.getModel(); + var service = ctx.getService(); + return goTemplate(""" + func $func:L(resp $smithyhttpResponse:P) error { + payload, err := $readAll:T(resp.Body) + if err != nil { + return &$deserError:T{Err: $fmtErrorf:T("read response body: %w", err)} + } + + typ, msg, v, err := getProtocolErrorInfo(payload) + if err != nil { + return &$deserError:T{Err: $fmtErrorf:T("get error info: %w", err)} + } + + if len(typ) == 0 { + typ = "UnknownError" + } + if len(msg) == 0 { + msg = "UnknownError" + } + + _ = v + switch string(typ) { + $errors:W + default: + $awsQueryCompatible:W + return &$genericAPIError:T{Code: typ, Message: msg} + } + } + """, + MapUtils.of( + "cborDecode", SmithyGoTypes.Encoding.Cbor.Decode, + "cborMap", SmithyGoTypes.Encoding.Cbor.Map, + "cborString", SmithyGoTypes.Encoding.Cbor.String + ), + MapUtils.of( + "deserError", SmithyGoDependency.SMITHY.pointableSymbol("DeserializationError"), + "fmtErrorf", GoStdlibTypes.Fmt.Errorf, + "func", ProtocolGenerator.getOperationErrorDeserFunctionName(operation, service, "rpc2"), + "genericAPIError", SmithyGoDependency.SMITHY.pointableSymbol("GenericAPIError"), + "readAll", SmithyGoDependency.IO.func("ReadAll"), + "smithyhttpResponse", SmithyGoTypes.Transport.Http.Response, + "awsQueryCompatible", ctx.getService().hasTrait(AwsQueryCompatibleTrait.class) + ? deserializeAwsQueryError() + : emptyGoTemplate(), + "errors", GoWriter.ChainWritable.of( + operation.getErrors(service).stream() + .map(it -> + deserializeErrorCase(ctx, model.expectShape(it, StructureShape.class))) + .toList() + ).compose(false) + )); + } + + private GoWriter.Writable deserializeErrorCase(GenerationContext ctx, StructureShape error) { + return goTemplate(""" + case $type:S: + verr, err := $deserialize:L(v) + if err != nil { + return &$deserError:T{ + Err: $fmtErrorf:T("deserialize $type:L: %w", err), + Snapshot: payload, + } + } + $awsQueryCompatible:W + return verr + """, + MapUtils.of( + "deserError", SmithyGoDependency.SMITHY.pointableSymbol("DeserializationError"), + "deserialize", getDeserializerName(error), + "equalFold", SmithyGoDependency.STRINGS.func("EqualFold"), + "fmtErrorf", GoStdlibTypes.Fmt.Errorf, + "type", error.getId().toString(), + "awsQueryCompatible", ctx.getService().hasTrait(AwsQueryCompatibleTrait.class) + ? deserializeModeledAwsQueryError() + : emptyGoTemplate() + )); + } + + private GoWriter.Writable deserializeAwsQueryError() { + return goTemplate(""" + if qtype := getAwsQueryErrorCode(resp); len(qt) > 0 { + typ = qtype + }"""); + } + + private GoWriter.Writable deserializeModeledAwsQueryError() { + return goTemplate(""" + if qtype := getAwsQueryErrorCode(resp); len(qt) > 0 { + verr.ErrorCodeOverride = $T(qtype) + }""", SmithyGoTypes.Ptr.String); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/SerializeMiddleware.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/SerializeMiddleware.java new file mode 100644 index 000000000..05c2c6f2c --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/protocol/rpc2/cbor/SerializeMiddleware.java @@ -0,0 +1,76 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.protocol.rpc2.cbor; + +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2ProtocolGenerator.CONTENT_TYPE; +import static software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2ProtocolGenerator.SMITHY_PROTOCOL_NAME; +import static software.amazon.smithy.go.codegen.serde.cbor.CborSerializerGenerator.getSerializerName; + +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.protocol.rpc2.Rpc2SerializeRequestMiddleware; +import software.amazon.smithy.go.codegen.trait.BackfilledInputOutputTrait; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.utils.MapUtils; + +final class SerializeMiddleware extends Rpc2SerializeRequestMiddleware { + SerializeMiddleware( + ProtocolGenerator generator, ProtocolGenerator.GenerationContext ctx, OperationShape operation + ) { + super(generator, ctx, operation); + } + + @Override + public String getProtocolName() { + return SMITHY_PROTOCOL_NAME; + } + + @Override + public String getContentType() { + return CONTENT_TYPE; + } + + @Override + public GoWriter.Writable generateSerialize() { + if (input.hasTrait(BackfilledInputOutputTrait.class)) { + return emptyGoTemplate(); + } + + return goTemplate(""" + cv, err := $serialize:L(input) + if err != nil { + return out, metadata, &$error:T{Err: err} + } + + payload := $reader:T($encode:T(cv)) + if req, err = req.SetStream(payload); err != nil { + return out, metadata, &$error:T{Err: err} + } + + in.Request = req + """, + MapUtils.of( + "serialize", getSerializerName(input), + "encode", SmithyGoTypes.Encoding.Cbor.Encode, + "reader", GoStdlibTypes.Bytes.NewReader, + "error", SmithyGoTypes.Smithy.SerializationError + )); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/SerdeUtil.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/SerdeUtil.java new file mode 100644 index 000000000..39dadb951 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/SerdeUtil.java @@ -0,0 +1,99 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.serde; + +import static java.util.stream.Collectors.toSet; + +import java.util.HashSet; +import java.util.Set; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.BlobShape; +import software.amazon.smithy.model.shapes.BooleanShape; +import software.amazon.smithy.model.shapes.ByteShape; +import software.amazon.smithy.model.shapes.DoubleShape; +import software.amazon.smithy.model.shapes.FloatShape; +import software.amazon.smithy.model.shapes.IntegerShape; +import software.amazon.smithy.model.shapes.LongShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.ShortShape; +import software.amazon.smithy.model.shapes.StringShape; +import software.amazon.smithy.model.shapes.TimestampShape; +import software.amazon.smithy.model.traits.StreamingTrait; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class SerdeUtil { + private SerdeUtil() {} + + /** + * Gets the set of all shapes that require serde codegen for the given root shape. This is generally called for + * every input/output shape in a model, with all the results collected into a single set. + * @param model The model + * @param shape The root shape to walk to find serdeables. + * @return The set of shapes that require serde codegen. + */ + public static Set getShapesToSerde(Model model, Shape shape) { + var toSerde = new HashSet(); + visitShapesToSerde(model, shape, toSerde); + + // We don't want to actually generate serde for event stream unions - their variants can target errors, which + // shouldn't be handled generally. We DO want any of their inner members though which is why we didn't filter + // them in the previous visit step. + // + // Serde for the root unions is handled as a special case by event streaming serde codegen. + return toSerde.stream() + .filter(it -> !it.hasTrait(StreamingTrait.class)) + .collect(toSet()); + } + + /** + * Normalizes a scalar shape, erasing any nullability information and giving the shape a single unique synthetic ID. + * Non-scalar shapes are returned unmodified. + * @param shape The shape. + * @return The normalized shape. + */ + public static Shape normalize(Shape shape) { + return switch (shape.getType()) { + case BLOB -> BlobShape.builder().id("com.amazonaws.synthetic#Blob").build(); + case BOOLEAN -> BooleanShape.builder().id("com.amazonaws.synthetic#Bool").build(); + case STRING -> StringShape.builder().id("com.amazonaws.synthetic#String").build(); + case TIMESTAMP -> TimestampShape.builder().id("com.amazonaws.synthetic#Time").build(); + case BYTE -> ByteShape.builder().id("com.amazonaws.synthetic#Int8").build(); + case SHORT -> ShortShape.builder().id("com.amazonaws.synthetic#Int16").build(); + case INTEGER -> IntegerShape.builder().id("com.amazonaws.synthetic#Int32").build(); + case LONG -> LongShape.builder().id("com.amazonaws.synthetic#Int64").build(); + case FLOAT -> FloatShape.builder().id("com.amazonaws.synthetic#Float32").build(); + case DOUBLE -> DoubleShape.builder().id("com.amazonaws.synthetic#Float64").build(); + default -> shape; + }; + } + + private static void visitShapesToSerde(Model model, Shape shape, Set visited) { + if (isUnit(shape.getId()) || visited.contains(shape)) { + return; + } + + visited.add(normalize(shape)); + shape.members().stream() + .map(it -> model.expectShape(it.getTarget())) + .forEach(it -> visitShapesToSerde(model, it, visited)); + } + + private static boolean isUnit(ShapeId id) { + return id.toString().equals("smithy.api#Unit"); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborDeserializerGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborDeserializerGenerator.java new file mode 100644 index 000000000..1837ecc11 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborDeserializerGenerator.java @@ -0,0 +1,420 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.serde.cbor; + +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SymbolUtils.buildSymbol; +import static software.amazon.smithy.go.codegen.SymbolUtils.getReference; +import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable; +import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable; +import static software.amazon.smithy.go.codegen.serde.SerdeUtil.normalize; + +import java.util.Set; +import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.ProtocolDocumentGenerator; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.CollectionShape; +import software.amazon.smithy.model.shapes.MapShape; +import software.amazon.smithy.model.shapes.MemberShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.model.shapes.UnionShape; +import software.amazon.smithy.model.traits.StreamingTrait; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class CborDeserializerGenerator { + private final Model model; + private final SymbolProvider symbolProvider; + private final GoSettings settings; + + public CborDeserializerGenerator(ProtocolGenerator.GenerationContext ctx) { + this.model = ctx.getModel(); + this.symbolProvider = ctx.getSymbolProvider(); + this.settings = ctx.getSettings(); + } + + public static String getDeserializerName(Shape shape) { + return "deserializeCBOR_" + shape.getId().getName(); + } + + public GoWriter.Writable generate(Set shapes) { + return GoWriter.ChainWritable.of( + shapes.stream() + .map(this::deserializeShape) + .toList() + ).compose(); + } + + private GoWriter.Writable deserializeShape(Shape shape) { + return switch (shape.getType()) { + case BIG_INTEGER, BIG_DECIMAL -> + throw new CodegenException("arbitrary-precision nums are not supported (" + shape.getType() + ")"); + case BYTE -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsInt8); // special types with coercers + case SHORT -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsInt16); + case INTEGER -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsInt32); + case LONG -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsInt64); + case FLOAT -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsFloat32); + case DOUBLE -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsFloat64); + case TIMESTAMP -> deserializeStatic(shape, SmithyGoTypes.Encoding.Cbor.AsTime); + case INT_ENUM -> deserializeIntEnum(shape); + case STRING -> deserializeString(shape); + case DOCUMENT -> deserializeDocument(shape); // implemented, but not currently supported + default -> deserializeAssertFunc(shape); // everything else is a static assert + }; + } + + private GoWriter.Writable deserializeStatic(Shape shape, Symbol coercer) { + return goTemplate(""" + func $deserName:L(v $cborValue:T) ($type:T, error) { + return $coercer:T(v) + } + """, + MapUtils.of( + "deserName", getDeserializerName(shape), + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "type", symbolProvider.toSymbol(shape), + "coercer", coercer + )); + } + + private GoWriter.Writable deserializeIntEnum(Shape shape) { + return goTemplate(""" + func $name:L(v $cborValue:T) ($shapeType:T, error) { + av, err := $asInt32:T(v) + if err != nil { + return 0, err + } + return $shapeType:T(av), nil + } + """, + MapUtils.of( + "name", getDeserializerName(shape), + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "shapeType", symbolProvider.toSymbol(shape), + "asInt32", SmithyGoTypes.Encoding.Cbor.AsInt32 + )); + } + + private GoWriter.Writable deserializeString(Shape shape) { + return goTemplate(""" + func $name:L(v $cborValue:T) (string, error) { + av, ok := v.($assert:T) + if !ok { + return "", $error:T("unexpected value type %T", v) + } + return string(av), nil + } + """, + MapUtils.of( + "name", getDeserializerName(shape), + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "assert", SmithyGoTypes.Encoding.Cbor.String, + "error", GoStdlibTypes.Fmt.Errorf + )); + } + + private GoWriter.Writable deserializeAssertFunc(Shape shape) { + return goTemplate(""" + func $name:L(v $cborValue:T) ($shapeType:P, error) { + av, ok := v.($assert:W) + if !ok { + return $zero:W, $error:T("unexpected value type %T", v) + } + $deserialize:W + } + """, + MapUtils.of( + "name", getDeserializerName(shape), + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "shapeType", symbolProvider.toSymbol(shape), + "assert", typeAssert(shape), + "zero", zeroValue(shape), + "error", GoStdlibTypes.Fmt.Errorf, + "deserialize", deserializeAsserted(shape, "av") + )); + } + + private GoWriter.Writable typeAssert(Shape shape) { + return switch (shape.getType()) { + case STRING, ENUM -> + goTemplate("$T", SmithyGoTypes.Encoding.Cbor.String); + case BLOB -> + goTemplate("$T", SmithyGoTypes.Encoding.Cbor.Slice); + case LIST, SET -> + goTemplate("$T", SmithyGoTypes.Encoding.Cbor.List); + case MAP, STRUCTURE, UNION -> + goTemplate("$T", SmithyGoTypes.Encoding.Cbor.Map); + case TIMESTAMP, BIG_DECIMAL, BIG_INTEGER -> + goTemplate("$P", SmithyGoTypes.Encoding.Cbor.Tag); + case BOOLEAN -> + goTemplate("$T", SmithyGoTypes.Encoding.Cbor.Bool); + default -> + throw new CodegenException("Unexpected shape for single-assert: " + shape.getType()); + }; + } + + private GoWriter.Writable zeroValue(Shape shape) { + return switch (shape.getType()) { + case STRING -> + goTemplate("\"\""); + case BOOLEAN -> + goTemplate("false"); + case BLOB, LIST, SET, MAP, STRUCTURE, UNION -> + goTemplate("nil"); + case ENUM -> + goTemplate("$T(\"\")", symbolProvider.toSymbol(shape)); + case INT_ENUM -> + goTemplate("$T(0)", symbolProvider.toSymbol(shape)); + default -> + throw new CodegenException("Unexpected shape for zero-value: " + shape.getType()); + }; + } + + private GoWriter.Writable deserializeAsserted(Shape shape, String ident) { + return switch (shape.getType()) { + case STRING -> goTemplate("return string($L), nil", ident); + case ENUM -> goTemplate("return $T($L), nil", symbolProvider.toSymbol(shape), ident); + case BOOLEAN -> goTemplate("return bool($L), nil", ident); + case BLOB -> goTemplate("return []byte($L), nil", ident); + case LIST, SET -> deserializeList((CollectionShape) shape, ident); + case MAP -> deserializeMap((MapShape) shape, ident); + case STRUCTURE -> deserializeStruct((StructureShape) shape, ident); + case UNION -> deserializeUnion((UnionShape) shape, ident); + default -> + throw new CodegenException("Cannot deserialize " + shape.getType()); + }; + } + + private GoWriter.Writable deserializeList(CollectionShape shape, String ident) { + var target = normalize(model.expectShape(shape.getMember().getTarget())); + var symbol = symbolProvider.toSymbol(shape); + var targetSymbol = symbolProvider.toSymbol(target); + return goTemplate(""" + var dl $type:T + for _, si := range $ident:L { + $sparse:W + di, err := $deserialize:L(si) + if err != nil { + return nil, err + } + dl = append(dl, $deref:L di) + } + return dl, nil + """, + MapUtils.of( + "type", symbol, + "ident", ident, + "deserialize", getDeserializerName(target), + "deref", resolveDeref(getReference(symbol), targetSymbol), + "sparse", isNilable(getReference(symbol)) ? handleSparseList() : emptyGoTemplate() + )); + } + + private GoWriter.Writable handleSparseList() { + return goTemplate(""" + if _, ok := si.($P); ok { + dl = append(dl, nil) + continue + } + """, SmithyGoTypes.Encoding.Cbor.Nil); + } + + private GoWriter.Writable deserializeMap(MapShape shape, String ident) { + var value = normalize(model.expectShape(shape.getValue().getTarget())); + var symbol = symbolProvider.toSymbol(shape); + var valueSymbol = symbolProvider.toSymbol(value); + return goTemplate(""" + dm := $type:T{} + for key, sv := range $ident:L { + $sparse:W + dv, err := $deserialize:L(sv) + if err != nil { + return nil, err + } + dm[key] = $deref:L dv + } + return dm, nil + """, + MapUtils.of( + "type", symbol, + "ident", ident, + "deserialize", getDeserializerName(value), + "deref", resolveDeref(getReference(symbol), valueSymbol), + "sparse", isNilable(getReference(symbol)) ? handleSparseMap() : emptyGoTemplate() + )); + } + + private GoWriter.Writable handleSparseMap() { + return goTemplate(""" + if _, ok := sv.($P); ok { + dm[key] = nil + continue + } + """, SmithyGoTypes.Encoding.Cbor.Nil); + } + + private String resolveDeref(Symbol ref, Symbol deserialized) { + if (isPointable(ref) == isPointable(deserialized)) { + return ""; + } + return isPointable(deserialized) ? "*" : "&"; + } + + private GoWriter.Writable deserializeStruct(StructureShape shape, String ident) { + return goTemplate(""" + ds := &$type:T{} + for key, sv := range $ident:L { + _, _ = key, sv + $fields:W + } + return ds, nil + """, + MapUtils.of( + "type", symbolProvider.toSymbol(shape), + "ident", ident, + "fields", GoWriter.ChainWritable.of( + shape.getAllMembers().values().stream() + .map(this::deserializeField) + .toList() + ).compose() + )); + } + + private GoWriter.Writable deserializeField(MemberShape member) { + var target = model.expectShape(member.getTarget()); + if (target.hasTrait(StreamingTrait.class)) { + return emptyGoTemplate(); // event stream, not an actual field + } + + var memberSymbol = symbolProvider.toSymbol(member); + return goTemplate(""" + if key == $field:S { + $nilable:W + dv, err := $deserialize:L(sv) + if err != nil { + return nil, err + } + ds.$fieldName:L = $deref:W + } + """, + MapUtils.of( + "field", member.getMemberName(), + "fieldName", symbolProvider.toMemberName(member), + "deserialize", getDeserializerName(normalize(target)), + "deref", generateStructFieldDeref(member, "dv"), + "nilable", isNilable(memberSymbol) ? handleSparseField() : emptyGoTemplate() + )); + } + + private GoWriter.Writable handleSparseField() { + return goTemplate(""" + if _, ok := sv.($P); ok { + continue + }""", SmithyGoTypes.Encoding.Cbor.Nil); + } + + private GoWriter.Writable generateStructFieldDeref(MemberShape member, String ident) { + var symbol = symbolProvider.toSymbol(member); + if (!isPointable(symbol)) { + return goTemplate(ident); + } + return switch (model.expectShape(member.getTarget()).getType()) { + case BYTE -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int8, ident); + case SHORT -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int16, ident); + case INTEGER -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int32, ident); + case LONG -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Int64, ident); + case FLOAT -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Float32, ident); + case DOUBLE -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Float64, ident); + case STRING -> goTemplate("$T($L)", SmithyGoTypes.Ptr.String, ident); + case BOOLEAN -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Bool, ident); + case TIMESTAMP -> goTemplate("$T($L)", SmithyGoTypes.Ptr.Time, ident); + default -> goTemplate(ident); + }; + } + + private GoWriter.Writable deserializeUnion(UnionShape union, String ident) { + return goTemplate(""" + for key, sv := range $ident:L { + $variants:W + } + return nil, $errorf:T("unrecognized variant") + """, + MapUtils.of( + "type", symbolProvider.toSymbol(union), + "ident", ident, + "errorf", GoStdlibTypes.Fmt.Errorf, + "variants", GoWriter.ChainWritable.of( + union.getAllMembers().values().stream() + .map(it -> deserializeVariant(union, it, "sv")) + .toList() + ).compose() + )); + } + + private GoWriter.Writable deserializeVariant(UnionShape union, MemberShape member, String ident) { + var target = normalize(model.expectShape(member.getTarget())); + var symbol = symbolProvider.toSymbol(union); + var variantSymbol = buildSymbol(symbolProvider.toMemberName(member), symbol.getNamespace()); + return goTemplate(""" + if key == $variantName:S { + if _, ok := $ident:L.($cborNil:P); ok { + continue + } + dv, err := $deserialize:L($ident:L) + if err != nil { + return nil, err + } + return &$variantSymbol:T{Value: $deref:L dv}, nil + } + """, + MapUtils.of( + "cborNil", SmithyGoDependency.SMITHY_CBOR.struct("Nil"), + "variantName", member.getMemberName(), + "deserialize", getDeserializerName(target), + "ident", ident, + "variantSymbol", variantSymbol, + "deref", target.getType() == ShapeType.STRUCTURE ? "*" : "" + )); + } + + private GoWriter.Writable deserializeDocument(Shape shape) { + var unmarshaler = ProtocolDocumentGenerator.Utilities.getInternalDocumentSymbolBuilder(settings, + ProtocolDocumentGenerator.INTERNAL_NEW_DOCUMENT_UNMARSHALER_FUNC).build(); + return goTemplate(""" + func $deser:L(v $cborValue:T) ($document:T, error) { + return $unmarshaler:T(v), nil + } + """, + MapUtils.of( + "deser", getDeserializerName(shape), + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "document", symbolProvider.toSymbol(shape), + "unmarshaler", unmarshaler + )); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborSerializerGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborSerializerGenerator.java new file mode 100644 index 000000000..a68dfe99d --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/serde/cbor/CborSerializerGenerator.java @@ -0,0 +1,323 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.serde.cbor; + +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SymbolUtils.buildSymbol; +import static software.amazon.smithy.go.codegen.SymbolUtils.getReference; +import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable; +import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable; +import static software.amazon.smithy.go.codegen.serde.SerdeUtil.normalize; + +import java.util.Set; +import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoStdlibTypes; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoTypes; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.CollectionShape; +import software.amazon.smithy.model.shapes.DocumentShape; +import software.amazon.smithy.model.shapes.MapShape; +import software.amazon.smithy.model.shapes.MemberShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.model.shapes.TimestampShape; +import software.amazon.smithy.model.shapes.UnionShape; +import software.amazon.smithy.model.traits.StreamingTrait; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class CborSerializerGenerator { + private final Model model; + private final SymbolProvider symbolProvider; + + public CborSerializerGenerator(ProtocolGenerator.GenerationContext ctx) { + this.model = ctx.getModel(); + this.symbolProvider = ctx.getSymbolProvider(); + } + + public static String getSerializerName(Shape shape) { + return "serializeCBOR_" + shape.getId().getName(); + } + + public GoWriter.Writable generate(Set shapes) { + return GoWriter.ChainWritable.of( + shapes.stream() + .map(this::generateShapeSerializer) + .toList() + ).compose(); + } + + private GoWriter.Writable generateShapeSerializer(Shape shape) { + return goTemplate(""" + func $name:L(v $shapeType:P) ($cborValue:T, error) { + $serialize:W + } + """, + MapUtils.of( + "name", getSerializerName(shape), + "shapeType", symbolProvider.toSymbol(shape), + "cborValue", SmithyGoTypes.Encoding.Cbor.Value, + "serialize", generateSerializeValue(shape) + )); + } + + private GoWriter.Writable generateSerializeValue(Shape shape) { + return switch (shape.getType()) { + case BYTE, SHORT, INTEGER, LONG, INT_ENUM -> generateSerializeIntegral(); + case FLOAT -> goTemplate("return $T(v), nil", SmithyGoTypes.Encoding.Cbor.Float32); + case DOUBLE -> goTemplate("return $T(v), nil", SmithyGoTypes.Encoding.Cbor.Float64); + case STRING -> goTemplate("return $T(v), nil", SmithyGoTypes.Encoding.Cbor.String); + case BOOLEAN -> goTemplate("return $T(v), nil", SmithyGoTypes.Encoding.Cbor.Bool); + case BLOB -> goTemplate("return $T(v), nil", SmithyGoTypes.Encoding.Cbor.Slice); + case ENUM -> goTemplate("return $T(string(v)), nil", SmithyGoTypes.Encoding.Cbor.String); + case TIMESTAMP -> generateSerializeTimestamp((TimestampShape) shape); + case LIST, SET -> generateSerializeList((CollectionShape) shape); + case MAP -> generateSerializeMap((MapShape) shape); + case STRUCTURE -> generateSerializeStruct((StructureShape) shape); + case UNION -> generateSerializeUnion((UnionShape) shape); + case DOCUMENT -> serializeDocument((DocumentShape) shape); // implemented, but not currently supported + case BIG_INTEGER, BIG_DECIMAL -> + throw new CodegenException("arbitrary-precision nums are not supported (" + shape.getType() + ")"); + case MEMBER, SERVICE, RESOURCE, OPERATION -> + throw new CodegenException("cannot generate serializer for shape type " + shape.getType()); + }; + } + + private GoWriter.Writable generateSerializeIntegral() { + return goTemplate(""" + if v < 0 { + return $T(uint64(-v)), nil + } + return $T(uint64(v)), nil + """, SmithyGoTypes.Encoding.Cbor.NegInt, SmithyGoTypes.Encoding.Cbor.Uint); + } + + private GoWriter.Writable generateSerializeTimestamp(TimestampShape shape) { + return goTemplate(""" + return &$tag:T{ + ID: 1, + Value: $float64:T(float64(v.UnixMilli()) / 1000), + }, nil + """, + MapUtils.of( + "tag", SmithyGoTypes.Encoding.Cbor.Tag, + "float64", SmithyGoTypes.Encoding.Cbor.Float64 + )); + } + + private GoWriter.Writable generateSerializeList(CollectionShape shape) { + var target = normalize(model.expectShape(shape.getMember().getTarget())); + var symbol = symbolProvider.toSymbol(shape); + var targetSymbol = symbolProvider.toSymbol(target); + return goTemplate(""" + vl := $list:T{} + for i := range v { + $sparse:W + ser, err := $serialize:L($indirect:L v[i]) + if err != nil { + return nil, err + } + vl = append(vl, ser) + } + return vl, nil + """, + MapUtils.of( + "list", SmithyGoTypes.Encoding.Cbor.List, + "sparse", isNilable(getReference(symbol)) ? handleSparseList() : emptyGoTemplate(), + "serialize", getSerializerName(target), + "indirect", resolveIndirect(getReference(symbol), targetSymbol) + )); + } + + private GoWriter.Writable handleSparseList() { + return goTemplate(""" + if v[i] == nil { + vl = append(vl, &$T{}) + continue + } + """, SmithyGoTypes.Encoding.Cbor.Nil); + } + + private GoWriter.Writable generateSerializeMap(MapShape shape) { + var value = normalize(model.expectShape(shape.getValue().getTarget())); + var symbol = symbolProvider.toSymbol(shape); + var valueSymbol = symbolProvider.toSymbol(value); + return goTemplate(""" + vm := $map:T{} + for k, vv := range v { + $sparse:W + ser, err := $serialize:L($indirect:L vv) + if err != nil { + return nil, err + } + vm[k] = ser + } + return vm, nil + """, + MapUtils.of( + "map", SmithyGoTypes.Encoding.Cbor.Map, + "sparse", isNilable(getReference(symbol)) ? handleSparseMap() : emptyGoTemplate(), + "serialize", getSerializerName(value), + "indirect", resolveIndirect(getReference(symbol), valueSymbol) + )); + } + + private GoWriter.Writable handleSparseMap() { + return goTemplate(""" + if vv == nil { + vm[k] = &$T{} + continue + } + """, SmithyGoTypes.Encoding.Cbor.Nil); + } + + private GoWriter.Writable generateSerializeStruct(StructureShape shape) { + return goTemplate(""" + vm := $map:T{} + $serialize:W + return vm, nil + """, + MapUtils.of( + "map", SmithyGoTypes.Encoding.Cbor.Map, + "serialize", GoWriter.ChainWritable.of( + shape.getAllMembers().values().stream() + .map(this::generateSerializeField) + .toList() + ).compose(false) + )); + } + + private GoWriter.Writable generateSerializeField(MemberShape member) { + var target = normalize(model.expectShape(member.getTarget())); + if (target.hasTrait(StreamingTrait.class)) { + return emptyGoTemplate(); // event stream, not an actual field + } + + var symbol = symbolProvider.toSymbol(member); + return switch (target.getType()) { + case BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, BOOLEAN, TIMESTAMP -> + isPointable(symbol) + ? serializeNilableMember(member, target, true) + : serializeMember(member, target); + case BLOB, LIST, SET, MAP, STRUCTURE, UNION -> + serializeNilableMember(member, target, false); + default -> + serializeMember(member, target); + }; + } + + private GoWriter.Writable serializeNilableMember(MemberShape member, Shape target, boolean deref) { + return goTemplate(""" + if v.$field:L != nil { + ser, err := $serialize:L($deref:L v.$field:L) + if err != nil { + return nil, err + } + vm[$key:S] = ser + } + """, + MapUtils.of( + "field", symbolProvider.toMemberName(member), + "key", member.getMemberName(), + "serialize", getSerializerName(target), + "deref", deref ? "*" : "" + )); + } + + private GoWriter.Writable serializeMember(MemberShape member, Shape target) { + return goTemplate(""" + ser$key:L, err := $serialize:L(v.$field:L) + if err != nil { + return nil, err + } + vm[$key:S] = ser$key:L + """, + MapUtils.of( + "field", symbolProvider.toMemberName(member), + "key", member.getMemberName(), + "serialize", getSerializerName(target) + )); + } + + private GoWriter.Writable generateSerializeUnion(UnionShape union) { + return goTemplate(""" + vm := $map:T{} + switch uv := v.(type) { + $serialize:W + default: + return nil, $errorf:T("unknown variant type %T", v) + } + return vm, nil + """, + MapUtils.of( + "map", SmithyGoTypes.Encoding.Cbor.Map, + "errorf", GoStdlibTypes.Fmt.Errorf, + "serialize", GoWriter.ChainWritable.of( + union.getAllMembers().values().stream() + .map(it -> serializeVariant(union, it)) + .toList() + ).compose(false) + )); + } + + private GoWriter.Writable serializeVariant(UnionShape union, MemberShape member) { + var target = normalize(model.expectShape(member.getTarget())); + var symbol = symbolProvider.toSymbol(union); + var variantSymbol = buildSymbol(symbolProvider.toMemberName(member), symbol.getNamespace()); + return goTemplate(""" + case *$variant:T: + ser, err := $serialize:L($indirect:L uv.Value) + if err != nil { + return nil, err + } + vm[$key:S] = ser + """, + MapUtils.of( + "variant", variantSymbol, + "serialize", getSerializerName(target), + "key", member.getMemberName(), + "indirect", target.getType() == ShapeType.STRUCTURE ? "&" : "" + )); + } + + private GoWriter.Writable serializeDocument(DocumentShape document) { + return goTemplate(""" + raw, err := v.MarshalSmithyDocument() + if err != nil { + return nil, err + } + return $encodeRaw:T(raw), nil + """, + MapUtils.of( + "encoder", SmithyGoTypes.Document.Cbor.NewEncoder, + "encodeRaw", SmithyGoTypes.Encoding.Cbor.EncodeRaw + )); + } + + private String resolveIndirect(Symbol ref, Symbol serialized) { + if (isPointable(ref) == isPointable(serialized)) { + return ""; + } + return isPointable(serialized) ? "&" : "*"; + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/trait/BackfilledInputOutputTrait.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/trait/BackfilledInputOutputTrait.java new file mode 100644 index 000000000..adfb15ae1 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/trait/BackfilledInputOutputTrait.java @@ -0,0 +1,33 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.trait; + +import static software.amazon.smithy.model.node.Node.objectNode; + +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.AnnotationTrait; + +/** + * Marker trait to indicate that an input or output shape on an operation was backfilled (i.e. the operation DID NOT + * have a target for it in the original model). + */ +public class BackfilledInputOutputTrait extends AnnotationTrait { + public static final ShapeId ID = ShapeId.from("smithy.go.trait#BackfilledInputOutput"); + + public BackfilledInputOutputTrait() { + super(ID, objectNode()); + } +} diff --git a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration index a9c685fcf..4781a49a2 100644 --- a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration +++ b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration @@ -16,4 +16,4 @@ software.amazon.smithy.go.codegen.integration.auth.AnonymousAuthScheme software.amazon.smithy.go.codegen.requestcompression.RequestCompression # server -software.amazon.smithy.go.codegen.server.integration.DefaultProtocols \ No newline at end of file +software.amazon.smithy.go.codegen.server.integration.DefaultProtocols diff --git a/document/cbor/cbor.go b/document/cbor/cbor.go new file mode 100644 index 000000000..a4c0cb893 --- /dev/null +++ b/document/cbor/cbor.go @@ -0,0 +1,8 @@ +// Package cbor implements reflective encoding of Smithy documents for +// CBOR-based protocols. +// +// This package is NOT caller-facing and is not suitable for general +// application use. Callers using the document type with SDK clients should use +// the embedded NewLazyDocument() API in the SDK package to create document +// types. +package cbor diff --git a/document/cbor/decode.go b/document/cbor/decode.go new file mode 100644 index 000000000..c998989a7 --- /dev/null +++ b/document/cbor/decode.go @@ -0,0 +1,342 @@ +package cbor + +import ( + "fmt" + "math/big" + "reflect" + + "github.com/aws/smithy-go/document" + "github.com/aws/smithy-go/document/internal/serde" + "github.com/aws/smithy-go/encoding/cbor" +) + +// decoderOptions is the set of options that can be configured for a Decoder. +// +// FUTURE(rpc2cbor): document support is currently disabled. This API is +// unexported until that changes. +type decoderOptions struct{} + +// decoder is a Smithy document decoder for CBOR-based protocols. +// +// FUTURE(rpc2cbor): document support is currently disabled. This API is +// unexported until that changes. +type decoder struct { + options decoderOptions +} + +// newDecoder returns a Decoder for deserializing Smithy documents. +// +// FUTURE(rpc2cbor): document support is currently disabled. This API is +// unexported until that changes. +func newDecoder(optFns ...func(options *decoderOptions)) *decoder { + o := decoderOptions{} + + for _, fn := range optFns { + fn(&o) + } + + return &decoder{ + options: o, + } +} + +// Decode unmarshals a CBOR Value into the target. +func (d *decoder) Decode(v cbor.Value, to interface{}) error { + if document.IsNoSerde(to) { + return fmt.Errorf("unsupported type: %T", to) + } + + rv := reflect.ValueOf(to) + if rv.Kind() != reflect.Ptr || rv.IsNil() || !rv.IsValid() { + return &document.InvalidUnmarshalError{reflect.TypeOf(to)} + } + + return d.decode(v, rv, serde.Tag{}) +} + +func (d *decoder) decode(cv cbor.Value, rv reflect.Value, tag serde.Tag) error { + if _, ok := cv.(*cbor.Nil); ok { + return d.decodeNil(serde.Indirect(rv, true)) + } + + rv = serde.Indirect(rv, false) + if err := d.unsupportedType(rv); err != nil { + return err + } + + switch v := cv.(type) { + case cbor.Uint, cbor.NegInt: + return d.decodeInt(v, rv) + case cbor.Float64: + return d.decodeFloat(float64(v), rv) + case cbor.String: + return d.decodeString(string(v), rv) + case cbor.Bool: + return d.decodeBool(bool(v), rv) + case cbor.List: + return d.decodeList(v, rv) + case cbor.Map: + return d.decodeMap(v, rv) + case *cbor.Tag: + return d.decodeTag(v, rv) + default: + return fmt.Errorf("unsupported cbor document type %T", v) + } +} + +func (d *decoder) decodeInt(v cbor.Value, rv reflect.Value) error { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := cbor.AsInt64(v) + if err != nil { + return err + } + if rv.OverflowInt(i) { + return &document.UnmarshalTypeError{ + Value: fmt.Sprintf("number overflow, %d", i), + Type: rv.Type(), + } + } + rv.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u, ok := v.(cbor.Uint) + if !ok { + return &document.UnmarshalTypeError{Value: "number", Type: rv.Type()} + } + if rv.OverflowUint(uint64(u)) { + return &document.UnmarshalTypeError{ + Value: fmt.Sprintf("number overflow, %d", u), + Type: rv.Type(), + } + } + rv.SetUint(uint64(u)) + default: + return &document.UnmarshalTypeError{Value: "number", Type: rv.Type()} + } + return nil +} + +func (d *decoder) decodeNil(rv reflect.Value) error { + if rv.IsValid() && rv.CanSet() { + rv.Set(reflect.Zero(rv.Type())) + } + return nil +} + +func (d *decoder) decodeBool(v bool, rv reflect.Value) error { + switch rv.Kind() { + case reflect.Bool, reflect.Interface: + rv.Set(reflect.ValueOf(v).Convert(rv.Type())) + default: + return &document.UnmarshalTypeError{Value: "bool", Type: rv.Type()} + } + return nil +} + +func (d *decoder) decodeFloat(v float64, rv reflect.Value) error { + switch rv.Kind() { + case reflect.Interface: + rv.Set(reflect.ValueOf(v)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, accuracy := big.NewFloat(v).Int64() + if accuracy != big.Exact || rv.OverflowInt(i) { + return &document.UnmarshalTypeError{ + Value: fmt.Sprintf("int overflow, %e", v), + Type: rv.Type(), + } + } + rv.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u, accuracy := big.NewFloat(v).Uint64() + if accuracy != big.Exact || rv.OverflowUint(u) { + return &document.UnmarshalTypeError{ + Value: fmt.Sprintf("uint overflow, %e", v), + Type: rv.Type(), + } + } + rv.SetUint(u) + case reflect.Float32, reflect.Float64: + if rv.OverflowFloat(v) { + return &document.UnmarshalTypeError{ + Value: fmt.Sprintf("float overflow, %e", v), + Type: rv.Type(), + } + } + rv.SetFloat(v) + default: + return &document.UnmarshalTypeError{Value: "number", Type: rv.Type()} + } + return nil +} + +func (d *decoder) decodeList(v cbor.List, rv reflect.Value) error { + var isArray bool + + switch rv.Kind() { + case reflect.Slice: + // Make room for the slice elements if needed + if rv.IsNil() || rv.Cap() < len(v) { + rv.Set(reflect.MakeSlice(rv.Type(), 0, len(v))) + } + case reflect.Array: + // Limited to capacity of existing array. + isArray = true + case reflect.Interface: + s := make([]interface{}, len(v)) + for i, av := range v { + if err := d.decode(av, reflect.ValueOf(&s[i]).Elem(), serde.Tag{}); err != nil { + return err + } + } + rv.Set(reflect.ValueOf(s)) + return nil + default: + return &document.UnmarshalTypeError{Value: "list", Type: rv.Type()} + } + + // If rv is not a slice, array + for i := 0; i < rv.Cap() && i < len(v); i++ { + if !isArray { + rv.SetLen(i + 1) + } + if err := d.decode(v[i], rv.Index(i), serde.Tag{}); err != nil { + return err + } + } + + return nil +} + +func (d *decoder) decodeString(v string, rv reflect.Value) error { + switch rv.Kind() { + case reflect.String: + rv.SetString(v) + case reflect.Interface: + rv.Set(reflect.ValueOf(v).Convert(rv.Type())) + default: + return &document.UnmarshalTypeError{Value: "string", Type: rv.Type()} + } + return nil +} + +func (d *decoder) decodeMap(tv cbor.Map, rv reflect.Value) error { + switch rv.Kind() { + case reflect.Map: + t := rv.Type() + if t.Key().Kind() != reflect.String { + return &document.UnmarshalTypeError{Value: "map string key", Type: t.Key()} + } + if rv.IsNil() { + rv.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + if rv.CanInterface() && document.IsNoSerde(rv.Interface()) { + return &document.UnmarshalTypeError{ + Value: fmt.Sprintf("unsupported type"), + Type: rv.Type(), + } + } + case reflect.Interface: + rv.Set(reflect.MakeMap(serde.ReflectTypeOf.MapStringToInterface)) + rv = rv.Elem() + default: + return &document.UnmarshalTypeError{Value: "map", Type: rv.Type()} + } + + if rv.Kind() == reflect.Map { + for k, kv := range tv { + key := reflect.New(rv.Type().Key()).Elem() + key.SetString(k) + elem := reflect.New(rv.Type().Elem()).Elem() + if err := d.decode(kv, elem, serde.Tag{}); err != nil { + return err + } + rv.SetMapIndex(key, elem) + } + } else if rv.Kind() == reflect.Struct { + fields := serde.GetStructFields(rv.Type()) + for k, kv := range tv { + if f, ok := fields.FieldByName(k); ok { + fv := serde.DecoderFieldByIndex(rv, f.Index) + if err := d.decode(kv, fv, f.Tag); err != nil { + return err + } + } + } + } + + return nil +} + +func (d *decoder) decodeTag(tv *cbor.Tag, rv reflect.Value) error { + rvt := rv.Type() + switch { + case rvt.ConvertibleTo(serde.ReflectTypeOf.BigInt): + i, err := cbor.AsBigInt(tv) + if err != nil { + return &document.UnmarshalTypeError{Value: "tag", Type: rv.Type()} + } + + rv.Set(reflect.ValueOf(*i).Convert(rvt)) + return nil + case rvt.ConvertibleTo(serde.ReflectTypeOf.BigFloat): + i, err := asBigFloat(tv) + if err != nil { + return &document.UnmarshalTypeError{Value: "tag", Type: rv.Type()} + } + + rv.Set(reflect.ValueOf(*i).Convert(rvt)) + return nil + default: + return &document.UnmarshalTypeError{Value: "tag", Type: rv.Type()} + } +} + +func (d *decoder) unsupportedType(rv reflect.Value) error { + if rv.Kind() == reflect.Interface && rv.NumMethod() != 0 { + return &document.UnmarshalTypeError{Value: "non-empty interface", Type: rv.Type()} + } + + if rv.Type().ConvertibleTo(serde.ReflectTypeOf.Time) { + return &document.UnmarshalTypeError{ + Type: rv.Type(), + Value: fmt.Sprintf("time value"), + } + } + return nil +} + +func asBigFloat(tv *cbor.Tag) (*big.Float, error) { + const tagbase10 = 4 + + if tv.ID != tagbase10 { + return nil, fmt.Errorf("invalid tag: %d", tv.ID) + } + + pcs, ok := tv.Value.(cbor.List) + if !ok { + return nil, fmt.Errorf("invalid tagged type: %T", tv.Value) + } + + if len(pcs) != 2 { + return nil, fmt.Errorf("invalid tagged list len: %d", len(pcs)) + } + + eval, mval := pcs[0], pcs[1] + exp, err := cbor.AsBigInt(eval) + if err != nil { + return nil, fmt.Errorf("invalid exp: %w", err) + } + + mant, err := cbor.AsBigInt(mval) + if !ok { + return nil, fmt.Errorf("invalid mant: %w", err) + } + + // We literally re-express this as e and send it through + // bigfloat parse. Not mathematically amazing, but ensures that + // string-borne bignums and this are computed identically. + str := fmt.Sprintf("%se%s", mant.String(), exp.String()) + x, _, err := new(big.Float).Parse(str, 0) + return x, err +} diff --git a/document/cbor/decode_test.go b/document/cbor/decode_test.go new file mode 100644 index 000000000..c2ecf5b54 --- /dev/null +++ b/document/cbor/decode_test.go @@ -0,0 +1,130 @@ +package cbor + +import ( + "math" + "math/big" + "reflect" + "testing" + + "github.com/aws/smithy-go/encoding/cbor" + "github.com/aws/smithy-go/ptr" +) + +func TestDecode_KitchenSink(t *testing.T) { + type target struct { + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + + Slice []byte + String string + + List []target + Map map[string]target + + UintptrNil *uint + UintptrNonnil *uint + Bool bool + Float float64 + + BigInt *big.Int + BigNegInt *big.Int + BigFloat *big.Float + } + + in := cbor.Map{ + "Int8": cbor.NegInt(8), + "Int16": cbor.NegInt(16), + "Int32": cbor.NegInt(32), + "Int64": cbor.NegInt(64), + "Uint8": cbor.Uint(8), + "Uint16": cbor.Uint(16), + "Uint32": cbor.Uint(32), + "Uint64": cbor.Uint(64), + + "String": cbor.String("foo"), + + "List": cbor.List{ + cbor.Map{ + "Int8": cbor.NegInt(8), + }, + }, + "Map": cbor.Map{ + "k0": cbor.Map{ + "Int8": cbor.NegInt(8), + }, + }, + + "UintptrNil": &cbor.Nil{}, + "UintptrNonnil": cbor.Uint(4), + "Bool": cbor.Bool(true), + "Float": cbor.Float64(math.Inf(1)), + + "BigInt": &cbor.Tag{ + ID: 2, + Value: cbor.Slice{1, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + "BigNegInt": &cbor.Tag{ + ID: 3, + Value: cbor.Slice{1, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + "BigFloat": &cbor.Tag{ + ID: 4, + Value: cbor.List{ + cbor.NegInt(200), // exp + cbor.Uint(200), // mant + }, + }, + + "UnknownField": &cbor.Nil{}, + } + + expect := target{ + Int8: -8, + Int16: -16, + Int32: -32, + Int64: -64, + Uint8: 8, + Uint16: 16, + Uint32: 32, + Uint64: 64, + + String: "foo", + + List: []target{ + {Int8: -8}, + }, + Map: map[string]target{ + "k0": {Int8: -8}, + }, + + UintptrNil: nil, + UintptrNonnil: ptr.Uint(4), + Bool: true, + Float: math.Inf(1), + + BigInt: new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0}), + BigNegInt: new(big.Int).Sub( + big.NewInt(-1), + new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0}), + ), + BigFloat: func() *big.Float { + x, _ := new(big.Float).SetString("200e-200") + return x + }(), + } + + var actual target + dec := &decoder{} + if err := dec.Decode(in, &actual); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(expect, actual) { + t.Errorf("%v != %v", expect, actual) + } +} diff --git a/document/cbor/encode.go b/document/cbor/encode.go new file mode 100644 index 000000000..7ad7f43ff --- /dev/null +++ b/document/cbor/encode.go @@ -0,0 +1,228 @@ +package cbor + +import ( + "fmt" + "math/big" + "reflect" + + "github.com/aws/smithy-go/document" + "github.com/aws/smithy-go/document/internal/serde" + "github.com/aws/smithy-go/encoding/cbor" +) + +// encoderOptions is the set of options that can be configured for an Encoder. +// +// FUTURE(rpc2cbor): document support is currently disabled. This API is +// unexported until that changes. +type encoderOptions struct{} + +// encoder is a Smithy document encoder for CBOR-based protocols. +// +// FUTURE(rpc2cbor): document support is currently disabled. This API is +// unexported until that changes. +type encoder struct { + options encoderOptions +} + +// newEncoder returns an Encoder for serializing Smithy documents. +// +// FUTURE(rpc2cbor): document support is currently disabled. This API is +// unexported until that changes. +func newEncoder(optFns ...func(options *encoderOptions)) *encoder { + o := encoderOptions{} + + for _, fn := range optFns { + fn(&o) + } + + return &encoder{ + options: o, + } +} + +// Encode returns the CBOR encoding of v. +func (e *encoder) Encode(v interface{}) ([]byte, error) { + cv, err := e.encode(reflect.ValueOf(v), serde.Tag{}) + if err != nil { + return nil, err + } + + return cbor.Encode(cv), nil +} + +func (e *encoder) encode(rv reflect.Value, tag serde.Tag) (cbor.Value, error) { + if serde.IsZeroValue(rv) { + if tag.OmitEmpty { + return nil, nil + } + return e.encodeZeroValue(rv) + } + + rv = serde.ValueElem(rv) + switch rv.Kind() { + case reflect.Struct: + return e.encodeStruct(rv) + case reflect.Map: + return e.encodeMap(rv) + case reflect.Slice, reflect.Array: + return e.encodeSlice(rv) + case reflect.Invalid, reflect.Chan, reflect.Func, reflect.UnsafePointer: + return nil, nil + default: + return e.encodeScalar(rv) + } +} + +func (e *encoder) encodeZeroValue(rv reflect.Value) (cbor.Value, error) { + switch rv.Kind() { + case reflect.Array: + return cbor.List{}, nil + case reflect.String: + return cbor.String(""), nil + case reflect.Bool: + return cbor.Bool(false), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return cbor.EncodeFixedUint(0), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return cbor.EncodeFixedUint(0), nil + case reflect.Float32, reflect.Float64: + return cbor.Float64(0), nil + case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: + return &cbor.Nil{}, nil + default: + return nil, &document.InvalidMarshalError{Message: fmt.Sprintf("unknown value type: %s", rv.String())} + } +} + +func (e *encoder) encodeStruct(rv reflect.Value) (cbor.Value, error) { + if rv.CanInterface() && document.IsNoSerde(rv.Interface()) { + return nil, &document.UnmarshalTypeError{ + Value: fmt.Sprintf("unsupported type"), + Type: rv.Type(), + } + } + + switch { + case rv.Type().ConvertibleTo(serde.ReflectTypeOf.Time): + return nil, &document.InvalidMarshalError{ + Message: fmt.Sprintf("unsupported type %s", rv.Type().String()), + } + case rv.Type().ConvertibleTo(serde.ReflectTypeOf.BigFloat): + fallthrough + case rv.Type().ConvertibleTo(serde.ReflectTypeOf.BigInt): + return e.encodeNumber(rv) + } + + fields := serde.GetStructFields(rv.Type()) + + mv := cbor.Map{} + for _, f := range fields.All() { + if f.Name == "" { + return nil, &document.InvalidMarshalError{Message: "map key cannot be empty"} + } + + fv, found := serde.EncoderFieldByIndex(rv, f.Index) + if !found { + continue + } + + cv, err := e.encode(fv, f.Tag) + if err != nil { + return nil, err + } + if cv == nil { // from omitEmpty + continue + } + + mv[f.Name] = cv + } + + return mv, nil +} + +func (e *encoder) encodeMap(rv reflect.Value) (cbor.Map, error) { + mv := cbor.Map{} + for _, key := range rv.MapKeys() { + keyName := fmt.Sprint(key.Interface()) + if keyName == "" { + return nil, &document.InvalidMarshalError{"map key cannot be empty"} + } + + cv, err := e.encode(rv.MapIndex(key), serde.Tag{}) + if err != nil { + return nil, err + } + + mv[keyName] = cv + } + return mv, nil +} + +func (e *encoder) encodeSlice(rv reflect.Value) (cbor.List, error) { + lv := cbor.List{} + for i := 0; i < rv.Len(); i++ { + cv, err := e.encode(rv.Index(i), serde.Tag{}) + if err != nil { + return nil, err + } + + lv = append(lv, cv) + } + + return lv, nil +} + +func (e *encoder) encodeScalar(rv reflect.Value) (cbor.Value, error) { + switch rv.Kind() { + case reflect.Bool: + return cbor.Bool(rv.Bool()), nil + case reflect.String: + return cbor.String(rv.String()), nil + default: + return e.encodeNumber(rv) + } +} + +func (e *encoder) encodeNumber(rv reflect.Value) (cbor.Value, error) { + const tagbigpos = 2 + const tagbigneg = 3 + const tagbigfloat = 4 + + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + iv := rv.Int() + if iv >= 0 { + return cbor.EncodeFixedUint(uint64(iv)), nil + } + return cbor.EncodeFixedNegInt(uint64(-iv)), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return cbor.EncodeFixedUint(rv.Uint()), nil + case reflect.Float32, reflect.Float64: + return cbor.Float64(rv.Float()), nil + default: + rvt := rv.Type() + switch { + case rvt.ConvertibleTo(serde.ReflectTypeOf.BigInt): + i := rv.Convert(serde.ReflectTypeOf.BigInt).Interface().(big.Int) + if i.Sign() > -1 { + return &cbor.Tag{ + ID: tagbigpos, + Value: cbor.Slice(i.Bytes()), + }, nil + } + + biased := new(big.Int).Add(&i, big.NewInt(1)) + return &cbor.Tag{ + ID: tagbigneg, + Value: cbor.Slice(biased.Bytes()), + }, nil + case rvt.ConvertibleTo(serde.ReflectTypeOf.BigFloat): + // FUTURE(rpc2cbor): when document support is enabled, complete this logic + return &cbor.Tag{}, nil + default: + return nil, &document.InvalidMarshalError{ + Message: fmt.Sprintf("incompatible type: %s", rvt.String()), + } + } + } +} diff --git a/document/cbor/encode_test.go b/document/cbor/encode_test.go new file mode 100644 index 000000000..4909c48f8 --- /dev/null +++ b/document/cbor/encode_test.go @@ -0,0 +1,115 @@ +package cbor + +import ( + "math" + "math/big" + "reflect" + "testing" + + "github.com/aws/smithy-go/encoding/cbor" + "github.com/aws/smithy-go/ptr" +) + +func TestEncode_KitchenSink(t *testing.T) { + type subtarget struct { + Int8 int8 + Int16 int16 + } + type target struct { + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + String string + List []subtarget + Map map[string]subtarget + UintptrNil *uint + UintptrNonnil *uint + Bool bool + Float float64 + BigInt *big.Int + BigNegInt *big.Int + } + + in := target{ + Int8: -8, + Int16: -16, + Int32: -32, + Int64: -64, + Uint8: 8, + Uint16: 16, + Uint32: 32, + Uint64: 64, + String: "foo", + List: []subtarget{ + {Int8: -8}, + }, + Map: map[string]subtarget{ + "k0": {Int8: -8}, + }, + UintptrNil: nil, + UintptrNonnil: ptr.Uint(4), + Bool: true, + Float: math.Inf(1), + BigInt: new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0}), + BigNegInt: new(big.Int).Sub( + big.NewInt(-1), + new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0}), + ), + } + + expect := cbor.Map{ + "Int8": cbor.NegInt(8), + "Int16": cbor.NegInt(16), + "Int32": cbor.NegInt(32), + "Int64": cbor.NegInt(64), + "Uint8": cbor.Uint(8), + "Uint16": cbor.Uint(16), + "Uint32": cbor.Uint(32), + "Uint64": cbor.Uint(64), + "String": cbor.String("foo"), + "List": cbor.List{ + cbor.Map{ + "Int8": cbor.NegInt(8), + "Int16": cbor.Uint(0), // implicit + }, + }, + "Map": cbor.Map{ + "k0": cbor.Map{ + "Int8": cbor.NegInt(8), + "Int16": cbor.Uint(0), + }, + }, + "UintptrNil": &cbor.Nil{}, + "UintptrNonnil": cbor.Uint(4), + "Bool": cbor.Bool(true), + "Float": cbor.Float64(math.Inf(1)), + "BigInt": &cbor.Tag{ + ID: 2, + Value: cbor.Slice{1, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + "BigNegInt": &cbor.Tag{ + ID: 3, + Value: cbor.Slice{1, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + enc := &encoder{} + encoded, err := enc.Encode(in) + if err != nil { + t.Fatal(err) + } + + actual, err := cbor.Decode(encoded) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expect, actual) { + t.Errorf("%v != %v", expect, actual) + } +} diff --git a/encoding/cbor/cbor.go b/encoding/cbor/cbor.go new file mode 100644 index 000000000..fe8449a82 --- /dev/null +++ b/encoding/cbor/cbor.go @@ -0,0 +1,139 @@ +// Package cbor implements partial encoding/decoding of concise binary object +// representation (CBOR) described in [RFC 8949]. +// +// This package is intended for use only by the smithy client runtime. The +// exported API therein is not considered stable and is subject to breaking +// changes without notice. More specifically, this package implements a subset +// of the RFC 8949 specification required to support the Smithy RPCv2-CBOR +// protocol and is NOT suitable for general application use. +// +// The following principal restrictions apply: +// - Map (major type 5) keys can only be strings. +// - Float16 (major type 7, 25) values can be read but not encoded. Any +// float16 encountered during decode is converted to float32. +// - Indefinite-length values can be read but not encoded. Since the encoding +// API operates strictly off of a constructed syntax tree, the length of each +// data item in a Value will always be known and the encoder will always +// generate definite-length variants. +// +// It is the responsibility of the caller to determine whether a decoded CBOR +// integral or floating-point Value is suitable for its target (e.g. whether +// the value of a CBOR Uint fits into a field modeled as a Smithy short). +// +// All CBOR tags (major type 6) are implicitly supported since the +// encoder/decoder does not attempt to interpret a tag's contents. It is the +// responsibility of the caller to both provide valid Tag values to encode and +// to assert that a decoded Tag's contents are valid for its tag ID (e.g. +// ensuring whether a Tag with ID 1, indicating an enclosed epoch timestamp, +// actually contains a valid integral or floating-point CBOR Value). +// +// [RFC 8949]: https://www.rfc-editor.org/rfc/rfc8949.html +package cbor + +// Value describes a CBOR data item. +// +// The following types implement Value: +// - [Uint] +// - [NegInt] +// - [Slice] +// - [String] +// - [List] +// - [Map] +// - [Tag] +// - [Bool] +// - [Nil] +// - [Undefined] +// - [Float32] +// - [Float64] +type Value interface { + len() int + encode(p []byte) int +} + +var ( + _ Value = Uint(0) + _ Value = NegInt(0) + _ Value = Slice(nil) + _ Value = String("") + _ Value = List(nil) + _ Value = Map(nil) + _ Value = (*Tag)(nil) + _ Value = Bool(false) + _ Value = (*Nil)(nil) + _ Value = (*Undefined)(nil) + _ Value = Float32(0) + _ Value = Float64(0) +) + +// Uint describes a CBOR uint (major type 0) in the range [0, 2^64-1]. +type Uint uint64 + +// NegInt describes a CBOR negative int (major type 1) in the range [-2^64, -1]. +// +// The "true negative" value of a type 1 is specified by RFC 8949 to be -1 +// minus the encoded value. The encoder/decoder applies this bias +// automatically, e.g. the integral -100 is represented as NegInt(100), which +// will which encode to/from hex 3863 (major 1, minor 24, argument 99). +// +// This implicitly means that the lower bound of this type -2^64 is represented +// as the wraparound value NegInt(0). Deserializer implementations should take +// care to guard against this case when deriving a value for a signed integral +// type which was encoded as NegInt. +type NegInt uint64 + +// Slice describes a CBOR byte slice (major type 2). +type Slice []byte + +// String describes a CBOR text string (major type 3). +type String string + +// List describes a CBOR list (major type 4). +type List []Value + +// Map describes a CBOR map (major type 5). +// +// The type signature of the map's key is restricted to string as it is in +// Smithy. +type Map map[string]Value + +// Tag describes a CBOR-tagged value (major type 6). +type Tag struct { + ID uint64 + Value Value +} + +// Bool describes a boolean value (major type 7, argument 20/21). +type Bool bool + +// Nil is the `nil` / `null` literal (major type 7, argument 22). +type Nil struct{} + +// Undefined is the `undefined` literal (major type 7, argument 23). +type Undefined struct{} + +// Float32 describes an IEEE 754 single-precision floating-point number +// (major type 7, argument 26). +// +// Go does not natively support float16, all values encoded as such (major type +// 7, argument 25) must be represented by this variant instead. +type Float32 float32 + +// Float64 describes an IEEE 754 double-precision floating-point number +// (major type 7, argument 27). +type Float64 float64 + +// Encode returns a byte slice that encodes the given Value. +func Encode(v Value) []byte { + p := make([]byte, v.len()) + v.encode(p) + return p +} + +// Decode returns the Value encoded in the given byte slice. +func Decode(p []byte) (Value, error) { + v, _, err := decode(p) + if err != nil { + return nil, err + } + return v, nil +} diff --git a/encoding/cbor/coerce.go b/encoding/cbor/coerce.go new file mode 100644 index 000000000..0c585cd30 --- /dev/null +++ b/encoding/cbor/coerce.go @@ -0,0 +1,229 @@ +package cbor + +import ( + "fmt" + "math/big" + "time" +) + +func fmtNegint(v NegInt) string { + if v == 0 { + return "-2^64" + } + return fmt.Sprintf("-%d", v) +} + +// AsInt8 coerces a Value to its int8 representation if possible. +func AsInt8(v Value) (int8, error) { + const max8 = 0x7f + + switch vv := v.(type) { + case Uint: + if vv > max8 { + return 0, fmt.Errorf("cbor uint %d exceeds max int8 value", vv) + } + return int8(vv), nil + case NegInt: + if vv > max8+1 || vv == 0 { + return 0, fmt.Errorf("cbor negint %s exceeds min int8 value", fmtNegint(vv)) + } + return -int8(vv), nil + } + return 0, fmt.Errorf("unexpected value type %T", v) +} + +// AsInt16 coerces a Value to its int16 representation if possible. +func AsInt16(v Value) (int16, error) { + const max16 = 0x7fff + + switch vv := v.(type) { + case Uint: + if vv > max16 { + return 0, fmt.Errorf("cbor uint %d exceeds max int16 value", vv) + } + return int16(vv), nil + case NegInt: + if vv > max16+1 || vv == 0 { + return 0, fmt.Errorf("cbor negint %s exceeds min int16 value", fmtNegint(vv)) + } + return -int16(vv), nil + } + return 0, fmt.Errorf("unexpected value type %T", v) +} + +// AsInt32 coerces a Value to its int32 representation if possible. +func AsInt32(v Value) (int32, error) { + const max32 = 0x7fffffff + + switch vv := v.(type) { + case Uint: + if vv > max32 { + return 0, fmt.Errorf("cbor uint %d exceeds max int32 value", vv) + } + return int32(vv), nil + case NegInt: + if vv > max32+1 || vv == 0 { + return 0, fmt.Errorf("cbor negint %s exceeds min int32 value", fmtNegint(vv)) + } + return -int32(vv), nil + } + return 0, fmt.Errorf("unexpected value type %T", v) +} + +// AsInt64 coerces a Value to its int64 representation if possible. +func AsInt64(v Value) (int64, error) { + const max64 = 0x7fffffff_ffffffff + + switch vv := v.(type) { + case Uint: + if vv > max64 { + return 0, fmt.Errorf("cbor uint %d exceeds max int64 value", vv) + } + return int64(vv), nil + case NegInt: + if vv > max64+1 || vv == 0 { + return 0, fmt.Errorf("cbor negint %s exceeds min int64 value", fmtNegint(vv)) + } + return -int64(vv), nil + } + return 0, fmt.Errorf("unexpected value type %T", v) +} + +// AsFloat32 coerces a Value to its float32 representation if possible. +// +// A float32 may be represented by any of the following alternatives: +// - cbor uint (if within lossless range) +// - cbor -int (if within lossless range) +func AsFloat32(v Value) (float32, error) { + const maxLosslessFloat32 = 1 << 24 + + switch vv := v.(type) { + case Float32: + return float32(vv), nil + case Uint: + if vv > maxLosslessFloat32 { + return 0, fmt.Errorf("cbor uint %d exceeds max lossless float32 value", vv) + } + return float32(vv), nil + case NegInt: + if vv > maxLosslessFloat32 || vv == 0 { + return 0, fmt.Errorf("cbor negint %s exceeds min lossless float32 value", fmtNegint(vv)) + } + return -float32(vv), nil + } + return 0, fmt.Errorf("unexpected value type %T", v) +} + +// AsFloat64 coerces a Value to its float64 representation if possible. +// +// A float64 may be represented by any of the following alternatives: +// - float32 +// - cbor uint (if within lossless range) +// - cbor -int (if within lossless range) +func AsFloat64(v Value) (float64, error) { + const maxLosslessFloat64 = 1 << 54 + + switch vv := v.(type) { + case Float64: + return float64(vv), nil + case Float32: + return float64(vv), nil + case Uint: + if vv > maxLosslessFloat64 { + return 0, fmt.Errorf("cbor uint %d exceeds max lossless float64 value", vv) + } + return float64(vv), nil + case NegInt: + if vv > maxLosslessFloat64 || vv == 0 { + return 0, fmt.Errorf("cbor negint %s exceeds min lossless float64 value", fmtNegint(vv)) + } + return -float64(vv), nil + } + return 0, fmt.Errorf("unexpected value type %T", v) +} + +// AsTime coerces a Value to its time.Time representation if possible. +// +// This coercion will check that the given Value is a Tag with the registered +// number (1) for epoch time. The value for time.Time within that tag may be +// derived from any of the following: +// - float32 +// - float64 +// - cbor uint (within int64 bounds) +// - cbor -int (within int64 bounds) +// +// Tag number 0 (date-time RFC3339) is not supported. +func AsTime(v Value) (time.Time, error) { + const tagEpoch = 1 + + tag, ok := v.(*Tag) + if !ok { + return time.Time{}, fmt.Errorf("unexpected value type %T", v) + } + if tag.ID != tagEpoch { + return time.Time{}, fmt.Errorf("unexpected tag ID %d", tag.ID) + } + + switch vv := tag.Value.(type) { + case Float32: + return time.UnixMilli(int64(vv * 1e3)), nil + case Float64: + return time.UnixMilli(int64(vv * 1e3)), nil + } + + as64, err := AsInt64(tag.Value) // will handle fail on non-int types + if err != nil { + return time.Time{}, fmt.Errorf("coerce tag value: %w", err) + } + + return time.Unix(as64, 0), nil +} + +// AsBigInt coerces a Value to its big.Int representation if possible. +// +// A BigInt may be represented by any of the following: +// - Uint +// - NegInt +// - Tag (type 2/3, where tagged value is a Slice) +// - Nil +func AsBigInt(v Value) (*big.Int, error) { + switch vv := v.(type) { + case Uint: + return new(big.Int).SetUint64(uint64(vv)), nil + case NegInt: + i := new(big.Int) + if vv == 0 { + i.SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0}) + } else { + i.SetUint64(uint64(vv)) + } + return i.Neg(i), nil + case *Tag: + return asBigIntFromTag(vv) + case *Nil: + return nil, nil + default: + return nil, fmt.Errorf("unexpected value type %T", v) + } +} + +func asBigIntFromTag(tv *Tag) (*big.Int, error) { + const tagpos = 2 + const tagneg = 3 + + if tv.ID != tagpos && tv.ID != tagneg { + return nil, fmt.Errorf("unexpected tag ID %d", tv.ID) + } + + bytes, ok := tv.Value.(Slice) + if !ok { + return nil, fmt.Errorf("unexpected tag value type %T", tv.Value) + } + + i := new(big.Int).SetBytes([]byte(bytes)) + if tv.ID == tagneg { + i.Sub(big.NewInt(-1), i) + } + + return i, nil +} diff --git a/encoding/cbor/coerce_test.go b/encoding/cbor/coerce_test.go new file mode 100644 index 000000000..e89f37bfb --- /dev/null +++ b/encoding/cbor/coerce_test.go @@ -0,0 +1,531 @@ +package cbor + +import ( + "fmt" + "math/big" + "strings" + "testing" + "time" +) + +func TestAsInt8(t *testing.T) { + const maxv = 0x7f + for name, c := range map[string]struct { + In Value + Expect int8 + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "uint oob": { + In: Uint(maxv + 1), + Err: fmt.Sprintf("cbor uint %d exceeds", maxv+1), + }, + "negint oob": { + In: NegInt(maxv + 2), + Err: fmt.Sprintf("cbor negint %s exceeds", fmtNegint(NegInt(maxv+2))), + }, + "negint wrap oob": { + In: NegInt(0), + Err: "cbor negint -2^64 exceeds", + }, + "uint ok min": { + In: Uint(0), + Expect: 0, + }, + "uint ok max": { + In: Uint(maxv), + Expect: maxv, + }, + "negint ok min": { + In: NegInt(1), + Expect: -1, + }, + "negint ok max": { + In: NegInt(maxv + 1), + Expect: -maxv - 1, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsInt8(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsInt16(t *testing.T) { + const maxv = 0x7fff + for name, c := range map[string]struct { + In Value + Expect int16 + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "uint oob": { + In: Uint(maxv + 1), + Err: fmt.Sprintf("cbor uint %d exceeds", maxv+1), + }, + "negint oob": { + In: NegInt(maxv + 2), + Err: fmt.Sprintf("cbor negint %s exceeds", fmtNegint(NegInt(maxv+2))), + }, + "negint wrap oob": { + In: NegInt(0), + Err: "cbor negint -2^64 exceeds", + }, + "uint ok min": { + In: Uint(0), + Expect: 0, + }, + "uint ok max": { + In: Uint(maxv), + Expect: maxv, + }, + "negint ok min": { + In: NegInt(1), + Expect: -1, + }, + "negint ok max": { + In: NegInt(maxv + 1), + Expect: -maxv - 1, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsInt16(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsInt32(t *testing.T) { + const maxv = 0x7fffffff + for name, c := range map[string]struct { + In Value + Expect int32 + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "uint oob": { + In: Uint(maxv + 1), + Err: fmt.Sprintf("cbor uint %d exceeds", maxv+1), + }, + "negint oob": { + In: NegInt(maxv + 2), + Err: fmt.Sprintf("cbor negint %s exceeds", fmtNegint(NegInt(maxv+2))), + }, + "negint wrap oob": { + In: NegInt(0), + Err: "cbor negint -2^64 exceeds", + }, + "uint ok min": { + In: Uint(0), + Expect: 0, + }, + "uint ok max": { + In: Uint(maxv), + Expect: maxv, + }, + "negint ok min": { + In: NegInt(1), + Expect: -1, + }, + "negint ok max": { + In: NegInt(maxv + 1), + Expect: -maxv - 1, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsInt32(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsInt64(t *testing.T) { + const maxv = 0x7fffffff_ffffffff + for name, c := range map[string]struct { + In Value + Expect int64 + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "uint oob": { + In: Uint(uint64(maxv) + 1), + Err: fmt.Sprintf("cbor uint %d exceeds", uint64(maxv)+1), + }, + "negint oob": { + In: NegInt(uint64(maxv) + 2), + Err: fmt.Sprintf("cbor negint %s exceeds", fmtNegint(NegInt(uint64(maxv)+2))), + }, + "negint wrap oob": { + In: NegInt(0), + Err: "cbor negint -2^64 exceeds", + }, + "uint ok min": { + In: Uint(0), + Expect: 0, + }, + "uint ok max": { + In: Uint(maxv), + Expect: maxv, + }, + "negint ok min": { + In: NegInt(1), + Expect: -1, + }, + "negint ok max": { + In: NegInt(maxv + 1), + Expect: -maxv - 1, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsInt64(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsFloat32(t *testing.T) { + const maxv = 1 << 24 + for name, c := range map[string]struct { + In Value + Expect float32 + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "uint oob": { + In: Uint(maxv + 1), + Err: fmt.Sprintf("cbor uint %d exceeds", maxv+1), + }, + "negint oob": { + In: NegInt(maxv + 2), + Err: fmt.Sprintf("cbor negint %s exceeds", fmtNegint(NegInt(maxv+2))), + }, + "negint wrap oob": { + In: NegInt(0), + Err: "cbor negint -2^64 exceeds", + }, + "uint ok min": { + In: Uint(0), + Expect: 0, + }, + "uint ok max": { + In: Uint(maxv), + Expect: maxv, + }, + "negint ok min": { + In: NegInt(1), + Expect: -1, + }, + "negint ok max": { + In: NegInt(maxv), + Expect: -maxv, + }, + "direct": { + In: Float32(0.5), + Expect: 0.5, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsFloat32(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsFloat64(t *testing.T) { + const maxv = 1 << 54 + for name, c := range map[string]struct { + In Value + Expect float64 + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "uint oob": { + In: Uint(maxv + 1), + Err: fmt.Sprintf("cbor uint %d exceeds", maxv+1), + }, + "negint oob": { + In: NegInt(maxv + 2), + Err: fmt.Sprintf("cbor negint %s exceeds", fmtNegint(NegInt(maxv+2))), + }, + "negint wrap oob": { + In: NegInt(0), + Err: "cbor negint -2^64 exceeds", + }, + "uint ok min": { + In: Uint(0), + Expect: 0, + }, + "uint ok max": { + In: Uint(maxv), + Expect: maxv, + }, + "negint ok min": { + In: NegInt(1), + Expect: -1, + }, + "negint ok max": { + In: NegInt(maxv), + Expect: -maxv, + }, + "float32": { + In: Float32(0.5), + Expect: 0.5, + }, + "direct": { + In: Float64(0.5), + Expect: 0.5, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsFloat64(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsTime(t *testing.T) { + for name, c := range map[string]struct { + In Value + Expect time.Time + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "wrong tag": { + In: &Tag{ID: 2}, + Err: "unexpected tag ID 2", + }, + "wrong tag value": { + In: &Tag{ID: 1, Value: String("")}, + Err: "coerce tag value: unexpected value type cbor.String", + }, + "no tag value": { + In: &Tag{ID: 1}, + Err: "coerce tag value: unexpected value type ", + }, + "negint": { + In: &Tag{ID: 1, Value: Uint(4)}, + Expect: time.UnixMilli(4000), + }, + "float32": { + In: &Tag{ID: 1, Value: Float32(3.997)}, + Expect: time.UnixMilli(3997), + }, + "float64": { + In: &Tag{ID: 1, Value: Float64(3.997)}, + Expect: time.UnixMilli(3997), + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsTime(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if actual != c.Expect { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} + +func TestAsBigInt(t *testing.T) { + for name, c := range map[string]struct { + In Value + Expect *big.Int + Err string + }{ + "wrong type": { + In: String(""), + Err: "unexpected value type cbor.String", + }, + "wrong tag": { + In: &Tag{ID: 1}, + Err: "unexpected tag ID 1", + }, + "wrong tag value": { + In: &Tag{ID: 2, Value: String("")}, + Err: "unexpected tag value type cbor.String", + }, + "uint min": { + In: Uint(0), + Expect: big.NewInt(0), + }, + "uint max": { + In: Uint(0xffffffff_ffffffff), + Expect: new(big.Int).SetBytes( + []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + ), + }, + "negint min": { + In: NegInt(1), + Expect: big.NewInt(-1), + }, + "negint max": { + In: NegInt(0), + Expect: func() *big.Int { + i := new(big.Int).SetBytes( + []byte{1, 0, 0, 0, 0, 0, 0, 0, 0}, + ) + return i.Neg(i) + }(), + }, + "tag 2": { + In: &Tag{ + ID: 2, + Value: Slice{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + Expect: new(big.Int).SetBytes( + []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + ), + }, + "tag 3": { + In: &Tag{ + ID: 3, + Value: Slice{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + Expect: func() *big.Int { + i := new(big.Int).SetBytes( + []byte{1, 0, 0, 0, 0, 0, 0, 0, 0}, + ) + return i.Neg(i) + }(), + }, + "nil": { + In: &Nil{}, + Expect: nil, + }, + } { + t.Run(name, func(t *testing.T) { + actual, err := AsBigInt(c.In) + if c.Err == "" { + if err != nil { + t.Fatalf("expect no err, got %v", err) + } + if c.Expect.Cmp(actual) != 0 { + t.Fatalf("%v != %v", c.Expect, actual) + } + } else { + if err == nil { + t.Fatalf("expect err %v", err) + } + if !strings.Contains(err.Error(), c.Err) { + t.Fatalf("'%v' does not contain '%s'", err, c.Err) + } + } + }) + } +} diff --git a/encoding/cbor/const.go b/encoding/cbor/const.go new file mode 100644 index 000000000..353471557 --- /dev/null +++ b/encoding/cbor/const.go @@ -0,0 +1,41 @@ +package cbor + +// major type in LSB position +type majorType byte + +const ( + majorTypeUint majorType = iota + majorTypeNegInt + majorTypeSlice + majorTypeString + majorTypeList + majorTypeMap + majorTypeTag + majorType7 +) + +// masks for major/minor component in encoded head +const ( + maskMajor = 0b111 << 5 + maskMinor = 0b11111 +) + +// minor value encodings to represent arg bit length (and indefinite) +const ( + minorArg1 = 24 + minorArg2 = 25 + minorArg4 = 26 + minorArg8 = 27 + minorIndefinite = 31 +) + +// minor sentinels for everything in major 7 +const ( + major7False = 20 + major7True = 21 + major7Nil = 22 + major7Undefined = 23 + major7Float16 = minorArg2 + major7Float32 = minorArg4 + major7Float64 = minorArg8 +) diff --git a/encoding/cbor/decode.go b/encoding/cbor/decode.go new file mode 100644 index 000000000..036f3ac34 --- /dev/null +++ b/encoding/cbor/decode.go @@ -0,0 +1,320 @@ +package cbor + +import ( + "encoding/binary" + "fmt" + "math" +) + +func decode(p []byte) (Value, int, error) { + if len(p) == 0 { + return nil, 0, fmt.Errorf("unexpected end of payload") + } + + switch peekMajor(p) { + case majorTypeUint: + return decodeUint(p) + case majorTypeNegInt: + return decodeNegInt(p) + case majorTypeSlice: + return decodeSlice(p, majorTypeSlice) + case majorTypeString: + s, n, err := decodeSlice(p, majorTypeString) + return String(s), n, err + case majorTypeList: + return decodeList(p) + case majorTypeMap: + return decodeMap(p) + case majorTypeTag: + return decodeTag(p) + default: // majorType7 + return decodeMajor7(p) + } +} + +func decodeUint(p []byte) (Uint, int, error) { + i, off, err := decodeArgument(p) + if err != nil { + return 0, 0, fmt.Errorf("decode argument: %w", err) + } + + return Uint(i), off, nil +} + +func decodeNegInt(p []byte) (NegInt, int, error) { + i, off, err := decodeArgument(p) + if err != nil { + return 0, 0, fmt.Errorf("decode argument: %w", err) + } + + return NegInt(i + 1), off, nil +} + +// this routine is used for both string and slice major types, the value of +// inner specifies which context we're in (needed for validating subsegments +// inside indefinite encodings) +func decodeSlice(p []byte, inner majorType) (Slice, int, error) { + minor := peekMinor(p) + if minor == minorIndefinite { + return decodeSliceIndefinite(p, inner) + } + + slen, off, err := decodeArgument(p) + if err != nil { + return nil, 0, fmt.Errorf("decode argument: %w", err) + } + + p = p[off:] + if uint64(len(p)) < slen { + return nil, 0, fmt.Errorf("slice len %d greater than remaining buf len", slen) + } + + return Slice(p[:slen]), off + int(slen), nil +} + +func decodeSliceIndefinite(p []byte, inner majorType) (Slice, int, error) { + p = p[1:] + + s := Slice{} + for off := 0; len(p) > 0; { + if p[0] == 0xff { + return s, off + 2, nil + } + + if major := peekMajor(p); major != inner { + return nil, 0, fmt.Errorf("unexpected major type %d in indefinite slice", major) + } + if peekMinor(p) == minorIndefinite { + return nil, 0, fmt.Errorf("nested indefinite slice") + } + + ss, n, err := decodeSlice(p, inner) + if err != nil { + return nil, 0, fmt.Errorf("decode subslice: %w", err) + } + p = p[n:] + + s = append(s, ss...) + off += n + } + return nil, 0, fmt.Errorf("expected break marker") +} + +func decodeList(p []byte) (List, int, error) { + minor := peekMinor(p) + if minor == minorIndefinite { + return decodeListIndefinite(p) + } + + alen, off, err := decodeArgument(p) + if err != nil { + return nil, 0, fmt.Errorf("decode argument: %w", err) + } + p = p[off:] + + l := List{} + for i := 0; i < int(alen); i++ { + item, n, err := decode(p) + if err != nil { + return nil, 0, fmt.Errorf("decode item: %w", err) + } + p = p[n:] + + l = append(l, item) + off += n + } + + return l, off, nil +} + +func decodeListIndefinite(p []byte) (List, int, error) { + p = p[1:] + + l := List{} + for off := 0; len(p) > 0; { + if p[0] == 0xff { + return l, off + 2, nil + } + + item, n, err := decode(p) + if err != nil { + return nil, 0, fmt.Errorf("decode item: %w", err) + } + p = p[n:] + + l = append(l, item) + off += n + } + return nil, 0, fmt.Errorf("expected break marker") +} + +func decodeMap(p []byte) (Map, int, error) { + minor := peekMinor(p) + if minor == minorIndefinite { + return decodeMapIndefinite(p) + } + + maplen, off, err := decodeArgument(p) + if err != nil { + return nil, 0, fmt.Errorf("decode argument: %w", err) + } + p = p[off:] + + mp := Map{} + for i := 0; i < int(maplen); i++ { + if len(p) == 0 { + return nil, 0, fmt.Errorf("unexpected end of payload") + } + + if major := peekMajor(p); major != majorTypeString { + return nil, 0, fmt.Errorf("unexpected major type %d for map key", major) + } + + key, kn, err := decodeSlice(p, majorTypeString) + if err != nil { + return nil, 0, fmt.Errorf("decode key: %w", err) + } + p = p[kn:] + + value, vn, err := decode(p) + if err != nil { + return nil, 0, fmt.Errorf("decode value: %w", err) + } + p = p[vn:] + + mp[string(key)] = value + off += kn + vn + } + + return mp, off, nil +} + +func decodeMapIndefinite(p []byte) (Map, int, error) { + p = p[1:] + + mp := Map{} + for off := 0; len(p) > 0; { + if p[0] == 0xff { + return mp, off + 2, nil + } + + if major := peekMajor(p); major != majorTypeString { + return nil, 0, fmt.Errorf("unexpected major type %d for map key", major) + } + + key, kn, err := decodeSlice(p, majorTypeString) + if err != nil { + return nil, 0, fmt.Errorf("decode key: %w", err) + } + p = p[kn:] + + value, vn, err := decode(p) + if err != nil { + return nil, 0, fmt.Errorf("decode value: %w", err) + } + p = p[vn:] + + mp[string(key)] = value + off += kn + vn + } + return nil, 0, fmt.Errorf("expected break marker") +} + +func decodeTag(p []byte) (*Tag, int, error) { + id, off, err := decodeArgument(p) + if err != nil { + return nil, 0, fmt.Errorf("decode argument: %w", err) + } + p = p[off:] + + v, n, err := decode(p) + if err != nil { + return nil, 0, fmt.Errorf("decode value: %w", err) + } + + return &Tag{ID: id, Value: v}, off + n, nil +} + +func decodeMajor7(p []byte) (Value, int, error) { + switch m := peekMinor(p); m { + case major7True, major7False: + return Bool(m == major7True), 1, nil + case major7Nil: + return &Nil{}, 1, nil + case major7Undefined: + return &Undefined{}, 1, nil + case major7Float16: + if len(p) < 3 { + return nil, 0, fmt.Errorf("incomplete float16 at end of buf") + } + b := binary.BigEndian.Uint16(p[1:]) + return Float32(math.Float32frombits(float16to32(b))), 3, nil + case major7Float32: + if len(p) < 5 { + return nil, 0, fmt.Errorf("incomplete float32 at end of buf") + } + b := binary.BigEndian.Uint32(p[1:]) + return Float32(math.Float32frombits(b)), 5, nil + case major7Float64: + if len(p) < 9 { + return nil, 0, fmt.Errorf("incomplete float64 at end of buf") + } + b := binary.BigEndian.Uint64(p[1:]) + return Float64(math.Float64frombits(b)), 9, nil + default: + return nil, 0, fmt.Errorf("unexpected minor value %d", m) + } +} + +func peekMajor(p []byte) majorType { + return majorType(p[0] & maskMajor >> 5) +} + +func peekMinor(p []byte) byte { + return p[0] & maskMinor +} + +// pulls the next argument out of the buffer +// +// expects one of the sized arguments and will error otherwise - callers that +// need to check for the indefinite flag must do so externally +func decodeArgument(p []byte) (uint64, int, error) { + minor := peekMinor(p) + if minor < minorArg1 { + return uint64(minor), 1, nil + } + + switch minor { + case minorArg1, minorArg2, minorArg4, minorArg8: + argLen := mtol(minor) + if len(p) < argLen+1 { + return 0, 0, fmt.Errorf("arg len %d greater than remaining buf len", argLen) + } + return readArgument(p[1:], argLen), argLen + 1, nil + default: + return 0, 0, fmt.Errorf("unexpected minor value %d", minor) + } +} + +// minor value to arg len in bytes, assumes minor was checked to be in [24,27] +func mtol(minor byte) int { + if minor == minorArg1 { + return 1 + } else if minor == minorArg2 { + return 2 + } else if minor == minorArg4 { + return 4 + } + return 8 +} + +func readArgument(p []byte, len int) uint64 { + if len == 1 { + return uint64(p[0]) + } else if len == 2 { + return uint64(binary.BigEndian.Uint16(p)) + } else if len == 4 { + return uint64(binary.BigEndian.Uint32(p)) + } + return uint64(binary.BigEndian.Uint64(p)) +} diff --git a/encoding/cbor/decode_test.go b/encoding/cbor/decode_test.go new file mode 100644 index 000000000..c0ac01182 --- /dev/null +++ b/encoding/cbor/decode_test.go @@ -0,0 +1,1334 @@ +package cbor + +import ( + "math" + "reflect" + "strings" + "testing" +) + +func TestDecode_InvalidArgument(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Err string + }{ + "uint/1": { + []byte{0<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "uint/2": { + []byte{0<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "uint/4": { + []byte{0<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "uint/8": { + []byte{0<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "uint/?": { + []byte{0<<5 | 31}, + "unexpected minor value 31", + }, + "negint/1": { + []byte{1<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "negint/2": { + []byte{1<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "negint/4": { + []byte{1<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "negint/8": { + []byte{1<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "negint/?": { + []byte{1<<5 | 31}, + "unexpected minor value 31", + }, + "slice/1": { + []byte{2<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "slice/2": { + []byte{2<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "slice/4": { + []byte{2<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "slice/8": { + []byte{2<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "string/1": { + []byte{3<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "string/2": { + []byte{3<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "string/4": { + []byte{3<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "string/8": { + []byte{3<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "list/1": { + []byte{4<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "list/2": { + []byte{4<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "list/4": { + []byte{4<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "list/8": { + []byte{4<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "map/1": { + []byte{5<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "map/2": { + []byte{5<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "map/4": { + []byte{5<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "map/8": { + []byte{5<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "tag/1": { + []byte{6<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "tag/2": { + []byte{6<<5 | 25, 0}, + "arg len 2 greater than remaining buf len", + }, + "tag/4": { + []byte{6<<5 | 26, 0, 0, 0}, + "arg len 4 greater than remaining buf len", + }, + "tag/8": { + []byte{6<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "arg len 8 greater than remaining buf len", + }, + "tag/?": { + []byte{6<<5 | 31}, + "unexpected minor value 31", + }, + "major7/float16": { + []byte{7<<5 | 25, 0}, + "incomplete float16 at end of buf", + }, + "major7/float32": { + []byte{7<<5 | 26, 0, 0, 0}, + "incomplete float32 at end of buf", + }, + "major7/float64": { + []byte{7<<5 | 27, 0, 0, 0, 0, 0, 0, 0}, + "incomplete float64 at end of buf", + }, + "major7/?": { + []byte{7<<5 | 31}, + "unexpected minor value 31", + }, + } { + t.Run(name, func(t *testing.T) { + _, _, err := decode(c.In) + if err == nil { + t.Errorf("expect err %s", c.Err) + } + if aerr := err.Error(); !strings.Contains(aerr, c.Err) { + t.Errorf("expect err %s, got %s", c.Err, aerr) + } + }) + } +} + +func TestDecode_InvalidSlice(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Err string + }{ + "slice/1, not enough bytes": { + []byte{2<<5 | 24, 1}, + "slice len 1 greater than remaining buf len", + }, + "slice/?, no break": { + []byte{2<<5 | 31}, + "expected break marker", + }, + "slice/?, invalid nested major": { + []byte{2<<5 | 31, 3<<5 | 0}, + "unexpected major type 3 in indefinite slice", + }, + "slice/?, nested indefinite": { + []byte{2<<5 | 31, 2<<5 | 31}, + "nested indefinite slice", + }, + "slice/?, invalid nested definite": { + []byte{2<<5 | 31, 2<<5 | 24, 1}, + "decode subslice: slice len 1 greater than remaining buf len", + }, + "string/1, not enough bytes": { + []byte{3<<5 | 24, 1}, + "slice len 1 greater than remaining buf len", + }, + "string/?, no break": { + []byte{3<<5 | 31}, + "expected break marker", + }, + "string/?, invalid nested major": { + []byte{3<<5 | 31, 2<<5 | 0}, + "unexpected major type 2 in indefinite slice", + }, + "string/?, nested indefinite": { + []byte{3<<5 | 31, 3<<5 | 31}, + "nested indefinite slice", + }, + "string/?, invalid nested definite": { + []byte{3<<5 | 31, 3<<5 | 24, 1}, + "decode subslice: slice len 1 greater than remaining buf len", + }, + } { + t.Run(name, func(t *testing.T) { + _, _, err := decode(c.In) + if err == nil { + t.Errorf("expect err %s", c.Err) + } + if aerr := err.Error(); !strings.Contains(aerr, c.Err) { + t.Errorf("expect err %s, got %s", c.Err, aerr) + } + }) + } +} + +func TestDecode_InvalidList(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Err string + }{ + "[] / eof after head": { + []byte{4<<5 | 1}, + "unexpected end of payload", + }, + "[] / invalid item": { + []byte{4<<5 | 1, 0<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "[_ ] / no break": { + []byte{4<<5 | 31}, + "expected break marker", + }, + "[_ ] / invalid item": { + []byte{4<<5 | 31, 0<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + } { + t.Run(name, func(t *testing.T) { + _, _, err := decode(c.In) + if err == nil { + t.Errorf("expect err %s", c.Err) + } + if aerr := err.Error(); !strings.Contains(aerr, c.Err) { + t.Errorf("expect err %s, got %s", c.Err, aerr) + } + }) + } +} + +func TestDecode_InvalidMap(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Err string + }{ + "{} / eof after head": { + []byte{5<<5 | 1}, + "unexpected end of payload", + }, + "{} / non-string key": { + []byte{5<<5 | 1, 0}, + "unexpected major type 0 for map key", + }, + "{} / invalid key": { + []byte{5<<5 | 1, 3<<5 | 24, 1}, + "slice len 1 greater than remaining buf len", + }, + "{} / invalid value": { + []byte{5<<5 | 1, 3<<5 | 3, 0x66, 0x6f, 0x6f, 0<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "{_ } / no break": { + []byte{5<<5 | 31}, + "expected break marker", + }, + "{_ } / non-string key": { + []byte{5<<5 | 31, 0}, + "unexpected major type 0 for map key", + }, + "{_ } / invalid key": { + []byte{5<<5 | 31, 3<<5 | 24, 1}, + "slice len 1 greater than remaining buf len", + }, + "{_ } / invalid value": { + []byte{5<<5 | 31, 3<<5 | 3, 0x66, 0x6f, 0x6f, 0<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + } { + t.Run(name, func(t *testing.T) { + _, _, err := decode(c.In) + if err == nil { + t.Errorf("expect err %s", c.Err) + } + if aerr := err.Error(); !strings.Contains(aerr, c.Err) { + t.Errorf("expect err %s, got %s", c.Err, aerr) + } + }) + } +} + +func TestDecode_InvalidTag(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Err string + }{ + "invalid value": { + []byte{6<<5 | 1, 0<<5 | 24}, + "arg len 1 greater than remaining buf len", + }, + "eof": { + []byte{6<<5 | 1}, + "unexpected end of payload", + }, + } { + t.Run(name, func(t *testing.T) { + _, _, err := decode(c.In) + if err == nil { + t.Errorf("expect err %s", c.Err) + } + if aerr := err.Error(); !strings.Contains(aerr, c.Err) { + t.Errorf("expect err %s, got %s", c.Err, aerr) + } + }) + } +} + +func TestDecode_Atomic(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "uint/0/min": { + []byte{0<<5 | 0}, + Uint(0), + }, + "uint/0/max": { + []byte{0<<5 | 23}, + Uint(23), + }, + "uint/1/min": { + []byte{0<<5 | 24, 0}, + Uint(0), + }, + "uint/1/max": { + []byte{0<<5 | 24, 0xff}, + Uint(0xff), + }, + "uint/2/min": { + []byte{0<<5 | 25, 0, 0}, + Uint(0), + }, + "uint/2/max": { + []byte{0<<5 | 25, 0xff, 0xff}, + Uint(0xffff), + }, + "uint/4/min": { + []byte{0<<5 | 26, 0, 0, 0, 0}, + Uint(0), + }, + "uint/4/max": { + []byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}, + Uint(0xffffffff), + }, + "uint/8/min": { + []byte{0<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}, + Uint(0), + }, + "uint/8/max": { + []byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + Uint(0xffffffff_ffffffff), + }, + "negint/0/min": { + []byte{1<<5 | 0}, + NegInt(1), + }, + "negint/0/max": { + []byte{1<<5 | 23}, + NegInt(24), + }, + "negint/1/min": { + []byte{1<<5 | 24, 0}, + NegInt(1), + }, + "negint/1/max": { + []byte{1<<5 | 24, 0xff}, + NegInt(0x100), + }, + "negint/2/min": { + []byte{1<<5 | 25, 0, 0}, + NegInt(1), + }, + "negint/2/max": { + []byte{1<<5 | 25, 0xff, 0xff}, + NegInt(0x10000), + }, + "negint/4/min": { + []byte{1<<5 | 26, 0, 0, 0, 0}, + NegInt(1), + }, + "negint/4/max": { + []byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}, + NegInt(0x100000000), + }, + "negint/8/min": { + []byte{1<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}, + NegInt(1), + }, + "negint/8/max": { + []byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + NegInt(0xffffffff_ffffffff), + }, + "true": { + []byte{7<<5 | major7True}, + Bool(true), + }, + "false": { + []byte{7<<5 | major7False}, + Bool(false), + }, + "null": { + []byte{7<<5 | major7Nil}, + &Nil{}, + }, + "undefined": { + []byte{7<<5 | major7Undefined}, + &Undefined{}, + }, + "float16/+Inf": { + []byte{7<<5 | major7Float16, 0x7c, 0}, + Float32(math.Float32frombits(0x7f800000)), + }, + "float16/-Inf": { + []byte{7<<5 | major7Float16, 0xfc, 0}, + Float32(math.Float32frombits(0xff800000)), + }, + "float16/NaN/MSB": { + []byte{7<<5 | major7Float16, 0x7e, 0}, + Float32(math.Float32frombits(0x7fc00000)), + }, + "float16/NaN/LSB": { + []byte{7<<5 | major7Float16, 0x7c, 1}, + Float32(math.Float32frombits(0x7f802000)), + }, + "float32": { + []byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}, + Float32(math.Float32frombits(0x7f800000)), + }, + "float64": { + []byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}, + Float64(math.Float64frombits(0x7ff00000_00000000)), + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer") + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_DefiniteSlice(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "len = 0": { + []byte{2<<5 | 0}, + Slice{}, + }, + "len > 0": { + []byte{2<<5 | 3, 0x66, 0x6f, 0x6f}, + Slice{0x66, 0x6f, 0x6f}, + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer") + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_IndefiniteSlice(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "len = 0": { + []byte{2<<5 | 31, 0xff}, + Slice{}, + }, + "len = 0, explicit": { + []byte{2<<5 | 31, 2<<5 | 0, 0xff}, + Slice{}, + }, + "len = 0, len > 0": { + []byte{ + 2<<5 | 31, + 2<<5 | 0, + 2<<5 | 3, 0x66, 0x6f, 0x6f, + 0xff, + }, + Slice{0x66, 0x6f, 0x6f}, + }, + "len > 0, len = 0": { + []byte{ + 2<<5 | 31, + 2<<5 | 3, 0x66, 0x6f, 0x6f, + 2<<5 | 0, + 0xff, + }, + Slice{0x66, 0x6f, 0x6f}, + }, + "len > 0, len > 0": { + []byte{ + 2<<5 | 31, + 2<<5 | 3, 0x66, 0x6f, 0x6f, + 2<<5 | 3, 0x66, 0x6f, 0x6f, + 0xff, + }, + Slice{0x66, 0x6f, 0x6f, 0x66, 0x6f, 0x6f}, + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer") + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_DefiniteString(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "len = 0": { + []byte{3<<5 | 0}, + String(""), + }, + "len > 0": { + []byte{3<<5 | 3, 0x66, 0x6f, 0x6f}, + String("foo"), + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer") + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_IndefiniteString(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "len = 0": { + []byte{3<<5 | 31, 0xff}, + String(""), + }, + "len = 0, explicit": { + []byte{3<<5 | 31, 3<<5 | 0, 0xff}, + String(""), + }, + "len = 0, len > 0": { + []byte{ + 3<<5 | 31, + 3<<5 | 0, + 3<<5 | 3, 0x66, 0x6f, 0x6f, + 0xff, + }, + String("foo"), + }, + "len > 0, len = 0": { + []byte{ + 3<<5 | 31, + 3<<5 | 3, 0x66, 0x6f, 0x6f, + 3<<5 | 0, + 0xff, + }, + String("foo"), + }, + "len > 0, len > 0": { + []byte{ + 3<<5 | 31, + 3<<5 | 3, 0x66, 0x6f, 0x6f, + 3<<5 | 3, 0x66, 0x6f, 0x6f, + 0xff, + }, + String("foofoo"), + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer") + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_List(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "[uint/0/min]": { + In: withDefiniteList([]byte{0<<5 | 0}), + Expect: List{Uint(0)}, + }, + "[uint/0/max]": { + In: withDefiniteList([]byte{0<<5 | 23}), + Expect: List{Uint(23)}, + }, + "[uint/1/min]": { + In: withDefiniteList([]byte{0<<5 | 24, 0}), + Expect: List{Uint(0)}, + }, + "[uint/1/max]": { + In: withDefiniteList([]byte{0<<5 | 24, 0xff}), + Expect: List{Uint(0xff)}, + }, + "[uint/2/min]": { + In: withDefiniteList([]byte{0<<5 | 25, 0, 0}), + Expect: List{Uint(0)}, + }, + "[uint/2/max]": { + In: withDefiniteList([]byte{0<<5 | 25, 0xff, 0xff}), + Expect: List{Uint(0xffff)}, + }, + "[uint/4/min]": { + In: withDefiniteList([]byte{0<<5 | 26, 0, 0, 0, 0}), + Expect: List{Uint(0)}, + }, + "[uint/4/max]": { + In: withDefiniteList([]byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: List{Uint(0xffffffff)}, + }, + "[uint/8/min]": { + In: withDefiniteList([]byte{0<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: List{Uint(0)}, + }, + "[uint/8/max]": { + In: withDefiniteList([]byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), + Expect: List{Uint(0xffffffff_ffffffff)}, + }, + "[negint/0/min]": { + In: withDefiniteList([]byte{1<<5 | 0}), + Expect: List{NegInt(1)}, + }, + "[negint/0/max]": { + In: withDefiniteList([]byte{1<<5 | 23}), + Expect: List{NegInt(24)}, + }, + "[negint/1/min]": { + In: withDefiniteList([]byte{1<<5 | 24, 0}), + Expect: List{NegInt(1)}, + }, + "[negint/1/max]": { + In: withDefiniteList([]byte{1<<5 | 24, 0xff}), + Expect: List{NegInt(0x100)}, + }, + "[negint/2/min]": { + In: withDefiniteList([]byte{1<<5 | 25, 0, 0}), + Expect: List{NegInt(1)}, + }, + "[negint/2/max]": { + In: withDefiniteList([]byte{1<<5 | 25, 0xff, 0xff}), + Expect: List{NegInt(0x10000)}, + }, + "[negint/4/min]": { + In: withDefiniteList([]byte{1<<5 | 26, 0, 0, 0, 0}), + Expect: List{NegInt(1)}, + }, + "[negint/4/max]": { + In: withDefiniteList([]byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: List{NegInt(0x100000000)}, + }, + "[negint/8/min]": { + In: withDefiniteList([]byte{1<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: List{NegInt(1)}, + }, + "[negint/8/max]": { + In: withDefiniteList([]byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}), + Expect: List{NegInt(0xffffffff_ffffffff)}, + }, + "[true]": { + In: withDefiniteList([]byte{7<<5 | major7True}), + Expect: List{Bool(true)}, + }, + "[false]": { + In: withDefiniteList([]byte{7<<5 | major7False}), + Expect: List{Bool(false)}, + }, + "[null]": { + In: withDefiniteList([]byte{7<<5 | major7Nil}), + Expect: List{&Nil{}}, + }, + "[undefined]": { + In: withDefiniteList([]byte{7<<5 | major7Undefined}), + Expect: List{&Undefined{}}, + }, + "[float16/+Inf]": { + In: withDefiniteList([]byte{7<<5 | major7Float16, 0x7c, 0}), + Expect: List{Float32(math.Float32frombits(0x7f800000))}, + }, + "[float16/-Inf]": { + In: withDefiniteList([]byte{7<<5 | major7Float16, 0xfc, 0}), + Expect: List{Float32(math.Float32frombits(0xff800000))}, + }, + "[float16/NaN/MSB]": { + In: withDefiniteList([]byte{7<<5 | major7Float16, 0x7e, 0}), + Expect: List{Float32(math.Float32frombits(0x7fc00000))}, + }, + "[float16/NaN/LSB]": { + In: withDefiniteList([]byte{7<<5 | major7Float16, 0x7c, 1}), + Expect: List{Float32(math.Float32frombits(0x7f802000))}, + }, + "[float32]": { + In: withDefiniteList([]byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}), + Expect: List{Float32(math.Float32frombits(0x7f800000))}, + }, + "[float64]": { + In: withDefiniteList([]byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}), + Expect: List{Float64(math.Float64frombits(0x7ff00000_00000000))}, + }, + "[_ uint/0/min]": { + In: withIndefiniteList([]byte{0<<5 | 0}), + Expect: List{Uint(0)}, + }, + "[_ uint/0/max]": { + In: withIndefiniteList([]byte{0<<5 | 23}), + Expect: List{Uint(23)}, + }, + "[_ uint/1/min]": { + In: withIndefiniteList([]byte{0<<5 | 24, 0}), + Expect: List{Uint(0)}, + }, + "[_ uint/1/max]": { + In: withIndefiniteList([]byte{0<<5 | 24, 0xff}), + Expect: List{Uint(0xff)}, + }, + "[_ uint/2/min]": { + In: withIndefiniteList([]byte{0<<5 | 25, 0, 0}), + Expect: List{Uint(0)}, + }, + "[_ uint/2/max]": { + In: withIndefiniteList([]byte{0<<5 | 25, 0xff, 0xff}), + Expect: List{Uint(0xffff)}, + }, + "[_ uint/4/min]": { + In: withIndefiniteList([]byte{0<<5 | 26, 0, 0, 0, 0}), + Expect: List{Uint(0)}, + }, + "[_ uint/4/max]": { + In: withIndefiniteList([]byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: List{Uint(0xffffffff)}, + }, + "[_ uint/8/min]": { + In: withIndefiniteList([]byte{0<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: List{Uint(0)}, + }, + "[_ uint/8/max]": { + In: withIndefiniteList([]byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), + Expect: List{Uint(0xffffffff_ffffffff)}, + }, + "[_ negint/0/min]": { + In: withIndefiniteList([]byte{1<<5 | 0}), + Expect: List{NegInt(1)}, + }, + "[_ negint/0/max]": { + In: withIndefiniteList([]byte{1<<5 | 23}), + Expect: List{NegInt(24)}, + }, + "[_ negint/1/min]": { + In: withIndefiniteList([]byte{1<<5 | 24, 0}), + Expect: List{NegInt(1)}, + }, + "[_ negint/1/max]": { + In: withIndefiniteList([]byte{1<<5 | 24, 0xff}), + Expect: List{NegInt(0x100)}, + }, + "[_ negint/2/min]": { + In: withIndefiniteList([]byte{1<<5 | 25, 0, 0}), + Expect: List{NegInt(1)}, + }, + "[_ negint/2/max]": { + In: withIndefiniteList([]byte{1<<5 | 25, 0xff, 0xff}), + Expect: List{NegInt(0x10000)}, + }, + "[_ negint/4/min]": { + In: withIndefiniteList([]byte{1<<5 | 26, 0, 0, 0, 0}), + Expect: List{NegInt(1)}, + }, + "[_ negint/4/max]": { + In: withIndefiniteList([]byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: List{NegInt(0x100000000)}, + }, + "[_ negint/8/min]": { + In: withIndefiniteList([]byte{1<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: List{NegInt(1)}, + }, + "[_ negint/8/max]": { + In: withIndefiniteList([]byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}), + Expect: List{NegInt(0xffffffff_ffffffff)}, + }, + "[_ true]": { + In: withIndefiniteList([]byte{7<<5 | major7True}), + Expect: List{Bool(true)}, + }, + "[_ false]": { + In: withIndefiniteList([]byte{7<<5 | major7False}), + Expect: List{Bool(false)}, + }, + "[_ null]": { + In: withIndefiniteList([]byte{7<<5 | major7Nil}), + Expect: List{&Nil{}}, + }, + "[_ undefined]": { + In: withIndefiniteList([]byte{7<<5 | major7Undefined}), + Expect: List{&Undefined{}}, + }, + "[_ float16/+Inf]": { + In: withIndefiniteList([]byte{7<<5 | major7Float16, 0x7c, 0}), + Expect: List{Float32(math.Float32frombits(0x7f800000))}, + }, + "[_ float16/-Inf]": { + In: withIndefiniteList([]byte{7<<5 | major7Float16, 0xfc, 0}), + Expect: List{Float32(math.Float32frombits(0xff800000))}, + }, + "[_ float16/NaN/MSB]": { + In: withIndefiniteList([]byte{7<<5 | major7Float16, 0x7e, 0}), + Expect: List{Float32(math.Float32frombits(0x7fc00000))}, + }, + "[_ float16/NaN/LSB]": { + In: withIndefiniteList([]byte{7<<5 | major7Float16, 0x7c, 1}), + Expect: List{Float32(math.Float32frombits(0x7f802000))}, + }, + "[_ float32]": { + In: withIndefiniteList([]byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}), + Expect: List{Float32(math.Float32frombits(0x7f800000))}, + }, + "[_ float64]": { + In: withIndefiniteList([]byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}), + Expect: List{Float64(math.Float64frombits(0x7ff00000_00000000))}, + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer (decoded %d of %d)", n, len(c.In)) + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_Map(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "{uint/0/min}": { + In: withDefiniteMap([]byte{0<<5 | 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{uint/0/max}": { + In: withDefiniteMap([]byte{0<<5 | 23}), + Expect: Map{"foo": Uint(23)}, + }, + "{uint/1/min}": { + In: withDefiniteMap([]byte{0<<5 | 24, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{uint/1/max}": { + In: withDefiniteMap([]byte{0<<5 | 24, 0xff}), + Expect: Map{"foo": Uint(0xff)}, + }, + "{uint/2/min}": { + In: withDefiniteMap([]byte{0<<5 | 25, 0, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{uint/2/max}": { + In: withDefiniteMap([]byte{0<<5 | 25, 0xff, 0xff}), + Expect: Map{"foo": Uint(0xffff)}, + }, + "{uint/4/min}": { + In: withDefiniteMap([]byte{0<<5 | 26, 0, 0, 0, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{uint/4/max}": { + In: withDefiniteMap([]byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: Map{"foo": Uint(0xffffffff)}, + }, + "{uint/8/min}": { + In: withDefiniteMap([]byte{0<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{uint/8/max}": { + In: withDefiniteMap([]byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), + Expect: Map{"foo": Uint(0xffffffff_ffffffff)}, + }, + "{negint/0/min}": { + In: withDefiniteMap([]byte{1<<5 | 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{negint/0/max}": { + In: withDefiniteMap([]byte{1<<5 | 23}), + Expect: Map{"foo": NegInt(24)}, + }, + "{negint/1/min}": { + In: withDefiniteMap([]byte{1<<5 | 24, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{negint/1/max}": { + In: withDefiniteMap([]byte{1<<5 | 24, 0xff}), + Expect: Map{"foo": NegInt(0x100)}, + }, + "{negint/2/min}": { + In: withDefiniteMap([]byte{1<<5 | 25, 0, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{negint/2/max}": { + In: withDefiniteMap([]byte{1<<5 | 25, 0xff, 0xff}), + Expect: Map{"foo": NegInt(0x10000)}, + }, + "{negint/4/min}": { + In: withDefiniteMap([]byte{1<<5 | 26, 0, 0, 0, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{negint/4/max}": { + In: withDefiniteMap([]byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: Map{"foo": NegInt(0x100000000)}, + }, + "{negint/8/min}": { + In: withDefiniteMap([]byte{1<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{negint/8/max}": { + In: withDefiniteMap([]byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}), + Expect: Map{"foo": NegInt(0xffffffff_ffffffff)}, + }, + "{true}": { + In: withDefiniteMap([]byte{7<<5 | major7True}), + Expect: Map{"foo": Bool(true)}, + }, + "{false}": { + In: withDefiniteMap([]byte{7<<5 | major7False}), + Expect: Map{"foo": Bool(false)}, + }, + "{null}": { + In: withDefiniteMap([]byte{7<<5 | major7Nil}), + Expect: Map{"foo": &Nil{}}, + }, + "{undefined}": { + In: withDefiniteMap([]byte{7<<5 | major7Undefined}), + Expect: Map{"foo": &Undefined{}}, + }, + "{float16/+Inf}": { + In: withDefiniteMap([]byte{7<<5 | major7Float16, 0x7c, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7f800000))}, + }, + "{float16/-Inf}": { + In: withDefiniteMap([]byte{7<<5 | major7Float16, 0xfc, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0xff800000))}, + }, + "{float16/NaN/MSB}": { + In: withDefiniteMap([]byte{7<<5 | major7Float16, 0x7e, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7fc00000))}, + }, + "{float16/NaN/LSB}": { + In: withDefiniteMap([]byte{7<<5 | major7Float16, 0x7c, 1}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7f802000))}, + }, + "{float32}": { + In: withDefiniteMap([]byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7f800000))}, + }, + "{float64}": { + In: withDefiniteMap([]byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}), + Expect: Map{"foo": Float64(math.Float64frombits(0x7ff00000_00000000))}, + }, + "{_ uint/0/min}": { + In: withIndefiniteMap([]byte{0<<5 | 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{_ uint/0/max}": { + In: withIndefiniteMap([]byte{0<<5 | 23}), + Expect: Map{"foo": Uint(23)}, + }, + "{_ uint/1/min}": { + In: withIndefiniteMap([]byte{0<<5 | 24, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{_ uint/1/max}": { + In: withIndefiniteMap([]byte{0<<5 | 24, 0xff}), + Expect: Map{"foo": Uint(0xff)}, + }, + "{_ uint/2/min}": { + In: withIndefiniteMap([]byte{0<<5 | 25, 0, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{_ uint/2/max}": { + In: withIndefiniteMap([]byte{0<<5 | 25, 0xff, 0xff}), + Expect: Map{"foo": Uint(0xffff)}, + }, + "{_ uint/4/min}": { + In: withIndefiniteMap([]byte{0<<5 | 26, 0, 0, 0, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{_ uint/4/max}": { + In: withIndefiniteMap([]byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: Map{"foo": Uint(0xffffffff)}, + }, + "{_ uint/8/min}": { + In: withIndefiniteMap([]byte{0<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: Map{"foo": Uint(0)}, + }, + "{_ uint/8/max}": { + In: withIndefiniteMap([]byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), + Expect: Map{"foo": Uint(0xffffffff_ffffffff)}, + }, + "{_ negint/0/min}": { + In: withIndefiniteMap([]byte{1<<5 | 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{_ negint/0/max}": { + In: withIndefiniteMap([]byte{1<<5 | 23}), + Expect: Map{"foo": NegInt(24)}, + }, + "{_ negint/1/min}": { + In: withIndefiniteMap([]byte{1<<5 | 24, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{_ negint/1/max}": { + In: withIndefiniteMap([]byte{1<<5 | 24, 0xff}), + Expect: Map{"foo": NegInt(0x100)}, + }, + "{_ negint/2/min}": { + In: withIndefiniteMap([]byte{1<<5 | 25, 0, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{_ negint/2/max}": { + In: withIndefiniteMap([]byte{1<<5 | 25, 0xff, 0xff}), + Expect: Map{"foo": NegInt(0x10000)}, + }, + "{_ negint/4/min}": { + In: withIndefiniteMap([]byte{1<<5 | 26, 0, 0, 0, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{_ negint/4/max}": { + In: withIndefiniteMap([]byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Expect: Map{"foo": NegInt(0x100000000)}, + }, + "{_ negint/8/min}": { + In: withIndefiniteMap([]byte{1<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0}), + Expect: Map{"foo": NegInt(1)}, + }, + "{_ negint/8/max}": { + In: withIndefiniteMap([]byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}), + Expect: Map{"foo": NegInt(0xffffffff_ffffffff)}, + }, + "{_ true}": { + In: withIndefiniteMap([]byte{7<<5 | major7True}), + Expect: Map{"foo": Bool(true)}, + }, + "{_ false}": { + In: withIndefiniteMap([]byte{7<<5 | major7False}), + Expect: Map{"foo": Bool(false)}, + }, + "{_ null}": { + In: withIndefiniteMap([]byte{7<<5 | major7Nil}), + Expect: Map{"foo": &Nil{}}, + }, + "{_ undefined}": { + In: withIndefiniteMap([]byte{7<<5 | major7Undefined}), + Expect: Map{"foo": &Undefined{}}, + }, + "{_ float16/+Inf}": { + In: withIndefiniteMap([]byte{7<<5 | major7Float16, 0x7c, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7f800000))}, + }, + "{_ float16/-Inf}": { + In: withIndefiniteMap([]byte{7<<5 | major7Float16, 0xfc, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0xff800000))}, + }, + "{_ float16/NaN/MSB}": { + In: withIndefiniteMap([]byte{7<<5 | major7Float16, 0x7e, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7fc00000))}, + }, + "{_ float16/NaN/LSB}": { + In: withIndefiniteMap([]byte{7<<5 | major7Float16, 0x7c, 1}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7f802000))}, + }, + "{_ float32}": { + In: withIndefiniteMap([]byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}), + Expect: Map{"foo": Float32(math.Float32frombits(0x7f800000))}, + }, + "{_ float64}": { + In: withIndefiniteMap([]byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}), + Expect: Map{"foo": Float64(math.Float64frombits(0x7ff00000_00000000))}, + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer (decoded %d of %d)", n, len(c.In)) + } + assertValue(t, c.Expect, actual) + }) + } +} + +func TestDecode_Tag(t *testing.T) { + for name, c := range map[string]struct { + In []byte + Expect Value + }{ + "0/min": { + In: []byte{6<<5 | 0, 1}, + Expect: &Tag{0, Uint(1)}, + }, + "0/max": { + In: []byte{6<<5 | 23, 1}, + Expect: &Tag{23, Uint(1)}, + }, + "1/min": { + In: []byte{6<<5 | 24, 0, 1}, + Expect: &Tag{0, Uint(1)}, + }, + "1/max": { + In: []byte{6<<5 | 24, 0xff, 1}, + Expect: &Tag{0xff, Uint(1)}, + }, + "2/min": { + In: []byte{6<<5 | 25, 0, 0, 1}, + Expect: &Tag{0, Uint(1)}, + }, + "2/max": { + In: []byte{6<<5 | 25, 0xff, 0xff, 1}, + Expect: &Tag{0xffff, Uint(1)}, + }, + "4/min": { + In: []byte{6<<5 | 26, 0, 0, 0, 0, 1}, + Expect: &Tag{0, Uint(1)}, + }, + "4/max": { + In: []byte{6<<5 | 26, 0xff, 0xff, 0xff, 0xff, 1}, + Expect: &Tag{0xffffffff, Uint(1)}, + }, + "8/min": { + In: []byte{6<<5 | 27, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + Expect: &Tag{0, Uint(1)}, + }, + "8/max": { + In: []byte{6<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 1}, + Expect: &Tag{0xffffffff_ffffffff, Uint(1)}, + }, + } { + t.Run(name, func(t *testing.T) { + actual, n, err := decode(c.In) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + if n != len(c.In) { + t.Errorf("didn't decode whole buffer (decoded %d of %d)", n, len(c.In)) + } + assertValue(t, c.Expect, actual) + }) + } +} + +func assertValue(t *testing.T, e, a Value) { + switch v := e.(type) { + case Uint, NegInt, Slice, String, Bool, *Nil, *Undefined: + if !reflect.DeepEqual(e, a) { + t.Errorf("%v != %v", e, a) + } + case List: + assertList(t, v, a) + case Map: + assertMap(t, v, a) + case *Tag: + assertTag(t, v, a) + case Float32: + assertMajor7Float32(t, v, a) + case Float64: + assertMajor7Float64(t, v, a) + default: + t.Errorf("unrecognized variant %T", e) + } +} + +func assertList(t *testing.T, e List, a Value) { + av, ok := a.(List) + if !ok { + t.Errorf("%T != %T", e, a) + return + } + + if len(e) != len(av) { + t.Errorf("length %d != %d", len(e), len(av)) + return + } + + for i := 0; i < len(e); i++ { + assertValue(t, e[i], av[i]) + } +} + +func assertMap(t *testing.T, e Map, a Value) { + av, ok := a.(Map) + if !ok { + t.Errorf("%T != %T", e, a) + return + } + + if len(e) != len(av) { + t.Errorf("length %d != %d", len(e), len(av)) + return + } + + for k, ev := range e { + avv, ok := av[k] + if !ok { + t.Errorf("missing key %s", k) + return + } + + assertValue(t, ev, avv) + } +} + +func assertTag(t *testing.T, e *Tag, a Value) { + av, ok := a.(*Tag) + if !ok { + t.Errorf("%T != %T", e, a) + return + } + + if e.ID != av.ID { + t.Errorf("tag ID %d != %d", e.ID, av.ID) + return + } + + assertValue(t, e.Value, av.Value) +} + +func assertMajor7Float32(t *testing.T, e Float32, a Value) { + av, ok := a.(Float32) + if !ok { + t.Errorf("%T != %T", e, a) + return + } + + if math.Float32bits(float32(e)) != math.Float32bits(float32(av)) { + t.Errorf("float32(%x) != float32(%x)", e, av) + } +} + +func assertMajor7Float64(t *testing.T, e Float64, a Value) { + av, ok := a.(Float64) + if !ok { + t.Errorf("%T != %T", e, a) + return + } + + if math.Float64bits(float64(e)) != math.Float64bits(float64(av)) { + t.Errorf("float64(%x) != float64(%x)", e, av) + } +} + +var mapKeyFoo = []byte{0x63, 0x66, 0x6f, 0x6f} + +func withDefiniteList(p []byte) []byte { + return append([]byte{4<<5 | 1}, p...) +} + +func withIndefiniteList(p []byte) []byte { + p = append([]byte{4<<5 | 31}, p...) + return append(p, 0xff) +} + +func withDefiniteMap(p []byte) []byte { + head := append([]byte{5<<5 | 1}, mapKeyFoo...) + return append(head, p...) +} + +func withIndefiniteMap(p []byte) []byte { + head := append([]byte{5<<5 | 31}, mapKeyFoo...) + p = append(head, p...) + return append(p, 0xff) +} diff --git a/encoding/cbor/encode.go b/encoding/cbor/encode.go new file mode 100644 index 000000000..646f5b68c --- /dev/null +++ b/encoding/cbor/encode.go @@ -0,0 +1,218 @@ +package cbor + +import ( + "encoding/binary" + "math" +) + +func (i Uint) len() int { + return itoarglen(uint64(i)) +} + +func (i Uint) encode(p []byte) int { + return encodeArg(majorTypeUint, uint64(i), p) +} + +func (i NegInt) len() int { + return itoarglen(uint64(i) - 1) +} + +func (i NegInt) encode(p []byte) int { + return encodeArg(majorTypeNegInt, uint64(i-1), p) +} + +func (s Slice) len() int { + return itoarglen(len(s)) + len(s) +} + +func (s Slice) encode(p []byte) int { + off := encodeArg(majorTypeSlice, len(s), p) + copy(p[off:], []byte(s)) + return off + len(s) +} + +func (s String) len() int { + return itoarglen(len(s)) + len(s) +} + +func (s String) encode(p []byte) int { + off := encodeArg(majorTypeString, len(s), p) + copy(p[off:], []byte(s)) + return off + len(s) +} + +func (l List) len() int { + total := itoarglen(len(l)) + for _, v := range l { + total += v.len() + } + return total +} + +func (l List) encode(p []byte) int { + off := encodeArg(majorTypeList, len(l), p) + for _, v := range l { + off += v.encode(p[off:]) + } + return off +} + +func (m Map) len() int { + total := itoarglen(len(m)) + for k, v := range m { + total += String(k).len() + v.len() + } + return total +} + +func (m Map) encode(p []byte) int { + off := encodeArg(majorTypeMap, len(m), p) + for k, v := range m { + off += String(k).encode(p[off:]) + off += v.encode(p[off:]) + } + return off +} + +func (t Tag) len() int { + return itoarglen(t.ID) + t.Value.len() +} + +func (t Tag) encode(p []byte) int { + off := encodeArg(majorTypeTag, t.ID, p) + return off + t.Value.encode(p[off:]) +} + +func (b Bool) len() int { + return 1 +} + +func (b Bool) encode(p []byte) int { + if b { + p[0] = compose(majorType7, major7True) + } else { + p[0] = compose(majorType7, major7False) + } + return 1 +} + +func (*Nil) len() int { + return 1 +} + +func (*Nil) encode(p []byte) int { + p[0] = compose(majorType7, major7Nil) + return 1 +} + +func (*Undefined) len() int { + return 1 +} + +func (*Undefined) encode(p []byte) int { + p[0] = compose(majorType7, major7Undefined) + return 1 +} + +func (f Float32) len() int { + return 5 +} + +func (f Float32) encode(p []byte) int { + p[0] = compose(majorType7, major7Float32) + binary.BigEndian.PutUint32(p[1:], math.Float32bits(float32(f))) + return 5 +} + +func (f Float64) len() int { + return 9 +} + +func (f Float64) encode(p []byte) int { + p[0] = compose(majorType7, major7Float64) + binary.BigEndian.PutUint64(p[1:], math.Float64bits(float64(f))) + return 9 +} + +func compose(major majorType, minor byte) byte { + return byte(major)<<5 | minor +} + +func itoarglen[I int | uint64](v I) int { + vv := uint64(v) + if vv < 24 { + return 1 // type and len in single byte + } else if vv < 0x100 { + return 2 // type + 1-byte len + } else if vv < 0x10000 { + return 3 // type + 2-byte len + } else if vv < 0x100000000 { + return 5 // type + 4-byte len + } + return 9 // type + 8-byte len +} + +func encodeArg[I int | uint64](t majorType, arg I, p []byte) int { + aarg := uint64(arg) + if aarg < 24 { + p[0] = byte(t)<<5 | byte(aarg) + return 1 + } else if aarg < 0x100 { + p[0] = compose(t, minorArg1) + p[1] = byte(aarg) + return 2 + } else if aarg < 0x10000 { + p[0] = compose(t, minorArg2) + binary.BigEndian.PutUint16(p[1:], uint16(aarg)) + return 3 + } else if aarg < 0x100000000 { + p[0] = compose(t, minorArg4) + binary.BigEndian.PutUint32(p[1:], uint32(aarg)) + return 5 + } + + p[0] = compose(t, minorArg8) + binary.BigEndian.PutUint64(p[1:], uint64(aarg)) + return 9 +} + +// EncodeRaw encodes opaque CBOR data. +// +// This is used by the encoder for the purpose of embedding document shapes. +// Decode will never return values of this type. +type EncodeRaw []byte + +func (v EncodeRaw) len() int { return len(v) } + +func (v EncodeRaw) encode(p []byte) int { + copy(p, v) + return len(v) +} + +// FixedUint encodes fixed-width Uint values. +// +// This is used by the encoder for the purpose of embedding integrals in +// document shapes. Decode will never return values of this type. +type EncodeFixedUint uint64 + +func (EncodeFixedUint) len() int { return 9 } + +func (v EncodeFixedUint) encode(p []byte) int { + p[0] = compose(majorTypeUint, minorArg8) + binary.BigEndian.PutUint64(p[1:], uint64(v)) + return 9 +} + +// FixedUint encodes fixed-width NegInt values. +// +// This is used by the encoder for the purpose of embedding integrals in +// document shapes. Decode will never return values of this type. +type EncodeFixedNegInt uint64 + +func (EncodeFixedNegInt) len() int { return 9 } + +func (v EncodeFixedNegInt) encode(p []byte) int { + p[0] = compose(majorTypeNegInt, minorArg8) + binary.BigEndian.PutUint64(p[1:], uint64(v-1)) + return 9 +} diff --git a/encoding/cbor/encode_test.go b/encoding/cbor/encode_test.go new file mode 100644 index 000000000..204ee26d0 --- /dev/null +++ b/encoding/cbor/encode_test.go @@ -0,0 +1,466 @@ +package cbor + +import ( + "bytes" + "encoding/hex" + "math" + "testing" +) + +func TestEncode_Atomic(t *testing.T) { + for name, c := range map[string]struct { + Expect []byte + In Value + }{ + "uint/0/min": { + []byte{0<<5 | 0}, + Uint(0), + }, + "uint/0/max": { + []byte{0<<5 | 23}, + Uint(23), + }, + "uint/1/min": { + []byte{0<<5 | 24, 24}, + Uint(24), + }, + "uint/1/max": { + []byte{0<<5 | 24, 0xff}, + Uint(0xff), + }, + "uint/2/min": { + []byte{0<<5 | 25, 1, 0}, + Uint(0x100), + }, + "uint/2/max": { + []byte{0<<5 | 25, 0xff, 0xff}, + Uint(0xffff), + }, + "uint/4/min": { + []byte{0<<5 | 26, 1, 0, 0, 0}, + Uint(0x1000000), + }, + "uint/4/max": { + []byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}, + Uint(0xffffffff), + }, + "uint/8/min": { + []byte{0<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0}, + Uint(0x1000000_00000000), + }, + "uint/8/max": { + []byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + Uint(0xffffffff_ffffffff), + }, + "negint/0/min": { + []byte{1<<5 | 0}, + NegInt(1), + }, + "negint/0/max": { + []byte{1<<5 | 23}, + NegInt(24), + }, + "negint/1/min": { + []byte{1<<5 | 24, 24}, + NegInt(25), + }, + "negint/1/max": { + []byte{1<<5 | 24, 0xff}, + NegInt(0x100), + }, + "negint/2/min": { + []byte{1<<5 | 25, 1, 0}, + NegInt(0x101), + }, + "negint/2/max": { + []byte{1<<5 | 25, 0xff, 0xff}, + NegInt(0x10000), + }, + "negint/4/min": { + []byte{1<<5 | 26, 1, 0, 0, 0}, + NegInt(0x1000001), + }, + "negint/4/max": { + []byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}, + NegInt(0x100000000), + }, + "negint/8/min": { + []byte{1<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0}, + NegInt(0x1000000_00000001), + }, + "negint/8/max": { + []byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + NegInt(0xffffffff_ffffffff), + }, + "true": { + []byte{7<<5 | major7True}, + Bool(true), + }, + "false": { + []byte{7<<5 | major7False}, + Bool(false), + }, + "null": { + []byte{7<<5 | major7Nil}, + &Nil{}, + }, + "undefined": { + []byte{7<<5 | major7Undefined}, + &Undefined{}, + }, + "float32": { + []byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}, + Float32(math.Float32frombits(0x7f800000)), + }, + "float64": { + []byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}, + Float64(math.Float64frombits(0x7ff00000_00000000)), + }, + } { + t.Run(name, func(t *testing.T) { + actual := Encode(c.In) + if !bytes.Equal(c.Expect, actual) { + t.Errorf("bytes not equal (%s != %s)", hex.EncodeToString(c.Expect), hex.EncodeToString(actual)) + } + }) + } +} + +func TestEncode_Slice(t *testing.T) { + for name, c := range map[string]struct { + Expect []byte + In Value + }{ + "len = 0": { + []byte{2<<5 | 0}, + Slice{}, + }, + "len > 0": { + []byte{2<<5 | 3, 0x66, 0x6f, 0x6f}, + Slice{0x66, 0x6f, 0x6f}, + }, + } { + t.Run(name, func(t *testing.T) { + actual := Encode(c.In) + if !bytes.Equal(c.Expect, actual) { + t.Errorf("bytes not equal (%s != %s)", hex.EncodeToString(c.Expect), hex.EncodeToString(actual)) + } + }) + } +} + +func TestEncode_String(t *testing.T) { + for name, c := range map[string]struct { + Expect []byte + In Value + }{ + "len = 0": { + []byte{3<<5 | 0}, + String(""), + }, + "len > 0": { + []byte{3<<5 | 3, 0x66, 0x6f, 0x6f}, + String("foo"), + }, + } { + t.Run(name, func(t *testing.T) { + actual := Encode(c.In) + if !bytes.Equal(c.Expect, actual) { + t.Errorf("bytes not equal (%s != %s)", hex.EncodeToString(c.Expect), hex.EncodeToString(actual)) + } + }) + } +} + +func TestEncode_List(t *testing.T) { + for name, c := range map[string]struct { + Expect []byte + In Value + }{ + "[uint/0/min]": { + withDefiniteList([]byte{0<<5 | 0}), + List{Uint(0)}, + }, + "[uint/0/max]": { + withDefiniteList([]byte{0<<5 | 23}), + List{Uint(23)}, + }, + "[uint/1/min]": { + withDefiniteList([]byte{0<<5 | 24, 24}), + List{Uint(24)}, + }, + "[uint/1/max]": { + withDefiniteList([]byte{0<<5 | 24, 0xff}), + List{Uint(0xff)}, + }, + "[uint/2/min]": { + withDefiniteList([]byte{0<<5 | 25, 1, 0}), + List{Uint(0x100)}, + }, + "[uint/2/max]": { + withDefiniteList([]byte{0<<5 | 25, 0xff, 0xff}), + List{Uint(0xffff)}, + }, + "[uint/4/min]": { + withDefiniteList([]byte{0<<5 | 26, 1, 0, 0, 0}), + List{Uint(0x1000000)}, + }, + "[uint/4/max]": { + withDefiniteList([]byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + List{Uint(0xffffffff)}, + }, + "[uint/8/min]": { + withDefiniteList([]byte{0<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0}), + List{Uint(0x1000000_00000000)}, + }, + "[uint/8/max]": { + withDefiniteList([]byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), + List{Uint(0xffffffff_ffffffff)}, + }, + "[negint/0/min]": { + withDefiniteList([]byte{1<<5 | 0}), + List{NegInt(1)}, + }, + "[negint/0/max]": { + withDefiniteList([]byte{1<<5 | 23}), + List{NegInt(24)}, + }, + "[negint/1/min]": { + withDefiniteList([]byte{1<<5 | 24, 24}), + List{NegInt(25)}, + }, + "[negint/1/max]": { + withDefiniteList([]byte{1<<5 | 24, 0xff}), + List{NegInt(0x100)}, + }, + "[negint/2/min]": { + withDefiniteList([]byte{1<<5 | 25, 1, 0}), + List{NegInt(0x101)}, + }, + "[negint/2/max]": { + withDefiniteList([]byte{1<<5 | 25, 0xff, 0xff}), + List{NegInt(0x10000)}, + }, + "[negint/4/min]": { + withDefiniteList([]byte{1<<5 | 26, 1, 0, 0, 0}), + List{NegInt(0x1000001)}, + }, + "[negint/4/max]": { + withDefiniteList([]byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + List{NegInt(0x100000000)}, + }, + "[negint/8/min]": { + withDefiniteList([]byte{1<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0}), + List{NegInt(0x1000000_00000001)}, + }, + "[negint/8/max]": { + withDefiniteList([]byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}), + List{NegInt(0xffffffff_ffffffff)}, + }, + "[true]": { + withDefiniteList([]byte{7<<5 | major7True}), + List{Bool(true)}, + }, + "[false]": { + withDefiniteList([]byte{7<<5 | major7False}), + List{Bool(false)}, + }, + "[null]": { + withDefiniteList([]byte{7<<5 | major7Nil}), + List{&Nil{}}, + }, + "[undefined]": { + withDefiniteList([]byte{7<<5 | major7Undefined}), + List{&Undefined{}}, + }, + "[float32]": { + withDefiniteList([]byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}), + List{Float32(math.Float32frombits(0x7f800000))}, + }, + "[float64]": { + withDefiniteList([]byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}), + List{Float64(math.Float64frombits(0x7ff00000_00000000))}, + }, + } { + t.Run(name, func(t *testing.T) { + actual := Encode(c.In) + if !bytes.Equal(c.Expect, actual) { + t.Errorf("bytes not equal (%s != %s)", hex.EncodeToString(c.Expect), hex.EncodeToString(actual)) + } + }) + } +} + +func TestEncode_Map(t *testing.T) { + for name, c := range map[string]struct { + Expect []byte + In Value + }{ + "{uint/0/min}": { + withDefiniteMap([]byte{0<<5 | 0}), + Map{"foo": Uint(0)}, + }, + "{uint/0/max}": { + withDefiniteMap([]byte{0<<5 | 23}), + Map{"foo": Uint(23)}, + }, + "{uint/1/min}": { + withDefiniteMap([]byte{0<<5 | 24, 24}), + Map{"foo": Uint(24)}, + }, + "{uint/1/max}": { + withDefiniteMap([]byte{0<<5 | 24, 0xff}), + Map{"foo": Uint(0xff)}, + }, + "{uint/2/min}": { + withDefiniteMap([]byte{0<<5 | 25, 1, 0}), + Map{"foo": Uint(0x100)}, + }, + "{uint/2/max}": { + withDefiniteMap([]byte{0<<5 | 25, 0xff, 0xff}), + Map{"foo": Uint(0xffff)}, + }, + "{uint/4/min}": { + withDefiniteMap([]byte{0<<5 | 26, 1, 0, 0, 0}), + Map{"foo": Uint(0x1000000)}, + }, + "{uint/4/max}": { + withDefiniteMap([]byte{0<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Map{"foo": Uint(0xffffffff)}, + }, + "{uint/8/min}": { + withDefiniteMap([]byte{0<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0}), + Map{"foo": Uint(0x1000000_00000000)}, + }, + "{uint/8/max}": { + withDefiniteMap([]byte{0<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), + Map{"foo": Uint(0xffffffff_ffffffff)}, + }, + "{negint/0/min}": { + withDefiniteMap([]byte{1<<5 | 0}), + Map{"foo": NegInt(1)}, + }, + "{negint/0/max}": { + withDefiniteMap([]byte{1<<5 | 23}), + Map{"foo": NegInt(24)}, + }, + "{negint/1/min}": { + withDefiniteMap([]byte{1<<5 | 24, 24}), + Map{"foo": NegInt(25)}, + }, + "{negint/1/max}": { + withDefiniteMap([]byte{1<<5 | 24, 0xff}), + Map{"foo": NegInt(0x100)}, + }, + "{negint/2/min}": { + withDefiniteMap([]byte{1<<5 | 25, 1, 0}), + Map{"foo": NegInt(0x101)}, + }, + "{negint/2/max}": { + withDefiniteMap([]byte{1<<5 | 25, 0xff, 0xff}), + Map{"foo": NegInt(0x10000)}, + }, + "{negint/4/min}": { + withDefiniteMap([]byte{1<<5 | 26, 1, 0, 0, 0}), + Map{"foo": NegInt(0x1000001)}, + }, + "{negint/4/max}": { + withDefiniteMap([]byte{1<<5 | 26, 0xff, 0xff, 0xff, 0xff}), + Map{"foo": NegInt(0x100000000)}, + }, + "{negint/8/min}": { + withDefiniteMap([]byte{1<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0}), + Map{"foo": NegInt(0x1000000_00000001)}, + }, + "{negint/8/max}": { + withDefiniteMap([]byte{1<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}), + Map{"foo": NegInt(0xffffffff_ffffffff)}, + }, + "{true}": { + withDefiniteMap([]byte{7<<5 | major7True}), + Map{"foo": Bool(true)}, + }, + "{false}": { + withDefiniteMap([]byte{7<<5 | major7False}), + Map{"foo": Bool(false)}, + }, + "{null}": { + withDefiniteMap([]byte{7<<5 | major7Nil}), + Map{"foo": &Nil{}}, + }, + "{undefined}": { + withDefiniteMap([]byte{7<<5 | major7Undefined}), + Map{"foo": &Undefined{}}, + }, + "{float32}": { + withDefiniteMap([]byte{7<<5 | major7Float32, 0x7f, 0x80, 0, 0}), + Map{"foo": Float32(math.Float32frombits(0x7f800000))}, + }, + "{float64}": { + withDefiniteMap([]byte{7<<5 | major7Float64, 0x7f, 0xf0, 0, 0, 0, 0, 0, 0}), + Map{"foo": Float64(math.Float64frombits(0x7ff00000_00000000))}, + }, + } { + t.Run(name, func(t *testing.T) { + actual := Encode(c.In) + if !bytes.Equal(c.Expect, actual) { + t.Errorf("bytes not equal (%s != %s)", hex.EncodeToString(c.Expect), hex.EncodeToString(actual)) + } + }) + } +} + +func TestEncode_Tag(t *testing.T) { + for name, c := range map[string]struct { + Expect []byte + In Value + }{ + "0/min": { + []byte{6<<5 | 0, 1}, + &Tag{0, Uint(1)}, + }, + "0/max": { + []byte{6<<5 | 23, 1}, + &Tag{23, Uint(1)}, + }, + "1/min": { + []byte{6<<5 | 24, 24, 1}, + &Tag{24, Uint(1)}, + }, + "1/max": { + []byte{6<<5 | 24, 0xff, 1}, + &Tag{0xff, Uint(1)}, + }, + "2/min": { + []byte{6<<5 | 25, 1, 0, 1}, + &Tag{0x100, Uint(1)}, + }, + "2/max": { + []byte{6<<5 | 25, 0xff, 0xff, 1}, + &Tag{0xffff, Uint(1)}, + }, + "4/min": { + []byte{6<<5 | 26, 1, 0, 0, 0, 1}, + &Tag{0x1000000, Uint(1)}, + }, + "4/max": { + []byte{6<<5 | 26, 0xff, 0xff, 0xff, 0xff, 1}, + &Tag{0xffffffff, Uint(1)}, + }, + "8/min": { + []byte{6<<5 | 27, 1, 0, 0, 0, 0, 0, 0, 0, 1}, + &Tag{0x1000000_00000000, Uint(1)}, + }, + "8/max": { + []byte{6<<5 | 27, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 1}, + &Tag{0xffffffff_ffffffff, Uint(1)}, + }, + } { + t.Run(name, func(t *testing.T) { + actual := Encode(c.In) + if !bytes.Equal(c.Expect, actual) { + t.Errorf("bytes not equal (%s != %s)", hex.EncodeToString(c.Expect), hex.EncodeToString(actual)) + } + }) + } +} diff --git a/encoding/cbor/float16.go b/encoding/cbor/float16.go new file mode 100644 index 000000000..081eea705 --- /dev/null +++ b/encoding/cbor/float16.go @@ -0,0 +1,45 @@ +package cbor + +func float16to32(f uint16) uint32 { + sign, exp, mant := splitf16(f) + if exp == 0x1f { + return sign | 0xff<<23 | mant // infinity/NaN + } + + if exp == 0 { // subnormal + if mant == 0 { + return sign + } + return normalize(sign, mant) + } + + return sign | (exp+127-15)<<23 | mant // rebias exp by the difference between the two +} + +func splitf16(f uint16) (sign, exp, mantissa uint32) { + const smask = 0x1 << 15 // put sign in float32 position + const emask = 0x1f << 10 // pull exponent as a number (for bias shift) + const mmask = 0x3ff // put mantissa in float32 position + + return uint32(f&smask) << 16, uint32(f&emask) >> 10, uint32(f&mmask) << 13 +} + +// moves a float16 normal into normal float32 space +// to do this we must re-express the float16 mantissa in terms of a normal +// float32 where the hidden bit is 1, e.g. +// +// f16: 0 00000 0001010000 = 0.000101 * 2^(-14), which is equal to +// f32: 0 01101101 01000000000000000000000 = 1.01 * 2^(-18) +// +// this is achieved by shifting the mantissa to the right until the leading bit +// that == 1 reaches position 24, then the number of positions shifted over is +// equal to the offset from the subnormal exponent +func normalize(sign, mant uint32) uint32 { + exp := uint32(-14 + 127) // f16 subnormal exp, with f32 bias + for mant&0x800000 == 0 { // repeat until bit 24 ("hidden" mantissa) is 1 + mant <<= 1 + exp-- // tracking the offset + } + mant &= 0x7fffff // remask to 23bit + return sign | exp<<23 | mant +} diff --git a/encoding/cbor/float16_test.go b/encoding/cbor/float16_test.go new file mode 100644 index 000000000..ef97f3acc --- /dev/null +++ b/encoding/cbor/float16_test.go @@ -0,0 +1,41 @@ +package cbor + +import ( + "testing" +) + +func TestFloat16To32(t *testing.T) { + for name, c := range map[string]struct { + In uint16 + Expect uint32 + }{ + "+infinity": { + 0b0_11111_0000000000, + 0b0_11111111_00000000000000000000000, + }, + "-infinity": { + 0b1_11111_0000000000, + 0b1_11111111_00000000000000000000000, + }, + "NaN": { + 0b0_11111_0101010101, + 0b0_11111111_01010101010000000000000, + }, + "absolute zero": {0, 0}, + "subnormal": { + 0b0_00000_0001010000, + 0b0_01101101_01000000000000000000000, + }, + "normal": { + 0b0_00001_0001010000, + 0b0_0001110001_00010100000000000000000, + }, + } { + t.Run(name, func(t *testing.T) { + if actual := float16to32(c.In); c.Expect != actual { + t.Errorf("%x != %x", c.Expect, actual) + } + }) + } + +} diff --git a/encoding/cbor/fuzz_test.go b/encoding/cbor/fuzz_test.go new file mode 100644 index 000000000..03dd1e231 --- /dev/null +++ b/encoding/cbor/fuzz_test.go @@ -0,0 +1,114 @@ +//go:build fuzz +// +build fuzz + +package cbor + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "testing" +) + +// caught by fuzz: +// - broken typecast from uint64 to int when checking encoded string(mt2,3) length vs buflen +// - huge encoded list/map sizes would cause panics on make() +// - map declaration at end of buffer would attempt to peek p[0] when len(p) == 0 + +func TestDecode_Fuzz(t *testing.T) { + const runs = 1_000_000 + const buflen = 512 + + p := make([]byte, buflen) + + defer func() { + if err := recover(); err != nil { + fmt.Println(hex.EncodeToString(p)) + dump(p) + + t.Fatalf("decode panic: %v\n", err) + } + }() + + for i := 0; i < runs; i++ { + if _, err := rand.Read(p); err != nil { + t.Fatalf("create randbuf: %v", err) + } + + decode(p) + } +} + +func dump(p []byte) { + for len(p) > 0 { + var off int + + major, minor := peekMajor(p), peekMinor(p) + switch major { + case majorTypeUint, majorTypeNegInt, majorType7: + if minor > 27 { + fmt.Printf("%d, %d (invalid)\n", major, minor) + return + } + + arg, n, err := decodeArgument(p) + if err != nil { + panic(err) + } + + fmt.Printf("%d, %d\n", major, arg) + off = n + case majorTypeSlice, majorTypeString: + if minor == 31 { + panic("todo") + } else if minor > 27 { + fmt.Printf("%d, %d (invalid)\n", major, minor) + return + } + + arg, n, err := decodeArgument(p) + if err != nil { + panic(err) + } + + fmt.Printf("str(%d), len %d\n", major, arg) + off = n + int(arg) + case majorTypeList, majorTypeMap: + if minor == 31 { + panic("todo") + } else if minor > 27 { + fmt.Printf("%d, %d (invalid)\n", major, minor) + return + } + + arg, n, err := decodeArgument(p) + if err != nil { + panic(err) + } + + fmt.Printf("container(%d), len %d\n", major, arg) + off = n + case majorTypeTag: + if minor > 27 { + fmt.Printf("tag, %d (invalid)\n", minor) + return + } + + arg, n, err := decodeArgument(p) + if err != nil { + panic(err) + } + + fmt.Printf("tag, %d\n", arg) + off = n + } + + if off > len(p) { + fmt.Println("overflow, stop") + return + } + p = p[off:] + } + + fmt.Println("EOF") +} diff --git a/testing/cbor.go b/testing/cbor.go new file mode 100644 index 000000000..6e8b1b521 --- /dev/null +++ b/testing/cbor.go @@ -0,0 +1,169 @@ +package testing + +import ( + "encoding/base64" + "fmt" + "io" + "math" + "reflect" + + "github.com/aws/smithy-go/encoding/cbor" +) + +// CompareCBOR checks whether two CBOR values are equivalent. +// +// The function signature is tailored for use in smithy protocol tests, where +// the expected encoding is given in base64, and the actual value to check is +// passed from the mock HTTP request body. +func CompareCBOR(actual io.Reader, expect64 string) error { + ap, err := io.ReadAll(actual) + if err != nil { + return fmt.Errorf("read actual: %w", err) + } + + av, err := cbor.Decode(ap) + if err != nil { + return fmt.Errorf("decode actual: %w", err) + } + + ep, err := base64.StdEncoding.DecodeString(expect64) + if err != nil { + return fmt.Errorf("decode expect64: %w", err) + } + + ev, err := cbor.Decode(ep) + if err != nil { + return fmt.Errorf("decode expect: %w", err) + } + + return cmpCBOR(ev, av, "") +} + +func cmpCBOR(e, a cbor.Value, path string) error { + switch v := e.(type) { + case cbor.Uint, cbor.NegInt, cbor.Slice, cbor.String, cbor.Bool, *cbor.Nil, *cbor.Undefined: + if !reflect.DeepEqual(e, a) { + return fmt.Errorf("%s: %v != %v", path, e, a) + } + return nil + case cbor.List: + return cmpList(v, a, path) + case cbor.Map: + return cmpMap(v, a, path) + case *cbor.Tag: + return cmpTag(v, a, path) + case cbor.Float32: + return cmpF32(v, a, path) + case cbor.Float64: + return cmpF64(v, a, path) + default: + return fmt.Errorf("%s: unrecognized variant %T", path, e) + } +} + +func cmpList(e cbor.List, a cbor.Value, path string) error { + av, ok := a.(cbor.List) + if !ok { + return fmt.Errorf("%s: %T != %T", path, e, a) + } + + if len(e) != len(av) { + return fmt.Errorf("%s: length %d != %d", path, len(e), len(av)) + } + + for i := 0; i < len(e); i++ { + ipath := fmt.Sprintf("%s[%d]", path, i) + if err := cmpCBOR(e[i], av[i], ipath); err != nil { + return err + } + } + return nil +} + +func cmpMap(e cbor.Map, a cbor.Value, path string) error { + av, ok := a.(cbor.Map) + if !ok { + return fmt.Errorf("%s: %T != %T", path, e, a) + } + + if len(e) != len(av) { + return fmt.Errorf("%s: length %d != %d", path, len(e), len(av)) + } + + for k, ev := range e { + avv, ok := av[k] + if !ok { + return fmt.Errorf("%s: missing key %s", path, k) + } + + kpath := fmt.Sprintf("%s[%q]", path, k) + if err := cmpCBOR(ev, avv, kpath); err != nil { + return err + } + } + return nil +} + +func cmpTag(e *cbor.Tag, a cbor.Value, path string) error { + av, ok := a.(*cbor.Tag) + if !ok { + return fmt.Errorf("%s: %T != %T", path, e, a) + } + + if e.ID != av.ID { + return fmt.Errorf("%s: tag ID %d != %d", path, e.ID, av.ID) + } + return cmpCBOR(e.Value, av.Value, path) +} + +func cmpF32(e cbor.Float32, a cbor.Value, path string) error { + av, ok := a.(cbor.Float32) + if !ok { + return fmt.Errorf("%s: %T != %T", path, e, a) + } + + ebits, abits := math.Float32bits(float32(e)), math.Float32bits(float32(av)) + if enan, anan := isNaN32(ebits), isNaN32(abits); enan || anan { + if enan != anan { + return fmt.Errorf("%s: NaN: float32(%x) != float32(%x)", path, ebits, abits) + } + return nil + } + + if ebits != abits { + return fmt.Errorf("%s: float32(%x) != float32(%x)", path, ebits, abits) + } + return nil +} + +func cmpF64(e cbor.Float64, a cbor.Value, path string) error { + av, ok := a.(cbor.Float64) + if !ok { + return fmt.Errorf("%s: %T != %T", path, e, a) + } + + ebits, abits := math.Float64bits(float64(e)), math.Float64bits(float64(av)) + if enan, anan := isNaN64(ebits), isNaN64(abits); enan || anan { + if enan != anan { + return fmt.Errorf("%s: NaN: float64(%x) != float64(%x)", path, ebits, abits) + } + return nil + } + + if math.Float64bits(float64(e)) != math.Float64bits(float64(av)) { + return fmt.Errorf("%s: float64(%x) != float64(%x)", path, ebits, abits) + } + return nil +} + +func isNaN32(f uint32) bool { + const infmask = 0x7f800000 + + return f&infmask == infmask && f != infmask && f != (1<<31)|infmask +} + +func isNaN64(f uint64) bool { + const infmask = 0x7ff00000_00000000 + + return f&infmask == infmask && f != infmask && f != (1<<63)|infmask +} diff --git a/testing/struct.go b/testing/struct.go index 3b08f9c2b..9aae8fe83 100644 --- a/testing/struct.go +++ b/testing/struct.go @@ -105,10 +105,16 @@ func deepEqual(expect, actual reflect.Value, path string) error { } return nil case reflect.Float32, reflect.Float64: - // NaN != NaN by definition but we just care about bitwise equality - ef, af := math.Float64bits(expect.Float()), math.Float64bits(actual.Float()) - if ef != af { - return fmt.Errorf("%s: float 0x%x != 0x%x", path, ef, af) + ef, af := expect.Float(), actual.Float() + ebits, abits := math.Float64bits(ef), math.Float64bits(af) + if enan, anan := math.IsNaN(ef), math.IsNaN(af); enan || anan { + if enan != anan { + return fmt.Errorf("%s: NaN: float64(0x%x) != float64(0x%x)", path, ebits, abits) + } + return nil + } + if ebits != abits { + return fmt.Errorf("%s: float64(0x%x) != float64(0x%x)", path, ebits, abits) } return nil default: diff --git a/testing/struct_test.go b/testing/struct_test.go index 2ec400365..c1ff11b20 100644 --- a/testing/struct_test.go +++ b/testing/struct_test.go @@ -120,7 +120,7 @@ func TestCompareValues(t *testing.T) { Bar: 123, }, }, - "float diff": { + "float diff NaN": { A: struct { Foo float64 Bar int @@ -135,8 +135,25 @@ func TestCompareValues(t *testing.T) { Foo: math.Float64frombits(float64NaN - 1), Bar: 123, }, - ExpectErr: ".Foo: float 0x7fffffffffffffff != 0x7ffffffffffffffe", }, + "float diff": { + A: struct { + Foo float64 + Bar int + }{ + Foo: math.Float64frombits(0x100), + Bar: 123, + }, + B: struct { + Foo float64 + Bar int + }{ + Foo: math.Float64frombits(0x101), + Bar: 123, + }, + ExpectErr: ".Foo: float64(0x100) != float64(0x101)", + }, + "document equal": { A: &mockDocumentMarshaler{[]byte("123"), nil}, B: &mockDocumentMarshaler{[]byte("123"), nil},