Skip to content

Commit

Permalink
Merge pull request #77 from awslabs/servlet-improvements
Browse files Browse the repository at this point in the history
Servlet improvements merge for 0.8 release
  • Loading branch information
sapessi authored Nov 22, 2017
2 parents 55aae8d + a1b0798 commit d074fae
Show file tree
Hide file tree
Showing 34 changed files with 913 additions and 90 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public class LambdaHandler implements RequestHandler<AwsProxyRequest, AwsProxyRe
public AwsProxyResponse handleRequest(AwsProxyRequest awsProxyRequest, Context context) {
if (!initialized) {
defineRoutes();
// it's important to call the awaitInitialization method not to run into race
// conditions as routes are loaded asynchronously
Spark.awaitInitialization();
initialized = true;
}
return handler.proxy(awsProxyRequest, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ public class InvalidRequestEventException extends Exception {
public InvalidRequestEventException(String message, Exception e) {
super(message, e);
}

public InvalidRequestEventException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package com.amazonaws.serverless.proxy.internal.jaxrs;

import com.amazonaws.serverless.proxy.internal.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.internal.model.CognitoAuthorizerClaims;
import com.amazonaws.services.lambda.runtime.Context;

import javax.ws.rs.core.SecurityContext;
Expand Down Expand Up @@ -61,26 +62,33 @@ public AwsProxySecurityContext(final Context lambdaContext, final AwsProxyReques
//-------------------------------------------------------------

public Principal getUserPrincipal() {
return () -> {
if (getAuthenticationScheme() == null) {
return null;
}

if (getAuthenticationScheme().equals(AUTH_SCHEME_CUSTOM)) {
return event.getRequestContext().getAuthorizer().getPrincipalId();
} else if (getAuthenticationScheme().equals(AUTH_SCHEME_AWS_IAM)) {
// if we received credentials from Cognito Federated Identities then we return the identity id
if (event.getRequestContext().getIdentity().getCognitoIdentityId() != null) {
return event.getRequestContext().getIdentity().getCognitoIdentityId();
} else { // otherwise the user arn from the credentials
return event.getRequestContext().getIdentity().getUserArn();
if (getAuthenticationScheme() == null) {
return () -> null;
}

if (getAuthenticationScheme().equals(AUTH_SCHEME_CUSTOM) || getAuthenticationScheme().equals(AUTH_SCHEME_AWS_IAM)) {
return () -> {
if (getAuthenticationScheme().equals(AUTH_SCHEME_CUSTOM)) {
return event.getRequestContext().getAuthorizer().getPrincipalId();
} else if (getAuthenticationScheme().equals(AUTH_SCHEME_AWS_IAM)) {
// if we received credentials from Cognito Federated Identities then we return the identity id
if (event.getRequestContext().getIdentity().getCognitoIdentityId() != null) {
return event.getRequestContext().getIdentity().getCognitoIdentityId();
} else { // otherwise the user arn from the credentials
return event.getRequestContext().getIdentity().getUserArn();
}
}
} else if (getAuthenticationScheme().equals(AUTH_SCHEME_COGNITO_POOL)) {
return event.getRequestContext().getAuthorizer().getClaims().getSubject();
}

return null;
};
// return null if we couldn't find a valid scheme
return null;
};
}

if (getAuthenticationScheme().equals(AUTH_SCHEME_COGNITO_POOL)) {
return new CognitoUserPoolPrincipal(event.getRequestContext().getAuthorizer().getClaims());
}

throw new RuntimeException("Cannot recognize authorization scheme in event");
}


Expand All @@ -105,4 +113,27 @@ public String getAuthenticationScheme() {
return null;
}
}


/**
* Custom object for request authorized with a Cognito User Pool authorizer. By casting the Principal
* object to this you can extract the Claims object included in the token.
*/
public class CognitoUserPoolPrincipal implements Principal {

private CognitoAuthorizerClaims claims;

CognitoUserPoolPrincipal(CognitoAuthorizerClaims c) {
claims = c;
}

@Override
public String getName() {
return claims.getSubject();
}

public CognitoAuthorizerClaims getClaims() {
return claims;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

import java.util.HashMap;
import java.util.Map;

/**
Expand All @@ -31,7 +32,7 @@ public class AwsProxyRequest {
private String resource;
private ApiGatewayRequestContext requestContext;
private Map<String, String> queryStringParameters;
private Map<String, String> headers;
private Map<String, String> headers = new HashMap<>(); // avoid NPE
private Map<String, String> pathParameters;
private String httpMethod;
private Map<String, String> stageVariables;
Expand Down Expand Up @@ -105,7 +106,11 @@ public Map<String, String> getHeaders() {


public void setHeaders(Map<String, String> headers) {
this.headers = headers;
if (null != headers) {
this.headers = headers;
} else {
this.headers.clear();
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
package com.amazonaws.serverless.proxy.internal.model;


import com.fasterxml.jackson.annotation.JsonAnyGetter;
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.Map;


/**
Expand Down Expand Up @@ -44,6 +48,8 @@ public class CognitoAuthorizerClaims {
// Variables - Private
//-------------------------------------------------------------

private Map<String, String> claims = new HashMap<>();

@JsonProperty(value = "sub")
private String subject;
@JsonProperty(value = "aud")
Expand All @@ -69,6 +75,16 @@ public class CognitoAuthorizerClaims {
// Methods - Getter/Setter
//-------------------------------------------------------------

@JsonAnyGetter
public String getClaim(String claim) {
return claims.get(claim);
}

@JsonAnySetter
public void setClaim(String claim, String value) {
claims.put(claim, value);
}

public String getSubject() { return this.subject; }


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package com.amazonaws.serverless.proxy.internal.servlet;

import com.amazonaws.serverless.proxy.internal.RequestReader;
import com.amazonaws.serverless.proxy.internal.model.ApiGatewayRequestContext;
import com.amazonaws.serverless.proxy.internal.model.ContainerConfig;
import com.amazonaws.services.lambda.runtime.Context;

Expand All @@ -24,6 +26,7 @@
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionContext;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.net.URLEncoder;
Expand Down Expand Up @@ -68,6 +71,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest {
private Context lambdaContext;
private Map<String, Object> attributes;
private ServletContext servletContext;
private AwsHttpSession session;

protected DispatcherType dispatcherType;

Expand Down Expand Up @@ -101,13 +105,17 @@ public String getRequestedSessionId() {

@Override
public HttpSession getSession(boolean b) {
return null;
if (b && null == this.session) {
ApiGatewayRequestContext requestContext = (ApiGatewayRequestContext) getAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY);
this.session = new AwsHttpSession(requestContext.getRequestId());
}
return this.session;
}


@Override
public HttpSession getSession() {
return null;
return this.session;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class AwsHttpServletResponse
private int statusCode;
private String statusMessage;
private String responseBody;
private PrintWriter writer;
private ByteArrayOutputStream bodyOutputStream = new ByteArrayOutputStream();
private CountDownLatch writersCountDownLatch;
private AwsHttpServletRequest request;
Expand Down Expand Up @@ -316,7 +317,10 @@ public void close()

@Override
public PrintWriter getWriter() throws IOException {
return new PrintWriter(bodyOutputStream);
if (null == writer) {
writer = new PrintWriter(bodyOutputStream);
}
return writer;
}


Expand Down Expand Up @@ -358,7 +362,11 @@ public int getBufferSize() {

@Override
public void flushBuffer() throws IOException {
if (null != writer) {
writer.flush();
}
responseBody = new String(bodyOutputStream.toByteArray());
log.debug("Response buffer flushed with {} bytes, latch={}", responseBody.length(), writersCountDownLatch.getCount());
isCommitted = true;
writersCountDownLatch.countDown();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package com.amazonaws.serverless.proxy.internal.servlet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ServletContext;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionContext;
import java.util.Enumeration;

public class AwsHttpSession implements HttpSession {

private static final Logger log = LoggerFactory.getLogger(AwsHttpSession.class);
private String id;

/**
* @param id API gateway request ID.
*/
public AwsHttpSession(String id) {
if (null == id) {
throw new RuntimeException("HTTP session id (from request ID) cannot be null");
}
log.debug("Creating session " + id);
this.id = id;
}

@Override
public long getCreationTime() {
return 0;
}

@Override
public String getId() {
return id;
}

@Override
public long getLastAccessedTime() {
return 0;
}

@Override
public ServletContext getServletContext() {
return null;
}

@Override
public void setMaxInactiveInterval(int interval) {

}

@Override
public int getMaxInactiveInterval() {
return 0;
}

@Override
public HttpSessionContext getSessionContext() {
return null;
}

@Override
public Object getAttribute(String name) {
return null;
}

@Override
public Object getValue(String name) {
return null;
}

@Override
public Enumeration<String> getAttributeNames() {
return null;
}

@Override
public String[] getValueNames() {
return new String[0];
}

@Override
public void setAttribute(String name, Object value) {

}

@Override
public void putValue(String name, Object value) {

}

@Override
public void removeAttribute(String name) {

}

@Override
public void removeValue(String name) {

}

@Override
public void invalidate() {

}

@Override
public boolean isNew() {
return false;
}
}
Loading

0 comments on commit d074fae

Please sign in to comment.