Skip to content

Commit

Permalink
feat: native Spring Web workloads (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
deki authored Jan 30, 2024
2 parents e1acefc + 4533c4b commit 7ca2f07
Show file tree
Hide file tree
Showing 37 changed files with 2,110 additions and 103 deletions.
7 changes: 6 additions & 1 deletion aws-serverless-java-container-springboot3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-function-serverless-web</artifactId>
<version>4.0.4</version>
<version>4.0.6</version>
</dependency>
<dependency>
<groupId>com.amazonaws.serverless</groupId>
Expand Down Expand Up @@ -201,6 +201,11 @@
<configuration>
<destFile>${basedir}/target/coverage-reports/jacoco-unit.exec</destFile>
<dataFile>${basedir}/target/coverage-reports/jacoco-unit.exec</dataFile>
<excludes>
<!-- Native AOT implementation is currently not covered (due to complexity of the test setup) -->
<exclude>com/amazonaws/serverless/proxy/spring/AwsSpringWebCustomRuntimeEventLoop*</exclude>
<exclude>com/amazonaws/serverless/proxy/spring/AwsSpringAotTypesProcessor*</exclude>
</excludes>
</configuration>
<executions>
<execution>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.amazonaws.serverless.proxy.spring;

import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;

import com.amazonaws.serverless.proxy.internal.servlet.AwsHttpServletResponse;
import com.amazonaws.serverless.proxy.model.ApiGatewayRequestIdentity;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyRequestContext;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.model.Headers;
import com.amazonaws.serverless.proxy.model.MultiValuedTreeMap;
import com.amazonaws.serverless.proxy.model.SingleValueHeaders;
import com.fasterxml.jackson.core.JsonToken;

/**
* AOT Initialization processor required to register reflective hints for GraalVM.
* This is necessary to ensure proper JSON serialization/deserialization.
* It is registered with META-INF/spring/aot.factories
*
* @author Oleg Zhurakousky
*/
public class AwsSpringAotTypesProcessor implements BeanFactoryInitializationAotProcessor {

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
return new ReflectiveProcessorBeanFactoryInitializationAotContribution();
}

private static final class ReflectiveProcessorBeanFactoryInitializationAotContribution implements BeanFactoryInitializationAotContribution {
@Override
public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) {
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
// known static types

runtimeHints.reflection().registerType(AwsProxyRequest.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(AwsProxyResponse.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(SingleValueHeaders.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(JsonToken.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(MultiValuedTreeMap.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(Headers.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(AwsProxyRequestContext.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(ApiGatewayRequestIdentity.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
runtimeHints.reflection().registerType(AwsHttpServletResponse.class,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS,
MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES, MemberCategory.INTROSPECT_DECLARED_METHODS);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package com.amazonaws.serverless.proxy.spring;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.serverless.web.ServerlessHttpServletRequest;
import org.springframework.cloud.function.serverless.web.ServerlessMVC;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.MultiValueMapAdapter;
import org.springframework.util.StringUtils;

import com.amazonaws.serverless.proxy.AsyncInitializationWrapper;
import com.amazonaws.serverless.proxy.AwsHttpApiV2SecurityContextWriter;
import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter;
import com.amazonaws.serverless.proxy.RequestReader;
import com.amazonaws.serverless.proxy.SecurityContextWriter;
import com.amazonaws.serverless.proxy.internal.servlet.AwsHttpServletResponse;
import com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletResponseWriter;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.model.HttpApiV2ProxyRequest;
import com.amazonaws.services.lambda.runtime.Context;
import com.fasterxml.jackson.databind.ObjectMapper;

import jakarta.servlet.ServletContext;
import jakarta.servlet.http.HttpServletRequest;

class AwsSpringHttpProcessingUtils {

private static Log logger = LogFactory.getLog(AwsSpringHttpProcessingUtils.class);
private static final int LAMBDA_MAX_REQUEST_DURATION_MINUTES = 15;

private AwsSpringHttpProcessingUtils() {

}

public static AwsProxyResponse processRequest(HttpServletRequest request, ServerlessMVC mvc,
AwsProxyHttpServletResponseWriter responseWriter) {
CountDownLatch latch = new CountDownLatch(1);
AwsHttpServletResponse response = new AwsHttpServletResponse(request, latch);
try {
mvc.service(request, response);
boolean requestTimedOut = !latch.await(LAMBDA_MAX_REQUEST_DURATION_MINUTES, TimeUnit.MINUTES); // timeout is potentially lower as user configures it
if (requestTimedOut) {
logger.warn("request timed out after " + LAMBDA_MAX_REQUEST_DURATION_MINUTES + " minutes");
}
AwsProxyResponse awsResponse = responseWriter.writeResponse(response, null);
return awsResponse;
}
catch (Exception e) {
e.printStackTrace();
throw new IllegalStateException(e);
}
}

public static String extractVersion() {
try {
String path = AwsSpringHttpProcessingUtils.class.getProtectionDomain().getCodeSource().getLocation().toString();
int endIndex = path.lastIndexOf('.');
if (endIndex < 0) {
return "UNKNOWN-VERSION";
}
int startIndex = path.lastIndexOf("/") + 1;
return path.substring(startIndex, endIndex).replace("spring-cloud-function-serverless-web-", "");
}
catch (Exception e) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to detect version", e);
}
return "UNKNOWN-VERSION";
}

}

public static HttpServletRequest generateHttpServletRequest(InputStream jsonRequest, Context lambdaContext,
ServletContext servletContext, ObjectMapper mapper) {
try {
String text = new String(FileCopyUtils.copyToByteArray(jsonRequest), StandardCharsets.UTF_8);
if (logger.isDebugEnabled()) {
logger.debug("Creating HttpServletRequest from: " + text);
}
return generateHttpServletRequest(text, lambdaContext, servletContext, mapper);
} catch (Exception e) {
throw new IllegalStateException(e);
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
public static HttpServletRequest generateHttpServletRequest(String jsonRequest, Context lambdaContext,
ServletContext servletContext, ObjectMapper mapper) {
Map<String, Object> _request = readValue(jsonRequest, Map.class, mapper);
SecurityContextWriter securityWriter = "2.0".equals(_request.get("version"))
? new AwsHttpApiV2SecurityContextWriter()
: new AwsProxySecurityContextWriter();
HttpServletRequest httpServletRequest = "2.0".equals(_request.get("version"))
? AwsSpringHttpProcessingUtils.generateRequest2(jsonRequest, lambdaContext, securityWriter, mapper, servletContext)
: AwsSpringHttpProcessingUtils.generateRequest1(jsonRequest, lambdaContext, securityWriter, mapper, servletContext);
return httpServletRequest;
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private static HttpServletRequest generateRequest1(String request, Context lambdaContext,
SecurityContextWriter securityWriter, ObjectMapper mapper, ServletContext servletContext) {
AwsProxyRequest v1Request = readValue(request, AwsProxyRequest.class, mapper);

ServerlessHttpServletRequest httpRequest = new ServerlessHttpServletRequest(servletContext, v1Request.getHttpMethod(), v1Request.getPath());
if (v1Request.getMultiValueHeaders() != null) {
MultiValueMapAdapter headers = new MultiValueMapAdapter(v1Request.getMultiValueHeaders());
httpRequest.setHeaders(headers);
}
if (StringUtils.hasText(v1Request.getBody())) {
httpRequest.setContentType("application/json");
httpRequest.setContent(v1Request.getBody().getBytes(StandardCharsets.UTF_8));
}
if (v1Request.getRequestContext() != null) {
httpRequest.setAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY, v1Request.getRequestContext());
httpRequest.setAttribute(RequestReader.ALB_CONTEXT_PROPERTY, v1Request.getRequestContext().getElb());
}
httpRequest.setAttribute(RequestReader.API_GATEWAY_STAGE_VARS_PROPERTY, v1Request.getStageVariables());
httpRequest.setAttribute(RequestReader.API_GATEWAY_EVENT_PROPERTY, v1Request);
httpRequest.setAttribute(RequestReader.LAMBDA_CONTEXT_PROPERTY, lambdaContext);
httpRequest.setAttribute(RequestReader.JAX_SECURITY_CONTEXT_PROPERTY,
securityWriter.writeSecurityContext(v1Request, lambdaContext));
return httpRequest;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private static HttpServletRequest generateRequest2(String request, Context lambdaContext,
SecurityContextWriter securityWriter, ObjectMapper mapper, ServletContext servletContext) {
HttpApiV2ProxyRequest v2Request = readValue(request, HttpApiV2ProxyRequest.class, mapper);
ServerlessHttpServletRequest httpRequest = new ServerlessHttpServletRequest(servletContext,
v2Request.getRequestContext().getHttp().getMethod(), v2Request.getRequestContext().getHttp().getPath());

v2Request.getHeaders().forEach(httpRequest::setHeader);

if (StringUtils.hasText(v2Request.getBody())) {
httpRequest.setContentType("application/json");
httpRequest.setContent(v2Request.getBody().getBytes(StandardCharsets.UTF_8));
}
httpRequest.setAttribute(RequestReader.HTTP_API_CONTEXT_PROPERTY, v2Request.getRequestContext());
httpRequest.setAttribute(RequestReader.HTTP_API_STAGE_VARS_PROPERTY, v2Request.getStageVariables());
httpRequest.setAttribute(RequestReader.HTTP_API_EVENT_PROPERTY, v2Request);
httpRequest.setAttribute(RequestReader.LAMBDA_CONTEXT_PROPERTY, lambdaContext);
httpRequest.setAttribute(RequestReader.JAX_SECURITY_CONTEXT_PROPERTY,
securityWriter.writeSecurityContext(v2Request, lambdaContext));
return httpRequest;
}

private static <T> T readValue(String json, Class<T> clazz, ObjectMapper mapper) {
try {
return mapper.readValue(json, clazz);
}
catch (Exception e) {
throw new IllegalStateException(e);
}
}

}
Loading

0 comments on commit 7ca2f07

Please sign in to comment.