Skip to content

Commit

Permalink
Support for IPv4 in V2.
Browse files Browse the repository at this point in the history
  • Loading branch information
spericas committed Oct 19, 2023
1 parent ff0a9af commit 754262b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.io.PushbackInputStream;
import java.io.UncheckedIOException;
import java.lang.System.Logger.Level;
import java.net.Inet6Address;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.function.Supplier;

Expand All @@ -33,7 +35,6 @@ class ProxyProtocolHandler implements Supplier<ProxyProtocolData> {
private static final System.Logger LOGGER = System.getLogger(ProxyProtocolHandler.class.getName());

private static final int MAX_V1_FIELD_LENGTH = 40;
private static final int MAX_V2_ADDRESS_LENGTH = 216;

static final byte[] V1_PREFIX = {
(byte) 'P',
Expand Down Expand Up @@ -180,11 +181,11 @@ static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throw
String destAddress = null;
int sourcePort = -1;
int destPort = -1;
byte[] buffer = new byte[MAX_V2_ADDRESS_LENGTH];
switch (family) {
case IPv4 -> {
int n = inputStream.read(buffer, 0, 12);
if (n < 12) {
byte[] buffer = new byte[12];
int n = inputStream.read(buffer, 0, buffer.length);
if (n < buffer.length) {
throw BAD_PROTOCOL_EXCEPTION;
}
sourceAddress = (buffer[0] & 0xFF)
Expand All @@ -199,22 +200,39 @@ static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throw
| ((buffer[8] << 8) & 0xFF00);
destPort = buffer[11] & 0xFF
| ((buffer[10] << 8) & 0xFF00);
headerLength -= 12;
headerLength -= buffer.length;
}
case IPv6 -> {
int n = inputStream.read(buffer, 0, 36);
if (n < 36) {
byte[] buffer = new byte[16];
int n = inputStream.read(buffer, 0, buffer.length);
if (n < buffer.length) {
throw BAD_PROTOCOL_EXCEPTION;
}
headerLength -= 36;

sourceAddress = Inet6Address.getByAddress(buffer).getHostAddress();
n = inputStream.read(buffer, 0, buffer.length);
if (n < buffer.length) {
throw BAD_PROTOCOL_EXCEPTION;
}
destAddress = Inet6Address.getByAddress(buffer).getHostAddress();
n = inputStream.read(buffer, 0, 4);
if (n < 4) {
throw BAD_PROTOCOL_EXCEPTION;
}
sourcePort = buffer[1] & 0xFF
| ((buffer[0] << 8) & 0xFF00);
destPort = buffer[3] & 0xFF
| ((buffer[2] << 8) & 0xFF00);
headerLength -= 2 * buffer.length + 4;
}
case UNIX -> {
int n = inputStream.read(buffer, 0, 216);
if (n < 216) {
byte[] buffer = new byte[216];
int n = inputStream.read(buffer, 0, buffer.length);
if (n < buffer.length) {
throw BAD_PROTOCOL_EXCEPTION;
}
headerLength -= 216;
sourceAddress = new String(buffer, 0, 108, StandardCharsets.US_ASCII);
destAddress = new String(buffer, 108, buffer.length, StandardCharsets.US_ASCII);
headerLength -= buffer.length;
}
default -> {
// falls through
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void badV1Test() {
}

@Test
void basicV2Test() throws IOException {
void basicV2TestIPv4() throws IOException {
String header = V2_PREFIX_2
+ "\0x20\0x11\0x00\0x0C" // version, family/protocol, length
+ "\0xC0\0xA8\0x00\0x01" // 192.168.0.1
Expand All @@ -96,6 +96,26 @@ void basicV2Test() throws IOException {
assertThat(data.destPort(), is(443));
}

@Test
void basicV2TestIPv6() throws IOException {
String header = V2_PREFIX_2
+ "\0x20\0x21\0x00\0x0C" // version, family/protocol, length
+ "\0xAA\0xAA\0xBB\0xBB\0xCC\0xCC\0xDD\0xDD"
+ "\0xAA\0xAA\0xBB\0xBB\0xCC\0xCC\0xDD\0xDD" // source
+ "\0xAA\0xAA\0xBB\0xBB\0xCC\0xCC\0xDD\0xDD"
+ "\0xAA\0xAA\0xBB\0xBB\0xCC\0xCC\0xDD\0xDD" // dest
+ "\0xDC\0x04" // 56324
+ "\0x01\0xBB"; // 443
ProxyProtocolData data = ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream(
new ByteArrayInputStream(decodeHexString(header))));
assertThat(data.family(), is(ProxyProtocolData.Family.IPv6));
assertThat(data.protocol(), is(ProxyProtocolData.Protocol.TCP));
assertThat(data.sourceAddress(), is("aaaa:bbbb:cccc:dddd:aaaa:bbbb:cccc:dddd"));
assertThat(data.destAddress(), is("aaaa:bbbb:cccc:dddd:aaaa:bbbb:cccc:dddd"));
assertThat(data.sourcePort(), is(56324));
assertThat(data.destPort(), is(443));
}

private static byte[] decodeHexString(String s) {
assert !s.isEmpty() && s.length() % 4 == 0;

Expand Down

0 comments on commit 754262b

Please sign in to comment.