From ff0a9afc816f30a1d6a6a3499e062d0ed609e4e0 Mon Sep 17 00:00:00 2001 From: Santiago Pericasgeertsen Date: Wed, 18 Oct 2023 17:27:49 -0400 Subject: [PATCH] Initial support for proxy protocol version 2. --- .../helidon/webserver/ConnectionHandler.java | 1 - .../helidon/webserver/ProxyProtocolData.java | 77 ++++++++-- .../webserver/ProxyProtocolHandler.java | 139 +++++++++++++++--- .../webserver/ProxyProtocolHandlerTest.java | 40 ++++- 4 files changed, 223 insertions(+), 34 deletions(-) diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java b/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java index 9f16152e9bd..0af7de7b5fe 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java @@ -34,7 +34,6 @@ import io.helidon.common.socket.HelidonSocket; import io.helidon.common.socket.PeerInfo; import io.helidon.common.socket.PlainSocket; -import io.helidon.common.socket.SocketOptions; import io.helidon.common.socket.SocketWriter; import io.helidon.common.socket.TlsSocket; import io.helidon.common.task.InterruptableTask; diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java index 25201ede9b2..4d2ba1ac566 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java @@ -21,41 +21,92 @@ public interface ProxyProtocolData { /** - * The protocol family options. + * Protocol family. */ - enum ProtocolFamily { + enum Family { /** - * TCP version 4. + * Unknown family. */ - TCP4, + UNKNOWN, /** - * TCP version 6. + * IP version 4. */ - TCP6, + IPv4, /** - * Protocol family is unknown. + * IP version 6. */ - UNKNOWN + IPv6, + + /** + * Unix. + */ + UNIX; + + static Family fromString(String s) { + return switch (s) { + case "TCP4" -> IPv4; + case "TCP6" -> IPv6; + case "UNIX" -> UNIX; + case "UNKNOWN" -> UNKNOWN; + default -> throw new IllegalArgumentException("Unknown family " + s); + }; + } } /** - * Protocol family from protocol header. + * Protocol type. + */ + enum Protocol { + /** + * Unknown protocol. + */ + UNKNOWN, + + /** + * TCP streams protocol. + */ + TCP, + + /** + * UDP datagram protocol. + */ + UDP; + + static Protocol fromString(String s) { + return switch (s) { + case "TCP4", "TCP6" -> TCP; + case "UDP" -> UDP; + case "UNKNOWN" -> UNKNOWN; + default -> throw new IllegalArgumentException("Unknown protocol " + s); + }; + } + } + + /** + * Family from protocol header. + * + * @return family + */ + Family family(); + + /** + * Protocol from protocol header. * - * @return protocol family + * @return protocol */ - ProtocolFamily protocolFamily(); + Protocol protocol(); /** - * Source address that is either IPv4 or IPv6 depending on {@link #protocolFamily()}}. + * Source address that is either IP4 or IP6 depending on {@link #family()}. * * @return source address */ String sourceAddress(); /** - * Destination address that is either IPv4 or IPv6 depending on {@link #protocolFamily()}}. + * Destination address that is either IP4 or IP46 depending on {@link #family()}. * * @return source address */ diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java index 138bd8115a2..66285b1019a 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java @@ -16,6 +16,7 @@ package io.helidon.webserver; import java.io.IOException; +import java.io.InputStream; import java.io.PushbackInputStream; import java.io.UncheckedIOException; import java.lang.System.Logger.Level; @@ -25,11 +26,14 @@ import io.helidon.http.DirectHandler; import io.helidon.http.RequestException; +import io.helidon.webserver.ProxyProtocolData.Family; +import io.helidon.webserver.ProxyProtocolData.Protocol; class ProxyProtocolHandler implements Supplier { private static final System.Logger LOGGER = System.getLogger(ProxyProtocolHandler.class.getName()); private static final int MAX_V1_FIELD_LENGTH = 40; + private static final int MAX_V2_ADDRESS_LENGTH = 216; static final byte[] V1_PREFIX = { (byte) 'P', @@ -39,12 +43,15 @@ class ProxyProtocolHandler implements Supplier { (byte) 'Y', }; - static final byte[] V2_PREFIX = { + static final byte[] V2_PREFIX_1 = { (byte) 0x0D, (byte) 0x0A, (byte) 0x0D, (byte) 0x0A, (byte) 0x00, + }; + + static final byte[] V2_PREFIX_2 = { (byte) 0x0D, (byte) 0x0A, (byte) 0x51, @@ -80,7 +87,7 @@ public ProxyProtocolData get() { } if (arrayEquals(prefix, V1_PREFIX, V1_PREFIX.length)) { return handleV1Protocol(inputStream); - } else if (arrayEquals(prefix, V2_PREFIX, V1_PREFIX.length)) { + } else if (arrayEquals(prefix, V2_PREFIX_1, V2_PREFIX_1.length)) { return handleV2Protocol(inputStream); } else { throw BAD_PROTOCOL_EXCEPTION; @@ -99,12 +106,14 @@ static ProxyProtocolData handleV1Protocol(PushbackInputStream inputStream) throw // protocol and family n = readUntil(inputStream, buffer, (byte) ' ', (byte) '\r'); - var family = ProxyProtocolData.ProtocolFamily.valueOf(new String(buffer, 0, n)); - byte b = (byte) inputStream.read(); + String familyProtocol = new String(buffer, 0, n); + var family = Family.fromString(familyProtocol); + var protocol = Protocol.fromString(familyProtocol); + byte b = readNext(inputStream); if (b == (byte) '\r') { // special case for just UNKNOWN family - if (family == ProxyProtocolData.ProtocolFamily.UNKNOWN) { - return new ProxyProtocolDataImpl(ProxyProtocolData.ProtocolFamily.UNKNOWN, + if (family == ProxyProtocolData.Family.UNKNOWN) { + return new ProxyProtocolDataImpl(Family.UNKNOWN, Protocol.UNKNOWN, null, null, -1, -1); } } @@ -132,12 +141,103 @@ static ProxyProtocolData handleV1Protocol(PushbackInputStream inputStream) throw match(inputStream, (byte) '\r'); match(inputStream, (byte) '\n'); - return new ProxyProtocolDataImpl(family, sourceAddress, destAddress, sourcePort, destPort); + return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress, sourcePort, destPort); } catch (IllegalArgumentException e) { throw BAD_PROTOCOL_EXCEPTION; } } + static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throws IOException { + // match rest of prefix + match(inputStream, V2_PREFIX_2); + + // only accept version 2, ignore LOCAL/PROXY + int b = readNext(inputStream); + if (b >>> 4 != 0x02) { + throw BAD_PROTOCOL_EXCEPTION; + } + + // protocol and family + b = readNext(inputStream); + var family = switch (b >>> 4) { + case 0x1 -> Family.IPv4; + case 0x2 -> Family.IPv6; + case 0x3 -> Family.UNIX; + default -> Family.UNKNOWN; + }; + var protocol = switch (b & 0x0F) { + case 0x1 -> Protocol.TCP; + case 0x2 -> Protocol.UDP; + default -> Protocol.UNKNOWN; + }; + + // length + b = readNext(inputStream); + int headerLength = ((b << 8) & 0xFF00) | (readNext(inputStream) & 0xFF); + + // decode addresses and ports + String sourceAddress = null; + String destAddress = null; + int sourcePort = -1; + int destPort = -1; + byte[] buffer = new byte[MAX_V2_ADDRESS_LENGTH]; + switch (family) { + case IPv4 -> { + int n = inputStream.read(buffer, 0, 12); + if (n < 12) { + throw BAD_PROTOCOL_EXCEPTION; + } + sourceAddress = (buffer[0] & 0xFF) + + "." + (buffer[1] & 0xFF) + + "." + (buffer[2] & 0xFF) + + "." + (buffer[3] & 0xFF); + destAddress = (buffer[4] & 0xFF) + + "." + (buffer[5] & 0xFF) + + "." + (buffer[6] & 0xFF) + + "." + (buffer[7] & 0xFF); + sourcePort = buffer[9] & 0xFF + | ((buffer[8] << 8) & 0xFF00); + destPort = buffer[11] & 0xFF + | ((buffer[10] << 8) & 0xFF00); + headerLength -= 12; + } + case IPv6 -> { + int n = inputStream.read(buffer, 0, 36); + if (n < 36) { + throw BAD_PROTOCOL_EXCEPTION; + } + headerLength -= 36; + + } + case UNIX -> { + int n = inputStream.read(buffer, 0, 216); + if (n < 216) { + throw BAD_PROTOCOL_EXCEPTION; + } + headerLength -= 216; + } + default -> { + // falls through + } + } + + // skip any TLV vectors + while (headerLength > 0) { + headerLength -= (int) inputStream.skip(headerLength); + } + + return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress, + sourcePort, destPort); + } + + private static byte readNext(InputStream inputStream) throws IOException { + int b = inputStream.read(); + if (b < 0) { + throw BAD_PROTOCOL_EXCEPTION; + } + return (byte) b; + } + private static void match(byte a, byte b) { if (a != b) { throw BAD_PROTOCOL_EXCEPTION; @@ -150,28 +250,30 @@ private static void match(PushbackInputStream inputStream, byte b) throws IOExce } } + private static void match(PushbackInputStream inputStream, byte... bs) throws IOException { + for (byte b : bs) { + int c = inputStream.read(); + if (((byte) c) != b) { + throw BAD_PROTOCOL_EXCEPTION; + } + } + } + private static int readUntil(PushbackInputStream inputStream, byte[] buffer, byte... delims) throws IOException { int n = 0; do { - int b = inputStream.read(); - if (b < 0) { - throw BAD_PROTOCOL_EXCEPTION; - } - if (arrayContains(delims, (byte) b)) { + byte b = readNext(inputStream); + if (arrayContains(delims, b)) { inputStream.unread(b); return n; } - buffer[n++] = (byte) b; + buffer[n++] = b; if (n >= buffer.length) { throw BAD_PROTOCOL_EXCEPTION; } } while (true); } - static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throws IOException { - return null; - } - private static boolean arrayEquals(byte[] array1, byte[] array2, int prefix) { return Arrays.equals(array1, 0, prefix, array2, 0, prefix); } @@ -185,7 +287,8 @@ private static boolean arrayContains(byte[] array, byte b) { return false; } - record ProxyProtocolDataImpl(ProtocolFamily protocolFamily, + record ProxyProtocolDataImpl(Family family, + Protocol protocol, String sourceAddress, String destAddress, int sourcePort, diff --git a/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java b/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java index c3f889de27c..76db2d6daa4 100644 --- a/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java +++ b/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java @@ -30,12 +30,15 @@ class ProxyProtocolHandlerTest { + static final String V2_PREFIX_2 = "\0x0D\0x0A\0x51\0x55\0x49\0x54\0x0A"; + @Test void basicV1Test() throws IOException { String header = " TCP4 192.168.0.1 192.168.0.11 56324 443\r\n"; // excludes PROXY prefix ProxyProtocolData data = ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( new ByteArrayInputStream(header.getBytes(StandardCharsets.US_ASCII)))); - assertThat(data.protocolFamily(), is(ProxyProtocolData.ProtocolFamily.TCP4)); + assertThat(data.family(), is(ProxyProtocolData.Family.IPv4)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.TCP)); assertThat(data.sourceAddress(), is("192.168.0.1")); assertThat(data.destAddress(), is("192.168.0.11")); assertThat(data.sourcePort(), is(56324)); @@ -47,7 +50,8 @@ void unknownV1Test() throws IOException { String header = " UNKNOWN\r\n"; // excludes PROXY prefix ProxyProtocolData data = ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( new ByteArrayInputStream(header.getBytes(StandardCharsets.US_ASCII)))); - assertThat(data.protocolFamily(), is(ProxyProtocolData.ProtocolFamily.UNKNOWN)); + assertThat(data.family(), is(ProxyProtocolData.Family.UNKNOWN)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.UNKNOWN)); assertThat(data.sourceAddress(), nullValue()); assertThat(data.destAddress(), nullValue()); assertThat(data.sourcePort(), is(-1)); @@ -73,4 +77,36 @@ void badV1Test() { ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( new ByteArrayInputStream(header4.getBytes(StandardCharsets.US_ASCII))))); } + + @Test + void basicV2Test() throws IOException { + String header = V2_PREFIX_2 + + "\0x20\0x11\0x00\0x0C" // version, family/protocol, length + + "\0xC0\0xA8\0x00\0x01" // 192.168.0.1 + + "\0xC0\0xA8\0x00\0x0B" // 192.168.0.11 + + "\0xDC\0x04" // 56324 + + "\0x01\0xBB"; // 443 + ProxyProtocolData data = ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(decodeHexString(header)))); + assertThat(data.family(), is(ProxyProtocolData.Family.IPv4)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.TCP)); + assertThat(data.sourceAddress(), is("192.168.0.1")); + assertThat(data.destAddress(), is("192.168.0.11")); + assertThat(data.sourcePort(), is(56324)); + assertThat(data.destPort(), is(443)); + } + + private static byte[] decodeHexString(String s) { + assert !s.isEmpty() && s.length() % 4 == 0; + + byte[] bytes = new byte[s.length() / 4]; + for (int i = 0, j = 0; i < s.length(); i += 4) { + char c1 = s.charAt(i + 2); + byte b1 = (byte) (Character.isDigit(c1) ? c1 - '0' : c1 - 'A' + 10); + char c2 = s.charAt(i + 3); + byte b2 = (byte) (Character.isDigit(c2) ? c2 - '0' : c2 - 'A' + 10); + bytes[j++] = (byte) (((b1 << 4) & 0xF0) | (b2 & 0x0F)); + } + return bytes; + } }