Skip to content

Commit

Permalink
Moved execution of servlet to the filter chain to allow filters to ha…
Browse files Browse the repository at this point in the history
…lt execution. The servlet is injected in the chain by the FilterChainManager. This should fix the last issues with #65 and fix #66
  • Loading branch information
sapessi committed Oct 12, 2017
1 parent 0433b6b commit c677a20
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.FilterChain;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
Expand Down Expand Up @@ -168,6 +169,10 @@ protected void setServletContext(final ServletContext context) {
filterChainManager = new AwsFilterChainManager((AwsServletContext)context);
}

protected FilterChain getFilterChain(ContainerRequestType req, Servlet servlet) {
return filterChainManager.getFilterChain(req, servlet);
}


//-------------------------------------------------------------
// Methods - Protected
Expand All @@ -182,12 +187,11 @@ protected void setServletContext(final ServletContext context) {
* @throws ServletException
*/
protected void doFilter(ContainerRequestType request, ContainerResponseType response, Servlet servlet) throws IOException, ServletException {
FilterChainHolder chain = filterChainManager.getFilterChain(request, servlet);
FilterChain chain = getFilterChain(request, servlet);
log.debug("FilterChainHolder.doFilter {}", chain);
chain.doFilter(request, response);
}


//-------------------------------------------------------------
// Inner Class -
//-------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
* during a request lifecycle
*/
public class FilterChainHolder implements FilterChain {
private final Servlet servlet;

//-------------------------------------------------------------
// Variables - Private
Expand All @@ -46,22 +45,19 @@ public class FilterChainHolder implements FilterChain {

/**
* Creates a new empty <code>FilterChainHolder</code>
* @param servlet
*/
FilterChainHolder(Servlet servlet) {
this(new ArrayList<>(), servlet);
FilterChainHolder() {
this(new ArrayList<>());
}


/**
* Creates a new instance of a filter chain holder
* @param allFilters A populated list of <code>FilterHolder</code> objects
* @param servlet
*/
FilterChainHolder(List<FilterHolder> allFilters, Servlet servlet) {
FilterChainHolder(List<FilterHolder> allFilters) {
filters = allFilters;
resetHolder();
this.servlet = servlet;
}


Expand All @@ -72,34 +68,27 @@ public class FilterChainHolder implements FilterChain {
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException, ServletException {
currentFilter++;
if (filters == null || filters.size() == 0 ) {
log.debug("Could not find filters to execute, returning");
return;
} else if (currentFilter > filters.size() - 1) {
if (null != servlet) {
log.debug("Starting servlet {}", servlet);
servlet.service(servletRequest, servletResponse);
log.debug("Executed servlet {}", servlet);
return;
} else {
log.debug("No more filters");
return;
}
}
// TODO: We do not check for async filters here

FilterHolder holder = filters.get(currentFilter);
// if we still have filters, keep running through the chain
if (currentFilter <= filters.size() - 1) {
FilterHolder holder = filters.get(currentFilter);

// lazily initialize filters when they are needed
if (!holder.isFilterInitialized()) {
holder.init();
}
log.debug("Starting {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
currentFilter, holder.getFilterName(), holder.getFilter());
holder.getFilter().doFilter(servletRequest, servletResponse, this);
log.debug("Executed {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
currentFilter, holder.getFilterName(), holder.getFilter());
}

// lazily initialize filters when they are needed
if (!holder.isFilterInitialized()) {
holder.init();
// if for some reason the response wasn't flushed yet, we force it here.
if (!servletResponse.isCommitted()) {
servletResponse.flushBuffer();
}
log.debug("Starting {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
currentFilter, holder.getFilterName(), holder.getFilter());
holder.getFilter().doFilter(servletRequest, servletResponse, this);
log.debug("Executed {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(),
currentFilter, holder.getFilterName(), holder.getFilter());
currentFilter--;
}


Expand Down Expand Up @@ -162,6 +151,6 @@ private void resetHolder() {

@Override
public String toString() {
return "filters=" + filters + ", servlet=" + servlet;
return "filters=" + filters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
package com.amazonaws.serverless.proxy.internal.servlet;

import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -94,10 +101,13 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
return getFilterChainCache(type, targetPath, servlet);
}

FilterChainHolder chainHolder = new FilterChainHolder(servlet);
FilterChainHolder chainHolder = new FilterChainHolder();

Map<String, FilterHolder> registrations = getFilterHolders();
if (registrations == null || registrations.size() == 0) {
if (servlet != null) {
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servlet), servletContext));
}
return chainHolder;
}
for (String name : registrations.keySet()) {
Expand All @@ -117,6 +127,10 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
// we assume we only ever have one servlet.
}

if (servlet != null) {
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servlet), servletContext));
}

putFilterChainCache(type, targetPath, chainHolder);
// update total filter size
if (filtersSize != registrations.size()) {
Expand Down Expand Up @@ -148,7 +162,7 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S
return null;
}

return new FilterChainHolder(filterCache.get(key), servlet);
return new FilterChainHolder(filterCache.get(key));
}


Expand Down Expand Up @@ -303,4 +317,34 @@ void setDispatcherType(DispatcherType dispatcherType) {
this.dispatcherType = dispatcherType;
}
}

