Skip to content

Commit

Permalink
Don't cache local and remote address as these might change during the… (
Browse files Browse the repository at this point in the history
java-native-access#664)

… life-time

Motivation:

We should not cache the local and remote address as these might change
during the life-time of the connection

Modifications:

- Override methods so we don't cache
- Also duplicate ByteBuffer in the QuicConnectionAddress to make things
safer.

Result:

Correctly return ids
  • Loading branch information
normanmaurer authored Feb 5, 2024
1 parent b6139e7 commit 976a873
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ public final class QuicConnectionAddress extends SocketAddress {

private final String toStr;

// Accessed by QuicheQuicheChannel
final ByteBuffer connId;
private final ByteBuffer connId;

/**
* Create a new instance
Expand All @@ -57,7 +56,7 @@ public QuicConnectionAddress(byte[] connId) {
* @param connId the connection id to use.
*/
public QuicConnectionAddress(ByteBuffer connId) {
this(connId, true);
this(connId.duplicate(), true);
}

private QuicConnectionAddress(ByteBuffer connId, boolean validate) {
Expand All @@ -66,10 +65,11 @@ private QuicConnectionAddress(ByteBuffer connId, boolean validate) {
throw new IllegalArgumentException("Connection ID can only be of max length "
+ Quiche.QUICHE_MAX_CONN_ID_LEN);
}
this.connId = connId;
if (connId == null) {
this.connId = null;
toStr = "QuicConnectionAddress{EPHEMERAL}";
} else {
this.connId = connId.asReadOnlyBuffer().duplicate();
ByteBuf buffer = Unpooled.wrappedBuffer(connId);
try {
toStr = "QuicConnectionAddress{" +
Expand Down Expand Up @@ -102,10 +102,14 @@ public boolean equals(Object obj) {
if (obj == this) {
return true;
}
return connId.equals(address.connId);
}

ByteBuffer id() {
if (connId == null) {
return false;
return ByteBuffer.allocate(0);
}
return connId.equals(address.connId);
return connId.duplicate();
}

/**
Expand All @@ -128,5 +132,4 @@ public static QuicConnectionAddress random(int length) {
public static QuicConnectionAddress random() {
return random(Quiche.QUICHE_MAX_CONN_ID_LEN);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,15 @@ void connectNow(Function<QuicChannel, ? extends QuicSslEngine> engineProvider, E
this.freeTask = freeTask;

QuicConnectionAddress address = this.connectAddress;

if (address == QuicConnectionAddress.EPHEMERAL) {
address = QuicConnectionAddress.random(localConnIdLength);
} else {
if (address.connId.remaining() != localConnIdLength) {
failConnectPromiseAndThrow(new IllegalArgumentException("connectionAddress has length "
+ address.connId.remaining()
+ " instead of " + localConnIdLength));
}
}
ByteBuffer connectId = address.id();
if (connectId.remaining() != localConnIdLength) {
failConnectPromiseAndThrow(new IllegalArgumentException("connectionAddress has length "
+ connectId.remaining()
+ " instead of " + localConnIdLength));
}
QuicSslEngine engine = engineProvider.apply(this);
if (!(engine instanceof QuicheQuicSslEngine)) {
Expand All @@ -351,7 +352,6 @@ void connectNow(Function<QuicChannel, ? extends QuicSslEngine> engineProvider, E
failConnectPromiseAndThrow(new IllegalArgumentException("QuicSslEngine is not create in client mode"));
}
QuicheQuicSslEngine quicheEngine = (QuicheQuicSslEngine) engine;
ByteBuffer connectId = address.connId.duplicate();
ByteBuf idBuffer = alloc().directBuffer(connectId.remaining()).writeBytes(connectId.duplicate());
try {
int fromSockaddrLen = SockaddrIn.setAddress(fromSockaddrMemory, local);
Expand Down Expand Up @@ -560,6 +560,18 @@ protected SocketAddress remoteAddress0() {
return connection == null ? null : connection.destinationId();
}

@Override
public SocketAddress localAddress() {
// Override so we never cache as the sourceId() can change over life-time.
return localAddress0();
}

@Override
public SocketAddress remoteAddress() {
// Override so we never cache as the destinationId() can change over life-time.
return remoteAddress0();
}

@Override
protected void doBind(SocketAddress socketAddress) {
throw new UnsupportedOperationException();
Expand Down Expand Up @@ -935,13 +947,14 @@ List<ByteBuffer> newSourceConnectionIds() {
}
List<ByteBuffer> generatedIds = new ArrayList<>(left);
boolean sendAndFlush = false;
ByteBuffer key = sourceAddr.connId.duplicate();
ByteBuffer key = sourceAddr.id();
ByteBuf connIdBuffer = alloc().directBuffer(key.remaining());

byte[] resetTokenArray = new byte[Quic.RESET_TOKEN_LEN];
try {
do {
ByteBuffer srcId = connectionIdAddressGenerator.newId(key, key.remaining());
ByteBuffer srcId = connectionIdAddressGenerator.newId(key.duplicate(), key.remaining())
.asReadOnlyBuffer();
connIdBuffer.clear();
connIdBuffer.writeBytes(srcId.duplicate());
ByteBuffer resetToken = resetTokenGenerator.newResetToken(srcId.duplicate());
Expand All @@ -953,7 +966,7 @@ List<ByteBuffer> newSourceConnectionIds() {
break;
}
sendAndFlush = true;
generatedIds.add(srcId);
generatedIds.add(srcId.duplicate());
sourceConnectionIds.add(srcId);
} while (--left > 0);
} finally {
Expand Down Expand Up @@ -1435,9 +1448,8 @@ public void connect(SocketAddress remote, SocketAddress local, ChannelPromise ch
return;
}

QuicConnectionAddress address = (QuicConnectionAddress) remote;
connectAddress = (QuicConnectionAddress) remote;
connectPromise = channelPromise;
connectAddress = address;

// Schedule connect timeout.
int connectTimeoutMillis = config().getConnectTimeoutMillis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.LockSupport;
import java.util.function.Consumer;
import java.util.function.Function;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ private void removeMapping(QuicheQuicChannel channel, ByteBuffer id) {

private void removeChannel(QuicheQuicChannel channel) {
boolean removed = channels.remove(channel);
assert removed;
for (ByteBuffer id : channel.sourceConnectionIds()) {
QuicheQuicChannel ch = connectionIdToChannel.remove(id);
assert ch == channel;
if (removed) {
for (ByteBuffer id : channel.sourceConnectionIds()) {
QuicheQuicChannel ch = connectionIdToChannel.remove(id);
assert ch == channel;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ private static void testConnectWithCustomIdLength(Executor executor, int clientI
ChannelFuture closeFuture = quicChannel.closeFuture().await();
assertTrue(closeFuture.isSuccess());
clientQuicChannelHandler.assertState();
assertEquals(clientIdLength, clientQuicChannelHandler.localAddress().connId.remaining());
assertEquals(serverIdLength, clientQuicChannelHandler.remoteAddress().connId.remaining());
assertEquals(clientIdLength, clientQuicChannelHandler.localAddress().id().remaining());
assertEquals(serverIdLength, clientQuicChannelHandler.remoteAddress().id().remaining());
} finally {
serverQuicChannelHandler.assertState();
assertEquals(serverIdLength, serverQuicChannelHandler.localAddress().connId.remaining());
assertEquals(clientIdLength, serverQuicChannelHandler.remoteAddress().connId.remaining());
assertEquals(serverIdLength, serverQuicChannelHandler.localAddress().id().remaining());
assertEquals(clientIdLength, serverQuicChannelHandler.remoteAddress().id().remaining());
serverQuicStreamHandler.assertState();

server.close().sync();
Expand Down Expand Up @@ -542,10 +542,6 @@ public void testConnectWithoutTokenValidation(Executor executor) throws Throwabl

assertEquals(serverQuicChannelHandler.localAddress(), remoteAddress);
assertEquals(serverQuicChannelHandler.remoteAddress(), localAddress);

// Check if we also can access these after the channel was closed.
assertNotNull(quicChannel.localAddress());
assertNotNull(quicChannel.remoteAddress());
} finally {
serverLatch.await();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ public void testByteArrayIsCloned() {
byte[] bytes = new byte[8];
ThreadLocalRandom.current().nextBytes(bytes);
QuicConnectionAddress address = new QuicConnectionAddress(bytes);
assertEquals(ByteBuffer.wrap(bytes), address.connId);
assertEquals(ByteBuffer.wrap(bytes), address.id());
ThreadLocalRandom.current().nextBytes(bytes);
assertNotEquals(ByteBuffer.wrap(bytes), address.connId);
assertNotEquals(ByteBuffer.wrap(bytes), address.id());
}

@Test
Expand All @@ -52,8 +52,8 @@ public void tesByteBufferIsDuplicated() {
ThreadLocalRandom.current().nextBytes(bytes);
ByteBuffer buffer = ByteBuffer.wrap(bytes);
QuicConnectionAddress address = new QuicConnectionAddress(bytes);
assertEquals(buffer, address.connId);
assertEquals(buffer, address.id());
buffer.position(1);
assertNotEquals(buffer, address.connId);
assertNotEquals(buffer, address.id());
}
}

0 comments on commit 976a873

Please sign in to comment.