From 21838d73b5f26490967d0db3d6667a701de6387a Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Mon, 21 Jan 2019 14:14:18 -0700 Subject: [PATCH] Extract message serialization from `TcpTransport` (#37034) This commit introduces a NetworkMessage class. This class has two subclasses - InboundMessage and OutboundMessage. These messages can be serialized and deserialized independent of the transport. This allows more granular testing. Additionally, the serialization mechanism is now a simple Supplier. This builds the framework to eventually move the serialization of transport messages to the network thread. This is the one serialization component that is not currently performed on the network thread (transport deserialization and http serialization and deserialization are all on the network thread). --- .../elasticsearch/nio/BytesWriteHandler.java | 3 + .../transport/netty4/Netty4TransportIT.java | 10 +- .../transport/nio/NioTransportIT.java | 10 +- .../CompressibleBytesOutputStream.java | 11 +- .../transport/InboundMessage.java | 168 ++++++++ .../transport/NetworkMessage.java | 76 ++++ .../transport/OutboundHandler.java | 167 ++++++++ .../transport/OutboundMessage.java | 155 +++++++ .../elasticsearch/transport/TcpTransport.java | 399 ++++-------------- .../transport/TransportLogger.java | 4 + .../discovery/DiscoveryModuleTests.java | 7 +- .../FileBasedUnicastHostsProviderTests.java | 4 +- .../CompressibleBytesOutputStreamTests.java | 9 +- .../transport/InboundMessageTests.java | 228 ++++++++++ .../transport/OutboundHandlerTests.java | 184 ++++++++ .../transport/TcpTransportTests.java | 40 +- .../AbstractSimpleTransportTestCase.java | 11 +- .../transport/FakeTcpChannel.java | 18 +- 18 files changed, 1104 insertions(+), 400 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/transport/InboundMessage.java create mode 100644 server/src/main/java/org/elasticsearch/transport/NetworkMessage.java create mode 100644 server/src/main/java/org/elasticsearch/transport/OutboundHandler.java create mode 100644 server/src/main/java/org/elasticsearch/transport/OutboundMessage.java create mode 100644 server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java create mode 100644 server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java index 87c0ff2817eb7..2d57faf5cb897 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java @@ -34,14 +34,17 @@ public WriteOperation createWriteOperation(SocketChannelContext context, Object return new FlushReadyWrite(context, (ByteBuffer[]) message, listener); } + @Override public List writeToBytes(WriteOperation writeOperation) { assert writeOperation instanceof FlushReadyWrite : "Write operation must be flush ready"; return Collections.singletonList((FlushReadyWrite) writeOperation); } + @Override public List pollFlushOperations() { return EMPTY_LIST; } + @Override public void close() {} } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index 28d32f50bfc69..bc24789341e04 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -36,12 +36,12 @@ import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.Scope; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundMessage; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportSettings; import java.io.IOException; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -111,13 +111,9 @@ public ExceptionThrowingNetty4Transport( } @Override - protected String handleRequest(TcpChannel channel, String profileName, - StreamInput stream, long requestId, int messageLengthBytes, Version version, - InetSocketAddress remoteAddress, byte status) throws IOException { - String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version, - remoteAddress, status); + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; - return action; } @Override diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java index 334e63dc0bf95..087c3758bb98b 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java @@ -38,12 +38,12 @@ import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.Scope; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundMessage; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportSettings; import java.io.IOException; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -113,13 +113,9 @@ public Map> getTransports(Settings settings, ThreadP } @Override - protected String handleRequest(TcpChannel channel, String profileName, - StreamInput stream, long requestId, int messageLengthBytes, Version version, - InetSocketAddress remoteAddress, byte status) throws IOException { - String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version, - remoteAddress, status); + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; - return action; } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java b/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java index 54f4d1d0c8d84..4116f88b14224 100644 --- a/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java +++ b/server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java @@ -39,8 +39,8 @@ * written to this stream. If compression is enabled, the proper EOS bytes will be written at that point. * The underlying {@link BytesReference} will be returned. * - * {@link CompressibleBytesOutputStream#close()} should be called when the bytes are no longer needed and - * can be safely released. + * {@link CompressibleBytesOutputStream#close()} will NOT close the underlying stream. The byte stream passed + * in the constructor must be closed individually. */ final class CompressibleBytesOutputStream extends StreamOutput { @@ -92,12 +92,9 @@ public void flush() throws IOException { @Override public void close() throws IOException { - if (stream == bytesStreamOutput) { - assert shouldCompress == false : "If the streams are the same we should not be compressing"; - IOUtils.close(stream); - } else { + if (stream != bytesStreamOutput) { assert shouldCompress : "If the streams are different we should be compressing"; - IOUtils.close(stream, bytesStreamOutput); + IOUtils.close(stream); } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java new file mode 100644 index 0000000000000..44e3b017ed27e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -0,0 +1,168 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.compress.Compressor; +import org.elasticsearch.common.compress.CompressorFactory; +import org.elasticsearch.common.compress.NotCompressedException; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.internal.io.IOUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Set; +import java.util.TreeSet; + +public abstract class InboundMessage extends NetworkMessage implements Closeable { + + private final StreamInput streamInput; + + InboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { + super(threadContext, version, status, requestId); + this.streamInput = streamInput; + } + + StreamInput getStreamInput() { + return streamInput; + } + + static class Reader { + + private final Version version; + private final NamedWriteableRegistry namedWriteableRegistry; + private final ThreadContext threadContext; + + Reader(Version version, NamedWriteableRegistry namedWriteableRegistry, ThreadContext threadContext) { + this.version = version; + this.namedWriteableRegistry = namedWriteableRegistry; + this.threadContext = threadContext; + } + + InboundMessage deserialize(BytesReference reference) throws IOException { + int messageLengthBytes = reference.length(); + final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; + // we have additional bytes to read, outside of the header + boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0; + StreamInput streamInput = reference.streamInput(); + boolean success = false; + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + long requestId = streamInput.readLong(); + byte status = streamInput.readByte(); + Version remoteVersion = Version.fromId(streamInput.readInt()); + final boolean isHandshake = TransportStatus.isHandshake(status); + ensureVersionCompatibility(remoteVersion, version, isHandshake); + if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamInput.available() > 0) { + Compressor compressor; + try { + final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE; + compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed)); + } catch (NotCompressedException ex) { + int maxToRead = Math.min(reference.length(), 10); + StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [") + .append(maxToRead).append("] content bytes out of [").append(reference.length()) + .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are ["); + for (int i = 0; i < maxToRead; i++) { + sb.append(reference.get(i)).append(","); + } + sb.append("]"); + throw new IllegalStateException(sb.toString()); + } + streamInput = compressor.streamInput(streamInput); + } + streamInput = new NamedWriteableAwareStreamInput(streamInput, namedWriteableRegistry); + streamInput.setVersion(remoteVersion); + + threadContext.readHeaders(streamInput); + + InboundMessage message; + if (TransportStatus.isRequest(status)) { + final Set features; + if (remoteVersion.onOrAfter(Version.V_6_3_0)) { + features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(streamInput.readStringArray()))); + } else { + features = Collections.emptySet(); + } + final String action = streamInput.readString(); + message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput); + } else { + message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput); + } + success = true; + return message; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(streamInput); + } + } + } + } + + @Override + public void close() throws IOException { + streamInput.close(); + } + + private static void ensureVersionCompatibility(Version version, Version currentVersion, boolean isHandshake) { + // for handshakes we are compatible with N-2 since otherwise we can't figure out our initial version + // since we are compatible with N-1 and N+1 so we always send our minCompatVersion as the initial version in the + // handshake. This looks odd but it's required to establish the connection correctly we check for real compatibility + // once the connection is established + final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion; + if (version.isCompatible(compatibilityVersion) == false) { + final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion(); + String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: ["; + throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]"); + } + } + + public static class RequestMessage extends InboundMessage { + + private final String actionName; + private final Set features; + + RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set features, + StreamInput streamInput) { + super(threadContext, version, status, requestId, streamInput); + this.actionName = actionName; + this.features = features; + } + + String getActionName() { + return actionName; + } + + Set getFeatures() { + return features; + } + } + + public static class ResponseMessage extends InboundMessage { + + ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { + super(threadContext, version, status, requestId, streamInput); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java new file mode 100644 index 0000000000000..7d8dbb8a0f1b6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +/** + * Represents a transport message sent over the network. Subclasses implement serialization and + * deserialization. + */ +public abstract class NetworkMessage { + + protected final Version version; + protected final ThreadContext threadContext; + protected final ThreadContext.StoredContext storedContext; + protected final long requestId; + protected final byte status; + + NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) { + this.threadContext = threadContext; + storedContext = threadContext.stashContext(); + storedContext.restore(); + this.version = version; + this.requestId = requestId; + this.status = status; + } + + public Version getVersion() { + return version; + } + + public long getRequestId() { + return requestId; + } + + boolean isCompress() { + return TransportStatus.isCompress(status); + } + + ThreadContext.StoredContext getStoredContext() { + return storedContext; + } + + boolean isResponse() { + return TransportStatus.isRequest(status) == false; + } + + boolean isRequest() { + return TransportStatus.isRequest(status); + } + + boolean isHandshake() { + return TransportStatus.isHandshake(status); + } + + boolean isError() { + return TransportStatus.isError(status); + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java new file mode 100644 index 0000000000000..31c67d930c538 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -0,0 +1,167 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.NotifyOnceListener; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.metrics.MeanMetric; +import org.elasticsearch.common.network.CloseableChannel; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; + +final class OutboundHandler { + + private static final Logger logger = LogManager.getLogger(OutboundHandler.class); + + private final MeanMetric transmittedBytesMetric = new MeanMetric(); + private final ThreadPool threadPool; + private final BigArrays bigArrays; + private final TransportLogger transportLogger; + + OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) { + this.threadPool = threadPool; + this.bigArrays = bigArrays; + this.transportLogger = transportLogger; + } + + void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener listener) { + channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + SendContext sendContext = new SendContext(channel, () -> bytes, listener); + try { + internalSendMessage(channel, sendContext); + } catch (IOException e) { + // This should not happen as the bytes are already serialized + throw new AssertionError(e); + } + } + + void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { + channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); + SendContext sendContext = new SendContext(channel, serializer, listener, serializer); + internalSendMessage(channel, sendContext); + } + + /** + * sends a message to the given channel, using the given callbacks. + */ + private void internalSendMessage(TcpChannel channel, SendContext sendContext) throws IOException { + channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + BytesReference reference = sendContext.get(); + try { + channel.sendMessage(reference, sendContext); + } catch (RuntimeException ex) { + sendContext.onFailure(ex); + CloseableChannel.closeChannel(channel); + throw ex; + } + + } + + MeanMetric getTransmittedBytes() { + return transmittedBytesMetric; + } + + private static class MessageSerializer implements CheckedSupplier, Releasable { + + private final OutboundMessage message; + private final BigArrays bigArrays; + private volatile ReleasableBytesStreamOutput bytesStreamOutput; + + private MessageSerializer(OutboundMessage message, BigArrays bigArrays) { + this.message = message; + this.bigArrays = bigArrays; + } + + @Override + public BytesReference get() throws IOException { + bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays); + return message.serialize(bytesStreamOutput); + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(bytesStreamOutput); + } + } + + private class SendContext extends NotifyOnceListener implements CheckedSupplier { + + private final TcpChannel channel; + private final CheckedSupplier messageSupplier; + private final ActionListener listener; + private final Releasable optionalReleasable; + private long messageSize = -1; + + private SendContext(TcpChannel channel, CheckedSupplier messageSupplier, + ActionListener listener) { + this(channel, messageSupplier, listener, null); + } + + private SendContext(TcpChannel channel, CheckedSupplier messageSupplier, + ActionListener listener, Releasable optionalReleasable) { + this.channel = channel; + this.messageSupplier = messageSupplier; + this.listener = listener; + this.optionalReleasable = optionalReleasable; + } + + public BytesReference get() throws IOException { + BytesReference message; + try { + message = messageSupplier.get(); + messageSize = message.length(); + transportLogger.logOutboundMessage(channel, message); + return message; + } catch (Exception e) { + onFailure(e); + throw e; + } + } + + @Override + protected void innerOnResponse(Void v) { + assert messageSize != -1 : "If onResponse is being called, the message should have been serialized"; + transmittedBytesMetric.inc(messageSize); + closeAndCallback(() -> listener.onResponse(v)); + } + + @Override + protected void innerOnFailure(Exception e) { + logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e); + closeAndCallback(() -> listener.onFailure(e)); + } + + private void closeAndCallback(Runnable runnable) { + Releasables.close(optionalReleasable, runnable::run); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java new file mode 100644 index 0000000000000..cc295a68df3c7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.ThreadContext; + +import java.io.IOException; +import java.util.Set; + +abstract class OutboundMessage extends NetworkMessage implements Writeable { + + private final Writeable message; + + OutboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, Writeable message) { + super(threadContext, version, status, requestId); + this.message = message; + } + + BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { + storedContext.restore(); + bytesStream.setVersion(version); + bytesStream.skip(TcpHeader.HEADER_SIZE); + + // The compressible bytes stream will not close the underlying bytes stream + BytesReference reference; + try (CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) { + stream.setVersion(version); + threadContext.writeTo(stream); + writeTo(stream); + reference = writeMessage(stream); + } + bytesStream.seek(0); + TcpHeader.writeHeader(bytesStream, requestId, status, version, reference.length() - TcpHeader.HEADER_SIZE); + return reference; + } + + private BytesReference writeMessage(CompressibleBytesOutputStream stream) throws IOException { + final BytesReference zeroCopyBuffer; + if (message instanceof BytesTransportRequest) { + BytesTransportRequest bRequest = (BytesTransportRequest) message; + bRequest.writeThin(stream); + zeroCopyBuffer = bRequest.bytes; + } else if (message instanceof RemoteTransportException) { + stream.writeException((RemoteTransportException) message); + zeroCopyBuffer = BytesArray.EMPTY; + } else { + message.writeTo(stream); + zeroCopyBuffer = BytesArray.EMPTY; + } + // we have to call materializeBytes() here before accessing the bytes. A CompressibleBytesOutputStream + // might be implementing compression. And materializeBytes() ensures that some marker bytes (EOS marker) + // are written. Otherwise we barf on the decompressing end when we read past EOF on purpose in the + // #validateRequest method. this might be a problem in deflate after all but it's important to write + // the marker bytes. + final BytesReference message = stream.materializeBytes(); + if (zeroCopyBuffer.length() == 0) { + return message; + } else { + return new CompositeBytesReference(message, zeroCopyBuffer); + } + } + + static class Request extends OutboundMessage { + + private final String[] features; + private final String action; + + Request(ThreadContext threadContext, String[] features, Writeable message, Version version, String action, long requestId, + boolean isHandshake, boolean compress) { + super(threadContext, version, setStatus(compress, isHandshake, message), requestId, message); + this.features = features; + this.action = action; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (version.onOrAfter(Version.V_6_3_0)) { + out.writeStringArray(features); + } + out.writeString(action); + } + + private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) { + byte status = 0; + status = TransportStatus.setRequest(status); + if (compress && OutboundMessage.canCompress(message)) { + status = TransportStatus.setCompress(status); + } + if (isHandshake) { + status = TransportStatus.setHandshake(status); + } + + return status; + } + } + + static class Response extends OutboundMessage { + + private final Set features; + + Response(ThreadContext threadContext, Set features, Writeable message, Version version, long requestId, + boolean isHandshake, boolean compress) { + super(threadContext, version, setStatus(compress, isHandshake, message), requestId, message); + this.features = features; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.setFeatures(features); + } + + private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) { + byte status = 0; + status = TransportStatus.setResponse(status); + if (message instanceof RemoteTransportException) { + status = TransportStatus.setError(status); + } + if (compress) { + status = TransportStatus.setCompress(status); + } + if (isHandshake) { + status = TransportStatus.setHandshake(status); + } + + return status; + } + } + + private static boolean canCompress(Writeable message) { + return message instanceof BytesTransportRequest == false; + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index e0d939b37a314..e0b466e2ec631 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -26,24 +26,16 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.Lifecycle; -import org.elasticsearch.common.compress.Compressor; -import org.elasticsearch.common.compress.CompressorFactory; -import org.elasticsearch.common.compress.NotCompressedException; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.metrics.MeanMetric; @@ -64,17 +56,14 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.node.Node; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; -import java.io.Closeable; import java.io.IOException; import java.io.StreamCorruptedException; -import java.io.UncheckedIOException; import java.net.BindException; import java.net.InetAddress; import java.net.InetSocketAddress; @@ -136,8 +125,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final Map> serverChannels = newConcurrentMap(); private final Set acceptedChannels = ConcurrentCollections.newConcurrentSet(); - private final NamedWriteableRegistry namedWriteableRegistry; - // this lock is here to make sure we close this transport and disconnect all the client nodes // connections while no connect operations is going on private final ReadWriteLock closeLock = new ReentrantReadWriteLock(); @@ -145,15 +132,16 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final String transportName; private final MeanMetric readBytesMetric = new MeanMetric(); - private final MeanMetric transmittedBytesMetric = new MeanMetric(); private volatile Map> requestHandlers = Collections.emptyMap(); private final ResponseHandlers responseHandlers = new ResponseHandlers(); private final TransportLogger transportLogger; private final TransportHandshaker handshaker; private final TransportKeepAlive keepAlive; + private final InboundMessage.Reader reader; + private final OutboundHandler outboundHandler; private final String nodeName; - public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool, + public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { this.settings = settings; @@ -163,17 +151,18 @@ public TcpTransport(String transportName, Settings settings, Version version, T this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS); this.pageCacheRecycler = pageCacheRecycler; this.circuitBreakerService = circuitBreakerService; - this.namedWriteableRegistry = namedWriteableRegistry; this.networkService = networkService; this.transportName = transportName; this.transportLogger = new TransportLogger(); + this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger); this.handshaker = new TransportHandshaker(version, threadPool, (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), - TransportRequestOptions.EMPTY, v, false, TransportStatus.setHandshake((byte) 0)), + TransportRequestOptions.EMPTY, v, false, true), (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, false, TransportStatus.setHandshake((byte) 0))); - this.keepAlive = new TransportKeepAlive(threadPool, this::internalSendMessage); + TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); + this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); + this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); this.nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings); @@ -280,7 +269,7 @@ public void sendRequest(long requestId, String action, TransportRequest request, throw new NodeNotConnectedException(node, "connection already closed"); } TcpChannel channel = channel(options.type()); - sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress, (byte) 0); + sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress); } } @@ -573,7 +562,8 @@ protected final void doStop() { for (Map.Entry> entry : serverChannels.entrySet()) { String profile = entry.getKey(); List channels = entry.getValue(); - ActionListener closeFailLogger = ActionListener.wrap(c -> {}, + ActionListener closeFailLogger = ActionListener.wrap(c -> { + }, e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e)); channels.forEach(c -> c.addCloseListener(closeFailLogger)); CloseableChannel.closeChannels(channels, true); @@ -628,26 +618,7 @@ public void onException(TcpChannel channel, Exception e) { // in case we are able to return data, serialize the exception content and sent it back to the client if (channel.isOpen()) { BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8)); - ActionListener listener = new ActionListener() { - @Override - public void onResponse(Void aVoid) { - CloseableChannel.closeChannel(channel); - } - - @Override - public void onFailure(Exception e) { - logger.debug("failed to send message to httpOnTransport channel", e); - CloseableChannel.closeChannel(channel); - } - }; - // We do not call internalSendMessage because we are not sending a message that is an - // elasticsearch binary message. We are just serializing an exception here. Not formatting it - // as an elasticsearch transport message. - try { - channel.sendMessage(message, new SendListener(channel, message.length(), listener)); - } catch (Exception ex) { - listener.onFailure(ex); - } + outboundHandler.sendBytes(channel, message, ActionListener.wrap(() -> CloseableChannel.closeChannel(channel))); } } else { logger.warn(() -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e); @@ -691,65 +662,21 @@ protected void serverAcceptedChannel(TcpChannel channel) { */ protected abstract void stopInternal(); - private boolean canCompress(TransportRequest request) { - return request instanceof BytesTransportRequest == false; - } - private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, final TransportRequest request, TransportRequestOptions options, Version channelVersion, - boolean compressRequest, byte status) throws IOException, TransportException { - - // only compress if asked and the request is not bytes. Otherwise only - // the header part is compressed, and the "body" can't be extracted as compressed - final boolean compressMessage = compressRequest && canCompress(request); - - status = TransportStatus.setRequest(status); - ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays); - final CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bStream, compressMessage); - boolean addedReleaseListener = false; - try { - if (compressMessage) { - status = TransportStatus.setCompress(status); - } - - // we pick the smallest of the 2, to support both backward and forward compatibility - // note, this is the only place we need to do this, since from here on, we use the serialized version - // as the version to use also when the node receiving this request will send the response with - Version version = Version.min(this.version, channelVersion); - - stream.setVersion(version); - threadPool.getThreadContext().writeTo(stream); - if (version.onOrAfter(Version.V_6_3_0)) { - stream.writeStringArray(features); - } - stream.writeString(action); - BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream); - final TransportRequestOptions finalOptions = options; - // this might be called in a different thread - ReleaseListener releaseListener = new ReleaseListener(stream, - () -> messageListener.onRequestSent(node, requestId, action, request, finalOptions)); - internalSendMessage(channel, message, releaseListener); - addedReleaseListener = true; - } finally { - if (!addedReleaseListener) { - IOUtils.close(stream); - } - } + boolean compressRequest) throws IOException, TransportException { + sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false); } - /** - * sends a message to the given channel, using the given callbacks. - */ - private void internalSendMessage(TcpChannel channel, BytesReference message, ActionListener listener) { - channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); - transportLogger.logOutboundMessage(channel, message); - try { - channel.sendMessage(message, new SendListener(channel, message.length(), listener)); - } catch (Exception ex) { - // call listener to ensure that any resources are released - listener.onFailure(ex); - onException(channel, ex); - } + private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, + final TransportRequest request, TransportRequestOptions options, Version channelVersion, + boolean compressRequest, boolean isHandshake) throws IOException, TransportException { + Version version = Version.min(this.version, channelVersion); + OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, + requestId, isHandshake, compressRequest); + ActionListener listener = ActionListener.wrap(() -> + messageListener.onRequestSent(node, requestId, action, request, options)); + outboundHandler.sendMessage(channel, message, listener); } /** @@ -769,23 +696,13 @@ public void sendErrorResponse( final Exception error, final long requestId, final String action) throws IOException { - try (BytesStreamOutput stream = new BytesStreamOutput()) { - stream.setVersion(nodeVersion); - stream.setFeatures(features); - RemoteTransportException tx = new RemoteTransportException( - nodeName, new TransportAddress(channel.getLocalAddress()), action, error); - threadPool.getThreadContext().writeTo(stream); - stream.writeException(tx); - byte status = 0; - status = TransportStatus.setResponse(status); - status = TransportStatus.setError(status); - final BytesReference bytes = stream.bytes(); - final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length()); - CompositeBytesReference message = new CompositeBytesReference(header, bytes); - ReleaseListener releaseListener = new ReleaseListener(null, - () -> messageListener.onResponseSent(requestId, action, error)); - internalSendMessage(channel, message, releaseListener); - } + Version version = Version.min(this.version, nodeVersion); + TransportAddress address = new TransportAddress(channel.getLocalAddress()); + RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId, + false, false); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); + outboundHandler.sendMessage(channel, message, listener); } /** @@ -801,7 +718,7 @@ public void sendResponse( final long requestId, final String action, final boolean compress) throws IOException { - sendResponse(nodeVersion, features, channel, response, requestId, action, compress, (byte) 0); + sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false); } private void sendResponse( @@ -812,82 +729,18 @@ private void sendResponse( final long requestId, final String action, boolean compress, - byte status) throws IOException { - - status = TransportStatus.setResponse(status); - ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays); - CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bStream, compress); - boolean addedReleaseListener = false; - try { - if (compress) { - status = TransportStatus.setCompress(status); - } - threadPool.getThreadContext().writeTo(stream); - stream.setVersion(nodeVersion); - stream.setFeatures(features); - BytesReference message = buildMessage(requestId, status, nodeVersion, response, stream); - - // this might be called in a different thread - ReleaseListener releaseListener = new ReleaseListener(stream, - () -> messageListener.onResponseSent(requestId, action, response)); - internalSendMessage(channel, message, releaseListener); - addedReleaseListener = true; - } finally { - if (!addedReleaseListener) { - IOUtils.close(stream); - } - } - } - - /** - * Writes the Tcp message header into a bytes reference. - * - * @param requestId the request ID - * @param status the request status - * @param protocolVersion the protocol version used to serialize the data in the message - * @param length the payload length in bytes - * @see TcpHeader - */ - private BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException { - try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) { - headerOutput.setVersion(protocolVersion); - TcpHeader.writeHeader(headerOutput, requestId, status, protocolVersion, length); - final BytesReference bytes = headerOutput.bytes(); - assert bytes.length() == TcpHeader.HEADER_SIZE : "header size mismatch expected: " + TcpHeader.HEADER_SIZE + " but was: " - + bytes.length(); - return bytes; - } - } - - /** - * Serializes the given message into a bytes representation - */ - private BytesReference buildMessage(long requestId, byte status, Version nodeVersion, TransportMessage message, - CompressibleBytesOutputStream stream) throws IOException { - final BytesReference zeroCopyBuffer; - if (message instanceof BytesTransportRequest) { // what a shitty optimization - we should use a direct send method instead - BytesTransportRequest bRequest = (BytesTransportRequest) message; - assert nodeVersion.equals(bRequest.version()); - bRequest.writeThin(stream); - zeroCopyBuffer = bRequest.bytes; - } else { - message.writeTo(stream); - zeroCopyBuffer = BytesArray.EMPTY; - } - // we have to call materializeBytes() here before accessing the bytes. A CompressibleBytesOutputStream - // might be implementing compression. And materializeBytes() ensures that some marker bytes (EOS marker) - // are written. Otherwise we barf on the decompressing end when we read past EOF on purpose in the - // #validateRequest method. this might be a problem in deflate after all but it's important to write - // the marker bytes. - final BytesReference messageBody = stream.materializeBytes(); - final BytesReference header = buildHeader(requestId, status, stream.getVersion(), messageBody.length() + zeroCopyBuffer.length()); - return new CompositeBytesReference(header, messageBody, zeroCopyBuffer); + boolean isHandshake) throws IOException { + Version version = Version.min(this.version, nodeVersion); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, + requestId, isHandshake, compress); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); + outboundHandler.sendMessage(channel, message, listener); } /** * Handles inbound message that has been decoded. * - * @param channel the channel the message if fomr + * @param channel the channel the message is from * @param message the message */ public void inboundMessage(TcpChannel channel, BytesReference message) { @@ -1055,53 +908,26 @@ public HttpOnTransportException(StreamInput in) throws IOException { * This method handles the message receive part for both request and responses */ public final void messageReceived(BytesReference reference, TcpChannel channel) throws IOException { - String profileName = channel.getProfile(); + readBytesMetric.inc(reference.length() + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE); InetSocketAddress remoteAddress = channel.getRemoteAddress(); - int messageLengthBytes = reference.length(); - final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; - readBytesMetric.inc(totalMessageSize); - // we have additional bytes to read, outside of the header - boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0; - StreamInput streamIn = reference.streamInput(); - boolean success = false; - try (ThreadContext.StoredContext tCtx = threadPool.getThreadContext().stashContext()) { - long requestId = streamIn.readLong(); - byte status = streamIn.readByte(); - Version version = Version.fromId(streamIn.readInt()); - if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamIn.available() > 0) { - Compressor compressor; - try { - final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE; - compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed)); - } catch (NotCompressedException ex) { - int maxToRead = Math.min(reference.length(), 10); - StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [").append(maxToRead) - .append("] content bytes out of [").append(reference.length()) - .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are ["); - for (int i = 0; i < maxToRead; i++) { - sb.append(reference.get(i)).append(","); - } - sb.append("]"); - throw new IllegalStateException(sb.toString()); - } - streamIn = compressor.streamInput(streamIn); - } - final boolean isHandshake = TransportStatus.isHandshake(status); - ensureVersionCompatibility(version, this.version, isHandshake); - streamIn = new NamedWriteableAwareStreamInput(streamIn, namedWriteableRegistry); - streamIn.setVersion(version); - threadPool.getThreadContext().readHeaders(streamIn); - threadPool.getThreadContext().putTransient("_remote_address", remoteAddress); - if (TransportStatus.isRequest(status)) { - handleRequest(channel, profileName, streamIn, requestId, messageLengthBytes, version, remoteAddress, status); + + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext existing = threadContext.stashContext(); + InboundMessage message = reader.deserialize(reference)) { + // Place the context with the headers from the message + message.getStoredContext().restore(); + threadContext.putTransient("_remote_address", remoteAddress); + if (message.isRequest()) { + handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length()); } else { final TransportResponseHandler handler; - if (isHandshake) { + long requestId = message.getRequestId(); + if (message.isHandshake()) { handler = handshaker.removeHandlerForHandshake(requestId); } else { TransportResponseHandler theHandler = responseHandlers.onResponseReceived(requestId, messageListener); - if (theHandler == null && TransportStatus.isError(status)) { + if (theHandler == null && message.isError()) { handler = handshaker.removeHandlerForHandshake(requestId); } else { handler = theHandler; @@ -1109,40 +935,20 @@ public final void messageReceived(BytesReference reference, TcpChannel channel) } // ignore if its null, the service logs it if (handler != null) { - if (TransportStatus.isError(status)) { - handlerResponseError(streamIn, handler); + if (message.isError()) { + handlerResponseError(message.getStreamInput(), handler); } else { - handleResponse(remoteAddress, streamIn, handler); + handleResponse(remoteAddress, message.getStreamInput(), handler); } // Check the entire message has been read - final int nextByte = streamIn.read(); + final int nextByte = message.getStreamInput().read(); // calling read() is useful to make sure the message is fully read, even if there is an EOS marker if (nextByte != -1) { throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler [" - + handler + "], error [" + TransportStatus.isError(status) + "]; resetting"); + + handler + "], error [" + message.isError() + "]; resetting"); } } } - success = true; - } finally { - if (success) { - IOUtils.close(streamIn); - } else { - IOUtils.closeWhileHandlingException(streamIn); - } - } - } - - static void ensureVersionCompatibility(Version version, Version currentVersion, boolean isHandshake) { - // for handshakes we are compatible with N-2 since otherwise we can't figure out our initial version - // since we are compatible with N-1 and N+1 so we always send our minCompatVersion as the initial version in the - // handshake. This looks odd but it's required to establish the connection correctly we check for real compatibility - // once the connection is established - final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion; - if (version.isCompatible(compatibilityVersion) == false) { - final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion(); - String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: ["; - throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]"); } } @@ -1198,20 +1004,17 @@ private void handleException(final TransportResponseHandler handler, Throwable e }); } - protected String handleRequest(TcpChannel channel, String profileName, final StreamInput stream, long requestId, - int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) - throws IOException { - final Set features; - if (version.onOrAfter(Version.V_6_3_0)) { - features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(stream.readStringArray()))); - } else { - features = Collections.emptySet(); - } - final String action = stream.readString(); + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException { + final Set features = message.getFeatures(); + final String profileName = channel.getProfile(); + final String action = message.getActionName(); + final long requestId = message.getRequestId(); + final StreamInput stream = message.getStreamInput(); + final Version version = message.getVersion(); messageListener.onRequestReceived(requestId, action); TransportChannel transportChannel = null; try { - if (TransportStatus.isHandshake(status)) { + if (message.isHandshake()) { handshaker.handleHandshake(version, features, channel, requestId, stream); } else { final RequestHandlerRegistry reg = getRequestHandler(action); @@ -1224,9 +1027,9 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes); } transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName, - messageLengthBytes, TransportStatus.isCompress(status)); + messageLengthBytes, message.isCompress()); final TransportRequest request = reg.newRequest(stream); - request.remoteAddress(new TransportAddress(remoteAddress)); + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); // in case we throw an exception, i.e. when the limit is hit, we don't want to verify validateRequest(stream, requestId, action); threadPool.executor(reg.getExecutor()).execute(new RequestHandler(reg, request, transportChannel)); @@ -1235,7 +1038,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str // the circuit breaker tripped if (transportChannel == null) { transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, - profileName, 0, TransportStatus.isCompress(status)); + profileName, 0, message.isCompress()); } try { transportChannel.sendResponse(e); @@ -1244,7 +1047,6 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", action), inner); } } - return action; } // This template method is needed to inject custom error checking logic in tests. @@ -1321,70 +1123,11 @@ protected final void ensureOpen() { } } - /** - * This listener increments the transmitted bytes metric on success. - */ - private class SendListener extends NotifyOnceListener { - - private final TcpChannel channel; - private final long messageSize; - private final ActionListener delegateListener; - - private SendListener(TcpChannel channel, long messageSize, ActionListener delegateListener) { - this.channel = channel; - this.messageSize = messageSize; - this.delegateListener = delegateListener; - } - - @Override - protected void innerOnResponse(Void v) { - transmittedBytesMetric.inc(messageSize); - delegateListener.onResponse(v); - } - - @Override - protected void innerOnFailure(Exception e) { - logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e); - delegateListener.onFailure(e); - } - } - - private class ReleaseListener implements ActionListener { - - private final Closeable optionalCloseable; - private final Runnable transportAdaptorCallback; - - private ReleaseListener(Closeable optionalCloseable, Runnable transportAdaptorCallback) { - this.optionalCloseable = optionalCloseable; - this.transportAdaptorCallback = transportAdaptorCallback; - } - - @Override - public void onResponse(Void aVoid) { - closeAndCallback(null); - } - - @Override - public void onFailure(Exception e) { - closeAndCallback(e); - } - - private void closeAndCallback(final Exception e) { - try { - IOUtils.close(optionalCloseable, transportAdaptorCallback::run); - } catch (final IOException inner) { - if (e != null) { - inner.addSuppressed(e); - } - throw new UncheckedIOException(inner); - } - } - } - @Override public final TransportStats getStats() { - return new TransportStats(acceptedChannels.size(), readBytesMetric.count(), readBytesMetric.sum(), transmittedBytesMetric.count(), - transmittedBytesMetric.sum()); + MeanMetric transmittedBytes = outboundHandler.getTransmittedBytes(); + return new TransportStats(acceptedChannels.size(), readBytesMetric.count(), readBytesMetric.sum(), transmittedBytes.count(), + transmittedBytes.sum()); } /** @@ -1559,7 +1302,7 @@ public void onFailure(Exception ex) { public void onTimeout() { if (countDown.fastForward()) { CloseableChannel.closeChannels(channels, false); - listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]")); + listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]")); } } } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportLogger.java b/server/src/main/java/org/elasticsearch/transport/TransportLogger.java index ea01cc4ddbfa6..0198609f4c3b6 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportLogger.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportLogger.java @@ -51,6 +51,10 @@ void logInboundMessage(TcpChannel channel, BytesReference message) { void logOutboundMessage(TcpChannel channel, BytesReference message) { if (logger.isTraceEnabled()) { try { + if (message.get(0) != 'E') { + // This is not an Elasticsearch transport message. + return; + } BytesReference withoutHeader = message.slice(HEADER_SIZE, message.length() - HEADER_SIZE); String logMessage = format(channel, withoutHeader, "WRITE"); logger.trace(logMessage); diff --git a/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java b/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java index 9051f302a591f..e237415a9c60e 100644 --- a/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.discovery.zen.UnicastHostsProvider; import org.elasticsearch.discovery.zen.ZenDiscovery; @@ -53,6 +54,7 @@ import java.util.function.Supplier; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class DiscoveryModuleTests extends ESTestCase { @@ -87,11 +89,12 @@ default Map> getDiscoveryTypes(ThreadPool threadPool @Before public void setupDummyServices() { - transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, null, null); + threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null); masterService = mock(MasterService.class); namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); clusterApplier = mock(ClusterApplier.class); - threadPool = mock(ThreadPool.class); clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); gatewayMetaState = mock(GatewayMetaState.class); } diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java index cc9295cee2e3f..b04a1cebbf3e3 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java @@ -62,6 +62,7 @@ public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(FileBasedUnicastHostsProviderTests.class.getName()); executorService = Executors.newSingleThreadExecutor(); + createTransportSvc(); } @After @@ -77,8 +78,7 @@ public void tearDown() throws Exception { } } - @Before - public void createTransportSvc() { + private void createTransportSvc() { final MockNioTransport transport = new MockNioTransport(Settings.EMPTY, Version.CURRENT, threadPool, new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, diff --git a/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java b/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java index 58dc2a1e55b47..aeb92dac73479 100644 --- a/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java +++ b/server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java @@ -39,6 +39,8 @@ public void testStreamWithoutCompression() throws IOException { stream.write(expectedBytes); BytesReference bytesRef = stream.materializeBytes(); + // Closing compression stream does not close underlying stream + stream.close(); assertFalse(CompressorFactory.COMPRESSOR.isCompressed(bytesRef)); @@ -48,7 +50,8 @@ public void testStreamWithoutCompression() throws IOException { assertEquals(-1, streamInput.read()); assertArrayEquals(expectedBytes, actualBytes); - stream.close(); + + bStream.close(); // The bytes should be zeroed out on close for (byte b : bytesRef.toBytesRef().bytes) { @@ -64,6 +67,7 @@ public void testStreamWithCompression() throws IOException { stream.write(expectedBytes); BytesReference bytesRef = stream.materializeBytes(); + stream.close(); assertTrue(CompressorFactory.COMPRESSOR.isCompressed(bytesRef)); @@ -73,7 +77,8 @@ public void testStreamWithCompression() throws IOException { assertEquals(-1, streamInput.read()); assertArrayEquals(expectedBytes, actualBytes); - stream.close(); + + bStream.close(); // The bytes should be zeroed out on close for (byte b : bytesRef.toBytesRef().bytes) { diff --git a/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java new file mode 100644 index 0000000000000..499b6586543ed --- /dev/null +++ b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java @@ -0,0 +1,228 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.VersionUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +public class InboundMessageTests extends ESTestCase { + + private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + private final NamedWriteableRegistry registry = new NamedWriteableRegistry(Collections.emptyList()); + + public void testReadRequest() throws IOException { + String[] features = {"feature1", "feature2"}; + String value = randomAlphaOfLength(10); + Message message = new Message(value); + String action = randomAlphaOfLength(10); + long requestId = randomLong(); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + OutboundMessage.Request request = new OutboundMessage.Request(threadContext, features, message, version, action, requestId, + isHandshake, compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + // Check that the thread context is not deleted. + assertEquals("header_value", threadContext.getHeader("header")); + + threadContext.stashContext(); + threadContext.putHeader("header", "header_value2"); + + InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced); + // Check that deserialize does not overwrite current thread context. + assertEquals("header_value2", threadContext.getHeader("header")); + inboundMessage.getStoredContext().restore(); + assertEquals("header_value", threadContext.getHeader("header")); + assertEquals(isHandshake, inboundMessage.isHandshake()); + assertEquals(compress, inboundMessage.isCompress()); + assertEquals(version, inboundMessage.getVersion()); + assertEquals(action, inboundMessage.getActionName()); + assertEquals(new HashSet<>(Arrays.asList(features)), inboundMessage.getFeatures()); + assertTrue(inboundMessage.isRequest()); + assertFalse(inboundMessage.isResponse()); + assertFalse(inboundMessage.isError()); + assertEquals(value, new Message(inboundMessage.getStreamInput()).value); + } + + public void testReadResponse() throws IOException { + HashSet features = new HashSet<>(Arrays.asList("feature1", "feature2")); + String value = randomAlphaOfLength(10); + Message message = new Message(value); + long requestId = randomLong(); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + OutboundMessage.Response request = new OutboundMessage.Response(threadContext, features, message, version, requestId, isHandshake, + compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + // Check that the thread context is not deleted. + assertEquals("header_value", threadContext.getHeader("header")); + + threadContext.stashContext(); + threadContext.putHeader("header", "header_value2"); + + InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + // Check that deserialize does not overwrite current thread context. + assertEquals("header_value2", threadContext.getHeader("header")); + inboundMessage.getStoredContext().restore(); + assertEquals("header_value", threadContext.getHeader("header")); + assertEquals(isHandshake, inboundMessage.isHandshake()); + assertEquals(compress, inboundMessage.isCompress()); + assertEquals(version, inboundMessage.getVersion()); + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isRequest()); + assertFalse(inboundMessage.isError()); + assertEquals(value, new Message(inboundMessage.getStreamInput()).value); + } + + public void testReadErrorResponse() throws IOException { + HashSet features = new HashSet<>(Arrays.asList("feature1", "feature2")); + RemoteTransportException exception = new RemoteTransportException("error", new IOException()); + long requestId = randomLong(); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + threadContext.putHeader("header", "header_value"); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + OutboundMessage.Response request = new OutboundMessage.Response(threadContext, features, exception, version, requestId, + isHandshake, compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + // Check that the thread context is not deleted. + assertEquals("header_value", threadContext.getHeader("header")); + + threadContext.stashContext(); + threadContext.putHeader("header", "header_value2"); + + InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + // Check that deserialize does not overwrite current thread context. + assertEquals("header_value2", threadContext.getHeader("header")); + inboundMessage.getStoredContext().restore(); + assertEquals("header_value", threadContext.getHeader("header")); + assertEquals(isHandshake, inboundMessage.isHandshake()); + assertEquals(compress, inboundMessage.isCompress()); + assertEquals(version, inboundMessage.getVersion()); + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isRequest()); + assertTrue(inboundMessage.isError()); + assertEquals("[error]", inboundMessage.getStreamInput().readException().getMessage()); + } + + public void testEnsureVersionCompatibility() throws IOException { + testVersionIncompatibility(VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), + Version.CURRENT), Version.CURRENT, randomBoolean()); + + final Version version = Version.fromString("7.0.0"); + testVersionIncompatibility(Version.fromString("6.0.0"), version, true); + IllegalStateException ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("6.0.0"), version, false)); + assertEquals("Received message from unsupported version: [6.0.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + + // For handshake we are compatible with N-2 + testVersionIncompatibility(Version.fromString("5.6.0"), version, true); + ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("5.6.0"), version, false)); + assertEquals("Received message from unsupported version: [5.6.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + + ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("2.3.0"), version, true)); + assertEquals("Received handshake message from unsupported version: [2.3.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + + ise = expectThrows(IllegalStateException.class, () -> + testVersionIncompatibility(Version.fromString("2.3.0"), version, false)); + assertEquals("Received message from unsupported version: [2.3.0] minimal compatible version is: [" + + version.minimumCompatibilityVersion() + "]", ise.getMessage()); + } + + private void testVersionIncompatibility(Version version, Version currentVersion, boolean isHandshake) throws IOException { + String[] features = {}; + String value = randomAlphaOfLength(10); + Message message = new Message(value); + String action = randomAlphaOfLength(10); + long requestId = randomLong(); + boolean compress = randomBoolean(); + OutboundMessage.Request request = new OutboundMessage.Request(threadContext, features, message, version, action, requestId, + isHandshake, compress); + BytesReference reference; + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + reference = request.serialize(streamOutput); + } + + BytesReference sliced = reference.slice(6, reference.length() - 6); + InboundMessage.Reader reader = new InboundMessage.Reader(currentVersion, registry, threadContext); + reader.deserialize(sliced); + } + + private static final class Message extends TransportMessage { + + public String value; + + private Message() { + } + + private Message(StreamInput in) throws IOException { + readFrom(in); + } + + private Message(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java new file mode 100644 index 0000000000000..01e391a30a732 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -0,0 +1,184 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +public class OutboundHandlerTests extends ESTestCase { + + private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); + private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + private OutboundHandler handler; + private FakeTcpChannel fakeTcpChannel; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportLogger transportLogger = new TransportLogger(); + fakeTcpChannel = new FakeTcpChannel(randomBoolean()); + handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger); + } + + @After + public void tearDown() throws Exception { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + super.tearDown(); + } + + public void testSendRawBytes() { + BytesArray bytesArray = new BytesArray("message".getBytes(StandardCharsets.UTF_8)); + + AtomicBoolean isSuccess = new AtomicBoolean(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); + handler.sendBytes(fakeTcpChannel, bytesArray, listener); + + BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); + ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + assertTrue(isSuccess.get()); + assertNull(exception.get()); + } else { + IOException e = new IOException("failed"); + sendListener.onFailure(e); + assertFalse(isSuccess.get()); + assertSame(e, exception.get()); + } + + assertEquals(bytesArray, reference); + } + + public void testSendMessage() throws IOException { + OutboundMessage message; + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = Version.CURRENT; + String actionName = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + String value = "message"; + threadContext.putHeader("header", "header_value"); + Writeable writeable = new Message(value); + + boolean isRequest = randomBoolean(); + if (isRequest) { + message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake, + compress); + } else { + message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress); + } + + AtomicBoolean isSuccess = new AtomicBoolean(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); + handler.sendMessage(fakeTcpChannel, message, listener); + + BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); + ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + assertTrue(isSuccess.get()); + assertNull(exception.get()); + } else { + IOException e = new IOException("failed"); + sendListener.onFailure(e); + assertFalse(isSuccess.get()); + assertSame(e, exception.get()); + } + + InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); + try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { + assertEquals(version, inboundMessage.getVersion()); + assertEquals(requestId, inboundMessage.getRequestId()); + if (isRequest) { + assertTrue(inboundMessage.isRequest()); + assertFalse(inboundMessage.isResponse()); + } else { + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isRequest()); + } + if (isHandshake) { + assertTrue(inboundMessage.isHandshake()); + } else { + assertFalse(inboundMessage.isHandshake()); + } + if (compress) { + assertTrue(inboundMessage.isCompress()); + } else { + assertFalse(inboundMessage.isCompress()); + } + Message readMessage = new Message(); + readMessage.readFrom(inboundMessage.getStreamInput()); + assertEquals(value, readMessage.value); + + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext(); + assertNull(threadContext.getHeader("header")); + storedContext.restore(); + assertEquals("header_value", threadContext.getHeader("header")); + } + } + } + + private static final class Message extends TransportMessage { + + public String value; + + private Message() { + } + + private Message(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java index f3294366b8fe1..a25ac2a551a22 100644 --- a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java @@ -37,7 +37,6 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.VersionUtils; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -158,41 +157,12 @@ public void testAddressLimit() throws Exception { assertEquals(102, addresses[2].getPort()); } - public void testEnsureVersionCompatibility() { - TcpTransport.ensureVersionCompatibility(VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), - Version.CURRENT), Version.CURRENT, randomBoolean()); - - final Version version = Version.fromString("7.0.0"); - TcpTransport.ensureVersionCompatibility(Version.fromString("6.0.0"), version, true); - IllegalStateException ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("6.0.0"), version, false)); - assertEquals("Received message from unsupported version: [6.0.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - - // For handshake we are compatible with N-2 - TcpTransport.ensureVersionCompatibility(Version.fromString("5.6.0"), version, true); - ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("5.6.0"), version, false)); - assertEquals("Received message from unsupported version: [5.6.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - - ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("2.3.0"), version, true)); - assertEquals("Received handshake message from unsupported version: [2.3.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - - ise = expectThrows(IllegalStateException.class, () -> - TcpTransport.ensureVersionCompatibility(Version.fromString("2.3.0"), version, false)); - assertEquals("Received message from unsupported version: [2.3.0] minimal compatible version is: [" - + version.minimumCompatibilityVersion() + "]", ise.getMessage()); - } - @SuppressForbidden(reason = "Allow accessing localhost") public void testCompressRequestAndResponse() throws IOException { final boolean compressed = randomBoolean(); Req request = new Req(randomRealisticUnicodeOfLengthBetween(10, 100)); ThreadPool threadPool = new TestThreadPool(TcpTransportTests.class.getName()); - AtomicReference requestCaptor = new AtomicReference<>(); + AtomicReference messageCaptor = new AtomicReference<>(); try { TcpTransport transport = new TcpTransport("test", Settings.EMPTY, Version.CURRENT, threadPool, PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), null, null) { @@ -204,7 +174,7 @@ protected FakeServerChannel bind(String name, InetSocketAddress address) throws @Override protected FakeTcpChannel initiateChannel(DiscoveryNode node) throws IOException { - return new FakeTcpChannel(false, requestCaptor); + return new FakeTcpChannel(false); } @Override @@ -219,7 +189,7 @@ public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, int numConnections = profile.getNumConnections(); ArrayList fakeChannels = new ArrayList<>(numConnections); for (int i = 0; i < numConnections; ++i) { - fakeChannels.add(new FakeTcpChannel(false, requestCaptor)); + fakeChannels.add(new FakeTcpChannel(false, messageCaptor)); } listener.onResponse(new NodeChannels(node, fakeChannels, profile, Version.CURRENT)); return () -> CloseableChannel.closeChannels(fakeChannels, false); @@ -241,12 +211,12 @@ public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, (request1, channel, task) -> channel.sendResponse(TransportResponse.Empty.INSTANCE), ThreadPool.Names.SAME, true, true)); - BytesReference reference = requestCaptor.get(); + BytesReference reference = messageCaptor.get(); assertNotNull(reference); AtomicReference responseCaptor = new AtomicReference<>(); InetSocketAddress address = new InetSocketAddress(InetAddress.getLocalHost(), 0); - FakeTcpChannel responseChannel = new FakeTcpChannel(true, address, address, responseCaptor); + FakeTcpChannel responseChannel = new FakeTcpChannel(true, address, address, "profile", responseCaptor); transport.messageReceived(reference.slice(6, reference.length() - 6), responseChannel); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 43542d48a6c39..87343d4c82087 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -2038,11 +2038,14 @@ public void testTcpHandshake() { new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - protected String handleRequest(TcpChannel mockChannel, String profileName, StreamInput stream, long requestId, - int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) + protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { - return super.handleRequest(mockChannel, profileName, stream, requestId, messageLengthBytes, version, remoteAddress, - (byte) (status & ~(1 << 3))); // we flip the isHandshake bit back and act like the handler is not found + // we flip the isHandshake bit back and act like the handler is not found + byte status = (byte) (request.status & ~(1 << 3)); + Version version = request.getVersion(); + InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version, + status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput()); + super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes); } }; diff --git a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java index 63cacfbb093a8..bb392554305c0 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java @@ -31,9 +31,10 @@ public class FakeTcpChannel implements TcpChannel { private final InetSocketAddress localAddress; private final InetSocketAddress remoteAddress; private final String profile; - private final AtomicReference messageCaptor; private final ChannelStats stats = new ChannelStats(); private final CompletableContext closeContext = new CompletableContext<>(); + private final AtomicReference messageCaptor; + private final AtomicReference> listenerCaptor; public FakeTcpChannel() { this(false, "profile", new AtomicReference<>()); @@ -47,11 +48,6 @@ public FakeTcpChannel(boolean isServer, AtomicReference messageC this(isServer, "profile", messageCaptor); } - public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress, - AtomicReference messageCaptor) { - this(isServer, localAddress, remoteAddress,"profile", messageCaptor); - } - public FakeTcpChannel(boolean isServer, String profile, AtomicReference messageCaptor) { this(isServer, null, null, profile, messageCaptor); @@ -64,6 +60,7 @@ public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSock this.remoteAddress = remoteAddress; this.profile = profile; this.messageCaptor = messageCaptor; + this.listenerCaptor = new AtomicReference<>(); } @Override @@ -89,6 +86,7 @@ public InetSocketAddress getRemoteAddress() { @Override public void sendMessage(BytesReference reference, ActionListener listener) { messageCaptor.set(reference); + listenerCaptor.set(listener); } @Override @@ -115,4 +113,12 @@ public boolean isOpen() { public ChannelStats getChannelStats() { return stats; } + + public AtomicReference getMessageCaptor() { + return messageCaptor; + } + + public AtomicReference> getListenerCaptor() { + return listenerCaptor; + } }