diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java index 87c0ff2817eb7..2d57faf5cb897 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java @@ -34,14 +34,17 @@ public WriteOperation createWriteOperation(SocketChannelContext context, Object return new FlushReadyWrite(context, (ByteBuffer[]) message, listener); } + @Override public List writeToBytes(WriteOperation writeOperation) { assert writeOperation instanceof FlushReadyWrite : "Write operation must be flush ready"; return Collections.singletonList((FlushReadyWrite) writeOperation); } + @Override public List pollFlushOperations() { return EMPTY_LIST; } + @Override public void close() {} } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index 28d32f50bfc69..bc24789341e04 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -36,12 +36,12 @@ import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.Scope; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundMessage; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportSettings; import java.io.IOException; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -111,13 +111,9 @@ public ExceptionThrowingNetty4Transport( } @Override - protected String handleRequest(TcpChannel channel, String profileName, - StreamInput stream, long requestId, int messageLengthBytes, Version version, - InetSocketAddress remoteAddress, byte status) throws IOException { - String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version, - remoteAddress, status); + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; - return action; } @Override diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java index 334e63dc0bf95..087c3758bb98b 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java @@ -38,12 +38,12 @@ import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.Scope; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundMessage; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportSettings; import java.io.IOException; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -113,13 +113,9 @@ public Map> getTransports(Settings settings, ThreadP } @Override - protected String handleRequest(TcpChannel channel, String profileName, - StreamInput stream, long requestId, int messageLengthBytes, Version version, - InetSocketAddress remoteAddress, byte status) throws IOException { - String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version, - remoteAddress, status); + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; - return action; } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java b/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java index 54f4d1d0c8d84..4116f88b14224 100644 --- a/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java +++ b/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java @@ -39,8 +39,8 @@ * written to this stream. If compression is enabled, the proper EOS bytes will be written at that point. * The underlying {@link BytesReference} will be returned. * - * {@link CompressibleBytesOutputStream#close()} should be called when the bytes are no longer needed and - * can be safely released. + * {@link CompressibleBytesOutputStream#close()} will NOT close the underlying stream. The byte stream passed + * in the constructor must be closed individually. */ final class CompressibleBytesOutputStream extends StreamOutput { @@ -92,12 +92,9 @@ public void flush() throws IOException { @Override public void close() throws IOException { - if (stream == bytesStreamOutput) { - assert shouldCompress == false : "If the streams are the same we should not be compressing"; - IOUtils.close(stream); - } else { + if (stream != bytesStreamOutput) { assert shouldCompress : "If the streams are different we should be compressing"; - IOUtils.close(stream, bytesStreamOutput); + IOUtils.close(stream); } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java new file mode 100644 index 0000000000000..44e3b017ed27e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -0,0 +1,168 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.compress.Compressor; +import org.elasticsearch.common.compress.CompressorFactory; +import org.elasticsearch.common.compress.NotCompressedException; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.internal.io.IOUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Set; +import java.util.TreeSet; + +public abstract class InboundMessage extends NetworkMessage implements Closeable { + + private final StreamInput streamInput; + + InboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { + super(threadContext, version, status, requestId); + this.streamInput = streamInput; + } + + StreamInput getStreamInput() { + return streamInput; + } + + static class Reader { + + private final Version version; + private final NamedWriteableRegistry namedWriteableRegistry; + private final ThreadContext threadContext; + + Reader(Version version, NamedWriteableRegistry namedWriteableRegistry, ThreadContext threadContext) { + this.version = version; + this.namedWriteableRegistry = namedWriteableRegistry; + this.threadContext = threadContext; + } + + InboundMessage deserialize(BytesReference reference) throws IOException { + int messageLengthBytes = reference.length(); + final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; + // we have additional bytes to read, outside of the header + boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0; + StreamInput streamInput = reference.streamInput(); + boolean success = false; + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + long requestId = streamInput.readLong(); + byte status = streamInput.readByte(); + Version remoteVersion = Version.fromId(streamInput.readInt()); + final boolean isHandshake = TransportStatus.isHandshake(status); + ensureVersionCompatibility(remoteVersion, version, isHandshake); + if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamInput.available() > 0) { + Compressor compressor; + try { + final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE; + compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed)); + } catch (NotCompressedException ex) { + int maxToRead = Math.min(reference.length(), 10); + StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [") + .append(maxToRead).append("] content bytes out of [").append(reference.length()) + .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are ["); + for (int i = 0; i < maxToRead; i++) { + sb.append(reference.get(i)).append(","); + } + sb.append("]"); + throw new IllegalStateException(sb.toString()); + } + streamInput = compressor.streamInput(streamInput); + } + streamInput = new NamedWriteableAwareStreamInput(streamInput, namedWriteableRegistry); + streamInput.setVersion(remoteVersion); + + threadContext.readHeaders(streamInput); + + InboundMessage message; + if (TransportStatus.isRequest(status)) { + final Set features; + if (remoteVersion.onOrAfter(Version.V_6_3_0)) { + features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(streamInput.readStringArray()))); + } else { + features = Collections.emptySet(); + } + final String action = streamInput.readString(); + message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput); + } else { + message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput); + } + success = true; + return message; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(streamInput); + } + } + } + } + + @Override + public void close() throws IOException { + streamInput.close(); + } + + private static void ensureVersionCompatibility(Version version, Version currentVersion, boolean isHandshake) { + // for handshakes we are compatible with N-2 since otherwise we can't figure out our initial version + // since we are compatible with N-1 and N+1 so we always send our minCompatVersion as the initial version in the + // handshake. This looks odd but it's required to establish the connection correctly we check for real compatibility + // once the connection is established + final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion; + if (version.isCompatible(compatibilityVersion) == false) { + final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion(); + String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: ["; + throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]"); + } + } + + public static class RequestMessage extends InboundMessage { + + private final String actionName; + private final Set features; + + RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set features, + StreamInput streamInput) { + super(threadContext, version, status, requestId, streamInput); + this.actionName = actionName; + this.features = features; + } + + String getActionName() { + return actionName; + } + + Set getFeatures() { + return features; + } + } + + public static class ResponseMessage extends InboundMessage { + + ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { + super(threadContext, version, status, requestId, streamInput); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java new file mode 100644 index 0000000000000..7d8dbb8a0f1b6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +/** + * Represents a transport message sent over the network. Subclasses implement serialization and + * deserialization. + */ +public abstract class NetworkMessage { + + protected final Version version; + protected final ThreadContext threadContext; + protected final ThreadContext.StoredContext storedContext; + protected final long requestId; + protected final byte status; + + NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) { + this.threadContext = threadContext; + storedContext = threadContext.stashContext(); + storedContext.restore(); + this.version = version; + this.requestId = requestId; + this.status = status; + } + + public Version getVersion() { + return version; + } + + public long getRequestId() { + return requestId; + } + + boolean isCompress() { + return TransportStatus.isCompress(status); + } + + ThreadContext.StoredContext getStoredContext() { + return storedContext; + } + + boolean isResponse() { + return TransportStatus.isRequest(status) == false; + } + + boolean isRequest() { + return TransportStatus.isRequest(status); + } + + boolean isHandshake() { + return TransportStatus.isHandshake(status); + } + + boolean isError() { + return TransportStatus.isError(status); + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java new file mode 100644 index 0000000000000..31c67d930c538 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -0,0 +1,167 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.NotifyOnceListener; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.metrics.MeanMetric; +import org.elasticsearch.common.network.CloseableChannel; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; + +final class OutboundHandler { + + private static final Logger logger = LogManager.getLogger(OutboundHandler.class); + + private final MeanMetric transmittedBytesMetric = new MeanMetric(); + private final ThreadPool threadPool; + private final BigArrays bigArrays; + private final TransportLogger transportLogger; + + OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) { + this.threadPool = threadPool; + this.bigArrays = bigArrays; + this.transportLogger = transportLogger; + } + + void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener listener) { + channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + SendContext sendContext = new SendContext(channel, () -> bytes, listener); + try { + internalSendMessage(channel, sendContext); + } catch (IOException e) { + // This should not happen as the bytes are already serialized + throw new AssertionError(e); + } + } + + void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { + channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); + SendContext sendContext = new SendContext(channel, serializer, listener, serializer); + internalSendMessage(channel, sendContext); + } + + /** + * sends a message to the given channel, using the given callbacks. + */ + private void internalSendMessage(TcpChannel channel, SendContext sendContext) throws IOException { + channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + BytesReference reference = sendContext.get(); + try { + channel.sendMessage(reference, sendContext); + } catch (RuntimeException ex) { + sendContext.onFailure(ex); + CloseableChannel.closeChannel(channel); + throw ex; + } + + } + + MeanMetric getTransmittedBytes() { + return transmittedBytesMetric; + } + + private static class MessageSerializer implements CheckedSupplier, Releasable { + + private final OutboundMessage message; + private final BigArrays bigArrays; + private volatile ReleasableBytesStreamOutput bytesStreamOutput; + + private MessageSerializer(OutboundMessage message, BigArrays bigArrays) { + this.message = message; + this.bigArrays = bigArrays; + } + + @Override + public BytesReference get() throws IOException { + bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays); + return message.serialize(bytesStreamOutput); + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(bytesStreamOutput); + } + } + + private class SendContext extends NotifyOnceListener implements CheckedSupplier { + + private final TcpChannel channel; + private final CheckedSupplier messageSupplier; + private final ActionListener listener; + private final Releasable optionalReleasable; + private long messageSize = -1; + + private SendContext(TcpChannel channel, CheckedSupplier messageSupplier, + ActionListener listener) { + this(channel, messageSupplier, listener, null); + } + + private SendContext(TcpChannel channel, CheckedSupplier messageSupplier, + ActionListener listener, Releasable optionalReleasable) { + this.channel = channel; + this.messageSupplier = messageSupplier; + this.listener = listener; + this.optionalReleasable = optionalReleasable; + } + + public BytesReference get() throws IOException { + BytesReference message; + try { + message = messageSupplier.get(); + messageSize = message.length(); + transportLogger.logOutboundMessage(channel, message); + return message; + } catch (Exception e) { + onFailure(e); + throw e; + } + } + + @Override + protected void innerOnResponse(Void v) { + assert messageSize != -1 : "If onResponse is being called, the message should have been serialized"; + transmittedBytesMetric.inc(messageSize); + closeAndCallback(() -> listener.onResponse(v)); + } + + @Override + protected void innerOnFailure(Exception e) { + logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e); + closeAndCallback(() -> listener.onFailure(e)); + } + + private void closeAndCallback(Runnable runnable) { + Releasables.close(optionalReleasable, runnable::run); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java new file mode 100644 index 0000000000000..cc295a68df3c7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +import java.io.IOException; +import java.util.Set; + +abstract class OutboundMessage extends NetworkMessage implements Writeable { + + private final Writeable message; + + OutboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, Writeable message) { + super(threadContext, version, status, requestId); + this.message = message; + } + + BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { + storedContext.restore(); + bytesStream.setVersion(version); + bytesStream.skip(TcpHeader.HEADER_SIZE); + + // The compressible bytes stream will not close the underlying bytes stream + BytesReference reference; + try (CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) { + stream.setVersion(version); + threadContext.writeTo(stream); + writeTo(stream); + reference = writeMessage(stream); + } + bytesStream.seek(0); + TcpHeader.writeHeader(bytesStream, requestId, status, version, reference.length() - TcpHeader.HEADER_SIZE); + return reference; + } + + private BytesReference writeMessage(CompressibleBytesOutputStream stream) throws IOException { + final BytesReference zeroCopyBuffer; + if (message instanceof BytesTransportRequest) { + BytesTransportRequest bRequest = (BytesTransportRequest) message; + bRequest.writeThin(stream); + zeroCopyBuffer = bRequest.bytes; + } else if (message instanceof RemoteTransportException) { + stream.writeException((RemoteTransportException) message); + zeroCopyBuffer = BytesArray.EMPTY; + } else { + message.writeTo(stream); + zeroCopyBuffer = BytesArray.EMPTY; + } + // we have to call materializeBytes() here before accessing the bytes. A CompressibleBytesOutputStream + // might be implementing compression. And materializeBytes() ensures that some marker bytes (EOS marker) + // are written. Otherwise we barf on the decompressing end when we read past EOF on purpose in the + // #validateRequest method. this might be a problem in deflate after all but it's important to write + // the marker bytes. + final BytesReference message = stream.materializeBytes(); + if (zeroCopyBuffer.length() == 0) { + return message; + } else { + return new CompositeBytesReference(message, zeroCopyBuffer); + } + } + + static class Request extends OutboundMessage { + + private final String[] features; + private final String action; + + Request(ThreadContext threadContext, String[] features, Writeable message, Version version, String action, long requestId, + boolean isHandshake, boolean compress) { + super(threadContext, version, setStatus(compress, isHandshake, message), requestId, message); + this.features = features; + this.action = action; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (version.onOrAfter(Version.V_6_3_0)) { + out.writeStringArray(features); + } + out.writeString(action); + } + + private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) { + byte status = 0; + status = TransportStatus.setRequest(status); + if (compress && OutboundMessage.canCompress(message)) { + status = TransportStatus.setCompress(status); + } + if (isHandshake) { + status = TransportStatus.setHandshake(status); + } + + return status; + } + } + + static class Response extends OutboundMessage { + + private final Set features; + + Response(ThreadContext threadContext, Set features, Writeable message, Version version, long requestId, + boolean isHandshake, boolean compress) { + super(threadContext, version, setStatus(compress, isHandshake, message), requestId, message); + this.features = features; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.setFeatures(features); + } + + private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) { + byte status = 0; + status = TransportStatus.setResponse(status); + if (message instanceof RemoteTransportException) { + status = TransportStatus.setError(status); + } + if (compress) { + status = TransportStatus.setCompress(status); + } + if (isHandshake) { + status = TransportStatus.setHandshake(status); + } + + return status; + } + } + + private static boolean canCompress(Writeable message) { + return message instanceof BytesTransportRequest == false; + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index e0d939b37a314..e0b466e2ec631 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -26,24 +26,16 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.Lifecycle; -import org.elasticsearch.common.compress.Compressor; -import org.elasticsearch.common.compress.CompressorFactory; -import org.elasticsearch.common.compress.NotCompressedException; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.metrics.MeanMetric; @@ -64,17 +56,14 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.node.Node; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; -import java.io.Closeable; import java.io.IOException; import java.io.StreamCorruptedException; -import java.io.UncheckedIOException; import java.net.BindException; import java.net.InetAddress; import java.net.InetSocketAddress; @@ -136,8 +125,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final Map> serverChannels = newConcurrentMap(); private final Set acceptedChannels = ConcurrentCollections.newConcurrentSet(); - private final NamedWriteableRegistry namedWriteableRegistry; - // this lock is here to make sure we close this transport and disconnect all the client nodes // connections while no connect operations is going on private final ReadWriteLock closeLock = new ReentrantReadWriteLock(); @@ -145,15 +132,16 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final String transportName; private final MeanMetric readBytesMetric = new MeanMetric(); - private final MeanMetric transmittedBytesMetric = new MeanMetric(); private volatile Map> requestHandlers = Collections.emptyMap(); private final ResponseHandlers responseHandlers = new ResponseHandlers(); private final TransportLogger transportLogger; private final TransportHandshaker handshaker; private final TransportKeepAlive keepAlive; + private final InboundMessage.Reader reader; + private final OutboundHandler outboundHandler; private final String nodeName; - public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool, + public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { this.settings = settings; @@ -163,17 +151,18 @@ public TcpTransport(String transportName, Settings settings, Version version, T this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS); this.pageCacheRecycler = pageCacheRecycler; this.circuitBreakerService = circuitBreakerService; - this.namedWriteableRegistry = namedWriteableRegistry; this.networkService = networkService; this.transportName = transportName; this.transportLogger = new TransportLogger(); + this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger); this.handshaker = new TransportHandshaker(version, threadPool, (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), - TransportRequestOptions.EMPTY, v, false, TransportStatus.setHandshake((byte) 0)), + TransportRequestOptions.EMPTY, v, false, true), (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, false, TransportStatus.setHandshake((byte) 0))); - this.keepAlive = new TransportKeepAlive(threadPool, this::internalSendMessage); + TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); + this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); + this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); this.nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings); @@ -280,7 +269,7 @@ public void sendRequest(long requestId, String action, TransportRequest request, throw new NodeNotConnectedException(node, "connection already closed"); } TcpChannel channel = channel(options.type()); - sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress, (byte) 0); + sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress); } } @@ -573,7 +562,8 @@ protected final void doStop() { for (Map.Entry> entry : serverChannels.entrySet()) { String profile = entry.getKey(); List channels = entry.getValue(); - ActionListener closeFailLogger = ActionListener.wrap(c -> {}, + ActionListener closeFailLogger = ActionListener.wrap(c -> { + }, e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e)); channels.forEach(c -> c.addCloseListener(closeFailLogger)); CloseableChannel.closeChannels(channels, true); @@ -628,26 +618,7 @@ public void onException(TcpChannel channel, Exception e) { // in case we are able to return data, serialize the exception content and sent it back to the client if (channel.isOpen()) { BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8)); - ActionListener listener = new ActionListener() { - @Override - public void onResponse(Void aVoid) { - CloseableChannel.closeChannel(channel); - } - - @Override - public void onFailure(Exception e) { - logger.debug("failed to send message to httpOnTransport channel", e); - CloseableChannel.closeChannel(channel); - } - }; - // We do not call internalSendMessage because we are not sending a message that is an - // elasticsearch binary message. We are just serializing an exception here. Not formatting it - // as an elasticsearch transport message. - try { - channel.sendMessage(message, new SendListener(channel, message.length(), listener)); - } catch (Exception ex) { - listener.onFailure(ex); - } + outboundHandler.sendBytes(channel, message, ActionListener.wrap(() -> CloseableChannel.closeChannel(channel))); } } else { logger.warn(() -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e); @@ -691,65 +662,21 @@ protected void serverAcceptedChannel(TcpChannel channel) { */ protected abstract void stopInternal(); - private boolean canCompress(TransportRequest request) { - return request instanceof BytesTransportRequest == false; - } - private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, final TransportRequest request, TransportRequestOptions options, Version channelVersion, - boolean compressRequest, byte status) throws IOException, TransportException { - - // only compress if asked and the request is not bytes. Otherwise only - // the header part is compressed, and the "body" can't be extracted as compressed - final boolean compressMessage = compressRequest && canCompress(request); - - status = TransportStatus.setRequest(status); - ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays); - final CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bStream, compressMessage); - boolean addedReleaseListener = false; - try { - if (compressMessage) { - status = TransportStatus.setCompress(status); - } - - // we pick the smallest of the 2, to support both backward and forward compatibility - // note, this is the only place we need to do this, since from here on, we use the serialized version - // as the version to use also when the node receiving this request will send the response with - Version version = Version.min(this.version, channelVersion); - - stream.setVersion(version); - threadPool.getThreadContext().writeTo(stream); - if (version.onOrAfter(Version.V_6_3_0)) { - stream.writeStringArray(features); - } - stream.writeString(action); - BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream); - final TransportRequestOptions finalOptions = options; - // this might be called in a different thread - ReleaseListener releaseListener = new ReleaseListener(stream, - () -> messageListener.onRequestSent(node, requestId, action, request, finalOptions)); - internalSendMessage(channel, message, releaseListener); - addedReleaseListener = true; - } finally { - if (!addedReleaseListener) { - IOUtils.close(stream); - } - } + boolean compressRequest) throws IOException, TransportException { + sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false); } - /** - * sends a message to the given channel, using the given callbacks. - */ - private void internalSendMessage(TcpChannel channel, BytesReference message, ActionListener listener) { - channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); - transportLogger.logOutboundMessage(channel, message); - try { - channel.sendMessage(message, new SendListener(channel, message.length(), listener)); - } catch (Exception ex) { - // call listener to ensure that any resources are released - listener.onFailure(ex); - onException(channel, ex); - } + private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, + final TransportRequest request, TransportRequestOptions options, Version channelVersion, + boolean compressRequest, boolean isHandshake) throws IOException, TransportException { + Version version = Version.min(this.version, channelVersion); + OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, + requestId, isHandshake, compressRequest); + ActionListener listener = ActionListener.wrap(() -> + messageListener.onRequestSent(node, requestId, action, request, options)); + outboundHandler.sendMessage(channel, message, listener); } /** @@ -769,23 +696,13 @@ public void sendErrorResponse( final Exception error, final long requestId, final String action) throws IOException { - try (BytesStreamOutput stream = new BytesStreamOutput()) { - stream.setVersion(nodeVersion); - stream.setFeatures(features); - RemoteTransportException tx = new RemoteTransportException( - nodeName, new TransportAddress(channel.getLocalAddress()), action, error); - threadPool.getThreadContext().writeTo(stream); - stream.writeException(tx); - byte status = 0; - status = TransportStatus.setResponse(status); - status = TransportStatus.setError(status); - final BytesReference bytes = stream.bytes(); - final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length()); - CompositeBytesReference message = new CompositeBytesReference(header, bytes); - ReleaseListener releaseListener = new ReleaseListener(null, - () -> messageListener.onResponseSent(requestId, action, error)); - internalSendMessage(channel, message, releaseListener); - } + Version version = Version.min(this.version, nodeVersion); + TransportAddress address = new TransportAddress(channel.getLocalAddress()); + RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId, + false, false); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); + outboundHandler.sendMessage(channel, message, listener); } /** @@ -801,7 +718,7 @@ public void sendResponse( final long requestId, final String action, final boolean compress) throws IOException { - sendResponse(nodeVersion, features, channel, response, requestId, action, compress, (byte) 0); + sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false); } private void sendResponse( @@ -812,82 +729,18 @@ private void sendResponse( final long requestId, final String action, boolean compress, - byte status) throws IOException { - - status = TransportStatus.setResponse(status); - ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays); - CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bStream, compress); - boolean addedReleaseListener = false; - try { - if (compress) { - status = TransportStatus.setCompress(status); - } - threadPool.getThreadContext().writeTo(stream); - stream.setVersion(nodeVersion); - stream.setFeatures(features); - BytesReference message = buildMessage(requestId, status, nodeVersion, response, stream); - - // this might be called in a different thread - ReleaseListener releaseListener = new ReleaseListener(stream, - () -> messageListener.onResponseSent(requestId, action, response)); - internalSendMessage(channel, message, releaseListener); - addedReleaseListener = true; - } finally { - if (!addedReleaseListener) { - IOUtils.close(stream); - } - } - } - - /** - * Writes the Tcp message header into a bytes reference. - * - * @param requestId the request ID - * @param status the request status - * @param protocolVersion the protocol version used to serialize the data in the message - * @param length the payload length in bytes - * @see TcpHeader - */ - private BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException { - try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) { - headerOutput.setVersion(protocolVersion); - TcpHeader.writeHeader(headerOutput, requestId, status, protocolVersion, length); - final BytesReference bytes = headerOutput.bytes(); - assert bytes.length() == TcpHeader.HEADER_SIZE : "header size mismatch expected: " + TcpHeader.HEADER_SIZE + " but was: " - + bytes.length(); - return bytes; - } - } - - /** - * Serializes the given message into a bytes representation - */ - private BytesReference buildMessage(long requestId, byte status, Version nodeVersion, TransportMessage message, - CompressibleBytesOutputStream stream) throws IOException { - final BytesReference zeroCopyBuffer; - if (message instanceof BytesTransportRequest) { // what a shitty optimization - we should use a direct send method instead - BytesTransportRequest bRequest = (BytesTransportRequest) message; - assert nodeVersion.equals(bRequest.version()); - bRequest.writeThin(stream); - zeroCopyBuffer = bRequest.bytes; - } else { - message.writeTo(stream); - zeroCopyBuffer = BytesArray.EMPTY; - } - // we have to call materializeBytes() here before accessing the bytes. A CompressibleBytesOutputStream - // might be implementing compression. And materializeBytes() ensures that some marker bytes (EOS marker) - // are written. Otherwise we barf on the decompressing end when we read past EOF on purpose in the - // #validateRequest method. this might be a problem in deflate after all but it's important to write - // the marker bytes. - final BytesReference messageBody = stream.materializeBytes(); - final BytesReference header = buildHeader(requestId, status, stream.getVersion(), messageBody.length() + zeroCopyBuffer.length()); - return new CompositeBytesReference(header, messageBody, zeroCopyBuffer); + boolean isHandshake) throws IOException { + Version version = Version.min(this.version, nodeVersion); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, + requestId, isHandshake, compress); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); + outboundHandler.sendMessage(channel, message, listener); } /** * Handles inbound message that has been decoded. * - * @param channel the channel the message if fomr + * @param channel the channel the message is from * @param message the message */ public void inboundMessage(TcpChannel channel, BytesReference message) { @@ -1055,53 +908,26 @@ public HttpOnTransportException(StreamInput in) throws IOException { * This method handles the message receive part for both request and responses */ public final void messageReceived(BytesReference reference, TcpChannel channel) throws IOException { - String profileName = channel.getProfile(); + readBytesMetric.inc(reference.length() + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE); InetSocketAddress remoteAddress = channel.getRemoteAddress(); - int messageLengthBytes = reference.length(); - final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; - readBytesMetric.inc(totalMessageSize); - // we have additional bytes to read, outside of the header - boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0; - StreamInput streamIn = reference.streamInput(); - boolean success = false; - try (ThreadContext.StoredContext tCtx = threadPool.getThreadContext().stashContext()) { - long requestId = streamIn.readLong(); - byte status = streamIn.readByte(); - Version version = Version.fromId(streamIn.readInt()); - if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamIn.available() > 0) { - Compressor compressor; - try { - final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE; - compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed)); - } catch (NotCompressedException ex) { - int maxToRead = Math.min(reference.length(), 10); - StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [").append(maxToRead) - .append("] content bytes out of [").append(reference.length()) - .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are ["); - for (int i = 0; i < maxToRead; i++) { - sb.append(reference.get(i)).append(","); - } - sb.append("]"); - throw new IllegalStateException(sb.toString()); - } - streamIn = compressor.streamInput(streamIn); - } - final boolean isHandshake = TransportStatus.isHandshake(status); - ensureVersionCompatibility(version, this.version, isHandshake); - streamIn = new NamedWriteableAwareStreamInput(streamIn, namedWriteableRegistry); - streamIn.setVersion(version); - threadPool.getThreadContext().readHeaders(streamIn); - threadPool.getThreadContext().putTransient("_remote_address", remoteAddress); - if (TransportStatus.isRequest(status)) { - handleRequest(channel, profileName, streamIn, requestId, messageLengthBytes, version, remoteAddress, status); + + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext existing = threadContext.stashContext(); + InboundMessage message = reader.deserialize(reference)) { + // Place the context with the headers from the message + message.getStoredContext().restore(); + threadContext.putTransient("_remote_address", remoteAddress); + if (message.isRequest()) { + handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length()); } else { final TransportResponseHandler handler; - if (isHandshake) { + long requestId = message.getRequestId(); + if (message.isHandshake()) { handler = handshaker.removeHandlerForHandshake(requestId); } else { TransportResponseHandler theHandler = responseHandlers.onResponseReceived(requestId, messageListener); - if (theHandler == null && TransportStatus.isError(status)) { + if (theHandler == null && message.isError()) { handler = handshaker.removeHandlerForHandshake(requestId); } else { handler = theHandler; @@ -1109,40 +935,20 @@ public final void messageReceived(BytesReference reference, TcpChannel channel) } // ignore if its null, the service logs it if (handler != null) { - if (TransportStatus.isError(status)) { - handlerResponseError(streamIn, handler); + if (message.isError()) { + handlerResponseError(message.getStreamInput(), handler); } else { - handleResponse(remoteAddress, streamIn, handler); + handleResponse(remoteAddress, message.getStreamInput(), handler); } // Check the entire message has been read - final int nextByte = streamIn.read(); + final int nextByte = message.getStreamInput().read(); // calling read() is useful to make sure the message is fully read, even if there is an EOS marker if (nextByte != -1) { throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler [" - + handler + "], error [" + TransportStatus.isError(status) + "]; resetting"); + + handler + "], error [" + message.isError() + "]; resetting"); } } } - success = true; - } finally { - if (success) { - IOUtils.close(streamIn); - } else { - IOUtils.closeWhileHandlingException(streamIn); - } - } - } - - static void ensureVersionCompatibility(Version version, Version currentVersion, boolean isHandshake) { - // for handshakes we are compatible with N-2 since otherwise we can't figure out our initial version - // since we are compatible with N-1 and N+1 so we always send our minCompatVersion as the initial version in the - // handshake. This looks odd but it's required to establish the connection correctly we check for real compatibility - // once the connection is established - final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion; - if (version.isCompatible(compatibilityVersion) == false) { - final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion(); - String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: ["; - throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]"); } } @@ -1198,20 +1004,17 @@ private void handleException(final TransportResponseHandler handler, Throwable e }); } - protected String handleRequest(TcpChannel channel, String profileName, final StreamInput stream, long requestId, - int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) - throws IOException { - final Set features; - if (version.onOrAfter(Version.V_6_3_0)) { - features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(stream.readStringArray()))); - } else { - features = Collections.emptySet(); - } - final String action = stream.readString(); + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException { + final Set features = message.getFeatures(); + final String profileName = channel.getProfile(); + final String action = message.getActionName(); + final long requestId = message.getRequestId(); + final StreamInput stream = message.getStreamInput(); + final Version version = message.getVersion(); messageListener.onRequestReceived(requestId, action); TransportChannel transportChannel = null; try { - if (TransportStatus.isHandshake(status)) { + if (message.isHandshake()) { handshaker.handleHandshake(version, features, channel, requestId, stream); } else { final RequestHandlerRegistry reg = getRequestHandler(action); @@ -1224,9 +1027,9 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes); } transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName, - messageLengthBytes, TransportStatus.isCompress(status)); + messageLengthBytes, message.isCompress()); final TransportRequest request = reg.newRequest(stream); - request.remoteAddress(new TransportAddress(remoteAddress)); + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); // in case we throw an exception, i.e. when the limit is hit, we don't want to verify validateRequest(stream, requestId, action); threadPool.executor(reg.getExecutor()).execute(new RequestHandler(reg, request, transportChannel)); @@ -1235,7 +1038,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str // the circuit breaker tripped if (transportChannel == null) { transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, - profileName, 0, TransportStatus.isCompress(status)); + profileName, 0, message.isCompress()); } try { transportChannel.sendResponse(e); @@ -1244,7 +1047,6 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", action), inner); } } - return action; } // This template method is needed to inject custom error checking logic in tests. @@ -1321,70 +1123,11 @@ protected final void ensureOpen() { } } - /** - * This listener increments the transmitted bytes metric on success. - */ - private class SendListener extends NotifyOnceListener { - - private final TcpChannel channel; - private final long messageSize; - private final ActionListener delegateListener; - - private SendListener(TcpChannel channel, long messageSize, ActionListener delegateListener) { - this.channel = channel; - this.messageSize = messageSize; - this.delegateListener = delegateListener; - } - - @Override - protected void innerOnResponse(Void v) { - transmittedBytesMetric.inc(messageSize); - delegateListener.onResponse(v); - } - - @Override - protected void innerOnFailure(Exception e) { - logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e); - delegateListener.onFailure(e); - } - } - - private class ReleaseListener implements ActionListener { - - private final Closeable optionalCloseable; - private final Runnable transportAdaptorCallback; - - private ReleaseListener(Closeable optionalCloseable, Runnable transportAdaptorCallback) { - this.optionalCloseable = optionalCloseable; - this.transportAdaptorCallback = transportAdaptorCallback; - } - - @Override - public void onResponse(Void aVoid) { - closeAndCallback(null); - } - - @Override - public void onFailure(Exception e) { - closeAndCallback(e); - } - - private void closeAndCallback(final Exception e) { - try { - IOUtils.close(optionalCloseable, transportAdaptorCallback::run); - } catch (final IOException inner) { - if (e != null) { - inner.addSuppressed(e); - } - throw new UncheckedIOException(inner); - } - } - } - @Override public final TransportStats getStats() { - return new TransportStats(acceptedChannels.size(), readBytesMetric.count(), readBytesMetric.sum(), transmittedBytesMetric.count(), - transmittedBytesMetric.sum()); + MeanMetric transmittedBytes = outboundHandler.getTransmittedBytes(); + return new TransportStats(acceptedChannels.size(), readBytesMetric.count(), readBytesMetric.sum(), transmittedBytes.count(), + transmittedBytes.sum()); } /** @@ -1559,7 +1302,7 @@ public void onFailure(Exception ex) { public void onTimeout() { if (countDown.fastForward()) { CloseableChannel.closeChannels(channels, false); - listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]")); + listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]")); } } } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportLogger.java b/server/src/main/java/org/elasticsearch/transport/TransportLogger.java index ea01cc4ddbfa6..0198609f4c3b6 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportLogger.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportLogger.java @@ -51,6 +51,10 @@ void logInboundMessage(TcpChannel channel, BytesReference message) { void logOutboundMessage(TcpChannel channel, BytesReference message) { if (logger.isTraceEnabled()) { try { + if (message.get(0) != 'E') { + // This is not an Elasticsearch transport message. + return; + } BytesReference withoutHeader = message.slice(HEADER_SIZE, message.length() - HEADER_SIZE); String logMessage = format(channel, withoutHeader, "WRITE"); logger.trace(logMessage); diff --git a/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java b/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java index 9051f302a591f..e237415a9c60e 100644 --- a/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.discovery.zen.UnicastHostsProvider; import org.elasticsearch.discovery.zen.ZenDiscovery; @@ -53,6 +54,7 @@ import java.util.function.Supplier; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class DiscoveryModuleTests extends ESTestCase { @@ -87,11 +89,12 @@ default Map> getDiscoveryTypes(ThreadPool threadPool @Before public void setupDummyServices() { - transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, null, null); + threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null); masterService = mock(MasterService.class); namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); clusterApplier = mock(ClusterApplier.class); - threadPool = mock(ThreadPool.class); clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); gatewayMetaState = mock(GatewayMetaState.class); } diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java index cc9295cee2e3f..b04a1cebbf3e3 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java @@ -62,6 +62,7 @@ public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(FileBasedUnicastHostsProviderTests.class.getName()); executorService = Executors.newSingleThreadExecutor(); + createTransportSvc(); } @After @@ -77,8 +78,7 @@ public void tearDown() throws Exception { } } - @Before - public void createTransportSvc() { + private void createTransportSvc() { final MockNioTransport transport = new MockNioTransport(Settings.EMPTY, Version.CURRENT, threadPool, new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, diff --git a/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java b/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java index 58dc2a1e55b47..aeb92dac73479 100644 --- a/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java +++ b/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java @@ -39,6 +39,8 @@ public void testStreamWithoutCompression() throws IOException { stream.write(expectedBytes); BytesReference bytesRef = stream.materializeBytes(); + // Closing compression stream does not close underlying stream + stream.close(); assertFalse(CompressorFactory.COMPRESSOR.isCompressed(bytesRef)); @@ -48,7 +50,8 @@ public void testStreamWithoutCompression() throws IOException { assertEquals(-1, streamInput.read()); assertArrayEquals(expectedBytes, actualBytes); - stream.close(); + + bStream.close(); // The bytes should be zeroed out on close for (byte b : bytesRef.toBytesRef().bytes) { @@ -64,6 +67,7 @@ public void testStreamWithCompression() throws IOException { stream.write(expectedBytes); BytesReference bytesRef = stream.materializeBytes(); + stream.close(); assertTrue(CompressorFactory.COMPRESSOR.isCompressed(bytesRef)); @@ -73,7 +77,8 @@ public void testStreamWithCompression() throws IOException { assertEquals(-1, streamInput.read()); assertArrayEquals(expectedBytes, actualBytes); - stream.close(); + + bStream.close(); // The bytes should be zeroed out on close for (byte b : bytesRef.toBytesRef().bytes) { diff --git a/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java new file mode 100644 index 0000000000000..499b6586543ed --- /dev/null +++ b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java @@ -0,0 +1,228 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.VersionUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +public class InboundMessageTests extends ESTestCase { + + private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + private final NamedWriteableRegistry registry = new NamedWriteableRegistry(Collections.emptyList()); + + public void testReadRequest() throws IOException { + String[] features = {"feature1", "feature2"}; + String value = randomAlphaOfLength(10); + Message message = new Message(value); + String action = randomAlphaOfLength(10); + long requestId = randomLong(); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + OutboundMessage.Request request = new OutboundMessage.Request(threadContext, features, message, version, action, requestId, + isHandshake, compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + // Check that the thread context is not deleted. + assertEquals("header_value", threadContext.getHeader("header")); + + threadContext.stashContext(); + threadContext.putHeader("header", "header_value2"); + + InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced); + // Check that deserialize does not overwrite current thread context. + assertEquals("header_value2", threadContext.getHeader("header")); + inboundMessage.getStoredContext().restore(); + assertEquals("header_value", threadContext.getHeader("header")); + assertEquals(isHandshake, inboundMessage.isHandshake()); + assertEquals(compress, inboundMessage.isCompress()); + assertEquals(version, inboundMessage.getVersion()); + assertEquals(action, inboundMessage.getActionName()); + assertEquals(new HashSet<>(Arrays.asList(features)), inboundMessage.getFeatures()); + assertTrue(inboundMessage.isRequest()); + assertFalse(inboundMessage.isResponse()); + assertFalse(inboundMessage.isError()); + assertEquals(value, new Message(inboundMessage.getStreamInput()).value); + } + + public void testReadResponse() throws IOException { + HashSet features = new HashSet<>(Arrays.asList("feature1", "feature2")); + String value = randomAlphaOfLength(10); + Message message = new Message(value); + long requestId = randomLong(); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + OutboundMessage.Response request = new OutboundMessage.Response(threadContext, features, message, version, requestId, isHandshake, + compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + // Check that the thread context is not deleted. + assertEquals("header_value", threadContext.getHeader("header")); + + threadContext.stashContext(); + threadContext.putHeader("header", "header_value2"); + + InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + // Check that deserialize does not overwrite current thread context. + assertEquals("header_value2", threadContext.getHeader("header")); + inboundMessage.getStoredContext().restore(); + assertEquals("header_value", threadContext.getHeader("header")); + assertEquals(isHandshake, inboundMessage.isHandshake()); + assertEquals(compress, inboundMessage.isCompress()); + assertEquals(version, inboundMessage.getVersion()); + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isRequest()); + assertFalse(inboundMessage.isError()); + assertEquals(value, new Message(inboundMessage.getStreamInput()).value); + } + + public void testReadErrorResponse() throws IOException { + HashSet features = new HashSet<>(Arrays.asList("feature1", "feature2")); + RemoteTransportException exception = new RemoteTransportException("error", new IOException()); + long requestId = randomLong(); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + OutboundMessage.Response request = new OutboundMessage.Response(threadContext, features, exception, version, requestId, + isHandshake, compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + // Check that the thread context is not deleted. + assertEquals("header_value", threadContext.getHeader("header")); + + threadContext.stashContext(); + threadContext.putHeader("header", "header_value2"); + + InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + // Check that deserialize does not overwrite current thread context. + assertEquals("header_value2", threadContext.getHeader("header")); + inboundMessage.getStoredContext().restore(); + assertEquals("header_value", threadContext.getHeader("header")); + assertEquals(isHandshake, inboundMessage.isHandshake()); + assertEquals(compress, inboundMessage.isCompress()); + assertEquals(version, inboundMessage.getVersion()); + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isRequest()); + assertTrue(inboundMessage.isError()); + assertEquals("[error]", inboundMessage.getStreamInput().readException().getMessage()); + } + + public void testEnsureVersionCompatibility() throws IOException { + testVersionIncompatibility(VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), + Version.CURRENT), Version.CURRENT, randomBoolean()); + + final Version version = Version.fromString("7.0.0"); + testVersionIncompatibility(Version.fromString("6.0.0"), version, true); + IllegalStateException ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("6.0.0"), version, false)); + assertEquals("Received message from unsupported version: [6.0.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + + // For handshake we are compatible with N-2 + testVersionIncompatibility(Version.fromString("5.6.0"), version, true); + ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("5.6.0"), version, false)); + assertEquals("Received message from unsupported version: [5.6.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + + ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("2.3.0"), version, true)); + assertEquals("Received handshake message from unsupported version: [2.3.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + + ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("2.3.0"), version, false)); + assertEquals("Received message from unsupported version: [2.3.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + } + + private void testVersionIncompatibility(Version version, Version currentVersion, boolean isHandshake) throws IOException { + String[] features = {}; + String value = randomAlphaOfLength(10); + Message message = new Message(value); + String action = randomAlphaOfLength(10); + long requestId = randomLong(); + boolean compress = randomBoolean(); + OutboundMessage.Request request = new OutboundMessage.Request(threadContext, features, message, version, action, requestId, + isHandshake, compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.Reader reader = new InboundMessage.Reader(currentVersion, registry, threadContext); + reader.deserialize(sliced); + } + + private static final class Message extends TransportMessage { + + public String value; + + private Message() { + } + + private Message(StreamInput in) throws IOException { + readFrom(in); + } + + private Message(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java new file mode 100644 index 0000000000000..01e391a30a732 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -0,0 +1,184 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +public class OutboundHandlerTests extends ESTestCase { + + private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); + private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + private OutboundHandler handler; + private FakeTcpChannel fakeTcpChannel; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportLogger transportLogger = new TransportLogger(); + fakeTcpChannel = new FakeTcpChannel(randomBoolean()); + handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger); + } + + @After + public void tearDown() throws Exception { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + super.tearDown(); + } + + public void testSendRawBytes() { + BytesArray bytesArray = new BytesArray("message".getBytes(StandardCharsets.UTF_8)); + + AtomicBoolean isSuccess = new AtomicBoolean(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); + handler.sendBytes(fakeTcpChannel, bytesArray, listener); + + BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); + ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + assertTrue(isSuccess.get()); + assertNull(exception.get()); + } else { + IOException e = new IOException("failed"); + sendListener.onFailure(e); + assertFalse(isSuccess.get()); + assertSame(e, exception.get()); + } + + assertEquals(bytesArray, reference); + } + + public void testSendMessage() throws IOException { + OutboundMessage message; + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = Version.CURRENT; + String actionName = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + String value = "message"; + threadContext.putHeader("header", "header_value"); + Writeable writeable = new Message(value); + + boolean isRequest = randomBoolean(); + if (isRequest) { + message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake, + compress); + } else { + message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress); + } + + AtomicBoolean isSuccess = new AtomicBoolean(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); + handler.sendMessage(fakeTcpChannel, message, listener); + + BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); + ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + assertTrue(isSuccess.get()); + assertNull(exception.get()); + } else { + IOException e = new IOException("failed"); + sendListener.onFailure(e); + assertFalse(isSuccess.get()); + assertSame(e, exception.get()); + } + + InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); + try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { + assertEquals(version, inboundMessage.getVersion()); + assertEquals(requestId, inboundMessage.getRequestId()); + if (isRequest) { + assertTrue(inboundMessage.isRequest()); + assertFalse(inboundMessage.isResponse()); + } else { + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isRequest()); + } + if (isHandshake) { + assertTrue(inboundMessage.isHandshake()); + } else { + assertFalse(inboundMessage.isHandshake()); + } + if (compress) { + assertTrue(inboundMessage.isCompress()); + } else { + assertFalse(inboundMessage.isCompress()); + } + Message readMessage = new Message(); + readMessage.readFrom(inboundMessage.getStreamInput()); + assertEquals(value, readMessage.value); + + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext(); + assertNull(threadContext.getHeader("header")); + storedContext.restore(); + assertEquals("header_value", threadContext.getHeader("header")); + } + } + } + + private static final class Message extends TransportMessage { + + public String value; + + private Message() { + } + + private Message(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java index f3294366b8fe1..a25ac2a551a22 100644 --- a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java @@ -37,7 +37,6 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.VersionUtils; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -158,41 +157,12 @@ public void testAddressLimit() throws Exception { assertEquals(102, addresses[2].getPort()); } - public void testEnsureVersionCompatibility() { - TcpTransport.ensureVersionCompatibility(VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), - Version.CURRENT), Version.CURRENT, randomBoolean()); - - final Version version = Version.fromString("7.0.0"); - TcpTransport.ensureVersionCompatibility(Version.fromString("6.0.0"), version, true); - IllegalStateException ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("6.0.0"), version, false)); - assertEquals("Received message from unsupported version: [6.0.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - - // For handshake we are compatible with N-2 - TcpTransport.ensureVersionCompatibility(Version.fromString("5.6.0"), version, true); - ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("5.6.0"), version, false)); - assertEquals("Received message from unsupported version: [5.6.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - - ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("2.3.0"), version, true)); - assertEquals("Received handshake message from unsupported version: [2.3.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - - ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("2.3.0"), version, false)); - assertEquals("Received message from unsupported version: [2.3.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - } - @SuppressForbidden(reason = "Allow accessing localhost") public void testCompressRequestAndResponse() throws IOException { final boolean compressed = randomBoolean(); Req request = new Req(randomRealisticUnicodeOfLengthBetween(10, 100)); ThreadPool threadPool = new TestThreadPool(TcpTransportTests.class.getName()); - AtomicReference requestCaptor = new AtomicReference<>(); + AtomicReference messageCaptor = new AtomicReference<>(); try { TcpTransport transport = new TcpTransport("test", Settings.EMPTY, Version.CURRENT, threadPool, PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), null, null) { @@ -204,7 +174,7 @@ protected FakeServerChannel bind(String name, InetSocketAddress address) throws @Override protected FakeTcpChannel initiateChannel(DiscoveryNode node) throws IOException { - return new FakeTcpChannel(false, requestCaptor); + return new FakeTcpChannel(false); } @Override @@ -219,7 +189,7 @@ public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, int numConnections = profile.getNumConnections(); ArrayList fakeChannels = new ArrayList<>(numConnections); for (int i = 0; i < numConnections; ++i) { - fakeChannels.add(new FakeTcpChannel(false, requestCaptor)); + fakeChannels.add(new FakeTcpChannel(false, messageCaptor)); } listener.onResponse(new NodeChannels(node, fakeChannels, profile, Version.CURRENT)); return () -> CloseableChannel.closeChannels(fakeChannels, false); @@ -241,12 +211,12 @@ public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, (request1, channel, task) -> channel.sendResponse(TransportResponse.Empty.INSTANCE), ThreadPool.Names.SAME, true, true)); - BytesReference reference = requestCaptor.get(); + BytesReference reference = messageCaptor.get(); assertNotNull(reference); AtomicReference responseCaptor = new AtomicReference<>(); InetSocketAddress address = new InetSocketAddress(InetAddress.getLocalHost(), 0); - FakeTcpChannel responseChannel = new FakeTcpChannel(true, address, address, responseCaptor); + FakeTcpChannel responseChannel = new FakeTcpChannel(true, address, address, "profile", responseCaptor); transport.messageReceived(reference.slice(6, reference.length() - 6), responseChannel); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 43542d48a6c39..87343d4c82087 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -2038,11 +2038,14 @@ public void testTcpHandshake() { new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - protected String handleRequest(TcpChannel mockChannel, String profileName, StreamInput stream, long requestId, - int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { - return super.handleRequest(mockChannel, profileName, stream, requestId, messageLengthBytes, version, remoteAddress, - (byte) (status & ~(1 << 3))); // we flip the isHandshake bit back and act like the handler is not found + // we flip the isHandshake bit back and act like the handler is not found + byte status = (byte) (request.status & ~(1 << 3)); + Version version = request.getVersion(); + InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version, + status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput()); + super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes); } }; diff --git a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java index 63cacfbb093a8..bb392554305c0 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java @@ -31,9 +31,10 @@ public class FakeTcpChannel implements TcpChannel { private final InetSocketAddress localAddress; private final InetSocketAddress remoteAddress; private final String profile; - private final AtomicReference messageCaptor; private final ChannelStats stats = new ChannelStats(); private final CompletableContext closeContext = new CompletableContext<>(); + private final AtomicReference messageCaptor; + private final AtomicReference> listenerCaptor; public FakeTcpChannel() { this(false, "profile", new AtomicReference<>()); @@ -47,11 +48,6 @@ public FakeTcpChannel(boolean isServer, AtomicReference messageC this(isServer, "profile", messageCaptor); } - public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress, - AtomicReference messageCaptor) { - this(isServer, localAddress, remoteAddress,"profile", messageCaptor); - } - public FakeTcpChannel(boolean isServer, String profile, AtomicReference messageCaptor) { this(isServer, null, null, profile, messageCaptor); @@ -64,6 +60,7 @@ public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSock this.remoteAddress = remoteAddress; this.profile = profile; this.messageCaptor = messageCaptor; + this.listenerCaptor = new AtomicReference<>(); } @Override @@ -89,6 +86,7 @@ public InetSocketAddress getRemoteAddress() { @Override public void sendMessage(BytesReference reference, ActionListener listener) { messageCaptor.set(reference); + listenerCaptor.set(listener); } @Override @@ -115,4 +113,12 @@ public boolean isOpen() { public ChannelStats getChannelStats() { return stats; } + + public AtomicReference getMessageCaptor() { + return messageCaptor; + } + + public AtomicReference> getListenerCaptor() { + return listenerCaptor; + } }