diff --git a/docs/changelog/108538.yaml b/docs/changelog/108538.yaml new file mode 100644 index 0000000000000..10ae49f0c1670 --- /dev/null +++ b/docs/changelog/108538.yaml @@ -0,0 +1,5 @@ +pr: 108538 +summary: Adding RankFeature search phase implementation +area: Search +type: feature +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java new file mode 100644 index 0000000000000..a4e2fda0fd3c9 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java @@ -0,0 +1,811 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchPhaseController; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.index.query.QueryBuilders.boolQuery; +import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery; +import static org.elasticsearch.index.query.QueryBuilders.matchQuery; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 3) +public class FieldBasedRerankerIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return List.of(FieldBasedRerankerPlugin.class); + } + + public void testFieldBasedReranker() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + assertHitCount(response, 5L); + int rank = 1; + for (SearchHit searchHit : response.getHits().getHits()) { + assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); + assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + rank++; + } + } + ); + assertNoOpenContext(indexName); + } + + public void testFieldBasedRerankerPagination() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(2) + .setFrom(2), + response -> { + assertHitCount(response, 5L); + int rank = 3; + for (SearchHit searchHit : response.getHits().getHits()) { + assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); + assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + rank++; + } + } + ); + assertNoOpenContext(indexName); + } + + public void testFieldBasedRerankerPaginationOutsideOfBounds() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(2) + .setFrom(10), + response -> { + assertHitCount(response, 5L); + assertEquals(0, response.getHits().getHits().length); + } + ); + assertNoOpenContext(indexName); + } + + public void testNotAllShardsArePresentInFetchPhase() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 10).build()); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A").setRouting("A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B").setRouting("B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C").setRouting("C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D").setRouting("C"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E").setRouting("C") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(0.1f)) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(0.3f)) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(0.3f)) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(0.3f)) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(2), + response -> { + assertHitCount(response, 4L); + assertEquals(2, response.getHits().getHits().length); + int rank = 1; + for (SearchHit searchHit : response.getHits().getHits()) { + assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); + assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + rank++; + } + } + ); + assertNoOpenContext(indexName); + } + + public void testFieldBasedRerankerNoMatchingDocs() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery(boolQuery().should(constantScoreQuery(matchQuery(searchField, "F")).boost(randomFloat()))) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + assertHitCount(response, 0L); + } + ); + assertNoOpenContext(indexName); + } + + public void testQueryPhaseShardThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + // this test is irrespective of the number of shards, as we will always reach QueryPhaseRankShardContext#combineQueryPhaseResults + // even with no results. So, when we get back to the coordinator, all shards will have failed, and the whole response + // will be marked as a failure + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + public void testQueryPhaseCoordinatorThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + // when we throw on the coordinator, the onPhaseFailure handler will be invoked, which in turn will mark the whole + // search request as a failure (i.e. no partial results) + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + public void testRankFeaturePhaseShardThrowingRankBuilderAllContextsAreClosedPartialFailures() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 10).build()); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + // we have 10 shards and 5 documents, so when the exception is thrown we know that not all shards will report failures + assertResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + assertTrue(response.getFailedShards() > 0); + assertTrue( + Arrays.stream(response.getShardFailures()) + .allMatch(failure -> failure.getCause().getMessage().contains("rfs - simulated failure")) + ); + assertHitCount(response, 5); + assertTrue(response.getHits().getHits().length == 0); + } + ); + assertNoOpenContext(indexName); + } + + public void testRankFeaturePhaseShardThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + // we have 1 shard and 5 documents, so when the exception is thrown we know that all shards will have failed + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).build()); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + public void testRankFeaturePhaseCoordinatorThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + private void assertNoOpenContext(final String indexName) throws Exception { + assertBusy( + () -> assertThat(indicesAdmin().prepareStats(indexName).get().getTotal().getSearch().getOpenContexts(), equalTo(0L)), + 1, + TimeUnit.SECONDS + ); + } + + public static class FieldBasedRankBuilder extends RankBuilder { + + public static final ParseField FIELD_FIELD = new ParseField("field"); + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field-based-rank", + args -> { + int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; + String field = (String) args[1]; + if (field == null || field.isEmpty()) { + throw new IllegalArgumentException("Field cannot be null or empty"); + } + return new FieldBasedRankBuilder(rankWindowSize, field); + } + ); + + static { + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareString(constructorArg(), FIELD_FIELD); + } + + protected final String field; + + public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + public FieldBasedRankBuilder(final int rankWindowSize, final String field) { + super(rankWindowSize); + this.field = field; + } + + public FieldBasedRankBuilder(StreamInput in) throws IOException { + super(in); + this.field = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(FIELD_FIELD.getPreferredName(), field); + } + + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, rankWindowSize()) { + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + Map rankDocs = new HashMap<>(); + rankResults.forEach(topDocs -> { + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + rankDocs.compute(scoreDoc.doc, (key, value) -> { + if (value == null) { + return new RankFeatureDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + } else { + value.score = Math.max(scoreDoc.score, rankDocs.get(scoreDoc.doc).score); + return value; + } + }); + } + }); + RankFeatureDoc[] sortedResults = rankDocs.values().toArray(RankFeatureDoc[]::new); + Arrays.sort(sortedResults, (o1, o2) -> Float.compare(o2.score, o1.score)); + return new RankFeatureShardResult(sortedResults); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(rankWindowSize()) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + RankFeatureShardResult shardResult = (RankFeatureShardResult) querySearchResult.getRankShardResult(); + for (RankFeatureDoc frd : shardResult.rankFeatureDocs) { + frd.shardIndex = i; + rankDocs.add(frd); + } + } + // no support for sort field atm + // should pass needed info to make use of org.elasticsearch.action.search.SearchPhaseController.sortDocs? + rankDocs.sort(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankFeatureDoc[]::new); + + assert topDocStats.fetchHits == 0; + topDocStats.fetchHits = topResults.length; + + return topResults; + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + try { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); + rankFeatureDocs[i].featureData(hits.getHits()[i].field(field).getValue().toString()); + } + return new RankFeatureShardResult(rankFeatureDocs); + } catch (Exception ex) { + throw ex; + } + } + }; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = Float.parseFloat(featureDocs[i].featureData); + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + protected boolean doEquals(RankBuilder other) { + return other instanceof FieldBasedRankBuilder && Objects.equals(field, ((FieldBasedRankBuilder) other).field); + } + + @Override + protected int doHashCode() { + return Objects.hash(field); + } + + @Override + public String getWriteableName() { + return "field-based-rank"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_FEATURE_PHASE_ADDED; + } + } + + public static class ThrowingRankBuilder extends FieldBasedRankBuilder { + + public enum ThrowingRankBuilderType { + THROWING_QUERY_PHASE_SHARD_CONTEXT, + THROWING_QUERY_PHASE_COORDINATOR_CONTEXT, + THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT, + THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT; + } + + protected final ThrowingRankBuilderType throwingRankBuilderType; + + public static final ParseField FIELD_FIELD = new ParseField("field"); + public static final ParseField THROWING_TYPE_FIELD = new ParseField("throwing-type"); + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("throwing-rank", args -> { + int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; + String field = (String) args[1]; + if (field == null || field.isEmpty()) { + throw new IllegalArgumentException("Field cannot be null or empty"); + } + String throwingType = (String) args[2]; + return new ThrowingRankBuilder(rankWindowSize, field, throwingType); + }); + + static { + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareString(constructorArg(), FIELD_FIELD); + PARSER.declareString(constructorArg(), THROWING_TYPE_FIELD); + } + + public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + public ThrowingRankBuilder(final int rankWindowSize, final String field, final String throwingType) { + super(rankWindowSize, field); + this.throwingRankBuilderType = ThrowingRankBuilderType.valueOf(throwingType); + } + + public ThrowingRankBuilder(StreamInput in) throws IOException { + super(in); + this.throwingRankBuilderType = in.readEnum(ThrowingRankBuilderType.class); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + super.doWriteTo(out); + out.writeEnum(throwingRankBuilderType); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + super.doXContent(builder, params); + builder.field(THROWING_TYPE_FIELD.getPreferredName(), throwingRankBuilderType); + } + + @Override + public String getWriteableName() { + return "throwing-rank"; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT) + return new QueryPhaseRankShardContext(queries, rankWindowSize()) { + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + throw new UnsupportedOperationException("qps - simulated failure"); + } + }; + else { + return super.buildQueryPhaseShardContext(queries, from); + } + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT) + return new QueryPhaseRankCoordinatorContext(rankWindowSize()) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + throw new UnsupportedOperationException("qpc - simulated failure"); + } + }; + else { + return super.buildQueryPhaseCoordinatorContext(size, from); + } + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT) + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + throw new UnsupportedOperationException("rfs - simulated failure"); + } + }; + else { + return super.buildRankFeaturePhaseShardContext(); + } + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT) + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new UnsupportedOperationException("rfc - simulated failure"); + } + }; + else { + return super.buildRankFeaturePhaseCoordinatorContext(size, from); + } + } + } + + public static class FieldBasedRerankerPlugin extends Plugin implements SearchPlugin { + + private static final String FIELD_BASED_RANK_BUILDER_NAME = "field-based-rank"; + private static final String THROWING_RANK_BUILDER_NAME = "throwing-rank"; + + @Override + public List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(RankBuilder.class, FIELD_BASED_RANK_BUILDER_NAME, FieldBasedRankBuilder::new), + new NamedWriteableRegistry.Entry(RankBuilder.class, THROWING_RANK_BUILDER_NAME, ThrowingRankBuilder::new), + new NamedWriteableRegistry.Entry(RankShardResult.class, "rank_feature_shard", RankFeatureShardResult::new) + ); + } + + @Override + public List getNamedXContent() { + return List.of( + new NamedXContentRegistry.Entry( + RankBuilder.class, + new ParseField(FIELD_BASED_RANK_BUILDER_NAME), + FieldBasedRankBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + RankBuilder.class, + new ParseField(THROWING_RANK_BUILDER_NAME), + ThrowingRankBuilder::fromXContent + ) + ); + } + } +} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index d8682500c49d6..2f08129b4080d 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -362,6 +362,7 @@ exports org.elasticsearch.search.query; exports org.elasticsearch.search.rank; exports org.elasticsearch.search.rank.context; + exports org.elasticsearch.search.rank.feature; exports org.elasticsearch.search.rescore; exports org.elasticsearch.search.retriever; exports org.elasticsearch.search.runtime; diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e8a33217b937d..72771855ff622 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -184,7 +184,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED = def(8_675_00_0); public static final TransportVersion ADD_MISTRAL_EMBEDDINGS_INFERENCE = def(8_676_00_0); public static final TransportVersion ML_CHUNK_INFERENCE_OPTION = def(8_677_00_0); - + public static final TransportVersion RANK_FEATURE_PHASE_ADDED = def(8_678_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java index 3b594c94db9a7..0504d0cde8986 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java @@ -260,6 +260,24 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna } } + /** + * Executed when a shard returns a rank feature result. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + */ + @Override + public void onRankFeatureResult(int shardIndex) {} + + /** + * Executed when a shard reports a rank feature failure. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + * @param shardTarget The last shard target that thrown an exception. + * @param exc The cause of the failure. + */ + @Override + public void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {} + /** * Executed when a shard returns a fetch result. * diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index f804ab31faf8e..2308f5fcc8085 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -17,8 +17,6 @@ import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchSearchRequest; import org.elasticsearch.search.internal.ShardSearchContextId; -import org.elasticsearch.search.query.QuerySearchResult; -import org.elasticsearch.transport.Transport; import java.util.List; import java.util.function.BiFunction; @@ -29,7 +27,7 @@ */ final class FetchSearchPhase extends SearchPhase { private final ArraySearchPhaseResults fetchResults; - private final AtomicArray queryResults; + private final AtomicArray searchPhaseShardResults; private final BiFunction, SearchPhase> nextPhaseFactory; private final SearchPhaseContext context; private final Logger logger; @@ -74,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase { } this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards()); context.addReleasable(fetchResults); - this.queryResults = resultConsumer.getAtomicArray(); + this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; this.nextPhaseFactory = nextPhaseFactory; this.context = context; @@ -103,19 +101,20 @@ private void innerRun() { final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = queryResults.length() == 1 + final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 && context.getRequest().hasKnnSearch() == false - && reducedQueryPhase.rankCoordinatorContext() == null; + && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null; if (queryAndFetchOptimization) { assert assertConsistentWithQueryAndFetchOptimization(); // query AND fetch optimization - moveToNextPhase(queryResults); + moveToNextPhase(searchPhaseShardResults); } else { ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // no docs to fetch -- sidestep everything and return if (scoreDocs.length == 0) { // we have to release contexts here to free up resources - queryResults.asList().stream().map(SearchPhaseResult::queryResult).forEach(this::releaseIrrelevantSearchContext); + searchPhaseShardResults.asList() + .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); moveToNextPhase(fetchResults.getAtomicArray()); } else { final ScoreDoc[] lastEmittedDocPerShard = context.getRequest().scroll() != null @@ -130,19 +129,19 @@ private void innerRun() { ); for (int i = 0; i < docIdsToLoad.length; i++) { List entry = docIdsToLoad[i]; - SearchPhaseResult queryResult = queryResults.get(i); + SearchPhaseResult shardPhaseResult = searchPhaseShardResults.get(i); if (entry == null) { // no results for this shard ID - if (queryResult != null) { + if (shardPhaseResult != null) { // if we got some hits from this shard we have to release the context there // we do this as we go since it will free up resources and passing on the request on the // transport layer is cheap. - releaseIrrelevantSearchContext(queryResult.queryResult()); + releaseIrrelevantSearchContext(shardPhaseResult, context); progressListener.notifyFetchResult(i); } // in any case we count down this result since we don't talk to this shard anymore counter.countDown(); } else { - executeFetch(queryResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null); + executeFetch(shardPhaseResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null); } } } @@ -150,31 +149,33 @@ private void innerRun() { } private boolean assertConsistentWithQueryAndFetchOptimization() { - var phaseResults = queryResults.asList(); + var phaseResults = searchPhaseShardResults.asList(); assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + phaseResults.isEmpty() + "], single result: " + phaseResults.get(0).fetchResult(); return true; } private void executeFetch( - SearchPhaseResult queryResult, + SearchPhaseResult shardPhaseResult, final CountedCollector counter, final List entry, ScoreDoc lastEmittedDocForShard ) { - final SearchShardTarget shardTarget = queryResult.getSearchShardTarget(); - final int shardIndex = queryResult.getShardIndex(); - final ShardSearchContextId contextId = queryResult.queryResult().getContextId(); + final SearchShardTarget shardTarget = shardPhaseResult.getSearchShardTarget(); + final int shardIndex = shardPhaseResult.getShardIndex(); + final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null + ? shardPhaseResult.queryResult().getContextId() + : shardPhaseResult.rankFeatureResult().getContextId(); context.getSearchTransport() .sendExecuteFetch( context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()), new ShardFetchSearchRequest( - context.getOriginalIndices(queryResult.getShardIndex()), + context.getOriginalIndices(shardPhaseResult.getShardIndex()), contextId, - queryResult.getShardSearchRequest(), + shardPhaseResult.getShardSearchRequest(), entry, lastEmittedDocForShard, - queryResult.getRescoreDocIds(), + shardPhaseResult.getRescoreDocIds(), aggregatedDfs ), context.getTask(), @@ -199,40 +200,17 @@ public void onFailure(Exception e) { // the search context might not be cleared on the node where the fetch was executed for example // because the action was rejected by the thread pool. in this case we need to send a dedicated // request to clear the search context. - releaseIrrelevantSearchContext(queryResult.queryResult()); + releaseIrrelevantSearchContext(shardPhaseResult, context); } } } ); } - /** - * Releases shard targets that are not used in the docsIdsToLoad. - */ - private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) { - // we only release search context that we did not fetch from, if we are not scrolling - // or using a PIT and if it has at least one hit that didn't make it to the global topDocs - if (queryResult.hasSearchContext() - && context.getRequest().scroll() == null - && (context.isPartOfPointInTime(queryResult.getContextId()) == false)) { - try { - SearchShardTarget shardTarget = queryResult.getSearchShardTarget(); - Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); - context.sendReleaseSearchContext( - queryResult.getContextId(), - connection, - context.getOriginalIndices(queryResult.getShardIndex()) - ); - } catch (Exception e) { - logger.trace("failed to release context", e); - } - } - } - private void moveToNextPhase(AtomicArray fetchResultsArr) { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp::decRef); fetchResults.close(); - context.executeNextPhase(this, nextPhaseFactory.apply(resp, queryResults)); + context.executeNextPhase(this, nextPhaseFactory.apply(resp, searchPhaseShardResults)); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 767597625edc6..291982dd9bdd3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -7,23 +7,39 @@ */ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.dfs.AggregatedDfs; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; + +import java.util.List; /** * This search phase is responsible for executing any re-ranking needed for the given search request, iff that is applicable. - * It starts by retrieving {code num_shards * window_size} results from the query phase and reduces them to a global list of + * It starts by retrieving {@code num_shards * window_size} results from the query phase and reduces them to a global list of * the top {@code window_size} results. It then reaches out to the shards to extract the needed feature data, * and finally passes all this information to the appropriate {@code RankFeatureRankCoordinatorContext} which is responsible for reranking * the results. If no rank query is specified, it proceeds directly to the next phase (FetchSearchPhase) by first reducing the results. */ -public final class RankFeaturePhase extends SearchPhase { +public class RankFeaturePhase extends SearchPhase { + private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class); private final SearchPhaseContext context; - private final SearchPhaseResults queryPhaseResults; - + final SearchPhaseResults queryPhaseResults; + final SearchPhaseResults rankPhaseResults; private final AggregatedDfs aggregatedDfs; + private final SearchProgressListener progressListener; RankFeaturePhase(SearchPhaseResults queryPhaseResults, AggregatedDfs aggregatedDfs, SearchPhaseContext context) { super("rank-feature"); @@ -38,6 +54,9 @@ public final class RankFeaturePhase extends SearchPhase { this.context = context; this.queryPhaseResults = queryPhaseResults; this.aggregatedDfs = aggregatedDfs; + this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards()); + context.addReleasable(rankPhaseResults); + this.progressListener = context.getTask().getProgressListener(); } @Override @@ -59,16 +78,154 @@ public void onFailure(Exception e) { }); } - private void innerRun() throws Exception { - // other than running reduce, this is currently close to a no-op + void innerRun() throws Exception { + // if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call + // to operate on the first `window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); - moveToNextPhase(queryPhaseResults, reducedQueryPhase); + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source()); + if (rankFeaturePhaseRankCoordinatorContext != null) { + ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final CountedCollector rankRequestCounter = new CountedCollector<>( + rankPhaseResults, + context.getNumShards(), + () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), + context + ); + + // we send out a request to each shard in order to fetch the needed feature info + for (int i = 0; i < docIdsToLoad.length; i++) { + List entry = docIdsToLoad[i]; + SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i); + if (entry == null || entry.isEmpty()) { + if (queryResult != null) { + releaseIrrelevantSearchContext(queryResult, context); + progressListener.notifyRankFeatureResult(i); + } + rankRequestCounter.countDown(); + } else { + executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry); + } + } + } else { + moveToNextPhase(queryPhaseResults, reducedQueryPhase); + } + } + + private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) { + return source == null || source.rankBuilder() == null + ? null + : context.getRequest() + .source() + .rankBuilder() + .buildRankFeaturePhaseCoordinatorContext(context.getRequest().source().size(), context.getRequest().source().from()); } - private void moveToNextPhase( - SearchPhaseResults phaseResults, + private void executeRankFeatureShardPhase( + SearchPhaseResult queryResult, + final CountedCollector rankRequestCounter, + final List entry + ) { + final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget(); + final ShardSearchContextId contextId = queryResult.queryResult().getContextId(); + final int shardIndex = queryResult.getShardIndex(); + context.getSearchTransport() + .sendExecuteRankFeature( + context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()), + new RankFeatureShardRequest( + context.getOriginalIndices(queryResult.getShardIndex()), + queryResult.getContextId(), + queryResult.getShardSearchRequest(), + entry + ), + context.getTask(), + new SearchActionListener<>(shardTarget, shardIndex) { + @Override + protected void innerOnResponse(RankFeatureResult response) { + try { + progressListener.notifyRankFeatureResult(shardIndex); + rankRequestCounter.onResult(response); + } catch (Exception e) { + context.onPhaseFailure(RankFeaturePhase.this, "", e); + } + } + + @Override + public void onFailure(Exception e) { + try { + logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e); + progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e); + rankRequestCounter.onFailure(shardIndex, shardTarget, e); + } finally { + releaseIrrelevantSearchContext(queryResult, context); + } + } + } + ); + } + + private void onPhaseDone( + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext, SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { + assert rankFeaturePhaseRankCoordinatorContext != null; + ThreadedActionListener rankResultListener = new ThreadedActionListener<>(context, new ActionListener<>() { + @Override + public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { + RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores); + SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( + reducedQueryPhase, + topResults + ); + moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase); + } + + @Override + public void onFailure(Exception e) { + context.onPhaseFailure(RankFeaturePhase.this, "Computing updated ranks for results failed", e); + } + }); + rankFeaturePhaseRankCoordinatorContext.rankGlobalResults( + rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(), + rankResultListener + ); + } + + private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults( + SearchPhaseController.ReducedQueryPhase reducedQueryPhase, + ScoreDoc[] scoreDocs + ) { + + return new SearchPhaseController.ReducedQueryPhase( + reducedQueryPhase.totalHits(), + reducedQueryPhase.fetchHits(), + maxScore(scoreDocs), + reducedQueryPhase.timedOut(), + reducedQueryPhase.terminatedEarly(), + reducedQueryPhase.suggest(), + reducedQueryPhase.aggregations(), + reducedQueryPhase.profileBuilder(), + new SearchPhaseController.SortedTopDocs(scoreDocs, false, null, null, null, 0), + reducedQueryPhase.sortValueFormats(), + reducedQueryPhase.queryPhaseRankCoordinatorContext(), + reducedQueryPhase.numReducePhases(), + reducedQueryPhase.size(), + reducedQueryPhase.from(), + reducedQueryPhase.isEmptyResult() + ); + } + + private float maxScore(ScoreDoc[] scoreDocs) { + float maxScore = Float.NaN; + for (ScoreDoc scoreDoc : scoreDocs) { + if (Float.isNaN(maxScore) || scoreDoc.score > maxScore) { + maxScore = scoreDoc.score; + } + } + return maxScore; + } + + void moveToNextPhase(SearchPhaseResults phaseResults, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) { context.executeNextPhase(this, new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase)); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java index 9d3eadcc42bf9..5ed449667fe57 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java @@ -9,6 +9,9 @@ import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.core.CheckedRunnable; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.transport.Transport; import java.io.IOException; import java.io.UncheckedIOException; @@ -62,4 +65,35 @@ static void doCheckNoMissingShards(String phaseName, SearchRequest request, Grou } } } + + /** + * Releases shard targets that are not used in the docsIdsToLoad. + */ + protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, SearchPhaseContext context) { + // we only release search context that we did not fetch from, if we are not scrolling + // or using a PIT and if it has at least one hit that didn't make it to the global topDocs + if (searchPhaseResult == null) { + return; + } + // phaseResult.getContextId() is the same for query & rank feature results + SearchPhaseResult phaseResult = searchPhaseResult.queryResult() != null + ? searchPhaseResult.queryResult() + : searchPhaseResult.rankFeatureResult(); + if (phaseResult != null + && phaseResult.hasSearchContext() + && context.getRequest().scroll() == null + && (context.isPartOfPointInTime(phaseResult.getContextId()) == false)) { + try { + SearchShardTarget shardTarget = phaseResult.getSearchShardTarget(); + Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); + context.sendReleaseSearchContext( + phaseResult.getContextId(), + connection, + context.getOriginalIndices(phaseResult.getShardIndex()) + ); + } catch (Exception e) { + context.getLogger().trace("failed to release context", e); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 1b894dfe3d8bd..1d3859b9038fe 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -456,7 +456,7 @@ private static SearchHits getHits( : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length; SearchHit searchHit = fetchResult.hits().getHits()[index]; searchHit.shard(fetchResult.getSearchShardTarget()); - if (reducedQueryPhase.rankCoordinatorContext != null) { + if (reducedQueryPhase.queryPhaseRankCoordinatorContext != null) { assert shardDoc instanceof RankDoc; searchHit.setRank(((RankDoc) shardDoc).rank); searchHit.score(shardDoc.score); @@ -747,7 +747,7 @@ public record ReducedQueryPhase( // sort value formats used to sort / format the result DocValueFormat[] sortValueFormats, // the rank context if ranking is used - QueryPhaseRankCoordinatorContext rankCoordinatorContext, + QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext, // the number of reduces phases int numReducePhases, // the size of the top hits to return diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index f5d280a01257c..3b5e03cb5ac4a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -88,6 +88,22 @@ protected void onPartialReduce(List shards, TotalHits totalHits, In */ protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} + /** + * Executed when a shard returns a rank feature result. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + */ + protected void onRankFeatureResult(int shardIndex) {} + + /** + * Executed when a shard reports a rank feature failure. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + * @param shardTarget The last shard target that thrown an exception. + * @param exc The cause of the failure. + */ + protected void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {} + /** * Executed when a shard returns a fetch result. * @@ -160,6 +176,22 @@ protected final void notifyFinalReduce(List shards, TotalHits total } } + final void notifyRankFeatureResult(int shardIndex) { + try { + onRankFeatureResult(shardIndex); + } catch (Exception e) { + logger.warn(() -> "[" + shards.get(shardIndex) + "] Failed to execute progress listener on rank-feature result", e); + } + } + + final void notifyRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { + try { + onRankFeatureFailure(shardIndex, shardTarget, exc); + } catch (Exception e) { + logger.warn(() -> "[" + shards.get(shardIndex) + "] Failed to execute progress listener on rank-feature failure", e); + } + } + final void notifyFetchResult(int shardIndex) { try { onFetchResult(shardIndex); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index 4e3fdbc9633b9..3e4f6dfec9fdb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -407,7 +407,7 @@ public ActionRequestValidationException validate() { ); } int queryCount = source.subSearches().size() + source.knnSearch().size(); - if (queryCount < 2) { + if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) { validationException = addValidationError( "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", validationException diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java index 93b8e22d0d7cd..9f8896f169350 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java @@ -19,6 +19,7 @@ public class SearchTransportAPMMetrics { public static final String DFS_ACTION_METRIC = "dfs_query_then_fetch/shard_dfs_phase"; public static final String QUERY_ID_ACTION_METRIC = "dfs_query_then_fetch/shard_query_phase"; public static final String QUERY_ACTION_METRIC = "query_then_fetch/shard_query_phase"; + public static final String RANK_SHARD_FEATURE_ACTION_METRIC = "rank/shard_feature_phase"; public static final String FREE_CONTEXT_ACTION_METRIC = "shard_release_context"; public static final String FETCH_ID_ACTION_METRIC = "shard_fetch_phase"; public static final String QUERY_SCROLL_ACTION_METRIC = "scroll/shard_query_phase"; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 66c395cf51d96..d627da9b0e33b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -39,6 +39,8 @@ import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.query.ScrollQuerySearchResult; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterService; @@ -70,6 +72,7 @@ import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_FETCH_SCROLL_ACTION_METRIC; import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_ID_ACTION_METRIC; import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_SCROLL_ACTION_METRIC; +import static org.elasticsearch.action.search.SearchTransportAPMMetrics.RANK_SHARD_FEATURE_ACTION_METRIC; /** * An encapsulation of {@link org.elasticsearch.search.SearchService} operations exposed through @@ -96,6 +99,8 @@ public class SearchTransportService { public static final String FETCH_ID_SCROLL_ACTION_NAME = "indices:data/read/search[phase/fetch/id/scroll]"; public static final String FETCH_ID_ACTION_NAME = "indices:data/read/search[phase/fetch/id]"; + public static final String RANK_FEATURE_SHARD_ACTION_NAME = "indices:data/read/search[phase/rank/feature]"; + /** * The Can-Match phase. It is executed to pre-filter shards that a search request hits. It rewrites the query on * the shard and checks whether the result of the rewrite matches no documents, in which case the shard can be @@ -250,6 +255,21 @@ public void sendExecuteScrollQuery( ); } + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + transportService.sendChildRequest( + connection, + RANK_FEATURE_SHARD_ACTION_NAME, + request, + task, + new ConnectionCountingHandler<>(listener, RankFeatureResult::new, connection) + ); + } + public void sendExecuteScrollFetch( Transport.Connection connection, final InternalScrollSearchRequest request, @@ -539,6 +559,16 @@ public static void registerRequestHandler( ); TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); + final TransportRequestHandler rankShardFeatureRequest = (request, channel, task) -> searchService + .executeRankFeaturePhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); + transportService.registerRequestHandler( + RANK_FEATURE_SHARD_ACTION_NAME, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + RankFeatureShardRequest::new, + instrumentedHandler(RANK_SHARD_FEATURE_ACTION_METRIC, transportService, searchTransportMetrics, rankShardFeatureRequest) + ); + TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new); + final TransportRequestHandler shardFetchRequestHandler = (request, channel, task) -> searchService .executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); transportService.registerRequestHandler( diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index e0ca0f7a48cdd..fd2aabce8e952 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -1044,6 +1044,7 @@ record PluginServiceInstances( threadPool, scriptService, bigArrays, + searchModule.getRankFeatureShardPhase(), searchModule.getFetchPhase(), responseCollectorService, circuitBreakerService, diff --git a/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java b/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java index ab90ca42bca98..914dd51d0c6b2 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java +++ b/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java @@ -33,6 +33,7 @@ import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; @@ -116,6 +117,7 @@ SearchService newSearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -128,6 +130,7 @@ SearchService newSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index 9bacf19a9169d..4f16d3a5720fb 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -70,6 +70,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.SortAndFormats; @@ -102,6 +103,7 @@ final class DefaultSearchContext extends SearchContext { private final ContextIndexSearcher searcher; private DfsSearchResult dfsResult; private QuerySearchResult queryResult; + private RankFeatureResult rankFeatureResult; private FetchSearchResult fetchResult; private final float queryBoost; private final boolean lowLevelCancellation; @@ -308,6 +310,17 @@ static boolean isParallelCollectionSupportedForResults( return false; } + @Override + public void addRankFeatureResult() { + this.rankFeatureResult = new RankFeatureResult(this.readerContext.id(), this.shardTarget, this.request); + addReleasable(rankFeatureResult::decRef); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return rankFeatureResult; + } + @Override public void addFetchResult() { this.fetchResult = new FetchSearchResult(this.readerContext.id(), this.shardTarget); diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 8d5fa0a7ac155..d93ff91a6ffe4 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -226,6 +226,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.fetch.subphase.highlight.PlainHighlighter; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; @@ -1252,6 +1253,10 @@ private void registerQuery(QuerySpec spec) { ); } + public RankFeatureShardPhase getRankFeatureShardPhase() { + return new RankFeatureShardPhase(); + } + public FetchPhase getFetchPhase() { return new FetchPhase(fetchSubPhases); } diff --git a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java index 254cd7d3370b5..450b98b22f39c 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java +++ b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.transport.TransportResponse; import java.io.IOException; @@ -43,6 +44,14 @@ protected SearchPhaseResult(StreamInput in) throws IOException { super(in); } + /** + * Specifies whether the specific search phase results are associated with an opened SearchContext on the shards that + * executed the request. + */ + public boolean hasSearchContext() { + return false; + } + /** * Returns the search context ID that is used to reference the search context on the executing node * or null if no context was created. @@ -81,6 +90,13 @@ public QuerySearchResult queryResult() { return null; } + /** + * Returns the rank feature result iff it's included in this response otherwise null + */ + public RankFeatureResult rankFeatureResult() { + return null; + } + /** * Returns the fetch result iff it's included in this response otherwise null */ diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 41796967c3870..3f9dd7895f6a7 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -112,6 +112,9 @@ import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.query.ScrollQuerySearchResult; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; @@ -151,6 +154,7 @@ import static org.elasticsearch.core.TimeValue.timeValueMillis; import static org.elasticsearch.core.TimeValue.timeValueMinutes; import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import static org.elasticsearch.search.rank.feature.RankFeatureShardPhase.EMPTY_RESULT; public class SearchService extends AbstractLifecycleComponent implements IndexEventListener { private static final Logger logger = LogManager.getLogger(SearchService.class); @@ -276,6 +280,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv private final DfsPhase dfsPhase = new DfsPhase(); private final FetchPhase fetchPhase; + private final RankFeatureShardPhase rankFeatureShardPhase; private volatile boolean enableSearchWorkerThreads; private volatile boolean enableQueryPhaseParallelCollection; @@ -314,6 +319,7 @@ public SearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -327,6 +333,7 @@ public SearchService( this.scriptService = scriptService; this.responseCollectorService = responseCollectorService; this.bigArrays = bigArrays; + this.rankFeatureShardPhase = rankFeatureShardPhase; this.fetchPhase = fetchPhase; this.multiBucketConsumerService = new MultiBucketConsumerService( clusterService, @@ -713,6 +720,32 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh } } + public void executeRankFeaturePhase(RankFeatureShardRequest request, SearchShardTask task, ActionListener listener) { + final ReaderContext readerContext = findReaderContext(request.contextId(), request); + final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); + final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); + runAsync(getExecutor(readerContext.indexShard()), () -> { + try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.RANK_FEATURE, false)) { + int[] docIds = request.getDocIds(); + if (docIds == null || docIds.length == 0) { + searchContext.rankFeatureResult().shardResult(EMPTY_RESULT); + searchContext.rankFeatureResult().incRef(); + return searchContext.rankFeatureResult(); + } + rankFeatureShardPhase.prepareForFetch(searchContext, request); + fetchPhase.execute(searchContext, docIds); + rankFeatureShardPhase.processFetch(searchContext); + var rankFeatureResult = searchContext.rankFeatureResult(); + rankFeatureResult.incRef(); + return rankFeatureResult; + } catch (Exception e) { + assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); + // we handle the failure in the failure listener below + throw e; + } + }, wrapFailureListener(listener, readerContext, markAsUsed)); + } + private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchContext context, long afterQueryTime) { try ( Releasable scope = tracer.withScope(context.getTask()); @@ -1559,6 +1592,12 @@ void addResultsObject(SearchContext context) { context.addQueryResult(); } }, + RANK_FEATURE { + @Override + void addResultsObject(SearchContext context) { + context.addRankFeatureResult(); + } + }, FETCH { @Override void addResultsObject(SearchContext context) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index 0c54e8ff89589..4ba191794413d 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -98,7 +98,6 @@ public Source getSource(LeafReaderContext ctx, int doc) { } private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler) { - FetchContext fetchContext = new FetchContext(context); SourceLoader sourceLoader = context.newSourceLoader(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java index d5c3c00c00ce1..e32397e25d773 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java @@ -35,6 +35,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -374,6 +375,16 @@ public float getMaxScore() { return in.getMaxScore(); } + @Override + public void addRankFeatureResult() { + in.addRankFeatureResult(); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return in.rankFeatureResult(); + } + @Override public FetchSearchResult fetchResult() { return in.fetchResult(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 35f96ee2dc102..9bc622034184c 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -42,6 +42,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -332,6 +333,10 @@ public Query rewrittenQuery() { public abstract float getMaxScore(); + public abstract void addRankFeatureResult(); + + public abstract RankFeatureResult rankFeatureResult(); + public abstract FetchPhase fetchPhase(); public abstract FetchSearchResult fetchResult(); diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 828c6d2b4f3e8..0d2610aa34282 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -87,35 +87,38 @@ static void executeRank(SearchContext searchContext) throws QueryPhaseExecutionE boolean searchTimedOut = querySearchResult.searchTimedOut(); long serviceTimeEWMA = querySearchResult.serviceTimeEWMA(); int nodeQueueSize = querySearchResult.nodeQueueSize(); - - // run each of the rank queries - for (Query rankQuery : queryPhaseRankShardContext.queries()) { - // if a search timeout occurs, exit with partial results - if (searchTimedOut) { - break; - } - try ( - RankSearchContext rankSearchContext = new RankSearchContext( - searchContext, - rankQuery, - queryPhaseRankShardContext.rankWindowSize() - ) - ) { - QueryPhase.addCollectorsAndSearch(rankSearchContext); - QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult(); - rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs); - serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA(); - nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize()); - searchTimedOut = rrfQuerySearchResult.searchTimedOut(); + try { + // run each of the rank queries + for (Query rankQuery : queryPhaseRankShardContext.queries()) { + // if a search timeout occurs, exit with partial results + if (searchTimedOut) { + break; + } + try ( + RankSearchContext rankSearchContext = new RankSearchContext( + searchContext, + rankQuery, + queryPhaseRankShardContext.rankWindowSize() + ) + ) { + QueryPhase.addCollectorsAndSearch(rankSearchContext); + QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult(); + rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs); + serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA(); + nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize()); + searchTimedOut = rrfQuerySearchResult.searchTimedOut(); + } } - } - querySearchResult.setRankShardResult(queryPhaseRankShardContext.combineQueryPhaseResults(rrfRankResults)); + querySearchResult.setRankShardResult(queryPhaseRankShardContext.combineQueryPhaseResults(rrfRankResults)); - // record values relevant to all queries - querySearchResult.searchTimedOut(searchTimedOut); - querySearchResult.serviceTimeEWMA(serviceTimeEWMA); - querySearchResult.nodeQueueSize(nodeQueueSize); + // record values relevant to all queries + querySearchResult.searchTimedOut(searchTimedOut); + querySearchResult.serviceTimeEWMA(serviceTimeEWMA); + querySearchResult.nodeQueueSize(nodeQueueSize); + } catch (Exception e) { + throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute rank query", e); + } } static void executeQuery(SearchContext searchContext) throws QueryPhaseExecutionException { diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java index 7118c9f49b36d..f496758c3f5c6 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java @@ -16,6 +16,8 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -32,7 +34,7 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); - public static final int DEFAULT_WINDOW_SIZE = SearchService.DEFAULT_SIZE; + public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE; private final int rankWindowSize; @@ -68,6 +70,12 @@ public int rankWindowSize() { return rankWindowSize; } + /** + * Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires + * two or more queries to be executed in order to generate the final result. + */ + public abstract boolean isCompoundBuilder(); + /** * Generates a context used to execute required searches during the query phase on the shard. */ @@ -78,6 +86,19 @@ public int rankWindowSize() { */ public abstract QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from); + /** + * Generates a context used to execute the rank feature phase on the shard. This is responsible for retrieving any needed + * feature data, and passing them back to the coordinator through the appropriate {@link RankShardResult}. + */ + public abstract RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(); + + /** + * Generates a context used to perform global ranking during the RankFeature phase, + * on the coordinator based on all the individual shard results. The output of this will be a `size` ranked list of ordered results, + * which will then be passed to fetch phase. + */ + public abstract RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from); + @Override public final boolean equals(Object obj) { if (this == obj) { diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java index 1cb5843dfc7da..7f8e99971d61b 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java @@ -43,6 +43,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -57,14 +58,14 @@ public class RankSearchContext extends SearchContext { private final SearchContext parent; private final Query rankQuery; - private final int windowSize; + private final int rankWindowSize; private final QuerySearchResult querySearchResult; @SuppressWarnings("this-escape") - public RankSearchContext(SearchContext parent, Query rankQuery, int windowSize) { + public RankSearchContext(SearchContext parent, Query rankQuery, int rankWindowSize) { this.parent = parent; this.rankQuery = parent.buildFilteredQuery(rankQuery); - this.windowSize = windowSize; + this.rankWindowSize = rankWindowSize; this.querySearchResult = new QuerySearchResult(parent.readerContext().id(), parent.shardTarget(), parent.request()); this.addReleasable(querySearchResult::decRef); } @@ -182,7 +183,7 @@ public int from() { @Override public int size() { - return windowSize; + return rankWindowSize; } /** @@ -492,6 +493,16 @@ public FetchPhase fetchPhase() { throw new UnsupportedOperationException(); } + @Override + public void addRankFeatureResult() { + throw new UnsupportedOperationException(); + } + + @Override + public RankFeatureResult rankFeatureResult() { + throw new UnsupportedOperationException(); + } + @Override public FetchSearchResult fetchResult() { throw new UnsupportedOperationException(); diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java new file mode 100644 index 0000000000000..b8951a4779166 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.context; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import static org.elasticsearch.search.SearchService.DEFAULT_FROM; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; + +/** + * {@code RankFeaturePhaseRankCoordinatorContext} is a base class that runs on the coordinating node and is responsible for retrieving + * {@code window_size} total results from all shards, rank them, and then produce a final paginated response of [from, from+size] results. + */ +public abstract class RankFeaturePhaseRankCoordinatorContext { + + protected final int size; + protected final int from; + protected final int rankWindowSize; + + public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + this.size = size < 0 ? DEFAULT_SIZE : size; + this.from = from < 0 ? DEFAULT_FROM : from; + this.rankWindowSize = rankWindowSize; + } + + /** + * Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener + * that should be called with the new scores, and will continue execution to the next phase + */ + protected abstract void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener); + + /** + * This method is responsible for ranking the global results based on the provided rank feature results from each shard. + *

