diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 2c9bcf139319c..861f7fa1ab695 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -69,6 +69,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.FetchSubPhaseProcessor; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESIntegTestCase; import java.io.IOException; @@ -133,9 +134,10 @@ public void testLocalClusterAlias() { indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); IndexResponse indexResponse = client().index(indexRequest).actionGet(); assertEquals(RestStatus.CREATED, indexResponse.status()); + TaskId parentTaskId = new TaskId("node", randomNonNegativeLong()); { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, + SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY, "local", nowInMillis, randomBoolean()); SearchResponse searchResponse = client().search(searchRequest).actionGet(); assertEquals(1, searchResponse.getHits().getTotalHits().value); @@ -147,7 +149,7 @@ public void testLocalClusterAlias() { assertEquals("1", hit.getId()); } { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, + SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY, "", nowInMillis, randomBoolean()); SearchResponse searchResponse = client().search(searchRequest).actionGet(); assertEquals(1, searchResponse.getHits().getTotalHits().value); @@ -161,6 +163,7 @@ public void testLocalClusterAlias() { } public void testAbsoluteStartMillis() { + TaskId parentTaskId = new TaskId("node", randomNonNegativeLong()); { IndexRequest indexRequest = new IndexRequest("test-1970.01.01"); indexRequest.id("1"); @@ -189,13 +192,13 @@ public void testAbsoluteStartMillis() { assertEquals(0, searchResponse.getTotalShards()); } { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY, "", 0, randomBoolean()); SearchResponse searchResponse = client().search(searchRequest).actionGet(); assertEquals(2, searchResponse.getHits().getTotalHits().value); } { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY, "", 0, randomBoolean()); searchRequest.indices(""); SearchResponse searchResponse = client().search(searchRequest).actionGet(); @@ -203,7 +206,7 @@ public void testAbsoluteStartMillis() { assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); } { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY, "", 0, randomBoolean()); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date"); @@ -219,6 +222,7 @@ public void testAbsoluteStartMillis() { public void testFinalReduce() { long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); + TaskId taskId = new TaskId("node", randomNonNegativeLong()); { IndexRequest indexRequest = new IndexRequest("test"); indexRequest.id("1"); @@ -245,7 +249,7 @@ public void testFinalReduce() { source.aggregation(terms); { - SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest, + SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(taskId, originalRequest, Strings.EMPTY_ARRAY, "remote", nowInMillis, true); SearchResponse searchResponse = client().search(searchRequest).actionGet(); assertEquals(2, searchResponse.getHits().getTotalHits().value); @@ -254,7 +258,7 @@ public void testFinalReduce() { assertEquals(1, longTerms.getBuckets().size()); } { - SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest, + SearchRequest searchRequest = SearchRequest.subSearchRequest(taskId, originalRequest, Strings.EMPTY_ARRAY, "remote", nowInMillis, false); SearchResponse searchResponse = client().search(searchRequest).actionGet(); assertEquals(2, searchResponse.getHits().getTotalHits().value); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index fb47511ee56a4..b7efac874447e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -19,12 +19,18 @@ package org.elasticsearch.search.ccs; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.query.MatchAllQueryBuilder; @@ -33,15 +39,21 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.InternalTestCluster; import org.elasticsearch.transport.TransportService; +import org.hamcrest.Matchers; import org.junit.Before; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -101,6 +113,70 @@ public void testProxyConnectionDisconnect() throws Exception { } } + public void testCancel() throws Exception { + assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo")); + indexDocs(client(LOCAL_CLUSTER), "demo"); + final InternalTestCluster remoteCluster = cluster("cluster_a"); + remoteCluster.ensureAtLeastNumDataNodes(1); + final Settings.Builder allocationFilter = Settings.builder(); + if (randomBoolean()) { + remoteCluster.ensureAtLeastNumDataNodes(3); + List remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false) + .filter(DiscoveryNode::isDataNode) + .map(DiscoveryNode::getName) + .collect(Collectors.toList()); + assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(3)); + List seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes); + disconnectFromRemoteClusters(); + configureRemoteCluster("cluster_a", seedNodes); + if (randomBoolean()) { + // Using proxy connections + allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes)); + } else { + allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes)); + } + } + assertAcked(client("cluster_a").admin().indices().prepareCreate("prod") + .setSettings(Settings.builder().put(allocationFilter.build()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0))); + assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod") + .setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut()); + indexDocs(client("cluster_a"), "prod"); + SearchListenerPlugin.blockQueryPhase(); + PlainActionFuture queryFuture = new PlainActionFuture<>(); + SearchRequest searchRequest = new SearchRequest("demo", "cluster_a:prod"); + searchRequest.allowPartialSearchResults(false); + searchRequest.setCcsMinimizeRoundtrips(false); + searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1000)); + client(LOCAL_CLUSTER).search(searchRequest, queryFuture); + SearchListenerPlugin.waitSearchStarted(); + // Get the search task and cancelled + final TaskInfo rootTask = client().admin().cluster().prepareListTasks() + .setActions(SearchAction.INSTANCE.name()) + .get().getTasks().stream().filter(t -> t.getParentTaskId().isSet() == false) + .findFirst().get(); + final CancelTasksRequest cancelRequest = new CancelTasksRequest().setTaskId(rootTask.getTaskId()); + cancelRequest.setWaitForCompletion(randomBoolean()); + final ActionFuture cancelFuture = client().admin().cluster().cancelTasks(cancelRequest); + assertBusy(() -> { + final Iterable transportServices = cluster("cluster_a").getInstances(TransportService.class); + for (TransportService transportService : transportServices) { + Collection cancellableTasks = transportService.getTaskManager().getCancellableTasks().values(); + for (CancellableTask cancellableTask : cancellableTasks) { + assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled()); + } + } + }); + SearchListenerPlugin.allowQueryPhase(); + assertBusy(() -> assertTrue(queryFuture.isDone())); + assertBusy(() -> assertTrue(cancelFuture.isDone())); + assertBusy(() -> { + final Iterable transportServices = cluster("cluster_a").getInstances(TransportService.class); + for (TransportService transportService : transportServices) { + assertThat(transportService.getTaskManager().getBannedTaskIds(), Matchers.empty()); + } + }); + } + @Override protected Collection> nodePlugins(String clusterAlias) { if (clusterAlias.equals(LOCAL_CLUSTER)) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index 4623893c8153e..5da95412bf9ad 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -142,21 +142,25 @@ public SearchRequest(String[] indices, SearchSourceBuilder source) { * Used when a {@link SearchRequest} is created and executed as part of a cross-cluster search request * performing reduction on each cluster in order to minimize network round-trips between the coordinating node and the remote clusters. * + * @param parentTaskId the parent taskId of the original search request * @param originalSearchRequest the original search request * @param indices the indices to search against * @param clusterAlias the alias to prefix index names with in the returned search results * @param absoluteStartMillis the absolute start time to be used on the remote clusters to ensure that the same value is used * @param finalReduce whether the reduction should be final or not */ - static SearchRequest subSearchRequest(SearchRequest originalSearchRequest, String[] indices, + static SearchRequest subSearchRequest(TaskId parentTaskId, SearchRequest originalSearchRequest, String[] indices, String clusterAlias, long absoluteStartMillis, boolean finalReduce) { + Objects.requireNonNull(parentTaskId, "parentTaskId must be specified"); Objects.requireNonNull(originalSearchRequest, "search request must not be null"); validateIndices(indices); Objects.requireNonNull(clusterAlias, "cluster alias must not be null"); if (absoluteStartMillis < 0) { throw new IllegalArgumentException("absoluteStartMillis must not be negative but was [" + absoluteStartMillis + "]"); } - return new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce); + final SearchRequest request = new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce); + request.setParentTask(parentTaskId); + return request; } private SearchRequest(SearchRequest searchRequest, String[] indices, String localClusterAlias, long absoluteStartMillis, @@ -320,7 +324,7 @@ boolean isFinalReduce() { /** * Returns the current time in milliseconds from the time epoch, to be used for the execution of this search request. Used to * ensure that the same value, determined by the coordinating node, is used on all nodes involved in the execution of the search - * request. When created through {@link #subSearchRequest(SearchRequest, String[], String, long, boolean)}, this method returns + * request. When created through {@link #subSearchRequest(TaskId, SearchRequest, String[], String, long, boolean)}, this method returns * the provided current time, otherwise it will return {@link System#currentTimeMillis()}. */ long getOrCreateAbsoluteStartMillis() { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index d818bd5096c53..9eaedf07865d8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -66,6 +66,7 @@ import org.elasticsearch.search.profile.ProfileShardResult; import org.elasticsearch.search.profile.SearchProfileShardResults; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.RemoteClusterService; @@ -296,7 +297,8 @@ private void executeRequest(Task task, SearchRequest searchRequest, task, timeProvider, searchRequest, localIndices, clusterState, listener, searchContext, searchAsyncActionProvider); } else { if (shouldMinimizeRoundtrips(searchRequest)) { - ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider, + final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).getTaskId(); + ccsRemoteReduce(parentTaskId, searchRequest, localIndices, remoteClusterIndices, timeProvider, searchService.aggReduceContextBuilder(searchRequest), remoteClusterService, threadPool, listener, (r, l) -> executeLocalSearch( @@ -358,8 +360,9 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) { source.collapse().getInnerHits().isEmpty(); } - static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map remoteIndices, - SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, + static void ccsRemoteReduce(TaskId parentTaskId, SearchRequest searchRequest, OriginalIndices localIndices, + Map remoteIndices, SearchTimeProvider timeProvider, + InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener listener, BiConsumer> localSearchConsumer) { @@ -370,7 +373,7 @@ static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIn String clusterAlias = entry.getKey(); boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); - SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(), + SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), true); Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias); remoteClusterClient.search(ccsSearchRequest, new ActionListener() { @@ -408,7 +411,7 @@ public void onFailure(Exception e) { String clusterAlias = entry.getKey(); boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); - SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(), + SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), false); ActionListener ccsListener = createCCSListener(clusterAlias, skipUnavailable, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener); @@ -418,7 +421,7 @@ public void onFailure(Exception e) { if (localIndices != null) { ActionListener ccsListener = createCCSListener(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, false, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener); - SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(searchRequest, localIndices.indices(), + SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, localIndices.indices(), RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false); localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener); } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index f415d64a1565d..633daed2eb45d 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -145,6 +145,7 @@ private void setBanOnChildConnections(String reason, boolean waitForCompletion, GroupedActionListener groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size()); final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion); for (Transport.Connection connection : childConnections) { + assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped"; transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, banRequest, TransportRequestOptions.EMPTY, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { @Override @@ -167,6 +168,7 @@ private void removeBanOnChildConnections(CancellableTask task, Collection extends TransportRequest { super(in); targetNode = new DiscoveryNode(in); wrapped = reader.read(in); + setParentTask(wrapped.getParentTask()); } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index 9424ef2a878b1..90cd43ef2a302 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -701,6 +701,17 @@ public final void sendRequest(final DiscoveryNode sendRequest(connection, action, request, options, handler); } + /** + * Unwraps and returns the actual underlying connection of the given connection. + */ + public static Transport.Connection unwrapConnection(Transport.Connection connection) { + Transport.Connection unwrapped = connection; + while (unwrapped instanceof RemoteConnectionManager.ProxyConnection) { + unwrapped = ((RemoteConnectionManager.ProxyConnection) unwrapped).getConnection(); + } + return unwrapped; + } + /** * Sends a request on the specified connection. If there is a failure sending the request, the specified handler is invoked. * @@ -718,7 +729,18 @@ public final void sendRequest(final Transport.Conn try { final TransportResponseHandler delegate; if (request.getParentTask().isSet()) { - final Releasable unregisterChildNode = taskManager.registerChildConnection(request.getParentTask().getId(), connection); + // If the connection is a proxy connection, then we will create a cancellable proxy task on the proxy node and an actual + // child task on the target node of the remote cluster. + // ----> a parent task on the local cluster + // | + // ----> a proxy task on the proxy node on the remote cluster + // | + // ----> an actual child task on the target node on the remote cluster + // To cancel the child task on the remote cluster, we must send a cancel request to the proxy node instead of the target + // node as the parent task of the child task is the proxy task not the parent task on the local cluster. Hence, here we + // unwrap the connection and keep track of the connection to the proxy node instead of the proxy connection. + final Transport.Connection unwrappedConn = unwrapConnection(connection); + final Releasable unregisterChildNode = taskManager.registerChildConnection(request.getParentTask().getId(), unwrappedConn); delegate = new TransportResponseHandler() { @Override public void handleResponse(T response) { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 2898e203a13a1..0dacaaed99d02 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -69,6 +69,7 @@ import org.elasticsearch.search.suggest.completion.CompletionSuggestion; import org.elasticsearch.search.suggest.phrase.PhraseSuggestion; import org.elasticsearch.search.suggest.term.TermSuggestion; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalAggregationTestCase; import org.elasticsearch.threadpool.TestThreadPool; @@ -395,7 +396,7 @@ private static AtomicArray generateFetchResults(int nShards, } private static SearchRequest randomSearchRequest() { - return randomBoolean() ? new SearchRequest() : SearchRequest.subSearchRequest(new SearchRequest(), + return randomBoolean() ? new SearchRequest() : SearchRequest.subSearchRequest(new TaskId("n", 1), new SearchRequest(), Strings.EMPTY_ARRAY, "remote", 0, randomBoolean()); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index cbf231b611004..5b875e8c0a13a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -57,21 +57,23 @@ protected SearchRequest createSearchRequest() throws IOException { return request; } //clusterAlias and absoluteStartMillis do not have public getters/setters hence we randomize them only in this test specifically. - return SearchRequest.subSearchRequest(request, request.indices(), + return SearchRequest.subSearchRequest(new TaskId("node", 1), request, request.indices(), randomAlphaOfLengthBetween(5, 10), randomNonNegativeLong(), randomBoolean()); } public void testWithLocalReduction() { - expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(null, Strings.EMPTY_ARRAY, "", 0, randomBoolean())); + final TaskId taskId = new TaskId("n", 1); + expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest( + taskId, null, Strings.EMPTY_ARRAY, "", 0, randomBoolean())); SearchRequest request = new SearchRequest(); - expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(request, null, "", 0, randomBoolean())); - expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(request, + expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(taskId, request, null, "", 0, randomBoolean())); + expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(taskId, request, new String[]{null}, "", 0, randomBoolean())); - expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(request, + expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(taskId, request, Strings.EMPTY_ARRAY, null, 0, randomBoolean())); - expectThrows(IllegalArgumentException.class, () -> SearchRequest.subSearchRequest(request, + expectThrows(IllegalArgumentException.class, () -> SearchRequest.subSearchRequest(taskId, request, Strings.EMPTY_ARRAY, "", -1, randomBoolean())); - SearchRequest searchRequest = SearchRequest.subSearchRequest(request, Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + SearchRequest searchRequest = SearchRequest.subSearchRequest(taskId, request, Strings.EMPTY_ARRAY, "", 0, randomBoolean()); assertNull(searchRequest.validate()); } diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index dfaf1be3c2311..0b7925913abcd 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -62,6 +62,7 @@ import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.sort.SortBuilders; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -391,7 +392,7 @@ public void testCCSRemoteReduceMergeFails() throws Exception { AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch); - TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, + TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider, emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l))); if (localIndices == null) { assertNull(setOnce.get()); @@ -436,7 +437,7 @@ public void testCCSRemoteReduce() throws Exception { AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch); - TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, + TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider, emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l))); if (localIndices == null) { assertNull(setOnce.get()); @@ -462,7 +463,7 @@ public void testCCSRemoteReduce() throws Exception { AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch); - TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, + TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider, emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l))); if (localIndices == null) { assertNull(setOnce.get()); @@ -509,7 +510,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch); - TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, + TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider, emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l))); if (localIndices == null) { assertNull(setOnce.get()); @@ -538,7 +539,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch); - TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, + TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider, emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l))); if (localIndices == null) { assertNull(setOnce.get()); @@ -578,7 +579,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch); - TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, + TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider, emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l))); if (localIndices == null) { assertNull(setOnce.get()); diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index 2bcd3f2c89f64..72caaf3df554e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -53,7 +53,7 @@ import static org.elasticsearch.discovery.SettingsBasedSeedHostsProvider.DISCOVERY_SEED_HOSTS_SETTING; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasKey; -import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; public abstract class AbstractMultiClustersTestCase extends ESTestCase { public static final String LOCAL_CLUSTER = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -145,34 +145,32 @@ protected void disconnectFromRemoteClusters() throws Exception { } protected void configureAndConnectsToRemoteClusters() throws Exception { - Map> seedNodes = new HashMap<>(); for (String clusterAlias : clusterGroup.clusterAliases()) { if (clusterAlias.equals(LOCAL_CLUSTER) == false) { final InternalTestCluster cluster = clusterGroup.getCluster(clusterAlias); final String[] allNodes = cluster.getNodeNames(); - final List selectedNodes = randomSubsetOf(randomIntBetween(1, Math.min(3, allNodes.length)), allNodes); - seedNodes.put(clusterAlias, selectedNodes); + final List seedNodes = randomSubsetOf(randomIntBetween(1, Math.min(3, allNodes.length)), allNodes); + configureRemoteCluster(clusterAlias, seedNodes); } } - if (seedNodes.isEmpty()) { - return; - } + } + + protected void configureRemoteCluster(String clusterAlias, Collection seedNodes) throws Exception { Settings.Builder settings = Settings.builder(); - for (Map.Entry> entry : seedNodes.entrySet()) { - final String clusterAlias = entry.getKey(); - final String seeds = entry.getValue().stream() - .map(node -> cluster(clusterAlias).getInstance(TransportService.class, node).boundAddress().publishAddress().toString()) - .collect(Collectors.joining(",")); - settings.put("cluster.remote." + clusterAlias + ".seeds", seeds); - } + final String seed = seedNodes.stream() + .map(node -> { + final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, node); + return transportService.boundAddress().publishAddress().toString(); + }) + .collect(Collectors.joining(",")); + settings.put("cluster.remote." + clusterAlias + ".seeds", seed); client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get(); assertBusy(() -> { List remoteConnectionInfos = client() .execute(RemoteInfoAction.INSTANCE, new RemoteInfoRequest()).actionGet().getInfos() - .stream().filter(RemoteConnectionInfo::isConnected) + .stream().filter(c -> c.isConnected() && c.getClusterAlias().equals(clusterAlias)) .collect(Collectors.toList()); - final long totalConnections = seedNodes.values().stream().map(List::size).count(); - assertThat(remoteConnectionInfos, hasSize(Math.toIntExact(totalConnections))); + assertThat(remoteConnectionInfos, not(empty())); }); }