diff --git a/README.md b/README.md index b21c87013..210b948ac 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,9 @@ public class LambdaHandler implements RequestHandler { - if (getAuthenticationScheme() == null) { - return null; - } - - if (getAuthenticationScheme().equals(AUTH_SCHEME_CUSTOM)) { - return event.getRequestContext().getAuthorizer().getPrincipalId(); - } else if (getAuthenticationScheme().equals(AUTH_SCHEME_AWS_IAM)) { - // if we received credentials from Cognito Federated Identities then we return the identity id - if (event.getRequestContext().getIdentity().getCognitoIdentityId() != null) { - return event.getRequestContext().getIdentity().getCognitoIdentityId(); - } else { // otherwise the user arn from the credentials - return event.getRequestContext().getIdentity().getUserArn(); + if (getAuthenticationScheme() == null) { + return () -> null; + } + + if (getAuthenticationScheme().equals(AUTH_SCHEME_CUSTOM) || getAuthenticationScheme().equals(AUTH_SCHEME_AWS_IAM)) { + return () -> { + if (getAuthenticationScheme().equals(AUTH_SCHEME_CUSTOM)) { + return event.getRequestContext().getAuthorizer().getPrincipalId(); + } else if (getAuthenticationScheme().equals(AUTH_SCHEME_AWS_IAM)) { + // if we received credentials from Cognito Federated Identities then we return the identity id + if (event.getRequestContext().getIdentity().getCognitoIdentityId() != null) { + return event.getRequestContext().getIdentity().getCognitoIdentityId(); + } else { // otherwise the user arn from the credentials + return event.getRequestContext().getIdentity().getUserArn(); + } } - } else if (getAuthenticationScheme().equals(AUTH_SCHEME_COGNITO_POOL)) { - return event.getRequestContext().getAuthorizer().getClaims().getSubject(); - } - return null; - }; + // return null if we couldn't find a valid scheme + return null; + }; + } + + if (getAuthenticationScheme().equals(AUTH_SCHEME_COGNITO_POOL)) { + return new CognitoUserPoolPrincipal(event.getRequestContext().getAuthorizer().getClaims()); + } + + throw new RuntimeException("Cannot recognize authorization scheme in event"); } @@ -105,4 +113,27 @@ public String getAuthenticationScheme() { return null; } } + + + /** + * Custom object for request authorized with a Cognito User Pool authorizer. By casting the Principal + * object to this you can extract the Claims object included in the token. + */ + public class CognitoUserPoolPrincipal implements Principal { + + private CognitoAuthorizerClaims claims; + + CognitoUserPoolPrincipal(CognitoAuthorizerClaims c) { + claims = c; + } + + @Override + public String getName() { + return claims.getSubject(); + } + + public CognitoAuthorizerClaims getClaims() { + return claims; + } + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java index e9b141ec1..d8c594826 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.util.HashMap; import java.util.Map; /** @@ -31,7 +32,7 @@ public class AwsProxyRequest { private String resource; private ApiGatewayRequestContext requestContext; private Map queryStringParameters; - private Map headers; + private Map headers = new HashMap<>(); // avoid NPE private Map pathParameters; private String httpMethod; private Map stageVariables; @@ -105,7 +106,11 @@ public Map getHeaders() { public void setHeaders(Map headers) { - this.headers = headers; + if (null != headers) { + this.headers = headers; + } else { + this.headers.clear(); + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/CognitoAuthorizerClaims.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/CognitoAuthorizerClaims.java index 697bce4e8..6c3702862 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/CognitoAuthorizerClaims.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/CognitoAuthorizerClaims.java @@ -13,10 +13,14 @@ package com.amazonaws.serverless.proxy.internal.model; +import com.fasterxml.jackson.annotation.JsonAnyGetter; +import com.fasterxml.jackson.annotation.JsonAnySetter; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import java.time.format.DateTimeFormatter; +import java.util.HashMap; +import java.util.Map; /** @@ -44,6 +48,8 @@ public class CognitoAuthorizerClaims { // Variables - Private //------------------------------------------------------------- + private Map claims = new HashMap<>(); + @JsonProperty(value = "sub") private String subject; @JsonProperty(value = "aud") @@ -69,6 +75,16 @@ public class CognitoAuthorizerClaims { // Methods - Getter/Setter //------------------------------------------------------------- + @JsonAnyGetter + public String getClaim(String claim) { + return claims.get(claim); + } + + @JsonAnySetter + public void setClaim(String claim, String value) { + claims.put(claim, value); + } + public String getSubject() { return this.subject; } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java index 15ad74dac..c4d85dd2d 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java @@ -12,6 +12,8 @@ */ package com.amazonaws.serverless.proxy.internal.servlet; +import com.amazonaws.serverless.proxy.internal.RequestReader; +import com.amazonaws.serverless.proxy.internal.model.ApiGatewayRequestContext; import com.amazonaws.serverless.proxy.internal.model.ContainerConfig; import com.amazonaws.services.lambda.runtime.Context; @@ -24,6 +26,7 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpSessionContext; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.net.URLEncoder; @@ -68,6 +71,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest { private Context lambdaContext; private Map attributes; private ServletContext servletContext; + private AwsHttpSession session; protected DispatcherType dispatcherType; @@ -101,13 +105,17 @@ public String getRequestedSessionId() { @Override public HttpSession getSession(boolean b) { - return null; + if (b && null == this.session) { + ApiGatewayRequestContext requestContext = (ApiGatewayRequestContext) getAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY); + this.session = new AwsHttpSession(requestContext.getRequestId()); + } + return this.session; } @Override public HttpSession getSession() { - return null; + return this.session; } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java index 4ebfd8938..52b6b80a3 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java @@ -53,6 +53,7 @@ public class AwsHttpServletResponse private int statusCode; private String statusMessage; private String responseBody; + private PrintWriter writer; private ByteArrayOutputStream bodyOutputStream = new ByteArrayOutputStream(); private CountDownLatch writersCountDownLatch; private AwsHttpServletRequest request; @@ -316,7 +317,10 @@ public void close() @Override public PrintWriter getWriter() throws IOException { - return new PrintWriter(bodyOutputStream); + if (null == writer) { + writer = new PrintWriter(bodyOutputStream); + } + return writer; } @@ -358,7 +362,11 @@ public int getBufferSize() { @Override public void flushBuffer() throws IOException { + if (null != writer) { + writer.flush(); + } responseBody = new String(bodyOutputStream.toByteArray()); + log.debug("Response buffer flushed with {} bytes, latch={}", responseBody.length(), writersCountDownLatch.getCount()); isCommitted = true; writersCountDownLatch.countDown(); } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpSession.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpSession.java new file mode 100644 index 000000000..6aca5947c --- /dev/null +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpSession.java @@ -0,0 +1,111 @@ +package com.amazonaws.serverless.proxy.internal.servlet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.ServletContext; +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpSessionContext; +import java.util.Enumeration; + +public class AwsHttpSession implements HttpSession { + + private static final Logger log = LoggerFactory.getLogger(AwsHttpSession.class); + private String id; + + /** + * @param id API gateway request ID. + */ + public AwsHttpSession(String id) { + if (null == id) { + throw new RuntimeException("HTTP session id (from request ID) cannot be null"); + } + log.debug("Creating session " + id); + this.id = id; + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public String getId() { + return id; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public ServletContext getServletContext() { + return null; + } + + @Override + public void setMaxInactiveInterval(int interval) { + + } + + @Override + public int getMaxInactiveInterval() { + return 0; + } + + @Override + public HttpSessionContext getSessionContext() { + return null; + } + + @Override + public Object getAttribute(String name) { + return null; + } + + @Override + public Object getValue(String name) { + return null; + } + + @Override + public Enumeration getAttributeNames() { + return null; + } + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public void setAttribute(String name, Object value) { + + } + + @Override + public void putValue(String name, Object value) { + + } + + @Override + public void removeAttribute(String name) { + + } + + @Override + public void removeValue(String name) { + + } + + @Override + public void invalidate() { + + } + + @Override + public boolean isNew() { + return false; + } +} diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java index 39d7806ed..ade16b26b 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java @@ -22,6 +22,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.servlet.FilterChain; +import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -167,6 +169,10 @@ protected void setServletContext(final ServletContext context) { filterChainManager = new AwsFilterChainManager((AwsServletContext)context); } + protected FilterChain getFilterChain(ContainerRequestType req, Servlet servlet) { + return filterChainManager.getFilterChain(req, servlet); + } + //------------------------------------------------------------- // Methods - Protected @@ -176,15 +182,16 @@ protected void setServletContext(final ServletContext context) { * Applies the filter chain in the request lifecycle * @param request The Request object. This must be an implementation of HttpServletRequest * @param response The response object. This must be an implementation of HttpServletResponse + * @param servlet Servlet at the end of the chain (optional). * @throws IOException * @throws ServletException */ - protected void doFilter(ContainerRequestType request, ContainerResponseType response) throws IOException, ServletException { - FilterChainHolder chain = filterChainManager.getFilterChain(request); + protected void doFilter(ContainerRequestType request, ContainerResponseType response, Servlet servlet) throws IOException, ServletException { + FilterChain chain = getFilterChain(request, servlet); + log.debug("FilterChainHolder.doFilter {}", chain); chain.doFilter(request, response); } - //------------------------------------------------------------- // Inner Class - //------------------------------------------------------------- diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java index 5f836ef11..dc2a95504 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java @@ -196,9 +196,13 @@ public String getPathTranslated() { } + /** + * In AWS API Gateway, stage is never given as part of the path. + * @return + */ @Override public String getContextPath() { - return request.getRequestContext().getStage(); + return ""; } @@ -228,7 +232,7 @@ public Principal getUserPrincipal() { @Override public String getRequestURI() { - return request.getPath(); + return (getContextPath().isEmpty() ? "" : "/" + getContextPath()) + request.getPath(); } @@ -382,6 +386,9 @@ public String getContentType() { @Override public ServletInputStream getInputStream() throws IOException { + if (request.getBody() == null) { + return null; + } byte[] bodyBytes = request.getBody().getBytes(); if (request.isBase64Encoded()) { bodyBytes = Base64.getMimeDecoder().decode(request.getBody()); @@ -700,13 +707,8 @@ private Map> getFormUrlEncodedParametersMap() { if (!contentType.startsWith(MediaType.APPLICATION_FORM_URLENCODED) || !getMethod().toLowerCase().equals("post")) { return new HashMap<>(); } - String rawBodyContent; - try { - rawBodyContent = URLDecoder.decode(request.getBody(), DEFAULT_CHARACTER_ENCODING); - } catch (UnsupportedEncodingException e) { - log.warn("Could not decode body content - proceeding as if it was already decoded", e); - rawBodyContent = request.getBody(); - } + + String rawBodyContent = request.getBody(); Map> output = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); for (String parameter : rawBodyContent.split(FORM_DATA_SEPARATOR)) { @@ -718,10 +720,19 @@ private Map> getFormUrlEncodedParametersMap() { if (output.containsKey(parameterKeyValue[0])) { values = output.get(parameterKeyValue[0]); } - values.add(parameterKeyValue[1]); - output.put(parameterKeyValue[0], values); + values.add(decodeValueIfEncoded(parameterKeyValue[1])); + output.put(decodeValueIfEncoded(parameterKeyValue[0]), values); } return output; } + + private String decodeValueIfEncoded(String value) { + try { + return URLDecoder.decode(value, DEFAULT_CHARACTER_ENCODING); + } catch (UnsupportedEncodingException e) { + log.warn("Could not decode body content - proceeding as if it was already decoded", e); + return value; + } + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java index fcdbcb5ea..957af0f01 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; +import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Files; @@ -141,9 +142,16 @@ public int getEffectiveMinorVersion() { @Override public String getMimeType(String s) { try { - return Files.probeContentType(Paths.get(s)); + + if (s.startsWith("file:")) { // Support paths such as file:/D:/something/hello.txt + return Files.probeContentType(Paths.get(URI.create(s))); + } else if (s.startsWith("/")) { // Support paths such as file:/D:/something/hello.txt + return Files.probeContentType(Paths.get(URI.create("file://" + s))); + } else { + return Files.probeContentType(Paths.get(s)); + } } catch (IOException e) { - log.warn("Could not find content type for filter", e); + log.warn("Could not find content type for file {}", s, e); return null; } } @@ -364,6 +372,8 @@ public FilterRegistration.Dynamic addFilter(String name, Filter filter) { // filter already exists, we do nothing if (filters.containsKey(name)) { return null; + } else { + log.debug("Adding filter '{}' from {}", name, filter); } FilterHolder newFilter = new FilterHolder(name, filter, this); @@ -376,6 +386,7 @@ public FilterRegistration.Dynamic addFilter(String name, Filter filter) { @Override public FilterRegistration.Dynamic addFilter(String name, Class filterClass) { try { + log.debug("Adding filter '{}' from {}", name, filterClass.getName()); Filter newFilter = createFilter(filterClass); return addFilter(name, newFilter); } catch (ServletException e) { diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java index 47c5ebe6e..af7ba7a70 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java @@ -15,10 +15,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; +import javax.servlet.*; +import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.util.ArrayList; @@ -70,21 +68,27 @@ public class FilterChainHolder implements FilterChain { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException, ServletException { currentFilter++; - if (filters == null || filters.size() == 0 || currentFilter > filters.size() - 1) { - log.debug("Could not find filters to execute, returning"); - return; - } // TODO: We do not check for async filters here - FilterHolder holder = filters.get(currentFilter); + // if we still have filters, keep running through the chain + if (currentFilter <= filters.size() - 1) { + FilterHolder holder = filters.get(currentFilter); + + // lazily initialize filters when they are needed + if (!holder.isFilterInitialized()) { + holder.init(); + } + log.debug("Starting {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(), + currentFilter, holder.getFilterName(), holder.getFilter()); + holder.getFilter().doFilter(servletRequest, servletResponse, this); + log.debug("Executed {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(), + currentFilter, holder.getFilterName(), holder.getFilter()); + } - // lazily initialize filters when they are needed - if (!holder.isFilterInitialized()) { - holder.init(); + // if for some reason the response wasn't flushed yet, we force it here. + if (!servletResponse.isCommitted()) { + servletResponse.flushBuffer(); } - log.debug("Starting filter " + holder.getFilterName()); - holder.getFilter().doFilter(servletRequest, servletResponse, this); - log.debug("Executed filter " + holder.getFilterName()); } @@ -144,4 +148,9 @@ public List getFilters() { private void resetHolder() { currentFilter = -1; } + + @Override + public String toString() { + return "filters=" + filters; + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java index 79bbbc75d..bd3233496 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java @@ -13,9 +13,17 @@ package com.amazonaws.serverless.proxy.internal.servlet; import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.Servlet; import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -81,21 +89,25 @@ public abstract class FilterChainManagerFilterChainHolder object that can be used to apply the filters to the request */ - FilterChainHolder getFilterChain(final HttpServletRequest request) { + FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servlet) { String targetPath = request.getServletPath(); DispatcherType type = request.getDispatcherType(); // only return the cached result if the filter list hasn't changed in the meanwhile - if (getFilterHolders().size() == filtersSize && getFilterChainCache(type, targetPath) != null) { - return getFilterChainCache(type, targetPath); + if (getFilterHolders().size() == filtersSize && getFilterChainCache(type, targetPath, servlet) != null) { + return getFilterChainCache(type, targetPath, servlet); } FilterChainHolder chainHolder = new FilterChainHolder(); Map registrations = getFilterHolders(); if (registrations == null || registrations.size() == 0) { + if (servlet != null) { + chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servlet), servletContext)); + } return chainHolder; } for (String name : registrations.keySet()) { @@ -115,6 +127,10 @@ FilterChainHolder getFilterChain(final HttpServletRequest request) { // we assume we only ever have one servlet. } + if (servlet != null) { + chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servlet), servletContext)); + } + putFilterChainCache(type, targetPath, chainHolder); // update total filter size if (filtersSize != registrations.size()) { @@ -134,9 +150,10 @@ FilterChainHolder getFilterChain(final HttpServletRequest request) { * initialized with the cached list of {@link FilterHolder} objects * @param type The dispatcher type for the incoming request * @param targetPath The request path - this is extracted with the getPath method of the request object + * @param servlet Servlet to put at the end of the chain (optional). * @return A populated FilterChainHolder */ - private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath) { + private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet) { TargetCacheKey key = new TargetCacheKey(); key.setDispatcherType(type); key.setTargetPath(targetPath); @@ -300,4 +317,34 @@ void setDispatcherType(DispatcherType dispatcherType) { this.dispatcherType = dispatcherType; } } + + private class ServletExecutionFilter implements Filter { + + private FilterConfig config; + private Servlet handlerServlet; + + public ServletExecutionFilter(Servlet handler) { + handlerServlet = handler; + } + + @Override + public void init(FilterConfig filterConfig) + throws ServletException { + config = filterConfig; + } + + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + handlerServlet.service(servletRequest, servletResponse); + filterChain.doFilter(servletRequest, servletResponse); + } + + + @Override + public void destroy() { + + } + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java index daa318285..a8446b141 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java @@ -29,7 +29,7 @@ public class FilterHolder { //------------------------------------------------------------- private Filter filter; - private FilterConfig filterConfig; + private FilterConfig filterConfig = new Config(); private Registration registration; private String filterName; private Map initParameters; diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java index bdfd90fbd..5ffd29413 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java @@ -62,10 +62,12 @@ public AwsProxyRequestBuilder(String path, String httpMethod) { this.mapper = new ObjectMapper(); this.request = new AwsProxyRequest(); + this.request.setHeaders(new HashMap<>()); // avoid NPE this.request.setHttpMethod(httpMethod); this.request.setPath(path); this.request.setQueryStringParameters(new HashMap<>()); this.request.setRequestContext(new ApiGatewayRequestContext()); + this.request.getRequestContext().setRequestId("test-invoke-request"); this.request.getRequestContext().setStage("test"); ApiGatewayRequestIdentity identity = new ApiGatewayRequestIdentity(); identity.setSourceIp("127.0.0.1"); @@ -187,6 +189,12 @@ public AwsProxyRequestBuilder cognitoUserPool(String identityId) { return this; } + public AwsProxyRequestBuilder claim(String claim, String value) { + this.request.getRequestContext().getAuthorizer().getClaims().setClaim(claim, value); + + return this; + } + public AwsProxyRequestBuilder cognitoIdentity(String identityId, String identityPoolId) { this.request.getRequestContext().getIdentity().setCognitoAuthenticationType("IDENTITY"); diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/jaxrs/AwsProxySecurityContextTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/jaxrs/AwsProxySecurityContextTest.java index c5d631588..255f2a0da 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/jaxrs/AwsProxySecurityContextTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/jaxrs/AwsProxySecurityContextTest.java @@ -4,13 +4,17 @@ import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder; import org.junit.Test; +import java.security.Principal; + import static org.junit.Assert.*; public class AwsProxySecurityContextTest { + private static final String CLAIM_KEY = "custom:claim"; + private static final String CLAIM_VALUE = "customClaimant"; private static final String COGNITO_IDENTITY_ID = "us-east-2:123123123123"; private static final AwsProxyRequest REQUEST_NO_AUTH = new AwsProxyRequestBuilder("/hello", "GET").build(); private static final AwsProxyRequest REQUEST_COGNITO_USER_POOL = new AwsProxyRequestBuilder("/hello", "GET") - .cognitoUserPool(COGNITO_IDENTITY_ID).build(); + .cognitoUserPool(COGNITO_IDENTITY_ID).claim(CLAIM_KEY, CLAIM_VALUE).build(); @Test public void localVars_constructor_nullValues() { @@ -39,4 +43,16 @@ public void authScheme_getPrincipal_userPool() { assertEquals("COGNITO_USER_POOL", context.getAuthenticationScheme()); assertEquals(COGNITO_IDENTITY_ID, context.getUserPrincipal().getName()); } + + @Test + public void userPool_getClaims_retrieveCustomClaim() { + AwsProxySecurityContext context = new AwsProxySecurityContext(null, REQUEST_COGNITO_USER_POOL); + Principal userPrincipal = context.getUserPrincipal(); + assertNotNull(userPrincipal.getName()); + assertEquals(COGNITO_IDENTITY_ID, userPrincipal.getName()); + + assertTrue(userPrincipal instanceof AwsProxySecurityContext.CognitoUserPoolPrincipal); + assertNotNull(((AwsProxySecurityContext.CognitoUserPoolPrincipal)userPrincipal).getClaims().getClaim(CLAIM_KEY)); + assertEquals(CLAIM_VALUE, ((AwsProxySecurityContext.CognitoUserPoolPrincipal)userPrincipal).getClaims().getClaim(CLAIM_KEY)); + } } diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java index 84f320b96..c5ed8f248 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java @@ -115,21 +115,21 @@ public void filterChain_getFilterChain_subsetOfFilters() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter1", fcHolder.getFilter(0).getFilterName()); req = new AwsProxyHttpServletRequest( new AwsProxyRequestBuilder("/second/mime", "GET").build(), lambdaContext, null ); - fcHolder = chainManager.getFilterChain(req); + fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); req = new AwsProxyHttpServletRequest( new AwsProxyRequestBuilder("/second/mime/third", "GET").build(), lambdaContext, null ); - fcHolder = chainManager.getFilterChain(req); + fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); } @@ -140,7 +140,7 @@ public void filterChain_matchMultipleTimes_expectSameMatch() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter1", fcHolder.getFilter(0).getFilterName()); @@ -148,7 +148,7 @@ public void filterChain_matchMultipleTimes_expectSameMatch() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2); + FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2, null); assertEquals(1, fcHolder2.filterCount()); assertEquals("Filter1", fcHolder2.getFilter(0).getFilterName()); } @@ -159,7 +159,7 @@ public void filerChain_executeMultipleFilters_expectRunEachTime() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter1", fcHolder.getFilter(0).getFilterName()); AwsHttpServletResponse resp = new AwsHttpServletResponse(req, new CountDownLatch(1)); @@ -183,7 +183,7 @@ public void filerChain_executeMultipleFilters_expectRunEachTime() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req2.setServletContext(servletContext); - FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2); + FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2, null); assertEquals(1, fcHolder2.filterCount()); assertEquals("Filter1", fcHolder2.getFilter(0).getFilterName()); assertEquals(-1, fcHolder2.currentFilter); @@ -212,14 +212,14 @@ public void filterChain_getFilterChain_multipleFilters() { req.setServletContext(servletContext); FilterRegistration.Dynamic reg = req.getServletContext().addFilter("Filter4", new MockFilter()); reg.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/second/*"); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(2, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); assertEquals("Filter4", fcHolder.getFilter(1).getFilterName()); reg = req.getServletContext().addFilter("Filter5", new MockFilter()); reg.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/second/*"); - fcHolder = chainManager.getFilterChain(req); + fcHolder = chainManager.getFilterChain(req, null); assertEquals(3, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); assertEquals("Filter4", fcHolder.getFilter(1).getFilterName()); diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestFormTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestFormTest.java index f82ab9d93..62f3fe971 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestFormTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestFormTest.java @@ -7,12 +7,14 @@ import org.apache.commons.io.IOUtils; import org.apache.http.HttpEntity; import org.apache.http.client.entity.EntityBuilder; +import org.apache.http.entity.ContentType; import org.apache.http.entity.mime.MultipartEntityBuilder; import org.junit.Test; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; import java.io.IOException; import java.util.Random; @@ -29,6 +31,8 @@ public class AwsProxyHttpServletRequestFormTest { private static final String PART_VALUE_2 = "value2"; private static final String FILE_KEY = "file_upload_1"; + private static final String ENCODED_VALUE = "test123a%3D1%262@3"; + private static final HttpEntity MULTIPART_FORM_DATA = MultipartEntityBuilder.create() .addTextBody(PART_KEY_1, PART_VALUE_1) .addTextBody(PART_KEY_2, PART_VALUE_2) @@ -43,6 +47,24 @@ public class AwsProxyHttpServletRequestFormTest { .addTextBody(PART_KEY_2, PART_VALUE_2) .addBinaryBody(FILE_KEY, FILE_BYTES) .build(); + private static final String ENCODED_FORM_ENTITY = PART_KEY_1 + "=" + ENCODED_VALUE + "&" + PART_KEY_2 + "=" + PART_VALUE_2; + + @Test + public void postForm_getParam_getEncodedFullValue() { + try { + AwsProxyRequest proxyRequest = new AwsProxyRequestBuilder("/form", "POST") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED) + .body(ENCODED_FORM_ENTITY) + .build(); + + HttpServletRequest request = new AwsProxyHttpServletRequest(proxyRequest, null, null); + assertNotNull(request.getParts()); + assertEquals("test123a=1&2@3", request.getParameter(PART_KEY_1)); + } catch (IOException | ServletException e) { + fail(e.getMessage()); + } + } + @Test public void postForm_getParts_parsing() { try { diff --git a/aws-serverless-java-container-spark/pom.xml b/aws-serverless-java-container-spark/pom.xml index 5901f1fa3..4fb1c3e03 100644 --- a/aws-serverless-java-container-spark/pom.xml +++ b/aws-serverless-java-container-spark/pom.xml @@ -15,7 +15,7 @@ - 2.6.0 + 2.7.1 diff --git a/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java b/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java index b4a059d30..9c61f0812 100644 --- a/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java +++ b/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java @@ -29,9 +29,13 @@ import spark.embeddedserver.EmbeddedServerFactory; import spark.embeddedserver.EmbeddedServers; +import javax.servlet.DispatcherType; +import javax.servlet.FilterRegistration; + import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.EnumSet; import java.util.concurrent.CountDownLatch; /** @@ -162,10 +166,16 @@ protected void handleRequest(AwsProxyHttpServletRequest httpServletRequest, AwsH if (startupHandler != null) { startupHandler.onStartup(getServletContext()); } - } - doFilter(httpServletRequest, httpServletResponse); + // manually add the spark filter to the chain. This should the last one and match all uris + FilterRegistration.Dynamic sparkRegistration = getServletContext().addFilter("SparkFilter", embeddedServer.getSparkFilter()); + sparkRegistration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/*"); + + // adding this call to make sure that the framework is fully initialized. This should address a race + // condition and solve GitHub issue #71. + Spark.awaitInitialization(); + } - embeddedServer.handle(httpServletRequest, httpServletResponse); + doFilter(httpServletRequest, httpServletResponse, null); } } diff --git a/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/embeddedserver/LambdaEmbeddedServer.java b/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/embeddedserver/LambdaEmbeddedServer.java index e6c8541f2..6885cc52d 100644 --- a/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/embeddedserver/LambdaEmbeddedServer.java +++ b/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/embeddedserver/LambdaEmbeddedServer.java @@ -9,6 +9,7 @@ import spark.ssl.SslStores; import spark.staticfiles.StaticFilesConfiguration; +import javax.servlet.Filter; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -38,6 +39,9 @@ public class LambdaEmbeddedServer applicationRoutes = routes; staticFilesConfiguration = filesConfig; hasMultipleHandler = multipleHandlers; + + // try to initialize the filter here. + sparkFilter = new MatcherFilter(applicationRoutes, staticFilesConfiguration, true, hasMultipleHandler); } @@ -48,12 +52,13 @@ public class LambdaEmbeddedServer public int ignite(String s, int i, SslStores sslStores, int i1, int i2, int i3) throws Exception { log.info("Starting Spark server, ignoring port and host"); - sparkFilter = new MatcherFilter(applicationRoutes, staticFilesConfiguration, false, hasMultipleHandler); + // if not initialized yet + if (sparkFilter == null) { + sparkFilter = new MatcherFilter(applicationRoutes, staticFilesConfiguration, true, hasMultipleHandler); + } sparkFilter.init(null); - //countDownLatch.countDown(); - - return 0; + return i; } @@ -92,4 +97,13 @@ public void handle(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { sparkFilter.doFilter(request, response, null); } + + + /** + * Returns the initialized instance of the main Spark filter. + * @return The spark filter instance. + */ + public Filter getSparkFilter() { + return sparkFilter; + } } diff --git a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/HelloWorldSparkTest.java b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/HelloWorldSparkTest.java index d75595e26..3975a4218 100644 --- a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/HelloWorldSparkTest.java +++ b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/HelloWorldSparkTest.java @@ -37,7 +37,7 @@ public static void initializeServer() { handler = SparkLambdaContainerHandler.getAwsProxyHandler(); configureRoutes(); - + Spark.awaitInitialization(); } catch (RuntimeException | ContainerInitializationException e) { e.printStackTrace(); fail(); diff --git a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/InitExceptionHandlerTest.java b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/InitExceptionHandlerTest.java index a0c976d51..a95497c26 100644 --- a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/InitExceptionHandlerTest.java +++ b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/InitExceptionHandlerTest.java @@ -44,7 +44,7 @@ public void initException_mockException_expectHandlerToRun() { serverFactory); configureRoutes(); - + Spark.awaitInitialization(); } catch (Exception e) { e.printStackTrace(); fail("Error while mocking server"); diff --git a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandlerTest.java b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandlerTest.java index ecf09226e..f1884537a 100644 --- a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandlerTest.java +++ b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandlerTest.java @@ -7,8 +7,10 @@ import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder; import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext; import com.amazonaws.serverless.proxy.spark.filter.CustomHeaderFilter; +import com.amazonaws.serverless.proxy.spark.filter.UnauthenticatedFilter; import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import spark.Spark; @@ -48,6 +50,8 @@ public void filters_onStartupMethod_executeFilters() { configureRoutes(); + Spark.awaitInitialization(); + AwsProxyRequest req = new AwsProxyRequestBuilder().method("GET").path("/header-filter").build(); AwsProxyResponse response = handler.proxy(req, new MockLambdaContext()); @@ -59,15 +63,65 @@ public void filters_onStartupMethod_executeFilters() { } + @Test + public void filters_unauthenticatedFilter_stopRequestProcessing() { + + SparkLambdaContainerHandler handler = null; + try { + handler = SparkLambdaContainerHandler.getAwsProxyHandler(); + } catch (ContainerInitializationException e) { + e.printStackTrace(); + fail(); + } + + handler.onStartup(c -> { + if (c == null) { + System.out.println("Null servlet context"); + fail(); + } + FilterRegistration.Dynamic registration = c.addFilter("UnauthenticatedFilter", UnauthenticatedFilter.class); + // update the registration to map to a path + registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/unauth"); + // servlet name mappings are disabled and will throw an exception + }); + + configureRoutes(); + Spark.awaitInitialization(); + + // first we test without the custom header, we expect request processing to complete + // successfully + AwsProxyRequest req = new AwsProxyRequestBuilder().method("GET").path("/unauth").build(); + AwsProxyResponse response = handler.proxy(req, new MockLambdaContext()); + + assertNotNull(response); + assertEquals(200, response.getStatusCode()); + assertEquals(RESPONSE_BODY_TEXT, response.getBody()); + + // now we test with the custom header, this should stop request processing in the + // filter and return an unauthenticated response + AwsProxyRequest unauthReq = new AwsProxyRequestBuilder().method("GET").path("/unauth") + .header(UnauthenticatedFilter.HEADER_NAME, "1").build(); + AwsProxyResponse unauthResp = handler.proxy(unauthReq, new MockLambdaContext()); + + assertNotNull(unauthResp); + assertEquals(UnauthenticatedFilter.RESPONSE_STATUS, unauthResp.getStatusCode()); + assertEquals("", unauthResp.getBody()); + } + @AfterClass public static void stopSpark() { Spark.stop(); } - private void configureRoutes() { + private static void configureRoutes() { get("/header-filter", (req, res) -> { res.status(200); return RESPONSE_BODY_TEXT; }); + + get("/unauth", (req, res) -> { + res.status(200); + return RESPONSE_BODY_TEXT; + }); } } diff --git a/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/filter/UnauthenticatedFilter.java b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/filter/UnauthenticatedFilter.java new file mode 100644 index 000000000..60e8eb429 --- /dev/null +++ b/aws-serverless-java-container-spark/src/test/java/com/amazonaws/serverless/proxy/spark/filter/UnauthenticatedFilter.java @@ -0,0 +1,45 @@ +package com.amazonaws.serverless.proxy.spark.filter; + + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import java.io.IOException; + + +public class UnauthenticatedFilter implements Filter { + public static final String HEADER_NAME = "X-Unauthenticated-Response"; + public static final int RESPONSE_STATUS = 401; + + @Override + public void init(FilterConfig filterConfig) + throws ServletException { + + } + + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + System.out.println("Running unauth filter"); + if (((HttpServletRequest)servletRequest).getHeader(HEADER_NAME) != null) { + ((HttpServletResponse) servletResponse).setStatus(401); + System.out.println("Returning 401"); + return; + } + System.out.println("Continue chain"); + filterChain.doFilter(servletRequest, servletResponse); + } + + + @Override + public void destroy() { + + } +} diff --git a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/LambdaSpringApplicationInitializer.java b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/LambdaSpringApplicationInitializer.java index 45a9e1e0f..28e38a4a8 100644 --- a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/LambdaSpringApplicationInitializer.java +++ b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/LambdaSpringApplicationInitializer.java @@ -27,6 +27,7 @@ import org.springframework.web.servlet.DispatcherServlet; import javax.servlet.*; +import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; @@ -42,7 +43,7 @@ * `currentResponse` private property to the value of the new `HttpServletResponse` object. This is used to intercept * Spring notifications for the `ServletRequestHandledEvent` and call the flush method to release the latch. */ -public class LambdaSpringApplicationInitializer implements WebApplicationInitializer { +public class LambdaSpringApplicationInitializer extends HttpServlet implements WebApplicationInitializer { public static final String ERROR_NO_CONTEXT = "No application context or configuration classes provided"; private static final String DEFAULT_SERVLET_NAME = "aws-servless-java-container"; @@ -98,6 +99,15 @@ public void dispatch(HttpServletRequest request, HttpServletResponse response) dispatcherServlet.service(request, response); } + + /** + * Gets the initialized Spring dispatcher servlet instance. + * @return + */ + public Servlet getDispatcherServlet() { + return dispatcherServlet; + } + public List getSpringProfiles() { return Collections.unmodifiableList(springProfiles); } @@ -151,6 +161,18 @@ private void notifyStartListeners(ServletContext context) { } } + + /////////////////////////////////////////////////////////////// + // HttpServlet implementation // + // This is used to pass the initializer to the filter chain // + // to handle requests // + /////////////////////////////////////////////////////////////// + + @Override + public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException { + dispatch((HttpServletRequest)req, (HttpServletResponse)res); + } + /** * Default configuration class for the DispatcherServlet. This just mocks the behaviour of a default * ServletConfig object with no init parameters diff --git a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java new file mode 100644 index 000000000..de7ce509b --- /dev/null +++ b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java @@ -0,0 +1,254 @@ +/* + * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package com.amazonaws.serverless.proxy.spring; + +import com.amazonaws.serverless.exceptions.ContainerInitializationException; +import com.amazonaws.serverless.proxy.internal.*; +import com.amazonaws.serverless.proxy.internal.model.AwsProxyRequest; +import com.amazonaws.serverless.proxy.internal.model.AwsProxyResponse; +import com.amazonaws.serverless.proxy.internal.servlet.*; +import com.amazonaws.services.lambda.runtime.Context; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.SpringServletContainerInitializer; +import org.springframework.web.WebApplicationInitializer; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.servlet.DispatcherServlet; + +import javax.servlet.*; +import javax.servlet.http.HttpServletResponse; +import java.util.*; +import java.util.concurrent.CountDownLatch; + +/** + * Spring implementation of the `LambdaContainerHandler` abstract class. This class uses the `LambdaSpringApplicationInitializer` + * object behind the scenes to proxy requests. The default implementation leverages the `AwsProxyHttpServletRequest` and + * `AwsHttpServletResponse` implemented in the `aws-serverless-java-container-core` package. + * + * Important: Make sure to add {@link LambdaFlushResponseListener} in your SpringBootServletInitializer subclass configure(). + * + * @param The incoming event type + * @param The expected return type + */ +public class SpringBootLambdaContainerHandler extends AwsLambdaServletContainerHandler { + static ThreadLocal currentResponse = new ThreadLocal<>(); + private final Class springBootInitializer; + private static final Logger log = LoggerFactory.getLogger(SpringBootLambdaContainerHandler.class); + + // State vars + private boolean initialized; + + /** + * Creates a default SpringLambdaContainerHandler initialized with the `AwsProxyRequest` and `AwsProxyResponse` objects + * @param springBootInitializer {@code SpringBootServletInitializer} class + * @return An initialized instance of the `SpringLambdaContainerHandler` + * @throws ContainerInitializationException + */ + public static SpringBootLambdaContainerHandler getAwsProxyHandler(Class springBootInitializer) + throws ContainerInitializationException { + return new SpringBootLambdaContainerHandler<>( + new AwsProxyHttpServletRequestReader(), + new AwsProxyHttpServletResponseWriter(), + new AwsProxySecurityContextWriter(), + new AwsProxyExceptionHandler(), + springBootInitializer + ); + } + + /** + * Creates a new container handler with the given reader and writer objects + * + * @param requestReader An implementation of `RequestReader` + * @param responseWriter An implementation of `ResponseWriter` + * @param securityContextWriter An implementation of `SecurityContextWriter` + * @param exceptionHandler An implementation of `ExceptionHandler` + * @throws ContainerInitializationException + */ + public SpringBootLambdaContainerHandler(RequestReader requestReader, + ResponseWriter responseWriter, + SecurityContextWriter securityContextWriter, + ExceptionHandler exceptionHandler, + Class springBootInitializer) + throws ContainerInitializationException { + super(requestReader, responseWriter, securityContextWriter, exceptionHandler); + this.springBootInitializer = springBootInitializer; + } + + @Override + protected AwsHttpServletResponse getContainerResponse(AwsProxyHttpServletRequest request, CountDownLatch latch) { + return new AwsHttpServletResponse(request, latch); + } + + @Override + protected void handleRequest(AwsProxyHttpServletRequest containerRequest, AwsHttpServletResponse containerResponse, Context lambdaContext) throws Exception { + // this method of the AwsLambdaServletContainerHandler sets the servlet context + if (getServletContext() == null) { + setServletContext(new SpringBootAwsServletContext()); + } + + // wire up the application context on the first invocation + if (!initialized) { + SpringServletContainerInitializer springServletContainerInitializer = new SpringServletContainerInitializer(); + LinkedHashSet> webAppInitializers = new LinkedHashSet<>(); + webAppInitializers.add(springBootInitializer); + springServletContainerInitializer.onStartup(webAppInitializers, getServletContext()); + initialized = true; + } + + containerRequest.setServletContext(getServletContext()); + + currentResponse.set(containerResponse); + try { + WebApplicationContext applicationContext = WebApplicationContextUtils.getRequiredWebApplicationContext(getServletContext()); + DispatcherServlet dispatcherServlet = applicationContext.getBean("dispatcherServlet", DispatcherServlet.class); + // process filters & invoke servlet + log.debug("Process filters & invoke servlet: {}", dispatcherServlet); + doFilter(containerRequest, containerResponse, dispatcherServlet); + } finally { + // call the flush method to release the latch + SpringBootLambdaContainerHandler.currentResponse.remove(); + currentResponse.remove(); + } + } + + private class SpringBootAwsServletContext extends AwsServletContext { + public SpringBootAwsServletContext() { + super(SpringBootLambdaContainerHandler.this); + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, String s1) { + throw new UnsupportedOperationException("Only dispatcherServlet is supported"); + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Class aClass) { + throw new UnsupportedOperationException("Only dispatcherServlet is supported"); + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Servlet servlet) { + if ("dispatcherServlet".equals(s)) { + try { + servlet.init(new ServletConfig() { + @Override + public String getServletName() { + return s; + } + + @Override + public ServletContext getServletContext() { + return SpringBootAwsServletContext.this; + } + + @Override + public String getInitParameter(String name) { + return null; + } + + @Override + public Enumeration getInitParameterNames() { + return new Enumeration() { + @Override + public boolean hasMoreElements() { + return false; + } + + @Override + public String nextElement() { + return null; + } + }; + } + }); + } catch (ServletException e) { + throw new RuntimeException("Cannot add servlet " + servlet, e); + } + return new ServletRegistration.Dynamic() { + @Override + public String getName() { + return s; + } + + @Override + public String getClassName() { + return null; + } + + @Override + public boolean setInitParameter(String name, String value) { + return false; + } + + @Override + public String getInitParameter(String name) { + return null; + } + + @Override + public Set setInitParameters(Map initParameters) { + return null; + } + + @Override + public Map getInitParameters() { + return null; + } + + @Override + public Set addMapping(String... urlPatterns) { + return null; + } + + @Override + public Collection getMappings() { + return null; + } + + @Override + public String getRunAsRole() { + return null; + } + + @Override + public void setAsyncSupported(boolean isAsyncSupported) { + + } + + @Override + public void setLoadOnStartup(int loadOnStartup) { + + } + + @Override + public Set setServletSecurity(ServletSecurityElement constraint) { + return null; + } + + @Override + public void setMultipartConfig(MultipartConfigElement multipartConfig) { + + } + + @Override + public void setRunAsRole(String roleName) { + + } + }; + } else { + throw new UnsupportedOperationException("Only dispatcherServlet is supported"); + } + } + } +} diff --git a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java index f3a932b0b..9ac03516e 100644 --- a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java +++ b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java @@ -21,7 +21,6 @@ import org.springframework.web.context.ConfigurableWebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import javax.servlet.ServletContext; import java.util.Arrays; import java.util.concurrent.CountDownLatch; @@ -136,8 +135,6 @@ protected void handleRequest(AwsProxyHttpServletRequest containerRequest, AwsHtt containerRequest.setServletContext(getServletContext()); // process filters - doFilter(containerRequest, containerResponse); - // invoke servlet - initializer.dispatch(containerRequest, containerResponse); + doFilter(containerRequest, containerResponse, initializer); } } diff --git a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java index 9d6ee9e11..1cbeaef4e 100644 --- a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java +++ b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java @@ -6,6 +6,7 @@ import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder; import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext; import com.amazonaws.serverless.proxy.spring.echoapp.EchoSpringAppConfig; +import com.amazonaws.serverless.proxy.spring.echoapp.UnauthenticatedFilter; import com.amazonaws.serverless.proxy.spring.echoapp.model.MapResponseModel; import com.amazonaws.serverless.proxy.spring.echoapp.model.SingleValueModel; import com.fasterxml.jackson.core.JsonProcessingException; @@ -139,6 +140,18 @@ public void error_statusCode_methodNotAllowed() { assertEquals(405, output.getStatusCode()); } + @Test + public void error_unauthenticatedCall_filterStepsRequest() { + AwsProxyRequest request = new AwsProxyRequestBuilder("/echo/status-code", "GET") + .header(UnauthenticatedFilter.HEADER_NAME, "1") + .json() + .queryString("status", "201") + .build(); + + AwsProxyResponse output = handler.proxy(request, lambdaContext); + assertEquals(401, output.getStatusCode()); + } + @Test public void responseBody_responseWriter_validBody() throws JsonProcessingException { SingleValueModel singleValueModel = new SingleValueModel(); @@ -176,6 +189,34 @@ public void base64_binaryResponse_base64Encoding() { assertTrue(Base64.isBase64(response.getBody())); } + @Test + public void injectBody_populatedResponse_noException() { + AwsProxyRequest request = new AwsProxyRequestBuilder("/echo/request-body", "POST") + .body("This is a populated body") + .build(); + + AwsProxyResponse response = handler.proxy(request, lambdaContext); + assertNotNull(response.getBody()); + try { + SingleValueModel output = objectMapper.readValue(response.getBody(), SingleValueModel.class); + assertEquals("true", output.getValue()); + } catch (IOException e) { + e.printStackTrace(); + fail(); + } + + AwsProxyRequest emptyReq = new AwsProxyRequestBuilder("/echo/request-body", "POST") + .build(); + AwsProxyResponse emptyResp = handler.proxy(emptyReq, lambdaContext); + try { + SingleValueModel output = objectMapper.readValue(emptyResp.getBody(), SingleValueModel.class); + assertEquals(null, output.getValue()); + } catch (IOException e) { + e.printStackTrace(); + fail(); + } + } + @Test public void servletRequestEncoding_acceptEncoding_okStatusCode() { SingleValueModel singleValueModel = new SingleValueModel(); @@ -209,16 +250,16 @@ public void request_requestURI() { @Test public void request_requestURL() { - AwsProxyRequest request = new AwsProxyRequestBuilder("/echo/request-Url", "GET") + AwsProxyRequest request = new AwsProxyRequestBuilder("/echo/request-url", "GET") .scheme("https") .serverName("api.myserver.com") .stage("prod") .build(); - + handler.stripBasePath(""); AwsProxyResponse output = handler.proxy(request, lambdaContext); assertEquals(200, output.getStatusCode()); - validateSingleValueModel(output, "https://api.myserver.com/prod/echo/request-Url"); + validateSingleValueModel(output, "https://api.myserver.com/echo/request-url"); } @Test diff --git a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoResource.java b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoResource.java index 55deb7edb..185bc7bdc 100644 --- a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoResource.java +++ b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoResource.java @@ -102,7 +102,7 @@ public SingleValueModel echoRequestURI(HttpServletRequest request) { return valueModel; } - @RequestMapping(path = "/request-Url", method = RequestMethod.GET) + @RequestMapping(path = "/request-url", method = RequestMethod.GET) public SingleValueModel echoRequestURL(HttpServletRequest request) { SingleValueModel valueModel = new SingleValueModel(); valueModel.setValue(request.getRequestURL().toString()); @@ -110,6 +110,17 @@ public SingleValueModel echoRequestURL(HttpServletRequest request) { return valueModel; } + @RequestMapping(path = "/request-body", method = RequestMethod.POST) + public SingleValueModel helloForPopulatedBody(@RequestBody(required = false) String input) { + SingleValueModel valueModel = new SingleValueModel(); + System.out.println("Input: \"" + input + "\""); + if (input != null && !"null".equals(input)) { + valueModel.setValue("true"); + } + + return valueModel; + } + @RequestMapping(path = "/encoded-request-uri/{encoded-var}", method = RequestMethod.GET) public SingleValueModel echoEncodedRequestUri(@PathVariable("encoded-var") String encodedVar) { SingleValueModel valueModel = new SingleValueModel(); diff --git a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoSpringAppConfig.java b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoSpringAppConfig.java index 93134c1bf..1f068f8fe 100644 --- a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoSpringAppConfig.java +++ b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/EchoSpringAppConfig.java @@ -13,6 +13,11 @@ import org.springframework.context.annotation.PropertySource; import org.springframework.web.context.ConfigurableWebApplicationContext; +import javax.servlet.DispatcherType; +import javax.servlet.FilterRegistration; + +import java.util.EnumSet; + @Configuration @ComponentScan("com.amazonaws.serverless.proxy.spring.echoapp") @@ -26,6 +31,12 @@ public class EchoSpringAppConfig { public SpringLambdaContainerHandler springLambdaContainerHandler() throws ContainerInitializationException { SpringLambdaContainerHandler handler = SpringLambdaContainerHandler.getAwsProxyHandler(applicationContext); handler.setRefreshContext(false); + handler.onStartup(c -> { + FilterRegistration.Dynamic registration = c.addFilter("UnauthenticatedFilter", UnauthenticatedFilter.class); + // update the registration to map to a path + registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/echo/*"); + // servlet name mappings are disabled and will throw an exception + }); return handler; } diff --git a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/UnauthenticatedFilter.java b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/UnauthenticatedFilter.java new file mode 100644 index 000000000..b63401572 --- /dev/null +++ b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/echoapp/UnauthenticatedFilter.java @@ -0,0 +1,45 @@ +package com.amazonaws.serverless.proxy.spring.echoapp; + + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import java.io.IOException; + + +public class UnauthenticatedFilter implements Filter { + public static final String HEADER_NAME = "X-Unauthenticated-Response"; + public static final int RESPONSE_STATUS = 401; + + @Override + public void init(FilterConfig filterConfig) + throws ServletException { + + } + + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + System.out.println("Running unauth filter"); + if (((HttpServletRequest)servletRequest).getHeader(HEADER_NAME) != null) { + ((HttpServletResponse) servletResponse).setStatus(401); + System.out.println("Returning 401"); + return; + } + System.out.println("Continue chain"); + filterChain.doFilter(servletRequest, servletResponse); + } + + + @Override + public void destroy() { + + } +} diff --git a/samples/spark/pet-store/pom.xml b/samples/spark/pet-store/pom.xml index 7cad08a18..a291acf72 100644 --- a/samples/spark/pet-store/pom.xml +++ b/samples/spark/pet-store/pom.xml @@ -27,7 +27,7 @@ 1.8 1.8 2.8.5 - 2.6.0 + 2.7.1 diff --git a/samples/spark/pet-store/src/main/java/com/amazonaws/serverless/sample/spark/LambdaHandler.java b/samples/spark/pet-store/src/main/java/com/amazonaws/serverless/sample/spark/LambdaHandler.java index efd481ffc..8eff622e4 100644 --- a/samples/spark/pet-store/src/main/java/com/amazonaws/serverless/sample/spark/LambdaHandler.java +++ b/samples/spark/pet-store/src/main/java/com/amazonaws/serverless/sample/spark/LambdaHandler.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import spark.Spark; import javax.ws.rs.core.Response; import java.util.UUID; @@ -45,6 +46,7 @@ public AwsProxyResponse handleRequest(AwsProxyRequest awsProxyRequest, Context c try { handler = SparkLambdaContainerHandler.getAwsProxyHandler(); defineResources(); + Spark.awaitInitialization(); } catch (ContainerInitializationException e) { log.error("Cannot initialize Spark application", e); return null;