From b2fffd5ff0a68985e778d217b1b54d2fa9c927d7 Mon Sep 17 00:00:00 2001 From: George Fu Date: Tue, 2 Aug 2022 11:24:57 -0400 Subject: [PATCH] feat(codegen): general data mapping function (#576) * feat(codegen): use object mapper to assign fields * feat(codegen): xml factory --- .../HttpBindingProtocolGenerator.java | 287 ++++++++++++------ .../HttpProtocolGeneratorUtils.java | 93 +++--- 2 files changed, 248 insertions(+), 132 deletions(-) diff --git a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpBindingProtocolGenerator.java b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpBindingProtocolGenerator.java index da2925ace64..8cd9ff63866 100644 --- a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpBindingProtocolGenerator.java +++ b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpBindingProtocolGenerator.java @@ -19,9 +19,11 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.TreeMap; @@ -93,6 +95,8 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator private final boolean isErrorCodeInBody; private final EventStreamGenerator eventStreamGenerator = new EventStreamGenerator(); + private final LinkedHashMap headerBuffer = new LinkedHashMap<>(); + /** * Creates a Http binding protocol generator. * @@ -146,6 +150,10 @@ public final ApplicationProtocol getApplicationProtocol() { @Override public void generateSharedComponents(GenerationContext context) { + TypeScriptWriter writer = context.getWriter(); + writer.addImport("map", "__map", "@aws-sdk/smithy-client"); + writer.write("const map = __map"); + deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error)); serializingErrorShapes.forEach(error -> generateErrorSerializer(context, error)); ServiceShape service = context.getService(); @@ -727,31 +735,30 @@ private void writeResolvedPath( // Handle any label bindings. if (!labelBindings.isEmpty()) { + writer.addImport("resolvedPath", "__resolvedPath", "@aws-sdk/smithy-client"); + Model model = context.getModel(); List uriLabels = trait.getUri().getLabels(); for (HttpBinding binding : labelBindings) { String memberName = symbolProvider.toMemberName(binding.getMember()); Shape target = model.expectShape(binding.getMember().getTarget()); - String labelValue = getInputValue(context, binding.getLocation(), "input." + memberName, - binding.getMember(), target); + + String labelValueProvider = "() => " + getInputValue( + context, + binding.getLocation(), + "input." + memberName + "!", + binding.getMember(), + target + ); + // Get the correct label to use. Segment uriLabel = uriLabels.stream().filter(s -> s.getContent().equals(memberName)).findFirst().get(); - writer.addImport("extendedEncodeURIComponent", "__extendedEncodeURIComponent", - "@aws-sdk/smithy-client"); - String encodedSegment = uriLabel.isGreedyLabel() - ? "labelValue.split(\"/\").map(segment => __extendedEncodeURIComponent(segment)).join(\"/\")" - : "__extendedEncodeURIComponent(labelValue)"; - - // Set the label's value and throw a clear error if empty or undefined. - writer.write("if (input.$L !== undefined) {", memberName).indent() - .write("const labelValue: string = $L;", labelValue) - .openBlock("if (labelValue.length <= 0) {", "}", () -> { - writer.write("throw new Error('Empty value provided for input HTTP label: $L.');", memberName); - }) - .write("resolvedPath = resolvedPath.replace($S, $L);", uriLabel.toString(), encodedSegment).dedent() - .write("} else {").indent() - .write("throw new Error('No value provided for input HTTP label: $L.');", memberName).dedent() - .write("}"); + writer.write("resolvedPath = __resolvedPath(resolvedPath, input, '$L', $L, '$L', $L)", + memberName, + labelValueProvider, + uriLabel.toString(), + uriLabel.isGreedyLabel() ? "true" : "false" + ); } } } @@ -769,10 +776,10 @@ private boolean writeRequestQueryString( // Build the initial query bag. Map queryLiterals = trait.getUri().getQueryLiterals(); if (!queryLiterals.isEmpty() || !queryBindings.isEmpty() || !queryParamsBindings.isEmpty()) { - writer.openBlock("const query: any = {", "};", () -> { + writer.openBlock("const query: any = map({", "});", () -> { if (!queryLiterals.isEmpty()) { // Write any query literals present in the uri. - queryLiterals.forEach((k, v) -> writer.write("$S: $S,", k, v)); + queryLiterals.forEach((k, v) -> writer.write("$S: [, $S],", k, v)); } // Handle any additional query params bindings. // If query string parameter is also present in httpQuery, it would be overwritten. @@ -780,7 +787,8 @@ private boolean writeRequestQueryString( if (!queryParamsBindings.isEmpty()) { SymbolProvider symbolProvider = context.getSymbolProvider(); String memberName = symbolProvider.toMemberName(queryParamsBindings.get(0).getMember()); - writer.write("...(input.$1L !== undefined && input.$1L),", memberName); + writer.addImport("convertMap", "convertMap", "@aws-sdk/smithy-client"); + writer.write("...convertMap(input.$L),", memberName); } // Handle any additional query bindings. if (!queryBindings.isEmpty()) { @@ -808,10 +816,37 @@ private void writeRequestQueryParam( "@aws-sdk/smithy-client"); Shape target = model.expectShape(binding.getMember().getTarget()); - String queryValue = getInputValue(context, binding.getLocation(), "input." + memberName, - binding.getMember(), target); - writer.write("...(input.$L !== undefined && { $S: $L }),", memberName, - binding.getLocationName(), queryValue); + String queryValue = getInputValue( + context, + binding.getLocation(), + "input." + memberName + "!", + binding.getMember(), + target + ); + + if (Objects.equals("input." + memberName + "!", queryValue)) { + // simple undefined check + writer.write( + "$S: [,$L],", + binding.getLocationName(), + queryValue + ); + } else { + // undefined check with lazy eval + writer.write( + "$S: [() => input.$L !== void 0, () => $L],", + binding.getLocationName(), + memberName, + queryValue + ); + } + } + + private void flushHeadersBuffer(TypeScriptWriter writer) { + for (Map.Entry entry : headerBuffer.entrySet()) { + writer.write(entry.getValue()); + } + headerBuffer.clear(); } private void writeRequestHeaders( @@ -821,32 +856,88 @@ private void writeRequestHeaders( ) { TypeScriptWriter writer = context.getWriter(); + List headers = bindingIndex.getRequestBindings(operation, Location.HEADER); + List prefixHeaders = bindingIndex.getRequestBindings(operation, Location.PREFIX_HEADERS); + boolean inputPresent = operation.getInput().isPresent(); + + int normalHeaderCount = headers.size(); + int prefixHeaderCount = prefixHeaders.size(); + + String opening; + String closing; + if (normalHeaderCount + prefixHeaderCount == 0) { + opening = "const headers: any = {"; + closing = "};"; + } else { + opening = normalHeaderCount > 0 + ? "const headers: any = map({}, isSerializableHeaderValue, {" + : "const headers: any = map({"; + closing = "});"; + } + // Headers are always present either from the default document or the payload. - writer.openBlock("const headers: any = {", "};", () -> { - // Only set the content type if one can be determined. - writeContentTypeHeader(context, operation, true); - writeDefaultInputHeaders(context, operation); - - operation.getInput().ifPresent(outputId -> { - for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.HEADER)) { - writeNormalHeader(context, binding); - } + writer.write(opening); + writer.indent(); + // Only set the content type if one can be determined. + writeContentTypeHeader(context, operation, true); + writeDefaultInputHeaders(context, operation); - // Handle assembling prefix headers. - for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.PREFIX_HEADERS)) { - writePrefixHeaders(context, binding); - } - }); - }); + if (inputPresent) { + for (HttpBinding binding : headers) { + writeNormalHeader(context, binding); + } + } + + flushHeadersBuffer(writer); + + if (inputPresent) { + // Handle assembling prefix headers. + for (HttpBinding binding : prefixHeaders) { + writePrefixHeaders(context, binding); + } + } + writer.dedent(); + writer.write(closing); } private void writeNormalHeader(GenerationContext context, HttpBinding binding) { String memberLocation = "input." + context.getSymbolProvider().toMemberName(binding.getMember()); Shape target = context.getModel().expectShape(binding.getMember().getTarget()); - String headerValue = getInputValue(context, binding.getLocation(), memberLocation + "!", - binding.getMember(), target); - context.getWriter().write("...isSerializableHeaderValue($L) && { $S: $L },", - memberLocation, binding.getLocationName().toLowerCase(Locale.US), headerValue); + + String headerKey = binding.getLocationName().toLowerCase(Locale.US); + String headerValue = getInputValue( + context, + binding.getLocation(), + memberLocation + "!", + binding.getMember(), + target + ); + + if (!Objects.equals(memberLocation + "!", headerValue)) { + String defaultValue = ""; + if (headerBuffer.containsKey(headerKey)) { + String s = headerBuffer.get(headerKey); + defaultValue = " || " + s.substring(s.indexOf(": ") + 2, s.length() - 1); + } + // evaluated value has a function or method call attached + headerBuffer.put(headerKey, String.format( + "'%s': [() => isSerializableHeaderValue(%s), () => %s],", + headerKey, + memberLocation + defaultValue, + headerValue + defaultValue + )); + } else { + String value = headerValue; + if (headerBuffer.containsKey(headerKey)) { + String s = headerBuffer.get(headerKey); + value = headerValue + " || " + s.substring(s.indexOf(": ") + 2, s.length() - 1); + } + headerBuffer.put(headerKey, String.format( + "'%s': %s,", + headerKey, + value + )); + } } private void writePrefixHeaders(GenerationContext context, HttpBinding binding) { @@ -881,7 +972,7 @@ private void writeResponseHeaders( TypeScriptWriter writer = context.getWriter(); // Headers are always present either from the default document or the payload. - writer.openBlock("let headers: any = {", "};", () -> { + writer.openBlock("let headers: any = map({}, isSerializableHeaderValue, {", "});", () -> { writeContentTypeHeader(context, operationOrError, false); injectExtraHeaders.run(); @@ -889,6 +980,8 @@ private void writeResponseHeaders( writeNormalHeader(context, binding); } + flushHeadersBuffer(writer); + // Handle assembling prefix headers. for (HttpBinding binding : bindingIndex.getResponseBindings(operationOrError, Location.PREFIX_HEADERS)) { writePrefixHeaders(context, binding); @@ -908,7 +1001,12 @@ private void writeContentTypeHeader(GenerationContext context, Shape operationOr if (!optionalContentType.isPresent() && shouldWriteDefaultBody(context, operationOrError, isInput)) { optionalContentType = Optional.of(getDocumentContentType()); } - optionalContentType.ifPresent(contentType -> context.getWriter().write("'content-type': $S,", contentType)); + optionalContentType.ifPresent(contentType -> { + // context.getWriter().write("'content-type': $S,", contentType) + headerBuffer.put("content-type", + "'content-type': '" + contentType + "'," + ); + }); } private List writeRequestBody( @@ -1573,18 +1671,18 @@ private void generateOperationRequestDeserializer( .forEach((memberName, memberShape) -> writer.write( "$L: undefined,", memberName)); }); + readRequestHeaders(context, operation, bindingIndex, "output"); }); readQueryString(context, operation, bindingIndex); readPath(context, operation, bindingIndex, trait); readHost(context, operation); - readRequestHeaders(context, operation, bindingIndex, "output"); List documentBindings = readRequestBody(context, operation, bindingIndex); // Track all shapes bound to the document so their deserializers may be generated. documentBindings.forEach(binding -> { Shape target = model.expectShape(binding.getMember().getTarget()); deserializingDocumentShapes.add(target); }); - writer.write("return Promise.resolve(contents);"); + writer.write("return contents;"); }); writer.write(""); } @@ -1626,7 +1724,7 @@ private void handleContentType( operation, getDocumentContentType()); writer.write("const contentTypeHeaderKey: string | undefined = Object.keys(output.headers)" + ".find(key => key.toLowerCase() === 'content-type');"); - writer.openBlock("if (contentTypeHeaderKey !== undefined && contentTypeHeaderKey !== null) {", "};", () -> { + writer.openBlock("if (contentTypeHeaderKey != null) {", "};", () -> { writer.write("const contentType = output.headers[contentTypeHeaderKey];"); if (optionalContentType.isPresent() || operation.getInput().isPresent()) { String contentType = optionalContentType.orElse(getDocumentContentType()); @@ -1681,7 +1779,7 @@ private void handleAccept( writer.addImport("acceptMatches", "__acceptMatches", "@aws-smithy/server-common"); writer.write("const acceptHeaderKey: string | undefined = Object.keys(output.headers)" + ".find(key => key.toLowerCase() === 'accept');"); - writer.openBlock("if (acceptHeaderKey !== undefined && acceptHeaderKey !== null) {", "};", () -> { + writer.openBlock("if (acceptHeaderKey != null) {", "};", () -> { writer.write("const accept = output.headers[acceptHeaderKey];"); String contentType = optionalContentType.orElse(getDocumentContentType()); // Validate that the content type matches the protocol default, or what's modeled if there's @@ -1716,7 +1814,7 @@ private void readQueryString( return; } writer.write("const query = output.query"); - writer.openBlock("if (query !== undefined && query !== null) {", "}", () -> { + writer.openBlock("if (query != null) {", "}", () -> { readDirectQueryBindings(context, directQueryBindings); if (!mappedQueryBindings.isEmpty()) { // There can only ever be one of these bindings on a given operation. @@ -1923,20 +2021,14 @@ private void generateOperationResponseDeserializer( () -> writer.write("return $L(output, context);", errorMethodName)); // Start deserializing the response. - writer.openBlock("const contents: $T = {", "};", outputType, () -> { + writer.openBlock("const contents: any = map({", "});", () -> { writer.write("$$metadata: deserializeMetadata(output),"); - // Only set a type and the members if we have output. - operation.getOutput().ifPresent(outputId -> { - // Set all the members to undefined to meet type constraints. - StructureShape target = model.expectShape(outputId).asStructureShape().get(); - new TreeMap<>(target.getAllMembers()) - .forEach((memberName, memberShape) -> writer.write( - "$L: undefined,", memberName)); - }); + readResponseHeaders(context, operation, bindingIndex, "output"); }); - readResponseHeaders(context, operation, bindingIndex, "output"); + List documentBindings = readResponseBody(context, operation, bindingIndex); + // Track all shapes bound to the document so their deserializers may be generated. documentBindings.forEach(binding -> { Shape target = model.expectShape(binding.getMember().getTarget()); @@ -1944,7 +2036,8 @@ private void generateOperationResponseDeserializer( deserializingDocumentShapes.add(target); } }); - writer.write("return Promise.resolve(contents);"); + + writer.write("return contents;"); }); writer.write(""); // Write out the error deserialization dispatcher. @@ -1969,8 +2062,10 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape + " context: __SerdeContext\n" + "): Promise<$T> => {", "};", errorDeserMethodName, outputName, errorSymbol, () -> { - writer.write("const contents: any = {};"); - readResponseHeaders(context, error, bindingIndex, outputName); + writer.openBlock("const contents: any = map({", "});", () -> { + readResponseHeaders(context, error, bindingIndex, outputName); + }); + List documentBindings = readErrorResponseBody(context, error, bindingIndex); // Track all shapes bound to the document so their deserializers may be generated. documentBindings.forEach(binding -> { @@ -2053,12 +2148,16 @@ private void readNormalHeaders( TypeScriptWriter writer = context.getWriter(); String memberName = context.getSymbolProvider().toMemberName(binding.getMember()); String headerName = binding.getLocationName().toLowerCase(Locale.US); - writer.openBlock("if ($L.headers[$S] !== undefined) {", "}", outputName, headerName, () -> { - Shape target = context.getModel().expectShape(binding.getMember().getTarget()); - String headerValue = getOutputValue(context, binding.getLocation(), - outputName + ".headers['" + headerName + "']", binding.getMember(), target); - writer.write("contents.$L = $L;", memberName, headerValue); - }); + Shape target = context.getModel().expectShape(binding.getMember().getTarget()); + String headerValue = getOutputValue(context, binding.getLocation(), + outputName + ".headers['" + headerName + "']", binding.getMember(), target); + String checkedValue = outputName + ".headers['" + headerName + "']"; + + if (checkedValue.equals(headerValue)) { + writer.write("$L: [, $L],", memberName, headerValue); + } else { + writer.write("$L: [() => void 0 !== $L, () => $L],", memberName, checkedValue, headerValue); + } } } @@ -2083,28 +2182,27 @@ private void readPrefixHeaders( TypeScriptWriter writer = context.getWriter(); // Run through the headers one time, matching any prefix groups. - writer.openBlock("Object.keys($L.headers).forEach(header => {", "});", outputName, () -> { - for (HttpBinding binding : prefixHeaderBindings) { - // Prepare a grab bag for these headers if necessary - String memberName = symbolProvider.toMemberName(binding.getMember()); - writer.openBlock("if (contents.$L === undefined) {", "}", memberName, () -> { - writer.write("contents.$L = {};", memberName); - }); - - // Generate a single block for each group of lower-cased prefix headers. + for (HttpBinding binding : prefixHeaderBindings) { + // Prepare a grab bag for these headers if necessary + String memberName = symbolProvider.toMemberName(binding.getMember()); + writer.openBlock("$L: [, ", "],", memberName, () -> { String headerLocation = binding.getLocationName().toLowerCase(Locale.US); - writer.openBlock("if (header.startsWith($S)) {", "}", headerLocation, () -> { + writer.write( + "Object.keys($L.headers).filter(header => header.startsWith('$L'))", + outputName, + headerLocation + ); + writer.indent().openBlock(".reduce((acc, header) => {", "}, {} as any)", () -> { MapShape prefixMap = model.expectShape(binding.getMember().getTarget()).asMapShape().get(); Shape target = model.expectShape(prefixMap.getValue().getTarget()); String headerValue = getOutputValue(context, binding.getLocation(), - outputName + ".headers[header]", binding.getMember(), target); - - // Extract the non-prefix portion as the key. - writer.write("contents.$L[header.substring($L)] = $L;", - memberName, headerLocation.length(), headerValue); + outputName + ".headers[header]", binding.getMember(), target); + writer.write("acc[header.substring($L)] = $L;", + headerLocation.length(), headerValue); + writer.write("return acc;"); }); - } - }); + }); + } } private List readRequestBody( @@ -2179,12 +2277,17 @@ private List readBody( // isn't set in the body. // These are only relevant when a payload is not present, as it cannot // coexist with a payload. - for (HttpBinding responseCodeBinding : responseCodeBindings) { - // The name of the member to get from the input shape. - String memberName = symbolProvider.toMemberName(responseCodeBinding.getMember()); - writer.openBlock("if (contents.$L === undefined) {", "}", memberName, () -> - writer.write("contents.$L = output.statusCode;", memberName)); + if (!responseCodeBindings.isEmpty()) { + writer.openBlock("map(contents, {", "});", () -> { + for (HttpBinding responseCodeBinding : responseCodeBindings) { + // The name of the member to get from the input shape. + String memberName = symbolProvider.toMemberName(responseCodeBinding.getMember()); + + writer.write("$L: [, output.statusCode]", memberName); + } + }); } + if (!documentBindings.isEmpty()) { return documentBindings; } diff --git a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpProtocolGeneratorUtils.java b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpProtocolGeneratorUtils.java index 913a4f1cc37..c989388ea04 100644 --- a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpProtocolGeneratorUtils.java +++ b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpProtocolGeneratorUtils.java @@ -348,51 +348,64 @@ static Set generateErrorDispatcher( // Error responses must be at least BaseException interface SymbolReference baseExceptionReference = getClientBaseException(context); - writer.write("let response: $T;", baseExceptionReference); errorCodeGenerator.accept(context); - writer.openBlock("switch (errorCode) {", "}", () -> { - // Generate the case statement for each error, invoking the specific deserializer. - new TreeSet<>(operationIndex.getErrors(operation, context.getService())).forEach(error -> { - final ShapeId errorId = error.getId(); - // Track errors bound to the operation so their deserializers may be generated. - errorShapes.add(error); - Symbol errorSymbol = symbolProvider.toSymbol(error); - String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol, + + TreeSet structureShapes = new TreeSet<>( + operationIndex.getErrors(operation, context.getService()) + ); + + Runnable defaultErrorHandler = () -> { + if (shouldParseErrorBody) { + // Body is already parsed above + writer.write("const parsedBody = parsedOutput.body;"); + } else { + // Body is not parsed above, so parse it here + writer.write("const parsedBody = await parseBody(output.body, context);"); + } + + writer.addImport("throwDefaultError", "throwDefaultError", "@aws-sdk/smithy-client"); + + // Get the protocol specific error location for retrieving contents. + String errorLocation = bodyErrorLocationModifier.apply(context, "parsedBody"); + writer.openBlock("throwDefaultError({", "})", () -> { + writer.write("output,"); + if (errorLocation.equals("parsedBody")) { + writer.write("parsedBody,"); + } else { + writer.write("parsedBody: $L,", errorLocation); + } + writer.write("exceptionCtor: $T,", baseExceptionReference); + writer.write("errorCode"); + }); + }; + + if (!structureShapes.isEmpty()) { + writer.openBlock("switch (errorCode) {", "}", () -> { + // Generate the case statement for each error, invoking the specific deserializer. + + structureShapes.forEach(error -> { + final ShapeId errorId = error.getId(); + // Track errors bound to the operation so their deserializers may be generated. + errorShapes.add(error); + Symbol errorSymbol = symbolProvider.toSymbol(error); + String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol, context.getProtocolName()) + "Response"; - // Dispatch to the error deserialization function. - String outputParam = shouldParseErrorBody ? "parsedOutput" : "output"; - writer.write("case $S:", errorId.getName()); - writer.write("case $S:", errorId.toString()); - writer.indent() + // Dispatch to the error deserialization function. + String outputParam = shouldParseErrorBody ? "parsedOutput" : "output"; + writer.write("case $S:", errorId.getName()); + writer.write("case $S:", errorId.toString()); + writer.indent() .write("throw await $L($L, context);", errorDeserMethodName, outputParam) .dedent(); - }); + }); - // Build a generic error the best we can for ones we don't know about. - writer.write("default:").indent(); - if (shouldParseErrorBody) { - // Body is already parsed above - writer.write("const parsedBody = parsedOutput.body;"); - } else { - // Body is not parsed above, so parse it here - writer.write("const parsedBody = await parseBody(output.body, context);"); - } - - // Get the protocol specific error location for retrieving contents. - String errorLocation = bodyErrorLocationModifier.apply(context, "parsedBody"); - writer.write("const $$metadata = deserializeMetadata(output);"); - writer.write("const statusCode = $$metadata.httpStatusCode ? $$metadata.httpStatusCode" - + " + '' : undefined;"); - writer.openBlock("response = new $T({", "});", baseExceptionReference, () -> { - writer.write("name: $1L.code || $1L.Code || errorCode || statusCode || 'UnknowError',", - errorLocation); - writer.write("$$fault: \"client\","); - writer.write("$$metadata"); - }); - writer.addImport("decorateServiceException", "__decorateServiceException", - TypeScriptDependency.AWS_SMITHY_CLIENT.packageName); - writer.write("throw __decorateServiceException(response, $L);", errorLocation); - }); + // Build a generic error the best we can for ones we don't know about. + writer.write("default:").indent(); + defaultErrorHandler.run(); + }); + } else { + defaultErrorHandler.run(); + } }); writer.write("");