Skip to content

Commit

Permalink
Adding RankFeature implementation (elastic#108538)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis authored Jun 6, 2024
1 parent 13afe0f commit 4a1d742
Show file tree
Hide file tree
Showing 46 changed files with 4,207 additions and 109 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/108538.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 108538
summary: Adding RankFeature search phase implementation
area: Search
type: feature
issues: []

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@
exports org.elasticsearch.search.query;
exports org.elasticsearch.search.rank;
exports org.elasticsearch.search.rank.context;
exports org.elasticsearch.search.rank.feature;
exports org.elasticsearch.search.rescore;
exports org.elasticsearch.search.retriever;
exports org.elasticsearch.search.runtime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED = def(8_675_00_0);
public static final TransportVersion ADD_MISTRAL_EMBEDDINGS_INFERENCE = def(8_676_00_0);
public static final TransportVersion ML_CHUNK_INFERENCE_OPTION = def(8_677_00_0);

public static final TransportVersion RANK_FEATURE_PHASE_ADDED = def(8_678_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
}
}

/**
* Executed when a shard returns a rank feature result.
*
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
*/
@Override
public void onRankFeatureResult(int shardIndex) {}

/**
* Executed when a shard reports a rank feature failure.
*
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
* @param shardTarget The last shard target that thrown an exception.
* @param exc The cause of the failure.
*/
@Override
public void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}

/**
* Executed when a shard returns a fetch result.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.ShardFetchSearchRequest;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.transport.Transport;

import java.util.List;
import java.util.function.BiFunction;
Expand All @@ -29,7 +27,7 @@
*/
final class FetchSearchPhase extends SearchPhase {
private final ArraySearchPhaseResults<FetchSearchResult> fetchResults;
private final AtomicArray<SearchPhaseResult> queryResults;
private final AtomicArray<SearchPhaseResult> searchPhaseShardResults;
private final BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final Logger logger;
Expand Down Expand Up @@ -74,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase {
}
this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards());
context.addReleasable(fetchResults);
this.queryResults = resultConsumer.getAtomicArray();
this.searchPhaseShardResults = resultConsumer.getAtomicArray();
this.aggregatedDfs = aggregatedDfs;
this.nextPhaseFactory = nextPhaseFactory;
this.context = context;
Expand Down Expand Up @@ -103,19 +101,20 @@ private void innerRun() {
final int numShards = context.getNumShards();
// Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
// still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
final boolean queryAndFetchOptimization = queryResults.length() == 1
final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1
&& context.getRequest().hasKnnSearch() == false
&& reducedQueryPhase.rankCoordinatorContext() == null;
&& reducedQueryPhase.queryPhaseRankCoordinatorContext() == null;
if (queryAndFetchOptimization) {
assert assertConsistentWithQueryAndFetchOptimization();
// query AND fetch optimization
moveToNextPhase(queryResults);
moveToNextPhase(searchPhaseShardResults);
} else {
ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
// no docs to fetch -- sidestep everything and return
if (scoreDocs.length == 0) {
// we have to release contexts here to free up resources
queryResults.asList().stream().map(SearchPhaseResult::queryResult).forEach(this::releaseIrrelevantSearchContext);
searchPhaseShardResults.asList()
.forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context));
moveToNextPhase(fetchResults.getAtomicArray());
} else {
final ScoreDoc[] lastEmittedDocPerShard = context.getRequest().scroll() != null
Expand All @@ -130,51 +129,53 @@ private void innerRun() {
);
for (int i = 0; i < docIdsToLoad.length; i++) {
List<Integer> entry = docIdsToLoad[i];
SearchPhaseResult queryResult = queryResults.get(i);
SearchPhaseResult shardPhaseResult = searchPhaseShardResults.get(i);
if (entry == null) { // no results for this shard ID
if (queryResult != null) {
if (shardPhaseResult != null) {
// if we got some hits from this shard we have to release the context there
// we do this as we go since it will free up resources and passing on the request on the
// transport layer is cheap.
releaseIrrelevantSearchContext(queryResult.queryResult());
releaseIrrelevantSearchContext(shardPhaseResult, context);
progressListener.notifyFetchResult(i);
}
// in any case we count down this result since we don't talk to this shard anymore
counter.countDown();
} else {
executeFetch(queryResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null);
executeFetch(shardPhaseResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null);
}
}
}
}
}

private boolean assertConsistentWithQueryAndFetchOptimization() {
var phaseResults = queryResults.asList();
var phaseResults = searchPhaseShardResults.asList();
assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null
: "phaseResults empty [" + phaseResults.isEmpty() + "], single result: " + phaseResults.get(0).fetchResult();
return true;
}

private void executeFetch(
SearchPhaseResult queryResult,
SearchPhaseResult shardPhaseResult,
final CountedCollector<FetchSearchResult> counter,
final List<Integer> entry,
ScoreDoc lastEmittedDocForShard
) {
final SearchShardTarget shardTarget = queryResult.getSearchShardTarget();
final int shardIndex = queryResult.getShardIndex();
final ShardSearchContextId contextId = queryResult.queryResult().getContextId();
final SearchShardTarget shardTarget = shardPhaseResult.getSearchShardTarget();
final int shardIndex = shardPhaseResult.getShardIndex();
final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null
? shardPhaseResult.queryResult().getContextId()
: shardPhaseResult.rankFeatureResult().getContextId();
context.getSearchTransport()
.sendExecuteFetch(
context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
new ShardFetchSearchRequest(
context.getOriginalIndices(queryResult.getShardIndex()),
context.getOriginalIndices(shardPhaseResult.getShardIndex()),
contextId,
queryResult.getShardSearchRequest(),
shardPhaseResult.getShardSearchRequest(),
entry,
lastEmittedDocForShard,
queryResult.getRescoreDocIds(),
shardPhaseResult.getRescoreDocIds(),
aggregatedDfs
),
context.getTask(),
Expand All @@ -199,40 +200,17 @@ public void onFailure(Exception e) {
// the search context might not be cleared on the node where the fetch was executed for example
// because the action was rejected by the thread pool. in this case we need to send a dedicated
// request to clear the search context.
releaseIrrelevantSearchContext(queryResult.queryResult());
releaseIrrelevantSearchContext(shardPhaseResult, context);
}
}
}
);
}

/**
* Releases shard targets that are not used in the docsIdsToLoad.
*/
private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) {
// we only release search context that we did not fetch from, if we are not scrolling
// or using a PIT and if it has at least one hit that didn't make it to the global topDocs
if (queryResult.hasSearchContext()
&& context.getRequest().scroll() == null
&& (context.isPartOfPointInTime(queryResult.getContextId()) == false)) {
try {
SearchShardTarget shardTarget = queryResult.getSearchShardTarget();
Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
context.sendReleaseSearchContext(
queryResult.getContextId(),
connection,
context.getOriginalIndices(queryResult.getShardIndex())
);
} catch (Exception e) {
logger.trace("failed to release context", e);
}
}
}

private void moveToNextPhase(AtomicArray<? extends SearchPhaseResult> fetchResultsArr) {
var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
context.addReleasable(resp::decRef);
fetchResults.close();
context.executeNextPhase(this, nextPhaseFactory.apply(resp, queryResults));
context.executeNextPhase(this, nextPhaseFactory.apply(resp, searchPhaseShardResults));
}
}
Loading

0 comments on commit 4a1d742

Please sign in to comment.