Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Asynchronous Dispatch Logic in AwsAsyncContext with Spring's DispatcherServlet #631

Merged
merged 18 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class AwsAsyncContext implements AsyncContext {
private long timeout;
private AtomicBoolean dispatched;
private AtomicBoolean completed;
private AtomicBoolean dispatchStarted;

private Logger log = LoggerFactory.getLogger(AwsAsyncContext.class);

Expand All @@ -49,6 +50,7 @@ public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response,
timeout = 3000;
dispatched = new AtomicBoolean(false);
completed = new AtomicBoolean(false);
dispatchStarted = new AtomicBoolean(false);
}

@Override
Expand All @@ -68,16 +70,15 @@ public boolean hasOriginalRequestAndResponse() {

@Override
public void dispatch() {
try {
log.debug("Dispatching request");
if (dispatched.get()) {
throw new IllegalStateException("Dispatching already started");
}
log.debug("Dispatching request");
if (dispatched.get()) {
throw new IllegalStateException("Dispatching already started");
}
if (!dispatchStarted.get()) {
dispatchStarted.set(true);
} else {
dispatched.set(true);
handler.doFilter(req, res, ((AwsServletContext)req.getServletContext()).getServletForPath(req.getRequestURI()));
notifyListeners(NotificationType.START_ASYNC, null);
} catch (ServletException | IOException e) {
notifyListeners(NotificationType.ERROR, e);
}
}
leekib marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -154,6 +155,10 @@ public boolean isCompleted() {
return completed.get();
}

public boolean isDispatchStarted() {
return dispatchStarted.get();
}

private void notifyListeners(NotificationType type, Throwable t) {
listeners.forEach((h) -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.services.lambda.runtime.Context;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import jakarta.servlet.AsyncContext;
Expand All @@ -32,6 +33,20 @@ public class AwsAsyncContextTest {
private AwsServletContextTest.TestServlet srv2 = new AwsServletContextTest.TestServlet("srv2");
private AwsServletContext ctx = getCtx();


@Test
void dispatch_amendsPath() throws InvalidRequestEventException {
AwsProxyHttpServletRequest req = (AwsProxyHttpServletRequest)reader.readRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), null, lambdaCtx, LambdaContainerHandler.getContainerConfig());
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);

AsyncContext asyncCtx = req.startAsync();
asyncCtx.dispatch("/srv4/hello");
assertEquals("/srv1/hello", req.getRequestURI());
}

@Disabled("AwsAsyncContext does not sends to servlet anymore")
@Test
void dispatch_sendsToCorrectServlet() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), lambdaCtx, null);
Expand All @@ -58,6 +73,7 @@ void dispatch_sendsToCorrectServlet() {
assertEquals(202, handler.getResponse().getStatus());
}

@Disabled("AwsAsyncContext does not sends to servlet anymore")
leekib marked this conversation as resolved.
Show resolved Hide resolved
@Test
void dispatchNewPath_sendsToCorrectServlet() throws InvalidRequestEventException {
AwsProxyHttpServletRequest req = (AwsProxyHttpServletRequest)reader.readRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), null, lambdaCtx, LambdaContainerHandler.getContainerConfig());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.amazonaws.serverless.proxy.internal.servlet.*;
import com.amazonaws.serverless.proxy.model.HttpApiV2ProxyRequest;
import com.amazonaws.services.lambda.runtime.Context;
import jakarta.servlet.AsyncContext;
import org.springframework.web.context.ConfigurableWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;

Expand Down Expand Up @@ -160,9 +161,21 @@ protected void handleRequest(HttpServletRequest containerRequest, AwsHttpServlet
// process filters
Servlet reqServlet = ((AwsServletContext)getServletContext()).getServletForPath(containerRequest.getPathInfo());
doFilter(containerRequest, containerResponse, reqServlet);
if(requiresAsyncReDispatch(containerRequest)) {
reqServlet = ((AwsServletContext)getServletContext()).getServletForPath(containerRequest.getPathInfo());
doFilter(containerRequest, containerResponse, reqServlet);
}
Timer.stop("SPRING_HANDLE_REQUEST");
}

private boolean requiresAsyncReDispatch(HttpServletRequest request) {
if (request.isAsyncStarted()) {
AsyncContext asyncContext = request.getAsyncContext();
return asyncContext instanceof AwsAsyncContext
&& ((AwsAsyncContext) asyncContext).isDispatchStarted();
}
return false;
}

@Override
public void initialize()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.springapp.LambdaHandler;
import com.amazonaws.serverless.proxy.spring.springapp.MessageController;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

