diff --git a/aws-serverless-java-container-springboot3/pom.xml b/aws-serverless-java-container-springboot3/pom.xml
index 42a1edf7b..5a9843e16 100644
--- a/aws-serverless-java-container-springboot3/pom.xml
+++ b/aws-serverless-java-container-springboot3/pom.xml
@@ -25,7 +25,7 @@
org.springframework.cloud
spring-cloud-function-serverless-web
- 4.0.4
+ 4.0.6
com.amazonaws.serverless
@@ -201,6 +201,11 @@
${basedir}/target/coverage-reports/jacoco-unit.exec
${basedir}/target/coverage-reports/jacoco-unit.exec
+
+
+ com/amazonaws/serverless/proxy/spring/AwsSpringWebCustomRuntimeEventLoop*
+ com/amazonaws/serverless/proxy/spring/AwsSpringAotTypesProcessor*
+
diff --git a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringAotTypesProcessor.java b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringAotTypesProcessor.java
new file mode 100644
index 000000000..f7fbe9e25
--- /dev/null
+++ b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringAotTypesProcessor.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.amazonaws.serverless.proxy.spring;
+
+import org.springframework.aot.generate.GenerationContext;
+import org.springframework.aot.hint.MemberCategory;
+import org.springframework.aot.hint.RuntimeHints;
+import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
+import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
+import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
+import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
+
+import com.amazonaws.serverless.proxy.internal.servlet.AwsHttpServletResponse;
+import com.amazonaws.serverless.proxy.model.ApiGatewayRequestIdentity;
+import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
+import com.amazonaws.serverless.proxy.model.AwsProxyRequestContext;
+import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
+import com.amazonaws.serverless.proxy.model.Headers;
+import com.amazonaws.serverless.proxy.model.MultiValuedTreeMap;
+import com.amazonaws.serverless.proxy.model.SingleValueHeaders;
+import com.fasterxml.jackson.core.JsonToken;
+
+/**
+ * AOT Initialization processor required to register reflective hints for GraalVM.
+ * This is necessary to ensure proper JSON serialization/deserialization.
+ * It is registered with META-INF/spring/aot.factories
+ *
+ * @author Oleg Zhurakousky
+ */
+public class AwsSpringAotTypesProcessor implements BeanFactoryInitializationAotProcessor {
+
+ @Override
+ public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
+ return new ReflectiveProcessorBeanFactoryInitializationAotContribution();
+ }
+
+ private static final class ReflectiveProcessorBeanFactoryInitializationAotContribution implements BeanFactoryInitializationAotContribution {
+ @Override
+ public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) {
+ RuntimeHints runtimeHints = generationContext.getRuntimeHints();
+ // known static types
+
+ runtimeHints.reflection().registerType(AwsProxyRequest.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(AwsProxyResponse.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(SingleValueHeaders.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(JsonToken.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(MultiValuedTreeMap.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(Headers.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(AwsProxyRequestContext.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(ApiGatewayRequestIdentity.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES);
+ runtimeHints.reflection().registerType(AwsHttpServletResponse.class,
+ MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS,
+ MemberCategory.DECLARED_FIELDS, MemberCategory.DECLARED_CLASSES, MemberCategory.INTROSPECT_DECLARED_METHODS);
+ }
+
+ }
+}
diff --git a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java
new file mode 100644
index 000000000..d268bd2e4
--- /dev/null
+++ b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java
@@ -0,0 +1,163 @@
+package com.amazonaws.serverless.proxy.spring;
+
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.cloud.function.serverless.web.ServerlessHttpServletRequest;
+import org.springframework.cloud.function.serverless.web.ServerlessMVC;
+import org.springframework.util.FileCopyUtils;
+import org.springframework.util.MultiValueMapAdapter;
+import org.springframework.util.StringUtils;
+
+import com.amazonaws.serverless.proxy.AsyncInitializationWrapper;
+import com.amazonaws.serverless.proxy.AwsHttpApiV2SecurityContextWriter;
+import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter;
+import com.amazonaws.serverless.proxy.RequestReader;
+import com.amazonaws.serverless.proxy.SecurityContextWriter;
+import com.amazonaws.serverless.proxy.internal.servlet.AwsHttpServletResponse;
+import com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletResponseWriter;
+import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
+import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
+import com.amazonaws.serverless.proxy.model.HttpApiV2ProxyRequest;
+import com.amazonaws.services.lambda.runtime.Context;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import jakarta.servlet.ServletContext;
+import jakarta.servlet.http.HttpServletRequest;
+
+class AwsSpringHttpProcessingUtils {
+
+ private static Log logger = LogFactory.getLog(AwsSpringHttpProcessingUtils.class);
+ private static final int LAMBDA_MAX_REQUEST_DURATION_MINUTES = 15;
+
+ private AwsSpringHttpProcessingUtils() {
+
+ }
+
+ public static AwsProxyResponse processRequest(HttpServletRequest request, ServerlessMVC mvc,
+ AwsProxyHttpServletResponseWriter responseWriter) {
+ CountDownLatch latch = new CountDownLatch(1);
+ AwsHttpServletResponse response = new AwsHttpServletResponse(request, latch);
+ try {
+ mvc.service(request, response);
+ boolean requestTimedOut = !latch.await(LAMBDA_MAX_REQUEST_DURATION_MINUTES, TimeUnit.MINUTES); // timeout is potentially lower as user configures it
+ if (requestTimedOut) {
+ logger.warn("request timed out after " + LAMBDA_MAX_REQUEST_DURATION_MINUTES + " minutes");
+ }
+ AwsProxyResponse awsResponse = responseWriter.writeResponse(response, null);
+ return awsResponse;
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public static String extractVersion() {
+ try {
+ String path = AwsSpringHttpProcessingUtils.class.getProtectionDomain().getCodeSource().getLocation().toString();
+ int endIndex = path.lastIndexOf('.');
+ if (endIndex < 0) {
+ return "UNKNOWN-VERSION";
+ }
+ int startIndex = path.lastIndexOf("/") + 1;
+ return path.substring(startIndex, endIndex).replace("spring-cloud-function-serverless-web-", "");
+ }
+ catch (Exception e) {
+ if (logger.isDebugEnabled()) {
+ logger.debug("Failed to detect version", e);
+ }
+ return "UNKNOWN-VERSION";
+ }
+
+ }
+
+ public static HttpServletRequest generateHttpServletRequest(InputStream jsonRequest, Context lambdaContext,
+ ServletContext servletContext, ObjectMapper mapper) {
+ try {
+ String text = new String(FileCopyUtils.copyToByteArray(jsonRequest), StandardCharsets.UTF_8);
+ if (logger.isDebugEnabled()) {
+ logger.debug("Creating HttpServletRequest from: " + text);
+ }
+ return generateHttpServletRequest(text, lambdaContext, servletContext, mapper);
+ } catch (Exception e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ public static HttpServletRequest generateHttpServletRequest(String jsonRequest, Context lambdaContext,
+ ServletContext servletContext, ObjectMapper mapper) {
+ Map _request = readValue(jsonRequest, Map.class, mapper);
+ SecurityContextWriter securityWriter = "2.0".equals(_request.get("version"))
+ ? new AwsHttpApiV2SecurityContextWriter()
+ : new AwsProxySecurityContextWriter();
+ HttpServletRequest httpServletRequest = "2.0".equals(_request.get("version"))
+ ? AwsSpringHttpProcessingUtils.generateRequest2(jsonRequest, lambdaContext, securityWriter, mapper, servletContext)
+ : AwsSpringHttpProcessingUtils.generateRequest1(jsonRequest, lambdaContext, securityWriter, mapper, servletContext);
+ return httpServletRequest;
+ }
+
+ @SuppressWarnings({ "unchecked", "rawtypes" })
+ private static HttpServletRequest generateRequest1(String request, Context lambdaContext,
+ SecurityContextWriter securityWriter, ObjectMapper mapper, ServletContext servletContext) {
+ AwsProxyRequest v1Request = readValue(request, AwsProxyRequest.class, mapper);
+
+ ServerlessHttpServletRequest httpRequest = new ServerlessHttpServletRequest(servletContext, v1Request.getHttpMethod(), v1Request.getPath());
+ if (v1Request.getMultiValueHeaders() != null) {
+ MultiValueMapAdapter headers = new MultiValueMapAdapter(v1Request.getMultiValueHeaders());
+ httpRequest.setHeaders(headers);
+ }
+ if (StringUtils.hasText(v1Request.getBody())) {
+ httpRequest.setContentType("application/json");
+ httpRequest.setContent(v1Request.getBody().getBytes(StandardCharsets.UTF_8));
+ }
+ if (v1Request.getRequestContext() != null) {
+ httpRequest.setAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY, v1Request.getRequestContext());
+ httpRequest.setAttribute(RequestReader.ALB_CONTEXT_PROPERTY, v1Request.getRequestContext().getElb());
+ }
+ httpRequest.setAttribute(RequestReader.API_GATEWAY_STAGE_VARS_PROPERTY, v1Request.getStageVariables());
+ httpRequest.setAttribute(RequestReader.API_GATEWAY_EVENT_PROPERTY, v1Request);
+ httpRequest.setAttribute(RequestReader.LAMBDA_CONTEXT_PROPERTY, lambdaContext);
+ httpRequest.setAttribute(RequestReader.JAX_SECURITY_CONTEXT_PROPERTY,
+ securityWriter.writeSecurityContext(v1Request, lambdaContext));
+ return httpRequest;
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ private static HttpServletRequest generateRequest2(String request, Context lambdaContext,
+ SecurityContextWriter securityWriter, ObjectMapper mapper, ServletContext servletContext) {
+ HttpApiV2ProxyRequest v2Request = readValue(request, HttpApiV2ProxyRequest.class, mapper);
+ ServerlessHttpServletRequest httpRequest = new ServerlessHttpServletRequest(servletContext,
+ v2Request.getRequestContext().getHttp().getMethod(), v2Request.getRequestContext().getHttp().getPath());
+
+ v2Request.getHeaders().forEach(httpRequest::setHeader);
+
+ if (StringUtils.hasText(v2Request.getBody())) {
+ httpRequest.setContentType("application/json");
+ httpRequest.setContent(v2Request.getBody().getBytes(StandardCharsets.UTF_8));
+ }
+ httpRequest.setAttribute(RequestReader.HTTP_API_CONTEXT_PROPERTY, v2Request.getRequestContext());
+ httpRequest.setAttribute(RequestReader.HTTP_API_STAGE_VARS_PROPERTY, v2Request.getStageVariables());
+ httpRequest.setAttribute(RequestReader.HTTP_API_EVENT_PROPERTY, v2Request);
+ httpRequest.setAttribute(RequestReader.LAMBDA_CONTEXT_PROPERTY, lambdaContext);
+ httpRequest.setAttribute(RequestReader.JAX_SECURITY_CONTEXT_PROPERTY,
+ securityWriter.writeSecurityContext(v2Request, lambdaContext));
+ return httpRequest;
+ }
+
+ private static T readValue(String json, Class clazz, ObjectMapper mapper) {
+ try {
+ return mapper.readValue(json, clazz);
+ }
+ catch (Exception e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+}
diff --git a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringWebCustomRuntimeEventLoop.java b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringWebCustomRuntimeEventLoop.java
new file mode 100644
index 000000000..db71d56c0
--- /dev/null
+++ b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringWebCustomRuntimeEventLoop.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.amazonaws.serverless.proxy.spring;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.net.URI;
+import java.text.MessageFormat;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.boot.web.servlet.context.ServletWebServerApplicationContext;
+import org.springframework.cloud.function.serverless.web.ServerlessMVC;
+import org.springframework.context.SmartLifecycle;
+import org.springframework.core.env.Environment;
+import org.springframework.http.RequestEntity;
+import org.springframework.http.ResponseEntity;
+import org.springframework.web.client.RestTemplate;
+
+import com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletResponseWriter;
+import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializationFeature;
+
+/**
+ * Event loop and necessary configurations to support AWS Lambda Custom Runtime
+ * - https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html.
+ *
+ * @author Oleg Zhurakousky
+ * @author Mark Sailes
+ *
+ */
+public final class AwsSpringWebCustomRuntimeEventLoop implements SmartLifecycle {
+
+ private static Log logger = LogFactory.getLog(AwsSpringWebCustomRuntimeEventLoop.class);
+
+ static final String LAMBDA_VERSION_DATE = "2018-06-01";
+ private static final String LAMBDA_ERROR_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/{2}/error";
+ private static final String LAMBDA_RUNTIME_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/next";
+ private static final String LAMBDA_INVOCATION_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/{2}/response";
+ private static final String USER_AGENT_VALUE = String.format("spring-cloud-function/%s-%s",
+ System.getProperty("java.runtime.version"), AwsSpringHttpProcessingUtils.extractVersion());
+
+ private final ServletWebServerApplicationContext applicationContext;
+
+ private volatile boolean running;
+
+ private final ExecutorService executor = Executors.newSingleThreadExecutor();
+
+ public AwsSpringWebCustomRuntimeEventLoop(ServletWebServerApplicationContext applicationContext) {
+ this.applicationContext = applicationContext;
+ }
+
+ public void run() {
+ this.running = true;
+ this.executor.execute(() -> {
+ eventLoop(this.applicationContext);
+ });
+ }
+
+ @Override
+ public void start() {
+ this.run();
+ }
+
+ @Override
+ public void stop() {
+ this.executor.shutdownNow();
+ this.running = false;
+ }
+
+ @Override
+ public boolean isRunning() {
+ return this.running;
+ }
+
+ private void eventLoop(ServletWebServerApplicationContext context) {
+ ServerlessMVC mvc = ServerlessMVC.INSTANCE(context);
+
+ Environment environment = context.getEnvironment();
+ logger.info("Starting AWSWebRuntimeEventLoop");
+ if (logger.isDebugEnabled()) {
+ logger.debug("AWS LAMBDA ENVIRONMENT: " + System.getenv());
+ }
+
+ String runtimeApi = environment.getProperty("AWS_LAMBDA_RUNTIME_API");
+ String eventUri = MessageFormat.format(LAMBDA_RUNTIME_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE);
+ if (logger.isDebugEnabled()) {
+ logger.debug("Event URI: " + eventUri);
+ }
+
+ RequestEntity requestEntity = RequestEntity.get(URI.create(eventUri))
+ .header("User-Agent", USER_AGENT_VALUE).build();
+ RestTemplate rest = new RestTemplate();
+ ObjectMapper mapper = new ObjectMapper();//.getBean(ObjectMapper.class);
+ mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ AwsProxyHttpServletResponseWriter responseWriter = new AwsProxyHttpServletResponseWriter();
+
+ logger.info("Entering event loop");
+ while (this.isRunning()) {
+ logger.debug("Attempting to get new event");
+ ResponseEntity incomingEvent = rest.exchange(requestEntity, String.class);
+
+ if (incomingEvent != null && incomingEvent.hasBody()) {
+ if (logger.isDebugEnabled()) {
+ logger.debug("New Event received from AWS Gateway: " + incomingEvent.getBody());
+ }
+ String requestId = incomingEvent.getHeaders().getFirst("Lambda-Runtime-Aws-Request-Id");
+
+ try {
+ logger.debug("Submitting request to the user's web application");
+
+ AwsProxyResponse awsResponse = AwsSpringHttpProcessingUtils.processRequest(
+ AwsSpringHttpProcessingUtils.generateHttpServletRequest(incomingEvent.getBody(),
+ null, mvc.getServletContext(), mapper), mvc, responseWriter);
+ if (logger.isDebugEnabled()) {
+ logger.debug("Received response - body: " + awsResponse.getBody() +
+ "; status: " + awsResponse.getStatusCode() + "; headers: " + awsResponse.getHeaders());
+ }
+
+ String invocationUrl = MessageFormat.format(LAMBDA_INVOCATION_URL_TEMPLATE, runtimeApi,
+ LAMBDA_VERSION_DATE, requestId);
+
+ ResponseEntity result = rest.exchange(RequestEntity.post(URI.create(invocationUrl))
+ .header("User-Agent", USER_AGENT_VALUE).body(awsResponse), byte[].class);
+ if (logger.isDebugEnabled()) {
+ logger.debug("Response sent: body: " + result.getBody() +
+ "; status: " + result.getStatusCode() + "; headers: " + result.getHeaders());
+ }
+ if (logger.isInfoEnabled()) {
+ logger.info("Result POST status: " + result);
+ }
+ }
+ catch (Exception e) {
+ logger.error(e);
+ this.propagateAwsError(requestId, e, mapper, runtimeApi, rest);
+ }
+ }
+ }
+ }
+
+ private void propagateAwsError(String requestId, Exception e, ObjectMapper mapper, String runtimeApi, RestTemplate rest) {
+ String errorMessage = e.getMessage();
+ String errorType = e.getClass().getSimpleName();
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw);
+ e.printStackTrace(pw);
+ String stackTrace = sw.toString();
+ Map em = new HashMap<>();
+ em.put("errorMessage", errorMessage);
+ em.put("errorType", errorType);
+ em.put("stackTrace", stackTrace);
+ try {
+ byte[] outputBody = mapper.writeValueAsBytes(em);
+ String errorUrl = MessageFormat.format(LAMBDA_ERROR_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE, requestId);
+ ResponseEntity