From 04995892bda2cd27f0e8db76eed24d42b09d2019 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Fri, 22 Nov 2024 15:06:15 +0100 Subject: [PATCH] Fix unnecessary context switch in RankFeaturePhase (#113232) (#117339) If we don't actually execute this phase we shouldn't fork the phase unnecessarily. We can compute the RankFeaturePhaseRankCoordinatorContext on the transport thread and move on to fetch without forking. Fetch itself will then fork and we can run the reduce as part of fetch instead of in a separte search pool task (this is the way it worked up until the recent introduction of RankFeaturePhase, this fixes that regression). --- .../action/search/FetchSearchPhase.java | 26 +++++--- .../action/search/RankFeaturePhase.java | 62 +++++++++---------- .../action/search/RankFeaturePhaseTests.java | 10 ++- .../rank/RankFeatureShardPhaseTests.java | 8 ++- 4 files changed, 61 insertions(+), 45 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index a3bebd5d7c9cd..7ad61f60c0088 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Nullable; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.dfs.AggregatedDfs; @@ -39,13 +40,15 @@ final class FetchSearchPhase extends SearchPhase { private final Logger logger; private final SearchProgressListener progressListener; private final AggregatedDfs aggregatedDfs; + @Nullable + private final SearchPhaseResults resultConsumer; private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase; FetchSearchPhase( SearchPhaseResults resultConsumer, AggregatedDfs aggregatedDfs, SearchPhaseContext context, - SearchPhaseController.ReducedQueryPhase reducedQueryPhase + @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { this( resultConsumer, @@ -64,7 +67,7 @@ final class FetchSearchPhase extends SearchPhase { SearchPhaseResults resultConsumer, AggregatedDfs aggregatedDfs, SearchPhaseContext context, - SearchPhaseController.ReducedQueryPhase reducedQueryPhase, + @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase, BiFunction, SearchPhase> nextPhaseFactory ) { super("fetch"); @@ -85,6 +88,7 @@ final class FetchSearchPhase extends SearchPhase { this.logger = context.getLogger(); this.progressListener = context.getTask().getProgressListener(); this.reducedQueryPhase = reducedQueryPhase; + this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null; } @Override @@ -92,7 +96,7 @@ public void run() { context.execute(new AbstractRunnable() { @Override - protected void doRun() { + protected void doRun() throws Exception { innerRun(); } @@ -103,7 +107,10 @@ public void onFailure(Exception e) { }); } - private void innerRun() { + private void innerRun() throws Exception { + assert this.reducedQueryPhase == null ^ this.resultConsumer == null; + // depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already + final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase; final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. @@ -114,7 +121,7 @@ private void innerRun() { if (queryAndFetchOptimization) { assert assertConsistentWithQueryAndFetchOptimization(); // query AND fetch optimization - moveToNextPhase(searchPhaseShardResults); + moveToNextPhase(searchPhaseShardResults, reducedQueryPhase); } else { ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // no docs to fetch -- sidestep everything and return @@ -122,7 +129,7 @@ private void innerRun() { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(fetchResults.getAtomicArray()); + moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase); } else { final boolean shouldExplainRank = shouldExplainRankScores(context.getRequest()); final List> rankDocsPerShard = false == shouldExplainRank @@ -135,7 +142,7 @@ private void innerRun() { final CountedCollector counter = new CountedCollector<>( fetchResults, docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not - () -> moveToNextPhase(fetchResults.getAtomicArray()), + () -> moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase), context ); for (int i = 0; i < docIdsToLoad.length; i++) { @@ -244,7 +251,10 @@ public void onFailure(Exception e) { ); } - private void moveToNextPhase(AtomicArray fetchResultsArr) { + private void moveToNextPhase( + AtomicArray fetchResultsArr, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp::decRef); fetchResults.close(); diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 49e5c1b6d69e3..81053a70eca9f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -70,6 +70,12 @@ public class RankFeaturePhase extends SearchPhase { @Override public void run() { + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source()); + if (rankFeaturePhaseRankCoordinatorContext == null) { + moveToNextPhase(queryPhaseResults, null); + return; + } + context.execute(new AbstractRunnable() { @Override protected void doRun() throws Exception { @@ -77,7 +83,7 @@ protected void doRun() throws Exception { // was set up at FetchSearchPhase. // we do the heavy lifting in this inner run method where we reduce aggs etc - innerRun(); + innerRun(rankFeaturePhaseRankCoordinatorContext); } @Override @@ -87,51 +93,39 @@ public void onFailure(Exception e) { }); } - void innerRun() throws Exception { + void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception { // if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call // to operate on the first `rank_window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); - RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source()); - if (rankFeaturePhaseRankCoordinatorContext != null) { - ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size - final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); - final CountedCollector rankRequestCounter = new CountedCollector<>( - rankPhaseResults, - context.getNumShards(), - () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), - context - ); + ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final CountedCollector rankRequestCounter = new CountedCollector<>( + rankPhaseResults, + context.getNumShards(), + () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), + context + ); - // we send out a request to each shard in order to fetch the needed feature info - for (int i = 0; i < docIdsToLoad.length; i++) { - List entry = docIdsToLoad[i]; - SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i); - if (entry == null || entry.isEmpty()) { - if (queryResult != null) { - releaseIrrelevantSearchContext(queryResult, context); - progressListener.notifyRankFeatureResult(i); - } - rankRequestCounter.countDown(); - } else { - executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry); + // we send out a request to each shard in order to fetch the needed feature info + for (int i = 0; i < docIdsToLoad.length; i++) { + List entry = docIdsToLoad[i]; + SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i); + if (entry == null || entry.isEmpty()) { + if (queryResult != null) { + releaseIrrelevantSearchContext(queryResult, context); + progressListener.notifyRankFeatureResult(i); } + rankRequestCounter.countDown(); + } else { + executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry); } - } else { - moveToNextPhase(queryPhaseResults, reducedQueryPhase); } } private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) { return source == null || source.rankBuilder() == null ? null - : context.getRequest() - .source() - .rankBuilder() - .buildRankFeaturePhaseCoordinatorContext( - context.getRequest().source().size(), - context.getRequest().source().from(), - client - ); + : source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client); } private void executeRankFeatureShardPhase( diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java index 68a161d426f48..82463d601d164 100644 --- a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -536,7 +536,7 @@ public void sendExecuteRankFeature( // override the RankFeaturePhase to raise an exception RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) { @Override - void innerRun() { + void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) { throw new IllegalArgumentException("simulated failure"); } @@ -1142,7 +1142,13 @@ public void moveToNextPhase( ) { // this is called after the RankFeaturePhaseCoordinatorContext has been executed phaseDone.set(true); - finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs(); + try { + finalResults[0] = reducedQueryPhase == null + ? queryPhaseResults.reduce().sortedTopDocs().scoreDocs() + : reducedQueryPhase.sortedTopDocs().scoreDocs(); + } catch (Exception e) { + throw new AssertionError(e); + } logger.debug("Skipping moving to next phase"); } }; diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java index cce10fb86c775..6250d1679fda3 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.io.stream.StreamOutput; @@ -170,7 +171,12 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) // no work to be done on the coordinator node for the rank feature phase @Override public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) { - return null; + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new AssertionError("not expected"); + } + }; } @Override