Skip to content

Commit

Permalink
Fix unnecessary context switch in RankFeaturePhase (#113232) (#117339)
Browse files Browse the repository at this point in the history
If we don't actually execute this phase we shouldn't fork the phase
unnecessarily. We can compute the RankFeaturePhaseRankCoordinatorContext
on the transport thread and move on to fetch without forking.
Fetch itself will then fork and we can run the reduce as part of fetch instead of in
a separte search pool task (this is the way it worked up until the recent introduction
of RankFeaturePhase, this fixes that regression).
  • Loading branch information
original-brownbear authored Nov 22, 2024
1 parent 892d4ff commit 0499589
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.dfs.AggregatedDfs;
Expand Down Expand Up @@ -39,13 +40,15 @@ final class FetchSearchPhase extends SearchPhase {
private final Logger logger;
private final SearchProgressListener progressListener;
private final AggregatedDfs aggregatedDfs;
@Nullable
private final SearchPhaseResults<SearchPhaseResult> resultConsumer;
private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase;

FetchSearchPhase(
SearchPhaseResults<SearchPhaseResult> resultConsumer,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
this(
resultConsumer,
Expand All @@ -64,7 +67,7 @@ final class FetchSearchPhase extends SearchPhase {
SearchPhaseResults<SearchPhaseResult> resultConsumer,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
) {
super("fetch");
Expand All @@ -85,14 +88,15 @@ final class FetchSearchPhase extends SearchPhase {
this.logger = context.getLogger();
this.progressListener = context.getTask().getProgressListener();
this.reducedQueryPhase = reducedQueryPhase;
this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null;
}

@Override
public void run() {
context.execute(new AbstractRunnable() {

@Override
protected void doRun() {
protected void doRun() throws Exception {
innerRun();
}

Expand All @@ -103,7 +107,10 @@ public void onFailure(Exception e) {
});
}

private void innerRun() {
private void innerRun() throws Exception {
assert this.reducedQueryPhase == null ^ this.resultConsumer == null;
// depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already
final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase;
final int numShards = context.getNumShards();
// Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
// still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
Expand All @@ -114,15 +121,15 @@ private void innerRun() {
if (queryAndFetchOptimization) {
assert assertConsistentWithQueryAndFetchOptimization();
// query AND fetch optimization
moveToNextPhase(searchPhaseShardResults);
moveToNextPhase(searchPhaseShardResults, reducedQueryPhase);
} else {
ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
// no docs to fetch -- sidestep everything and return
if (scoreDocs.length == 0) {
// we have to release contexts here to free up resources
searchPhaseShardResults.asList()
.forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context));
moveToNextPhase(fetchResults.getAtomicArray());
moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase);
} else {
final boolean shouldExplainRank = shouldExplainRankScores(context.getRequest());
final List<Map<Integer, RankDoc>> rankDocsPerShard = false == shouldExplainRank
Expand All @@ -135,7 +142,7 @@ private void innerRun() {
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(
fetchResults,
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
() -> moveToNextPhase(fetchResults.getAtomicArray()),
() -> moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase),
context
);
for (int i = 0; i < docIdsToLoad.length; i++) {
Expand Down Expand Up @@ -244,7 +251,10 @@ public void onFailure(Exception e) {
);
}

private void moveToNextPhase(AtomicArray<? extends SearchPhaseResult> fetchResultsArr) {
private void moveToNextPhase(
AtomicArray<? extends SearchPhaseResult> fetchResultsArr,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
context.addReleasable(resp::decRef);
fetchResults.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,20 @@ public class RankFeaturePhase extends SearchPhase {

@Override
public void run() {
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
if (rankFeaturePhaseRankCoordinatorContext == null) {
moveToNextPhase(queryPhaseResults, null);
return;
}

context.execute(new AbstractRunnable() {
@Override
protected void doRun() throws Exception {
// we need to reduce the results at this point instead of fetch phase, so we fork this process similarly to how
// was set up at FetchSearchPhase.

// we do the heavy lifting in this inner run method where we reduce aggs etc
innerRun();
innerRun(rankFeaturePhaseRankCoordinatorContext);
}

@Override
Expand All @@ -87,51 +93,39 @@ public void onFailure(Exception e) {
});
}

void innerRun() throws Exception {
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception {
// if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call
// to operate on the first `rank_window_size * num_shards` results and merge them appropriately.
SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce();
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
if (rankFeaturePhaseRankCoordinatorContext != null) {
ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
rankPhaseResults,
context.getNumShards(),
() -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
context
);
ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
rankPhaseResults,
context.getNumShards(),
() -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
context
);

// we send out a request to each shard in order to fetch the needed feature info
for (int i = 0; i < docIdsToLoad.length; i++) {
List<Integer> entry = docIdsToLoad[i];
SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
if (entry == null || entry.isEmpty()) {
if (queryResult != null) {
releaseIrrelevantSearchContext(queryResult, context);
progressListener.notifyRankFeatureResult(i);
}
rankRequestCounter.countDown();
} else {
executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
// we send out a request to each shard in order to fetch the needed feature info
for (int i = 0; i < docIdsToLoad.length; i++) {
List<Integer> entry = docIdsToLoad[i];
SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
if (entry == null || entry.isEmpty()) {
if (queryResult != null) {
releaseIrrelevantSearchContext(queryResult, context);
progressListener.notifyRankFeatureResult(i);
}
rankRequestCounter.countDown();
} else {
executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
}
} else {
moveToNextPhase(queryPhaseResults, reducedQueryPhase);
}
}

private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) {
return source == null || source.rankBuilder() == null
? null
: context.getRequest()
.source()
.rankBuilder()
.buildRankFeaturePhaseCoordinatorContext(
context.getRequest().source().size(),
context.getRequest().source().from(),
client
);
: source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client);
}

private void executeRankFeatureShardPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ public void sendExecuteRankFeature(
// override the RankFeaturePhase to raise an exception
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
@Override
void innerRun() {
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) {
throw new IllegalArgumentException("simulated failure");
}

Expand Down Expand Up @@ -1142,7 +1142,13 @@ public void moveToNextPhase(
) {
// this is called after the RankFeaturePhaseCoordinatorContext has been executed
phaseDone.set(true);
finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs();
try {
finalResults[0] = reducedQueryPhase == null
? queryPhaseResults.reduce().sortedTopDocs().scoreDocs()
: reducedQueryPhase.sortedTopDocs().scoreDocs();
} catch (Exception e) {
throw new AssertionError(e);
}
logger.debug("Skipping moving to next phase");
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -170,7 +171,12 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
// no work to be done on the coordinator node for the rank feature phase
@Override
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
return null;
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
throw new AssertionError("not expected");
}
};
}

@Override
Expand Down

0 comments on commit 0499589

Please sign in to comment.