Skip to content

Commit

Permalink
add header check for responses
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhe committed Jun 19, 2024
1 parent b68a47a commit e8dd17a
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 60 deletions.
6 changes: 6 additions & 0 deletions packages/core/src/submodules/cbor/parseCborBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,9 @@ export const loadSmithyRpcV2CborErrorCode = (output: HttpResponse, data: any): s
return sanitizeErrorCode(data.code);
}
};

export const checkCborResponse = (response: HttpResponse): void => {
if (response.headers["smithy-protocol"] !== "rpc-v2-cbor") {
throw new Error("Malformed RPCv2 CBOR response, status: " + response.statusCode);
}
};
2 changes: 2 additions & 0 deletions packages/smithy-client/src/serde-json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* Maps an object through the default JSON serde behavior.
* This means removing nullish fields and un-sparsifying lists.
*
* This is also used by Smithy RPCv2 CBOR as the default serde behavior.
*
* @param obj - to be checked.
* @returns same object with default serde behavior applied.
*/
Expand Down
36 changes: 18 additions & 18 deletions private/smithy-rpcv2-cbor/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@
"@aws-crypto/sha256-browser": "5.2.0",
"@aws-crypto/sha256-js": "5.2.0",
"@aws-sdk/types": "latest",
"@smithy/config-resolver": "^3.0.2",
"@smithy/core": "^2.2.2",
"@smithy/fetch-http-handler": "^3.0.3",
"@smithy/hash-node": "^3.0.1",
"@smithy/invalid-dependency": "^3.0.1",
"@smithy/middleware-content-length": "^3.0.1",
"@smithy/middleware-retry": "^3.0.5",
"@smithy/middleware-serde": "^3.0.1",
"@smithy/middleware-stack": "^3.0.1",
"@smithy/node-config-provider": "^3.1.1",
"@smithy/node-http-handler": "^3.0.1",
"@smithy/protocol-http": "^4.0.1",
"@smithy/smithy-client": "^3.1.3",
"@smithy/types": "^3.1.0",
"@smithy/url-parser": "^3.0.1",
"@smithy/config-resolver": "^3.0.3",
"@smithy/core": "^2.2.3",
"@smithy/fetch-http-handler": "^3.1.0",
"@smithy/hash-node": "^3.0.2",
"@smithy/invalid-dependency": "^3.0.2",
"@smithy/middleware-content-length": "^3.0.2",
"@smithy/middleware-retry": "^3.0.6",
"@smithy/middleware-serde": "^3.0.2",
"@smithy/middleware-stack": "^3.0.2",
"@smithy/node-config-provider": "^3.1.2",
"@smithy/node-http-handler": "^3.1.0",
"@smithy/protocol-http": "^4.0.2",
"@smithy/smithy-client": "^3.1.4",
"@smithy/types": "^3.2.0",
"@smithy/url-parser": "^3.0.2",
"@smithy/util-base64": "^3.0.0",
"@smithy/util-body-length-browser": "^3.0.0",
"@smithy/util-body-length-node": "^3.0.0",
"@smithy/util-defaults-mode-browser": "^3.0.5",
"@smithy/util-defaults-mode-node": "^3.0.5",
"@smithy/util-retry": "^3.0.1",
"@smithy/util-defaults-mode-browser": "^3.0.6",
"@smithy/util-defaults-mode-node": "^3.0.6",
"@smithy/util-retry": "^3.0.2",
"@smithy/util-utf8": "^3.0.0",
"tslib": "^2.6.2"
},
Expand Down
25 changes: 25 additions & 0 deletions private/smithy-rpcv2-cbor/src/protocols/Rpcv2cbor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import {
import {
dateToTag as __dateToTag,
cbor,
checkCborResponse as cr,
loadSmithyRpcV2CborErrorCode,
parseCborBody as parseBody,
parseCborErrorBody as parseErrorBody,
Expand Down Expand Up @@ -277,9 +278,11 @@ export const de_EmptyInputOutputCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<EmptyInputOutputCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = _json(data);
Expand All @@ -297,9 +300,11 @@ export const de_FractionalSecondsCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<FractionalSecondsCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_FractionalSecondsOutput(data, context);
Expand All @@ -317,9 +322,11 @@ export const de_GreetingWithErrorsCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<GreetingWithErrorsCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = _json(data);
Expand All @@ -337,9 +344,11 @@ export const de_NoInputOutputCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<NoInputOutputCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

await collectBody(output.body, context);
const response: NoInputOutputCommandOutput = {
$metadata: deserializeMetadata(output),
Expand All @@ -354,9 +363,11 @@ export const de_OperationWithDefaultsCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<OperationWithDefaultsCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_OperationWithDefaultsOutput(data, context);
Expand All @@ -374,9 +385,11 @@ export const de_OptionalInputOutputCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<OptionalInputOutputCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = _json(data);
Expand All @@ -394,9 +407,11 @@ export const de_RecursiveShapesCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<RecursiveShapesCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_RecursiveShapesInputOutput(data, context);
Expand All @@ -414,9 +429,11 @@ export const de_RpcV2CborDenseMapsCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<RpcV2CborDenseMapsCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = _json(data);
Expand All @@ -434,9 +451,11 @@ export const de_RpcV2CborListsCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<RpcV2CborListsCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_RpcV2CborListInputOutput(data, context);
Expand All @@ -454,9 +473,11 @@ export const de_RpcV2CborSparseMapsCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<RpcV2CborSparseMapsCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_RpcV2CborSparseMapsInputOutput(data, context);
Expand All @@ -474,9 +495,11 @@ export const de_SimpleScalarPropertiesCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<SimpleScalarPropertiesCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_SimpleScalarStructure(data, context);
Expand All @@ -494,9 +517,11 @@ export const de_SparseNullsOperationCommand = async (
output: __HttpResponse,
context: __SerdeContext
): Promise<SparseNullsOperationCommandOutput> => {
cr(output);
if (output.statusCode >= 300) {
return de_CommandError(output, context);
}

const data: any = await parseBody(output.body, context);
let contents: any = {};
contents = de_SparseNullsOperationInputOutput(data, context);
Expand Down
14 changes: 8 additions & 6 deletions private/smithy-rpcv2-cbor/test/functional/rpcv2cbor.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ResponseDeserializationTestHandler implements HttpHandler {
body = "";
}
this.body = body;
this.isBase64Body = Buffer.from(String(body), "base64").toString("base64") === body;
this.isBase64Body = String(body).length > 0 && Buffer.from(String(body), "base64").toString("base64") === body;
}

handle(request: HttpRequest, options?: HttpHandlerOptions): Promise<{ response: HttpResponse }> {
Expand Down Expand Up @@ -201,13 +201,15 @@ function normalizeByteArrayType(data: any) {
if (!data || typeof data !== "object") {
return data;
}
if (data instanceof Uint8Array) {
return Uint8Array.from(data);
}
if (data instanceof String || data instanceof Boolean || data instanceof Number) {
return data.valueOf();
}
const output = {} as any;
for (const key of Object.getOwnPropertyNames(data)) {
if (data[key] instanceof Uint8Array) {
output[key] = Uint8Array.from(data[key]);
} else {
output[key] = normalizeByteArrayType(data[key]);
}
output[key] = normalizeByteArrayType(data[key]);
}
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import java.util.Set;
import java.util.TreeSet;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.OperationShape;
Expand All @@ -16,13 +18,15 @@
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait;
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait;
import software.amazon.smithy.typescript.codegen.CodegenUtils;
import software.amazon.smithy.typescript.codegen.SmithyCoreSubmodules;
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptSettings;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.EventStreamGenerator;
import software.amazon.smithy.typescript.codegen.integration.HttpProtocolGeneratorUtils;
import software.amazon.smithy.typescript.codegen.integration.HttpRpcProtocolGenerator;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.typescript.codegen.protocols.SmithyProtocolUtils;
import software.amazon.smithy.utils.SmithyInternalApi;
Expand Down Expand Up @@ -177,6 +181,60 @@ protected void generateDocumentBodyShapeDeserializers(GenerationContext generati
);
}

@Override
protected void generateOperationDeserializer(GenerationContext context, OperationShape operation) {
SymbolProvider symbolProvider = context.getSymbolProvider();
Symbol symbol = symbolProvider.toSymbol(operation);
SymbolReference responseType = getApplicationProtocol().getResponseType();
TypeScriptWriter writer = context.getWriter();

writer.addUseImports(responseType);
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = "de_CommandError";
String serdeContextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
context.getModel(), operation);
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);

writer.writeDocs(methodLongName);
writer.openBlock("""
export const $L = async(
output: $T,
context: $L
): Promise<$T> => {""", "}",
methodName, responseType, serdeContextType, outputType,
() -> {
writer.addSubPathImport(
"checkCborResponse", "cr",
TypeScriptDependency.SMITHY_CORE,
SmithyCoreSubmodules.CBOR
);
writer.write("cr(output);");

writer.write("""
if (output.statusCode >= 300) {
return $L(output, context);
}
""",
errorMethodName
);

readResponseBody(context, operation);

writer.write("""
const response: $T = {
$$metadata: deserializeMetadata(output), $L
};
return response;
""",
outputType,
operation.getOutput().map((o) -> "...contents,").orElse("")
);
}
);
writer.write("");
}

@Override
protected String getOperationPath(GenerationContext generationContext, OperationShape operationShape) {
// TODO(cbor) what is the prefix?
Expand Down
Loading

0 comments on commit e8dd17a

Please sign in to comment.