private class ServletExecutionFilter implements Filter {

private FilterConfig config;
private Servlet handlerServlet;

public ServletExecutionFilter(Servlet handler) {
handlerServlet = handler;
}

@Override
public void init(FilterConfig filterConfig)
throws ServletException {
config = filterConfig;
}


@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
handlerServlet.service(servletRequest, servletResponse);
filterChain.doFilter(servletRequest, servletResponse);
}


@Override
public void destroy() {

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@
import spark.embeddedserver.EmbeddedServerFactory;
import spark.embeddedserver.EmbeddedServers;

import javax.servlet.DispatcherType;
import javax.servlet.FilterRegistration;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.EnumSet;
import java.util.concurrent.CountDownLatch;

/**
Expand Down Expand Up @@ -162,10 +166,12 @@ protected void handleRequest(AwsProxyHttpServletRequest httpServletRequest, AwsH
if (startupHandler != null) {
startupHandler.onStartup(getServletContext());
}

// manually add the spark filter to the chain. This should the last one and match all uris
FilterRegistration.Dynamic sparkRegistration = getServletContext().addFilter("SparkFilter", embeddedServer.getSparkFilter());
sparkRegistration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/*");
}

doFilter(httpServletRequest, httpServletResponse, null);

embeddedServer.handle(httpServletRequest, httpServletResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import spark.ssl.SslStores;
import spark.staticfiles.StaticFilesConfiguration;

import javax.servlet.Filter;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -95,4 +96,13 @@ public void handle(HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
sparkFilter.doFilter(request, response, null);
}


/**
* Returns the initialized instance of the main Spark filter.
* @return The spark filter instance.
*/
public Filter getSparkFilter() {
return sparkFilter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.spark.filter.CustomHeaderFilter;
import com.amazonaws.serverless.proxy.spark.filter.UnauthenticatedFilter;

import org.junit.AfterClass;
import org.junit.BeforeClass;
Expand Down Expand Up @@ -60,6 +61,50 @@ public void filters_onStartupMethod_executeFilters() {

}

@Test
public void filters_unauthenticatedFilter_stopRequestProcessing() {

SparkLambdaContainerHandler<AwsProxyRequest, AwsProxyResponse> handler = null;
try {
handler = SparkLambdaContainerHandler.getAwsProxyHandler();
} catch (ContainerInitializationException e) {
e.printStackTrace();
fail();
}

handler.onStartup(c -> {
if (c == null) {
System.out.println("Null servlet context");
fail();
}
FilterRegistration.Dynamic registration = c.addFilter("UnauthenticatedFilter", UnauthenticatedFilter.class);
// update the registration to map to a path
registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/unauth");
// servlet name mappings are disabled and will throw an exception
});

configureRoutes();

// first we test without the custom header, we expect request processing to complete
// successfully
AwsProxyRequest req = new AwsProxyRequestBuilder().method("GET").path("/unauth").build();
AwsProxyResponse response = handler.proxy(req, new MockLambdaContext());

assertNotNull(response);
assertEquals(200, response.getStatusCode());
assertEquals(RESPONSE_BODY_TEXT, response.getBody());

// now we test with the custom header, this should stop request processing in the
// filter and return an unauthenticated response
AwsProxyRequest unauthReq = new AwsProxyRequestBuilder().method("GET").path("/unauth")
.header(UnauthenticatedFilter.HEADER_NAME, "1").build();
AwsProxyResponse unauthResp = handler.proxy(unauthReq, new MockLambdaContext());

assertNotNull(unauthResp);
assertEquals(UnauthenticatedFilter.RESPONSE_STATUS, unauthResp.getStatusCode());
assertEquals("", unauthResp.getBody());
}

@AfterClass
public static void stopSpark() {
Spark.stop();
Expand All @@ -70,5 +115,10 @@ private static void configureRoutes() {
res.status(200);
return RESPONSE_BODY_TEXT;
});

get("/unauth", (req, res) -> {
res.status(200);
return RESPONSE_BODY_TEXT;
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.amazonaws.serverless.proxy.spark.filter;


import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import java.io.IOException;


public class UnauthenticatedFilter implements Filter {
public static final String HEADER_NAME = "X-Unauthenticated-Response";
public static final int RESPONSE_STATUS = 401;

@Override
public void init(FilterConfig filterConfig)
throws ServletException {

}


@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
System.out.println("Running unauth filter");
if (((HttpServletRequest)servletRequest).getHeader(HEADER_NAME) != null) {
((HttpServletResponse) servletResponse).setStatus(401);
System.out.println("Returning 401");
return;
}
System.out.println("Continue chain");
filterChain.doFilter(servletRequest, servletResponse);
}


@Override
public void destroy() {

}
}
Loading

0 comments on commit c677a20

Please sign in to comment.