Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add delegation proxy to spring web proxy #578

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@ public ServletContext getServletContext() {
* Sets the ServletContext in the handler and initialized a new <code>FilterChainManager</code>
* @param context An initialized ServletContext
*/
protected void setServletContext(final ServletContext context) {
public void setServletContext(final ServletContext context) {
servletContext = context;
// We assume custom implementations of the RequestWriter for HttpServletRequest will reuse
// the existing AwsServletContext object since it has no dependencies other than the Lambda context
filterChainManager = new AwsFilterChainManager((AwsServletContext)servletContext);
if (context instanceof AwsServletContext) {
filterChainManager = new AwsFilterChainManager((AwsServletContext)servletContext);
}
}

protected FilterChain getFilterChain(HttpServletRequest req, Servlet servlet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class AwsProxyRequest {
private String resource;
private AwsProxyRequestContext requestContext;
private MultiValuedTreeMap<String, String> multiValueQueryStringParameters;
private Map<String, String> queryStringParameters;
private Map<String, String> queryStringParameters;
private Headers multiValueHeaders;
private SingleValueHeaders headers;
private Map<String, String> pathParameters;
Expand All @@ -41,12 +41,13 @@ public class AwsProxyRequest {
private String path;
private boolean isBase64Encoded;

public AwsProxyRequest() {
multiValueHeaders = new Headers();
multiValueQueryStringParameters = new MultiValuedTreeMap<>();
pathParameters = new HashMap<>();
stageVariables = new HashMap<>();
}
public AwsProxyRequest() {
this.headers = new SingleValueHeaders();
multiValueHeaders = new Headers();
multiValueQueryStringParameters = new MultiValuedTreeMap<>();
pathParameters = new HashMap<>();
stageVariables = new HashMap<>();
}


//-------------------------------------------------------------
Expand Down Expand Up @@ -131,17 +132,21 @@ public Headers getMultiValueHeaders() {
return multiValueHeaders;
}

public void setMultiValueHeaders(Headers multiValueHeaders) {
this.multiValueHeaders = multiValueHeaders;
}
public void setMultiValueHeaders(Headers multiValueHeaders) {
if (multiValueHeaders != null) {
this.multiValueHeaders = multiValueHeaders;
}
}

public SingleValueHeaders getHeaders() {
return headers;
}

public void setHeaders(SingleValueHeaders headers) {
this.headers = headers;
}
public void setHeaders(SingleValueHeaders headers) {
if (headers != null) {
this.headers = headers;
}
}


public Map<String, String> getPathParameters() {
Expand Down
12 changes: 5 additions & 7 deletions aws-serverless-java-container-spring/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-function-serverless-web</artifactId>
<version>4.0.3</version>
</dependency>
<!-- Core interfaces for the aws-serverless-java-container project -->
<dependency>
<groupId>com.amazonaws.serverless</groupId>
Expand Down Expand Up @@ -57,13 +62,6 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>jakarta.activation</groupId>
<artifactId>jakarta.activation-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.amazonaws.serverless.proxy.internal.servlet.*;
import com.amazonaws.serverless.proxy.model.HttpApiV2ProxyRequest;
import com.amazonaws.services.lambda.runtime.Context;

import org.springframework.web.context.ConfigurableWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;

Expand Down Expand Up @@ -49,13 +50,20 @@ public class SpringLambdaContainerHandler<RequestType, ResponseType> extends Aws
* @return An initialized instance of the `SpringLambdaContainerHandler`
* @throws ContainerInitializationException When the Spring framework fails to start.
*/
public static SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> getAwsProxyHandler(Class<?>... config) throws ContainerInitializationException {
return new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy()
.initializationWrapper(new InitializationWrapper())
.configurationClasses(config)
.buildAndInitialize();
}
public static SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> getAwsProxyHandler(Class<?>... config)
throws ContainerInitializationException {
// Temporary flag. Should be removed once spring-cloud-function delegation model
// becomes the only path.
boolean delegateToSpringCloudFunction = Boolean.parseBoolean(System.getenv().getOrDefault("spring.cloud.function.enable",
(String) System.getProperties().getOrDefault("spring.cloud.function.enable", "true")));
if (delegateToSpringCloudFunction) {
return getSpringNativeHandler(config);
} else {
return new SpringProxyHandlerBuilder<AwsProxyRequest>().defaultProxy()
.initializationWrapper(new InitializationWrapper()).configurationClasses(config)
.buildAndInitialize();
}
}

/**
* Creates a default SpringLambdaContainerHandler initialized with the `AwsProxyRequest` and `AwsProxyResponse` objects and sets the given profiles as active
Expand Down Expand Up @@ -188,4 +196,11 @@ protected void registerServlets() {
reg.addMapping("/");
reg.setLoadOnStartup(1);
}

private static SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> getSpringNativeHandler(Class<?>... config) throws ContainerInitializationException {
SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler = new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy().configurationClasses(config).buildSpringProxy();

return handler;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.servlet.AwsHttpServletResponse;
import com.amazonaws.serverless.proxy.internal.servlet.ServletLambdaContainerHandlerBuilder;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.services.lambda.runtime.Context;

import org.springframework.cloud.function.serverless.web.ProxyMvc;
import org.springframework.web.context.ConfigurableWebApplicationContext;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.function.BiFunction;


public class SpringProxyHandlerBuilder<RequestType> extends ServletLambdaContainerHandlerBuilder<
RequestType,
Expand Down Expand Up @@ -73,13 +80,55 @@ public SpringLambdaContainerHandler<RequestType, AwsProxyResponse> build() throw
return handler;
}

/**
* Builds an instance of SpringLambdaContainerHandler with "delegate" to Spring provided ProxyMvc. The delegate
* is provided via BiFunction which takes HttpServletRequest and HttpSerbletResponse as input parameters.
* The AWS context is set as attribute of HttpServletRequest under `AWS_CONTEXT` key.
*
* @return instance of SpringLambdaContainerHandler
*/
SpringLambdaContainerHandler<RequestType, AwsProxyResponse> buildSpringProxy() {
ProxyMvc mvc = ProxyMvc.INSTANCE(this.configurationClasses);
BiFunction<HttpServletRequest, HttpServletResponse, Void> handlerDelegate = new BiFunction<HttpServletRequest, HttpServletResponse, Void>() {
@Override
public Void apply(HttpServletRequest request, HttpServletResponse response) {
try {
mvc.service(request, response);
response.flushBuffer();
}
catch (Exception e) {
throw new IllegalStateException(e);
}
return null;
}
};
SpringLambdaContainerHandler<RequestType, AwsProxyResponse> handler = createHandler(mvc.getApplicationContext(),
handlerDelegate);
handler.setServletContext(mvc.getServletContext());
return handler;
}

protected SpringLambdaContainerHandler<RequestType, AwsProxyResponse> createHandler(ConfigurableWebApplicationContext ctx) {
return new SpringLambdaContainerHandler<>(
requestTypeClass, responseTypeClass, requestReader, responseWriter,
securityContextWriter, exceptionHandler, ctx, initializationWrapper
);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
protected SpringLambdaContainerHandler<RequestType, AwsProxyResponse> createHandler(ConfigurableWebApplicationContext ctx,
BiFunction<HttpServletRequest, HttpServletResponse, Void> handler) {
return new SpringLambdaContainerHandler(requestTypeClass, responseTypeClass, requestReader, responseWriter,
securityContextWriter, exceptionHandler, ctx, initializationWrapper) {
@Override
protected void handleRequest(HttpServletRequest containerRequest, AwsHttpServletResponse containerResponse,
Context lambdaContext) throws Exception {
containerRequest.setAttribute("AWS_CONTEXT", lambdaContext);
handler.apply(containerRequest, containerResponse);
}
};
}

@Override
public SpringLambdaContainerHandler<RequestType, AwsProxyResponse> buildAndInitialize() throws ContainerInitializationException {
SpringLambdaContainerHandler<RequestType, AwsProxyResponse> handler = build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler;
import com.amazonaws.serverless.proxy.internal.servlet.AwsLambdaServletContainerHandler;
import com.amazonaws.serverless.proxy.internal.servlet.AwsServletRegistration;
import com.amazonaws.serverless.proxy.model.*;
import com.amazonaws.serverless.proxy.internal.servlet.AwsServletContext;
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
Expand All @@ -21,13 +20,19 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.web.servlet.DispatcherServlet;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.FilterRegistration;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;

import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.DispatcherServlet;

import java.io.IOException;
import java.lang.reflect.Field;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
Expand Down Expand Up @@ -57,8 +62,8 @@ public class SpringAwsProxyTest {
registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/echo/*");
// servlet name mappings are disabled and will throw an exception

//handler.getApplicationInitializer().getDispatcherServlet().setThrowExceptionIfNoHandlerFound(true);
((DispatcherServlet)((AwsServletRegistration)c.getServletRegistration("dispatcherServlet")).getServlet()).setThrowExceptionIfNoHandlerFound(true);
DispatcherServlet dServlet = extractDispatcherServletFromContext(c);
dServlet.setThrowExceptionIfNoHandlerFound(true);
});

private String type;
Expand Down Expand Up @@ -503,5 +508,17 @@ private void validateSingleValueModel(AwsProxyResponse output, String value) {
fail("Exception while parsing response body: " + e.getMessage());
}
}

private static DispatcherServlet extractDispatcherServletFromContext(ServletContext servletContext) {
ServletRegistration servletRegistration = servletContext.getServletRegistration("dispatcherServlet");
Field field = ReflectionUtils.findField(servletRegistration.getClass(), "servlet");
field.setAccessible(true);
try {
return (DispatcherServlet) field.get(servletRegistration);
}
catch (IllegalArgumentException | IllegalAccessException e) {
throw new IllegalStateException(e);
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void profile_defaultProfile() throws Exception {

@Test
void profile_overrideProfile() throws Exception {
System.setProperty("spring.cloud.function.enable", "false");
AwsProxyRequest request = new AwsProxyRequestBuilder("/profile/spring-properties", "GET")
.build();
SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler = SpringLambdaContainerHandler.getAwsProxyHandler(EchoSpringAppConfig.class);
Expand All @@ -67,5 +68,6 @@ void profile_overrideProfile() throws Exception {
assertEquals("override-profile", response.getValues().get("profileTest"));
assertEquals("not-overridden", response.getValues().get("noOverride"));
assertEquals("override-profile-from-bean", response.getValues().get("beanInjectedValue"));
System.setProperty("spring.cloud.function.enable", "true");
}
}