Skip to content

Commit

Permalink
Add delegation proxy to spring web proxy
Browse files Browse the repository at this point in the history
This PR allows for separation of responsibilty between AWS and Spring where AWS side remains responsible to create HttpServletRequest from JSON representation of the API Gateway's input stream. Once created it can now delegate to Spring provided module for further interaction with the DispatcherServlet and the rest of thet Spring stack

Introduce a boolean flag to disable delegation to Spring
  • Loading branch information
olegz committed May 31, 2023
1 parent 013babd commit 4a61cfe
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@ public ServletContext getServletContext() {
* Sets the ServletContext in the handler and initialized a new <code>FilterChainManager</code>
* @param context An initialized ServletContext
*/
protected void setServletContext(final ServletContext context) {
public void setServletContext(final ServletContext context) {
servletContext = context;
// We assume custom implementations of the RequestWriter for HttpServletRequest will reuse
// the existing AwsServletContext object since it has no dependencies other than the Lambda context
filterChainManager = new AwsFilterChainManager((AwsServletContext)servletContext);
if (context instanceof AwsServletContext) {
filterChainManager = new AwsFilterChainManager((AwsServletContext)servletContext);
}
}

protected FilterChain getFilterChain(HttpServletRequest req, Servlet servlet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class AwsProxyRequest {
private String resource;
private AwsProxyRequestContext requestContext;
private MultiValuedTreeMap<String, String> multiValueQueryStringParameters;
private Map<String, String> queryStringParameters;
private Map<String, String> queryStringParameters;
private Headers multiValueHeaders;
private SingleValueHeaders headers;
private Map<String, String> pathParameters;
Expand All @@ -41,12 +41,13 @@ public class AwsProxyRequest {
private String path;
private boolean isBase64Encoded;

public AwsProxyRequest() {
multiValueHeaders = new Headers();
multiValueQueryStringParameters = new MultiValuedTreeMap<>();
pathParameters = new HashMap<>();
stageVariables = new HashMap<>();
}
public AwsProxyRequest() {
this.headers = new SingleValueHeaders();
multiValueHeaders = new Headers();
multiValueQueryStringParameters = new MultiValuedTreeMap<>();
pathParameters = new HashMap<>();
stageVariables = new HashMap<>();
}


//-------------------------------------------------------------
Expand Down Expand Up @@ -131,17 +132,21 @@ public Headers getMultiValueHeaders() {
return multiValueHeaders;
}

public void setMultiValueHeaders(Headers multiValueHeaders) {
this.multiValueHeaders = multiValueHeaders;
}
public void setMultiValueHeaders(Headers multiValueHeaders) {
if (multiValueHeaders != null) {
this.multiValueHeaders = multiValueHeaders;
}
}

public SingleValueHeaders getHeaders() {
return headers;
}

public void setHeaders(SingleValueHeaders headers) {
this.headers = headers;
}
public void setHeaders(SingleValueHeaders headers) {
if (headers != null) {
this.headers = headers;
}
}


public Map<String, String> getPathParameters() {
Expand Down
12 changes: 5 additions & 7 deletions aws-serverless-java-container-spring/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-function-serverless-web</artifactId>
<version>4.0.3</version>
</dependency>
<!-- Core interfaces for the aws-serverless-java-container project -->
<dependency>
<groupId>com.amazonaws.serverless</groupId>
Expand Down Expand Up @@ -57,13 +62,6 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>jakarta.activation</groupId>
<artifactId>jakarta.activation-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.amazonaws.serverless.proxy.internal.servlet.*;
import com.amazonaws.serverless.proxy.model.HttpApiV2ProxyRequest;
import com.amazonaws.services.lambda.runtime.Context;

import org.springframework.web.context.ConfigurableWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;

Expand Down Expand Up @@ -49,13 +50,19 @@ public class SpringLambdaContainerHandler<RequestType, ResponseType> extends Aws
* @return An initialized instance of the `SpringLambdaContainerHandler`
* @throws ContainerInitializationException When the Spring framework fails to start.
*/
public static SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> getAwsProxyHandler(Class<?>... config) throws ContainerInitializationException {
return new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy()
.initializationWrapper(new InitializationWrapper())
.configurationClasses(config)
.buildAndInitialize();
}
public static SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> getAwsProxyHandler(Class<?>... config)
throws ContainerInitializationException {
// Temporary flag. Should be removed once spring-cloud-function delegation model becomes the only path.
boolean delegateToSpringCloudFunction = Boolean.parseBoolean(System.getenv().getOrDefault("spring.cloud.function.enable",
(String) System.getProperties().getOrDefault("spring.cloud.function.enable", "true")));
if (delegateToSpringCloudFunction) {
return getSpringNativeHandler(config);
} else {
return new SpringProxyHandlerBuilder<AwsProxyRequest>().defaultProxy()
.initializationWrapper(new InitializationWrapper()).configurationClasses(config)
.buildAndInitialize();
}
}

/**
* Creates a default SpringLambdaContainerHandler initialized with the `AwsProxyRequest` and `AwsProxyResponse` objects and sets the given profiles as active
Expand Down Expand Up @@ -188,4 +195,11 @@ protected void registerServlets() {
reg.addMapping("/");
reg.setLoadOnStartup(1);
}

private static SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> getSpringNativeHandler(Class<?>... config) throws ContainerInitializationException {
SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler = new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy().configurationClasses(config).buildSpringProxy();

return handler;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.servlet.AwsHttpServletResponse;
import com.amazonaws.serverless.proxy.internal.servlet.ServletLambdaContainerHandlerBuilder;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.services.lambda.runtime.Context;

import org.springframework.cloud.function.serverless.web.ProxyMvc;
import org.springframework.web.context.ConfigurableWebApplicationContext;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.function.BiFunction;


public class SpringProxyHandlerBuilder<RequestType> extends ServletLambdaContainerHandlerBuilder<
RequestType,
Expand Down Expand Up @@ -73,13 +80,54 @@ public SpringLambdaContainerHandler<RequestType, AwsProxyResponse> build() throw
return handler;
}

/**
* Builds an instance of SpringLambdaContainerHandler with "delegate" to Spring provided ProxyMvc. The delegate
* is provided via BiFunction which takes HttpServletRequest and HttpSerbletResponse as input parameters.
* The AWS context is set as attribute of HttpServletRequest under `AWS_CONTEXT` key.
*
* @return instance of SpringLambdaContainerHandler
*/
SpringLambdaContainerHandler<RequestType, AwsProxyResponse> buildSpringProxy() {
ProxyMvc mvc = ProxyMvc.INSTANCE(this.configurationClasses);
BiFunction<HttpServletRequest, HttpServletResponse, Void> handlerDelegate = new BiFunction<HttpServletRequest, HttpServletResponse, Void>() {
@Override
public Void apply(HttpServletRequest request, HttpServletResponse response) {
try {
mvc.service(request, response);
response.flushBuffer();
}
catch (Exception e) {
throw new IllegalStateException(e);
}
return null;
}
};
SpringLambdaContainerHandler<RequestType, AwsProxyResponse> handler = createHandler(mvc.getApplicationContext(), handlerDelegate);
handler.setServletContext(mvc.getServletContext());
return handler;
}

protected SpringLambdaContainerHandler<RequestType, AwsProxyResponse> createHandler(ConfigurableWebApplicationContext ctx) {
return new SpringLambdaContainerHandler<>(
requestTypeClass, responseTypeClass, requestReader, responseWriter,
securityContextWriter, exceptionHandler, ctx, initializationWrapper
);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
protected SpringLambdaContainerHandler<RequestType, AwsProxyResponse> createHandler(ConfigurableWebApplicationContext ctx,
BiFunction<HttpServletRequest, HttpServletResponse, Void> handler) {
return new SpringLambdaContainerHandler(requestTypeClass, responseTypeClass, requestReader, responseWriter,
securityContextWriter, exceptionHandler, ctx, initializationWrapper) {
@Override
protected void handleRequest(HttpServletRequest containerRequest, AwsHttpServletResponse containerResponse,
Context lambdaContext) throws Exception {
containerRequest.setAttribute("AWS_CONTEXT", lambdaContext);
handler.apply(containerRequest, containerResponse);
}
};
}

@Override
public SpringLambdaContainerHandler<RequestType, AwsProxyResponse> buildAndInitialize() throws ContainerInitializationException {
SpringLambdaContainerHandler<RequestType, AwsProxyResponse> handler = build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler;
import com.amazonaws.serverless.proxy.internal.servlet.AwsLambdaServletContainerHandler;
import com.amazonaws.serverless.proxy.internal.servlet.AwsServletRegistration;
import com.amazonaws.serverless.proxy.model.*;
import com.amazonaws.serverless.proxy.internal.servlet.AwsServletContext;
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
Expand All @@ -21,13 +20,19 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.web.servlet.DispatcherServlet;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.FilterRegistration;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;

import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.DispatcherServlet;

import java.io.IOException;
import java.lang.reflect.Field;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
Expand Down Expand Up @@ -57,8 +62,8 @@ public class SpringAwsProxyTest {
registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/echo/*");
// servlet name mappings are disabled and will throw an exception

//handler.getApplicationInitializer().getDispatcherServlet().setThrowExceptionIfNoHandlerFound(true);
((DispatcherServlet)((AwsServletRegistration)c.getServletRegistration("dispatcherServlet")).getServlet()).setThrowExceptionIfNoHandlerFound(true);
DispatcherServlet dServlet = extractDispatcherServletFromContext(c);
dServlet.setThrowExceptionIfNoHandlerFound(true);
});

private String type;
Expand Down Expand Up @@ -503,5 +508,17 @@ private void validateSingleValueModel(AwsProxyResponse output, String value) {
fail("Exception while parsing response body: " + e.getMessage());
}
}

private static DispatcherServlet extractDispatcherServletFromContext(ServletContext servletContext) {
ServletRegistration servletRegistration = servletContext.getServletRegistration("dispatcherServlet");
Field field = ReflectionUtils.findField(servletRegistration.getClass(), "servlet");
field.setAccessible(true);
try {
return (DispatcherServlet) field.get(servletRegistration);
}
catch (IllegalArgumentException | IllegalAccessException e) {
throw new IllegalStateException(e);
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void profile_defaultProfile() throws Exception {

@Test
void profile_overrideProfile() throws Exception {
System.setProperty("spring.cloud.function.enable", "false");
AwsProxyRequest request = new AwsProxyRequestBuilder("/profile/spring-properties", "GET")
.build();
SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler = SpringLambdaContainerHandler.getAwsProxyHandler(EchoSpringAppConfig.class);
Expand All @@ -67,5 +68,6 @@ void profile_overrideProfile() throws Exception {
assertEquals("override-profile", response.getValues().get("profileTest"));
assertEquals("not-overridden", response.getValues().get("noOverride"));
assertEquals("override-profile-from-bean", response.getValues().get("beanInjectedValue"));
System.setProperty("spring.cloud.function.enable", "true");
}
}
6 changes: 6 additions & 0 deletions samples/spring/pet-store/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
</properties>

<dependencies>
<dependency>
<groupId>io.github.crac</groupId>
<artifactId>org-crac</artifactId>
<version>0.1.3</version>
</dependency>

<dependency>
<groupId>com.amazonaws.serverless</groupId>
<artifactId>aws-serverless-java-container-spring</artifactId>
Expand Down

0 comments on commit 4a61cfe

Please sign in to comment.