Skip to content

Commit

Permalink
feat(codegen): xml factory
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhe committed Jul 28, 2022
1 parent 8341b0b commit 8e0bb70
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -735,31 +735,30 @@ private void writeResolvedPath(

// Handle any label bindings.
if (!labelBindings.isEmpty()) {
writer.addImport("resolvedPath", "__resolvedPath", "@aws-sdk/smithy-client");

Model model = context.getModel();
List<Segment> uriLabels = trait.getUri().getLabels();
for (HttpBinding binding : labelBindings) {
String memberName = symbolProvider.toMemberName(binding.getMember());
Shape target = model.expectShape(binding.getMember().getTarget());
String labelValue = getInputValue(context, binding.getLocation(), "input." + memberName,
binding.getMember(), target);

String labelValueProvider = "() => " + getInputValue(
context,
binding.getLocation(),
"input." + memberName + "!",
binding.getMember(),
target
);

// Get the correct label to use.
Segment uriLabel = uriLabels.stream().filter(s -> s.getContent().equals(memberName)).findFirst().get();
writer.addImport("extendedEncodeURIComponent", "__extendedEncodeURIComponent",
"@aws-sdk/smithy-client");
String encodedSegment = uriLabel.isGreedyLabel()
? "labelValue.split(\"/\").map(segment => __extendedEncodeURIComponent(segment)).join(\"/\")"
: "__extendedEncodeURIComponent(labelValue)";

// Set the label's value and throw a clear error if empty or undefined.
writer.write("if (input.$L !== undefined) {", memberName).indent()
.write("const labelValue: string = $L;", labelValue)
.openBlock("if (labelValue.length <= 0) {", "}", () -> {
writer.write("throw new Error('Empty value provided for input HTTP label: $L.');", memberName);
})
.write("resolvedPath = resolvedPath.replace($S, $L);", uriLabel.toString(), encodedSegment).dedent()
.write("} else {").indent()
.write("throw new Error('No value provided for input HTTP label: $L.');", memberName).dedent()
.write("}");
writer.write("resolvedPath = __resolvedPath(resolvedPath, input, '$L', $L, '$L', $L)",
memberName,
labelValueProvider,
uriLabel.toString(),
uriLabel.isGreedyLabel() ? "true" : "false"
);
}
}
}
Expand All @@ -780,15 +779,16 @@ private boolean writeRequestQueryString(
writer.openBlock("const query: any = map({", "});", () -> {
if (!queryLiterals.isEmpty()) {
// Write any query literals present in the uri.
queryLiterals.forEach((k, v) -> writer.write("$S: $S,", k, v));
queryLiterals.forEach((k, v) -> writer.write("$S: [, $S],", k, v));
}
// Handle any additional query params bindings.
// If query string parameter is also present in httpQuery, it would be overwritten.
// Serializing HTTP messages https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#serializing-http-messages
if (!queryParamsBindings.isEmpty()) {
SymbolProvider symbolProvider = context.getSymbolProvider();
String memberName = symbolProvider.toMemberName(queryParamsBindings.get(0).getMember());
writer.write("...(input.$1L !== undefined && input.$1L),", memberName);
writer.addImport("convertMap", "convertMap", "@aws-sdk/smithy-client");
writer.write("...convertMap(input.$L),", memberName);
}
// Handle any additional query bindings.
if (!queryBindings.isEmpty()) {
Expand Down Expand Up @@ -856,27 +856,48 @@ private void writeRequestHeaders(
) {
TypeScriptWriter writer = context.getWriter();

List<HttpBinding> headers = bindingIndex.getRequestBindings(operation, Location.HEADER);
List<HttpBinding> prefixHeaders = bindingIndex.getRequestBindings(operation, Location.PREFIX_HEADERS);
boolean inputPresent = operation.getInput().isPresent();

int normalHeaderCount = headers.size();
int prefixHeaderCount = prefixHeaders.size();

String opening;
String closing;
if (normalHeaderCount + prefixHeaderCount == 0) {
opening = "const headers: any = {";
closing = "};";
} else {
opening = normalHeaderCount > 0
? "const headers: any = map({}, isSerializableHeaderValue, {"
: "const headers: any = map({";
closing = "});";
}

// Headers are always present either from the default document or the payload.
writer.openBlock("const headers: any = map({}, isSerializableHeaderValue, {", "});", () -> {
// Only set the content type if one can be determined.
writeContentTypeHeader(context, operation, true);
writeDefaultInputHeaders(context, operation);

operation.getInput().ifPresent(outputId -> {
for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.HEADER)) {
writeNormalHeader(context, binding);
}
});
writer.write(opening);
writer.indent();
// Only set the content type if one can be determined.
writeContentTypeHeader(context, operation, true);
writeDefaultInputHeaders(context, operation);

flushHeadersBuffer(writer);
if (inputPresent) {
for (HttpBinding binding : headers) {
writeNormalHeader(context, binding);
}
}

operation.getInput().ifPresent(outputId -> {
// Handle assembling prefix headers.
for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.PREFIX_HEADERS)) {
writePrefixHeaders(context, binding);
}
});
});
flushHeadersBuffer(writer);

if (inputPresent) {
// Handle assembling prefix headers.
for (HttpBinding binding : prefixHeaders) {
writePrefixHeaders(context, binding);
}
}
writer.dedent();
writer.write(closing);
}

