diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java index 9fab71c49..c10d74fb5 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/SecurityUtils.java @@ -7,8 +7,11 @@ import java.io.File; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Set; /** @@ -18,6 +21,51 @@ public final class SecurityUtils { private static Logger log = LoggerFactory.getLogger(SecurityUtils.class); + private static Set SCHEMES = new HashSet() {{ + add("http"); + add("https"); + add("HTTP"); + add("HTTPS"); + }}; + + private static Set PORTS = new HashSet() {{ + add(443); + add(80); + add(3000); // we allow port 3000 for SAM local + }}; + + public static boolean isValidPort(String port) { + if (port == null) { + return false; + } + try { + int intPort = Integer.parseInt(port); + return PORTS.contains(intPort); + } catch (NumberFormatException e) { + log.error("Invalid port parameter: " + crlf(port)); + return false; + } + } + + public static boolean isValidScheme(String scheme) { + return SCHEMES.contains(scheme); + } + + public static boolean isValidHost(String host, String apiId, String region) { + if (host == null) { + return false; + } + if (host.endsWith(".amazonaws.com")) { + String defaultHost = new StringBuilder().append(apiId) + .append(".execute-api.") + .append(region) + .append(".amazonaws.com").toString(); + return host.equals(defaultHost); + } else { + return LambdaContainerHandler.getContainerConfig().getCustomDomainNames().contains(host); + } + } + /** * Replaces CRLF characters in a string with empty string (""). * @param s The string to be cleaned diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java index 8bbbda245..8c7ba8fa6 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java @@ -64,6 +64,9 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest { // We need this to pickup the protocol from the CloudFront header since Lambda doesn't receive this // information from anywhere else static final String CF_PROTOCOL_HEADER_NAME = "CloudFront-Forwarded-Proto"; + static final String PROTOCOL_HEADER_NAME = "X-Forwarded-Proto"; + static final String HOST_HEADER_NAME = "Host"; + static final String PORT_HEADER_NAME = "X-Forwarded-Port"; //------------------------------------------------------------- diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java index 65668b2b8..3586a7fd8 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java @@ -501,22 +501,44 @@ public String getProtocol() { @Override public String getScheme() { - String headerValue = getHeaderCaseInsensitive(CF_PROTOCOL_HEADER_NAME); - if (headerValue == null) { - return "https"; + String cfScheme = getHeaderCaseInsensitive(CF_PROTOCOL_HEADER_NAME); + if (cfScheme != null && SecurityUtils.isValidScheme(cfScheme)) { + return cfScheme; + } + String gwScheme = getHeaderCaseInsensitive(PROTOCOL_HEADER_NAME); + if (gwScheme != null && SecurityUtils.isValidScheme(gwScheme)) { + return gwScheme; } - return headerValue; + // https is our default scheme + return "https"; } - @Override public String getServerName() { - String name = getHeaderCaseInsensitive(HttpHeaders.HOST); + String region = System.getenv("AWS_REGION"); + if (region == null) { + // this is not a critical failure, we just put a static region in the URI + region = "us-east-1"; + } - if (name == null || name.length() == 0) { - name = "lambda.amazonaws.com"; + String hostHeader = getHeaderCaseInsensitive(HOST_HEADER_NAME); + if (hostHeader != null && SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), region)) { + return hostHeader; + } + + return new StringBuilder().append(request.getRequestContext().getApiId()) + .append(".execute-api.") + .append(region) + .append(".amazonaws.com").toString(); + } + + public int getServerPort() { + String port = getHeaderCaseInsensitive(PORT_HEADER_NAME); + if (SecurityUtils.isValidPort(port)) { + return Integer.parseInt(port); + } else { + return 443; // default port } - return name; } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/model/ContainerConfig.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/model/ContainerConfig.java index 04344d3df..676697c3f 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/model/ContainerConfig.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/model/ContainerConfig.java @@ -35,9 +35,11 @@ public static ContainerConfig defaultConfig() { private boolean consolidateSetCookieHeaders; private boolean useStageAsServletContext; private List validFilePaths; + private List customDomainNames; public ContainerConfig() { validFilePaths = new ArrayList<>(); + customDomainNames = new ArrayList<>(); } @@ -168,4 +170,31 @@ public void setValidFilePaths(List validFilePaths) { public void addValidFilePath(String filePath) { validFilePaths.add(filePath); } + + + /** + * Adds a new custom domain name to the list of allowed domains + * @param name The new custom domain name, excluding the scheme ("https") and port + */ + public void addCustomDomain(String name) { + customDomainNames.add(name); + } + + + /** + * Returns the list of custom domain names enabled for the application + * @return The configured custom domain names + */ + public List getCustomDomainNames() { + return customDomainNames; + } + + + /** + * Enables localhost custom domain name for testing. This setting should be used only in local + * with SAM local + */ + public void enableLocalhost() { + customDomainNames.add("localhost"); + } } diff --git a/aws-serverless-java-container-jersey/src/main/java/com/amazonaws/serverless/proxy/jersey/JerseyHandlerFilter.java b/aws-serverless-java-container-jersey/src/main/java/com/amazonaws/serverless/proxy/jersey/JerseyHandlerFilter.java index 1741d49e6..d77870d9d 100644 --- a/aws-serverless-java-container-jersey/src/main/java/com/amazonaws/serverless/proxy/jersey/JerseyHandlerFilter.java +++ b/aws-serverless-java-container-jersey/src/main/java/com/amazonaws/serverless/proxy/jersey/JerseyHandlerFilter.java @@ -165,23 +165,19 @@ private ContainerRequest servletRequestToContainerRequest(ServletRequest request return requestContext; } + @SuppressFBWarnings("SERVLET_SERVER_NAME") private URI getBaseUri(ServletRequest request, String basePath) { - ApiGatewayRequestContext apiGatewayCtx = (ApiGatewayRequestContext) request.getAttribute(API_GATEWAY_CONTEXT_PROPERTY); - String region = System.getenv("AWS_REGION"); - if (region == null) { - // this is not a critical failure, we just put a static region in the URI - region = "us-east-1"; + String finalBasePath = basePath; + if (!finalBasePath.startsWith("/")) { + finalBasePath = "/" + finalBasePath; } - StringBuilder uriBuilder = new StringBuilder(); - uriBuilder.append("https://") // we assume it's always https - .append(apiGatewayCtx.getApiId()) - .append(".execute-api.") - .append(region) - .append(".amazonaws.com") - .append("/"); - - - return UriBuilder.fromUri(uriBuilder.toString()).build(); + String uriString = new StringBuilder().append(request.getScheme()) + .append("://") + .append(request.getServerName()) + .append(":") + .append(request.getServerPort()) + .append(finalBasePath).toString(); + return UriBuilder.fromUri(uriString).build(); } //------------------------------------------------------------- diff --git a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java index 934cc66b5..ca6844855 100644 --- a/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java +++ b/aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringAwsProxyTest.java @@ -1,5 +1,6 @@ package com.amazonaws.serverless.proxy.spring; +import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler; import com.amazonaws.serverless.proxy.model.AwsProxyRequest; import com.amazonaws.serverless.proxy.model.AwsProxyResponse; import com.amazonaws.serverless.proxy.internal.servlet.AwsServletContext; @@ -335,6 +336,7 @@ public void contextPath_generateLink_returnsCorrectPath() { .serverName("api.myserver.com") .stage("prod") .build(); + LambdaContainerHandler.getContainerConfig().addCustomDomain("api.myserver.com"); SpringLambdaContainerHandler.getContainerConfig().setUseStageAsServletContext(true); AwsProxyResponse output = handler.proxy(request, lambdaContext);