Skip to content

Commit

Permalink
DeserializeMiddleware Revision 1
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanhit committed Aug 27, 2024
1 parent 0b8259c commit 9cbcce8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

public class DeserializeMiddleware {

public static final String SMITHY_PROTOCOL_NAME = "awsJson10";
protected final ProtocolGenerator.GenerationContext ctx;
protected final OperationShape operation;
protected final GoWriter writer;
Expand All @@ -49,15 +48,15 @@ public DeserializeMiddleware(

this.input = ctx.getModel().expectShape(operation.getInputShape(), StructureShape.class);
this.output = ctx.getModel().expectShape(operation.getOutputShape(), StructureShape.class);

deserialName = getMiddlewareName(operation);
}

public static String getMiddlewareName(OperationShape operation) {
return "awsAwsjson10_deserializeOp" + operation.toShapeId().getName();
}

public GoWriter.Writable generate() {
deserialName = getMiddlewareName(operation);

return goTemplate("""
type $opName:L struct{
Expand Down Expand Up @@ -125,26 +124,15 @@ private GoWriter.Writable handleResponseChecks() {
return out, metadata, $errorf:T("unexpected transport type %T", out.RawResponse)
}
if resp.Header.Get("smithy-protocol") != $protocol:S {
return out, metadata, &$deserError:T{
Err: $errorf:T(
"unexpected smithy-protocol response header '%s' (HTTP status: %s)",
resp.Header.Get("smithy-protocol"),
resp.Status,
),
}
}
if resp.StatusCode != 200 {
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return out, metadata, &$deserError:T{}
}
""",
MapUtils.of(
"response", SMITHY_HTTP_TRANSPORT.pointableSymbol("Response"),
"errorf", GoStdlibTypes.Fmt.Errorf,
"deserError", SmithyGoDependency.SMITHY.struct("DeserializationError"),
"protocol", SMITHY_PROTOCOL_NAME
"deserError", SmithyGoDependency.SMITHY.struct("DeserializationError")
));
}

Expand All @@ -170,16 +158,16 @@ private GoWriter.Writable handlePayload() {
}
decoder := $decoder:T(resp.Body)
var cv map[string]interface{}
err = decoder.Decode(&cv)
var jv map[string]interface{}
err = decoder.Decode(&jv)
if err!= nil {
return out, metadata, err
}
output, err := $deserialize:L(cv)
if err != nil {
return out, metadata, err
}
output, err := $deserialize:L(jv)
if err != nil {
return out, metadata, err
}
out.Result = output
""",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ public SerializeMiddleware(

this.input = ctx.getModel().expectShape(operation.getInputShape(), StructureShape.class);
this.output = ctx.getModel().expectShape(operation.getOutputShape(), StructureShape.class);

serialName = getMiddlewareName(operation);
}

public static String getMiddlewareName(OperationShape operation) {
return "awsAwsjson10_serializeOp" + operation.toShapeId().getName();
}

public GoWriter.Writable generate() {
serialName = getMiddlewareName(operation);

return goTemplate("""
type $opName:L struct{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,13 @@ private GoWriter.Writable generateZeroValue(Shape shape) {

private GoWriter.Writable generateDeserializeAssertedValue(Shape shape, String ident) {
return switch (shape.getType()) {
case BYTE -> generateDeserializeIntegral(ident, "int8", "Int64", Byte.MIN_VALUE, Byte.MAX_VALUE);
case SHORT -> generateDeserializeIntegral(ident, "int16", "Int64", Short.MIN_VALUE, Short.MAX_VALUE);
case INTEGER -> generateDeserializeIntegral(ident, "int32", "Int64", Integer.MIN_VALUE, Integer.MAX_VALUE);
case LONG -> generateDeserializeIntegral(ident, "int64", "Int64", Long.MIN_VALUE, Long.MAX_VALUE);
case BYTE -> generateDeserializeIntegral(ident, "int8", Byte.MIN_VALUE, Byte.MAX_VALUE);
case SHORT -> generateDeserializeIntegral(ident, "int16", Short.MIN_VALUE, Short.MAX_VALUE);
case INTEGER -> generateDeserializeIntegral(ident, "int32", Integer.MIN_VALUE, Integer.MAX_VALUE);
case LONG -> generateDeserializeIntegral(ident, "int64", Long.MIN_VALUE, Long.MAX_VALUE);
case STRING, BOOLEAN -> goTemplate("return $L, nil", ident);
//Int_Enum implementation needs to be tested
case ENUM, INT_ENUM -> goTemplate("return $T($L), nil", symbolProvider.toSymbol(shape), ident);
case FLOAT -> generateDeserializeIntegral(ident, "float32", "Float64",
case FLOAT -> generateDeserializeFloat(ident, "float32",
(long) Float.MIN_VALUE, (long) Float.MAX_VALUE);
case BLOB -> goTemplate("""
p, err := $b64:T.DecodeString($ident:L)
Expand Down Expand Up @@ -257,10 +256,9 @@ yield goTemplate("""
};
}

private GoWriter.Writable generateDeserializeIntegral(String ident, String castTo, String typecast,
long min, long max) {
private GoWriter.Writable generateDeserializeIntegral(String ident, String castTo, long min, long max) {
return goTemplate("""
$nextident:L, err := $ident:L.$typecast:L()
$nextident:L, err := $ident:L.Int64()
if err != nil {
return 0, err
}
Expand All @@ -275,8 +273,28 @@ private GoWriter.Writable generateDeserializeIntegral(String ident, String castT
"nextident", ident + "_",
"min", min,
"max", max,
"cast", castTo,
"typecast", typecast
"cast", castTo
));
}

private GoWriter.Writable generateDeserializeFloat(String ident, String castTo, long min, long max) {
return goTemplate("""
$nextident:L, err := $ident:L.Float64()
if err != nil {
return 0, err
}
if $nextident:L < $min:L || $nextident:L > $max:L {
return 0, $errorf:T("invalid")
}
return $cast:L($nextident:L), nil
""",
MapUtils.of(
"errorf", GoStdlibTypes.Fmt.Errorf,
"ident", ident,
"nextident", ident + "_",
"min", min,
"max", max,
"cast", castTo
));
}

Expand Down

0 comments on commit 9cbcce8

Please sign in to comment.