diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java index b7706fc35f..ded4de718f 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java @@ -17,13 +17,14 @@ package org.springframework.cloud.gateway.handler.predicate; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import java.util.function.Predicate; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.cloud.gateway.handler.support.RoutePredicateFactoryUtils; import org.springframework.cloud.gateway.support.SubnetUtils; import org.springframework.tuple.Tuple; import org.springframework.util.Assert; @@ -46,7 +47,7 @@ public Predicate apply(Tuple args) { addSource(sources, (String) arg); } } - return apply(sources); + return apply(sources, false); } public Predicate apply(String... addrs) { @@ -56,14 +57,22 @@ public Predicate apply(String... addrs) { for (String addr : addrs) { addSource(sources, addr); } - return apply(sources); + return apply(sources, false); } - public Predicate apply(List sources) { + /** + * @param respectForwardedHeader whether to check the `X-Forwarded-For` header for the + * remote IP address. + */ + public Predicate apply(List sources, + boolean respectForwardedHeader) { + Function remoteIpResolver = respectForwardedHeader + ? RoutePredicateFactoryUtils::parseRemoteIpRespectingForwardedHeader + : RoutePredicateFactoryUtils::parseRemoteIpIgnoringForwardedHeader; + return exchange -> { - InetSocketAddress remoteAddress = exchange.getRequest().getRemoteAddress(); - if (remoteAddress != null) { - String hostAddress = remoteAddress.getAddress().getHostAddress(); + String hostAddress = remoteIpResolver.apply(exchange); + if (hostAddress != null) { String host = exchange.getRequest().getURI().getHost(); if (!hostAddress.equals(host)) { diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtils.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtils.java index a63dc0db31..5f3abf861f 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtils.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtils.java @@ -17,14 +17,22 @@ package org.springframework.cloud.gateway.handler.support; +import java.net.InetSocketAddress; +import java.util.List; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cloud.gateway.handler.predicate.RoutePredicateFactory; +import org.springframework.web.server.ServerWebExchange; /** * @author Spencer Gibb + * @author Andrew Fitzgerald */ public class RoutePredicateFactoryUtils { + + public static final String X_FORWARDED_FOR = "X-Forwarded-For"; + private static final Log logger = LogFactory.getLog(RoutePredicateFactory.class); public static void traceMatch(String prefix, Object desired, Object actual, boolean match) { @@ -34,4 +42,22 @@ public static void traceMatch(String prefix, Object desired, Object actual, bool logger.trace(message); } } + + + public static String parseRemoteIpRespectingForwardedHeader(ServerWebExchange exchange) { + List xForwardedValues = exchange.getRequest().getHeaders().get(X_FORWARDED_FOR); + if (xForwardedValues != null && xForwardedValues.size() != 0) { + return xForwardedValues.get(0).split(", ")[0]; + } + return parseRemoteIpIgnoringForwardedHeader(exchange); + } + + public static String parseRemoteIpIgnoringForwardedHeader(ServerWebExchange exchange) { + InetSocketAddress remoteAddress = exchange.getRequest().getRemoteAddress(); + if (remoteAddress != null) { + return remoteAddress.getAddress().getHostAddress(); + } + return null; + } + } diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtilsTest.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtilsTest.java new file mode 100644 index 0000000000..9a8110aa98 --- /dev/null +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/support/RoutePredicateFactoryUtilsTest.java @@ -0,0 +1,56 @@ +package org.springframework.cloud.gateway.handler.support; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +import org.junit.Test; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; + +public class RoutePredicateFactoryUtilsTest { + + private InetSocketAddress getRemote0000Address() { + try { + return new InetSocketAddress(InetAddress.getByName("0.0.0.0"), 1234); + } + catch (UnknownHostException e) { + throw new IllegalStateException(); + } + } + + @Test + public void parseRemoteIpPrioritizesFirstForwardedIp() { + MockServerHttpRequest request = MockServerHttpRequest.get("someUrl") + .remoteAddress(getRemote0000Address()) + .header("X-Forwarded-For", "0.0.0.1, 0.0.0.2, 0.0.0.3").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + String actualIp = RoutePredicateFactoryUtils + .parseRemoteIpRespectingForwardedHeader(exchange); + + assertThat(actualIp).isEqualTo("0.0.0.1"); + } + + @Test + public void parseRemoteIpFallsBackToRemoteIp() { + MockServerHttpRequest request = MockServerHttpRequest.get("someUrl") + .remoteAddress(getRemote0000Address()).build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + String actualIp = RoutePredicateFactoryUtils + .parseRemoteIpRespectingForwardedHeader(exchange); + + assertThat(actualIp).isEqualTo("0.0.0.0"); + } + + @Test + public void parseRemoteIpReturnsNullIfNoForwardedOrRemoteIp() { + MockServerHttpRequest request = MockServerHttpRequest.get("someUrl").build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + String actualIp = RoutePredicateFactoryUtils + .parseRemoteIpRespectingForwardedHeader(exchange); + + assertThat(actualIp).isEqualTo(null); + } +} \ No newline at end of file