Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Ross <[email protected]>
  • Loading branch information
andrross committed Aug 27, 2024
1 parent 4c4c69f commit 45f545d
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 286 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -128,7 +127,7 @@ public ProtocolInboundMessage finishAggregation() throws IOException {
}

final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
final ProtocolInboundMessage aggregated = new ProtocolInboundMessage(currentHeader, releasableContent, breakerControl);
boolean success = false;
try {
if (aggregated.getHeader().needsToReadVariableHeader()) {
Expand All @@ -143,7 +142,7 @@ public ProtocolInboundMessage finishAggregation() throws IOException {
if (isShortCircuited()) {
aggregated.close();
success = true;
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
return new ProtocolInboundMessage(aggregated.getHeader(), aggregationException);
} else {
success = true;
return aggregated;
Expand Down
133 changes: 126 additions & 7 deletions server/src/main/java/org/opensearch/transport/InboundBytesHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,143 @@
package org.opensearch.transport;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.CompositeBytesReference;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;

/**
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
* Handler for inbound bytes for the native protocol.
*/
public interface InboundBytesHandler extends Closeable {
class InboundBytesHandler {

private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);

private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final StatsTracker statsTracker;
private boolean isClosed = false;

InboundBytesHandler(
ArrayDeque<ReleasableBytesReference> pending,
InboundDecoder decoder,
InboundAggregator aggregator,
StatsTracker statsTracker
) {
this.pending = pending;
this.decoder = decoder;
this.aggregator = aggregator;
this.statsTracker = statsTracker;
}

public void close() {
isClosed = true;
}

public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException;
) throws IOException {
final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments, messageHandler);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

public boolean canHandleBytes(ReleasableBytesReference reference);
private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
assert reference != null;
if (bytesToRelease < reference.length()) {
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
bytesToRelease -= bytesToRelease;
} else {
bytesToRelease -= reference.length();
}
}
}
}

private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void forwardFragments(
TcpChannel channel,
ArrayList<Object> fragments,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, ProtocolInboundMessage.PING);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (ProtocolInboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}
}

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongSupplier;
Expand Down Expand Up @@ -94,7 +92,7 @@ public InboundPipeline(
this.statsTracker = statsTracker;
this.decoder = decoder;
this.aggregator = aggregator;
this.bytesHandler = new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker);
this.bytesHandler = new InboundBytesHandler(pending, decoder, aggregator, statsTracker);
this.messageHandler = messageHandler;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeOutboundHandler;

import java.io.EOFException;
Expand Down Expand Up @@ -119,18 +118,17 @@ public void messageReceived(
long slowLogThresholdMs,
TransportMessageListener messageListener
) throws IOException {
NativeInboundMessage inboundMessage = (NativeInboundMessage) message;
TransportLogger.logInboundMessage(channel, inboundMessage);
if (inboundMessage.isPing()) {
TransportLogger.logInboundMessage(channel, message);
if (message.isPing()) {
keepAlive.receiveKeepAlive(channel);
} else {
handleMessage(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener);
handleMessage(channel, message, startTime, slowLogThresholdMs, messageListener);
}
}

private void handleMessage(
TcpChannel channel,
NativeInboundMessage message,
ProtocolInboundMessage message,
long startTime,
long slowLogThresholdMs,
TransportMessageListener messageListener
Expand Down Expand Up @@ -202,7 +200,7 @@ private Map<String, Collection<String>> extractHeaders(Map<String, String> heade
private <T extends TransportRequest> void handleRequest(
TcpChannel channel,
Header header,
NativeInboundMessage message,
ProtocolInboundMessage message,
TransportMessageListener messageListener
) throws IOException {
final String action = header.getActionName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,66 @@
import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.common.io.stream.StreamInput;

import java.io.IOException;

/**
* Base class for inbound data as a message.
* Different implementations are used for different protocols.
* Inbound data as a message.
*
* @opensearch.internal
*/
@PublicApi(since = "2.14.0")
public abstract class ProtocolInboundMessage implements Releasable {
public class ProtocolInboundMessage implements Releasable {

static final ProtocolInboundMessage PING = new ProtocolInboundMessage(null, null, null, true, null);

protected final Header header;
protected final ReleasableBytesReference content;
protected final Exception exception;
protected final boolean isPing;
private Releasable breakerRelease;
private StreamInput streamInput;

public ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
this.breakerRelease = breakerRelease;
this.exception = null;
this.isPing = false;
ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this(header, content, null, false, breakerRelease);
}

public ProtocolInboundMessage(Header header, Exception exception) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = exception;
this.isPing = false;
ProtocolInboundMessage(Header header, Exception exception) {
this(header, null, exception, false, null);
}

public ProtocolInboundMessage(Header header, boolean isPing) {
private ProtocolInboundMessage(
Header header,
ReleasableBytesReference content,
Exception exception,
boolean isPing,
Releasable breakerRelease
) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = null;
this.content = content;
this.exception = exception;
this.isPing = isPing;
this.breakerRelease = breakerRelease;
}

TransportProtocol getTransportProtocol() {
if (isPing) {
return TransportProtocol.NATIVE;
}
return header.getTransportProtocol();
}

public String getProtocol() {
return header.getTransportProtocol().toString();
}

public Header getHeader() {
Header getHeader() {
return header;
}

public int getContentLength() {
int getContentLength() {
if (content == null) {
return 0;
} else {
Expand All @@ -76,15 +83,15 @@ public Exception getException() {
return exception;
}

public boolean isPing() {
boolean isPing() {
return isPing;
}

public boolean isShortCircuit() {
boolean isShortCircuit() {
return exception != null;
}

public Releasable takeBreakerReleaseControl() {
Releasable takeBreakerReleaseControl() {
final Releasable toReturn = breakerRelease;
breakerRelease = null;
if (toReturn != null) {
Expand All @@ -94,15 +101,23 @@ public Releasable takeBreakerReleaseControl() {
}
}


StreamInput openOrGetStreamInput() throws IOException {
assert isPing == false && content != null;
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setVersion(header.getVersion());
}
return streamInput;
}

@Override
public void close() {
IOUtils.closeWhileHandlingException(streamInput);
Releasables.closeWhileHandlingException(content, breakerRelease);
}

@Override
public String toString() {
return "InboundMessage{" + header + "}";
return "ProtocolInboundMessage{" + header + "}";
}
}
Loading

0 comments on commit 45f545d

Please sign in to comment.