diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java index 5f03d08b4c69b..cd2ba31c13d39 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundHandler; import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.Transports; @@ -55,8 +56,9 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { Netty4MessageChannelHandler(PageCacheRecycler recycler, Netty4Transport transport) { this.transport = transport; final ThreadPool threadPool = transport.getThreadPool(); + final InboundHandler inboundHandler = transport.getInboundHandler(); this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis, - transport::inboundMessage, transport::inboundDecodeException); + transport.getInflightBreaker(), inboundHandler::getRequestHandler, transport::inboundMessage); } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java index bf3473a16ce98..d4bd764cb2c50 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.nio; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -30,10 +31,12 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.Page; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundHandler; import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.TcpTransport; import java.io.IOException; +import java.util.function.Supplier; public class TcpReadWriteHandler extends BytesWriteHandler { @@ -43,8 +46,10 @@ public class TcpReadWriteHandler extends BytesWriteHandler { public TcpReadWriteHandler(NioTcpChannel channel, PageCacheRecycler recycler, TcpTransport transport) { this.channel = channel; final ThreadPool threadPool = transport.getThreadPool(); + final Supplier breaker = transport.getInflightBreaker(); + final InboundHandler inboundHandler = transport.getInboundHandler(); this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis, - transport::inboundMessage, transport::inboundDecodeException); + breaker, inboundHandler::getRequestHandler, transport::inboundMessage); } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java b/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java index 970747fc23221..4516bfe8b16d0 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java @@ -19,6 +19,8 @@ package org.elasticsearch.transport; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -27,44 +29,69 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; public class InboundAggregator implements Releasable { + private final Supplier circuitBreaker; + private final Predicate requestCanTripBreaker; + private ReleasableBytesReference firstContent; private ArrayList contentAggregation; private Header currentHeader; + private Exception aggregationException; + private boolean canTripBreaker = true; private boolean isClosed = false; + public InboundAggregator(Supplier circuitBreaker, + Function> registryFunction) { + this(circuitBreaker, (Predicate) actionName -> { + final RequestHandlerRegistry reg = registryFunction.apply(actionName); + if (reg == null) { + throw new ActionNotFoundTransportException(actionName); + } else { + return reg.canTripCircuitBreaker(); + } + }); + } + + // Visible for testing + InboundAggregator(Supplier circuitBreaker, Predicate requestCanTripBreaker) { + this.circuitBreaker = circuitBreaker; + this.requestCanTripBreaker = requestCanTripBreaker; + } + public void headerReceived(Header header) { ensureOpen(); assert isAggregating() == false; assert firstContent == null && contentAggregation == null; currentHeader = header; + if (currentHeader.isRequest() && currentHeader.needsToReadVariableHeader() == false) { + initializeRequestState(); + } } public void aggregate(ReleasableBytesReference content) { ensureOpen(); assert isAggregating(); - if (isFirstContent()) { - firstContent = content.retain(); - } else { - if (contentAggregation == null) { - contentAggregation = new ArrayList<>(4); - contentAggregation.add(firstContent); - firstContent = null; + if (isShortCircuited() == false) { + if (isFirstContent()) { + firstContent = content.retain(); + } else { + if (contentAggregation == null) { + contentAggregation = new ArrayList<>(4); + assert firstContent != null; + contentAggregation.add(firstContent); + firstContent = null; + } + contentAggregation.add(content.retain()); } - contentAggregation.add(content.retain()); } } - public Header cancelAggregation() { - ensureOpen(); - assert isAggregating(); - final Header header = this.currentHeader; - closeCurrentAggregation(); - return header; - } - public InboundMessage finishAggregation() throws IOException { ensureOpen(); final ReleasableBytesReference releasableContent; @@ -77,16 +104,30 @@ public InboundMessage finishAggregation() throws IOException { final CompositeBytesReference content = new CompositeBytesReference(references); releasableContent = new ReleasableBytesReference(content, () -> Releasables.close(references)); } - final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent); - resetCurrentAggregation(); + + final BreakerControl breakerControl = new BreakerControl(circuitBreaker); + final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { aggregated.getHeader().finishParsingHeader(aggregated.openOrGetStreamInput()); + if (aggregated.getHeader().isRequest()) { + initializeRequestState(); + } + } + if (isShortCircuited() == false) { + checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl); + } + if (isShortCircuited()) { + aggregated.close(); + success = true; + return new InboundMessage(aggregated.getHeader(), aggregationException); + } else { + success = true; + return aggregated; } - success = true; - return aggregated; } finally { + resetCurrentAggregation(); if (success == false) { aggregated.close(); } @@ -97,6 +138,14 @@ public boolean isAggregating() { return currentHeader != null; } + private void shortCircuit(Exception exception) { + this.aggregationException = exception; + } + + private boolean isShortCircuited() { + return aggregationException != null; + } + private boolean isFirstContent() { return firstContent == null && contentAggregation == null; } @@ -108,18 +157,24 @@ public void close() { } private void closeCurrentAggregation() { + releaseContent(); + resetCurrentAggregation(); + } + + private void releaseContent() { if (contentAggregation == null) { Releasables.close(firstContent); } else { Releasables.close(contentAggregation); } - resetCurrentAggregation(); } private void resetCurrentAggregation() { firstContent = null; contentAggregation = null; currentHeader = null; + aggregationException = null; + canTripBreaker = true; } private void ensureOpen() { @@ -127,4 +182,65 @@ private void ensureOpen() { throw new IllegalStateException("Aggregator is already closed"); } } + + private void initializeRequestState() { + assert currentHeader.needsToReadVariableHeader() == false; + assert currentHeader.isRequest(); + if (currentHeader.isHandshake()) { + canTripBreaker = false; + return; + } + + final String actionName = currentHeader.getActionName(); + try { + canTripBreaker = requestCanTripBreaker.test(actionName); + } catch (ActionNotFoundTransportException e) { + shortCircuit(e); + } + } + + private void checkBreaker(final Header header, final int contentLength, final BreakerControl breakerControl) { + if (header.isRequest() == false) { + return; + } + assert header.needsToReadVariableHeader() == false; + + if (canTripBreaker) { + try { + circuitBreaker.get().addEstimateBytesAndMaybeBreak(contentLength, header.getActionName()); + breakerControl.setReservedBytes(contentLength); + } catch (CircuitBreakingException e) { + shortCircuit(e); + } + } else { + circuitBreaker.get().addWithoutBreaking(contentLength); + breakerControl.setReservedBytes(contentLength); + } + } + + private static class BreakerControl implements Releasable { + + private static final int CLOSED = -1; + + private final Supplier circuitBreaker; + private final AtomicInteger bytesToRelease = new AtomicInteger(0); + + private BreakerControl(Supplier circuitBreaker) { + this.circuitBreaker = circuitBreaker; + } + + private void setReservedBytes(int reservedBytes) { + final boolean set = bytesToRelease.compareAndSet(0, reservedBytes); + assert set : "Expected bytesToRelease to be 0, found " + bytesToRelease.get(); + } + + @Override + public void close() { + final int toRelease = bytesToRelease.getAndSet(CLOSED); + assert toRelease != CLOSED; + if (toRelease > 0) { + circuitBreaker.get().addWithoutBreaking(-toRelease); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java index 4e1d204514c22..a596793508c72 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java @@ -37,7 +37,6 @@ public class InboundDecoder implements Releasable { private final Version version; private final PageCacheRecycler recycler; - private Exception decodingException; private TransportDecompressor decompressor; private int totalNetworkSize = -1; private int bytesConsumed = 0; @@ -86,13 +85,6 @@ public int internalDecode(ReleasableBytesReference reference, Consumer f return headerBytesToRead; } } - } else if (isDecodingFailed()) { - int bytesToConsume = Math.min(reference.length(), totalNetworkSize - bytesConsumed); - bytesConsumed += bytesToConsume; - if (isDone()) { - finishMessage(fragmentConsumer); - } - return bytesToConsume; } else { // There are a minimum number of bytes required to start decompression if (decompressor != null && decompressor.canDecompress(reference.length()) == false) { @@ -130,19 +122,12 @@ public void close() { } private void finishMessage(Consumer fragmentConsumer) { - Object finishMarker; - if (decodingException != null) { - finishMarker = decodingException; - } else { - finishMarker = END_CONTENT; - } cleanDecodeState(); - fragmentConsumer.accept(finishMarker); + fragmentConsumer.accept(END_CONTENT); } private void cleanDecodeState() { IOUtils.closeWhileHandlingException(decompressor); - decodingException = null; decompressor = null; totalNetworkSize = -1; bytesConsumed = 0; @@ -190,7 +175,7 @@ private Header readHeader(int networkMessageSize, BytesReference bytesReference) Header header = new Header(networkMessageSize, requestId, status, remoteVersion); final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake()); if (invalidVersion != null) { - decodingException = invalidVersion; + throw invalidVersion; } else { if (remoteVersion.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) { // Skip since we already have ensured enough data available @@ -206,10 +191,6 @@ private boolean isOnHeader() { return totalNetworkSize == -1; } - private boolean isDecodingFailed() { - return decodingException != null; - } - private void ensureOpen() { if (isClosed) { throw new IllegalStateException("Decoder is already closed"); diff --git a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java index 6fea05ef79570..21b32be9eae4d 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java @@ -23,7 +23,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; -import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -31,7 +30,6 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; @@ -46,7 +44,6 @@ public class InboundHandler { private final ThreadPool threadPool; private final OutboundHandler outboundHandler; private final NamedWriteableRegistry namedWriteableRegistry; - private final CircuitBreakerService circuitBreakerService; private final TransportHandshaker handshaker; private final TransportKeepAlive keepAlive; @@ -55,11 +52,10 @@ public class InboundHandler { private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; InboundHandler(ThreadPool threadPool, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry, - CircuitBreakerService circuitBreakerService, TransportHandshaker handshaker, TransportKeepAlive keepAlive) { + TransportHandshaker handshaker, TransportKeepAlive keepAlive) { this.threadPool = threadPool; this.outboundHandler = outboundHandler; this.namedWriteableRegistry = namedWriteableRegistry; - this.circuitBreakerService = circuitBreakerService; this.handshaker = handshaker; this.keepAlive = keepAlive; } @@ -72,7 +68,7 @@ synchronized void registerRequestHandler(Requ } @SuppressWarnings("unchecked") - final RequestHandlerRegistry getRequestHandler(String action) { + public final RequestHandlerRegistry getRequestHandler(String action) { return (RequestHandlerRegistry) requestHandlers.get(action); } @@ -95,26 +91,27 @@ void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception if (message.isPing()) { keepAlive.receiveKeepAlive(channel); } else { - messageReceived(message, channel); + messageReceived(channel, message); } } - private void messageReceived(InboundMessage message, TcpChannel channel) throws IOException { + private void messageReceived(TcpChannel channel, InboundMessage message) throws IOException { final InetSocketAddress remoteAddress = channel.getRemoteAddress(); final Header header = message.getHeader(); assert header.needsToReadVariableHeader() == false; - final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(streamInput, header.getVersion()); - ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext existing = threadContext.stashContext()) { // Place the context with the headers from the message threadContext.setHeaders(header.getHeaders()); threadContext.putTransient("_remote_address", remoteAddress); if (header.isRequest()) { - handleRequest(channel, header, streamInput, message.getContentLength()); + handleRequest(channel, header, message); } else { + // Responses do not support short circuiting currently + assert message.isShortCircuit() == false; + final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(streamInput, header.getVersion()); final TransportResponseHandler handler; long requestId = header.getRequestId(); if (header.isHandshake()) { @@ -147,54 +144,59 @@ private void messageReceived(InboundMessage message, TcpChannel channel) throws } } - private void handleRequest(TcpChannel channel, Header header, StreamInput stream, int messageLengthBytes) { + private void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException { final String action = header.getActionName(); final long requestId = header.getRequestId(); final Version version = header.getVersion(); - TransportChannel transportChannel = null; - try { + if (header.isHandshake()) { messageListener.onRequestReceived(requestId, action); - if (header.isHandshake()) { - // Handshakes are not currently circuit broken - transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, - circuitBreakerService, 0, header.isCompressed(), header.isHandshake()); + // Cannot short circuit handshakes + assert message.isShortCircuit() == false; + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, + header.isCompressed(), header.isHandshake(), message.takeBreakerReleaseControl()); + try { handshaker.handleHandshake(transportChannel, requestId, stream); - } else { - final RequestHandlerRegistry reg = getRequestHandler(action); - if (reg == null) { - throw new ActionNotFoundTransportException(action); - } - CircuitBreaker breaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); - if (reg.canTripCircuitBreaker()) { - breaker.addEstimateBytesAndMaybeBreak(messageLengthBytes, ""); - } else { - breaker.addWithoutBreaking(messageLengthBytes); - } - transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, - circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake()); - final T request = reg.newRequest(stream); - request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); - // in case we throw an exception, i.e. when the limit is hit, we don't want to verify - final int nextByte = stream.read(); - // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker - if (nextByte != -1) { - throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + action - + "], available [" + stream.available() + "]; resetting"); - } - threadPool.executor(reg.getExecutor()).execute(new RequestHandler<>(reg, request, transportChannel)); - } - } catch (Exception e) { - // the circuit breaker tripped - if (transportChannel == null) { - transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, - circuitBreakerService, 0, header.isCompressed(), header.isHandshake()); + } catch (Exception e) { + sendErrorResponse(action, transportChannel, e); } + } else { + final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, + header.isCompressed(), header.isHandshake(), message.takeBreakerReleaseControl()); try { - transportChannel.sendResponse(e); - } catch (IOException inner) { - inner.addSuppressed(e); - logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", action), inner); + messageListener.onRequestReceived(requestId, action); + if (message.isShortCircuit()) { + sendErrorResponse(action, transportChannel, message.getException()); + } else { + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final RequestHandlerRegistry reg = getRequestHandler(action); + assert reg != null; + final T request = reg.newRequest(stream); + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); + // in case we throw an exception, i.e. when the limit is hit, we don't want to verify + final int nextByte = stream.read(); + // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker + if (nextByte != -1) { + throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + + action + "], available [" + stream.available() + "]; resetting"); + } + threadPool.executor(reg.getExecutor()).execute(new RequestHandler<>(reg, request, transportChannel)); + } + } catch (Exception e) { + sendErrorResponse(action, transportChannel, e); } + + } + } + + private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { + try { + transportChannel.sendResponse(e); + } catch (Exception inner) { + inner.addSuppressed(e); + logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner); } } @@ -277,13 +279,7 @@ public boolean isForceExecution() { @Override public void onFailure(Exception e) { - try { - transportChannel.sendResponse(e); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.warn(() -> new ParameterizedMessage( - "Failed to send error message back to client for action [{}]", reg.getAction()), inner); - } + sendErrorResponse(reg.getAction(), transportChannel, e); } } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java index b8f1dfa14a13b..99dd23e940d27 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -31,18 +31,32 @@ public class InboundMessage implements Releasable { private final Header header; private final ReleasableBytesReference content; + private final Exception exception; private final boolean isPing; + private Releasable breakerRelease; private StreamInput streamInput; - public InboundMessage(Header header, ReleasableBytesReference content) { + public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { this.header = header; this.content = content; + this.breakerRelease = breakerRelease; + this.exception = null; + this.isPing = false; + } + + public InboundMessage(Header header, Exception exception) { + this.header = header; + this.content = null; + this.breakerRelease = null; + this.exception = exception; this.isPing = false; } public InboundMessage(Header header, boolean isPing) { this.header = header; this.content = null; + this.breakerRelease = null; + this.exception = null; this.isPing = isPing; } @@ -58,10 +72,24 @@ public int getContentLength() { } } + public Exception getException() { + return exception; + } + public boolean isPing() { return isPing; } + public boolean isShortCircuit() { + return exception != null; + } + + public Releasable takeBreakerReleaseControl() { + final Releasable toReturn = breakerRelease; + breakerRelease = null; + return toReturn; + } + public StreamInput openOrGetStreamInput() throws IOException { assert isPing == false && content != null; if (streamInput == null) { @@ -74,6 +102,6 @@ public StreamInput openOrGetStreamInput() throws IOException { @Override public void close() { IOUtils.closeWhileHandlingException(streamInput); - Releasables.closeWhileHandlingException(content); + Releasables.closeWhileHandlingException(content, breakerRelease); } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java index 68740b54742d5..a9e71c55b4f04 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java @@ -20,9 +20,9 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.PageCacheRecycler; @@ -31,7 +31,9 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.function.LongSupplier; +import java.util.function.Supplier; public class InboundPipeline implements Releasable { @@ -43,26 +45,25 @@ public class InboundPipeline implements Releasable { private final InboundDecoder decoder; private final InboundAggregator aggregator; private final BiConsumer messageHandler; - private final BiConsumer> errorHandler; + private Exception uncaughtException; private ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; public InboundPipeline(Version version, StatsTracker statsTracker, PageCacheRecycler recycler, LongSupplier relativeTimeInMillis, - BiConsumer messageHandler, - BiConsumer> errorHandler) { - this(statsTracker, relativeTimeInMillis, new InboundDecoder(version, recycler), new InboundAggregator(), messageHandler, - errorHandler); + Supplier circuitBreaker, + Function> registryFunction, + BiConsumer messageHandler) { + this(statsTracker, relativeTimeInMillis, new InboundDecoder(version, recycler), + new InboundAggregator(circuitBreaker, registryFunction), messageHandler); } - private InboundPipeline(StatsTracker statsTracker, LongSupplier relativeTimeInMillis, InboundDecoder decoder, - InboundAggregator aggregator, BiConsumer messageHandler, - BiConsumer> errorHandler) { + public InboundPipeline(StatsTracker statsTracker, LongSupplier relativeTimeInMillis, InboundDecoder decoder, + InboundAggregator aggregator, BiConsumer messageHandler) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; this.messageHandler = messageHandler; - this.errorHandler = errorHandler; } @Override @@ -74,6 +75,18 @@ public void close() { } public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { + if (uncaughtException != null) { + throw new IllegalStateException("Pipeline state corrupted by uncaught exception", uncaughtException); + } + try { + doHandleBytes(channel, reference); + } catch (Exception e) { + uncaughtException = e; + throw e; + } + } + + public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); statsTracker.markBytesRead(reference.length()); pending.add(reference.retain()); @@ -128,15 +141,6 @@ private void forwardFragments(TcpChannel channel, ArrayList fragments) t statsTracker.markMessageReceived(); messageHandler.accept(channel, aggregated); } - } else if (fragment instanceof Exception) { - final Header header; - if (aggregator.isAggregating()) { - header = aggregator.cancelAggregation(); - statsTracker.markMessageReceived(); - } else { - header = null; - } - errorHandler.accept(channel, new Tuple<>(header, (Exception) fragment)); } else { assert aggregator.isAggregating(); assert fragment instanceof ReleasableBytesReference; diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index 1c44655a2c6f2..52def6375401b 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -32,7 +32,6 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -84,6 +83,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Supplier; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -113,6 +113,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements protected final PageCacheRecycler pageCacheRecycler; protected final NetworkService networkService; protected final Set profileSettings; + private final CircuitBreakerService circuitBreakerService; private final ConcurrentMap profileBoundAddresses = newConcurrentMap(); private final Map> serverChannels = newConcurrentMap(); @@ -136,6 +137,7 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P this.version = version; this.threadPool = threadPool; this.pageCacheRecycler = pageCacheRecycler; + this.circuitBreakerService = circuitBreakerService; this.networkService = networkService; String nodeName = Node.NODE_NAME_SETTING.get(settings); BigArrays bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS); @@ -146,8 +148,7 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), TransportRequestOptions.EMPTY, v, false, true)); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); - this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker, - keepAlive); + this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive); } public Version getVersion() { @@ -162,6 +163,14 @@ public ThreadPool getThreadPool() { return threadPool; } + public Supplier getInflightBreaker() { + return () -> circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); + } + + public InboundHandler getInboundHandler() { + return inboundHandler; + } + @Override protected void doStart() { } @@ -677,10 +686,6 @@ public void inboundMessage(TcpChannel channel, InboundMessage message) { } } - public void inboundDecodeException(TcpChannel channel, Tuple tuple) { - onException(channel, tuple.v2()); - } - /** * Validates the first 6 bytes of the message header and returns the length of the message. If 6 bytes * are not available, it returns -1. diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java index faf8ca097b12f..14282b23791af 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java @@ -20,8 +20,7 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; -import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.common.lease.Releasable; import java.io.IOException; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,22 +33,20 @@ public final class TcpTransportChannel implements TransportChannel { private final String action; private final long requestId; private final Version version; - private final CircuitBreakerService breakerService; - private final long reservedBytes; private final boolean compressResponse; private final boolean isHandshake; + private final Releasable breakerRelease; TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version, - CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse, boolean isHandshake) { + boolean compressResponse, boolean isHandshake, Releasable breakerRelease) { this.version = version; this.channel = channel; this.outboundHandler = outboundHandler; this.action = action; this.requestId = requestId; - this.breakerService = breakerService; - this.reservedBytes = reservedBytes; this.compressResponse = compressResponse; this.isHandshake = isHandshake; + this.breakerRelease = breakerRelease; } @Override @@ -80,7 +77,7 @@ public void sendResponse(Exception exception) throws IOException { private void release(boolean isExceptionResponse) { if (released.compareAndSet(false, true)) { assert (releaseBy = new Exception()) != null; // easier to debug if it's already closed - breakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS).addWithoutBreaking(-reservedBytes); + breakerRelease.close(); } else if (isExceptionResponse == false) { // only fail if we are not sending an error - we might send the error triggered by the previous // sendResponse call diff --git a/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java b/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java index 4de58dfaddf9e..0a3df4e3a139c 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java @@ -20,6 +20,8 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.TestCircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Tuple; @@ -32,20 +34,33 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.function.Predicate; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.notNullValue; public class InboundAggregatorTests extends ESTestCase { private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + private final String unBreakableAction = "non_breakable_action"; + private final String unknownAction = "unknown_action"; private InboundAggregator aggregator; + private TestCircuitBreaker circuitBreaker; @Before @Override public void setUp() throws Exception { super.setUp(); - aggregator = new InboundAggregator(); + Predicate requestCanTripBreaker = action -> { + if (unknownAction.equals(action)) { + throw new ActionNotFoundTransportException(action); + } else { + return unBreakableAction.equals(action) == false; + } + }; + circuitBreaker = new TestCircuitBreaker(); + aggregator = new InboundAggregator(() -> circuitBreaker, requestCanTripBreaker); } public void testInboundAggregation() throws IOException { @@ -95,7 +110,89 @@ public void testInboundAggregation() throws IOException { } } - public void testCancelAndCloseWillCloseContent() { + public void testInboundUnknownAction() throws IOException { + long requestId = randomNonNegativeLong(); + Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + header.actionName = unknownAction; + // Initiate Message + aggregator.headerReceived(header); + + BytesArray bytes = new BytesArray(randomByteArrayOfLength(10)); + final ReleasableBytesReference content = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content); + content.close(); + assertEquals(0, content.refCount()); + + // Signal EOS + InboundMessage aggregated = aggregator.finishAggregation(); + + assertThat(aggregated, notNullValue()); + assertTrue(aggregated.isShortCircuit()); + assertThat(aggregated.getException(), instanceOf(ActionNotFoundTransportException.class)); + } + + public void testCircuitBreak() throws IOException { + circuitBreaker.startBreaking(); + // Actions are breakable + Header breakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + breakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + breakableHeader.actionName = "action_name"; + // Initiate Message + aggregator.headerReceived(breakableHeader); + + BytesArray bytes = new BytesArray(randomByteArrayOfLength(10)); + final ReleasableBytesReference content1 = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content1); + content1.close(); + + // Signal EOS + InboundMessage aggregated1 = aggregator.finishAggregation(); + + assertEquals(0, content1.refCount()); + assertThat(aggregated1, notNullValue()); + assertTrue(aggregated1.isShortCircuit()); + assertThat(aggregated1.getException(), instanceOf(CircuitBreakingException.class)); + + // Actions marked as unbreakable are not broken + Header unbreakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + unbreakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + unbreakableHeader.actionName = unBreakableAction; + // Initiate Message + aggregator.headerReceived(unbreakableHeader); + + final ReleasableBytesReference content2 = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content2); + content2.close(); + + // Signal EOS + InboundMessage aggregated2 = aggregator.finishAggregation(); + + assertEquals(1, content2.refCount()); + assertThat(aggregated2, notNullValue()); + assertFalse(aggregated2.isShortCircuit()); + + // Handshakes are not broken + final byte handshakeStatus = TransportStatus.setHandshake(TransportStatus.setRequest((byte) 0)); + Header handshakeHeader = new Header(randomInt(), randomNonNegativeLong(), handshakeStatus, Version.CURRENT); + handshakeHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + handshakeHeader.actionName = "handshake"; + // Initiate Message + aggregator.headerReceived(handshakeHeader); + + final ReleasableBytesReference content3 = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content3); + content3.close(); + + // Signal EOS + InboundMessage aggregated3 = aggregator.finishAggregation(); + + assertEquals(1, content3.refCount()); + assertThat(aggregated3, notNullValue()); + assertFalse(aggregated3.isShortCircuit()); + } + + public void testCloseWillCloseContent() { long requestId = randomNonNegativeLong(); Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); @@ -121,11 +218,7 @@ public void testCancelAndCloseWillCloseContent() { content2.close(); } - if (randomBoolean()) { - aggregator.cancelAggregation(); - } else { - aggregator.close(); - } + aggregator.close(); for (ReleasableBytesReference reference : references) { assertEquals(0, reference.refCount()); @@ -134,24 +227,40 @@ public void testCancelAndCloseWillCloseContent() { public void testFinishAggregationWillFinishHeader() throws IOException { long requestId = randomNonNegativeLong(); + final String actionName; + final boolean unknownAction = randomBoolean(); + if (unknownAction) { + actionName = this.unknownAction; + } else { + actionName = "action_name"; + } Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); // Initiate Message aggregator.headerReceived(header); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { threadContext.writeTo(streamOutput); - streamOutput.writeString("action_name"); + streamOutput.writeString(actionName); streamOutput.write(randomByteArrayOfLength(10)); - aggregator.aggregate(ReleasableBytesReference.wrap(streamOutput.bytes())); + final ReleasableBytesReference content = ReleasableBytesReference.wrap(streamOutput.bytes()); + aggregator.aggregate(content); + content.close(); // Signal EOS InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(header.needsToReadVariableHeader()); - assertEquals("action_name", header.getActionName()); + assertEquals(actionName, header.getActionName()); + if (unknownAction) { + assertEquals(0, content.refCount()); + assertTrue(aggregated.isShortCircuit()); + } else { + assertEquals(1, content.refCount()); + assertFalse(aggregated.isShortCircuit()); + } } - } + } diff --git a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java index 98958920cedeb..53745ba815f01 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java @@ -33,7 +33,6 @@ import java.util.ArrayList; import static org.hamcrest.Matchers.hasItems; -import static org.hamcrest.Matchers.instanceOf; public class InboundDecoderTests extends ESTestCase { @@ -108,20 +107,14 @@ public void testDecode() throws IOException { public void testDecodePreHeaderSizeVariableInt() throws IOException { // TODO: Can delete test on 9.0 - boolean isRequest = randomBoolean(); boolean isCompressed = randomBoolean(); String action = "test-request"; long requestId = randomNonNegativeLong(); final Version preHeaderVariableInt = Version.V_7_5_0; - OutboundMessage message; final String contentValue = randomAlphaOfLength(100); - if (isRequest) { - message = new OutboundMessage.Request(threadContext, new TestRequest(contentValue), - preHeaderVariableInt, action, requestId, false, isCompressed); - } else { - message = new OutboundMessage.Response(threadContext, new TestResponse(contentValue), - preHeaderVariableInt, requestId, false, isCompressed); - } + // 8.0 is only compatible with handshakes on a pre-variable int version + final OutboundMessage message = new OutboundMessage.Request(threadContext, new TestRequest(contentValue), + preHeaderVariableInt, action, requestId, true, isCompressed); final BytesReference totalBytes = message.serialize(new BytesStreamOutput()); int partialHeaderSize = TcpHeader.headerSize(preHeaderVariableInt); @@ -137,29 +130,17 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException { assertEquals(requestId, header.getRequestId()); assertEquals(preHeaderVariableInt, header.getVersion()); assertEquals(isCompressed, header.isCompressed()); - assertFalse(header.isHandshake()); - if (isRequest) { - assertTrue(header.isRequest()); - } else { - assertTrue(header.isResponse()); - } + assertTrue(header.isHandshake()); + assertTrue(header.isRequest()); assertTrue(header.needsToReadVariableHeader()); fragments.clear(); - final BytesReference bytes2 = totalBytes.slice(bytesConsumed, 2); + final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed); final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2); int bytesConsumed2 = decoder.decode(releasable2, fragments::add); - assertEquals(0, fragments.size()); - assertEquals(2, bytesConsumed2); - - final BytesReference bytes3 = totalBytes.slice(bytesConsumed + 2, totalBytes.length() - bytesConsumed - bytesConsumed2); - final ReleasableBytesReference releasable3 = ReleasableBytesReference.wrap(bytes3); - int bytesConsumed3 = decoder.decode(releasable3, fragments::add); - assertEquals(totalBytes.length() - bytesConsumed - bytesConsumed2, bytesConsumed3); - - final Object exception = fragments.get(0); - - assertThat(exception, instanceOf(IllegalStateException.class)); + assertEquals(2, fragments.size()); + assertEquals(InboundDecoder.END_CONTENT, fragments.get(fragments.size() - 1)); + assertEquals(totalBytes.length() - bytesConsumed, bytesConsumed2); } public void testDecodeHandshakeCompatibility() throws IOException { @@ -296,25 +277,13 @@ public void testVersionIncompatibilityDecodeException() throws IOException { incompatibleVersion, action, requestId, false, true); final BytesReference bytes = message.serialize(new BytesStreamOutput()); - int totalHeaderSize = TcpHeader.headerSize(incompatibleVersion); InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); - assertEquals(totalHeaderSize, bytesConsumed); + expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add)); + // No bytes are retained assertEquals(1, releasable1.refCount()); - - final Header header = (Header) fragments.get(0); - assertEquals(requestId, header.getRequestId()); - assertEquals(incompatibleVersion, header.getVersion()); - fragments.clear(); - - final int remaining = bytes.length() - bytesConsumed; - final BytesReference bytes2 = bytes.slice(bytesConsumed, remaining); - final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2); - bytesConsumed = decoder.decode(releasable2, fragments::add); - assertEquals(remaining, bytesConsumed); } public void testEnsureVersionCompatibility() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java index c0291a5a88f91..a9396fbe43601 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; @@ -62,8 +61,7 @@ public void setUp() throws Exception { TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage); OutboundHandler outboundHandler = new OutboundHandler("node", version, new StatsTracker(), threadPool, BigArrays.NON_RECYCLING_INSTANCE); - final NoneCircuitBreakerService breaker = new NoneCircuitBreakerService(); - handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, breaker, handshaker, keepAlive); + handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive); } @After @@ -129,7 +127,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent)); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -150,7 +148,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(fullRequestBytes.length() - 6, requestId, responseStatus, version); - InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent)); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); diff --git a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java index ab8c857c88709..5dbad59c95a3f 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java @@ -20,6 +20,11 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.breaker.TestCircuitBreaker; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Tuple; @@ -39,6 +44,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.LongSupplier; +import java.util.function.Predicate; +import java.util.function.Supplier; import static org.hamcrest.Matchers.instanceOf; @@ -59,34 +66,31 @@ public void testPipelineHandling() throws IOException { final boolean isRequest = header.isRequest(); final long requestId = header.getRequestId(); final boolean isCompressed = header.isCompressed(); - if (isRequest) { + if (m.isShortCircuit()) { + actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), null); + } else if (isRequest) { final TestRequest request = new TestRequest(m.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), request.value); } else { final TestResponse response = new TestResponse(m.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, null, response.value); } - actual.add(new Tuple<>(actualData, null)); + actual.add(new Tuple<>(actualData, m.getException())); } catch (IOException e) { throw new AssertionError(e); } }; - final BiConsumer> errorHandler = (c, tuple) -> { - final Header header = tuple.v1(); - final MessageData actualData; - final Version version = header.getVersion(); - final boolean isRequest = header.isRequest(); - final long requestId = header.getRequestId(); - final boolean isCompressed = header.isCompressed(); - actualData = new MessageData(version, requestId, isRequest, isCompressed, null, null); - actual.add(new Tuple<>(actualData, tuple.v2())); - }; - final PageCacheRecycler recycler = PageCacheRecycler.NON_RECYCLING_INSTANCE; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - final InboundPipeline pipeline = new InboundPipeline(Version.CURRENT, statsTracker, recycler, millisSupplier, messageHandler, - errorHandler); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final String breakThisAction = "break_this_action"; + final String actionName = "actionName"; + final Predicate canTripBreaker = breakThisAction::equals; + final TestCircuitBreaker circuitBreaker = new TestCircuitBreaker(); + circuitBreaker.startBreaking(); + final InboundAggregator aggregator = new InboundAggregator(() -> circuitBreaker, canTripBreaker); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); final FakeTcpChannel channel = new FakeTcpChannel(); final int iterations = randomIntBetween(100, 500); @@ -99,15 +103,7 @@ public void testPipelineHandling() throws IOException { toRelease.clear(); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { while (streamOutput.size() < BYTE_THRESHOLD) { - final boolean invalidVersion = rarely(); - - String actionName = "actionName"; - final Version version; - if (invalidVersion) { - version = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); - } else { - version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); - } + final Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); final String value = randomAlphaOfLength(randomIntBetween(10, 200)); final boolean isRequest = randomBoolean(); final boolean isCompressed = randomBoolean(); @@ -118,21 +114,18 @@ public void testPipelineHandling() throws IOException { OutboundMessage message; if (isRequest) { - if (invalidVersion) { - expectedExceptionClass = new IllegalStateException(); - messageData = new MessageData(version, requestId, true, isCompressed, null, null); + if (rarely()) { + messageData = new MessageData(version, requestId, true, isCompressed, breakThisAction, null); + message = new OutboundMessage.Request(threadContext, new TestRequest(value), + version, breakThisAction, requestId, false, isCompressed); + expectedExceptionClass = new CircuitBreakingException("", CircuitBreaker.Durability.PERMANENT); } else { messageData = new MessageData(version, requestId, true, isCompressed, actionName, value); + message = new OutboundMessage.Request(threadContext, new TestRequest(value), + version, actionName, requestId, false, isCompressed); } - message = new OutboundMessage.Request(threadContext, new TestRequest(value), - version, actionName, requestId, false, isCompressed); } else { - if (invalidVersion) { - expectedExceptionClass = new IllegalStateException(); - messageData = new MessageData(version, requestId, false, isCompressed, null, null); - } else { - messageData = new MessageData(version, requestId, false, isCompressed, null, value); - } + messageData = new MessageData(version, requestId, false, isCompressed, null, value); message = new OutboundMessage.Response(threadContext, new TestResponse(value), version, requestId, false, isCompressed); } @@ -165,8 +158,8 @@ public void testPipelineHandling() throws IOException { assertEquals(expectedMessageData.requestId, actualMessageData.requestId); assertEquals(expectedMessageData.isRequest, actualMessageData.isRequest); assertEquals(expectedMessageData.isCompressed, actualMessageData.isCompressed); - assertEquals(expectedMessageData.value, actualMessageData.value); assertEquals(expectedMessageData.actionName, actualMessageData.actionName); + assertEquals(expectedMessageData.value, actualMessageData.value); if (expectedTuple.v2() != null) { assertNotNull(actualTuple.v2()); assertThat(actualTuple.v2(), instanceOf(expectedTuple.v2().getClass())); @@ -183,14 +176,51 @@ public void testPipelineHandling() throws IOException { } } + public void testDecodeExceptionIsPropagated() throws IOException { + BiConsumer messageHandler = (c, m) -> {}; + final StatsTracker statsTracker = new StatsTracker(); + final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + String actionName = "actionName"; + final Version invalidVersion = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); + final String value = randomAlphaOfLength(1000); + final boolean isRequest = randomBoolean(); + final long requestId = randomNonNegativeLong(); + + OutboundMessage message; + if (isRequest) { + message = new OutboundMessage.Request(threadContext, new TestRequest(value), + invalidVersion, actionName, requestId, false, false); + } else { + message = new OutboundMessage.Response(threadContext, new TestResponse(value), + invalidVersion, requestId, false, false); + } + + final BytesReference reference = message.serialize(streamOutput); + try (ReleasableBytesReference releasable = ReleasableBytesReference.wrap(reference)) { + expectThrows(IllegalStateException.class, () -> pipeline.handleBytes(new FakeTcpChannel(), releasable)); + } + + // Pipeline cannot be reused after uncaught exception + final IllegalStateException ise = expectThrows(IllegalStateException.class, + () -> pipeline.handleBytes(new FakeTcpChannel(), ReleasableBytesReference.wrap(BytesArray.EMPTY))); + assertEquals("Pipeline state corrupted by uncaught exception", ise.getMessage()); + } + } + public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { - final PageCacheRecycler recycler = PageCacheRecycler.NON_RECYCLING_INSTANCE; BiConsumer messageHandler = (c, m) -> {}; - BiConsumer> errorHandler = (c, e) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - final InboundPipeline pipeline = new InboundPipeline(Version.CURRENT, statsTracker, recycler, millisSupplier, messageHandler, - errorHandler); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index e3bc5f8c5b5c8..282fdaae48819 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -23,6 +23,8 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -47,6 +49,8 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.LongSupplier; +import java.util.function.Predicate; +import java.util.function.Supplier; import static org.hamcrest.Matchers.instanceOf; @@ -56,7 +60,6 @@ public class OutboundHandlerTests extends ESTestCase { private final TransportRequestOptions options = TransportRequestOptions.EMPTY; private final AtomicReference> message = new AtomicReference<>(); private InboundPipeline pipeline; - private StatsTracker statsTracker; private OutboundHandler handler; private FakeTcpChannel channel; private DiscoveryNode node; @@ -67,11 +70,14 @@ public void setUp() throws Exception { channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); TransportAddress transportAddress = buildNewFakeTransportAddress(); node = new DiscoveryNode("", transportAddress, Version.CURRENT); - statsTracker = new StatsTracker(); + StatsTracker statsTracker = new StatsTracker(); handler = new OutboundHandler("node", Version.CURRENT, statsTracker, threadPool, BigArrays.NON_RECYCLING_INSTANCE); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - pipeline = new InboundPipeline(Version.CURRENT, new StatsTracker(), PageCacheRecycler.NON_RECYCLING_INSTANCE, millisSupplier, + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { Streams.copy(m.openOrGetStreamInput(), streamOutput); @@ -79,9 +85,7 @@ public void setUp() throws Exception { } catch (IOException e) { throw new AssertionError(e); } - }, (c, t) -> { - throw new AssertionError(t.v2()); - }); + }); } @After diff --git a/test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java b/test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java new file mode 100644 index 0000000000000..e2deffc52e794 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.breaker; + +import java.util.concurrent.atomic.AtomicBoolean; + +public class TestCircuitBreaker extends NoopCircuitBreaker { + + private final AtomicBoolean shouldBreak = new AtomicBoolean(false); + + public TestCircuitBreaker() { + super("test"); + } + + @Override + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + if (shouldBreak.get()) { + throw new CircuitBreakingException("broken", getDurability()); + } + return 0; + } + + public void startBreaking() { + shouldBreak.set(true); + } + + public void stopBreaking() { + shouldBreak.set(false); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 610e313fb0610..778ce99071d1e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -1638,9 +1638,10 @@ public String executor() { latch.await(); assertFalse(requestProcessed.get()); + } - service.acceptIncomingRequests(); - + service.acceptIncomingRequests(); + try (Transport.Connection connection = openConnection(serviceA, node, null)) { CountDownLatch latch2 = new CountDownLatch(1); serviceA.sendRequest(connection, "internal:action", new TestRequest(), TransportRequestOptions.EMPTY, new TransportResponseHandler() { @@ -2023,25 +2024,6 @@ public void testKeepAlivePings() throws Exception { public void testTcpHandshake() { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); - try (MockTransportService service = buildService("TS_BAD", Version.CURRENT, Settings.EMPTY)) { - service.addMessageListener(new TransportMessageListener() { - @Override - public void onRequestReceived(long requestId, String action) { - if (TransportHandshaker.HANDSHAKE_ACTION_NAME.equals(action)) { - throw new ActionNotFoundTransportException(action); - } - } - }); - service.start(); - service.acceptIncomingRequests(); - // this acts like a node that doesn't have support for handshakes - DiscoveryNode node = - new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); - ConnectTransportException exception = expectThrows(ConnectTransportException.class, () -> connectToNode(serviceA, node)); - assertThat(exception.getCause(), instanceOf(IllegalStateException.class)); - assertEquals("handshake failed", exception.getCause().getMessage()); - } - ConnectionProfile connectionProfile = ConnectionProfile.buildDefaultConnectionProfile(Settings.EMPTY); try (TransportService service = buildService("TS_TPC", Version.CURRENT, Settings.EMPTY)) { DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index e68b265a67211..19e561440b7d7 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -27,6 +27,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -53,7 +54,9 @@ import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectionProfile; +import org.elasticsearch.transport.InboundHandler; import org.elasticsearch.transport.InboundPipeline; +import org.elasticsearch.transport.StatsTracker; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpServerChannel; import org.elasticsearch.transport.TcpTransport; @@ -274,8 +277,12 @@ private static class MockTcpReadWriteHandler extends BytesWriteHandler { private MockTcpReadWriteHandler(MockSocketChannel channel, PageCacheRecycler recycler, TcpTransport transport) { this.channel = channel; final ThreadPool threadPool = transport.getThreadPool(); - this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, - threadPool::relativeTimeInMillis, transport::inboundMessage, transport::inboundDecodeException); + final Supplier breaker = transport.getInflightBreaker(); + final InboundHandler inboundHandler = transport.getInboundHandler(); + final Version version = transport.getVersion(); + final StatsTracker statsTracker = transport.getStatsTracker(); + this.pipeline = new InboundPipeline(version, statsTracker, recycler, threadPool::relativeTimeInMillis, breaker, + inboundHandler::getRequestHandler, transport::inboundMessage); } @Override