From f66d7fc54d23fa441d0feb9873952a325346cbc4 Mon Sep 17 00:00:00 2001 From: vinodhabib <47808007+vinodhabib@users.noreply.github.com> Date: Tue, 3 Dec 2024 00:39:25 +0530 Subject: [PATCH] netty: Fix ByteBuf leaks in tests (#11593) Part of #3353 --- .../io/grpc/netty/NettyClientHandlerTest.java | 6 ++- .../io/grpc/netty/NettyHandlerTestBase.java | 41 +++++++++++-------- .../io/grpc/netty/NettyServerHandlerTest.java | 28 +++++-------- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 6c5dd6b18bc..e5c97e9efd9 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -217,6 +217,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { // Simulate receipt of initial remote settings. ByteBuf serializedSettings = serializeSettings(new Http2Settings()); channelRead(serializedSettings); + channel().releaseOutbound(); } @Test @@ -342,11 +343,12 @@ public void sendFrameShouldSucceed() throws Exception { createStream(); // Send a frame and verify that it was written. + ByteBuf content = content(); ChannelFuture future - = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true)); + = enqueue(new SendGrpcFrameCommand(streamTransportState, content, true)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verifyNoMoreInteractions(mockKeepAliveManager); diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index eef8d30e05a..c971294fbb6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -38,7 +38,6 @@ import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; @@ -68,6 +67,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; +import org.junit.After; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,7 +84,6 @@ public abstract class NettyHandlerTestBase { protected static final int STREAM_ID = 3; - private ByteBuf content; private EmbeddedChannel channel; @@ -106,18 +105,24 @@ protected void manualSetUp() throws Exception {} protected final TransportTracer transportTracer = new TransportTracer(); protected int flowControlWindow = DEFAULT_WINDOW_SIZE; protected boolean autoFlowControl = false; - private final FakeClock fakeClock = new FakeClock(); FakeClock fakeClock() { return fakeClock; } + @After + public void tearDown() throws Exception { + if (channel() != null) { + channel().releaseInbound(); + channel().releaseOutbound(); + } + } + /** * Must be called by subclasses to initialize the handler and channel. */ protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { - content = Unpooled.copiedBuffer("hello world", UTF_8); frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -233,11 +238,11 @@ protected final Http2FrameReader frameReader() { } protected final ByteBuf content() { - return content; + return Unpooled.copiedBuffer(contentAsArray()); } protected final byte[] contentAsArray() { - return ByteBufUtil.getBytes(content()); + return "\000\000\000\000\rhello world".getBytes(UTF_8); } protected final Http2FrameWriter verifyWrite() { @@ -252,8 +257,8 @@ protected final void channelRead(Object obj) throws Exception { channel.writeInbound(obj); } - protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { - final ByteBuf compressionFrame = Unpooled.buffer(content.length); + protected ByteBuf grpcFrame(byte[] message) { + final ByteBuf compressionFrame = Unpooled.buffer(message.length); MessageFramer framer = new MessageFramer( new MessageFramer.Sink() { @Override @@ -262,23 +267,22 @@ public void deliverFrame( if (frame != null) { ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); compressionFrame.writeBytes(bytebuf); + bytebuf.release(); } } }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP); - framer.writePayload(new ByteArrayInputStream(content)); - framer.flush(); - ChannelHandlerContext ctx = newMockContext(); - new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, - newPromise()); - return captureWrite(ctx); + framer.writePayload(new ByteArrayInputStream(message)); + framer.close(); + return compressionFrame; } - protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { - // Need to retain the content since the frameWriter releases it. - content.retain(); + protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { + return dataFrame(streamId, endStream, grpcFrame(content)); + } + protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); return captureWrite(ctx); @@ -410,6 +414,7 @@ public void dataSizeSincePingAccumulates() throws Exception { channelRead(dataFrame(3, false, buff.copy())); assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); + buff.release(); } @Test @@ -608,12 +613,14 @@ public void bdpPingWindowResizing() throws Exception { private void readPingAck(long pingData) throws Exception { channelRead(pingFrame(true, pingData)); + channel().releaseOutbound(); } private void readXCopies(int copies, byte[] data) throws Exception { for (int i = 0; i < copies; i++) { channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it stream().request(1); // consume it + channel().releaseOutbound(); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 308079ff62f..54c1375eef2 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -43,6 +43,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -74,7 +75,6 @@ import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; @@ -120,23 +120,16 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase streamListenerMessageQueue = new LinkedList<>(); private int maxConcurrentStreams = Integer.MAX_VALUE; @@ -208,6 +201,7 @@ protected void manualSetUp() throws Exception { // Simulate receipt of initial remote settings. ByteBuf serializedSettings = serializeSettings(new Http2Settings()); channelRead(serializedSettings); + channel().releaseOutbound(); } @Test @@ -229,10 +223,11 @@ public void sendFrameShouldSucceed() throws Exception { createStream(); // Send a frame and verify that it was written. + ByteBuf content = content(); ChannelFuture future = enqueue( - new SendGrpcFrameCommand(stream.transportState(), content(), false)); + new SendGrpcFrameCommand(stream.transportState(), content, false)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false), any(ChannelPromise.class)); } @@ -267,10 +262,11 @@ private void inboundDataShouldForwardToStreamListener(boolean endStream) throws // Create a data frame and then trigger the handler to read it. ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray()); channelRead(frame); + channel().releaseOutbound(); verify(streamListener, atLeastOnce()) .messagesAvailable(any(StreamListener.MessageProducer.class)); InputStream message = streamListenerMessageQueue.poll(); - assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message)); + assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message)); message.close(); assertNull("no additional message expected", streamListenerMessageQueue.poll()); @@ -870,7 +866,7 @@ public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception { future.get(); for (int i = 0; i < 10; i++) { future = enqueue( - new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false)); + new SendGrpcFrameCommand(stream.transportState(), content(), false)); future.get(); channel().releaseOutbound(); channelRead(pingFrame(false /* isAck */, 1L)); @@ -1293,6 +1289,7 @@ public void maxRstCount_withinLimit_succeeds() throws Exception { maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); manualSetUp(); rapidReset(maxRstCount); + assertTrue(channel().isOpen()); } @@ -1302,6 +1299,7 @@ public void maxRstCount_exceedsLimit_fails() throws Exception { maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); manualSetUp(); assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1)); + assertFalse(channel().isOpen()); } @@ -1344,11 +1342,7 @@ private void createStream() throws Exception { private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception { ByteBuf buf = NettyTestUtil.messageFrame(""); - try { - return dataFrame(streamId, endStream, buf); - } finally { - buf.release(); - } + return dataFrame(streamId, endStream, buf); } @Override