diff --git a/server/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java b/server/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java index 109efb400bc93..aa0672d80ba1d 100644 --- a/server/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java +++ b/server/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java @@ -21,7 +21,7 @@ import com.carrotsearch.hppc.cursors.ObjectCursor; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.logging.log4j.util.Supplier; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; @@ -56,6 +56,7 @@ import org.elasticsearch.transport.TransportService; import java.io.Closeable; +import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -89,6 +90,7 @@ final class TransportClientNodesService extends AbstractComponent implements Clo private final Object mutex = new Object(); private volatile List nodes = Collections.emptyList(); + // Filtered nodes are nodes whose cluster name does not match the configured cluster name private volatile List filteredNodes = Collections.emptyList(); private final AtomicInteger tempNodeIdGenerator = new AtomicInteger(); @@ -268,7 +270,7 @@ public static class RetryListener implements ActionListener private volatile int i; RetryListener(NodeListenerCallback callback, ActionListener listener, - List nodes, int index, TransportClient.HostFailureListener hostFailureListener) { + List nodes, int index, TransportClient.HostFailureListener hostFailureListener) { this.callback = callback; this.listener = listener; this.nodes = nodes; @@ -361,10 +363,10 @@ public void sample() { protected abstract void doSample(); /** - * validates a set of potentially newly discovered nodes and returns an immutable - * list of the nodes that has passed. + * Establishes the node connections. If validateInHandshake is set to true, the connection will fail if + * node returned in the handshake response is different than the discovery node. */ - protected List validateNewNodes(Set nodes) { + List establishNodeConnections(Set nodes) { for (Iterator it = nodes.iterator(); it.hasNext(); ) { DiscoveryNode node = it.next(); if (!transportService.nodeConnected(node)) { @@ -380,7 +382,6 @@ protected List validateNewNodes(Set nodes) { return Collections.unmodifiableList(new ArrayList<>(nodes)); } - } class ScheduledNodeSampler implements Runnable { @@ -402,14 +403,16 @@ class SimpleNodeSampler extends NodeSampler { @Override protected void doSample() { HashSet newNodes = new HashSet<>(); - HashSet newFilteredNodes = new HashSet<>(); + ArrayList newFilteredNodes = new ArrayList<>(); for (DiscoveryNode listedNode : listedNodes) { try (Transport.Connection connection = transportService.openConnection(listedNode, LISTED_NODES_PROFILE)){ final PlainTransportFuture handler = new PlainTransportFuture<>( new FutureTransportResponseHandler() { @Override - public LivenessResponse newInstance() { - return new LivenessResponse(); + public LivenessResponse read(StreamInput in) throws IOException { + LivenessResponse response = new LivenessResponse(); + response.readFrom(in); + return response; } }); transportService.sendRequest(connection, TransportLivenessAction.NAME, new LivenessRequest(), @@ -435,8 +438,8 @@ public LivenessResponse newInstance() { } } - nodes = validateNewNodes(newNodes); - filteredNodes = Collections.unmodifiableList(new ArrayList<>(newFilteredNodes)); + nodes = establishNodeConnections(newNodes); + filteredNodes = Collections.unmodifiableList(newFilteredNodes); } } @@ -557,7 +560,7 @@ public void handleException(TransportException e) { } } - nodes = validateNewNodes(newNodes); + nodes = establishNodeConnections(newNodes); filteredNodes = Collections.unmodifiableList(new ArrayList<>(newFilteredNodes)); } } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index 5d3ac517c7a31..7798129f6a883 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -21,6 +21,8 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.Version; import org.elasticsearch.action.admin.cluster.node.liveness.TransportLivenessAction; @@ -124,6 +126,8 @@ protected boolean removeEldestEntry(Map.Entry eldest) { private final RemoteClusterService remoteClusterService; + private final boolean validateConnections; + /** if set will call requests sent to this id to shortcut and executed locally */ volatile DiscoveryNode localNode = null; private final Transport.Connection localNodeConnection = new Transport.Connection() { @@ -153,6 +157,9 @@ public TransportService(Settings settings, Transport transport, ThreadPool threa Function localNodeFactory, @Nullable ClusterSettings clusterSettings, Set taskHeaders) { super(settings); + // The only time we do not want to validate node connections is when this is a transport client using the simple node sampler + this.validateConnections = TransportClient.CLIENT_TYPE.equals(settings.get(Client.CLIENT_TYPE_SETTING_S.getKey())) == false || + TransportClient.CLIENT_TRANSPORT_SNIFF.get(settings); this.transport = transport; this.threadPool = threadPool; this.localNodeFactory = localNodeFactory; @@ -314,6 +321,11 @@ public boolean nodeConnected(DiscoveryNode node) { return isLocalNode(node) || transport.nodeConnected(node); } + /** + * Connect to the specified node with the default connection profile + * + * @param node the node to connect to + */ public void connectToNode(DiscoveryNode node) throws ConnectTransportException { connectToNode(node, null); } @@ -331,7 +343,7 @@ public void connectToNode(final DiscoveryNode node, ConnectionProfile connection transport.connectToNode(node, connectionProfile, (newConnection, actualProfile) -> { // We don't validate cluster names to allow for tribe node connections. final DiscoveryNode remote = handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true); - if (node.equals(remote) == false) { + if (validateConnections && node.equals(remote) == false) { throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); } }); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java index 08d88ad2e0486..ed1dfded0782f 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java @@ -20,6 +20,8 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -178,6 +180,42 @@ public void testNodeConnectWithDifferentNodeId() { assertFalse(handleA.transportService.nodeConnected(discoveryNode)); } + public void testNodeConnectWithDifferentNodeIdSucceedsIfThisIsTransportClientOfSimpleNodeSampler() { + Settings.Builder settings = Settings.builder().put("cluster.name", "test"); + Settings transportClientSettings = settings.put(Client.CLIENT_TYPE_SETTING_S.getKey(), TransportClient.CLIENT_TYPE).build(); + NetworkHandle handleA = startServices("TS_A", transportClientSettings, Version.CURRENT); + NetworkHandle handleB = startServices("TS_B", settings.build(), Version.CURRENT); + DiscoveryNode discoveryNode = new DiscoveryNode( + randomAlphaOfLength(10), + handleB.discoveryNode.getAddress(), + emptyMap(), + emptySet(), + handleB.discoveryNode.getVersion()); + + handleA.transportService.connectToNode(discoveryNode, MockTcpTransport.LIGHT_PROFILE); + assertTrue(handleA.transportService.nodeConnected(discoveryNode)); + } + + public void testNodeConnectWithDifferentNodeIdFailsWhenSnifferTransportClient() { + Settings.Builder settings = Settings.builder().put("cluster.name", "test"); + Settings transportClientSettings = settings.put(Client.CLIENT_TYPE_SETTING_S.getKey(), TransportClient.CLIENT_TYPE) + .put(TransportClient.CLIENT_TRANSPORT_SNIFF.getKey(), true) + .build(); + NetworkHandle handleA = startServices("TS_A", transportClientSettings, Version.CURRENT); + NetworkHandle handleB = startServices("TS_B", settings.build(), Version.CURRENT); + DiscoveryNode discoveryNode = new DiscoveryNode( + randomAlphaOfLength(10), + handleB.discoveryNode.getAddress(), + emptyMap(), + emptySet(), + handleB.discoveryNode.getVersion()); + ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> { + handleA.transportService.connectToNode(discoveryNode, MockTcpTransport.LIGHT_PROFILE); + }); + assertThat(ex.getMessage(), containsString("unexpected remote node")); + assertFalse(handleA.transportService.nodeConnected(discoveryNode)); + } + private static class NetworkHandle { private TransportService transportService;