diff --git a/buildSrc/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java b/buildSrc/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java index b66b30f9a1311..370dbabd60809 100644 --- a/buildSrc/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java +++ b/buildSrc/src/main/java/org/elasticsearch/gradle/testclusters/ElasticsearchCluster.java @@ -133,10 +133,16 @@ public void setNumberOfNodes(int numberOfNodes) { } @Internal - ElasticsearchNode getFirstNode() { + public ElasticsearchNode getFirstNode() { return nodes.getAt(clusterName + "-0"); } + @Internal + public ElasticsearchNode getLastNode() { + int index = nodes.size() - 1; + return nodes.getAt(clusterName + "-" + index); + } + @Internal public int getNumberOfNodes() { return nodes.size(); diff --git a/qa/ccs-rolling-upgrade-remote-cluster/build.gradle b/qa/ccs-rolling-upgrade-remote-cluster/build.gradle new file mode 100644 index 0000000000000..47f0480f8c60e --- /dev/null +++ b/qa/ccs-rolling-upgrade-remote-cluster/build.gradle @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +import org.elasticsearch.gradle.Version +import org.elasticsearch.gradle.info.BuildParams +import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask + +apply plugin: 'elasticsearch.testclusters' +apply plugin: 'elasticsearch.standalone-test' +apply from: "$rootDir/gradle/bwc-test.gradle" +apply plugin: 'elasticsearch.rest-resources' + +dependencies { + testImplementation project(':client:rest-high-level') +} + +for (Version bwcVersion : BuildParams.bwcVersions.wireCompatible) { + String baseName = "v${bwcVersion}" + String bwcVersionStr = "${bwcVersion}" + + /** + * We execute tests 3 times. + * - The local cluster is unchanged and it consists of an old version node and a new version node. + * - Nodes in the remote cluster are upgraded one by one in three steps. + * - Only node-0 and node-2 of the remote cluster can accept remote connections. This can creates a test + * scenario where a query request and fetch request are sent via **proxy nodes** that have different version. + */ + testClusters { + "${baseName}-local" { + numberOfNodes = 2 + versions = [bwcVersionStr, project.version] + setting 'cluster.remote.node.attr', 'gateway' + } + "${baseName}-remote" { + numberOfNodes = 3 + versions = [bwcVersionStr, project.version] + firstNode.setting 'node.attr.gateway', 'true' + lastNode.setting 'node.attr.gateway', 'true' + } + } + + tasks.withType(StandaloneRestIntegTestTask).matching { it.name.startsWith("${baseName}#") }.configureEach { + useCluster testClusters."${baseName}-local" + useCluster testClusters."${baseName}-remote" + systemProperty 'tests.upgrade_from_version', bwcVersionStr.replace('-SNAPSHOT', '') + + doFirst { + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}-local".allHttpSocketURI.join(",")}") + nonInputProperties.systemProperty('tests.rest.remote_cluster', "${-> testClusters."${baseName}-remote".allHttpSocketURI.join(",")}") + } + } + + tasks.register("${baseName}#oneThirdUpgraded", StandaloneRestIntegTestTask) { + dependsOn "processTestResources" + mustRunAfter("precommit") + doFirst { + testClusters."${baseName}-local".nextNodeToNextVersion() + testClusters."${baseName}-remote".nextNodeToNextVersion() + } + } + + tasks.register("${baseName}#twoThirdUpgraded", StandaloneRestIntegTestTask) { + dependsOn "${baseName}#oneThirdUpgraded" + doFirst { + testClusters."${baseName}-remote".nextNodeToNextVersion() + } + } + + tasks.register("${baseName}#fullUpgraded", StandaloneRestIntegTestTask) { + dependsOn "${baseName}#twoThirdUpgraded" + doFirst { + testClusters."${baseName}-remote".nextNodeToNextVersion() + } + } + + tasks.register(bwcTaskName(bwcVersion)) { + dependsOn tasks.named("${baseName}#fullUpgraded") + } +} diff --git a/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java b/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java new file mode 100644 index 0000000000000..8e7f3a399e3df --- /dev/null +++ b/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java @@ -0,0 +1,263 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.upgrades; + +import org.apache.http.HttpHost; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.Version; +import org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; +import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.RestClient; +import org.elasticsearch.client.RestHighLevelClient; +import org.elasticsearch.client.indices.CreateIndexRequest; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.rest.action.document.RestIndexAction; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.test.rest.yaml.ObjectPath; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; + +/** + * This test ensure that we keep the search states of a CCS request correctly when the local and remote clusters + * have different but compatible versions. See SearchService#createAndPutReaderContext + */ +public class SearchStatesIT extends ESRestTestCase { + + private static final Logger LOGGER = LogManager.getLogger(SearchStatesIT.class); + private static final Version UPGRADE_FROM_VERSION = Version.fromString(System.getProperty("tests.upgrade_from_version")); + private static final String CLUSTER_ALIAS = "remote_cluster"; + + static class Node { + final String id; + final String name; + final Version version; + final String transportAddress; + final String httpAddress; + final Map attributes; + + Node(String id, String name, Version version, String transportAddress, String httpAddress, Map attributes) { + this.id = id; + this.name = name; + this.version = version; + this.transportAddress = transportAddress; + this.httpAddress = httpAddress; + this.attributes = attributes; + } + + @Override + public String toString() { + return "Node{" + + "id='" + id + '\'' + + ", name='" + name + '\'' + + ", version=" + version + + ", transportAddress='" + transportAddress + '\'' + + ", httpAddress='" + httpAddress + '\'' + + ", attributes=" + attributes + + '}'; + } + } + + static List getNodes(RestClient restClient) throws IOException { + Response response = restClient.performRequest(new Request("GET", "_nodes")); + ObjectPath objectPath = ObjectPath.createFromResponse(response); + final Map nodeMap = objectPath.evaluate("nodes"); + final List nodes = new ArrayList<>(); + for (String id : nodeMap.keySet()) { + final String name = objectPath.evaluate("nodes." + id + ".name"); + final Version version = Version.fromString(objectPath.evaluate("nodes." + id + ".version")); + final String transportAddress = objectPath.evaluate("nodes." + id + ".transport.publish_address"); + final String httpAddress = objectPath.evaluate("nodes." + id + ".http.publish_address"); + final Map attributes = objectPath.evaluate("nodes." + id + ".attributes"); + nodes.add(new Node(id, name, version, transportAddress, httpAddress, attributes)); + } + return nodes; + } + + static List parseHosts(String props) { + final String address = System.getProperty(props); + assertNotNull("[" + props + "] is not configured", address); + String[] stringUrls = address.split(","); + List hosts = new ArrayList<>(stringUrls.length); + for (String stringUrl : stringUrls) { + int portSeparator = stringUrl.lastIndexOf(':'); + if (portSeparator < 0) { + throw new IllegalArgumentException("Illegal cluster url [" + stringUrl + "]"); + } + String host = stringUrl.substring(0, portSeparator); + int port = Integer.parseInt(stringUrl.substring(portSeparator + 1)); + hosts.add(new HttpHost(host, port, "http")); + } + assertThat("[" + props + "] is empty", hosts, not(empty())); + return hosts; + } + + public static void configureRemoteClusters(List remoteNodes) throws Exception { + assertThat(remoteNodes, hasSize(3)); + final String remoteClusterSettingPrefix = "cluster.remote." + CLUSTER_ALIAS + "."; + try (RestHighLevelClient localClient = newLocalClient()) { + final Settings remoteConnectionSettings; + if (UPGRADE_FROM_VERSION.before(Version.V_7_6_0) || randomBoolean()) { + final List seeds = remoteNodes.stream() + .filter(n -> n.attributes.containsKey("gateway")) + .map(n -> n.transportAddress) + .collect(Collectors.toList()); + assertThat(seeds, hasSize(2)); + LOGGER.info("--> use sniff mode with seed [{}], remote nodes [{}]", seeds, remoteNodes); + if (UPGRADE_FROM_VERSION.before(Version.V_7_6_0)) { + remoteConnectionSettings = Settings.builder() + .putList(remoteClusterSettingPrefix + "seeds", seeds) + .build(); + } else { + remoteConnectionSettings = Settings.builder() + .putNull(remoteClusterSettingPrefix + "proxy_address") + .put(remoteClusterSettingPrefix + "mode", "sniff") + .putList(remoteClusterSettingPrefix + "seeds", seeds) + .build(); + } + } else { + final Node proxyNode = randomFrom(remoteNodes); + LOGGER.info("--> use proxy node [{}], remote nodes [{}]", proxyNode, remoteNodes); + remoteConnectionSettings = Settings.builder() + .putNull(remoteClusterSettingPrefix + "seeds") + .put(remoteClusterSettingPrefix + "mode", "proxy") + .put(remoteClusterSettingPrefix + "proxy_address", proxyNode.transportAddress) + .build(); + } + assertTrue( + localClient.cluster() + .putSettings(new ClusterUpdateSettingsRequest().persistentSettings(remoteConnectionSettings), RequestOptions.DEFAULT) + .isAcknowledged() + ); + assertBusy(() -> { + final Response resp = localClient.getLowLevelClient().performRequest(new Request("GET", "/_remote/info")); + assertOK(resp); + final ObjectPath objectPath = ObjectPath.createFromResponse(resp); + assertNotNull(objectPath.evaluate(CLUSTER_ALIAS)); + assertTrue(objectPath.evaluate(CLUSTER_ALIAS + ".connected")); + }, 60, TimeUnit.SECONDS); + } + } + + static RestHighLevelClient newLocalClient() { + final List hosts = parseHosts("tests.rest.cluster"); + final int index = random().nextInt(hosts.size()); + LOGGER.info("Using client node {}", index); + return new RestHighLevelClient(RestClient.builder(hosts.get(index))); + } + + static RestHighLevelClient newRemoteClient() { + return new RestHighLevelClient(RestClient.builder(randomFrom(parseHosts("tests.rest.remote_cluster")))); + } + + static int indexDocs(RestHighLevelClient client, String index, int numDocs) throws IOException { + for (int i = 0; i < numDocs; i++) { + Request indexDoc = new Request("PUT", index + "/type/" + i); + indexDoc.setJsonEntity("{\"f\":" + i + "}"); + indexDoc.setOptions(expectWarnings(RestIndexAction.TYPES_DEPRECATION_MESSAGE)); + client.getLowLevelClient().performRequest(indexDoc); + } + client.indices().refresh(new RefreshRequest(index), RequestOptions.DEFAULT); + return numDocs; + } + + void verifySearch(String localIndex, int localNumDocs, String remoteIndex, int remoteNumDocs) { + try (RestHighLevelClient localClient = newLocalClient()) { + Request request = new Request("POST", "/_search"); + final int expectedDocs; + if (randomBoolean()) { + request.addParameter("index", remoteIndex); + expectedDocs = remoteNumDocs; + } else { + request.addParameter("index", localIndex + "," + remoteIndex); + expectedDocs = localNumDocs + remoteNumDocs; + } + if (UPGRADE_FROM_VERSION.onOrAfter(Version.V_7_0_0)) { + request.addParameter("ccs_minimize_roundtrips", Boolean.toString(randomBoolean())); + } + int size = between(1, 100); + request.setJsonEntity("{\"sort\": \"f\", \"size\": " + size + "}"); + Response response = localClient.getLowLevelClient().performRequest(request); + try (XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent())) { + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + ElasticsearchAssertions.assertNoFailures(searchResponse); + ElasticsearchAssertions.assertHitCount(searchResponse, expectedDocs); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void testBWCSearchStates() throws Exception { + String localIndex = "test_bwc_search_states_index"; + String remoteIndex = "test_bwc_search_states_remote_index"; + try (RestHighLevelClient localClient = newLocalClient(); + RestHighLevelClient remoteClient = newRemoteClient()) { + localClient.indices().create(new CreateIndexRequest(localIndex) + .settings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))), + RequestOptions.DEFAULT); + int localNumDocs = indexDocs(localClient, localIndex, between(10, 100)); + + remoteClient.indices().create(new CreateIndexRequest(remoteIndex) + .settings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))), + RequestOptions.DEFAULT); + int remoteNumDocs = indexDocs(remoteClient, remoteIndex, between(10, 100)); + + configureRemoteClusters(getNodes(remoteClient.getLowLevelClient())); + int iterations = between(1, 20); + for (int i = 0; i < iterations; i++) { + verifySearch(localIndex, localNumDocs, CLUSTER_ALIAS + ":" + remoteIndex, remoteNumDocs); + } + localClient.indices().delete(new DeleteIndexRequest(localIndex), RequestOptions.DEFAULT); + remoteClient.indices().delete(new DeleteIndexRequest(remoteIndex), RequestOptions.DEFAULT); + } + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index 847ddc35760ab..a58eaf149a754 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -26,6 +26,8 @@ import org.elasticsearch.index.shard.SearchOperationListener; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.LegacyReaderContext; +import org.elasticsearch.search.internal.ReaderContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskInfo; @@ -45,6 +47,8 @@ import java.util.stream.StreamSupport; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; public class CrossClusterSearchIT extends AbstractMultiClustersTestCase { @@ -206,6 +210,11 @@ static void waitSearchStarted() throws InterruptedException { @Override public void onIndexModule(IndexModule indexModule) { indexModule.addSearchOperationListener(new SearchOperationListener() { + @Override + public void onNewReaderContext(ReaderContext readerContext) { + assertThat(readerContext, not(instanceOf(LegacyReaderContext.class))); + } + @Override public void onPreQueryPhase(SearchContext searchContext) { startedLatch.get().countDown(); @@ -222,4 +231,5 @@ public void onPreQueryPhase(SearchContext searchContext) { super.onIndexModule(indexModule); } } + } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java index e26772cc1867d..5795fc6514790 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java @@ -14,19 +14,27 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.shard.SearchOperationListener; +import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.internal.LegacyReaderContext; +import org.elasticsearch.search.internal.ReaderContext; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.ESIntegTestCase; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.concurrent.ExecutionException; @@ -43,9 +51,16 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; public class SimpleSearchIT extends ESIntegTestCase { + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), VerifyReaderContextPlugin.class); + } + public void testSearchNullIndex() { expectThrows(NullPointerException.class, () -> client().prepareSearch((String) null).setQuery(QueryBuilders.termQuery("_id", "XXX1")).get()); @@ -517,4 +532,27 @@ private void assertRescoreWindowFails(int windowSize) { assertThat(e.toString(), containsString( "This limit can be set by changing the [" + IndexSettings.MAX_RESCORE_WINDOW_SETTING.getKey() + "] index level setting.")); } + + public static class VerifyReaderContextPlugin extends Plugin { + @Override + public void onIndexModule(IndexModule indexModule) { + super.onIndexModule(indexModule); + indexModule.addSearchOperationListener(new SearchOperationListener() { + @Override + public void onNewReaderContext(ReaderContext readerContext) { + assertThat(readerContext, not(instanceOf(LegacyReaderContext.class))); + } + + @Override + public void onQueryPhase(SearchContext searchContext, long tookInNanos) { + assertThat(searchContext.readerContext(), not(instanceOf(LegacyReaderContext.class))); + } + + @Override + public void onFetchPhase(SearchContext searchContext, long tookInNanos) { + assertThat(searchContext.readerContext(), not(instanceOf(LegacyReaderContext.class))); + } + }); + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 7f5a64398a338..0707bb99cfc22 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -8,7 +8,6 @@ package org.elasticsearch.action.search; -import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.IndicesRequest; @@ -277,10 +276,6 @@ public void writeTo(StreamOutput out) throws IOException { } } - static boolean keepStatesInContext(Version version) { - return version.before(Version.V_7_10_0); - } - public static void registerRequestHandler(TransportService transportService, SearchService searchService) { transportService.registerRequestHandler(FREE_CONTEXT_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, ScrollFreeContextRequest::new, (request, channel, task) -> { @@ -305,17 +300,17 @@ public static void registerRequestHandler(TransportService transportService, Sea transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> - searchService.executeDfsPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task, + searchService.executeDfsPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, DFS_ACTION_NAME, request)) ); TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new); transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, - (request, channel, task) -> { - searchService.executeQueryPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task, - new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request)); - }); + (request, channel, task) -> + searchService.executeQueryPhase(request, (SearchShardTask) task, + new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request)) + ); TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, true, (request) -> ((ShardSearchRequest)request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new); diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 00fbff5dee87b..fa7ac08fe58cb 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -322,14 +322,13 @@ protected void doClose() { keepAliveReaper.cancel(); } - public void executeDfsPhase(ShardSearchRequest request, boolean keepStatesInContext, - SearchShardTask task, ActionListener listener) { + public void executeDfsPhase(ShardSearchRequest request, SearchShardTask task, ActionListener listener) { final IndexShard shard = getShard(request); rewriteAndFetchShardRequest(shard, request, new ActionListener() { @Override public void onResponse(ShardSearchRequest rewritten) { // fork the execution in the search thread pool - runAsync(getExecutor(shard), () -> executeDfsPhase(request, task, keepStatesInContext), listener); + runAsync(getExecutor(shard), () -> executeDfsPhase(request, task), listener); } @Override @@ -339,10 +338,8 @@ public void onFailure(Exception exc) { }); } - private DfsSearchResult executeDfsPhase(ShardSearchRequest request, - SearchShardTask task, - boolean keepStatesInContext) throws IOException { - ReaderContext readerContext = createOrGetReaderContext(request, keepStatesInContext); + private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardTask task) throws IOException { + ReaderContext readerContext = createOrGetReaderContext(request); try (Releasable ignored = readerContext.markAsUsed(getKeepAlive(request)); SearchContext context = createContext(readerContext, request, task, true)) { dfsPhase.execute(context); @@ -367,8 +364,7 @@ private void loadOrExecuteQueryPhase(final ShardSearchRequest request, final Sea } } - public void executeQueryPhase(ShardSearchRequest request, boolean keepStatesInContext, - SearchShardTask task, ActionListener listener) { + public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, ActionListener listener) { assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1 : "empty responses require more than one shard"; final IndexShard shard = getShard(request); @@ -392,7 +388,7 @@ public void onResponse(ShardSearchRequest orig) { } } // fork the execution in the search thread pool - runAsync(getExecutor(shard), () -> executeQueryPhase(orig, task, keepStatesInContext), listener); + runAsync(getExecutor(shard), () -> executeQueryPhase(orig, task), listener); } @Override @@ -420,10 +416,8 @@ private void runAsync(Executor executor, CheckedSupplier execu executor.execute(ActionRunnable.supply(listener, executable::get)); } - private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, - SearchShardTask task, - boolean keepStatesInContext) throws Exception { - final ReaderContext readerContext = createOrGetReaderContext(request, keepStatesInContext); + private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchShardTask task) throws Exception { + final ReaderContext readerContext = createOrGetReaderContext(request); try (Releasable ignored = readerContext.markAsUsed(getKeepAlive(request)); SearchContext context = createContext(readerContext, request, task, true)) { final long afterQueryTime; @@ -618,9 +612,9 @@ private ReaderContext findReaderContext(ShardSearchContextId id, TransportReques return reader; } - final ReaderContext createOrGetReaderContext(ShardSearchRequest request, boolean keepStatesInContext) { + final ReaderContext createOrGetReaderContext(ShardSearchRequest request) { if (request.readerId() != null) { - assert keepStatesInContext == false; + assert request.scroll() == null : "scroll can't be used with pit"; try { return findReaderContext(request.readerId(), request); } catch (SearchContextMissingException e) { @@ -635,19 +629,19 @@ final ReaderContext createOrGetReaderContext(ShardSearchRequest request, boolean searcherSupplier.close(); throw e; } - return createAndPutReaderContext(request, indexService, shard, searcherSupplier, false, defaultKeepAlive); + return createAndPutReaderContext(request, indexService, shard, searcherSupplier, defaultKeepAlive); } } else { final long keepAliveInMillis = getKeepAlive(request); final IndexService indexService = indicesService.indexServiceSafe(request.shardId().getIndex()); final IndexShard shard = indexService.getShard(request.shardId().id()); final Engine.SearcherSupplier searcherSupplier = shard.acquireSearcherSupplier(); - return createAndPutReaderContext(request, indexService, shard, searcherSupplier, keepStatesInContext, keepAliveInMillis); + return createAndPutReaderContext(request, indexService, shard, searcherSupplier, keepAliveInMillis); } } final ReaderContext createAndPutReaderContext(ShardSearchRequest request, IndexService indexService, IndexShard shard, - Engine.SearcherSupplier reader, boolean keepStatesInContext, long keepAliveInMillis) { + Engine.SearcherSupplier reader, long keepAliveInMillis) { ReaderContext readerContext = null; Releasable decreaseScrollContexts = null; try { @@ -661,7 +655,23 @@ final ReaderContext createAndPutReaderContext(ShardSearchRequest request, IndexS } } final ShardSearchContextId id = new ShardSearchContextId(sessionId, idGenerator.incrementAndGet()); - if (keepStatesInContext || request.scroll() != null) { + // Previously, the search states are stored in ReaderContext on data nodes. Since 7.10, they are now + // sent to the coordinating node in QuerySearchResult and the coordinating node then sends them back + // in ShardFetchSearchRequest. We must keep the search states in ReaderContext unless the coordinating + // node is guaranteed to send them back in the fetch phase. + // Three cases that we have to keep the search states in ReaderContext: + // 1. Scroll requests + // 2. The coordinating node or a proxy node (i.e. CCS) is on the old version. The `channelVersion` + // of ShardSearchRequest, which is the minimum version of nodes that the request has been passed, + // can be used to determine this. + // 3. Any node on the cluster is on the old version. This extra check is to avoid a situation where a + // ShardSearchRequest is sent via a new proxy node, but a ShardFetchSearchRequest on an old proxy node. + // + // Note that it's ok to keep the search states in ReaderContext even when the coordinating node also sends + // them back in the fetch phase and it only happens in a mixed cluster. + if (request.scroll() != null || + request.getChannelVersion().before(Version.V_7_12_1) || + clusterService.state().nodes().getMinNodeVersion().before(Version.V_7_12_1)) { readerContext = new LegacyReaderContext(id, indexService, shard, reader, request, keepAliveInMillis); if (request.scroll() != null) { readerContext.addOnClose(decreaseScrollContexts); diff --git a/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java b/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java index bbcc2e2cd60eb..aeb0b2815d4e5 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java @@ -30,7 +30,7 @@ public LegacyReaderContext(ShardSearchContextId id, IndexService indexService, I assert shardSearchRequest.readerId() == null; assert shardSearchRequest.keepAlive() == null; assert id.getSearcherId() == null : "Legacy reader context must not have searcher id"; - this.shardSearchRequest = Objects.requireNonNull(shardSearchRequest); + this.shardSearchRequest = Objects.requireNonNull(shardSearchRequest, "ShardSearchRequest must be provided"); if (shardSearchRequest.scroll() != null) { // Search scroll requests are special, they don't hold indices names so we have // to reuse the searcher created on the request that initialized the scroll. diff --git a/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java b/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java index ea2b5b037ef46..d5865fcff5681 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java @@ -140,7 +140,7 @@ public boolean isExpired() { // BWC public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) { - return Objects.requireNonNull(other); + return Objects.requireNonNull(other, "ShardSearchRequest must be sent back in a fetch request"); } public ScrollContext scrollContext() { @@ -156,7 +156,7 @@ public void setAggregatedDfs(AggregatedDfs aggregatedDfs) { } public RescoreDocIds getRescoreDocIds(RescoreDocIds other) { - return Objects.requireNonNull(other); + return Objects.requireNonNull(other, "RescoreDocIds must be sent back in a fetch request"); } public void setRescoreDocIds(RescoreDocIds rescoreDocIds) { diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java index 3dda1b9867092..5050e735d09be 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java @@ -80,6 +80,8 @@ public class ShardSearchRequest extends TransportRequest implements IndicesReque private final ShardSearchContextId readerId; private final TimeValue keepAlive; + private final Version channelVersion; + public ShardSearchRequest(OriginalIndices originalIndices, SearchRequest searchRequest, ShardId shardId, @@ -166,6 +168,7 @@ private ShardSearchRequest(OriginalIndices originalIndices, this.readerId = readerId; this.keepAlive = keepAlive; assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; + this.channelVersion = Version.CURRENT; } public ShardSearchRequest(StreamInput in) throws IOException { @@ -207,8 +210,13 @@ public ShardSearchRequest(StreamInput in) throws IOException { this.readerId = null; this.keepAlive = null; } - originalIndices = OriginalIndices.readOriginalIndices(in); assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; + if (in.getVersion().onOrAfter(Version.V_7_12_1)) { + channelVersion = Version.min(Version.readVersion(in), in.getVersion()); + } else { + channelVersion = in.getVersion(); + } + originalIndices = OriginalIndices.readOriginalIndices(in); } public ShardSearchRequest(ShardSearchRequest clone) { @@ -230,6 +238,7 @@ public ShardSearchRequest(ShardSearchRequest clone) { this.originalIndices = clone.originalIndices; this.readerId = clone.readerId; this.keepAlive = clone.keepAlive; + this.channelVersion = clone.channelVersion; } @Override @@ -275,6 +284,9 @@ protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOExce out.writeOptionalWriteable(readerId); out.writeOptionalTimeValue(keepAlive); } + if (out.getVersion().onOrAfter(Version.V_7_12_1)) { + Version.writeVersion(channelVersion, out); + } } @Override @@ -532,4 +544,13 @@ public static QueryBuilder parseAliasFilter(CheckedFunction getRuntimeMappings() { return source == null ? emptyMap() : source.runtimeMappings(); } + + /** + * Returns the minimum version of the channel that the request has been passed. If the request never passes around, then the channel + * version is {@link Version#CURRENT}; otherwise, it's the minimum version of the coordinating node and data node (and the proxy node + * in case the request is sent to the proxy node of the remote cluster before reaching the data node). + */ + public Version getChannelVersion() { + return channelVersion; + } } diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index 66d660fc56e05..e74ba55865be0 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.Queries; @@ -77,6 +78,7 @@ import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.LegacyReaderContext; import org.elasticsearch.search.internal.ReaderContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchContextId; @@ -84,10 +86,12 @@ import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.VersionUtils; import org.elasticsearch.threadpool.ThreadPool; import org.junit.Before; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -114,6 +118,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.Matchers.not; public class SearchServiceTests extends ESSingleNodeTestCase { @@ -312,7 +317,6 @@ public void onFailure(Exception e) { new ShardSearchRequest(OriginalIndices.NONE, useScroll ? scrollSearchRequest : searchRequest, indexShard.shardId(), 0, 1, new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null), - true, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), result); SearchPhaseResult searchPhaseResult = result.get(); IntArrayList intCursors = new IntArrayList(1); @@ -380,7 +384,7 @@ public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws Executi new ShardSearchRequest(OriginalIndices.NONE, useScroll ? scrollSearchRequest : searchRequest, new ShardId(resolveIndex("index"), 0), 0, 1, new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null), - randomBoolean(), new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), result); + new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), result); try { result.get(); @@ -573,7 +577,7 @@ public void testMaxOpenScrollContexts() throws Exception { final ShardScrollRequestTest request = new ShardScrollRequestTest(indexShard.shardId()); ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> service.createAndPutReaderContext(request, indexService, indexShard, indexShard.acquireSearcherSupplier(), - randomBoolean(), SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis())); + SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis())); assertEquals( "Trying to create too many scroll contexts. Must be less than or equal to: [" + SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY) + "]. " + @@ -602,7 +606,7 @@ public void testOpenScrollContextsConcurrently() throws Exception { final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier(); try { final ShardScrollRequestTest request = new ShardScrollRequestTest(indexShard.shardId()); - searchService.createAndPutReaderContext(request, indexService, indexShard, reader, true, + searchService.createAndPutReaderContext(request, indexService, indexShard, reader, SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis()); } catch (ElasticsearchException e) { assertThat(e.getMessage(), equalTo( @@ -747,7 +751,7 @@ public void testCanMatch() throws Exception { CountDownLatch latch = new CountDownLatch(1); SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); - service.executeQueryPhase(request, randomBoolean(), task, new ActionListener() { + service.executeQueryPhase(request, task, new ActionListener() { @Override public void onResponse(SearchPhaseResult searchPhaseResult) { try { @@ -933,6 +937,64 @@ public void testCreateSearchContext() throws IOException { } } + public void testKeepStatesInContext() { + String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); + IndexService indexService = createIndex(index); + SearchService searchService = getInstanceFromNode(SearchService.class); + ShardId shardId = new ShardId(indexService.index(), 0); + String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.allowPartialSearchResults(randomBoolean()); + long nowInMillis = System.currentTimeMillis(); + int numIndices = randomInt(10); + String[] indices = new String[numIndices]; + for (int j = 0; j < indices.length; j++) { + indices[j] = randomAlphaOfLength(randomIntBetween(1, 10)); + } + ShardSearchRequest request = new ShardSearchRequest(new OriginalIndices(indices, IndicesOptions.lenientExpandOpen()), + searchRequest, shardId, 0, indexService.numberOfShards(), AliasFilter.EMPTY, 1f, nowInMillis, clusterAlias); + { + assertThat(request.getChannelVersion(), equalTo(Version.CURRENT)); + ReaderContext readerContext = searchService.createOrGetReaderContext(request); + assertThat(readerContext, not(instanceOf(LegacyReaderContext.class))); + NullPointerException error = expectThrows(NullPointerException.class, () -> readerContext.getShardSearchRequest(null)); + assertThat(error.getMessage(), equalTo("ShardSearchRequest must be sent back in a fetch request")); + searchService.freeReaderContext(readerContext.id()); + } + if (randomBoolean()) { + final Version version = VersionUtils.randomVersionBetween(random(), Version.V_7_12_1, Version.CURRENT); + request = serialize(request, version); + assertThat(request.getChannelVersion(), equalTo(version)); + ReaderContext readerContext = searchService.createOrGetReaderContext(request); + assertThat(readerContext, not(instanceOf(LegacyReaderContext.class))); + NullPointerException error = expectThrows(NullPointerException.class, () -> readerContext.getShardSearchRequest(null)); + assertThat(error.getMessage(), equalTo("ShardSearchRequest must be sent back in a fetch request")); + searchService.freeReaderContext(readerContext.id()); + } else { + final Version version = VersionUtils.randomVersionBetween( + random(), Version.V_7_0_0, VersionUtils.getPreviousVersion(Version.V_7_12_1)); + request = serialize(request, version); + assertThat(request.getChannelVersion(), equalTo(version)); + ReaderContext readerContext = searchService.createOrGetReaderContext(request); + assertThat(readerContext, instanceOf(LegacyReaderContext.class)); + assertThat(readerContext.getShardSearchRequest(null), equalTo(request)); + searchService.freeReaderContext(readerContext.id()); + } + } + + private static ShardSearchRequest serialize(ShardSearchRequest request, Version version) { + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.setVersion(version); + request.writeTo(out); + try (StreamInput in = out.bytes().streamInput()) { + in.setVersion(version); + return new ShardSearchRequest(in); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + /** * While we have no NPE in DefaultContext constructor anymore, we still want to guard against it (or other failures) in the future to * avoid leaking searchers. @@ -977,7 +1039,7 @@ public void testMatchNoDocsEmptyResponse() throws InterruptedException { { CountDownLatch latch = new CountDownLatch(1); shardRequest.source().query(new MatchAllQueryBuilder()); - service.executeQueryPhase(shardRequest, randomBoolean(), task, new ActionListener() { + service.executeQueryPhase(shardRequest, task, new ActionListener() { @Override public void onResponse(SearchPhaseResult result) { try { @@ -1007,7 +1069,7 @@ public void onFailure(Exception exc) { { CountDownLatch latch = new CountDownLatch(1); shardRequest.source().query(new MatchNoneQueryBuilder()); - service.executeQueryPhase(shardRequest, randomBoolean(), task, new ActionListener() { + service.executeQueryPhase(shardRequest, task, new ActionListener() { @Override public void onResponse(SearchPhaseResult result) { try { @@ -1037,7 +1099,7 @@ public void onFailure(Exception exc) { { CountDownLatch latch = new CountDownLatch(1); shardRequest.canReturnNullResponseIfMatchNoDocs(true); - service.executeQueryPhase(shardRequest, randomBoolean(), task, new ActionListener() { + service.executeQueryPhase(shardRequest, task, new ActionListener() { @Override public void onResponse(SearchPhaseResult result) { try { @@ -1105,12 +1167,13 @@ public void testLookUpSearchContext() throws Exception { CountDownLatch latch = new CountDownLatch(1); indexShard.getThreadPool().executor(ThreadPool.Names.SEARCH).execute(() -> { try { + // TODO: Add the context for (int i = 0; i < numContexts; i++) { ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, new SearchRequest().allowPartialSearchResults(true), indexShard.shardId(), 0, 1, new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null); final ReaderContext context = searchService.createAndPutReaderContext(request, indexService, indexShard, - indexShard.acquireSearcherSupplier(), randomBoolean(), + indexShard.acquireSearcherSupplier(), SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis()); assertThat(context.id().getId(), equalTo((long) (i + 1))); contextIds.add(context.id()); diff --git a/server/src/test/java/org/elasticsearch/search/internal/ShardSearchRequestTests.java b/server/src/test/java/org/elasticsearch/search/internal/ShardSearchRequestTests.java index b2c53cbbe2e36..00d3a39666045 100644 --- a/server/src/test/java/org/elasticsearch/search/internal/ShardSearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/search/internal/ShardSearchRequestTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.collect.Map; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; @@ -197,4 +198,56 @@ public QueryBuilder aliasFilter(IndexMetadata indexMetadata, String... aliasName } }, indexMetadata, aliasNames); } + + ShardSearchRequest serialize(ShardSearchRequest request, Version version) throws IOException { + if (version.before(Version.V_7_11_0)) { + if (request.source() != null) { + request.source().runtimeMappings(Map.of()); + } + } + return copyWriteable(request, namedWriteableRegistry, ShardSearchRequest::new, version); + } + + public void testChannelVersion() throws Exception { + ShardSearchRequest request = createShardSearchRequest(); + Version channelVersion = Version.CURRENT; + assertThat(request.getChannelVersion(), equalTo(channelVersion)); + int iterations = between(0, 5); + // New version + for (int i = 0; i < iterations; i++) { + Version newVersion = VersionUtils.randomVersionBetween(random(), Version.V_7_12_1, Version.CURRENT); + request = serialize(request, newVersion); + channelVersion = Version.min(newVersion, channelVersion); + assertThat(request.getChannelVersion(), equalTo(channelVersion)); + if (randomBoolean()) { + request = new ShardSearchRequest(request); + } + } + // Old version + iterations = between(1, 5); + for (int i = 0; i < iterations; i++) { + channelVersion = VersionUtils.randomVersionBetween(random(), + Version.V_7_0_0, VersionUtils.getPreviousVersion(Version.V_7_12_1)); + request = serialize(request, channelVersion); + assertThat(request.getChannelVersion(), equalTo(channelVersion)); + if (randomBoolean()) { + request = new ShardSearchRequest(request); + } + } + // Any version + iterations = between(1, 5); + for (int i = 0; i < iterations; i++) { + Version version = VersionUtils.randomVersion(random()); + request = serialize(request, version); + if (version.onOrAfter(Version.V_7_12_1)) { + channelVersion = Version.min(channelVersion, version); + } else { + channelVersion = version; + } + assertThat(request.getChannelVersion(), equalTo(channelVersion)); + if (randomBoolean()) { + request = new ShardSearchRequest(request); + } + } + } }