+ * We first start by extracting ordered feature data through a {@code List} + * from the provided rankSearchResults, and then compute the updated score for each of the documents. + * Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer + * with the final array of {@link ScoreDoc} results. + * + * @param rankSearchResults a list of rank feature results from each shard + * @param rankListener a rankListener to handle the global ranking result + */ + public void rankGlobalResults(List rankSearchResults, ActionListener rankListener) { + // extract feature data from each shard rank-feature phase result + RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults); + + // generate the final `topResults` paginated results, and pass them to fetch phase through the `rankListener` + computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> { + for (int i = 0; i < featureDocs.length; i++) { + featureDocs[i].score = scores[i]; + } + listener.onResponse(featureDocs); + })); + } + + /** + * Ranks the provided {@link RankFeatureDoc} array and paginates the results based on the `from` and `size` parameters. + */ + public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { + Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))]; + for (int rank = 0; rank < topResults.length; ++rank) { + topResults[rank] = rankFeatureDocs[from + rank]; + topResults[rank].rank = from + rank + 1; + } + return topResults; + } + + private RankFeatureDoc[] extractFeatureDocs(List rankSearchResults) { + List docFeatures = new ArrayList<>(); + for (RankFeatureResult rankFeatureResult : rankSearchResults) { + RankFeatureShardResult shardResult = rankFeatureResult.shardResult(); + for (RankFeatureDoc rankFeatureDoc : shardResult.rankFeatureDocs) { + if (rankFeatureDoc.featureData != null) { + docFeatures.add(rankFeatureDoc); + } + } + } + return docFeatures.toArray(new RankFeatureDoc[0]); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java new file mode 100644 index 0000000000000..5d3f30bce757a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.context; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.rank.RankShardResult; + +/** + * {@link RankFeaturePhaseRankShardContext} is a base class used to execute the RankFeature phase on each shard. + * In this class, we can fetch the feature data for a given set of documents and pass them back to the coordinator + * through the {@link RankShardResult}. + */ +public abstract class RankFeaturePhaseRankShardContext { + + protected final String field; + + public RankFeaturePhaseRankShardContext(final String field) { + this.field = field; + } + + public String getField() { + return field; + } + + /** + * This is used to fetch the feature data for a given set of documents, using the {@link org.elasticsearch.search.fetch.FetchPhase} + * and the {@link org.elasticsearch.search.fetch.subphase.FetchFieldsPhase} subphase. + * The feature data is then stored in a {@link org.elasticsearch.search.rank.feature.RankFeatureDoc} and passed back to the coordinator. + */ + @Nullable + public abstract RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId); +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java new file mode 100644 index 0000000000000..8eb3f2fc8339b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankDoc; + +import java.io.IOException; +import java.util.Objects; + +/** + * A {@link RankDoc} that contains field data to be used later by the reranker on the coordinator node. + */ +public class RankFeatureDoc extends RankDoc { + + // todo: update to support more than 1 fields; and not restrict to string data + public String featureData; + + public RankFeatureDoc(int doc, float score, int shardIndex) { + super(doc, score, shardIndex); + } + + public RankFeatureDoc(StreamInput in) throws IOException { + super(in); + featureData = in.readOptionalString(); + } + + public void featureData(String featureData) { + this.featureData = featureData; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeOptionalString(featureData); + } + + @Override + protected boolean doEquals(RankDoc rd) { + RankFeatureDoc other = (RankFeatureDoc) rd; + return Objects.equals(this.featureData, other.featureData); + } + + @Override + protected int doHashCode() { + return Objects.hashCode(featureData); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java new file mode 100644 index 0000000000000..1e16d18cda367 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; + +import java.io.IOException; + +/** + * The result of a rank feature search phase. + * Each instance holds a {@code RankFeatureShardResult} along with the references associated with it. + */ +public class RankFeatureResult extends SearchPhaseResult { + + private RankFeatureShardResult rankShardResult; + + public RankFeatureResult() {} + + public RankFeatureResult(ShardSearchContextId id, SearchShardTarget shardTarget, ShardSearchRequest request) { + this.contextId = id; + setSearchShardTarget(shardTarget); + setShardSearchRequest(request); + } + + public RankFeatureResult(StreamInput in) throws IOException { + super(in); + contextId = new ShardSearchContextId(in); + rankShardResult = in.readOptionalWriteable(RankFeatureShardResult::new); + setShardSearchRequest(in.readOptionalWriteable(ShardSearchRequest::new)); + setSearchShardTarget(in.readOptionalWriteable(SearchShardTarget::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + assert hasReferences(); + contextId.writeTo(out); + out.writeOptionalWriteable(rankShardResult); + out.writeOptionalWriteable(getShardSearchRequest()); + out.writeOptionalWriteable(getSearchShardTarget()); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return this; + } + + public void shardResult(RankFeatureShardResult shardResult) { + this.rankShardResult = shardResult; + } + + public RankFeatureShardResult shardResult() { + return rankShardResult; + } + + @Override + public boolean hasSearchContext() { + return rankShardResult != null; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java new file mode 100644 index 0000000000000..727ed4e938cca --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.search.SearchContextSourcePrinter; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; +import org.elasticsearch.search.fetch.subphase.FieldAndFormat; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.tasks.TaskCancelledException; + +import java.util.Arrays; +import java.util.Collections; + +/** + * The {@code RankFeatureShardPhase} executes the rank feature phase on the shard, iff there is a {@code RankBuilder} that requires it. + * This phase is responsible for reading field data for a set of docids. To do this, it reuses the {@code FetchPhase} to read the required + * fields for all requested documents using the `FetchFieldPhase` sub-phase. + */ +public final class RankFeatureShardPhase { + + private static final Logger logger = LogManager.getLogger(RankFeatureShardPhase.class); + + public static final RankFeatureShardResult EMPTY_RESULT = new RankFeatureShardResult(new RankFeatureDoc[0]); + + public RankFeatureShardPhase() {} + + public void prepareForFetch(SearchContext searchContext, RankFeatureShardRequest request) { + if (logger.isTraceEnabled()) { + logger.trace("{}", new SearchContextSourcePrinter(searchContext)); + } + + if (searchContext.isCancelled()) { + throw new TaskCancelledException("cancelled"); + } + + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext); + if (rankFeaturePhaseRankShardContext != null) { + assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null"; + searchContext.fetchFieldsContext( + new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) + ); + searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_))); + searchContext.addFetchResult(); + Arrays.sort(request.getDocIds()); + } + } + + public void processFetch(SearchContext searchContext) { + if (logger.isTraceEnabled()) { + logger.trace("{}", new SearchContextSourcePrinter(searchContext)); + } + + if (searchContext.isCancelled()) { + throw new TaskCancelledException("cancelled"); + } + + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = searchContext.request().source().rankBuilder() != null + ? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext() + : null; + if (rankFeaturePhaseRankShardContext != null) { + // TODO: here we populate the profile part of the fetchResult as well + // we need to see what info we want to include on the overall profiling section. This is something that is per-shard + // so most likely we will still care about the `FetchFieldPhase` profiling info as we could potentially + // operate on `rank_window_size` instead of just `size` results, so this could be much more expensive. + FetchSearchResult fetchSearchResult = searchContext.fetchResult(); + if (fetchSearchResult == null || fetchSearchResult.hits() == null) { + return; + } + // this cannot be null; as we have either already checked for it, or we would have thrown in + // FetchSearchResult#shardResult() + SearchHits hits = fetchSearchResult.hits(); + RankFeatureShardResult featureRankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext + .buildRankFeatureShardResult(hits, searchContext.shardTarget().getShardId().id()); + // save the result in the search context + // need to add profiling info as well available from fetch + if (featureRankShardResult != null) { + searchContext.rankFeatureResult().shardResult(featureRankShardResult); + } + } + } + + private RankFeaturePhaseRankShardContext shardContext(SearchContext searchContext) { + return searchContext.request().source() != null && searchContext.request().source().rankBuilder() != null + ? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext() + : null; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java new file mode 100644 index 0000000000000..d487fb63a0102 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchShardTask; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +/** + * Shard level request for extracting all needed feature for a global reranker + */ + +public class RankFeatureShardRequest extends TransportRequest implements IndicesRequest { + + private final OriginalIndices originalIndices; + private final ShardSearchRequest shardSearchRequest; + + private final ShardSearchContextId contextId; + + private final int[] docIds; + + public RankFeatureShardRequest( + OriginalIndices originalIndices, + ShardSearchContextId contextId, + ShardSearchRequest shardSearchRequest, + List docIds + ) { + this.originalIndices = originalIndices; + this.shardSearchRequest = shardSearchRequest; + this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray(); + this.contextId = contextId; + } + + public RankFeatureShardRequest(StreamInput in) throws IOException { + super(in); + originalIndices = OriginalIndices.readOriginalIndices(in); + shardSearchRequest = in.readOptionalWriteable(ShardSearchRequest::new); + docIds = in.readIntArray(); + contextId = in.readOptionalWriteable(ShardSearchContextId::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + OriginalIndices.writeOriginalIndices(originalIndices, out); + out.writeOptionalWriteable(shardSearchRequest); + out.writeIntArray(docIds); + out.writeOptionalWriteable(contextId); + } + + @Override + public String[] indices() { + if (originalIndices == null) { + return null; + } + return originalIndices.indices(); + } + + @Override + public IndicesOptions indicesOptions() { + if (originalIndices == null) { + return null; + } + return originalIndices.indicesOptions(); + } + + public ShardSearchRequest getShardSearchRequest() { + return shardSearchRequest; + } + + public int[] getDocIds() { + return docIds; + } + + public ShardSearchContextId contextId() { + return contextId; + } + + @Override + public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java new file mode 100644 index 0000000000000..e06b963621c60 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankShardResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * The result set of {@link RankFeatureDoc} docs for the shard. + */ +public class RankFeatureShardResult implements RankShardResult { + + public final RankFeatureDoc[] rankFeatureDocs; + + public RankFeatureShardResult(RankFeatureDoc[] rankFeatureDocs) { + this.rankFeatureDocs = Objects.requireNonNull(rankFeatureDocs); + } + + public RankFeatureShardResult(StreamInput in) throws IOException { + rankFeatureDocs = in.readArray(RankFeatureDoc::new, RankFeatureDoc[]::new); + } + + @Override + public String getWriteableName() { + return "rank_feature_shard"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_FEATURE_PHASE_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray(rankFeatureDocs); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RankFeatureShardResult that = (RankFeatureShardResult) o; + return Arrays.equals(rankFeatureDocs, that.rankFeatureDocs); + } + + @Override + public int hashCode() { + return 31 * Arrays.hashCode(rankFeatureDocs); + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + "{rankFeatureDocs=" + Arrays.toString(rankFeatureDocs) + '}'; + } +} diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java new file mode 100644 index 0000000000000..9716749562eae --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -0,0 +1,1170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.action.search; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankShardResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.InternalAggregationTestCase; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +public class RankFeaturePhaseTests extends ESTestCase { + + private static final int DEFAULT_RANK_WINDOW_SIZE = 10; + private static final int DEFAULT_FROM = 0; + private static final int DEFAULT_SIZE = 10; + private static final String DEFAULT_FIELD = "some_field"; + + private final RankBuilder DEFAULT_RANK_BUILDER = rankBuilder( + DEFAULT_RANK_WINDOW_SIZE, + defaultQueryPhaseRankShardContext(new ArrayList<>(), DEFAULT_RANK_WINDOW_SIZE), + defaultQueryPhaseRankCoordinatorContext(DEFAULT_RANK_WINDOW_SIZE), + defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD), + defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE) + ); + + private record ExpectedRankFeatureDoc(int doc, int rank, float score, String featureData) {} + + public void testRankFeaturePhaseWith1Shard() { + // request params used within SearchSourceBuilder and *RankContext classes + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + try { + queryResult.setShardIndex(shard1Target.getShardId().getId()); + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 + int totalHits = randomIntBetween(2, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResult, totalHits, shard1Docs); + results.consumeResult(queryResult, () -> {}); + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + totalHits, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResult.decRef(); + } + + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(1, rankPhaseResults.getAtomicArray().length()); + assertEquals(1, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShardResults = List.of( + new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"), + new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2") + ); + List expectedFinalResults = new ArrayList<>(expectedShardResults); + assertShardResults(shard1Result, expectedShardResults); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseWithMultipleShardsOneEmpty() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null); + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + queryResultShard3.setShardIndex(shard3Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final int shard2Results = randomIntBetween(1, 100); + final int shard3Results = 0; + + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + final ScoreDoc[] shard3Docs = new ScoreDoc[0]; + populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard3, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + // first shard + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + shard1Results, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 2 })) { + // second shard + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + shard2Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 789) { + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + queryResultShard3.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(3, rankPhaseResults.getAtomicArray().length()); + // one result is null + assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + assertShardResults(shard1Result, expectedShard1Results); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2")); + assertShardResults(shard2Result, expectedShard2Results); + + SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); + assertNull(shard3Result); + + List expectedFinalResults = List.of( + new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"), + new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2") + ); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseNoNeedForFetchingFieldData() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // build the appropriate RankBuilder; using a null rankFeaturePhaseRankShardContext + // and non-field based rankFeaturePhaseRankCoordinatorContext + RankBuilder rankBuilder = rankBuilder( + DEFAULT_RANK_WINDOW_SIZE, + defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE), + negatingScoresQueryFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE), + null, + null + ); + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + + try { + queryResult.setShardIndex(shard1Target.getShardId().getId()); + int totalHits = randomIntBetween(2, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResult, totalHits, shard1Docs); + results.consumeResult(queryResult, () -> {}); + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) { + listener.onFailure(new UnsupportedOperationException("should not have reached here")); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResult.decRef(); + } + // override the RankFeaturePhase to skip moving to next phase + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + + // in this case there was no additional "RankFeature" results on shards, so we shortcut directly to queryPhaseResults + SearchPhaseResults rankPhaseResults = rankFeaturePhase.queryPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(1, rankPhaseResults.getAtomicArray().length()); + assertEquals(1, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shardResult = rankPhaseResults.getAtomicArray().get(0); + assertTrue(shardResult instanceof QuerySearchResult); + QuerySearchResult rankResult = (QuerySearchResult) shardResult; + assertNull(rankResult.rankFeatureResult()); + assertNotNull(rankResult.queryResult()); + + List expectedFinalResults = List.of( + new ExpectedRankFeatureDoc(2, 1, -9.0F, null), + new ExpectedRankFeatureDoc(1, 2, -10.0F, null) + ); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseOneShardFails() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + + final int shard2Results = randomIntBetween(1, 100); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + // first shard + if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 2 })) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + shard2Docs + ); + listener.onResponse(rankFeatureResult); + + } else if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + // other shard; this one throws an exception + listener.onFailure(new IllegalArgumentException("simulated failure")); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + + mockSearchPhaseContext.assertNoFailure(); + assertEquals(1, mockSearchPhaseContext.failures.size()); + assertTrue(mockSearchPhaseContext.failures.get(0).getCause().getMessage().contains("simulated failure")); + assertTrue(phaseDone.get()); + + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(2, rankPhaseResults.getAtomicArray().length()); + // one shard failed + assertEquals(1, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + assertNull(shard1Result); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2")); + List expectedFinalResults = new ArrayList<>(expectedShard2Results); + assertShardResults(shard2Result, expectedShard2Results); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseExceptionThrownOnPhase() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + try { + queryResult.setShardIndex(shard1Target.getShardId().getId()); + int totalHits = randomIntBetween(2, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResult, totalHits, shard1Docs); + results.consumeResult(queryResult, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + totalHits, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResult.decRef(); + } + // override the RankFeaturePhase to raise an exception + RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext) { + @Override + void innerRun() { + throw new IllegalArgumentException("simulated failure"); + } + + @Override + public void moveToNextPhase( + SearchPhaseResults phaseResults, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + // this is called after the RankFeaturePhaseCoordinatorContext has been executed + phaseDone.set(true); + finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs(); + logger.debug("Skipping moving to next phase"); + } + }; + assertEquals("rank-feature", rankFeaturePhase.getName()); + try { + rankFeaturePhase.run(); + assertNotNull(mockSearchPhaseContext.phaseFailure.get()); + assertTrue(mockSearchPhaseContext.phaseFailure.get().getMessage().contains("simulated failure")); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertFalse(phaseDone.get()); + assertTrue(rankFeaturePhase.rankPhaseResults.getAtomicArray().asList().isEmpty()); + assertNull(finalResults[0][0]); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeatureWithPagination() { + // request params used within SearchSourceBuilder and *RankContext classes + final int from = 1; + final int size = 1; + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // build the appropriate RankBuilder + RankBuilder rankBuilder = rankBuilder( + DEFAULT_RANK_WINDOW_SIZE, + defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE), + defaultQueryPhaseRankCoordinatorContext(DEFAULT_RANK_WINDOW_SIZE), + defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD), + defaultRankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) + ); + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 4 results, with doc ids 1 and (11, 2, 200) found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null); + + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + queryResultShard3.setShardIndex(shard3Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + + final int shard2Results = randomIntBetween(1, 100); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { + new ScoreDoc(11, 100.0F, -1), + new ScoreDoc(2, 9.0F), + new ScoreDoc(200, 1F, -1) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + + final int shard3Results = 0; + final ScoreDoc[] shard3Docs = new ScoreDoc[0]; + populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard3, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + // make sure to match the context id generated above, otherwise we throw + // first shard + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + shard1Results, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 11, 2, 200 })) { + // second shard + + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + shard2Docs + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + queryResultShard3.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(3, rankPhaseResults.getAtomicArray().length()); + // one result is null + assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + assertShardResults(shard1Result, expectedShard1Results); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of( + new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"), + new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2"), + new ExpectedRankFeatureDoc(200, 3, 101.0F, "ranked_200") + + ); + assertShardResults(shard2Result, expectedShard2Results); + + SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); + assertNull(shard3Result); + + List expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1")); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeatureCollectOnlyRankWindowSizeFeatures() { + // request params used within SearchSourceBuilder and *RankContext classes + final int rankWindowSize = 2; + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // build the appropriate RankBuilder + RankBuilder rankBuilder = rankBuilder( + rankWindowSize, + defaultQueryPhaseRankShardContext(Collections.emptyList(), rankWindowSize), + defaultQueryPhaseRankCoordinatorContext(rankWindowSize), + defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD), + defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, rankWindowSize) + ); + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 3 results, with doc ids 1, and (11, 2) found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null); + + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + queryResultShard3.setShardIndex(shard3Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + + final int shard2Results = randomIntBetween(1, 100); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(11, 100.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + + final int shard3Results = 0; + final ScoreDoc[] shard3Docs = new ScoreDoc[0]; + populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard3, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + // make sure to match the context id generated above, otherwise we throw + // first shard + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + shard1Results, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 11 })) { + // second shard + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + new ScoreDoc[] { shard2Docs[0] } + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + queryResultShard3.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(3, rankPhaseResults.getAtomicArray().length()); + // one result is null + assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + assertShardResults(shard1Result, expectedShardResults); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11")); + assertShardResults(shard2Result, expectedShard2Results); + + SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); + assertNull(shard3Result); + + List expectedFinalResults = List.of( + new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"), + new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1") + ); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) { + + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + // no-op + // this one is handled directly in rankGlobalResults to create a RankFeatureDoc + // and avoid modifying in-place the ScoreDoc's rank + } + + @Override + public void rankGlobalResults(List rankSearchResults, ActionListener rankListener) { + List features = new ArrayList<>(); + for (RankFeatureResult rankFeatureResult : rankSearchResults) { + RankFeatureShardResult shardResult = rankFeatureResult.shardResult(); + features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList()); + } + rankListener.onResponse(features.toArray(new RankFeatureDoc[0])); + } + + @Override + public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { + Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))]; + // perform pagination + for (int rank = 0; rank < topResults.length; ++rank) { + RankFeatureDoc rfd = rankFeatureDocs[from + rank]; + topResults[rank] = new RankFeatureDoc(rfd.doc, rfd.score, rfd.shardIndex); + topResults[rank].rank = from + rank + 1; + } + return topResults; + } + }; + } + + private QueryPhaseRankCoordinatorContext negatingScoresQueryFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + return new QueryPhaseRankCoordinatorContext(rankWindowSize) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List rankSearchResults, + SearchPhaseController.TopDocsStats topDocsStats + ) { + List docScores = new ArrayList<>(); + for (QuerySearchResult phaseResults : rankSearchResults) { + docScores.addAll(Arrays.asList(phaseResults.topDocs().topDocs.scoreDocs)); + } + ScoreDoc[] sortedDocs = docScores.toArray(new ScoreDoc[0]); + // negating scores + Arrays.stream(sortedDocs).forEach(doc -> doc.score *= -1); + + Arrays.sort(sortedDocs, Comparator.comparing((ScoreDoc doc) -> doc.score).reversed()); + sortedDocs = Arrays.stream(sortedDocs).limit(rankWindowSize).toArray(ScoreDoc[]::new); + RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))]; + // perform pagination + for (int rank = 0; rank < topResults.length; ++rank) { + ScoreDoc base = sortedDocs[from + rank]; + topResults[rank] = new RankFeatureDoc(base.doc, base.score, base.shardIndex); + topResults[rank].rank = from + rank + 1; + } + topDocsStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + private RankFeaturePhaseRankShardContext defaultRankFeaturePhaseRankShardContext(String field) { + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].score += 100f; + rankFeatureDocs[i].featureData("ranked_" + hit.docId()); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + + private QueryPhaseRankCoordinatorContext defaultQueryPhaseRankCoordinatorContext(int rankWindowSize) { + return new QueryPhaseRankCoordinatorContext(rankWindowSize) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + RankFeatureShardResult shardResult = (RankFeatureShardResult) querySearchResult.getRankShardResult(); + for (RankFeatureDoc frd : shardResult.rankFeatureDocs) { + frd.shardIndex = i; + rankDocs.add(frd); + } + } + rankDocs.sort(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankFeatureDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + private QueryPhaseRankShardContext defaultQueryPhaseRankShardContext(List queries, int rankWindowSize) { + return new QueryPhaseRankShardContext(queries, rankWindowSize) { + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + throw new UnsupportedOperationException( + "shard-level QueryPhase context should not be accessed as part of the RankFeature phase" + ); + } + }; + } + + private SearchPhaseController searchPhaseController() { + return new SearchPhaseController((task, request) -> InternalAggregationTestCase.emptyReduceContextBuilder()); + } + + private RankBuilder rankBuilder( + int rankWindowSize, + QueryPhaseRankShardContext queryPhaseRankShardContext, + QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext, + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext, + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext + ) { + return new RankBuilder(rankWindowSize) { + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // no-op + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + // no-op + } + + @Override + public boolean isCompoundBuilder() { + return true; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return queryPhaseRankShardContext; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return queryPhaseRankCoordinatorContext; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return rankFeaturePhaseRankShardContext; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return rankFeaturePhaseRankCoordinatorContext; + } + + @Override + protected boolean doEquals(RankBuilder other) { + return other != null && other.rankWindowSize() == rankWindowSize; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return "test-rank-builder"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_12_0; + } + }; + } + + private SearchSourceBuilder searchSourceWithRankBuilder(RankBuilder rankBuilder) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(rankBuilder); + return searchSourceBuilder; + } + + private SearchPhaseResults searchPhaseResults( + SearchPhaseController controller, + MockSearchPhaseContext mockSearchPhaseContext + ) { + return controller.newSearchPhaseResults( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + () -> false, + SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), + mockSearchPhaseContext.numShards, + exc -> {} + ); + } + + private void buildRankFeatureResult( + RankBuilder shardRankBuilder, + RankFeatureResult rankFeatureResult, + SearchShardTarget shardTarget, + int totalHits, + ScoreDoc[] scoreDocs + ) { + rankFeatureResult.setSearchShardTarget(shardTarget); + // these are the SearchHits generated by the FetchFieldPhase processor + SearchHit[] searchHits = new SearchHit[scoreDocs.length]; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < searchHits.length; i++) { + searchHits[i] = SearchHit.unpooled(scoreDocs[i].doc); + searchHits[i].shard(shardTarget); + searchHits[i].score(scoreDocs[i].score); + searchHits[i].setDocumentField(DEFAULT_FIELD, new DocumentField(DEFAULT_FIELD, Collections.singletonList(scoreDocs[i].doc))); + if (scoreDocs[i].score > maxScore) { + maxScore = scoreDocs[i].score; + } + } + SearchHits hits = null; + try { + hits = SearchHits.unpooled(searchHits, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), maxScore); + // construct the appropriate RankFeatureDoc objects based on the rank builder + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardRankBuilder.buildRankFeaturePhaseShardContext(); + RankFeatureShardResult rankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext.buildRankFeatureShardResult( + hits, + shardTarget.getShardId().id() + ); + rankFeatureResult.shardResult(rankShardResult); + } finally { + if (hits != null) { + hits.decRef(); + } + } + } + + private void populateQuerySearchResult(QuerySearchResult queryResult, int totalHits, ScoreDoc[] scoreDocs) { + // this would have been populated during the QueryPhase by the appropriate QueryPhaseShardContext + float maxScore = Float.MIN_VALUE; + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + if (scoreDocs[i].score > maxScore) { + maxScore = scoreDocs[i].score; + } + rankFeatureDocs[i] = new RankFeatureDoc(scoreDocs[i].doc, scoreDocs[i].score, scoreDocs[i].shardIndex); + } + queryResult.setRankShardResult(new RankFeatureShardResult(rankFeatureDocs)); + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs), + maxScore + + ), + new DocValueFormat[0] + ); + queryResult.size(totalHits); + } + + private RankFeaturePhase rankFeaturePhase( + SearchPhaseResults results, + MockSearchPhaseContext mockSearchPhaseContext, + ScoreDoc[][] finalResults, + AtomicBoolean phaseDone + ) { + // override the RankFeaturePhase to skip moving to next phase + return new RankFeaturePhase(results, null, mockSearchPhaseContext) { + @Override + public void moveToNextPhase( + SearchPhaseResults phaseResults, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + // this is called after the RankFeaturePhaseCoordinatorContext has been executed + phaseDone.set(true); + finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs(); + logger.debug("Skipping moving to next phase"); + } + }; + } + + private void assertRankFeatureResults(RankFeatureShardResult rankFeatureShardResult, List expectedResults) { + assertEquals(expectedResults.size(), rankFeatureShardResult.rankFeatureDocs.length); + for (int i = 0; i < expectedResults.size(); i++) { + ExpectedRankFeatureDoc expected = expectedResults.get(i); + RankFeatureDoc actual = rankFeatureShardResult.rankFeatureDocs[i]; + assertEquals(expected.doc, actual.doc); + assertEquals(expected.rank, actual.rank); + assertEquals(expected.score, actual.score, 10E-5); + assertEquals(expected.featureData, actual.featureData); + } + } + + private void assertFinalResults(ScoreDoc[] finalResults, List expectedResults) { + assertEquals(expectedResults.size(), finalResults.length); + for (int i = 0; i < expectedResults.size(); i++) { + ExpectedRankFeatureDoc expected = expectedResults.get(i); + RankFeatureDoc actual = (RankFeatureDoc) finalResults[i]; + assertEquals(expected.doc, actual.doc); + assertEquals(expected.rank, actual.rank); + assertEquals(expected.score, actual.score, 10E-5); + } + } + + private void assertShardResults(SearchPhaseResult shardResult, List expectedShardResults) { + assertTrue(shardResult instanceof RankFeatureResult); + RankFeatureResult rankResult = (RankFeatureResult) shardResult; + assertNotNull(rankResult.rankFeatureResult()); + assertNull(rankResult.queryResult()); + assertNotNull(rankResult.rankFeatureResult().shardResult()); + RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult(); + assertRankFeatureResults(rankFeatureShardResult, expectedShardResults); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java index 59acb227385f6..4d58471f4817a 100644 --- a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java @@ -644,8 +644,8 @@ public void testIsParallelCollectionSupportedForResults() { ToLongFunction fieldCardinality = name -> -1; for (var resultsType : SearchService.ResultsType.values()) { switch (resultsType) { - case NONE, FETCH -> assertFalse( - "NONE and FETCH phases do not support parallel collection.", + case NONE, RANK_FEATURE, FETCH -> assertFalse( + "NONE, RANK_FEATURE, and FETCH phases do not support parallel collection.", DefaultSearchContext.isParallelCollectionSupportedForResults( resultsType, searchSourceBuilderOrNull, diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index d2c6c55634ec6..2af20a6ffef4a 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -13,6 +13,8 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollectorManager; import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.util.SetOnce; @@ -27,6 +29,7 @@ import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.ClosePointInTimeRequest; import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.SearchPhaseController; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -92,6 +95,7 @@ import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; +import org.elasticsearch.search.fetch.ShardFetchSearchRequest; import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.ContextIndexSearcher; @@ -102,12 +106,26 @@ import org.elasticsearch.search.query.NonCountingTermQuery; import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankShardResult; +import org.elasticsearch.search.rank.TestRankBuilder; +import org.elasticsearch.search.rank.TestRankDoc; +import org.elasticsearch.search.rank.TestRankShardResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.tasks.TaskCancelHelper; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; @@ -115,8 +133,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -136,8 +156,8 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; -import static org.elasticsearch.indices.cluster.AbstractIndicesClusterStateServiceTestCase.awaitIndexShardCloseAsyncTasks; import static org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason.DELETED; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.search.SearchService.QUERY_PHASE_PARALLEL_COLLECTION_ENABLED; import static org.elasticsearch.search.SearchService.SEARCH_WORKER_THREADS_ENABLED; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -371,7 +391,7 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { -1, null ), - new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), + new SearchShardTask(123L, "", "", "", null, emptyMap()), result.delegateFailure((l, r) -> { r.incRef(); l.onResponse(r); @@ -387,7 +407,7 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { null/* not a scroll */ ); PlainActionFuture listener = new PlainActionFuture<>(); - service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), listener); + service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener); listener.get(); if (useScroll) { // have to free context since this test does not remove the index from IndicesService. @@ -422,6 +442,711 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { assertEquals(0, totalStats.getFetchCurrent()); } + public void testRankFeaturePhaseSearchPhases() throws InterruptedException, ExecutionException { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + final SearchService service = getInstanceFromNode(SearchService.class); + + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex(indexName)); + final IndexShard indexShard = indexService.getShard(0); + SearchShardTask searchTask = new SearchShardTask(123L, "", "", "", null, emptyMap()); + + // create a SearchRequest that will return all documents and defines a TestRankBuilder with shard-level only operations + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true) + .source( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(DEFAULT_SIZE) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = (numDocs - i) + randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + } + ) + ); + + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + QuerySearchResult queryResult = null; + RankFeatureResult rankResult = null; + try { + // Execute the query phase and store the result in a SearchPhaseResult container using a PlainActionFuture + PlainActionFuture queryPhaseResults = new PlainActionFuture<>(); + service.executeQueryPhase(request, searchTask, queryPhaseResults); + queryResult = (QuerySearchResult) queryPhaseResults.get(); + + // these are the matched docs from the query phase + final TestRankDoc[] queryRankDocs = ((TestRankShardResult) queryResult.getRankShardResult()).testRankDocs; + + // assume that we have cut down to these from the coordinator node as the top-docs to run the rank feature phase upon + List topRankWindowSizeDocs = randomNonEmptySubsetOf(Arrays.stream(queryRankDocs).map(x -> x.doc).toList()); + + // now we create a RankFeatureShardRequest to extract feature info for the top-docs above + RankFeatureShardRequest rankFeatureShardRequest = new RankFeatureShardRequest( + OriginalIndices.NONE, + queryResult.getContextId(), // use the context from the query phase + request, + topRankWindowSizeDocs + ); + PlainActionFuture rankPhaseResults = new PlainActionFuture<>(); + service.executeRankFeaturePhase(rankFeatureShardRequest, searchTask, rankPhaseResults); + rankResult = rankPhaseResults.get(); + + assertNotNull(rankResult); + assertNotNull(rankResult.rankFeatureResult()); + RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult(); + assertNotNull(rankFeatureShardResult); + + List sortedRankWindowDocs = topRankWindowSizeDocs.stream().sorted().toList(); + assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length); + for (int i = 0; i < sortedRankWindowDocs.size(); i++) { + assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc); + assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i)); + } + + List globalTopKResults = randomNonEmptySubsetOf( + Arrays.stream(rankFeatureShardResult.rankFeatureDocs).map(x -> x.doc).toList() + ); + + // finally let's create a fetch request to bring back fetch info for the top results + ShardFetchSearchRequest fetchRequest = new ShardFetchSearchRequest( + OriginalIndices.NONE, + rankResult.getContextId(), + request, + globalTopKResults, + null, + rankResult.getRescoreDocIds(), + null + ); + + // execute fetch phase and perform any validations once we retrieve the response + // the difference in how we do assertions here is needed because once the transport service sends back the response + // it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null + service.executeFetchPhase(fetchRequest, searchTask, new ActionListener<>() { + @Override + public void onResponse(FetchSearchResult fetchSearchResult) { + assertNotNull(fetchSearchResult); + assertNotNull(fetchSearchResult.hits()); + + int totalHits = fetchSearchResult.hits().getHits().length; + assertEquals(globalTopKResults.size(), totalHits); + for (int i = 0; i < totalHits; i++) { + // rank and score are set by the SearchPhaseController#merge so no need to validate that here + SearchHit hit = fetchSearchResult.hits().getAt(i); + assertNotNull(hit.getFields().get(fetchFieldName)); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + } + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError("No failure should have been raised", e); + } + }); + } catch (Exception ex) { + if (queryResult != null) { + if (queryResult.hasReferences()) { + queryResult.decRef(); + } + service.freeReaderContext(queryResult.getContextId()); + } + if (rankResult != null && rankResult.hasReferences()) { + rankResult.decRef(); + } + throw ex; + } + } + + public void testRankFeaturePhaseUsingClient() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 4; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + ElasticsearchAssertions.assertResponse( + client().prepareSearch(indexName) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(2) + .from(2) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (TestRankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed()); + TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + } + ) + ), + (response) -> { + SearchHits hits = response.getHits(); + assertEquals(hits.getTotalHits().value, numDocs); + assertEquals(hits.getHits().length, 2); + int index = 0; + for (SearchHit hit : hits.getHits()) { + assertEquals(hit.getRank(), 3 + index); + assertTrue(hit.getScore() >= 0); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + index++; + } + } + ); + } + + public void testRankFeaturePhaseExceptionOnCoordinatingNode() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch(indexName) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(2) + .from(2) + .fetchField(fetchFieldName) + .rankBuilder(new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new IllegalStateException("should have failed earlier"); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + }) + ) + .get() + ); + } + + public void testRankFeaturePhaseExceptionAllShardFail() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch(indexName) + .setAllowPartialSearchResults(true) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (TestRankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed()); + TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + } + } + ) + ) + .get() + ); + } + + public void testRankFeaturePhaseExceptionOneShardFails() { + // if we have only one shard and it fails, it will fallback to context.onPhaseFailure which will eventually clean up all contexts. + // in this test we want to make sure that even if one shard (of many) fails during the RankFeaturePhase, then the appropriate + // context will have been cleaned up. + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).build()); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + assertResponse( + client().prepareSearch(indexName) + .setAllowPartialSearchResults(true) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (TestRankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed()); + TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + if (shardId == 0) { + throw new UnsupportedOperationException("simulated failure"); + } else { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + } + }; + } + } + ) + ), + (searchResponse) -> { + assertEquals(1, searchResponse.getSuccessfulShards()); + assertEquals("simulated failure", searchResponse.getShardFailures()[0].getCause().getMessage()); + assertNotEquals(0, searchResponse.getHits().getHits().length); + for (SearchHit hit : searchResponse.getHits().getHits()) { + assertEquals(fetchFieldValue + "_" + hit.getId(), hit.getFields().get(fetchFieldName).getValue()); + assertEquals(1, hit.getShard().getShardId().id()); + } + } + ); + } + public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws ExecutionException, InterruptedException { createIndex("index"); prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); @@ -457,7 +1182,7 @@ public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws Executi -1, null ), - new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), + new SearchShardTask(123L, "", "", "", null, emptyMap()), result ); @@ -694,7 +1419,7 @@ public void testMaxScriptFieldsSearch() throws IOException { for (int i = 0; i < maxScriptFields; i++) { searchSourceBuilder.scriptField( "field" + i, - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) ); } final ShardSearchRequest request = new ShardSearchRequest( @@ -723,7 +1448,7 @@ public void testMaxScriptFieldsSearch() throws IOException { } searchSourceBuilder.scriptField( "anotherScriptField", - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) ); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, @@ -752,7 +1477,7 @@ public void testIgnoreScriptfieldIfSizeZero() throws IOException { searchRequest.source(searchSourceBuilder); searchSourceBuilder.scriptField( "field" + 0, - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) ); searchSourceBuilder.size(0); final ShardSearchRequest request = new ShardSearchRequest( @@ -1036,7 +1761,7 @@ public void testCanMatch() throws Exception { ); CountDownLatch latch = new CountDownLatch(1); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); // Because the foo field used in alias filter is unmapped the term query builder rewrite can resolve to a match no docs query, // without acquiring a searcher and that means the wrapper is not called assertEquals(5, numWrapInvocations.get()); @@ -1330,7 +2055,7 @@ public void testMatchNoDocsEmptyResponse() throws InterruptedException { 0, null ); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); { CountDownLatch latch = new CountDownLatch(1); @@ -1705,7 +2430,7 @@ public void testWaitOnRefresh() throws ExecutionException, InterruptedException final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, searchRequest, @@ -1740,7 +2465,7 @@ public void testWaitOnRefreshFailsWithRefreshesDisabled() { final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); PlainActionFuture future = new PlainActionFuture<>(); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, @@ -1778,7 +2503,7 @@ public void testWaitOnRefreshFailsIfCheckpointNotIndexed() { final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); PlainActionFuture future = new PlainActionFuture<>(); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, @@ -1815,7 +2540,7 @@ public void testWaitOnRefreshTimeout() { final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); PlainActionFuture future = new PlainActionFuture<>(); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, @@ -1901,7 +2626,7 @@ public void testDfsQueryPhaseRewrite() { PlainActionFuture plainActionFuture = new PlainActionFuture<>(); service.executeQueryPhase( new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)), - new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()), + new SearchShardTask(42L, "", "", "", null, emptyMap()), plainActionFuture ); diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java new file mode 100644 index 0000000000000..cf464044cd701 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -0,0 +1,409 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TestSearchContext; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class RankFeatureShardPhaseTests extends ESTestCase { + + private SearchContext getSearchContext() { + return new TestSearchContext((SearchExecutionContext) null) { + + private FetchSearchResult fetchResult; + private RankFeatureResult rankFeatureResult; + private FetchFieldsContext fetchFieldsContext; + private StoredFieldsContext storedFieldsContext; + + @Override + public FetchSearchResult fetchResult() { + return fetchResult; + } + + @Override + public void addFetchResult() { + this.fetchResult = new FetchSearchResult(); + this.addReleasable(fetchResult::decRef); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return rankFeatureResult; + } + + @Override + public void addRankFeatureResult() { + this.rankFeatureResult = new RankFeatureResult(); + this.addReleasable(rankFeatureResult::decRef); + } + + @Override + public SearchContext fetchFieldsContext(FetchFieldsContext fetchFieldsContext) { + this.fetchFieldsContext = fetchFieldsContext; + return this; + } + + @Override + public FetchFieldsContext fetchFieldsContext() { + return fetchFieldsContext; + } + + @Override + public SearchContext storedFieldsContext(StoredFieldsContext storedFieldsContext) { + this.storedFieldsContext = storedFieldsContext; + return this; + } + + @Override + public StoredFieldsContext storedFieldsContext() { + return storedFieldsContext; + } + + @Override + public boolean isCancelled() { + return false; + } + }; + } + + private RankBuilder getRankBuilder(final String field) { + return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // no-op + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + // no-op + } + + @Override + public boolean isCompoundBuilder() { + return false; + } + + // no work to be done on the query phase + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return null; + } + + // no work to be done on the query phase + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return null; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(field).getValue()); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + + // no work to be done on the coordinator node for the rank feature phase + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return null; + } + + @Override + protected boolean doEquals(RankBuilder other) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return "rank_builder_rank_feature_shard_phase_enabled"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_FEATURE_PHASE_ADDED; + } + }; + } + + public void testPrepareForFetch() { + + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + try (SearchContext searchContext = spy(getSearchContext())) { + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + rankFeatureShardPhase.prepareForFetch(searchContext, request); + + assertNotNull(searchContext.fetchFieldsContext()); + assertEquals(searchContext.fetchFieldsContext().fields().size(), 1); + assertEquals(searchContext.fetchFieldsContext().fields().get(0).field, fieldName); + assertNotNull(searchContext.storedFieldsContext()); + assertNull(searchContext.storedFieldsContext().fieldNames()); + assertFalse(searchContext.storedFieldsContext().fetchFields()); + assertNotNull(searchContext.fetchResult()); + } + } + + public void testPrepareForFetchNoRankFeatureContext() { + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(null); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + try (SearchContext searchContext = spy(getSearchContext())) { + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + rankFeatureShardPhase.prepareForFetch(searchContext, request); + + assertNull(searchContext.fetchFieldsContext()); + assertNull(searchContext.fetchResult()); + } + } + + public void testPrepareForFetchWhileTaskIsCancelled() { + + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + try (SearchContext searchContext = spy(getSearchContext())) { + when(searchContext.isCancelled()).thenReturn(true); + when(searchContext.request()).thenReturn(searchRequest); + + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + expectThrows(TaskCancelledException.class, () -> rankFeatureShardPhase.prepareForFetch(searchContext, request)); + } + } + + public void testProcessFetch() { + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + Map expectedFieldData = Map.of(4, "doc_4_aardvark", 9, "doc_9_aardvark", numDocs - 1, "last_doc_aardvark"); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + SearchShardTarget shardTarget = new SearchShardTarget( + "node_id", + new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0), + null + ); + SearchHits searchHits = null; + try (SearchContext searchContext = spy(getSearchContext())) { + searchContext.addFetchResult(); + SearchHit[] hits = new SearchHit[3]; + hits[0] = SearchHit.unpooled(4); + hits[0].setDocumentField(fieldName, new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(4)))); + + hits[1] = SearchHit.unpooled(9); + hits[1].setDocumentField(fieldName, new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(9)))); + + hits[2] = SearchHit.unpooled(numDocs - 1); + hits[2].setDocumentField( + fieldName, + new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(numDocs - 1))) + ); + searchHits = SearchHits.unpooled(hits, new TotalHits(3, TotalHits.Relation.EQUAL_TO), 1.0f); + searchContext.fetchResult().shardResult(searchHits, null); + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + when(searchContext.shardTarget()).thenReturn(shardTarget); + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + // this is called as part of the search context initialization + // with the ResultsType.RANK_FEATURE type + searchContext.addRankFeatureResult(); + rankFeatureShardPhase.processFetch(searchContext); + + assertNotNull(searchContext.rankFeatureResult()); + assertNotNull(searchContext.rankFeatureResult().rankFeatureResult()); + for (RankFeatureDoc rankFeatureDoc : searchContext.rankFeatureResult().rankFeatureResult().shardResult().rankFeatureDocs) { + assertTrue(expectedFieldData.containsKey(rankFeatureDoc.doc)); + assertEquals(rankFeatureDoc.featureData, expectedFieldData.get(rankFeatureDoc.doc)); + } + } finally { + if (searchHits != null) { + searchHits.decRef(); + } + } + } + + public void testProcessFetchEmptyHits() { + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + SearchShardTarget shardTarget = new SearchShardTarget( + "node_id", + new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0), + null + ); + + SearchHits searchHits = null; + try (SearchContext searchContext = spy(getSearchContext())) { + searchContext.addFetchResult(); + SearchHit[] hits = new SearchHit[0]; + searchHits = SearchHits.unpooled(hits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f); + searchContext.fetchResult().shardResult(searchHits, null); + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + when(searchContext.shardTarget()).thenReturn(shardTarget); + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + // this is called as part of the search context initialization + // with the ResultsType.RANK_FEATURE type + searchContext.addRankFeatureResult(); + rankFeatureShardPhase.processFetch(searchContext); + + assertNotNull(searchContext.rankFeatureResult()); + assertNotNull(searchContext.rankFeatureResult().rankFeatureResult()); + assertEquals(searchContext.rankFeatureResult().rankFeatureResult().shardResult().rankFeatureDocs.length, 0); + } finally { + if (searchHits != null) { + searchHits.decRef(); + } + } + } + + public void testProcessFetchWhileTaskIsCancelled() { + + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + SearchShardTarget shardTarget = new SearchShardTarget( + "node_id", + new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0), + null + ); + + SearchHits searchHits = null; + try (SearchContext searchContext = spy(getSearchContext())) { + searchContext.addFetchResult(); + SearchHit[] hits = new SearchHit[0]; + searchHits = SearchHits.unpooled(hits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f); + searchContext.fetchResult().shardResult(searchHits, null); + when(searchContext.isCancelled()).thenReturn(true); + when(searchContext.request()).thenReturn(searchRequest); + when(searchContext.shardTarget()).thenReturn(shardTarget); + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + // this is called as part of the search context initialization + // with the ResultsType.RANK_FEATURE type + searchContext.addRankFeatureResult(); + expectThrows(TaskCancelledException.class, () -> rankFeatureShardPhase.processFetch(searchContext)); + } finally { + if (searchHits != null) { + searchHits.decRef(); + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 697b40671ee8b..6419759ab5962 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -178,6 +178,7 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.test.ClusterServiceUtils; @@ -2249,6 +2250,7 @@ public RecyclerBytesStreamOutput newNetworkBytesStream() { threadPool, scriptService, bigArrays, + new RankFeatureShardPhase(), new FetchPhase(Collections.emptyList()), responseCollectorService, new NoneCircuitBreakerService(), diff --git a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java index ef29f9fca4f93..520aff77497ba 100644 --- a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java +++ b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java @@ -40,6 +40,7 @@ import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.test.ESTestCase; @@ -97,6 +98,7 @@ SearchService newSearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -111,6 +113,7 @@ SearchService newSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, @@ -124,6 +127,7 @@ SearchService newSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index aa1889e15d594..747eff1d21708 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.internal.ReaderContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; @@ -81,6 +82,7 @@ public MockSearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -93,6 +95,7 @@ public MockSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java index 8e2a2c96a31ab..862c4d2ea3270 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java @@ -15,6 +15,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; @@ -31,7 +33,7 @@ public class TestRankBuilder extends RankBuilder { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, - args -> new TestRankBuilder(args[0] == null ? DEFAULT_WINDOW_SIZE : (int) args[0]) + args -> new TestRankBuilder(args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]) ); static { @@ -74,6 +76,11 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep // do nothing } + @Override + public boolean isCompoundBuilder() { + return true; + } + @Override public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { throw new UnsupportedOperationException(); @@ -84,6 +91,16 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si throw new UnsupportedOperationException(); } + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + throw new UnsupportedOperationException(); + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + throw new UnsupportedOperationException(); + } + @Override protected boolean doEquals(RankBuilder other) { return true; diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java index cba2b41d279bb..fa414cd8121d6 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java @@ -44,6 +44,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -463,6 +464,16 @@ public float getMaxScore() { return queryResult.getMaxScore(); } + @Override + public void addRankFeatureResult() { + // this space intentionally left blank + } + + @Override + public RankFeatureResult rankFeatureResult() { + return null; + } + @Override public FetchSearchResult fetchResult() { return null; diff --git a/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java b/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java index a7f21bd206c62..bf9eba87ee809 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java +++ b/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java @@ -687,6 +687,10 @@ public static Matcher hasScore(final float score) { return transformedMatch(SearchHit::getScore, equalTo(score)); } + public static Matcher hasRank(final int rank) { + return transformedMatch(SearchHit::getRank, equalTo(rank)); + } + public static T assertBooleanSubQuery(Query query, Class subqueryType, int i) { assertThat(query, instanceOf(BooleanQuery.class)); BooleanQuery q = (BooleanQuery) query; diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 04b0b11ad38d4..c0305f873327d 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -397,6 +397,14 @@ protected void onQueryResult(int shardIndex, QuerySearchResult queryResult) { } } + @Override + protected void onRankFeatureResult(int shardIndex) { + checkCancellation(); + if (delegate != null) { + delegate.onRankFeatureResult(shardIndex); + } + } + @Override protected void onFetchResult(int shardIndex) { checkCancellation(); @@ -420,6 +428,12 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc ); } + @Override + protected void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { + // best effort to cancel expired tasks + checkCancellation(); + } + @Override protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { // best effort to cancel expired tasks diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index 8f3ed15037c08..5c39c6c32fd06 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -16,6 +16,8 @@ import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -38,7 +40,7 @@ public class RRFRankBuilder extends RankBuilder { public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(RRFRankPlugin.NAME, args -> { - int windowSize = args[0] == null ? DEFAULT_WINDOW_SIZE : (int) args[0]; + int windowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; int rankConstant = args[1] == null ? DEFAULT_RANK_CONSTANT : (int) args[1]; if (rankConstant < 1) { throw new IllegalArgumentException("[rank_constant] must be greater than [0] for [rrf]"); @@ -94,6 +96,11 @@ public int rankConstant() { return rankConstant; } + @Override + public boolean isCompoundBuilder() { + return true; + } + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { return new RRFQueryPhaseRankShardContext(queries, rankWindowSize(), rankConstant); } @@ -103,6 +110,16 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si return new RRFQueryPhaseRankCoordinatorContext(size, from, rankWindowSize(), rankConstant); } + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return null; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return null; + } + @Override protected boolean doEquals(RankBuilder other) { return Objects.equals(rankConstant, ((RRFRankBuilder) other).rankConstant); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 077c933fa9add..e5a7983107278 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -71,7 +71,7 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP } List retrieverBuilders = Collections.emptyList(); - int rankWindowSize = RRFRankBuilder.DEFAULT_WINDOW_SIZE; + int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE; int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT; @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java index aeb6bfc8de796..221b7a65e1f8f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java @@ -45,6 +45,7 @@ public final class PreAuthorizationUtils { SearchTransportService.QUERY_ACTION_NAME, SearchTransportService.QUERY_ID_ACTION_NAME, SearchTransportService.FETCH_ID_ACTION_NAME, + SearchTransportService.RANK_FEATURE_SHARD_ACTION_NAME, SearchTransportService.QUERY_CAN_MATCH_NODE_NAME ) );