Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unnecessary context switch in RankFeaturePhase #113232

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.dfs.AggregatedDfs;
Expand Down Expand Up @@ -39,13 +40,15 @@ final class FetchSearchPhase extends SearchPhase {
private final Logger logger;
private final SearchProgressListener progressListener;
private final AggregatedDfs aggregatedDfs;
@Nullable
private final SearchPhaseResults<SearchPhaseResult> resultConsumer;
private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase;

FetchSearchPhase(
SearchPhaseResults<SearchPhaseResult> resultConsumer,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
this(
resultConsumer,
Expand All @@ -64,7 +67,7 @@ final class FetchSearchPhase extends SearchPhase {
SearchPhaseResults<SearchPhaseResult> resultConsumer,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
) {
super("fetch");
Expand All @@ -85,14 +88,15 @@ final class FetchSearchPhase extends SearchPhase {
this.logger = context.getLogger();
this.progressListener = context.getTask().getProgressListener();
this.reducedQueryPhase = reducedQueryPhase;
this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null;
}

@Override
public void run() {
context.execute(new AbstractRunnable() {

@Override
protected void doRun() {
protected void doRun() throws Exception {
innerRun();
}

Expand All @@ -103,7 +107,10 @@ public void onFailure(Exception e) {
});
}

private void innerRun() {
private void innerRun() throws Exception {
assert this.reducedQueryPhase == null ^ this.resultConsumer == null;
// depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already
final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase;
final int numShards = context.getNumShards();
// Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
// still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
Expand All @@ -113,15 +120,15 @@ private void innerRun() {
if (queryAndFetchOptimization) {
assert assertConsistentWithQueryAndFetchOptimization();
// query AND fetch optimization
moveToNextPhase(searchPhaseShardResults);
moveToNextPhase(searchPhaseShardResults, reducedQueryPhase);
} else {
ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
// no docs to fetch -- sidestep everything and return
if (scoreDocs.length == 0) {
// we have to release contexts here to free up resources
searchPhaseShardResults.asList()
.forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context));
moveToNextPhase(fetchResults.getAtomicArray());
moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase);
} else {
final boolean shouldExplainRank = shouldExplainRankScores(context.getRequest());
final List<Map<Integer, RankDoc>> rankDocsPerShard = false == shouldExplainRank
Expand All @@ -134,7 +141,7 @@ private void innerRun() {
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(
fetchResults,
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
() -> moveToNextPhase(fetchResults.getAtomicArray()),
() -> moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase),
context
);
for (int i = 0; i < docIdsToLoad.length; i++) {
Expand Down Expand Up @@ -243,7 +250,10 @@ public void onFailure(Exception e) {
);
}

private void moveToNextPhase(AtomicArray<? extends SearchPhaseResult> fetchResultsArr) {
private void moveToNextPhase(
AtomicArray<? extends SearchPhaseResult> fetchResultsArr,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
context.addReleasable(resp::decRef);
fetchResults.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,20 @@ public class RankFeaturePhase extends SearchPhase {

@Override
public void run() {
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
if (rankFeaturePhaseRankCoordinatorContext == null) {
moveToNextPhase(queryPhaseResults, null);
return;
}

context.execute(new AbstractRunnable() {
@Override
protected void doRun() throws Exception {
// we need to reduce the results at this point instead of fetch phase, so we fork this process similarly to how
// was set up at FetchSearchPhase.

// we do the heavy lifting in this inner run method where we reduce aggs etc
innerRun();
innerRun(rankFeaturePhaseRankCoordinatorContext);
}

@Override
Expand All @@ -87,51 +93,39 @@ public void onFailure(Exception e) {
});
}

void innerRun() throws Exception {
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception {
// if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call
// to operate on the first `rank_window_size * num_shards` results and merge them appropriately.
SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce();
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
if (rankFeaturePhaseRankCoordinatorContext != null) {
ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
rankPhaseResults,
context.getNumShards(),
() -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
context
);
ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
rankPhaseResults,
context.getNumShards(),
() -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
context
);

// we send out a request to each shard in order to fetch the needed feature info
for (int i = 0; i < docIdsToLoad.length; i++) {
List<Integer> entry = docIdsToLoad[i];
SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
if (entry == null || entry.isEmpty()) {
if (queryResult != null) {
releaseIrrelevantSearchContext(queryResult, context);
progressListener.notifyRankFeatureResult(i);
}
rankRequestCounter.countDown();
} else {
executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
// we send out a request to each shard in order to fetch the needed feature info
for (int i = 0; i < docIdsToLoad.length; i++) {
List<Integer> entry = docIdsToLoad[i];
SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
if (entry == null || entry.isEmpty()) {
if (queryResult != null) {
releaseIrrelevantSearchContext(queryResult, context);
progressListener.notifyRankFeatureResult(i);
}
rankRequestCounter.countDown();
} else {
executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
}
} else {
moveToNextPhase(queryPhaseResults, reducedQueryPhase);
}
}

private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) {
return source == null || source.rankBuilder() == null
? null
: context.getRequest()
.source()
.rankBuilder()
.buildRankFeaturePhaseCoordinatorContext(
context.getRequest().source().size(),
context.getRequest().source().from(),
client
);
: source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client);
}

private void executeRankFeatureShardPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ public void sendExecuteRankFeature(
// override the RankFeaturePhase to raise an exception
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
@Override
void innerRun() {
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) {
throw new IllegalArgumentException("simulated failure");
}

Expand Down Expand Up @@ -1142,7 +1142,13 @@ public void moveToNextPhase(
) {
// this is called after the RankFeaturePhaseCoordinatorContext has been executed
phaseDone.set(true);
finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs();
try {
finalResults[0] = reducedQueryPhase == null
? queryPhaseResults.reduce().sortedTopDocs().scoreDocs()
: reducedQueryPhase.sortedTopDocs().scoreDocs();
} catch (Exception e) {
throw new AssertionError(e);
}
logger.debug("Skipping moving to next phase");
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -170,7 +171,12 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
// no work to be done on the coordinator node for the rank feature phase
@Override
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
return null;
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
@Override
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
throw new AssertionError("not expected");
}
};
}

@Override
Expand Down