diff --git a/server/src/main/java/org/opensearch/transport/InboundAggregator.java b/server/src/main/java/org/opensearch/transport/InboundAggregator.java index c19ecd1f72b60..d96eb8fab935a 100644 --- a/server/src/main/java/org/opensearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/opensearch/transport/InboundAggregator.java @@ -128,7 +128,7 @@ public ProtocolInboundMessage finishAggregation() throws IOException { } final BreakerControl breakerControl = new BreakerControl(circuitBreaker); - final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl); + final ProtocolInboundMessage aggregated = new ProtocolInboundMessage(currentHeader, releasableContent, breakerControl); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { @@ -143,7 +143,7 @@ public ProtocolInboundMessage finishAggregation() throws IOException { if (isShortCircuited()) { aggregated.close(); success = true; - return new NativeInboundMessage(aggregated.getHeader(), aggregationException); + return new ProtocolInboundMessage(aggregated.getHeader(), aggregationException); } else { success = true; return aggregated; diff --git a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java index 276891212e43f..22a1c5ce451bd 100644 --- a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java @@ -9,24 +9,144 @@ package org.opensearch.transport; import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.core.common.bytes.CompositeBytesReference; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; -import java.io.Closeable; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.function.BiConsumer; /** - * Interface for handling inbound bytes. Can be implemented by different transport protocols. + * Handler for inbound bytes for the native protocol. */ -public interface InboundBytesHandler extends Closeable { +class InboundBytesHandler { + + private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); + + private final ArrayDeque pending; + private final InboundDecoder decoder; + private final InboundAggregator aggregator; + private final StatsTracker statsTracker; + private boolean isClosed = false; + + InboundBytesHandler( + ArrayDeque pending, + InboundDecoder decoder, + InboundAggregator aggregator, + StatsTracker statsTracker + ) { + this.pending = pending; + this.decoder = decoder; + this.aggregator = aggregator; + this.statsTracker = statsTracker; + } + + public void close() { + isClosed = true; + } public void doHandleBytes( TcpChannel channel, ReleasableBytesReference reference, BiConsumer messageHandler - ) throws IOException; + ) throws IOException { + final ArrayList fragments = fragmentList.get(); + boolean continueHandling = true; + + while (continueHandling && isClosed == false) { + boolean continueDecoding = true; + while (continueDecoding && pending.isEmpty() == false) { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + final int bytesDecoded = decoder.decode(toDecode, fragments::add); + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); + if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + continueDecoding = false; + } + } else { + continueDecoding = false; + } + } + } + + if (fragments.isEmpty()) { + continueHandling = false; + } else { + try { + forwardFragments(channel, fragments, messageHandler); + } finally { + for (Object fragment : fragments) { + if (fragment instanceof ReleasableBytesReference) { + ((ReleasableBytesReference) fragment).close(); + } + } + fragments.clear(); + } + } + } + } - public boolean canHandleBytes(ReleasableBytesReference reference); + private ReleasableBytesReference getPendingBytes() { + if (pending.size() == 1) { + return pending.peekFirst().retain(); + } else { + final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; + int index = 0; + for (ReleasableBytesReference pendingReference : pending) { + bytesReferences[index] = pendingReference.retain(); + ++index; + } + final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); + return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); + } + } + + private void releasePendingBytes(int bytesConsumed) { + int bytesToRelease = bytesConsumed; + while (bytesToRelease != 0) { + try (ReleasableBytesReference reference = pending.pollFirst()) { + assert reference != null; + if (bytesToRelease < reference.length()) { + pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); + bytesToRelease -= bytesToRelease; + } else { + bytesToRelease -= reference.length(); + } + } + } + } + + private boolean endOfMessage(Object fragment) { + return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; + } + + private void forwardFragments( + TcpChannel channel, + ArrayList fragments, + BiConsumer messageHandler + ) throws IOException { + for (Object fragment : fragments) { + if (fragment instanceof Header) { + assert aggregator.isAggregating() == false; + aggregator.headerReceived((Header) fragment); + } else if (fragment == InboundDecoder.PING) { + assert aggregator.isAggregating() == false; + messageHandler.accept(channel, ProtocolInboundMessage.PING); + } else if (fragment == InboundDecoder.END_CONTENT) { + assert aggregator.isAggregating(); + try (ProtocolInboundMessage aggregated = aggregator.finishAggregation()) { + statsTracker.markMessageReceived(); + messageHandler.accept(channel, aggregated); + } + } else { + assert aggregator.isAggregating(); + assert fragment instanceof ReleasableBytesReference; + aggregator.aggregate((ReleasableBytesReference) fragment); + } + } + } - @Override - void close(); } diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index 597ab0673ab4b..5590c3e330b19 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -38,11 +38,9 @@ import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; -import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler; import java.io.IOException; import java.util.ArrayDeque; -import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongSupplier; @@ -94,7 +92,7 @@ public InboundPipeline( this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; - this.bytesHandler = new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker); + this.bytesHandler = new InboundBytesHandler(pending, decoder, aggregator, statsTracker); this.messageHandler = messageHandler; } diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java index 4c972fdc14fa5..651f998d96f86 100644 --- a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -119,18 +119,17 @@ public void messageReceived( long slowLogThresholdMs, TransportMessageListener messageListener ) throws IOException { - NativeInboundMessage inboundMessage = (NativeInboundMessage) message; - TransportLogger.logInboundMessage(channel, inboundMessage); - if (inboundMessage.isPing()) { + TransportLogger.logInboundMessage(channel, message); + if (message.isPing()) { keepAlive.receiveKeepAlive(channel); } else { - handleMessage(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener); + handleMessage(channel, message, startTime, slowLogThresholdMs, messageListener); } } private void handleMessage( TcpChannel channel, - NativeInboundMessage message, + ProtocolInboundMessage message, long startTime, long slowLogThresholdMs, TransportMessageListener messageListener @@ -202,7 +201,7 @@ private Map> extractHeaders(Map heade private void handleRequest( TcpChannel channel, Header header, - NativeInboundMessage message, + ProtocolInboundMessage message, TransportMessageListener messageListener ) throws IOException { final String action = header.getActionName(); diff --git a/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java b/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java index d4ecb0f5d2941..ea860e19cc4bd 100644 --- a/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java @@ -8,51 +8,52 @@ package org.opensearch.transport; +import java.io.IOException; + import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.core.common.io.stream.StreamInput; /** - * Base class for inbound data as a message. - * Different implementations are used for different protocols. + * Inbound data as a message. * * @opensearch.internal */ @PublicApi(since = "2.14.0") -public abstract class ProtocolInboundMessage implements Releasable { +public class ProtocolInboundMessage implements Releasable { + + static final ProtocolInboundMessage PING = new ProtocolInboundMessage(null, null, null, true, null); protected final Header header; protected final ReleasableBytesReference content; protected final Exception exception; protected final boolean isPing; private Releasable breakerRelease; + private StreamInput streamInput; - public ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { - this.header = header; - this.content = content; - this.breakerRelease = breakerRelease; - this.exception = null; - this.isPing = false; + ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { + this(header,content, null, false, breakerRelease); } - public ProtocolInboundMessage(Header header, Exception exception) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = exception; - this.isPing = false; + ProtocolInboundMessage(Header header, Exception exception) { + this(header, null, exception, false, null); } - public ProtocolInboundMessage(Header header, boolean isPing) { + private ProtocolInboundMessage(Header header, ReleasableBytesReference content, Exception exception, boolean isPing, Releasable breakerRelease) { this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = null; + this.content = content; + this.exception = exception; this.isPing = isPing; + this.breakerRelease = breakerRelease; } TransportProtocol getTransportProtocol() { + if (isPing) { + return TransportProtocol.NATIVE; + } return header.getTransportProtocol(); } @@ -60,11 +61,11 @@ public String getProtocol() { return header.getTransportProtocol().toString(); } - public Header getHeader() { + Header getHeader() { return header; } - public int getContentLength() { + int getContentLength() { if (content == null) { return 0; } else { @@ -76,15 +77,15 @@ public Exception getException() { return exception; } - public boolean isPing() { + boolean isPing() { return isPing; } - public boolean isShortCircuit() { + boolean isShortCircuit() { return exception != null; } - public Releasable takeBreakerReleaseControl() { + Releasable takeBreakerReleaseControl() { final Releasable toReturn = breakerRelease; breakerRelease = null; if (toReturn != null) { @@ -94,15 +95,23 @@ public Releasable takeBreakerReleaseControl() { } } - + StreamInput openOrGetStreamInput() throws IOException { + assert isPing == false && content != null; + if (streamInput == null) { + streamInput = content.streamInput(); + streamInput.setVersion(header.getVersion()); + } + return streamInput; + } @Override public void close() { + IOUtils.closeWhileHandlingException(streamInput); Releasables.closeWhileHandlingException(content, breakerRelease); } @Override public String toString() { - return "InboundMessage{" + header + "}"; + return "ProtocolInboundMessage{" + header + "}"; } } diff --git a/server/src/main/java/org/opensearch/transport/TransportLogger.java b/server/src/main/java/org/opensearch/transport/TransportLogger.java index e780f643aafd7..d1968c92d16c2 100644 --- a/server/src/main/java/org/opensearch/transport/TransportLogger.java +++ b/server/src/main/java/org/opensearch/transport/TransportLogger.java @@ -65,7 +65,7 @@ static void logInboundMessage(TcpChannel channel, BytesReference message) { } } - static void logInboundMessage(TcpChannel channel, NativeInboundMessage message) { + static void logInboundMessage(TcpChannel channel, ProtocolInboundMessage message) { if (logger.isTraceEnabled()) { try { String logMessage = format(channel, message, "READ"); @@ -137,7 +137,7 @@ private static String format(TcpChannel channel, BytesReference message, String return sb.toString(); } - private static String format(TcpChannel channel, NativeInboundMessage message, String event) throws IOException { + private static String format(TcpChannel channel, ProtocolInboundMessage message, String event) throws IOException { final StringBuilder sb = new StringBuilder(); sb.append(channel); diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java deleted file mode 100644 index 9290ec8985161..0000000000000 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.transport.nativeprotocol; - -import org.opensearch.common.bytes.ReleasableBytesReference; -import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.core.common.bytes.CompositeBytesReference; -import org.opensearch.transport.Header; -import org.opensearch.transport.InboundAggregator; -import org.opensearch.transport.InboundBytesHandler; -import org.opensearch.transport.InboundDecoder; -import org.opensearch.transport.ProtocolInboundMessage; -import org.opensearch.transport.StatsTracker; -import org.opensearch.transport.TcpChannel; - -import java.io.IOException; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.function.BiConsumer; - -/** - * Handler for inbound bytes for the native protocol. - */ -public class NativeInboundBytesHandler implements InboundBytesHandler { - - private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - private static final NativeInboundMessage PING_MESSAGE = new NativeInboundMessage(null, true); - - private final ArrayDeque pending; - private final InboundDecoder decoder; - private final InboundAggregator aggregator; - private final StatsTracker statsTracker; - private boolean isClosed = false; - - public NativeInboundBytesHandler( - ArrayDeque pending, - InboundDecoder decoder, - InboundAggregator aggregator, - StatsTracker statsTracker - ) { - this.pending = pending; - this.decoder = decoder; - this.aggregator = aggregator; - this.statsTracker = statsTracker; - } - - @Override - public void close() { - isClosed = true; - } - - @Override - public boolean canHandleBytes(ReleasableBytesReference reference) { - return true; - } - - @Override - public void doHandleBytes( - TcpChannel channel, - ReleasableBytesReference reference, - BiConsumer messageHandler - ) throws IOException { - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { - continueDecoding = false; - } - } else { - continueDecoding = false; - } - } - } - - if (fragments.isEmpty()) { - continueHandling = false; - } else { - try { - forwardFragments(channel, fragments, messageHandler); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); - } - } - fragments.clear(); - } - } - } - } - - private ReleasableBytesReference getPendingBytes() { - if (pending.size() == 1) { - return pending.peekFirst().retain(); - } else { - final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; - int index = 0; - for (ReleasableBytesReference pendingReference : pending) { - bytesReferences[index] = pendingReference.retain(); - ++index; - } - final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); - return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); - } - } - - private void releasePendingBytes(int bytesConsumed) { - int bytesToRelease = bytesConsumed; - while (bytesToRelease != 0) { - try (ReleasableBytesReference reference = pending.pollFirst()) { - assert reference != null; - if (bytesToRelease < reference.length()) { - pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); - bytesToRelease -= bytesToRelease; - } else { - bytesToRelease -= reference.length(); - } - } - } - } - - private boolean endOfMessage(Object fragment) { - return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; - } - - private void forwardFragments( - TcpChannel channel, - ArrayList fragments, - BiConsumer messageHandler - ) throws IOException { - for (Object fragment : fragments) { - if (fragment instanceof Header) { - assert aggregator.isAggregating() == false; - aggregator.headerReceived((Header) fragment); - } else if (fragment == InboundDecoder.PING) { - assert aggregator.isAggregating() == false; - messageHandler.accept(channel, PING_MESSAGE); - } else if (fragment == InboundDecoder.END_CONTENT) { - assert aggregator.isAggregating(); - try (ProtocolInboundMessage aggregated = aggregator.finishAggregation()) { - statsTracker.markMessageReceived(); - messageHandler.accept(channel, aggregated); - } - } else { - assert aggregator.isAggregating(); - assert fragment instanceof ReleasableBytesReference; - aggregator.aggregate((ReleasableBytesReference) fragment); - } - } - } - -} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java index 448b21b5d3668..ac86e4807a432 100644 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java @@ -33,55 +33,16 @@ package org.opensearch.transport.nativeprotocol; import org.opensearch.common.annotation.PublicApi; -import org.opensearch.common.bytes.ReleasableBytesReference; -import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.common.util.io.IOUtils; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.transport.Header; -import org.opensearch.transport.ProtocolInboundMessage; - -import java.io.IOException; - /** * Inbound data as a message * * @opensearch.api */ @PublicApi(since = "2.14.0") -public class NativeInboundMessage extends ProtocolInboundMessage { +public class NativeInboundMessage { /** * The protocol used to encode this message */ public static String NATIVE_PROTOCOL = "native"; - - private StreamInput streamInput; - - public NativeInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { - super(header, content, breakerRelease); - } - - public NativeInboundMessage(Header header, Exception exception) { - super(header, exception); - } - - public NativeInboundMessage(Header header, boolean isPing) { - super(header, isPing); - } - - public StreamInput openOrGetStreamInput() throws IOException { - assert isPing == false && content != null; - if (streamInput == null) { - streamInput = content.streamInput(); - streamInput.setVersion(header.getVersion()); - } - return streamInput; - } - - @Override - public void close() { - IOUtils.closeWhileHandlingException(streamInput); - super.close(); - } } diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 7a2e79fa8cc1b..fc539ace18dd3 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -57,7 +57,6 @@ import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.After; import org.junit.Before; @@ -152,7 +151,7 @@ public void testPing() throws Exception { ); requestHandlers.registerHandler(registry); - handler.inboundMessage(channel, new NativeInboundMessage(null, true)); + handler.inboundMessage(channel, ProtocolInboundMessage.PING); if (channel.isServerChannel()) { BytesReference ping = channel.getMessageCaptor().get(); assertEquals('E', ping.get(0)); @@ -216,7 +215,7 @@ public TestResponse read(StreamInput in) throws IOException { ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(TransportProtocol.NATIVE, fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( + ProtocolInboundMessage requestMessage = new ProtocolInboundMessage( requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {} @@ -241,7 +240,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( + ProtocolInboundMessage responseMessage = new ProtocolInboundMessage( responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {} @@ -273,7 +272,7 @@ public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exce TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final ProtocolInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -314,7 +313,7 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final ProtocolInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -346,7 +345,7 @@ public void testLogsSlowInboundProcessing() throws Exception { TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = new NativeInboundMessage( + final ProtocolInboundMessage requestMessage = new ProtocolInboundMessage( requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> { @@ -428,7 +427,7 @@ public void onResponseSent(long requestId, String action, Exception error) { BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(TransportProtocol.NATIVE, fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( + ProtocolInboundMessage requestMessage = new ProtocolInboundMessage( requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {} @@ -497,7 +496,7 @@ public void onResponseSent(long requestId, String action, Exception error) { // Create the request payload by intentionally stripping 1 byte away BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); Header requestHeader = new Header(TransportProtocol.NATIVE, fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( + ProtocolInboundMessage requestMessage = new ProtocolInboundMessage( requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {} @@ -565,7 +564,7 @@ public TestResponse read(StreamInput in) throws IOException { ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(TransportProtocol.NATIVE, fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( + ProtocolInboundMessage requestMessage = new ProtocolInboundMessage( requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {} @@ -591,7 +590,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( + ProtocolInboundMessage responseMessage = new ProtocolInboundMessage( responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {} @@ -659,7 +658,7 @@ public TestResponse read(StreamInput in) throws IOException { ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(TransportProtocol.NATIVE, fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( + ProtocolInboundMessage requestMessage = new ProtocolInboundMessage( requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {} @@ -680,7 +679,7 @@ public TestResponse read(StreamInput in) throws IOException { // Create the response payload by intentionally stripping 1 byte away BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1); Header responseHeader = new Header(TransportProtocol.NATIVE, fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( + ProtocolInboundMessage responseMessage = new ProtocolInboundMessage( responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {} @@ -693,8 +692,8 @@ public TestResponse read(StreamInput in) throws IOException { assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler")); } - private static NativeInboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { - return new NativeInboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { + private static ProtocolInboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { + return new ProtocolInboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { @Override public StreamInput openOrGetStreamInput() { final StreamInput streamInput = new InputStreamStreamInput(new InputStream() { diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 5a89bf1e0ead3..78982a810c7bf 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -82,9 +82,8 @@ public void testPipelineHandlingForNativeProtocol() throws IOException { final List> expected = new ArrayList<>(); final List> actual = new ArrayList<>(); final List toRelease = new ArrayList<>(); - final BiConsumer messageHandler = (c, m) -> { + final BiConsumer messageHandler = (c, message) -> { try { - NativeInboundMessage message = (NativeInboundMessage) m; final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion(); diff --git a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java index 01f19bea7a37f..cf794040c1a0c 100644 --- a/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/NativeOutboundHandlerTests.java @@ -106,9 +106,8 @@ public void setUp() throws Exception { final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - NativeInboundMessage m1 = (NativeInboundMessage) m; - Streams.copy(m1.openOrGetStreamInput(), streamOutput); - message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); + Streams.copy(m.openOrGetStreamInput(), streamOutput); + message.set(new Tuple<>(m.getHeader(), streamOutput.bytes())); } catch (IOException e) { throw new AssertionError(e); }