Skip to content

Commit

Permalink
netty: Fix ByteBuf leaks in tests (#11593)
Browse files Browse the repository at this point in the history
Part of #3353
  • Loading branch information
vinodhabib authored Dec 2, 2024
1 parent 7f9c1f3 commit f66d7fc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
6 changes: 4 additions & 2 deletions netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
// Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings);
channel().releaseOutbound();
}

@Test
Expand Down Expand Up @@ -342,11 +343,12 @@ public void sendFrameShouldSucceed() throws Exception {
createStream();

// Send a frame and verify that it was written.
ByteBuf content = content();
ChannelFuture future
= enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
= enqueue(new SendGrpcFrameCommand(streamTransportState, content, true));

assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true),
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true),
any(ChannelPromise.class));
verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive
verifyNoMoreInteractions(mockKeepAliveManager);
Expand Down
41 changes: 24 additions & 17 deletions netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
Expand Down Expand Up @@ -68,6 +67,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -84,7 +84,6 @@
public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {

protected static final int STREAM_ID = 3;
private ByteBuf content;

private EmbeddedChannel channel;

Expand All @@ -106,18 +105,24 @@ protected void manualSetUp() throws Exception {}
protected final TransportTracer transportTracer = new TransportTracer();
protected int flowControlWindow = DEFAULT_WINDOW_SIZE;
protected boolean autoFlowControl = false;

private final FakeClock fakeClock = new FakeClock();

FakeClock fakeClock() {
return fakeClock;
}

@After
public void tearDown() throws Exception {
if (channel() != null) {
channel().releaseInbound();
channel().releaseOutbound();
}
}

/**
* Must be called by subclasses to initialize the handler and channel.
*/
protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception {
content = Unpooled.copiedBuffer("hello world", UTF_8);
frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter()));
frameReader = new DefaultHttp2FrameReader(headersDecoder);

Expand Down Expand Up @@ -233,11 +238,11 @@ protected final Http2FrameReader frameReader() {
}

protected final ByteBuf content() {
return content;
return Unpooled.copiedBuffer(contentAsArray());
}

protected final byte[] contentAsArray() {
return ByteBufUtil.getBytes(content());
return "\000\000\000\000\rhello world".getBytes(UTF_8);
}

protected final Http2FrameWriter verifyWrite() {
Expand All @@ -252,8 +257,8 @@ protected final void channelRead(Object obj) throws Exception {
channel.writeInbound(obj);
}

protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
final ByteBuf compressionFrame = Unpooled.buffer(content.length);
protected ByteBuf grpcFrame(byte[] message) {
final ByteBuf compressionFrame = Unpooled.buffer(message.length);
MessageFramer framer = new MessageFramer(
new MessageFramer.Sink() {
@Override
Expand All @@ -262,23 +267,22 @@ public void deliverFrame(
if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf);
bytebuf.release();
}
}
},
new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT),
StatsTraceContext.NOOP);
framer.writePayload(new ByteArrayInputStream(content));
framer.flush();
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
newPromise());
return captureWrite(ctx);
framer.writePayload(new ByteArrayInputStream(message));
framer.close();
return compressionFrame;
}

protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
// Need to retain the content since the frameWriter releases it.
content.retain();
protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
return dataFrame(streamId, endStream, grpcFrame(content));
}

protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise());
return captureWrite(ctx);
Expand Down Expand Up @@ -410,6 +414,7 @@ public void dataSizeSincePingAccumulates() throws Exception {
channelRead(dataFrame(3, false, buff.copy()));

assertEquals(length * 3, handler.flowControlPing().getDataSincePing());
buff.release();
}

@Test
Expand Down Expand Up @@ -608,12 +613,14 @@ public void bdpPingWindowResizing() throws Exception {

private void readPingAck(long pingData) throws Exception {
channelRead(pingFrame(true, pingData));
channel().releaseOutbound();
}

private void readXCopies(int copies, byte[] data) throws Exception {
for (int i = 0; i < copies; i++) {
channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it
stream().request(1); // consume it
channel().releaseOutbound();
}
}

Expand Down
28 changes: 11 additions & 17 deletions netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
Expand Down Expand Up @@ -74,7 +75,6 @@
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
Expand Down Expand Up @@ -120,23 +120,16 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10));
@Rule
public final MockitoRule mocks = MockitoJUnit.rule();

private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");


@Mock
private ServerStreamListener streamListener;

@Mock
private ServerStreamTracer.Factory streamTracerFactory;

private final ServerTransportListener transportListener =
mock(ServerTransportListener.class, delegatesTo(new ServerTransportListenerImpl()));
private final TestServerStreamTracer streamTracer = new TestServerStreamTracer();

private NettyServerStream stream;
private KeepAliveManager spyKeepAliveManager;

final Queue<InputStream> streamListenerMessageQueue = new LinkedList<>();

private int maxConcurrentStreams = Integer.MAX_VALUE;
Expand Down Expand Up @@ -208,6 +201,7 @@ protected void manualSetUp() throws Exception {
// Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings);
channel().releaseOutbound();
}

@Test
Expand All @@ -229,10 +223,11 @@ public void sendFrameShouldSucceed() throws Exception {
createStream();

// Send a frame and verify that it was written.
ByteBuf content = content();
ChannelFuture future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content(), false));
new SendGrpcFrameCommand(stream.transportState(), content, false));
assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false),
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false),
any(ChannelPromise.class));
}

Expand Down Expand Up @@ -267,10 +262,11 @@ private void inboundDataShouldForwardToStreamListener(boolean endStream) throws
// Create a data frame and then trigger the handler to read it.
ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
channelRead(frame);
channel().releaseOutbound();
verify(streamListener, atLeastOnce())
.messagesAvailable(any(StreamListener.MessageProducer.class));
InputStream message = streamListenerMessageQueue.poll();
assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message));
assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message));
message.close();
assertNull("no additional message expected", streamListenerMessageQueue.poll());

Expand Down Expand Up @@ -870,7 +866,7 @@ public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
future.get();
for (int i = 0; i < 10; i++) {
future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false));
new SendGrpcFrameCommand(stream.transportState(), content(), false));
future.get();
channel().releaseOutbound();
channelRead(pingFrame(false /* isAck */, 1L));
Expand Down Expand Up @@ -1293,6 +1289,7 @@ public void maxRstCount_withinLimit_succeeds() throws Exception {
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
rapidReset(maxRstCount);

assertTrue(channel().isOpen());
}

Expand All @@ -1302,6 +1299,7 @@ public void maxRstCount_exceedsLimit_fails() throws Exception {
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1));

assertFalse(channel().isOpen());
}

Expand Down Expand Up @@ -1344,11 +1342,7 @@ private void createStream() throws Exception {

private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
ByteBuf buf = NettyTestUtil.messageFrame("");
try {
return dataFrame(streamId, endStream, buf);
} finally {
buf.release();
}
return dataFrame(streamId, endStream, buf);
}

@Override
Expand Down

0 comments on commit f66d7fc

Please sign in to comment.