private void writeNormalHeader(GenerationContext context, HttpBinding binding) {
Expand Down Expand Up @@ -1650,11 +1671,11 @@ private void generateOperationRequestDeserializer(
.forEach((memberName, memberShape) -> writer.write(
"$L: undefined,", memberName));
});
readRequestHeaders(context, operation, bindingIndex, "output");
});
readQueryString(context, operation, bindingIndex);
readPath(context, operation, bindingIndex, trait);
readHost(context, operation);
readRequestHeaders(context, operation, bindingIndex, "output");
List<HttpBinding> documentBindings = readRequestBody(context, operation, bindingIndex);
// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Expand Down Expand Up @@ -1703,7 +1724,7 @@ private void handleContentType(
operation, getDocumentContentType());
writer.write("const contentTypeHeaderKey: string | undefined = Object.keys(output.headers)"
+ ".find(key => key.toLowerCase() === 'content-type');");
writer.openBlock("if (contentTypeHeaderKey !== undefined && contentTypeHeaderKey !== null) {", "};", () -> {
writer.openBlock("if (contentTypeHeaderKey != null) {", "};", () -> {
writer.write("const contentType = output.headers[contentTypeHeaderKey];");
if (optionalContentType.isPresent() || operation.getInput().isPresent()) {
String contentType = optionalContentType.orElse(getDocumentContentType());
Expand Down Expand Up @@ -1758,7 +1779,7 @@ private void handleAccept(
writer.addImport("acceptMatches", "__acceptMatches", "@aws-smithy/server-common");
writer.write("const acceptHeaderKey: string | undefined = Object.keys(output.headers)"
+ ".find(key => key.toLowerCase() === 'accept');");
writer.openBlock("if (acceptHeaderKey !== undefined && acceptHeaderKey !== null) {", "};", () -> {
writer.openBlock("if (acceptHeaderKey != null) {", "};", () -> {
writer.write("const accept = output.headers[acceptHeaderKey];");
String contentType = optionalContentType.orElse(getDocumentContentType());
// Validate that the content type matches the protocol default, or what's modeled if there's
Expand Down Expand Up @@ -1793,7 +1814,7 @@ private void readQueryString(
return;
}
writer.write("const query = output.query");
writer.openBlock("if (query !== undefined && query !== null) {", "}", () -> {
writer.openBlock("if (query != null) {", "}", () -> {
readDirectQueryBindings(context, directQueryBindings);
if (!mappedQueryBindings.isEmpty()) {
// There can only ever be one of these bindings on a given operation.
Expand Down Expand Up @@ -2041,8 +2062,10 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
+ " context: __SerdeContext\n"
+ "): Promise<$T> => {", "};",
errorDeserMethodName, outputName, errorSymbol, () -> {
writer.write("const contents: any = {};");
readResponseHeaders(context, error, bindingIndex, outputName);
writer.openBlock("const contents: any = map({", "});", () -> {
readResponseHeaders(context, error, bindingIndex, outputName);
});

List<HttpBinding> documentBindings = readErrorResponseBody(context, error, bindingIndex);
// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Expand Down Expand Up @@ -2159,26 +2182,27 @@ private void readPrefixHeaders(
TypeScriptWriter writer = context.getWriter();

// Run through the headers one time, matching any prefix groups.
writer.openBlock("...(Object.keys($L.headers).reduce((acc, header) => {", "}, {}));", outputName, () -> {
for (HttpBinding binding : prefixHeaderBindings) {
// Prepare a grab bag for these headers if necessary
String memberName = symbolProvider.toMemberName(binding.getMember());
writer.write("acc.$L = [, {}];", memberName);

// Generate a single block for each group of lower-cased prefix headers.
for (HttpBinding binding : prefixHeaderBindings) {
// Prepare a grab bag for these headers if necessary
String memberName = symbolProvider.toMemberName(binding.getMember());
writer.openBlock("$L: [, ", "],", memberName, () -> {
String headerLocation = binding.getLocationName().toLowerCase(Locale.US);
writer.openBlock("if (header.startsWith($S)) {", "}", headerLocation, () -> {
writer.write(
"Object.keys($L.headers).filter(header => header.startsWith('$L'))",
outputName,
headerLocation
);
writer.indent().openBlock(".reduce((acc, header) => {", "}, {} as any)", () -> {
MapShape prefixMap = model.expectShape(binding.getMember().getTarget()).asMapShape().get();
Shape target = model.expectShape(prefixMap.getValue().getTarget());
String headerValue = getOutputValue(context, binding.getLocation(),
outputName + ".headers[header]", binding.getMember(), target);

// Extract the non-prefix portion as the key.
writer.write("acc.$L[header.substring($L)] = [, $L];",
memberName, headerLocation.length(), headerValue);
outputName + ".headers[header]", binding.getMember(), target);
writer.write("acc[header.substring($L)] = $L;",
headerLocation.length(), headerValue);
writer.write("return acc;");
});
}
});
});
}
}

private List<HttpBinding> readRequestBody(
Expand Down Expand Up @@ -2253,12 +2277,17 @@ private List<HttpBinding> readBody(
// isn't set in the body.
// These are only relevant when a payload is not present, as it cannot
// coexist with a payload.
for (HttpBinding responseCodeBinding : responseCodeBindings) {
// The name of the member to get from the input shape.
String memberName = symbolProvider.toMemberName(responseCodeBinding.getMember());
writer.openBlock("if (contents.$L === undefined) {", "}", memberName, () ->
writer.write("contents.$L = output.statusCode;", memberName));
if (!responseCodeBindings.isEmpty()) {
writer.openBlock("map(contents, {", "});", () -> {
for (HttpBinding responseCodeBinding : responseCodeBindings) {
// The name of the member to get from the input shape.
String memberName = symbolProvider.toMemberName(responseCodeBinding.getMember());

writer.write("$L: [, output.statusCode]", memberName);
}
});
}

if (!documentBindings.isEmpty()) {
return documentBindings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,28 +348,13 @@ static Set<StructureShape> generateErrorDispatcher(

// Error responses must be at least BaseException interface
SymbolReference baseExceptionReference = getClientBaseException(context);
writer.write("let response: $T;", baseExceptionReference);
errorCodeGenerator.accept(context);
writer.openBlock("switch (errorCode) {", "}", () -> {
// Generate the case statement for each error, invoking the specific deserializer.
new TreeSet<>(operationIndex.getErrors(operation, context.getService())).forEach(error -> {
final ShapeId errorId = error.getId();
// Track errors bound to the operation so their deserializers may be generated.
errorShapes.add(error);
Symbol errorSymbol = symbolProvider.toSymbol(error);
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
context.getProtocolName()) + "Response";
// Dispatch to the error deserialization function.
String outputParam = shouldParseErrorBody ? "parsedOutput" : "output";
writer.write("case $S:", errorId.getName());
writer.write("case $S:", errorId.toString());
writer.indent()
.write("throw await $L($L, context);", errorDeserMethodName, outputParam)
.dedent();
});

// Build a generic error the best we can for ones we don't know about.
writer.write("default:").indent();
TreeSet<StructureShape> structureShapes = new TreeSet<>(
operationIndex.getErrors(operation, context.getService())
);

Runnable defaultErrorHandler = () -> {
if (shouldParseErrorBody) {
// Body is already parsed above
writer.write("const parsedBody = parsedOutput.body;");
Expand All @@ -378,21 +363,49 @@ static Set<StructureShape> generateErrorDispatcher(
writer.write("const parsedBody = await parseBody(output.body, context);");
}

writer.addImport("throwDefaultError", "throwDefaultError", "@aws-sdk/smithy-client");

// Get the protocol specific error location for retrieving contents.
String errorLocation = bodyErrorLocationModifier.apply(context, "parsedBody");
writer.write("const $$metadata = deserializeMetadata(output);");
writer.write("const statusCode = $$metadata.httpStatusCode ? $$metadata.httpStatusCode"
+ " + '' : undefined;");
writer.openBlock("response = new $T({", "});", baseExceptionReference, () -> {
writer.write("name: $1L.code || $1L.Code || errorCode || statusCode || 'UnknowError',",
errorLocation);
writer.write("$$fault: \"client\",");
writer.write("$$metadata");
writer.openBlock("throwDefaultError({", "})", () -> {
writer.write("output,");
if (errorLocation.equals("parsedBody")) {
writer.write("parsedBody,");
} else {
writer.write("parsedBody: $L,", errorLocation);
}
writer.write("exceptionCtor: $T,", baseExceptionReference);
writer.write("errorCode");
});
writer.addImport("decorateServiceException", "__decorateServiceException",
TypeScriptDependency.AWS_SMITHY_CLIENT.packageName);
writer.write("throw __decorateServiceException(response, $L);", errorLocation);
});
};

if (!structureShapes.isEmpty()) {
writer.openBlock("switch (errorCode) {", "}", () -> {
// Generate the case statement for each error, invoking the specific deserializer.

structureShapes.forEach(error -> {
final ShapeId errorId = error.getId();
// Track errors bound to the operation so their deserializers may be generated.
errorShapes.add(error);
Symbol errorSymbol = symbolProvider.toSymbol(error);
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
context.getProtocolName()) + "Response";
// Dispatch to the error deserialization function.
String outputParam = shouldParseErrorBody ? "parsedOutput" : "output";
writer.write("case $S:", errorId.getName());
writer.write("case $S:", errorId.toString());
writer.indent()
.write("throw await $L($L, context);", errorDeserMethodName, outputParam)
.dedent();
});

// Build a generic error the best we can for ones we don't know about.
writer.write("default:").indent();
defaultErrorHandler.run();
});
} else {
defaultErrorHandler.run();
}
});
writer.write("");

Expand Down

0 comments on commit 8e0bb70

Please sign in to comment.