Skip to content

Commit

Permalink
Merge pull request #631 from 2012160085/main
Browse files Browse the repository at this point in the history
Fix Asynchronous Dispatch Logic in AwsAsyncContext with Spring's DispatcherServlet
  • Loading branch information
deki authored Jan 30, 2024
2 parents 7ca2f07 + 1fa314b commit 69883f6
Show file tree
Hide file tree
Showing 19 changed files with 382 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@
public class AwsAsyncContext implements AsyncContext {
private HttpServletRequest req;
private HttpServletResponse res;
private AwsLambdaServletContainerHandler handler;
private List<AsyncListenerHolder> listeners;
private long timeout;
private AtomicBoolean dispatched;
private AtomicBoolean completed;
private AtomicBoolean dispatchStarted;

private Logger log = LoggerFactory.getLogger(AwsAsyncContext.class);

public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response, AwsLambdaServletContainerHandler servletHandler) {
public AwsAsyncContext(HttpServletRequest request, HttpServletResponse response) {
log.debug("Initializing async context for request: " + SecurityUtils.crlf(request.getPathInfo()) + " - " + SecurityUtils.crlf(request.getMethod()));
req = request;
res = response;
handler = servletHandler;
listeners = new ArrayList<>();
timeout = 3000;
dispatched = new AtomicBoolean(false);
completed = new AtomicBoolean(false);
dispatchStarted = new AtomicBoolean(false);
}

@Override
Expand All @@ -68,16 +68,14 @@ public boolean hasOriginalRequestAndResponse() {

@Override
public void dispatch() {
try {
log.debug("Dispatching request");
if (dispatched.get()) {
throw new IllegalStateException("Dispatching already started");
}
log.debug("Dispatching request");

if (dispatched.get()) {
throw new IllegalStateException("Dispatching already started");
}
if (dispatchStarted.getAndSet(true)) {
dispatched.set(true);
handler.doFilter(req, res, ((AwsServletContext)req.getServletContext()).getServletForPath(req.getRequestURI()));
notifyListeners(NotificationType.START_ASYNC, null);
} catch (ServletException | IOException e) {
notifyListeners(NotificationType.ERROR, e);
}
}

Expand Down Expand Up @@ -154,6 +152,10 @@ public boolean isCompleted() {
return completed.get();
}

public boolean isDispatchStarted() {
return dispatchStarted.get();
}

private void notifyListeners(NotificationType type, Throwable t) {
listeners.forEach((h) -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,15 @@ public boolean isAsyncStarted() {

@Override
public AsyncContext startAsync() throws IllegalStateException {
asyncContext = new AwsAsyncContext(this, response, containerHandler);
asyncContext = new AwsAsyncContext(this, response);
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, containerHandler);
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,25 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response

FilterChain chain = getFilterChain(request, servlet);
chain.doFilter(request, response);

if(requiresAsyncReDispatch(request)) {
chain = getFilterChain(request, servlet);
chain.doFilter(request, response);
}
// if for some reason the response wasn't flushed yet, we force it here unless it's being processed asynchronously (WebFlux)
if (!response.isCommitted() && request.getDispatcherType() != DispatcherType.ASYNC) {
response.flushBuffer();
}
}

private boolean requiresAsyncReDispatch(HttpServletRequest request) {
if (request.isAsyncStarted()) {
AsyncContext asyncContext = request.getAsyncContext();
return asyncContext instanceof AwsAsyncContext
&& ((AwsAsyncContext) asyncContext).isDispatchStarted();
}
return false;
}

@Override
public void initialize() throws ContainerInitializationException {
// we expect all servlets to be wrapped in an AwsServletRegistration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ public boolean isAsyncStarted() {
@Override
public AsyncContext startAsync()
throws IllegalStateException {
asyncContext = new AwsAsyncContext(this, response, containerHandler);
asyncContext = new AwsAsyncContext(this, response);
setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
Expand All @@ -506,7 +506,7 @@ public AsyncContext startAsync()
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
throws IllegalStateException {
servletRequest.setAttribute(DISPATCHER_TYPE_ATTRIBUTE, DispatcherType.ASYNC);
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, containerHandler);
asyncContext = new AwsAsyncContext((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
log.debug("Starting async context for request: " + SecurityUtils.crlf(request.getRequestContext().getRequestId()));
return asyncContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.services.lambda.runtime.Context;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import jakarta.servlet.AsyncContext;
Expand All @@ -32,48 +33,20 @@ public class AwsAsyncContextTest {
private AwsServletContextTest.TestServlet srv2 = new AwsServletContextTest.TestServlet("srv2");
private AwsServletContext ctx = getCtx();

@Test
void dispatch_sendsToCorrectServlet() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), lambdaCtx, null);
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);

AsyncContext asyncCtx = req.startAsync();
handler.setDesiredStatus(201);
asyncCtx.dispatch();
assertNotNull(handler.getSelectedServlet());
assertEquals(srv1, handler.getSelectedServlet());
assertEquals(201, handler.getResponse().getStatus());

req = new AwsProxyHttpServletRequest(new AwsProxyRequestBuilder("/srv5/hello", "GET").build(), lambdaCtx, null);
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);
asyncCtx = req.startAsync();
handler.setDesiredStatus(202);
asyncCtx.dispatch();
assertNotNull(handler.getSelectedServlet());
assertEquals(srv2, handler.getSelectedServlet());
assertEquals(202, handler.getResponse().getStatus());
}

@Test
void dispatchNewPath_sendsToCorrectServlet() throws InvalidRequestEventException {
void dispatch_amendsPath() throws InvalidRequestEventException {
AwsProxyHttpServletRequest req = (AwsProxyHttpServletRequest)reader.readRequest(new AwsProxyRequestBuilder("/srv1/hello", "GET").build(), null, lambdaCtx, LambdaContainerHandler.getContainerConfig());
req.setResponse(handler.getContainerResponse(req, new CountDownLatch(1)));
req.setServletContext(ctx);
req.setContainerHandler(handler);

AsyncContext asyncCtx = req.startAsync();
handler.setDesiredStatus(301);
asyncCtx.dispatch("/srv4/hello");
assertNotNull(handler.getSelectedServlet());
assertEquals(srv2, handler.getSelectedServlet());
assertNotNull(handler.getResponse());
assertEquals(301, handler.getResponse().getStatus());
assertEquals("/srv1/hello", req.getRequestURI());
}


private AwsServletContext getCtx() {
AwsServletContext ctx = new AwsServletContext(handler);
handler.setServletContext(ctx);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.springapp.LambdaHandler;
import com.amazonaws.serverless.proxy.spring.springapp.MessageController;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

public class AsyncAppTest {

private static LambdaHandler handler;

@BeforeAll
public static void setUp() {
try {
handler = new LambdaHandler();
} catch (ContainerInitializationException e) {
e.printStackTrace();
fail();
}
}

@Test
void springApp_helloRequest_returnsCorrect() {
AwsProxyRequest req = new AwsProxyRequestBuilder("/hello", "GET").build();
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
assertEquals(200, resp.getStatusCode());
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

@Test
void springApp_asyncRequest_returnsCorrect() {
AwsProxyRequest req = new AwsProxyRequestBuilder("/async", "GET").build();
AwsProxyResponse resp = handler.handleRequest(req, new MockLambdaContext());
assertEquals(200, resp.getStatusCode());
assertEquals(MessageController.HELLO_MESSAGE, resp.getBody());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;

@Configuration
@Import({MessageController.class})
public class AppConfig { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.spring.SpringLambdaContainerHandler;
import com.amazonaws.serverless.proxy.spring.SpringProxyHandlerBuilder;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;

public class LambdaHandler implements RequestHandler<AwsProxyRequest, AwsProxyResponse> {
private SpringLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler;

public LambdaHandler() throws ContainerInitializationException {
handler = new SpringProxyHandlerBuilder<AwsProxyRequest>()
.defaultProxy()
.asyncInit()
.configurationClasses(AppConfig.class)
.buildAndInitialize();
}

@Override
public AwsProxyResponse handleRequest(AwsProxyRequest awsProxyRequest, Context context) {
return handler.proxy(awsProxyRequest, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.amazonaws.serverless.proxy.spring.springapp;

import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;

@RestController
@EnableWebMvc
public class MessageController {
public static final String HELLO_MESSAGE = "Hello";

@RequestMapping(path="/hello", method= RequestMethod.GET)
public String hello() {
return HELLO_MESSAGE;
}

@RequestMapping(path="/async", method= RequestMethod.GET)
public DeferredResult<String> asyncHello() {
DeferredResult<String> result = new DeferredResult<>();
result.setResult(HELLO_MESSAGE);
return result;
}
}
48 changes: 48 additions & 0 deletions aws-serverless-java-container-springboot3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,46 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
<version>3.2.1</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-tomcat</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat.embed</groupId>
<artifactId>tomcat-embed-core</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat.embed</groupId>
<artifactId>tomcat-embed-websocket</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>2.2.222</version>
<scope>test</scope>
</dependency>


</dependencies>

<build>
Expand Down Expand Up @@ -282,6 +322,14 @@
<failOnError>false</failOnError>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>10</source>
<target>10</target>
</configuration>
</plugin>
</plugins>
</build>
<repositories>
Expand Down
Loading

0 comments on commit 69883f6

Please sign in to comment.