Skip to content

Commit

Permalink
feat(codegen): use object mapper to assign fields
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhe committed Jul 27, 2022
1 parent d3d2984 commit 8341b0b
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
Expand Down Expand Up @@ -93,6 +95,8 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator
private final boolean isErrorCodeInBody;
private final EventStreamGenerator eventStreamGenerator = new EventStreamGenerator();

private final LinkedHashMap<String, String> headerBuffer = new LinkedHashMap<>();

/**
* Creates a Http binding protocol generator.
*
Expand Down Expand Up @@ -146,6 +150,10 @@ public final ApplicationProtocol getApplicationProtocol() {

@Override
public void generateSharedComponents(GenerationContext context) {
TypeScriptWriter writer = context.getWriter();
writer.addImport("map", "__map", "@aws-sdk/smithy-client");
writer.write("const map = __map");

deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error));
serializingErrorShapes.forEach(error -> generateErrorSerializer(context, error));
ServiceShape service = context.getService();
Expand Down Expand Up @@ -769,7 +777,7 @@ private boolean writeRequestQueryString(
// Build the initial query bag.
Map<String, String> queryLiterals = trait.getUri().getQueryLiterals();
if (!queryLiterals.isEmpty() || !queryBindings.isEmpty() || !queryParamsBindings.isEmpty()) {
writer.openBlock("const query: any = {", "};", () -> {
writer.openBlock("const query: any = map({", "});", () -> {
if (!queryLiterals.isEmpty()) {
// Write any query literals present in the uri.
queryLiterals.forEach((k, v) -> writer.write("$S: $S,", k, v));
Expand Down Expand Up @@ -808,10 +816,37 @@ private void writeRequestQueryParam(
"@aws-sdk/smithy-client");

Shape target = model.expectShape(binding.getMember().getTarget());
String queryValue = getInputValue(context, binding.getLocation(), "input." + memberName,
binding.getMember(), target);
writer.write("...(input.$L !== undefined && { $S: $L }),", memberName,
binding.getLocationName(), queryValue);
String queryValue = getInputValue(
context,
binding.getLocation(),
"input." + memberName + "!",
binding.getMember(),
target
);

if (Objects.equals("input." + memberName + "!", queryValue)) {
// simple undefined check
writer.write(
"$S: [,$L],",
binding.getLocationName(),
queryValue
);
} else {
// undefined check with lazy eval
writer.write(
"$S: [() => input.$L !== void 0, () => $L],",
binding.getLocationName(),
memberName,
queryValue
);
}
}

private void flushHeadersBuffer(TypeScriptWriter writer) {
for (Map.Entry<String, String> entry : headerBuffer.entrySet()) {
writer.write(entry.getValue());
}
headerBuffer.clear();
}

private void writeRequestHeaders(
Expand All @@ -822,7 +857,7 @@ private void writeRequestHeaders(
TypeScriptWriter writer = context.getWriter();

// Headers are always present either from the default document or the payload.
writer.openBlock("const headers: any = {", "};", () -> {
writer.openBlock("const headers: any = map({}, isSerializableHeaderValue, {", "});", () -> {
// Only set the content type if one can be determined.
writeContentTypeHeader(context, operation, true);
writeDefaultInputHeaders(context, operation);
Expand All @@ -831,7 +866,11 @@ private void writeRequestHeaders(
for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.HEADER)) {
writeNormalHeader(context, binding);
}
});

flushHeadersBuffer(writer);

operation.getInput().ifPresent(outputId -> {
// Handle assembling prefix headers.
for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.PREFIX_HEADERS)) {
writePrefixHeaders(context, binding);
Expand All @@ -843,10 +882,41 @@ private void writeRequestHeaders(
private void writeNormalHeader(GenerationContext context, HttpBinding binding) {
String memberLocation = "input." + context.getSymbolProvider().toMemberName(binding.getMember());
Shape target = context.getModel().expectShape(binding.getMember().getTarget());
String headerValue = getInputValue(context, binding.getLocation(), memberLocation + "!",
binding.getMember(), target);
context.getWriter().write("...isSerializableHeaderValue($L) && { $S: $L },",
memberLocation, binding.getLocationName().toLowerCase(Locale.US), headerValue);

String headerKey = binding.getLocationName().toLowerCase(Locale.US);
String headerValue = getInputValue(
context,
binding.getLocation(),
memberLocation + "!",
binding.getMember(),
target
);

if (!Objects.equals(memberLocation + "!", headerValue)) {
String defaultValue = "";
if (headerBuffer.containsKey(headerKey)) {
String s = headerBuffer.get(headerKey);
defaultValue = " || " + s.substring(s.indexOf(": ") + 2, s.length() - 1);
}
// evaluated value has a function or method call attached
headerBuffer.put(headerKey, String.format(
"'%s': [() => isSerializableHeaderValue(%s), () => %s],",
headerKey,
memberLocation + defaultValue,
headerValue + defaultValue
));
} else {
String value = headerValue;
if (headerBuffer.containsKey(headerKey)) {
String s = headerBuffer.get(headerKey);
value = headerValue + " || " + s.substring(s.indexOf(": ") + 2, s.length() - 1);
}
headerBuffer.put(headerKey, String.format(
"'%s': %s,",
headerKey,
value
));
}
}

private void writePrefixHeaders(GenerationContext context, HttpBinding binding) {
Expand Down Expand Up @@ -881,14 +951,16 @@ private void writeResponseHeaders(
TypeScriptWriter writer = context.getWriter();

// Headers are always present either from the default document or the payload.
writer.openBlock("let headers: any = {", "};", () -> {
writer.openBlock("let headers: any = map({}, isSerializableHeaderValue, {", "});", () -> {
writeContentTypeHeader(context, operationOrError, false);
injectExtraHeaders.run();

for (HttpBinding binding : bindingIndex.getResponseBindings(operationOrError, Location.HEADER)) {
writeNormalHeader(context, binding);
}

flushHeadersBuffer(writer);

// Handle assembling prefix headers.
for (HttpBinding binding : bindingIndex.getResponseBindings(operationOrError, Location.PREFIX_HEADERS)) {
writePrefixHeaders(context, binding);
Expand All @@ -908,7 +980,12 @@ private void writeContentTypeHeader(GenerationContext context, Shape operationOr
if (!optionalContentType.isPresent() && shouldWriteDefaultBody(context, operationOrError, isInput)) {
optionalContentType = Optional.of(getDocumentContentType());
}
optionalContentType.ifPresent(contentType -> context.getWriter().write("'content-type': $S,", contentType));
optionalContentType.ifPresent(contentType -> {
// context.getWriter().write("'content-type': $S,", contentType)
headerBuffer.put("content-type",
"'content-type': '" + contentType + "',"
);
});
}

private List<HttpBinding> writeRequestBody(
Expand Down Expand Up @@ -1584,7 +1661,7 @@ private void generateOperationRequestDeserializer(
Shape target = model.expectShape(binding.getMember().getTarget());
deserializingDocumentShapes.add(target);
});
writer.write("return Promise.resolve(contents);");
writer.write("return contents;");
});
writer.write("");
}
Expand Down Expand Up @@ -1923,28 +2000,23 @@ private void generateOperationResponseDeserializer(
() -> writer.write("return $L(output, context);", errorMethodName));

// Start deserializing the response.
writer.openBlock("const contents: $T = {", "};", outputType, () -> {
writer.openBlock("const contents: any = map({", "});", () -> {
writer.write("$$metadata: deserializeMetadata(output),");

// Only set a type and the members if we have output.
operation.getOutput().ifPresent(outputId -> {
// Set all the members to undefined to meet type constraints.
StructureShape target = model.expectShape(outputId).asStructureShape().get();
new TreeMap<>(target.getAllMembers())
.forEach((memberName, memberShape) -> writer.write(
"$L: undefined,", memberName));
});
readResponseHeaders(context, operation, bindingIndex, "output");
});
readResponseHeaders(context, operation, bindingIndex, "output");

List<HttpBinding> documentBindings = readResponseBody(context, operation, bindingIndex);

// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Shape target = model.expectShape(binding.getMember().getTarget());
if (!EventStreamGenerator.isEventStreamShape(target)) {
deserializingDocumentShapes.add(target);
}
});
writer.write("return Promise.resolve(contents);");

writer.write("return contents;");
});
writer.write("");
// Write out the error deserialization dispatcher.
Expand Down Expand Up @@ -2053,12 +2125,16 @@ private void readNormalHeaders(
TypeScriptWriter writer = context.getWriter();
String memberName = context.getSymbolProvider().toMemberName(binding.getMember());
String headerName = binding.getLocationName().toLowerCase(Locale.US);
writer.openBlock("if ($L.headers[$S] !== undefined) {", "}", outputName, headerName, () -> {
Shape target = context.getModel().expectShape(binding.getMember().getTarget());
String headerValue = getOutputValue(context, binding.getLocation(),
outputName + ".headers['" + headerName + "']", binding.getMember(), target);
writer.write("contents.$L = $L;", memberName, headerValue);
});
Shape target = context.getModel().expectShape(binding.getMember().getTarget());
String headerValue = getOutputValue(context, binding.getLocation(),
outputName + ".headers['" + headerName + "']", binding.getMember(), target);
String checkedValue = outputName + ".headers['" + headerName + "']";

if (checkedValue.equals(headerValue)) {
writer.write("$L: [, $L],", memberName, headerValue);
} else {
writer.write("$L: [() => void 0 !== $L, () => $L],", memberName, checkedValue, headerValue);
}
}
}

Expand All @@ -2083,13 +2159,11 @@ private void readPrefixHeaders(
TypeScriptWriter writer = context.getWriter();

// Run through the headers one time, matching any prefix groups.
writer.openBlock("Object.keys($L.headers).forEach(header => {", "});", outputName, () -> {
writer.openBlock("...(Object.keys($L.headers).reduce((acc, header) => {", "}, {}));", outputName, () -> {
for (HttpBinding binding : prefixHeaderBindings) {
// Prepare a grab bag for these headers if necessary
String memberName = symbolProvider.toMemberName(binding.getMember());
writer.openBlock("if (contents.$L === undefined) {", "}", memberName, () -> {
writer.write("contents.$L = {};", memberName);
});
writer.write("acc.$L = [, {}];", memberName);

// Generate a single block for each group of lower-cased prefix headers.
String headerLocation = binding.getLocationName().toLowerCase(Locale.US);
Expand All @@ -2100,7 +2174,7 @@ private void readPrefixHeaders(
outputName + ".headers[header]", binding.getMember(), target);

// Extract the non-prefix portion as the key.
writer.write("contents.$L[header.substring($L)] = $L;",
writer.write("acc.$L[header.substring($L)] = [, $L];",
memberName, headerLocation.length(), headerValue);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,28 +370,28 @@ static Set<StructureShape> generateErrorDispatcher(

// Build a generic error the best we can for ones we don't know about.
writer.write("default:").indent();
if (shouldParseErrorBody) {
// Body is already parsed above
writer.write("const parsedBody = parsedOutput.body;");
} else {
// Body is not parsed above, so parse it here
writer.write("const parsedBody = await parseBody(output.body, context);");
}
if (shouldParseErrorBody) {
// Body is already parsed above
writer.write("const parsedBody = parsedOutput.body;");
} else {
// Body is not parsed above, so parse it here
writer.write("const parsedBody = await parseBody(output.body, context);");
}

// Get the protocol specific error location for retrieving contents.
String errorLocation = bodyErrorLocationModifier.apply(context, "parsedBody");
writer.write("const $$metadata = deserializeMetadata(output);");
writer.write("const statusCode = $$metadata.httpStatusCode ? $$metadata.httpStatusCode"
+ " + '' : undefined;");
writer.openBlock("response = new $T({", "});", baseExceptionReference, () -> {
writer.write("name: $1L.code || $1L.Code || errorCode || statusCode || 'UnknowError',",
errorLocation);
writer.write("$$fault: \"client\",");
writer.write("$$metadata");
});
writer.addImport("decorateServiceException", "__decorateServiceException",
TypeScriptDependency.AWS_SMITHY_CLIENT.packageName);
writer.write("throw __decorateServiceException(response, $L);", errorLocation);
// Get the protocol specific error location for retrieving contents.
String errorLocation = bodyErrorLocationModifier.apply(context, "parsedBody");
writer.write("const $$metadata = deserializeMetadata(output);");
writer.write("const statusCode = $$metadata.httpStatusCode ? $$metadata.httpStatusCode"
+ " + '' : undefined;");
writer.openBlock("response = new $T({", "});", baseExceptionReference, () -> {
writer.write("name: $1L.code || $1L.Code || errorCode || statusCode || 'UnknowError',",
errorLocation);
writer.write("$$fault: \"client\",");
writer.write("$$metadata");
});
writer.addImport("decorateServiceException", "__decorateServiceException",
TypeScriptDependency.AWS_SMITHY_CLIENT.packageName);
writer.write("throw __decorateServiceException(response, $L);", errorLocation);
});
});
writer.write("");
Expand Down

0 comments on commit 8341b0b

Please sign in to comment.