diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicChannelInitializer.java b/src/main/java/io/netty/incubator/codec/quic/QuicChannelInitializer.java index b015eb15f..4b92c4c8a 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicChannelInitializer.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicChannelInitializer.java @@ -24,18 +24,42 @@ import java.util.Objects; /** - * {@link ChannelInitializer} for {@link QuicChannel}s. + * {@link ChannelInitializer} for {@link QuicChannel}s and it's accepted {@link QuicStreamChannel}s. */ -public class QuicChannelInitializer extends ChannelInitializer { +public final class QuicChannelInitializer extends ChannelInitializer { + private final ChannelHandler quicChannelHandler; private final ChannelHandler streamChannelHandler; + /** + * Create a new instance. + * + * @param streamChannelHandler The {@link ChannelHandler} that is added to the + * {@link io.netty.channel.ChannelPipeline} of the accepted {@link QuicStreamChannel}. + */ public QuicChannelInitializer(ChannelHandler streamChannelHandler) { + this(null, streamChannelHandler); + } + + /** + * Create a new instance. + * + * @param quicChannelHandler The {@link ChannelHandler} that is added to the + * {@link io.netty.channel.ChannelPipeline} of the {@link QuicChannel} or {@code null} + * if none should be added. + * @param streamChannelHandler The {@link ChannelHandler} that is added to the + * {@link io.netty.channel.ChannelPipeline} of the accepted {@link QuicStreamChannel}. + */ + public QuicChannelInitializer(ChannelHandler quicChannelHandler, ChannelHandler streamChannelHandler) { + this.quicChannelHandler = quicChannelHandler; this.streamChannelHandler = Objects.requireNonNull(streamChannelHandler, "streamChannelHandler"); } @Override - protected final void initChannel(QuicChannel channel) { + protected void initChannel(QuicChannel channel) { + if (quicChannelHandler != null) { + channel.pipeline().addLast(quicChannelHandler); + } channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicExample.java b/src/test/java/io/netty/incubator/codec/quic/QuicExample.java index 065385334..004005e56 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicExample.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicExample.java @@ -53,30 +53,44 @@ public static void main(String[] args) throws Exception { .initialMaxStreamsUnidirectional(100) .disableActiveMigration(true) .enableEarlyData() - .buildServerCodec(InsecureQuicTokenHandler.INSTANCE, new QuicChannelInitializer( - new ChannelInboundHandlerAdapter() { + .buildServerCodec(InsecureQuicTokenHandler.INSTANCE, + new QuicChannelInitializer( + // ChannelHandler that is added into QuicChannel pipeline. + new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + QuicChannel channel = (QuicChannel) ctx.channel(); + // Create streams etc.. + } - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - ByteBuf byteBuf = (ByteBuf) msg; - try { - if (byteBuf.toString(CharsetUtil.US_ASCII).trim().equals("GET /")) { - ByteBuf buffer = ctx.alloc().directBuffer(); - buffer.writeCharSequence("Hello World!\r\n", CharsetUtil.US_ASCII); + @Override + public boolean isSharable() { + return true; + } + }, + // ChannelHandler that is added into QuicStreamChannel pipeline. + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ByteBuf byteBuf = (ByteBuf) msg; + try { + if (byteBuf.toString(CharsetUtil.US_ASCII).trim().equals("GET /")) { + ByteBuf buffer = ctx.alloc().directBuffer(); + buffer.writeCharSequence("Hello World!\r\n", CharsetUtil.US_ASCII); - // Write the buffer and close the stream once the write completes. - ctx.writeAndFlush(buffer).addListener(ChannelFutureListener.CLOSE); + // Write the buffer and close the stream once the write completes. + ctx.writeAndFlush(buffer).addListener(ChannelFutureListener.CLOSE); + } + } finally { + byteBuf.release(); + } } - } finally { - byteBuf.release(); - } - } - @Override - public boolean isSharable() { - return true; - } - })); + @Override + public boolean isSharable() { + return true; + } + })); try { Bootstrap bs = new Bootstrap(); Channel channel = bs.group(group)