public class AsyncAppTest {

private static LambdaHandler handler;

@BeforeAll
public static void setUp() {
try {
handler = new LambdaHandler();
} catch (ContainerInitializationException e) {
e.printStackTrace();
fail();
}
}

@Test
void springApp_helloRequest_returnsCorrect() {
AwsProxyRequest req = new AwsProxyRequestBuilder("/hello", "GET").build();
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
assertEquals(200, resp.getStatusCode());
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

@Test
void springApp_asyncRequest_returnsCorrect() {
AwsProxyRequest req = new AwsProxyRequestBuilder("/async", "GET").build();
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
assertEquals(200, resp.getStatusCode());
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;

@Configuration
@Import({MessageController.class})
public class AppConfig { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.SpringLambdaContainerHandler;
import com.amazonaws.serverless.proxy.spring.SpringProxyHandlerBuilder;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;

public class LambdaHandler implements RequestHandler<AwsProxyRequest, AwsProxyResponse> {
private SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler;

public LambdaHandler() throws ContainerInitializationException {
handler = new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy()
.asyncInit()
.configurationClasses(AppConfig.class)
.buildAndInitialize();
}

@Override
public AwsProxyResponse handleRequest(AwsProxyRequest awsProxyRequest, Context context) {
return handler.proxy(awsProxyRequest, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;

@RestController
@EnableWebMvc
public class MessageController {
public static final String HELLO_MESSAGE = "Hello";

@RequestMapping(path="/hello", method= RequestMethod.GET)
public String hello() {
return HELLO_MESSAGE;
}

@RequestMapping(path="/async", method= RequestMethod.GET)
public DeferredResult<String> asyncHello() {
DeferredResult<String> result = new DeferredResult<>();
result.setResult(HELLO_MESSAGE);
return result;
}
}
40 changes: 40 additions & 0 deletions aws-serverless-java-container-springboot3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,46 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
<version>3.1.3</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-tomcat</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat.embed</groupId>
<artifactId>tomcat-embed-core</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat.embed</groupId>
<artifactId>tomcat-embed-websocket</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>2.2.222</version>
<scope>test</scope>
</dependency>


</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.amazonaws.serverless.proxy.spring.embedded.ServerlessReactiveServletEmbeddedServerFactory;
import com.amazonaws.serverless.proxy.spring.embedded.ServerlessServletEmbeddedServerFactory;
import com.amazonaws.services.lambda.runtime.Context;
import jakarta.servlet.AsyncContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.WebApplicationType;
Expand Down Expand Up @@ -172,9 +173,22 @@ protected void handleRequest(HttpServletRequest containerRequest, AwsHttpServlet
((AwsHttpServletRequest)containerRequest).setResponse(containerResponse);
}
doFilter(containerRequest, containerResponse, reqServlet);
if(requiresAsyncReDispatch(containerRequest)) {
reqServlet = ((AwsServletContext)getServletContext()).getServletForPath(containerRequest.getPathInfo());
doFilter(containerRequest, containerResponse, reqServlet);
}
Timer.stop("SPRINGBOOT2_HANDLE_REQUEST");
}

private boolean requiresAsyncReDispatch(HttpServletRequest request) {
if (request.isAsyncStarted()) {
AsyncContext asyncContext = request.getAsyncContext();
return asyncContext instanceof AwsAsyncContext
&& ((AwsAsyncContext) asyncContext).isDispatchStarted();
}
return false;
}


@Override
public void initialize()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.jpaapp.LambdaHandler;
import com.amazonaws.serverless.proxy.spring.jpaapp.MessageController;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.Arrays;
import java.util.Collection;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class JpaAppTest {

LambdaHandler handler;
MockLambdaContext lambdaContext = new MockLambdaContext();

private String type;

public static Collection<Object> data() {
return Arrays.asList(new Object[]{"API_GW", "ALB", "HTTP_API"});
}

public void initJpaAppTest(String reqType) {
type = reqType;
handler = new LambdaHandler(type);
}

@MethodSource("data")
@ParameterizedTest
void asyncRequest(String reqType) {
initJpaAppTest(reqType);
AwsProxyRequestBuilder req = new AwsProxyRequestBuilder("/async", "POST")
.json()
.body("{\"name\":\"kong\"}");
AwsProxyResponse resp = handler.handleRequest(req, lambdaContext);
assertEquals("{\"name\":\"KONG\"}", resp.getBody());
}

@MethodSource("data")
@ParameterizedTest
void helloRequest_respondsWithSingleMessage(String reqType) {
initJpaAppTest(reqType);
AwsProxyRequestBuilder req = new AwsProxyRequestBuilder("/hello", "GET");
AwsProxyResponse resp = handler.handleRequest(req, lambdaContext);
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.amazonaws.serverless.proxy.spring.jpaapp;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import javax.sql.DataSource;

@Configuration
public class DatabaseConfig {

@Bean
public DataSource dataSource() {
DriverManagerDataSource dataSource = new DriverManagerDataSource();
dataSource.setDriverClassName("org.h2.Driver");
dataSource.setUrl("jdbc:h2:mem:testdb");
dataSource.setUsername("sa");
dataSource.setPassword("");

return dataSource;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.amazonaws.serverless.proxy.spring.jpaapp;

import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.logging.LogLevel;
import org.springframework.boot.logging.LoggingSystem;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;

@SpringBootApplication(exclude = {
org.springframework.boot.autoconfigure.security.reactive.ReactiveUserDetailsServiceAutoConfiguration.class,
org.springframework.boot.autoconfigure.security.reactive.ReactiveSecurityAutoConfiguration.class,
org.springframework.boot.autoconfigure.security.servlet.UserDetailsServiceAutoConfiguration.class,
org.springframework.boot.autoconfigure.security.servlet.SecurityAutoConfiguration.class
})
@Import(MessageController.class)
public class JpaApplication {}
Loading