Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andrross committed Aug 27, 2024
1 parent 4c4c69f commit 769c685
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public ProtocolInboundMessage finishAggregation() throws IOException {
}

final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
final ProtocolInboundMessage aggregated = new ProtocolInboundMessage(currentHeader, releasableContent, breakerControl);
boolean success = false;
try {
if (aggregated.getHeader().needsToReadVariableHeader()) {
Expand All @@ -143,7 +143,7 @@ public ProtocolInboundMessage finishAggregation() throws IOException {
if (isShortCircuited()) {
aggregated.close();
success = true;
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
return new ProtocolInboundMessage(aggregated.getHeader(), aggregationException);
} else {
success = true;
return aggregated;
Expand Down
134 changes: 127 additions & 7 deletions server/src/main/java/org/opensearch/transport/InboundBytesHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,144 @@
package org.opensearch.transport;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;

/**
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
* Handler for inbound bytes for the native protocol.
*/
public interface InboundBytesHandler extends Closeable {
class InboundBytesHandler {

private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);

private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final StatsTracker statsTracker;
private boolean isClosed = false;

InboundBytesHandler(
ArrayDeque<ReleasableBytesReference> pending,
InboundDecoder decoder,
InboundAggregator aggregator,
StatsTracker statsTracker
) {
this.pending = pending;
this.decoder = decoder;
this.aggregator = aggregator;
this.statsTracker = statsTracker;
}

public void close() {
isClosed = true;
}

public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException;
) throws IOException {
final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments, messageHandler);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

public boolean canHandleBytes(ReleasableBytesReference reference);
private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
assert reference != null;
if (bytesToRelease < reference.length()) {
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
bytesToRelease -= bytesToRelease;
} else {
bytesToRelease -= reference.length();
}
}
}
}

private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void forwardFragments(
TcpChannel channel,
ArrayList<Object> fragments,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, ProtocolInboundMessage.PING);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (ProtocolInboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}
}

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongSupplier;
Expand Down Expand Up @@ -94,7 +92,7 @@ public InboundPipeline(
this.statsTracker = statsTracker;
this.decoder = decoder;
this.aggregator = aggregator;
this.bytesHandler = new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker);
this.bytesHandler = new InboundBytesHandler(pending, decoder, aggregator, statsTracker);
this.messageHandler = messageHandler;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,17 @@ public void messageReceived(
long slowLogThresholdMs,
TransportMessageListener messageListener
) throws IOException {
NativeInboundMessage inboundMessage = (NativeInboundMessage) message;
TransportLogger.logInboundMessage(channel, inboundMessage);
if (inboundMessage.isPing()) {
TransportLogger.logInboundMessage(channel, message);
if (message.isPing()) {
keepAlive.receiveKeepAlive(channel);
} else {
handleMessage(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener);
handleMessage(channel, message, startTime, slowLogThresholdMs, messageListener);
}
}

private void handleMessage(
TcpChannel channel,
NativeInboundMessage message,
ProtocolInboundMessage message,
long startTime,
long slowLogThresholdMs,
TransportMessageListener messageListener
Expand Down Expand Up @@ -202,7 +201,7 @@ private Map<String, Collection<String>> extractHeaders(Map<String, String> heade
private <T extends TransportRequest> void handleRequest(
TcpChannel channel,
Header header,
NativeInboundMessage message,
ProtocolInboundMessage message,
TransportMessageListener messageListener
) throws IOException {
final String action = header.getActionName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,64 @@

package org.opensearch.transport;

import java.io.IOException;

import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.common.io.stream.StreamInput;

/**
* Base class for inbound data as a message.
* Different implementations are used for different protocols.
* Inbound data as a message.
*
* @opensearch.internal
*/
@PublicApi(since = "2.14.0")
public abstract class ProtocolInboundMessage implements Releasable {
public class ProtocolInboundMessage implements Releasable {

static final ProtocolInboundMessage PING = new ProtocolInboundMessage(null, null, null, true, null);

protected final Header header;
protected final ReleasableBytesReference content;
protected final Exception exception;
protected final boolean isPing;
private Releasable breakerRelease;
private StreamInput streamInput;

public ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
this.breakerRelease = breakerRelease;
this.exception = null;
this.isPing = false;
ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this(header,content, null, false, breakerRelease);
}

public ProtocolInboundMessage(Header header, Exception exception) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = exception;
this.isPing = false;
ProtocolInboundMessage(Header header, Exception exception) {
this(header, null, exception, false, null);
}

public ProtocolInboundMessage(Header header, boolean isPing) {
private ProtocolInboundMessage(Header header, ReleasableBytesReference content, Exception exception, boolean isPing, Releasable breakerRelease) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = null;
this.content = content;
this.exception = exception;
this.isPing = isPing;
this.breakerRelease = breakerRelease;
}

TransportProtocol getTransportProtocol() {
if (isPing) {
return TransportProtocol.NATIVE;
}
return header.getTransportProtocol();
}

public String getProtocol() {
return header.getTransportProtocol().toString();
}

public Header getHeader() {
Header getHeader() {
return header;
}

public int getContentLength() {
int getContentLength() {
if (content == null) {
return 0;
} else {
Expand All @@ -76,15 +77,15 @@ public Exception getException() {
return exception;
}

public boolean isPing() {
boolean isPing() {
return isPing;
}

public boolean isShortCircuit() {
boolean isShortCircuit() {
return exception != null;
}

public Releasable takeBreakerReleaseControl() {
Releasable takeBreakerReleaseControl() {
final Releasable toReturn = breakerRelease;
breakerRelease = null;
if (toReturn != null) {
Expand All @@ -94,15 +95,23 @@ public Releasable takeBreakerReleaseControl() {
}
}


StreamInput openOrGetStreamInput() throws IOException {
assert isPing == false && content != null;
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setVersion(header.getVersion());
}
return streamInput;
}

@Override
public void close() {
IOUtils.closeWhileHandlingException(streamInput);
Releasables.closeWhileHandlingException(content, breakerRelease);
}

@Override
public String toString() {
return "InboundMessage{" + header + "}";
return "ProtocolInboundMessage{" + header + "}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static void logInboundMessage(TcpChannel channel, BytesReference message) {
}
}

static void logInboundMessage(TcpChannel channel, NativeInboundMessage message) {
static void logInboundMessage(TcpChannel channel, ProtocolInboundMessage message) {
if (logger.isTraceEnabled()) {
try {
String logMessage = format(channel, message, "READ");
Expand Down Expand Up @@ -137,7 +137,7 @@ private static String format(TcpChannel channel, BytesReference message, String
return sb.toString();
}

private static String format(TcpChannel channel, NativeInboundMessage message, String event) throws IOException {
private static String format(TcpChannel channel, ProtocolInboundMessage message, String event) throws IOException {
final StringBuilder sb = new StringBuilder();
sb.append(channel);

Expand Down
Loading

0 comments on commit 769c685

Please sign in to comment.