From 69fdc1c3b1b77d9d39cccc9633b70f4c47abc654 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Sat, 29 May 2021 23:00:40 +0200 Subject: [PATCH] Update code to use the new network-path aware API of quiche (#291) Motivation: Quiche changed its API to be network path aware and so be able to handle connection migration in the future. Modifications: - Adjust code to use the new API Result: Be able to use the new network-path aware API of quiche --- pom.xml | 2 +- src/main/c/netty_quic_quiche.c | 183 ++++++++++++++++- .../quic/QuicConnectionMigrationEvent.java | 50 +++++ .../io/netty/incubator/codec/quic/Quiche.java | 45 +++- ...eNativeStaticallyReferencedJniMethods.java | 25 +++ .../codec/quic/QuicheQuicChannel.java | 109 ++++++++-- .../codec/quic/QuicheQuicClientCodec.java | 2 +- .../incubator/codec/quic/QuicheQuicCodec.java | 10 +- .../codec/quic/QuicheQuicConnection.java | 60 ++++++ .../codec/quic/QuicheQuicServerCodec.java | 9 +- .../incubator/codec/quic/QuicheRecvInfo.java | 73 +++++++ .../incubator/codec/quic/QuicheSendInfo.java | 115 +++++++++++ .../incubator/codec/quic/SockaddrIn.java | 162 +++++++++++++++ .../codec/quic/QuicChannelConnectTest.java | 34 ++- .../codec/quic/QuicChannelDatagramTest.java | 193 ++++++++++-------- .../codec/quic/QuicChannelEchoTest.java | 40 ++-- .../quic/QuicChannelValidationHandler.java | 46 +++++ .../codec/quic/QuicConnectionStatsTest.java | 47 +++-- .../codec/quic/QuicReadableTest.java | 9 +- .../quic/QuicStreamChannelCloseTest.java | 43 ++-- .../quic/QuicStreamChannelCreationTest.java | 20 +- .../codec/quic/QuicStreamFrameTest.java | 17 +- .../codec/quic/QuicStreamHalfClosureTest.java | 17 +- .../codec/quic/QuicStreamLimitTest.java | 94 +++++---- .../codec/quic/QuicStreamTypeTest.java | 55 +++-- .../codec/quic/QuicWritableTest.java | 20 +- 26 files changed, 1192 insertions(+), 288 deletions(-) create mode 100644 src/main/java/io/netty/incubator/codec/quic/QuicConnectionMigrationEvent.java create mode 100644 src/main/java/io/netty/incubator/codec/quic/QuicheRecvInfo.java create mode 100644 src/main/java/io/netty/incubator/codec/quic/QuicheSendInfo.java create mode 100644 src/main/java/io/netty/incubator/codec/quic/SockaddrIn.java create mode 100644 src/test/java/io/netty/incubator/codec/quic/QuicChannelValidationHandler.java diff --git a/pom.xml b/pom.xml index 696bdc333..4e88b953c 100644 --- a/pom.xml +++ b/pom.xml @@ -94,7 +94,7 @@ ${quicheHomeDir}/include https://github.com/cloudflare/quiche master - 92e9561501a3a5ab179115675dca25f8b0feb185 + 6d070ed8694216806f3ce689d71ceb7ad76d425e ${project.build.directory}/generated-sources ${project.build.directory}/template diff --git a/src/main/c/netty_quic_quiche.c b/src/main/c/netty_quic_quiche.c index f607d3f1f..40324433b 100644 --- a/src/main/c/netty_quic_quiche.c +++ b/src/main/c/netty_quic_quiche.c @@ -18,6 +18,16 @@ #include #include #include + +#ifdef _WIN32 +#include +#include +#else +#include +#include +#include +#endif // _WIN32 + #include #include "netty_jni_util.h" #include "netty_quic_boringssl.h" @@ -46,6 +56,97 @@ jint quic_get_java_env(JNIEnv **env) return (*global_vm)->GetEnv(global_vm, (void **)env, NETTY_JNI_UTIL_JNI_VERSION); } +static jint netty_quiche_afInet(JNIEnv* env, jclass clazz) { + return AF_INET; +} + +static jint netty_quiche_afInet6(JNIEnv* env, jclass clazz) { + return AF_INET6; +} + +static jint netty_quiche_sizeofSockaddrIn(JNIEnv* env, jclass clazz) { + return sizeof(struct sockaddr_in); +} + +static jint netty_quiche_sizeofSockaddrIn6(JNIEnv* env, jclass clazz) { + return sizeof(struct sockaddr_in6); +} + +static jint netty_quiche_sockaddrInOffsetofSinFamily(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in, sin_family); +} + +static jint netty_quiche_sockaddrInOffsetofSinPort(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in, sin_port); +} + +static jint netty_quiche_sockaddrInOffsetofSinAddr(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in, sin_addr); +} + +static jint netty_quiche_inAddressOffsetofSAddr(JNIEnv* env, jclass clazz) { + return offsetof(struct in_addr, s_addr); +} + +static jint netty_quiche_sockaddrIn6OffsetofSin6Family(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in6, sin6_family); +} + +static jint netty_quiche_sockaddrIn6OffsetofSin6Port(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in6, sin6_port); +} + +static jint netty_quiche_sockaddrIn6OffsetofSin6Flowinfo(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in6, sin6_flowinfo); +} + +static jint netty_quiche_sockaddrIn6OffsetofSin6Addr(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in6, sin6_addr); +} + +static jint netty_quiche_sockaddrIn6OffsetofSin6ScopeId(JNIEnv* env, jclass clazz) { + return offsetof(struct sockaddr_in6, sin6_scope_id); +} + +static jint netty_quiche_in6AddressOffsetofS6Addr(JNIEnv* env, jclass clazz) { + return offsetof(struct in6_addr, s6_addr); +} + +static jint netty_quiche_sizeofSockaddrStorage(JNIEnv* env, jclass clazz) { + return sizeof(struct sockaddr_storage); +} +static jint netty_quiche_sizeofSizeT(JNIEnv* env, jclass clazz) { + return sizeof(size_t); +} + +static jint netty_quiche_sizeofSocklenT(JNIEnv* env, jclass clazz) { + return sizeof(socklen_t); +} + +static jint netty_quicheRecvInfoOffsetofFrom(JNIEnv* env, jclass clazz) { + return offsetof(quiche_recv_info, from); +} + +static jint netty_quicheRecvInfoOffsetofFromLen(JNIEnv* env, jclass clazz) { + return offsetof(quiche_recv_info, from_len); +} + +static jint netty_sizeofQuicheRecvInfo(JNIEnv* env, jclass clazz) { + return sizeof(quiche_recv_info); +} + +static jint netty_quicheSendInfoOffsetofTo(JNIEnv* env, jclass clazz) { + return offsetof(quiche_send_info, to); +} + +static jint netty_quicheSendInfoOffsetofToLen(JNIEnv* env, jclass clazz) { + return offsetof(quiche_send_info, to_len); +} + +static jint netty_sizeofQuicheSendInfo(JNIEnv* env, jclass clazz) { + return sizeof(quiche_send_info); +} + static jint netty_quiche_max_conn_id_len(JNIEnv* env, jclass clazz) { return QUICHE_MAX_CONN_ID_LEN; } @@ -176,13 +277,15 @@ static jint netty_quiche_retry(JNIEnv* env, jclass clazz, jlong scid, jint scid_ (uint32_t) version, (uint8_t *) out, (size_t) out_len); } -static jlong netty_quiche_conn_new_with_tls(JNIEnv* env, jclass clazz, jlong scid, jint scid_len, jlong odcid, jint odcid_len, jlong config, jlong ssl, jboolean isServer) { +static jlong netty_quiche_conn_new_with_tls(JNIEnv* env, jclass clazz, jlong scid, jint scid_len, jlong odcid, jint odcid_len, jlong peer, jint peer_len, jlong config, jlong ssl, jboolean isServer) { const uint8_t * odcid_pointer = NULL; if (odcid_len != -1) { odcid_pointer = (const uint8_t *) odcid; } + const struct sockaddr *peer_pointer = (const struct sockaddr*) peer; quiche_conn *conn = quiche_conn_new_with_tls((const uint8_t *) scid, (size_t) scid_len, odcid_pointer, (size_t) odcid_len, + peer_pointer, (size_t) peer_len, (quiche_config *) config, (void*) ssl, isServer == JNI_TRUE ? true : false); if (conn == NULL) { return -1; @@ -226,12 +329,12 @@ static jbyteArray netty_quiche_conn_destination_id(JNIEnv* env, jclass clazz, jl return to_byte_array(env, id, len); } -static jint netty_quiche_conn_recv(JNIEnv* env, jclass clazz, jlong conn, jlong buf, jint buf_len) { - return (jint) quiche_conn_recv((quiche_conn *) conn, (uint8_t *) buf, (size_t) buf_len); +static jint netty_quiche_conn_recv(JNIEnv* env, jclass clazz, jlong conn, jlong buf, jint buf_len, jlong info) { + return (jint) quiche_conn_recv((quiche_conn *) conn, (uint8_t *) buf, (size_t) buf_len, (quiche_recv_info*) info); } -static jint netty_quiche_conn_send(JNIEnv* env, jclass clazz, jlong conn, jlong out, jint out_len) { - return (jint) quiche_conn_send((quiche_conn *) conn, (uint8_t *) out, (size_t) out_len); +static jint netty_quiche_conn_send(JNIEnv* env, jclass clazz, jlong conn, jlong out, jint out_len, jlong info) { + return (jint) quiche_conn_send((quiche_conn *) conn, (uint8_t *) out, (size_t) out_len, (quiche_send_info*) info); } static void netty_quiche_conn_free(JNIEnv* env, jclass clazz, jlong conn) { @@ -472,10 +575,71 @@ static jlong netty_buffer_memory_address(JNIEnv* env, jclass clazz, jobject buff return (jlong) (*env)->GetDirectBufferAddress(env, buffer); } +// Based on https://gist.github.com/kazuho/45eae4f92257daceb73e. +static jint netty_sockaddr_cmp(JNIEnv* env, jclass clazz, jlong addr1, jlong addr2) { + struct sockaddr* x = (struct sockaddr*) addr1; + struct sockaddr* y = (struct sockaddr*) addr2; + + if (x == NULL && y == NULL) { + return 0; + } + if (x != NULL && y == NULL) { + return 1; + } + if (x == NULL && y != NULL) { + return -1; + } + +#define CMP(a, b) if (a != b) return a < b ? -1 : 1 + + CMP(x->sa_family, y->sa_family); + + if (x->sa_family == AF_INET) { + struct sockaddr_in *xin = (void*)x, *yin = (void*)y; + CMP(ntohl(xin->sin_addr.s_addr), ntohl(yin->sin_addr.s_addr)); + CMP(ntohs(xin->sin_port), ntohs(yin->sin_port)); + } else if (x->sa_family == AF_INET6) { + struct sockaddr_in6 *xin6 = (void*)x, *yin6 = (void*)y; + int r = memcmp(xin6->sin6_addr.s6_addr, yin6->sin6_addr.s6_addr, sizeof(xin6->sin6_addr.s6_addr)); + if (r != 0) + return r; + CMP(ntohs(xin6->sin6_port), ntohs(yin6->sin6_port)); + CMP(xin6->sin6_flowinfo, yin6->sin6_flowinfo); + CMP(xin6->sin6_scope_id, yin6->sin6_scope_id); + } + +#undef CMP + return 0; +} + // JNI Registered Methods End // JNI Method Registration Table Begin static const JNINativeMethod statically_referenced_fixed_method_table[] = { + { "afInet", "()I", (void *) netty_quiche_afInet }, + { "afInet6", "()I", (void *) netty_quiche_afInet6 }, + { "sizeofSockaddrIn", "()I", (void *) netty_quiche_sizeofSockaddrIn }, + { "sizeofSockaddrIn6", "()I", (void *) netty_quiche_sizeofSockaddrIn6 }, + { "sockaddrInOffsetofSinFamily", "()I", (void *) netty_quiche_sockaddrInOffsetofSinFamily }, + { "sockaddrInOffsetofSinPort", "()I", (void *) netty_quiche_sockaddrInOffsetofSinPort }, + { "sockaddrInOffsetofSinAddr", "()I", (void *) netty_quiche_sockaddrInOffsetofSinAddr }, + { "inAddressOffsetofSAddr", "()I", (void *) netty_quiche_inAddressOffsetofSAddr }, + { "sockaddrIn6OffsetofSin6Family", "()I", (void *) netty_quiche_sockaddrIn6OffsetofSin6Family }, + { "sockaddrIn6OffsetofSin6Port", "()I", (void *) netty_quiche_sockaddrIn6OffsetofSin6Port }, + { "sockaddrIn6OffsetofSin6Flowinfo", "()I", (void *) netty_quiche_sockaddrIn6OffsetofSin6Flowinfo }, + { "sockaddrIn6OffsetofSin6Addr", "()I", (void *) netty_quiche_sockaddrIn6OffsetofSin6Addr }, + { "sockaddrIn6OffsetofSin6ScopeId", "()I", (void *) netty_quiche_sockaddrIn6OffsetofSin6ScopeId }, + { "in6AddressOffsetofS6Addr", "()I", (void *) netty_quiche_in6AddressOffsetofS6Addr }, + { "sizeofSockaddrStorage", "()I", (void *) netty_quiche_sizeofSockaddrStorage }, + { "sizeofSizeT", "()I", (void *) netty_quiche_sizeofSizeT }, + { "sizeofSocklenT", "()I", (void *) netty_quiche_sizeofSocklenT }, + { "quicheRecvInfoOffsetofFrom", "()I", (void *) netty_quicheRecvInfoOffsetofFrom }, + { "quicheRecvInfoOffsetofFromLen", "()I", (void *) netty_quicheRecvInfoOffsetofFromLen }, + { "sizeofQuicheRecvInfo", "()I", (void *) netty_sizeofQuicheRecvInfo }, + { "quicheSendInfoOffsetofTo", "()I", (void *) netty_quicheSendInfoOffsetofTo }, + { "quicheSendInfoOffsetofToLen", "()I", (void *) netty_quicheSendInfoOffsetofToLen }, + { "sizeofQuicheSendInfo", "()I", (void *) netty_sizeofQuicheSendInfo }, + { "quiche_protocol_version", "()I", (void *) netty_quiche_protocol_version }, { "quiche_max_conn_id_len", "()I", (void *) netty_quiche_max_conn_id_len }, { "quiche_shutdown_read", "()I", (void *) netty_quiche_shutdown_read }, @@ -510,9 +674,9 @@ static const JNINativeMethod fixed_method_table[] = { { "quiche_conn_trace_id", "(J)[B", (void *) netty_quiche_conn_trace_id }, { "quiche_conn_source_id", "(J)[B", (void *) netty_quiche_conn_source_id }, { "quiche_conn_destination_id", "(J)[B", (void *) netty_quiche_conn_destination_id }, - { "quiche_conn_new_with_tls", "(JIJIJJZ)J", (void *) netty_quiche_conn_new_with_tls }, - { "quiche_conn_recv", "(JJI)I", (void *) netty_quiche_conn_recv }, - { "quiche_conn_send", "(JJI)I", (void *) netty_quiche_conn_send }, + { "quiche_conn_new_with_tls", "(JIJIJIJJZ)J", (void *) netty_quiche_conn_new_with_tls }, + { "quiche_conn_recv", "(JJIJ)I", (void *) netty_quiche_conn_recv }, + { "quiche_conn_send", "(JJIJ)I", (void *) netty_quiche_conn_send }, { "quiche_conn_free", "(J)V", (void *) netty_quiche_conn_free }, { "quiche_conn_peer_streams_left_bidi", "(J)J", (void *) netty_quiche_conn_peer_streams_left_bidi }, { "quiche_conn_peer_streams_left_uni", "(J)J", (void *) netty_quiche_conn_peer_streams_left_uni }, @@ -555,7 +719,8 @@ static const JNINativeMethod fixed_method_table[] = { { "quiche_config_set_cc_algorithm", "(JI)V", (void *) netty_quiche_config_set_cc_algorithm }, { "quiche_config_enable_hystart", "(JZ)V", (void *) netty_quiche_config_enable_hystart }, { "quiche_config_free", "(J)V", (void *) netty_quiche_config_free }, - { "buffer_memory_address", "(Ljava/nio/ByteBuffer;)J", (void *) netty_buffer_memory_address} + { "buffer_memory_address", "(Ljava/nio/ByteBuffer;)J", (void *) netty_buffer_memory_address}, + { "sockaddr_cmp", "(JJ)I", (void *) netty_sockaddr_cmp} }; static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof(fixed_method_table[0]); diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicConnectionMigrationEvent.java b/src/main/java/io/netty/incubator/codec/quic/QuicConnectionMigrationEvent.java new file mode 100644 index 000000000..f1145159c --- /dev/null +++ b/src/main/java/io/netty/incubator/codec/quic/QuicConnectionMigrationEvent.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.incubator.codec.quic; + +import java.net.SocketAddress; + +/** + * {@link QuicEvent} which is fired when an QUIC connection migration was detected. + */ +public final class QuicConnectionMigrationEvent implements QuicEvent { + + private final SocketAddress from; + private final SocketAddress to; + + QuicConnectionMigrationEvent(SocketAddress from, SocketAddress to) { + this.from = from; + this.to = to; + } + + /** + * The old {@link SocketAddress} of the connection. + * + * @return the old {@link SocketAddress} of the connection. + */ + public SocketAddress from() { + return from; + } + + /** + * The new {@link SocketAddress} of the connection. + * + * @return the new {@link SocketAddress} of the connection. + */ + public SocketAddress to() { + return to; + } +} diff --git a/src/main/java/io/netty/incubator/codec/quic/Quiche.java b/src/main/java/io/netty/incubator/codec/quic/Quiche.java index f8f169ee2..6798fe6c3 100644 --- a/src/main/java/io/netty/incubator/codec/quic/Quiche.java +++ b/src/main/java/io/netty/incubator/codec/quic/Quiche.java @@ -76,6 +76,44 @@ private static void loadNativeLibrary() { } } + static final short AF_INET = (short) QuicheNativeStaticallyReferencedJniMethods.afInet(); + static final short AF_INET6 = (short) QuicheNativeStaticallyReferencedJniMethods.afInet6(); + static final int SIZEOF_SOCKADDR_STORAGE = QuicheNativeStaticallyReferencedJniMethods.sizeofSockaddrStorage(); + static final int SIZEOF_SOCKADDR_IN = QuicheNativeStaticallyReferencedJniMethods.sizeofSockaddrIn(); + static final int SIZEOF_SOCKADDR_IN6 = QuicheNativeStaticallyReferencedJniMethods.sizeofSockaddrIn6(); + static final int SOCKADDR_IN_OFFSETOF_SIN_FAMILY = + QuicheNativeStaticallyReferencedJniMethods.sockaddrInOffsetofSinFamily(); + static final int SOCKADDR_IN_OFFSETOF_SIN_PORT = + QuicheNativeStaticallyReferencedJniMethods.sockaddrInOffsetofSinPort(); + static final int SOCKADDR_IN_OFFSETOF_SIN_ADDR = + QuicheNativeStaticallyReferencedJniMethods.sockaddrInOffsetofSinAddr(); + static final int IN_ADDRESS_OFFSETOF_S_ADDR = QuicheNativeStaticallyReferencedJniMethods.inAddressOffsetofSAddr(); + static final int SOCKADDR_IN6_OFFSETOF_SIN6_FAMILY = + QuicheNativeStaticallyReferencedJniMethods.sockaddrIn6OffsetofSin6Family(); + static final int SOCKADDR_IN6_OFFSETOF_SIN6_PORT = + QuicheNativeStaticallyReferencedJniMethods.sockaddrIn6OffsetofSin6Port(); + static final int SOCKADDR_IN6_OFFSETOF_SIN6_FLOWINFO = + QuicheNativeStaticallyReferencedJniMethods.sockaddrIn6OffsetofSin6Flowinfo(); + static final int SOCKADDR_IN6_OFFSETOF_SIN6_ADDR = + QuicheNativeStaticallyReferencedJniMethods.sockaddrIn6OffsetofSin6Addr(); + static final int SOCKADDR_IN6_OFFSETOF_SIN6_SCOPE_ID = + QuicheNativeStaticallyReferencedJniMethods.sockaddrIn6OffsetofSin6ScopeId(); + static final int IN6_ADDRESS_OFFSETOF_S6_ADDR = + QuicheNativeStaticallyReferencedJniMethods.in6AddressOffsetofS6Addr(); + static final int SIZEOF_SOCKLEN_T = QuicheNativeStaticallyReferencedJniMethods.sizeofSocklenT(); + static final int SIZEOF_SIZE_T = QuicheNativeStaticallyReferencedJniMethods.sizeofSizeT(); + + static final int QUICHE_RECV_INFO_OFFSETOF_FROM = + QuicheNativeStaticallyReferencedJniMethods.quicheRecvInfoOffsetofFrom(); + static final int QUICHE_RECV_INFO_OFFSETOF_FROM_LEN = + QuicheNativeStaticallyReferencedJniMethods.quicheRecvInfoOffsetofFromLen(); + static final int SIZEOF_QUICHE_RECV_INFO = QuicheNativeStaticallyReferencedJniMethods.sizeofQuicheRecvInfo(); + static final int QUICHE_SEND_INFO_OFFSETOF_TO = + QuicheNativeStaticallyReferencedJniMethods.quicheSendInfoOffsetofTo(); + static final int QUICHE_SEND_INFO_OFFSETOF_TO_LEN = + QuicheNativeStaticallyReferencedJniMethods.quicheSendInfoOffsetofToLen(); + static final int SIZEOF_QUICHE_SEND_INFO = QuicheNativeStaticallyReferencedJniMethods.sizeofQuicheSendInfo(); + static final int QUICHE_PROTOCOL_VERSION = QuicheNativeStaticallyReferencedJniMethods.quiche_protocol_version(); static final int QUICHE_MAX_CONN_ID_LEN = QuicheNativeStaticallyReferencedJniMethods.quiche_max_conn_id_len(); @@ -229,6 +267,7 @@ static native int quiche_retry(long scidAddr, int scidLen, long dcidAddr, int dc * See quiche_conn_new_with_tls. */ static native long quiche_conn_new_with_tls(long scidAddr, int scidLen, long odcidAddr, int odcidLen, + long peerAddr, int peerLen, long configAddr, long ssl, boolean isServer); /** @@ -240,12 +279,12 @@ static native long quiche_conn_new_with_tls(long scidAddr, int scidLen, long odc /** * See quiche_conn_recv. */ - static native int quiche_conn_recv(long connAddr, long bufAddr, int bufLen); + static native int quiche_conn_recv(long connAddr, long bufAddr, int bufLen, long infoAddr); /** * See quiche_conn_send. */ - static native int quiche_conn_send(long connAddr, long outAddr, int outLen); + static native int quiche_conn_send(long connAddr, long outAddr, int outLen, long infoAddr); /** * See quiche_conn_free. @@ -544,6 +583,8 @@ static native void quiche_config_enable_dgram(long configAddr, boolean enable, private static native long buffer_memory_address(ByteBuffer buffer); + static native int sockaddr_cmp(long addr, long addr2); + /** * Returns the memory address if the {@link ByteBuf} */ diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheNativeStaticallyReferencedJniMethods.java b/src/main/java/io/netty/incubator/codec/quic/QuicheNativeStaticallyReferencedJniMethods.java index b7e70297a..7bf48e758 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheNativeStaticallyReferencedJniMethods.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheNativeStaticallyReferencedJniMethods.java @@ -41,5 +41,30 @@ final class QuicheNativeStaticallyReferencedJniMethods { static native int quiche_cc_reno(); static native int quiche_cc_cubic(); + static native int quicheRecvInfoOffsetofFrom(); + static native int quicheRecvInfoOffsetofFromLen(); + static native int sizeofQuicheRecvInfo(); + static native int quicheSendInfoOffsetofTo(); + static native int quicheSendInfoOffsetofToLen(); + static native int sizeofQuicheSendInfo(); + + static native int afInet(); + static native int afInet6(); + static native int sizeofSockaddrIn(); + static native int sizeofSockaddrIn6(); + static native int sockaddrInOffsetofSinFamily(); + static native int sockaddrInOffsetofSinPort(); + static native int sockaddrInOffsetofSinAddr(); + static native int inAddressOffsetofSAddr(); + static native int sockaddrIn6OffsetofSin6Family(); + static native int sockaddrIn6OffsetofSin6Port(); + static native int sockaddrIn6OffsetofSin6Flowinfo(); + static native int sockaddrIn6OffsetofSin6Addr(); + static native int sockaddrIn6OffsetofSin6ScopeId(); + static native int in6AddressOffsetofS6Addr(); + static native int sizeofSockaddrStorage(); + static native int sizeofSocklenT(); + static native int sizeofSizeT(); + private QuicheNativeStaticallyReferencedJniMethods() { } } diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java index 7cb7c35f9..e747cc805 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java @@ -108,11 +108,7 @@ public void operationComplete(ChannelFuture future) { private final Map.Entry, Object>[] streamOptionsArray; private final Map.Entry, Object>[] streamAttrsArray; private final TimeoutHandler timeoutHandler = new TimeoutHandler(); - private final InetSocketAddress remote; - private volatile QuicheQuicConnection connection; - private volatile QuicConnectionAddress remoteIdAddr; - private volatile QuicConnectionAddress localIdAdrr; private boolean inFireChannelReadCompleteQueue; private boolean fireChannelReadCompletePending; private ByteBuf finBuffer; @@ -123,6 +119,9 @@ public void operationComplete(ChannelFuture future) { private CloseData closeData; private QuicConnectionStats statsAtClose; + private long currentRecvInfoAddress; + private long currentSendInfoAddress; + private InetSocketAddress remote; private boolean supportsDatagram; private boolean recvDatagramPending; private boolean datagramReadable; @@ -138,6 +137,9 @@ public void operationComplete(ChannelFuture future) { private static final int ACTIVE = 2; private volatile int state; private volatile String traceId; + private volatile QuicheQuicConnection connection; + private volatile QuicConnectionAddress remoteIdAddr; + private volatile QuicConnectionAddress localIdAdrr; private static final AtomicLongFieldUpdater UNI_STREAMS_LEFT_UPDATER = AtomicLongFieldUpdater.newUpdater(QuicheQuicChannel.class, "uniStreamsLeft"); @@ -160,6 +162,7 @@ private QuicheQuicChannel(Channel parent, boolean server, ByteBuffer key, this.supportsDatagram = supportsDatagram; this.remote = remote; + this.streamHandler = streamHandler; this.streamOptionsArray = streamOptionsArray; this.streamAttrsArray = streamAttrsArray; @@ -200,11 +203,16 @@ public long peerAllowedStreams(QuicStreamType type) { void attachQuicheConnection(QuicheQuicConnection connection) { this.connection = connection; + byte[] traceId = Quiche.quiche_conn_trace_id(connection.address()); if (traceId != null) { this.traceId = new String(traceId); } + connection.initInfoAddresses(remote); + currentRecvInfoAddress = connection.recvInfoAddress(); + currentSendInfoAddress = connection.sendInfoAddress(); + // Setup QLOG if needed. QLogConfiguration configuration = config.getQLogConfiguration(); if (configuration != null) { @@ -232,7 +240,7 @@ void attachQuicheConnection(QuicheQuicConnection connection) { private void connect(Function engineProvider, long configAddr, int localConnIdLength, - boolean supportsDatagram) throws Exception { + boolean supportsDatagram, long sockaddr) throws Exception { assert this.connection == null; assert this.traceId == null; assert this.key == null; @@ -259,9 +267,11 @@ private void connect(Function engineProvid ByteBuffer connectId = address.connId.duplicate(); ByteBuf idBuffer = alloc().directBuffer(connectId.remaining()).writeBytes(connectId.duplicate()); try { + int sockaddrLen = SockaddrIn.write(sockaddr, remote); QuicheQuicConnection connection = quicheEngine.createConnection(ssl -> Quiche.quiche_conn_new_with_tls(Quiche.memoryAddress(idBuffer) + idBuffer.readerIndex(), - idBuffer.readableBytes(), -1, -1, configAddr, ssl, false)); + idBuffer.readableBytes(), -1, -1, sockaddr, sockaddrLen, + configAddr, ssl, false)); if (connection == null) { failConnectPromiseAndThrow(new ConnectException()); return; @@ -757,8 +767,8 @@ StreamRecvResult streamRecv(long streamId, ByteBuf buffer) throws Exception { /** * Receive some data on a QUIC connection. */ - void recv(ByteBuf buffer) { - ((QuicChannelUnsafe) unsafe()).connectionRecv(buffer); + void recv(InetSocketAddress sender, ByteBuf buffer) { + ((QuicChannelUnsafe) unsafe()).connectionRecv(sender, buffer); } void writable() { @@ -881,10 +891,13 @@ private boolean connectionSendSegments(SegmentedDatagramPacketAllocator segmente ByteBuf out = alloc().directBuffer(bufferSize); int lastWritten = -1; for (;;) { + long sendInfo = connection.nextSendInfoAddress(currentSendInfoAddress); + InetSocketAddress sendToAddress = this.remote; + boolean done; int writerIndex = out.writerIndex(); int written = Quiche.quiche_conn_send( - connAddr, Quiche.memoryAddress(out) + writerIndex, out.writableBytes()); + connAddr, Quiche.memoryAddress(out) + writerIndex, out.writableBytes(), sendInfo); if (written == 0) { // No need to create a new datagram packet. Just try again. continue; @@ -902,10 +915,11 @@ private boolean connectionSendSegments(SegmentedDatagramPacketAllocator segmente int readable = out.readableBytes(); if (readable != 0) { if (lastWritten != -1 && readable > lastWritten) { - parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, remote)); + parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, sendToAddress)); } else { - parent().write(new DatagramPacket(out, remote)); + parent().write(new DatagramPacket(out, sendToAddress)); } + packetWasWritten = true; } else { out.release(); @@ -913,16 +927,39 @@ private boolean connectionSendSegments(SegmentedDatagramPacketAllocator segmente break; } - if (written < lastWritten) { - // The write was smaller then the write before. This means we can write all together as the - // last segment can be smaller then the other segments. + boolean needWriteNow = false; + + if (SockaddrIn.cmp(QuicheSendInfo.sockAddress(currentSendInfoAddress), + QuicheSendInfo.sockAddress(sendInfo)) != 0) { + // Update the current address so we can keep track when it change again. + currentSendInfoAddress = sendInfo; + + // Change the cached address + InetSocketAddress oldRemote = remote; + remote = QuicheSendInfo.read(sendInfo); + pipeline().fireUserEventTriggered( + new QuicConnectionMigrationEvent(oldRemote, remote)); + needWriteNow = true; + } + + // If the remote address changed we need to ensure we write the segment before we try to write the rest. + if (needWriteNow || written < lastWritten) { out.writerIndex(writerIndex + written); - parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, remote)); + + if (lastWritten == -1) { + // This the first write so we shouldnt try to use segments. + parent().write(new DatagramPacket(out, sendToAddress)); + } else { + // The write was smaller then the write before. This means we can write all together as the + // last segment can be smaller then the other segments. + parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, sendToAddress)); + } packetWasWritten = true; out = alloc().directBuffer(bufferSize); lastWritten = -1; numSegments = 0; + continue; } @@ -933,7 +970,7 @@ private boolean connectionSendSegments(SegmentedDatagramPacketAllocator segmente // As the last write was smaller then this write we first need to write what we had before as // a segment can never be bigger then the previous segment. After this we will try to build a new // chain of segments for the writes to follow. - parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, remote)); + parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, sendToAddress)); packetWasWritten = true; out = newOut; @@ -949,7 +986,7 @@ private boolean connectionSendSegments(SegmentedDatagramPacketAllocator segmente // anymore. In this case lets write what we have and start a new chain of segments. if (numSegments == segmentedDatagramPacketAllocator.maxNumSegments() || !out.isWritable()) { - parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, remote)); + parent().write(segmentedDatagramPacketAllocator.newPacket(out, lastWritten, sendToAddress)); packetWasWritten = true; out = alloc().directBuffer(bufferSize); @@ -964,10 +1001,11 @@ private boolean connectionSendSimple() { long connAddr = connection.address(); boolean packetWasWritten = false; for (;;) { + long sendInfo = connection.nextSendInfoAddress(currentSendInfoAddress); ByteBuf out = alloc().directBuffer(Quic.MAX_DATAGRAM_SIZE); int writerIndex = out.writerIndex(); int written = Quiche.quiche_conn_send( - connAddr, Quiche.memoryAddress(out) + writerIndex, out.writableBytes()); + connAddr, Quiche.memoryAddress(out) + writerIndex, out.writableBytes(), sendInfo); try { if (Quiche.throwIfError(written)) { @@ -985,6 +1023,17 @@ private boolean connectionSendSimple() { out.release(); continue; } + if (SockaddrIn.cmp(QuicheSendInfo.sockAddress(currentSendInfoAddress), + QuicheSendInfo.sockAddress(sendInfo)) != 0) { + // Update the current address so we can keep track when it change again. + currentSendInfoAddress = sendInfo; + + // Change the cached address + InetSocketAddress oldRemote = remote; + remote = QuicheSendInfo.read(sendInfo); + pipeline().fireUserEventTriggered( + new QuicConnectionMigrationEvent(oldRemote, remote)); + } out.writerIndex(writerIndex + written); parent().write(new DatagramPacket(out, remote)); packetWasWritten = true; @@ -1104,7 +1153,7 @@ public void connect(SocketAddress remote, SocketAddress local, ChannelPromise ch channelPromise.setFailure(new UnsupportedOperationException()); } - void connectionRecv(ByteBuf buffer) { + void connectionRecv(InetSocketAddress sender, ByteBuf buffer) { if (isConnDestroyed()) { return; } @@ -1127,11 +1176,25 @@ void connectionRecv(ByteBuf buffer) { int bufferReaderIndex = buffer.readerIndex(); long memoryAddress = Quiche.memoryAddress(buffer) + bufferReaderIndex; + long recvInfoAddress = connection.nextRecvInfoAddress(currentRecvInfoAddress); + QuicheRecvInfo.write(recvInfoAddress, sender); + + SocketAddress oldRemote = remote; + + if (SockaddrIn.cmp(QuicheRecvInfo.sockAddress(currentRecvInfoAddress), + QuicheRecvInfo.sockAddress(recvInfoAddress)) != 0) { + // Update the cached address + currentRecvInfoAddress = recvInfoAddress; + remote = sender; + pipeline().fireUserEventTriggered( + new QuicConnectionMigrationEvent(oldRemote, sender)); + } + long connAddr = connection.address(); try { do { // Call quiche_conn_recv(...) until we consumed all bytes or we did receive some error. - int res = Quiche.quiche_conn_recv(connAddr, memoryAddress, bufferReadable); + int res = Quiche.quiche_conn_recv(connAddr, memoryAddress, bufferReadable, recvInfoAddress); boolean done; try { done = Quiche.throwIfError(res); @@ -1183,6 +1246,8 @@ void connectionRecv(ByteBuf buffer) { } } while (bufferReadable > 0); } finally { + // Store for later usage. + currentRecvInfoAddress = recvInfoAddress; buffer.skipBytes((int) (memoryAddress - Quiche.memoryAddress(buffer))); if (tmpBuffer != null) { tmpBuffer.release(); @@ -1367,11 +1432,11 @@ void finishConnect() { // TODO: Come up with something better. static QuicheQuicChannel handleConnect(Function sslEngineProvider, SocketAddress address, long config, int localConnIdLength, - boolean supportsDatagram) throws Exception { + boolean supportsDatagram, long sockaddr) throws Exception { if (address instanceof QuicheQuicChannel.QuicheQuicChannelAddress) { QuicheQuicChannel.QuicheQuicChannelAddress addr = (QuicheQuicChannel.QuicheQuicChannelAddress) address; QuicheQuicChannel channel = addr.channel; - channel.connect(sslEngineProvider, config, localConnIdLength, supportsDatagram); + channel.connect(sslEngineProvider, config, localConnIdLength, supportsDatagram, sockaddr); return channel; } return null; diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicClientCodec.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicClientCodec.java index 364d2b69d..34d08895c 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicClientCodec.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicClientCodec.java @@ -54,7 +54,7 @@ public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, final QuicheQuicChannel channel; try { channel = QuicheQuicChannel.handleConnect(sslEngineProvider, remoteAddress, config.nativeAddress(), - localConnIdLength, config.isDatagramSupported()); + localConnIdLength, config.isDatagramSupported(), sockaddrMemory.memoryAddress()); } catch (Exception e) { promise.setFailure(e); return; diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicCodec.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicCodec.java index 7e146fc00..2ee873378 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicCodec.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicCodec.java @@ -32,6 +32,8 @@ import java.util.Map; import java.util.Queue; +import static io.netty.incubator.codec.quic.Quiche.allocateNativeOrder; + /** * Abstract base class for QUIC codecs. */ @@ -51,6 +53,8 @@ abstract class QuicheQuicCodec extends ChannelDuplexHandler { protected final QuicheConfig config; protected final int localConnIdLength; + // This buffer is used to copy InetSocketAddress to sockaddr_storage and so pass it down the JNI layer. + protected ByteBuf sockaddrMemory; QuicheQuicCodec(QuicheConfig config, int localConnIdLength, int maxTokenLength, FlushStrategy flushStrategy) { this.config = config; @@ -69,13 +73,14 @@ protected void putChannel(QuicheQuicChannel channel) { @Override public void handlerAdded(ChannelHandlerContext ctx) { + sockaddrMemory = allocateNativeOrder(Quiche.SIZEOF_SOCKADDR_STORAGE); headerParser = new QuicHeaderParser(maxTokenLength, localConnIdLength); parserCallback = (sender, recipient, buffer, type, version, scid, dcid, token) -> { QuicheQuicChannel channel = quicPacketRead(ctx, sender, recipient, type, version, scid, dcid, token); if (channel != null) { - channel.recv(buffer); + channel.recv(sender, buffer); if (channel.markInFireChannelReadCompleteQueue()) { needsFireChannelReadComplete.add(channel); } @@ -97,6 +102,9 @@ public void handlerRemoved(ChannelHandlerContext ctx) { needsFireChannelReadComplete.clear(); } finally { config.free(); + if (sockaddrMemory != null) { + sockaddrMemory.release(); + } if (headerParser != null) { headerParser.close(); headerParser = null; diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicConnection.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicConnection.java index 9fe535435..d51d3b4fc 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicConnection.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicConnection.java @@ -15,19 +15,38 @@ */ package io.netty.incubator.codec.quic; +import io.netty.buffer.ByteBuf; import io.netty.util.ReferenceCounted; +import java.net.InetSocketAddress; import java.util.function.Supplier; final class QuicheQuicConnection { + private static final int TOTAL_RECV_INFO_SIZE = Quiche.SIZEOF_QUICHE_RECV_INFO + Quiche.SIZEOF_SOCKADDR_STORAGE; + private static final int QUICHE_SEND_INFOS_OFFSET = 2 * TOTAL_RECV_INFO_SIZE; private final ReferenceCounted refCnt; private final QuicheQuicSslEngine engine; + + // This block of memory is used to store the following structs (in this order): + // - quiche_recv_info + // - sockaddr_storage + // - quiche_recv_info + // - sockaddr_storage + // - quiche_send_info + // - quiche_send_info + // + // We need to have every stored 2 times as we need to check if the last sockaddr has changed between + // quiche_conn_recv and quiche_conn_send calls. If this happens we know a QUIC connection migration did happen. + private final ByteBuf infoBuffer; private long connection; QuicheQuicConnection(long connection, QuicheQuicSslEngine engine, ReferenceCounted refCnt) { this.connection = connection; this.engine = engine; this.refCnt = refCnt; + // TODO: Maybe cache these per thread as we only use them temporary within a limited scope. + infoBuffer = Quiche.allocateNativeOrder(QUICHE_SEND_INFOS_OFFSET + + 2 * Quiche.SIZEOF_QUICHE_SEND_INFO); } void free() { @@ -44,6 +63,7 @@ void free() { } if (release) { refCnt.release(); + infoBuffer.release(); } } @@ -75,6 +95,45 @@ long address() { return connection; } + private long sendInfosAddress() { + return infoBuffer.memoryAddress() + QUICHE_SEND_INFOS_OFFSET; + } + + void initInfoAddresses(InetSocketAddress address) { + // Fill both quiche_recv_info structs with the same address. + QuicheRecvInfo.write(infoBuffer.memoryAddress(), address); + QuicheRecvInfo.write(infoBuffer.memoryAddress() + TOTAL_RECV_INFO_SIZE, address); + + // Fill both quiche_send_info structs with the same address. + long sendInfosAddress = sendInfosAddress(); + QuicheSendInfo.write(sendInfosAddress, address); + QuicheSendInfo.write(sendInfosAddress + Quiche.SIZEOF_QUICHE_SEND_INFO, address); + } + + long recvInfoAddress() { + return infoBuffer.memoryAddress(); + } + + long sendInfoAddress() { + return sendInfosAddress(); + } + + long nextRecvInfoAddress(long previousRecvInfoAddress) { + long memoryAddress = infoBuffer.memoryAddress(); + if (memoryAddress == previousRecvInfoAddress) { + return memoryAddress + TOTAL_RECV_INFO_SIZE; + } + return memoryAddress; + } + + long nextSendInfoAddress(long previousSendInfoAddress) { + long memoryAddress = sendInfosAddress(); + if (memoryAddress == previousSendInfoAddress) { + return memoryAddress + Quiche.SIZEOF_QUICHE_SEND_INFO; + } + return memoryAddress; + } + boolean isClosed() { assert connection != -1; return Quiche.quiche_conn_is_closed(connection); @@ -90,4 +149,5 @@ protected void finalize() throws Throwable { super.finalize(); } } + } diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicServerCodec.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicServerCodec.java index 0317faec9..778124987 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicServerCodec.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicServerCodec.java @@ -211,9 +211,12 @@ private QuicheQuicChannel handleServer(ChannelHandlerContext ctx, InetSocketAddr } QuicheQuicSslEngine quicSslEngine = (QuicheQuicSslEngine) engine; - QuicheQuicConnection connection = quicSslEngine.createConnection(ssl -> - Quiche.quiche_conn_new_with_tls(scidAddr, scidLen, ocidAddr, ocidLen, - config.nativeAddress(), ssl, true)); + QuicheQuicConnection connection = quicSslEngine.createConnection(ssl -> { + long peerAddr = sockaddrMemory.memoryAddress(); + int peerLen = SockaddrIn.write(peerAddr, sender); + return Quiche.quiche_conn_new_with_tls(scidAddr, scidLen, ocidAddr, ocidLen, peerAddr, peerLen, + config.nativeAddress(), ssl, true); + }); if (connection == null) { channel.unsafe().closeForcibly(); LOGGER.debug("quiche_accept failed"); diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheRecvInfo.java b/src/main/java/io/netty/incubator/codec/quic/QuicheRecvInfo.java new file mode 100644 index 000000000..87be4f29b --- /dev/null +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheRecvInfo.java @@ -0,0 +1,73 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.incubator.codec.quic; + +import io.netty.util.internal.PlatformDependent; + +import java.net.InetSocketAddress; + +final class QuicheRecvInfo { + + private QuicheRecvInfo() { } + + /** + * Write the {@link InetSocketAddress} into the {@code quiche_recv_info} struct. + * + *
+     * typedef struct {
+     *     struct sockaddr *from;
+     *     socklen_t from_len;
+     * } quiche_recv_info;
+     * 
+ * + * @param memory the memory address of {@code quiche_recv_info}. + * @param address the {@link InetSocketAddress} to write into {@code quiche_recv_info}. + */ + static void write(long memory, InetSocketAddress address) { + long sockaddr = memory + Quiche.SIZEOF_QUICHE_RECV_INFO; + int len = SockaddrIn.write(sockaddr, address); + if (Quiche.SIZEOF_SIZE_T == 4) { + PlatformDependent.putInt(memory + Quiche.QUICHE_RECV_INFO_OFFSETOF_FROM, (int) sockaddr); + } else { + PlatformDependent.putLong(memory + Quiche.QUICHE_RECV_INFO_OFFSETOF_FROM, sockaddr); + } + switch (Quiche.SIZEOF_SOCKLEN_T) { + case 1: + PlatformDependent.putByte(memory + Quiche.QUICHE_RECV_INFO_OFFSETOF_FROM_LEN, (byte) len); + break; + case 2: + PlatformDependent.putShort(memory + Quiche.QUICHE_RECV_INFO_OFFSETOF_FROM_LEN, (short) len); + break; + case 4: + PlatformDependent.putInt(memory + Quiche.QUICHE_RECV_INFO_OFFSETOF_FROM_LEN, len); + break; + case 8: + PlatformDependent.putLong(memory + Quiche.QUICHE_RECV_INFO_OFFSETOF_FROM_LEN, len); + break; + default: + throw new IllegalStateException(); + } + } + + /** + * Return the memory address of the {@code sockaddr} that is contained in {@code quiche_recv_info}. + * @param memory the memory address of {@code quiche_recv_info}. + * @return the memory address of the {@code sockaddr}. + */ + static long sockAddress(long memory) { + return memory + Quiche.SIZEOF_QUICHE_RECV_INFO; + } +} diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheSendInfo.java b/src/main/java/io/netty/incubator/codec/quic/QuicheSendInfo.java new file mode 100644 index 000000000..970895883 --- /dev/null +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheSendInfo.java @@ -0,0 +1,115 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.incubator.codec.quic; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.PlatformDependent; + +import java.net.InetSocketAddress; + +final class QuicheSendInfo { + + private static final FastThreadLocal IPV4_ARRAYS = new FastThreadLocal() { + @Override + protected byte[] initialValue() { + return new byte[SockaddrIn.IPV4_ADDRESS_LENGTH]; + } + }; + + private static final FastThreadLocal IPV6_ARRAYS = new FastThreadLocal() { + @Override + protected byte[] initialValue() { + return new byte[SockaddrIn.IPV6_ADDRESS_LENGTH]; + } + }; + + private QuicheSendInfo() { } + + /** + * Read the {@link InetSocketAddress} out of the {@code quiche_send_info} struct. + * + * @param memory the memory address of {@code quiche_send_info}. + * @return the address that was read. + */ + static InetSocketAddress read(long memory) { + long to = memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO; + long len = readLen(memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO_LEN); + if (len == Quiche.SIZEOF_SOCKADDR_IN) { + return SockaddrIn.readIPv4(to, IPV4_ARRAYS.get()); + } + assert len == Quiche.SIZEOF_SOCKADDR_IN6; + return SockaddrIn.readIPv6(to, IPV6_ARRAYS.get(), IPV4_ARRAYS.get()); + } + + private static long readLen(long address) { + switch (Quiche.SIZEOF_SOCKLEN_T) { + case 1: + return PlatformDependent.getByte(address); + case 2: + return PlatformDependent.getShort(address); + case 4: + return PlatformDependent.getInt(address); + case 8: + return PlatformDependent.getLong(address); + default: + throw new IllegalStateException(); + } + } + + /** + * Write the {@link InetSocketAddress} into the {@code quiche_send_info} struct. + *
+     *
+     * typedef struct {
+     *     // The address the packet should be sent to.
+     *     struct sockaddr_storage to;
+     *     socklen_t to_len;
+     * } quiche_send_info;
+     * 
+ * + * @param memory the memory address of {@code quiche_send_info}. + * @param address the {@link InetSocketAddress} to write into {@code quiche_send_info}. + */ + static void write(long memory, InetSocketAddress address) { + long sockaddr = sockAddress(memory); + int len = SockaddrIn.write(sockaddr, address); + switch (Quiche.SIZEOF_SOCKLEN_T) { + case 1: + PlatformDependent.putByte(memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO_LEN, (byte) len); + break; + case 2: + PlatformDependent.putShort(memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO_LEN, (short) len); + break; + case 4: + PlatformDependent.putInt(memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO_LEN, len); + break; + case 8: + PlatformDependent.putLong(memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO_LEN, len); + break; + default: + throw new IllegalStateException(); + } + } + + /** + * Return the memory address of the {@code sockaddr_storage} that is contained in {@code quiche_send_info}. + * @param memory the memory address of {@code quiche_send_info}. + * @return the memory address of the {@code sockaddr_storage}. + */ + static long sockAddress(long memory) { + return memory + Quiche.QUICHE_SEND_INFO_OFFSETOF_TO; + } +} diff --git a/src/main/java/io/netty/incubator/codec/quic/SockaddrIn.java b/src/main/java/io/netty/incubator/codec/quic/SockaddrIn.java new file mode 100644 index 000000000..2219e8fdb --- /dev/null +++ b/src/main/java/io/netty/incubator/codec/quic/SockaddrIn.java @@ -0,0 +1,162 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.incubator.codec.quic; + +import io.netty.util.internal.PlatformDependent; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +import static io.netty.util.internal.PlatformDependent.BIG_ENDIAN_NATIVE_ORDER; + +final class SockaddrIn { + static final byte[] IPV4_MAPPED_IPV6_PREFIX = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, (byte) 0xff, (byte) 0xff }; + static final int IPV4_ADDRESS_LENGTH = 4; + static final int IPV6_ADDRESS_LENGTH = 16; + + private SockaddrIn() { } + + static int cmp(long memory, long memory2) { + return Quiche.sockaddr_cmp(memory, memory2); + } + + static int write(long memory, InetSocketAddress address) { + InetAddress addr = address.getAddress(); + return write(addr instanceof Inet6Address, memory, address); + } + + static int write(boolean ipv6, long memory, InetSocketAddress address) { + if (ipv6) { + return SockaddrIn.writeIPv6(memory, address.getAddress(), address.getPort()); + } else { + return SockaddrIn.writeIPv4(memory, address.getAddress(), address.getPort()); + } + } + + /** + * + * struct sockaddr_in { + * sa_family_t sin_family; // address family: AF_INET + * in_port_t sin_port; // port in network byte order + * struct in_addr sin_addr; // internet address + * }; + * + * // Internet address. + * struct in_addr { + * uint32_t s_addr; // address in network byte order + * }; + * + */ + static int writeIPv4(long memory, InetAddress address, int port) { + PlatformDependent.setMemory(memory, Quiche.SIZEOF_SOCKADDR_IN, (byte) 0); + + PlatformDependent.putShort(memory + Quiche.SOCKADDR_IN_OFFSETOF_SIN_FAMILY, Quiche.AF_INET); + PlatformDependent.putShort(memory + Quiche.SOCKADDR_IN_OFFSETOF_SIN_PORT, handleNetworkOrder((short) port)); + byte[] bytes = address.getAddress(); + int offset = 0; + if (bytes.length == IPV6_ADDRESS_LENGTH) { + // IPV6 mapped IPV4 address, we only need the last 4 bytes. + offset = IPV4_MAPPED_IPV6_PREFIX.length; + } + assert bytes.length == offset + 4; + PlatformDependent.copyMemory(bytes, offset, + memory + Quiche.SOCKADDR_IN_OFFSETOF_SIN_ADDR + Quiche.IN_ADDRESS_OFFSETOF_S_ADDR, 4); + return Quiche.SIZEOF_SOCKADDR_IN; + } + + /** + * struct sockaddr_in6 { + * sa_family_t sin6_family; // AF_INET6 + * in_port_t sin6_port; // port number + * uint32_t sin6_flowinfo; // IPv6 flow information + * struct in6_addr sin6_addr; // IPv6 address + * uint32_t sin6_scope_id; /* Scope ID (new in 2.4) + * }; + * + * struct in6_addr { + * unsigned char s6_addr[16]; // IPv6 address + * }; + */ + static int writeIPv6(long memory, InetAddress address, int port) { + PlatformDependent.setMemory(memory, Quiche.SIZEOF_SOCKADDR_IN6, (byte) 0); + PlatformDependent.putShort(memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_FAMILY, Quiche.AF_INET6); + PlatformDependent.putShort(memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_PORT, handleNetworkOrder((short) port)); + // Skip sin6_flowinfo as we did memset before + byte[] bytes = address.getAddress(); + if (bytes.length == IPV4_ADDRESS_LENGTH) { + int offset = Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_ADDR + Quiche.IN6_ADDRESS_OFFSETOF_S6_ADDR; + PlatformDependent.copyMemory(IPV4_MAPPED_IPV6_PREFIX, 0, memory + offset, IPV4_MAPPED_IPV6_PREFIX.length); + PlatformDependent.copyMemory(bytes, 0, + memory + offset + IPV4_MAPPED_IPV6_PREFIX.length, IPV4_ADDRESS_LENGTH); + // Skip sin6_scope_id as we did memset before + } else { + PlatformDependent.copyMemory( + bytes, 0, memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_ADDR + Quiche.IN6_ADDRESS_OFFSETOF_S6_ADDR, + IPV6_ADDRESS_LENGTH); + PlatformDependent.putInt( + memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_SCOPE_ID, ((Inet6Address) address).getScopeId()); + } + return Quiche.SIZEOF_SOCKADDR_IN6; + } + + static InetSocketAddress readIPv4(long memory, byte[] tmpArray) { + assert tmpArray.length == IPV4_ADDRESS_LENGTH; + int port = handleNetworkOrder(PlatformDependent.getShort( + memory + Quiche.SOCKADDR_IN_OFFSETOF_SIN_PORT)) & 0xFFFF; + PlatformDependent.copyMemory(memory + Quiche.SOCKADDR_IN_OFFSETOF_SIN_ADDR + Quiche.IN_ADDRESS_OFFSETOF_S_ADDR, + tmpArray, 0, IPV4_ADDRESS_LENGTH); + try { + return new InetSocketAddress(InetAddress.getByAddress(tmpArray), port); + } catch (UnknownHostException ignore) { + return null; + } + } + + static InetSocketAddress readIPv6(long memory, byte[] ipv6Array, byte[] ipv4Array) { + assert ipv6Array.length == IPV6_ADDRESS_LENGTH; + assert ipv4Array.length == IPV4_ADDRESS_LENGTH; + + int port = handleNetworkOrder(PlatformDependent.getShort( + memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_PORT)) & 0xFFFF; + PlatformDependent.copyMemory( + memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_ADDR + Quiche.IN6_ADDRESS_OFFSETOF_S6_ADDR, + ipv6Array, 0, IPV6_ADDRESS_LENGTH); + if (PlatformDependent.equals( + ipv6Array, 0, IPV4_MAPPED_IPV6_PREFIX, 0, IPV4_MAPPED_IPV6_PREFIX.length)) { + System.arraycopy(ipv6Array, IPV4_MAPPED_IPV6_PREFIX.length, ipv4Array, 0, IPV4_ADDRESS_LENGTH); + try { + return new InetSocketAddress(Inet4Address.getByAddress(ipv4Array), port); + } catch (UnknownHostException ignore) { + return null; + } + } else { + int scopeId = PlatformDependent.getInt(memory + Quiche.SOCKADDR_IN6_OFFSETOF_SIN6_SCOPE_ID); + try { + return new InetSocketAddress(Inet6Address.getByAddress(null, ipv6Array, scopeId), port); + } catch (UnknownHostException ignore) { + return null; + } + } + } + + private static short handleNetworkOrder(short v) { + return BIG_ENDIAN_NATIVE_ORDER ? v : Short.reverseBytes(v); + } +} diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicChannelConnectTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicChannelConnectTest.java index 8bbd34bd8..e94649d3e 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicChannelConnectTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicChannelConnectTest.java @@ -101,14 +101,16 @@ public void testConnectAndQLogDir() throws Throwable { }); } - private void testQLog(Path path, Consumer consumer) throws Exception { - Channel server = QuicTestUtils.newServer(new ChannelInboundHandlerAdapter(), + private void testQLog(Path path, Consumer consumer) throws Throwable { + QuicChannelValidationHandler serverValidationHandler = new QuicChannelValidationHandler(); + QuicChannelValidationHandler clientValidationHandler = new QuicChannelValidationHandler(); + Channel server = QuicTestUtils.newServer(serverValidationHandler, new ChannelInboundHandlerAdapter()); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(); try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientValidationHandler) .option(QuicChannelOption.QLOG, new QLogConfiguration(path.toString(), "testTitle", "test")) .streamHandler(new ChannelInboundHandlerAdapter()) @@ -123,6 +125,9 @@ private void testQLog(Path path, Consumer consumer) throws Exception { quicChannel.close().sync(); quicChannel.closeFuture().sync(); consumer.accept(path); + + serverValidationHandler.assertState(); + clientValidationHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. @@ -242,10 +247,9 @@ public void testConnectAlreadyConnected() throws Throwable { ChannelFuture closeFuture = quicChannel.closeFuture().await(); assertTrue(closeFuture.isSuccess()); clientQuicChannelHandler.assertState(); - } finally { serverQuicChannelHandler.assertState(); serverQuicStreamHandler.assertState(); - + } finally { server.close().sync(); // Close the parent Datagram channel as well. channel.close().sync(); @@ -304,7 +308,9 @@ public int maxTokenLength() { quicChannel.close().sync(); ChannelFuture closeFuture = quicChannel.closeFuture().await(); assertTrue(closeFuture.isSuccess()); + clientQuicChannelHandler.assertState(); + serverQuicChannelHandler.assertState(); assertEquals(serverQuicChannelHandler.localAddress(), remoteAddress); assertEquals(serverQuicChannelHandler.remoteAddress(), localAddress); @@ -314,7 +320,6 @@ public int maxTokenLength() { assertNotNull(quicChannel.remoteAddress()); } finally { serverLatch.await(); - serverQuicChannelHandler.assertState(); server.close().sync(); // Close the parent Datagram channel as well. @@ -616,8 +621,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } } - private static final class ChannelStateVerifyHandler extends ChannelInboundHandlerAdapter { - private volatile Throwable cause; + private static final class ChannelStateVerifyHandler extends QuicChannelValidationHandler { @Override public void channelActive(ChannelHandlerContext ctx) { ctx.fireChannelActive(); @@ -629,20 +633,9 @@ public void channelInactive(ChannelHandlerContext ctx) { ctx.fireChannelInactive(); fail(); } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - this.cause = cause; - } - - void assertState() throws Throwable { - if (cause != null) { - throw cause; - } - } } - private static final class ChannelActiveVerifyHandler extends ChannelInboundHandlerAdapter { + private static final class ChannelActiveVerifyHandler extends QuicChannelValidationHandler { private final BlockingQueue states = new LinkedBlockingQueue<>(); private volatile QuicConnectionAddress localAddress; private volatile QuicConnectionAddress remoteAddress; @@ -679,6 +672,7 @@ void assertState() throws Throwable { assertEquals(i, (int) states.take()); } assertNull(states.poll()); + super.assertState(); } QuicConnectionAddress localAddress() { diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicChannelDatagramTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicChannelDatagramTest.java index de1235e07..9a664a434 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicChannelDatagramTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicChannelDatagramTest.java @@ -60,9 +60,7 @@ public void testDatagramFlushInChannelReadComplete() throws Throwable { private void testDatagram(boolean flushInReadComplete) throws Throwable { AtomicReference serverEventRef = new AtomicReference<>(); - Channel server = QuicTestUtils.newServer(QuicTestUtils.newQuicServerBuilder() - .datagram(10, 10), - InsecureQuicTokenHandler.INSTANCE, new ChannelInboundHandlerAdapter() { + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { @@ -91,36 +89,45 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc if (evt instanceof QuicDatagramExtensionEvent) { serverEventRef.set((QuicDatagramExtensionEvent) evt); } + super.userEventTriggered(ctx, evt); } - }, new ChannelInboundHandlerAdapter()); + }; + Channel server = QuicTestUtils.newServer(QuicTestUtils.newQuicServerBuilder() + .datagram(10, 10), + InsecureQuicTokenHandler.INSTANCE, serverHandler , new ChannelInboundHandlerAdapter()); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Promise receivedBuffer = ImmediateEventExecutor.INSTANCE.newPromise(); AtomicReference clientEventRef = new AtomicReference<>(); Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder() .datagram(10, 10)); - try { - QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter() { - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - if (!receivedBuffer.trySuccess((ByteBuf) msg)) { - ReferenceCountUtil.release(msg); - } - } - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - if (evt instanceof QuicDatagramExtensionEvent) { - clientEventRef.set((QuicDatagramExtensionEvent) evt); - } - } + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (!receivedBuffer.trySuccess((ByteBuf) msg)) { + ReferenceCountUtil.release(msg); + } + } - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - receivedBuffer.tryFailure(cause); - } - }) + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof QuicDatagramExtensionEvent) { + clientEventRef.set((QuicDatagramExtensionEvent) evt); + } + super.userEventTriggered(ctx, evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + receivedBuffer.tryFailure(cause); + super.exceptionCaught(ctx, cause); + } + }; + + try { + QuicChannel quicChannel = QuicChannel.newBootstrap(channel) + .handler(clientHandler) .remoteAddress(address) .connect() .get(); @@ -136,6 +143,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { assertNotEquals(0, clientEventRef.get().maxLength()); quicChannel.close().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. @@ -170,84 +180,86 @@ private void testDatagramNoAutoRead(int maxMessagesPerRead, boolean readLater) t int numDatagrams = 5; AtomicInteger serverReadCount = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(numDatagrams); - Channel server = QuicTestUtils.newServer(QuicTestUtils.newQuicServerBuilder() - .option(ChannelOption.AUTO_READ, false) - .option(ChannelOption.MAX_MESSAGES_PER_READ, maxMessagesPerRead) - .datagram(10, 10), - InsecureQuicTokenHandler.INSTANCE, new ChannelInboundHandlerAdapter() { - private int readPerLoop; + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() { + private int readPerLoop; - @Override - public void channelActive(ChannelHandlerContext ctx) { - ctx.read(); - } + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + readPerLoop++; - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - if (msg instanceof ByteBuf) { - readPerLoop++; - - ctx.writeAndFlush(msg).addListener(future -> { - if (future.isSuccess()) { - latch.countDown(); - } - }); - if (serverReadCount.incrementAndGet() == numDatagrams) { - serverPromise.trySuccess(null); - } - } else { - ctx.fireChannelRead(msg); + ctx.writeAndFlush(msg).addListener(future -> { + if (future.isSuccess()) { + latch.countDown(); } + }); + if (serverReadCount.incrementAndGet() == numDatagrams) { + serverPromise.trySuccess(null); } + } else { + ctx.fireChannelRead(msg); + } + } - @Override - public void channelReadComplete(ChannelHandlerContext ctx) { - if (readPerLoop > maxMessagesPerRead) { - ctx.close(); - serverPromise.tryFailure(new AssertionError( - "Read more then " + maxMessagesPerRead + " time per read loop")); - return; - } - readPerLoop = 0; - if (serverReadCount.get() < numDatagrams) { - if (readLater) { - ctx.executor().execute(ctx::read); - } else { - ctx.read(); - } - } + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + if (readPerLoop > maxMessagesPerRead) { + ctx.close(); + serverPromise.tryFailure(new AssertionError( + "Read more then " + maxMessagesPerRead + " time per read loop")); + return; + } + readPerLoop = 0; + if (serverReadCount.get() < numDatagrams) { + if (readLater) { + ctx.executor().execute(ctx::read); + } else { + ctx.read(); } - }, new ChannelInboundHandlerAdapter()); + } + } + }; + Channel server = QuicTestUtils.newServer(QuicTestUtils.newQuicServerBuilder() + .option(ChannelOption.AUTO_READ, false) + .option(ChannelOption.MAX_MESSAGES_PER_READ, maxMessagesPerRead) + .datagram(10, 10), + InsecureQuicTokenHandler.INSTANCE, serverHandler, new ChannelInboundHandlerAdapter()); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder() .datagram(10, 10)); AtomicInteger clientReadCount = new AtomicInteger(); - try { - QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter() { - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - if (msg instanceof ByteBuf) { - - if (clientReadCount.incrementAndGet() == numDatagrams) { - if (!clientPromise.trySuccess((ByteBuf) msg)) { - ReferenceCountUtil.release(msg); - } - } else { - ReferenceCountUtil.release(msg); - } - } else { - ctx.fireChannelRead(msg); - } - } + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler() { - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - clientPromise.tryFailure(cause); + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + + if (clientReadCount.incrementAndGet() == numDatagrams) { + if (!clientPromise.trySuccess((ByteBuf) msg)) { + ReferenceCountUtil.release(msg); } - }) + } else { + ReferenceCountUtil.release(msg); + } + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + clientPromise.tryFailure(cause); + } + }; + try { + QuicChannel quicChannel = QuicChannel.newBootstrap(channel) + .handler(clientHandler) .remoteAddress(address) .connect() .get(); @@ -269,6 +281,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { expected.release(); quicChannel.close().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicChannelEchoTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicChannelEchoTest.java index aef38d90f..84deef217 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicChannelEchoTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicChannelEchoTest.java @@ -217,7 +217,7 @@ public void testEchoStartedFromClient(boolean autoRead, boolean directBuffer, bo final EchoHandler sh = new EchoHandler(true, autoRead, allocator); final EchoHandler ch = new EchoHandler(false, autoRead, allocator); - Channel server = QuicTestUtils.newServer(new ChannelInboundHandlerAdapter() { + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() { @Override public void channelActive(ChannelHandlerContext ctx) { setAllocator(ctx.channel(), allocator); @@ -233,28 +233,31 @@ public void channelReadComplete(ChannelHandlerContext ctx) { ctx.read(); } } - }, sh); + }; + + Channel server = QuicTestUtils.newServer(serverHandler, sh); setAllocator(server, allocator); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(); QuicChannel quicChannel = null; try { - quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter() { - @Override - public void channelActive(ChannelHandlerContext ctx) { - if (!autoRead) { - ctx.read(); - } - } + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + if (!autoRead) { + ctx.read(); + } + } - @Override - public void channelReadComplete(ChannelHandlerContext ctx) { - if (!autoRead) { - ctx.read(); - } - } - }) + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + if (!autoRead) { + ctx.read(); + } + } + }; + quicChannel = QuicChannel.newBootstrap(channel) + .handler(clientHandler) .streamHandler(ch) // Use the same allocator for the streams. .streamOption(ChannelOption.ALLOCATOR, allocator) @@ -292,6 +295,9 @@ public void channelReadComplete(ChannelHandlerContext ctx) { sh.channel.parent().close().sync(); ch.channel.parent().close().sync(); checkForException(ch, sh); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().syncUninterruptibly(); QuicTestUtils.closeIfNotNull(quicChannel); diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicChannelValidationHandler.java b/src/test/java/io/netty/incubator/codec/quic/QuicChannelValidationHandler.java new file mode 100644 index 000000000..8940034bf --- /dev/null +++ b/src/test/java/io/netty/incubator/codec/quic/QuicChannelValidationHandler.java @@ -0,0 +1,46 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.incubator.codec.quic; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +import static org.junit.jupiter.api.Assertions.fail; + +class QuicChannelValidationHandler extends ChannelInboundHandlerAdapter { + + private volatile Throwable cause; + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof QuicConnectionMigrationEvent) { + fail("QuicConnectionMigrationEvent should never happen atm"); + } + super.userEventTriggered(ctx, evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + this.cause = cause; + } + + void assertState() throws Throwable { + if (cause != null) { + throw cause; + } + } + +} diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicConnectionStatsTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicConnectionStatsTest.java index 1c3737c4b..51d2e65b5 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicConnectionStatsTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicConnectionStatsTest.java @@ -41,27 +41,29 @@ public void testStatsAreCollected() throws Throwable { Channel channel = null; AtomicInteger counter = new AtomicInteger(); + Promise serverActiveStats = ImmediateEventExecutor.INSTANCE.newPromise(); + Promise serverInactiveStats = ImmediateEventExecutor.INSTANCE.newPromise(); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + collectStats(ctx, serverActiveStats); + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + collectStats(ctx, serverInactiveStats); + ctx.fireChannelInactive(); + } + + private void collectStats(ChannelHandlerContext ctx, Promise promise) { + QuicheQuicChannel channel = (QuicheQuicChannel) ctx.channel(); + channel.collectStats(promise); + } + }; + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); try { - Promise serverActiveStats = ImmediateEventExecutor.INSTANCE.newPromise(); - Promise serverInactiveStats = ImmediateEventExecutor.INSTANCE.newPromise(); - server = QuicTestUtils.newServer(new ChannelInboundHandlerAdapter() { - @Override - public void channelActive(ChannelHandlerContext ctx) { - collectStats(ctx, serverActiveStats); - ctx.fireChannelActive(); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) { - collectStats(ctx, serverInactiveStats); - ctx.fireChannelInactive(); - } - - private void collectStats(ChannelHandlerContext ctx, Promise promise) { - QuicheQuicChannel channel = (QuicheQuicChannel) ctx.channel(); - channel.collectStats(promise); - } - }, new ChannelInboundHandlerAdapter() { + server = QuicTestUtils.newServer(serverHandler, new ChannelInboundHandlerAdapter() { @Override public void channelActive(ChannelHandlerContext ctx) { @@ -82,7 +84,7 @@ public boolean isSharable() { channel = QuicTestUtils.newClient(); QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(server.localAddress()) .connect().get(); @@ -116,6 +118,9 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { assertNotNull(serverActiveStats.sync().getNow()); assertStats(serverInactiveStats.sync().getNow()); assertEquals(1, counter.get()); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicReadableTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicReadableTest.java index e8a0dd890..80cc08e9a 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicReadableTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicReadableTest.java @@ -41,10 +41,11 @@ public void testCorrectlyHandleReadableStreams() throws Throwable { final AtomicReference serverErrorRef = new AtomicReference<>(); final AtomicReference clientErrorRef = new AtomicReference<>(); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); Channel server = QuicTestUtils.newServer( QuicTestUtils.newQuicServerBuilder().initialMaxStreamsBidirectional(5000), InsecureQuicTokenHandler.INSTANCE, - null, new ChannelInboundHandlerAdapter() { + serverHandler, new ChannelInboundHandlerAdapter() { private int counter; @Override public void channelRegistered(ChannelHandlerContext ctx) { @@ -80,9 +81,10 @@ public boolean isSharable() { } }); Channel channel = QuicTestUtils.newClient(); + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(server.localAddress()) .connect() @@ -112,6 +114,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { throwIfNotNull(serverErrorRef); throwIfNotNull(clientErrorRef); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCloseTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCloseTest.java index df2f611f6..4bc99036b 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCloseTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCloseTest.java @@ -32,36 +32,40 @@ public class QuicStreamChannelCloseTest extends AbstractQuicTest { @Test - public void testCloseFromServerWhileInActiveUnidirectional() throws Exception { + public void testCloseFromServerWhileInActiveUnidirectional() throws Throwable { testCloseFromServerWhileInActive(QuicStreamType.UNIDIRECTIONAL, false); } @Test - public void testCloseFromServerWhileInActiveBidirectional() throws Exception { + public void testCloseFromServerWhileInActiveBidirectional() throws Throwable { testCloseFromServerWhileInActive(QuicStreamType.BIDIRECTIONAL, false); } @Test - public void testHalfCloseFromServerWhileInActiveUnidirectional() throws Exception { + public void testHalfCloseFromServerWhileInActiveUnidirectional() throws Throwable { testCloseFromServerWhileInActive(QuicStreamType.UNIDIRECTIONAL, true); } @Test - public void testHalfCloseFromServerWhileInActiveBidirectional() throws Exception { + public void testHalfCloseFromServerWhileInActiveBidirectional() throws Throwable { testCloseFromServerWhileInActive(QuicStreamType.BIDIRECTIONAL, true); } private static void testCloseFromServerWhileInActive(QuicStreamType type, - boolean halfClose) throws Exception { + boolean halfClose) throws Throwable { Channel server = null; Channel channel = null; try { final Promise streamPromise = ImmediateEventExecutor.INSTANCE.newPromise(); - server = QuicTestUtils.newServer(new StreamCreationHandler(type, halfClose, streamPromise), + QuicChannelValidationHandler serverHandler = new StreamCreationHandler(type, halfClose, streamPromise); + server = QuicTestUtils.newServer(serverHandler, new ChannelInboundHandlerAdapter()); channel = QuicTestUtils.newClient(); + + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); + QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new StreamHandler()) .remoteAddress(server.localAddress()) .connect() @@ -76,6 +80,9 @@ private static void testCloseFromServerWhileInActive(QuicStreamType type, // Wait till the client was closed quicChannel.closeFuture().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); @@ -83,35 +90,38 @@ private static void testCloseFromServerWhileInActive(QuicStreamType type, } @Test - public void testCloseFromClientWhileInActiveUnidirectional() throws Exception { + public void testCloseFromClientWhileInActiveUnidirectional() throws Throwable { testCloseFromClientWhileInActive(QuicStreamType.UNIDIRECTIONAL, false); } @Test - public void testCloseFromClientWhileInActiveBidirectional() throws Exception { + public void testCloseFromClientWhileInActiveBidirectional() throws Throwable { testCloseFromClientWhileInActive(QuicStreamType.BIDIRECTIONAL, false); } @Test - public void testHalfCloseFromClientWhileInActiveUnidirectional() throws Exception { + public void testHalfCloseFromClientWhileInActiveUnidirectional() throws Throwable { testCloseFromClientWhileInActive(QuicStreamType.UNIDIRECTIONAL, true); } @Test - public void testHalfCloseFromClientWhileInActiveBidirectional() throws Exception { + public void testHalfCloseFromClientWhileInActiveBidirectional() throws Throwable { testCloseFromClientWhileInActive(QuicStreamType.BIDIRECTIONAL, true); } private static void testCloseFromClientWhileInActive(QuicStreamType type, - boolean halfClose) throws Exception { + boolean halfClose) throws Throwable { Channel server = null; Channel channel = null; try { final Promise streamPromise = ImmediateEventExecutor.INSTANCE.newPromise(); - server = QuicTestUtils.newServer(null, new StreamHandler()); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); + server = QuicTestUtils.newServer(serverHandler, new StreamHandler()); channel = QuicTestUtils.newClient(); + + StreamCreationHandler creationHandler = new StreamCreationHandler(type, halfClose, streamPromise); QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new StreamCreationHandler(type, halfClose, streamPromise)) + .handler(creationHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(server.localAddress()) .connect() @@ -126,13 +136,16 @@ private static void testCloseFromClientWhileInActive(QuicStreamType type, // Wait till the client was closed quicChannel.closeFuture().sync(); + + serverHandler.assertState(); + creationHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); } } - private static final class StreamCreationHandler extends ChannelInboundHandlerAdapter { + private static final class StreamCreationHandler extends QuicChannelValidationHandler { private final QuicStreamType type; private final boolean halfClose; private final Promise streamPromise; diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCreationTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCreationTest.java index e2c030a17..b916d5078 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCreationTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicStreamChannelCreationTest.java @@ -34,13 +34,16 @@ public class QuicStreamChannelCreationTest extends AbstractQuicTest { @Test public void testCreateStream() throws Throwable { - Channel server = QuicTestUtils.newServer(new ChannelInboundHandlerAdapter(), + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); + Channel server = QuicTestUtils.newServer(serverHandler, new ChannelInboundHandlerAdapter()); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(); + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); + try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(address) .connect() @@ -59,6 +62,9 @@ public void channelRegistered(ChannelHandlerContext ctx) { latch.await(); stream.close().sync(); quicChannel.close().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. @@ -68,13 +74,16 @@ public void channelRegistered(ChannelHandlerContext ctx) { @Test public void testCreateStreamViaBootstrap() throws Throwable { - Channel server = QuicTestUtils.newServer(new ChannelInboundHandlerAdapter(), + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); + Channel server = QuicTestUtils.newServer(serverHandler, new ChannelInboundHandlerAdapter()); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(); + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); + try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(address) .connect() @@ -96,6 +105,9 @@ public void channelRegistered(ChannelHandlerContext ctx) { latch.await(); stream.close().sync(); quicChannel.close().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().syncUninterruptibly(); // Close the parent Datagram channel as well. diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicStreamFrameTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicStreamFrameTest.java index 12a38e5ef..f522a4768 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicStreamFrameTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicStreamFrameTest.java @@ -32,24 +32,26 @@ public class QuicStreamFrameTest extends AbstractQuicTest { @Test - public void testCloseHalfClosureUnidirectional() throws Exception { + public void testCloseHalfClosureUnidirectional() throws Throwable { testCloseHalfClosure(QuicStreamType.UNIDIRECTIONAL); } @Test - public void testCloseHalfClosureBidirectional() throws Exception { + public void testCloseHalfClosureBidirectional() throws Throwable { testCloseHalfClosure(QuicStreamType.BIDIRECTIONAL); } - private static void testCloseHalfClosure(QuicStreamType type) throws Exception { + private static void testCloseHalfClosure(QuicStreamType type) throws Throwable { Channel server = null; Channel channel = null; + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); + QuicChannelValidationHandler clientHandler = new StreamCreationHandler(type); try { StreamHandler handler = new StreamHandler(); - server = QuicTestUtils.newServer(null, handler); + server = QuicTestUtils.newServer(serverHandler, handler); channel = QuicTestUtils.newClient(); QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new StreamCreationHandler(type)) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(server.localAddress()) .connect() @@ -57,13 +59,16 @@ private static void testCloseHalfClosure(QuicStreamType type) throws Exception { handler.assertSequence(); quicChannel.closeFuture().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); } } - private static final class StreamCreationHandler extends ChannelInboundHandlerAdapter { + private static final class StreamCreationHandler extends QuicChannelValidationHandler { private final QuicStreamType type; StreamCreationHandler(QuicStreamType type) { diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicStreamHalfClosureTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicStreamHalfClosureTest.java index c043b9346..4c049561d 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicStreamHalfClosureTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicStreamHalfClosureTest.java @@ -34,24 +34,26 @@ public class QuicStreamHalfClosureTest extends AbstractQuicTest { @Test - public void testCloseHalfClosureUnidirectional() throws Exception { + public void testCloseHalfClosureUnidirectional() throws Throwable { testCloseHalfClosure(QuicStreamType.UNIDIRECTIONAL); } @Test - public void testCloseHalfClosureBidirectional() throws Exception { + public void testCloseHalfClosureBidirectional() throws Throwable { testCloseHalfClosure(QuicStreamType.BIDIRECTIONAL); } - private static void testCloseHalfClosure(QuicStreamType type) throws Exception { + private static void testCloseHalfClosure(QuicStreamType type) throws Throwable { Channel server = null; Channel channel = null; + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); + QuicChannelValidationHandler clientHandler = new StreamCreationHandler(type); try { StreamHandler handler = new StreamHandler(); - server = QuicTestUtils.newServer(null, handler); + server = QuicTestUtils.newServer(serverHandler, handler); channel = QuicTestUtils.newClient(); QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new StreamCreationHandler(type)) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(server.localAddress()) .connect() @@ -59,13 +61,16 @@ private static void testCloseHalfClosure(QuicStreamType type) throws Exception { handler.assertSequence(); quicChannel.closeFuture().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); } } - private static final class StreamCreationHandler extends ChannelInboundHandlerAdapter { + private static final class StreamCreationHandler extends QuicChannelValidationHandler { private final QuicStreamType type; StreamCreationHandler(QuicStreamType type) { diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicStreamLimitTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicStreamLimitTest.java index 237fb95f0..0decc5d6c 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicStreamLimitTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicStreamLimitTest.java @@ -45,11 +45,12 @@ public void testStreamLimitEnforcedWhenCreatingViaClientUnidirectional() throws } private static void testStreamLimitEnforcedWhenCreatingViaClient(QuicStreamType type) throws Throwable { + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); Channel server = QuicTestUtils.newServer( QuicTestUtils.newQuicServerBuilder().initialMaxStreamsBidirectional(1) .initialMaxStreamsUnidirectional(1), InsecureQuicTokenHandler.INSTANCE, - null, new ChannelInboundHandlerAdapter() { + serverHandler, new ChannelInboundHandlerAdapter() { @Override public boolean isSharable() { return true; @@ -64,22 +65,26 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc }); InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(); + + CountDownLatch latch = new CountDownLatch(1); + CountDownLatch latch2 = new CountDownLatch(1); + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof QuicStreamLimitChangedEvent) { + if (latch.getCount() == 0) { + latch2.countDown(); + } else { + latch.countDown(); + } + } + super.userEventTriggered(ctx, evt); + } + }; try { - CountDownLatch latch = new CountDownLatch(1); - CountDownLatch latch2 = new CountDownLatch(1); - QuicChannel quicChannel = QuicChannel.newBootstrap(channel).handler( - new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof QuicStreamLimitChangedEvent) { - if (latch.getCount() == 0) { - latch2.countDown(); - } else { - latch.countDown(); - } - } - } - }).streamHandler(new ChannelInboundHandlerAdapter()) + QuicChannel quicChannel = QuicChannel.newBootstrap(channel) + .handler(clientHandler) + .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(address) .connect().get(); latch.await(); @@ -99,6 +104,9 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc assertEquals(1, quicChannel.peerAllowedStreams(type)); quicChannel.close().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. @@ -119,35 +127,30 @@ public void testStreamLimitEnforcedWhenCreatingViaServerUnidirectional() throws private static void testStreamLimitEnforcedWhenCreatingViaServer(QuicStreamType type) throws Throwable { Promise streamPromise = ImmediateEventExecutor.INSTANCE.newPromise(); Promise stream2Promise = ImmediateEventExecutor.INSTANCE.newPromise(); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + QuicChannel channel = (QuicChannel) ctx.channel(); + channel.createStream(type, new ChannelInboundHandlerAdapter()) + .addListener((Future future) -> { + if (future.isSuccess()) { + QuicStreamChannel stream = future.getNow(); + streamPromise.setSuccess(null); + channel.createStream(type, new ChannelInboundHandlerAdapter()) + .addListener((Future f) -> { + stream.close(); + stream2Promise.setSuccess(f.cause()); + }); + } else { + streamPromise.setFailure(future.cause()); + } + }); + } + }; Channel server = QuicTestUtils.newServer( QuicTestUtils.newQuicServerBuilder(), InsecureQuicTokenHandler.INSTANCE, - new ChannelInboundHandlerAdapter() { - - @Override - public void channelActive(ChannelHandlerContext ctx) { - QuicChannel channel = (QuicChannel) ctx.channel(); - channel.createStream(type, new ChannelInboundHandlerAdapter()) - .addListener((Future future) -> { - if (future.isSuccess()) { - QuicStreamChannel stream = future.getNow(); - streamPromise.setSuccess(null); - channel.createStream(type, new ChannelInboundHandlerAdapter()) - .addListener((Future f) -> { - stream.close(); - stream2Promise.setSuccess(f.cause()); - }); - } else { - streamPromise.setFailure(future.cause()); - } - }); - } - - @Override - public boolean isSharable() { - return true; - } - }, new ChannelInboundHandlerAdapter() { + serverHandler, new ChannelInboundHandlerAdapter() { @Override public boolean isSharable() { return true; @@ -156,9 +159,11 @@ public boolean isSharable() { InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder() .initialMaxStreamsBidirectional(1).initialMaxStreamsUnidirectional(1)); + + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(address) .connect().get(); @@ -166,6 +171,9 @@ public boolean isSharable() { // Second stream creation should fail. assertThat(stream2Promise.get(), CoreMatchers.instanceOf(IOException.class)); quicChannel.close().sync(); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicStreamTypeTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicStreamTypeTest.java index cb1d61d03..74e0789bb 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicStreamTypeTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicStreamTypeTest.java @@ -35,12 +35,15 @@ public class QuicStreamTypeTest extends AbstractQuicTest { @Test - public void testUnidirectionalCreatedByClient() throws Exception { + public void testUnidirectionalCreatedByClient() throws Throwable { Channel server = null; Channel channel = null; + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); + try { Promise serverWritePromise = ImmediateEventExecutor.INSTANCE.newPromise(); - server = QuicTestUtils.newServer(null, new ChannelInboundHandlerAdapter() { + server = QuicTestUtils.newServer(serverHandler, new ChannelInboundHandlerAdapter() { @Override public void channelActive(ChannelHandlerContext ctx) { QuicStreamChannel channel = (QuicStreamChannel) ctx.channel(); @@ -58,7 +61,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { channel = QuicTestUtils.newClient(); QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(server.localAddress()) .connect() @@ -73,6 +76,9 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { streamChannel.close().sync(); quicChannel.close().sync(); assertThat(serverWritePromise.get(), instanceOf(UnsupportedOperationException.class)); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); @@ -80,31 +86,33 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } @Test - public void testUnidirectionalCreatedByServer() throws Exception { + public void testUnidirectionalCreatedByServer() throws Throwable { Channel server = null; Channel channel = null; - try { - Promise serverWritePromise = ImmediateEventExecutor.INSTANCE.newPromise(); - Promise clientWritePromise = ImmediateEventExecutor.INSTANCE.newPromise(); + Promise serverWritePromise = ImmediateEventExecutor.INSTANCE.newPromise(); + Promise clientWritePromise = ImmediateEventExecutor.INSTANCE.newPromise(); - server = QuicTestUtils.newServer(new ChannelInboundHandlerAdapter() { - @Override - public void channelActive(ChannelHandlerContext ctx) { - QuicChannel channel = (QuicChannel) ctx.channel(); - channel.createStream(QuicStreamType.UNIDIRECTIONAL, new ChannelInboundHandlerAdapter() { - @Override - public void channelActive(ChannelHandlerContext ctx) { - // Do the write which should succeed - ctx.writeAndFlush(Unpooled.buffer().writeZero(8)) - .addListener(new PromiseNotifier<>(serverWritePromise)); - } - }); - } - }, new ChannelInboundHandlerAdapter()); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + QuicChannel channel = (QuicChannel) ctx.channel(); + channel.createStream(QuicStreamType.UNIDIRECTIONAL, new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + // Do the write which should succeed + ctx.writeAndFlush(Unpooled.buffer().writeZero(8)) + .addListener(new PromiseNotifier<>(serverWritePromise)); + } + }); + } + }; + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); + try { + server = QuicTestUtils.newServer(serverHandler, new ChannelInboundHandlerAdapter()); channel = QuicTestUtils.newClient(); QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter() { @Override public void channelActive(ChannelHandlerContext ctx) { @@ -133,6 +141,9 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { quicChannel.closeFuture().sync(); assertTrue(serverWritePromise.await().isSuccess()); assertThat(clientWritePromise.get(), instanceOf(UnsupportedOperationException.class)); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { QuicTestUtils.closeIfNotNull(channel); QuicTestUtils.closeIfNotNull(server); diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java index 4ec510ab2..49266587e 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java @@ -53,10 +53,11 @@ private static void testCorrectlyHandleWritability(boolean readInComplete) throw Promise writePromise = ImmediateEventExecutor.INSTANCE.newPromise(); final AtomicReference serverErrorRef = new AtomicReference<>(); final AtomicReference clientErrorRef = new AtomicReference<>(); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); Channel server = QuicTestUtils.newServer( QuicTestUtils.newQuicServerBuilder().initialMaxStreamsBidirectional(5000), InsecureQuicTokenHandler.INSTANCE, - null, new ChannelInboundHandlerAdapter() { + serverHandler, new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { @@ -79,9 +80,11 @@ public boolean isSharable() { InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder() .initialMaxStreamDataBidirectionalLocal(bufferSize / 4)); + + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(address) .connect() @@ -143,6 +146,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { throwIfNotNull(serverErrorRef); throwIfNotNull(clientErrorRef); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well. @@ -160,10 +166,11 @@ public void testBytesUntilUnwritable() throws Throwable { int firstWriteNumBytes = 8; int maxData = 32 * 1024; final AtomicLong beforeWritableRef = new AtomicLong(); + QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler(); Channel server = QuicTestUtils.newServer( QuicTestUtils.newQuicServerBuilder().initialMaxStreamsBidirectional(5000), InsecureQuicTokenHandler.INSTANCE, - null, new ChannelInboundHandlerAdapter() { + serverHandler, new ChannelInboundHandlerAdapter() { private int numBytesRead; @Override @@ -210,9 +217,11 @@ public boolean isSharable() { InetSocketAddress address = (InetSocketAddress) server.localAddress(); Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder() .initialMaxStreamDataBidirectionalLocal(maxData)); + + QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler(); try { QuicChannel quicChannel = QuicChannel.newBootstrap(channel) - .handler(new ChannelInboundHandlerAdapter()) + .handler(clientHandler) .streamHandler(new ChannelInboundHandlerAdapter()) .remoteAddress(address) .connect() @@ -264,6 +273,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { throwIfNotNull(serverErrorRef); throwIfNotNull(clientErrorRef); + + serverHandler.assertState(); + clientHandler.assertState(); } finally { server.close().sync(); // Close the parent Datagram channel as well.