Skip to content

Commit

Permalink
Initial support for proxy protocol version 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
spericas committed Oct 18, 2023
1 parent 24bc50f commit ff0a9af
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.helidon.common.socket.HelidonSocket;
import io.helidon.common.socket.PeerInfo;
import io.helidon.common.socket.PlainSocket;
import io.helidon.common.socket.SocketOptions;
import io.helidon.common.socket.SocketWriter;
import io.helidon.common.socket.TlsSocket;
import io.helidon.common.task.InterruptableTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,92 @@
public interface ProxyProtocolData {

/**
* The protocol family options.
* Protocol family.
*/
enum ProtocolFamily {
enum Family {
/**
* TCP version 4.
* Unknown family.
*/
TCP4,
UNKNOWN,

/**
* TCP version 6.
* IP version 4.
*/
TCP6,
IPv4,

/**
* Protocol family is unknown.
* IP version 6.
*/
UNKNOWN
IPv6,

/**
* Unix.
*/
UNIX;

static Family fromString(String s) {
return switch (s) {
case "TCP4" -> IPv4;
case "TCP6" -> IPv6;
case "UNIX" -> UNIX;
case "UNKNOWN" -> UNKNOWN;
default -> throw new IllegalArgumentException("Unknown family " + s);
};
}
}

/**
* Protocol family from protocol header.
* Protocol type.
*/
enum Protocol {
/**
* Unknown protocol.
*/
UNKNOWN,

/**
* TCP streams protocol.
*/
TCP,

/**
* UDP datagram protocol.
*/
UDP;

static Protocol fromString(String s) {
return switch (s) {
case "TCP4", "TCP6" -> TCP;
case "UDP" -> UDP;
case "UNKNOWN" -> UNKNOWN;
default -> throw new IllegalArgumentException("Unknown protocol " + s);
};
}
}

/**
* Family from protocol header.
*
* @return family
*/
Family family();

/**
* Protocol from protocol header.
*
* @return protocol family
* @return protocol
*/
ProtocolFamily protocolFamily();
Protocol protocol();

/**
* Source address that is either IPv4 or IPv6 depending on {@link #protocolFamily()}}.
* Source address that is either IP4 or IP6 depending on {@link #family()}.
*
* @return source address
*/
String sourceAddress();

/**
* Destination address that is either IPv4 or IPv6 depending on {@link #protocolFamily()}}.
* Destination address that is either IP4 or IP46 depending on {@link #family()}.
*
* @return source address
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.helidon.webserver;

import java.io.IOException;
import java.io.InputStream;
import java.io.PushbackInputStream;
import java.io.UncheckedIOException;
import java.lang.System.Logger.Level;
Expand All @@ -25,11 +26,14 @@

import io.helidon.http.DirectHandler;
import io.helidon.http.RequestException;
import io.helidon.webserver.ProxyProtocolData.Family;
import io.helidon.webserver.ProxyProtocolData.Protocol;

class ProxyProtocolHandler implements Supplier<ProxyProtocolData> {
private static final System.Logger LOGGER = System.getLogger(ProxyProtocolHandler.class.getName());

private static final int MAX_V1_FIELD_LENGTH = 40;
private static final int MAX_V2_ADDRESS_LENGTH = 216;

static final byte[] V1_PREFIX = {
(byte) 'P',
Expand All @@ -39,12 +43,15 @@ class ProxyProtocolHandler implements Supplier<ProxyProtocolData> {
(byte) 'Y',
};

static final byte[] V2_PREFIX = {
static final byte[] V2_PREFIX_1 = {
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x00,
};

static final byte[] V2_PREFIX_2 = {
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x51,
Expand Down Expand Up @@ -80,7 +87,7 @@ public ProxyProtocolData get() {
}
if (arrayEquals(prefix, V1_PREFIX, V1_PREFIX.length)) {
return handleV1Protocol(inputStream);
} else if (arrayEquals(prefix, V2_PREFIX, V1_PREFIX.length)) {
} else if (arrayEquals(prefix, V2_PREFIX_1, V2_PREFIX_1.length)) {
return handleV2Protocol(inputStream);
} else {
throw BAD_PROTOCOL_EXCEPTION;
Expand All @@ -99,12 +106,14 @@ static ProxyProtocolData handleV1Protocol(PushbackInputStream inputStream) throw

// protocol and family
n = readUntil(inputStream, buffer, (byte) ' ', (byte) '\r');
var family = ProxyProtocolData.ProtocolFamily.valueOf(new String(buffer, 0, n));
byte b = (byte) inputStream.read();
String familyProtocol = new String(buffer, 0, n);
var family = Family.fromString(familyProtocol);
var protocol = Protocol.fromString(familyProtocol);
byte b = readNext(inputStream);
if (b == (byte) '\r') {
// special case for just UNKNOWN family
if (family == ProxyProtocolData.ProtocolFamily.UNKNOWN) {
return new ProxyProtocolDataImpl(ProxyProtocolData.ProtocolFamily.UNKNOWN,
if (family == ProxyProtocolData.Family.UNKNOWN) {
return new ProxyProtocolDataImpl(Family.UNKNOWN, Protocol.UNKNOWN,
null, null, -1, -1);
}
}
Expand Down Expand Up @@ -132,12 +141,103 @@ static ProxyProtocolData handleV1Protocol(PushbackInputStream inputStream) throw
match(inputStream, (byte) '\r');
match(inputStream, (byte) '\n');

return new ProxyProtocolDataImpl(family, sourceAddress, destAddress, sourcePort, destPort);
return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress, sourcePort, destPort);
} catch (IllegalArgumentException e) {
throw BAD_PROTOCOL_EXCEPTION;
}
}

static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throws IOException {
// match rest of prefix
match(inputStream, V2_PREFIX_2);

// only accept version 2, ignore LOCAL/PROXY
int b = readNext(inputStream);
if (b >>> 4 != 0x02) {
throw BAD_PROTOCOL_EXCEPTION;
}

// protocol and family
b = readNext(inputStream);
var family = switch (b >>> 4) {
case 0x1 -> Family.IPv4;
case 0x2 -> Family.IPv6;
case 0x3 -> Family.UNIX;
default -> Family.UNKNOWN;
};
var protocol = switch (b & 0x0F) {
case 0x1 -> Protocol.TCP;
case 0x2 -> Protocol.UDP;
default -> Protocol.UNKNOWN;
};

// length
b = readNext(inputStream);
int headerLength = ((b << 8) & 0xFF00) | (readNext(inputStream) & 0xFF);

// decode addresses and ports
String sourceAddress = null;
String destAddress = null;
int sourcePort = -1;
int destPort = -1;
byte[] buffer = new byte[MAX_V2_ADDRESS_LENGTH];
switch (family) {
case IPv4 -> {
int n = inputStream.read(buffer, 0, 12);
if (n < 12) {
throw BAD_PROTOCOL_EXCEPTION;
}
sourceAddress = (buffer[0] & 0xFF)
+ "." + (buffer[1] & 0xFF)
+ "." + (buffer[2] & 0xFF)
+ "." + (buffer[3] & 0xFF);
destAddress = (buffer[4] & 0xFF)
+ "." + (buffer[5] & 0xFF)
+ "." + (buffer[6] & 0xFF)
+ "." + (buffer[7] & 0xFF);
sourcePort = buffer[9] & 0xFF
| ((buffer[8] << 8) & 0xFF00);
destPort = buffer[11] & 0xFF
| ((buffer[10] << 8) & 0xFF00);
headerLength -= 12;
}
case IPv6 -> {
int n = inputStream.read(buffer, 0, 36);
if (n < 36) {
throw BAD_PROTOCOL_EXCEPTION;
}
headerLength -= 36;

}
case UNIX -> {
int n = inputStream.read(buffer, 0, 216);
if (n < 216) {
throw BAD_PROTOCOL_EXCEPTION;
}
headerLength -= 216;
}
default -> {
// falls through
}
}

// skip any TLV vectors
while (headerLength > 0) {
headerLength -= (int) inputStream.skip(headerLength);
}

return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress,
sourcePort, destPort);
}

private static byte readNext(InputStream inputStream) throws IOException {
int b = inputStream.read();
if (b < 0) {
throw BAD_PROTOCOL_EXCEPTION;
}
return (byte) b;
}

private static void match(byte a, byte b) {
if (a != b) {
throw BAD_PROTOCOL_EXCEPTION;
Expand All @@ -150,28 +250,30 @@ private static void match(PushbackInputStream inputStream, byte b) throws IOExce
}
}

private static void match(PushbackInputStream inputStream, byte... bs) throws IOException {
for (byte b : bs) {
int c = inputStream.read();
if (((byte) c) != b) {
throw BAD_PROTOCOL_EXCEPTION;
}
}
}

private static int readUntil(PushbackInputStream inputStream, byte[] buffer, byte... delims) throws IOException {
int n = 0;
do {
int b = inputStream.read();
if (b < 0) {
throw BAD_PROTOCOL_EXCEPTION;
}
if (arrayContains(delims, (byte) b)) {
byte b = readNext(inputStream);
if (arrayContains(delims, b)) {
inputStream.unread(b);
return n;
}
buffer[n++] = (byte) b;
buffer[n++] = b;
if (n >= buffer.length) {
throw BAD_PROTOCOL_EXCEPTION;
}
} while (true);
}

static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throws IOException {
return null;
}

private static boolean arrayEquals(byte[] array1, byte[] array2, int prefix) {
return Arrays.equals(array1, 0, prefix, array2, 0, prefix);
}
Expand All @@ -185,7 +287,8 @@ private static boolean arrayContains(byte[] array, byte b) {
return false;
}

record ProxyProtocolDataImpl(ProtocolFamily protocolFamily,
record ProxyProtocolDataImpl(Family family,
Protocol protocol,
String sourceAddress,
String destAddress,
int sourcePort,
Expand Down
Loading

0 comments on commit ff0a9af

Please sign in to comment.