Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not fail on X-Forwarded-For by default #1193

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,25 @@
*/
package io.prestosql.dispatcher;

import com.google.common.net.HttpHeaders;
import io.airlift.configuration.Config;
import io.airlift.configuration.ConfigDescription;

import javax.validation.constraints.NotNull;

import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO;

public class DispatcherConfig
{
public enum HeaderSupport
{
WARN,
IGNORE,
ACCEPT,
REJECT,
/**/;
}

// When Presto is not behind a load-balancer, accepting user-provided X-Forwarded-For would be not be safe.
private HeaderSupport forwardedHeaderSupport = HeaderSupport.REJECT;
private HeaderSupport forwardedHeaderSupport = HeaderSupport.WARN;

@NotNull
public HeaderSupport getForwardedHeaderSupport()
Expand All @@ -39,7 +40,7 @@ public HeaderSupport getForwardedHeaderSupport()
}

@Config("dispatcher.forwarded-header")
@ConfigDescription("Support for " + HttpHeaders.X_FORWARDED_PROTO + " header")
@ConfigDescription("Support for " + X_FORWARDED_PROTO + " header")
public DispatcherConfig setForwardedHeaderSupport(HeaderSupport forwardedHeaderSupport)
{
this.forwardedHeaderSupport = forwardedHeaderSupport;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.prestosql.Session.ResourceEstimateBuilder;
Expand Down Expand Up @@ -70,12 +71,15 @@
import static io.prestosql.client.PrestoHeaders.PRESTO_TRANSACTION_ID;
import static io.prestosql.client.PrestoHeaders.PRESTO_USER;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.ACCEPT;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.IGNORE;
import static io.prestosql.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static java.lang.String.format;

public final class HttpRequestSessionContext
implements SessionContext
{
private static final Logger log = Logger.get(HttpRequestSessionContext.class);

private static final Splitter DOT_SPLITTER = Splitter.on('.');

private final String catalog;
Expand Down Expand Up @@ -175,18 +179,22 @@ private static String getRemoteUserAddress(HeaderSupport forwardedHeaderSupport,
// TODO support 'Forwarder' header (here & where other X-Forwarder-* are supported)

switch (forwardedHeaderSupport) {
case REJECT:
case WARN:
if (xForwarderForHeader != null) {
log.warn("Unsupported HTTP header '%s'. Presto needs to be explicitly configured to %s or %s this header", X_FORWARDED_FOR, IGNORE, ACCEPT);
}
return remoteAddess;

case IGNORE:
return remoteAddess;

case ACCEPT:
if (xForwarderForHeader != null) {
assertRequest(forwardedHeaderSupport == ACCEPT, "Unexpected HTTP header. Presto is configured to %s this header: %s", forwardedHeaderSupport, X_FORWARDED_FOR);
List<String> addresses = Splitter.on(",").trimResults().omitEmptyStrings().splitToList(xForwarderForHeader);
if (!addresses.isEmpty()) {
return addresses.get(0);
}
}

// fall-through
case IGNORE:
return remoteAddess;

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class TestDispatcherConfig
public void testDefaults()
{
assertRecordedDefaults(recordDefaults(DispatcherConfig.class)
.setForwardedHeaderSupport(HeaderSupport.REJECT));
.setForwardedHeaderSupport(HeaderSupport.WARN));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import static io.prestosql.client.PrestoHeaders.PRESTO_USER;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.ACCEPT;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.IGNORE;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.REJECT;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.WARN;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;

Expand Down Expand Up @@ -73,7 +73,7 @@ public void testSessionContext()
.build(),
"testRemote");

HttpRequestSessionContext context = new HttpRequestSessionContext(REJECT, request);
HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, request);
assertEquals(context.getSource(), "testSource");
assertEquals(context.getCatalog(), "testCatalog");
assertEquals(context.getSchema(), "testSchema");
Expand Down Expand Up @@ -111,7 +111,7 @@ public void testPreparedStatementsHeaderDoesNotParse()
.put(PRESTO_PREPARED_STATEMENT, "query1=abcdefg")
.build(),
"testRemote");
assertThatThrownBy(() -> new HttpRequestSessionContext(REJECT, request))
assertThatThrownBy(() -> new HttpRequestSessionContext(WARN, request))
.isInstanceOf(WebApplicationException.class)
.hasMessageMatching("Invalid X-Presto-Prepared-Statement header: line 1:1: mismatched input 'abcdefg'. Expecting: .*");
}
Expand All @@ -120,18 +120,16 @@ public void testPreparedStatementsHeaderDoesNotParse()
public void testXForwardedFor()
{
HttpServletRequest plainRequest = requestWithXForwardedFor(Optional.empty(), "remote_address");
HttpServletRequest requestWithXForwardedFor = requestWithXForwardedFor(Optional.of("forwarded_client"), "forwarded_remote_address");
HttpServletRequest requestWithXForwardedFor = requestWithXForwardedFor(Optional.of("forwarded_client"), "proxy_address");

assertEquals(new HttpRequestSessionContext(IGNORE, plainRequest).getRemoteUserAddress(), "remote_address");
assertEquals(new HttpRequestSessionContext(IGNORE, requestWithXForwardedFor).getRemoteUserAddress(), "forwarded_remote_address");
assertEquals(new HttpRequestSessionContext(IGNORE, requestWithXForwardedFor).getRemoteUserAddress(), "proxy_address");

assertEquals(new HttpRequestSessionContext(ACCEPT, plainRequest).getRemoteUserAddress(), "remote_address");
assertEquals(new HttpRequestSessionContext(ACCEPT, requestWithXForwardedFor).getRemoteUserAddress(), "forwarded_client");

assertEquals(new HttpRequestSessionContext(REJECT, plainRequest).getRemoteUserAddress(), "remote_address");
assertThatThrownBy(() -> new HttpRequestSessionContext(REJECT, requestWithXForwardedFor))
.isInstanceOf(WebApplicationException.class)
.hasMessage("Unexpected HTTP header. Presto is configured to REJECT this header: X-Forwarded-For");
assertEquals(new HttpRequestSessionContext(WARN, plainRequest).getRemoteUserAddress(), "remote_address");
assertEquals(new HttpRequestSessionContext(WARN, requestWithXForwardedFor).getRemoteUserAddress(), "proxy_address"); // this generates a warning to logs
}

private static HttpServletRequest requestWithXForwardedFor(Optional<String> xForwardedFor, String remoteAddress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import static io.prestosql.client.PrestoHeaders.PRESTO_SOURCE;
import static io.prestosql.client.PrestoHeaders.PRESTO_TIME_ZONE;
import static io.prestosql.client.PrestoHeaders.PRESTO_USER;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.REJECT;
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.WARN;
import static io.prestosql.spi.type.TimeZoneKey.getTimeZoneKey;
import static io.prestosql.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static org.testng.Assert.assertEquals;
Expand All @@ -77,7 +77,7 @@ public class TestQuerySessionSupplier
@Test
public void testCreateSession()
{
HttpRequestSessionContext context = new HttpRequestSessionContext(REJECT, TEST_REQUEST);
HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, TEST_REQUEST);
QuerySessionSupplier sessionSupplier = new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
Expand Down Expand Up @@ -115,7 +115,7 @@ public void testEmptyClientTags()
.put(PRESTO_USER, "testUser")
.build(),
"remoteAddress");
HttpRequestSessionContext context1 = new HttpRequestSessionContext(REJECT, request1);
HttpRequestSessionContext context1 = new HttpRequestSessionContext(WARN, request1);
assertEquals(context1.getClientTags(), ImmutableSet.of());

HttpServletRequest request2 = new MockHttpServletRequest(
Expand All @@ -124,7 +124,7 @@ public void testEmptyClientTags()
.put(PRESTO_CLIENT_TAGS, "")
.build(),
"remoteAddress");
HttpRequestSessionContext context2 = new HttpRequestSessionContext(REJECT, request2);
HttpRequestSessionContext context2 = new HttpRequestSessionContext(WARN, request2);
assertEquals(context2.getClientTags(), ImmutableSet.of());
}

Expand All @@ -137,15 +137,15 @@ public void testClientCapabilities()
.put(PRESTO_CLIENT_CAPABILITIES, "foo, bar")
.build(),
"remoteAddress");
HttpRequestSessionContext context1 = new HttpRequestSessionContext(REJECT, request1);
HttpRequestSessionContext context1 = new HttpRequestSessionContext(WARN, request1);
assertEquals(context1.getClientCapabilities(), ImmutableSet.of("foo", "bar"));

HttpServletRequest request2 = new MockHttpServletRequest(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.build(),
"remoteAddress");
HttpRequestSessionContext context2 = new HttpRequestSessionContext(REJECT, request2);
HttpRequestSessionContext context2 = new HttpRequestSessionContext(WARN, request2);
assertEquals(context2.getClientCapabilities(), ImmutableSet.of());
}

Expand All @@ -158,7 +158,7 @@ public void testInvalidTimeZone()
.put(PRESTO_TIME_ZONE, "unknown_timezone")
.build(),
"testRemote");
HttpRequestSessionContext context = new HttpRequestSessionContext(REJECT, request);
HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, request);
QuerySessionSupplier sessionSupplier = new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
Expand Down