diff --git a/docs/changelog/116957.yaml b/docs/changelog/116957.yaml new file mode 100644 index 0000000000000..1020190de180d --- /dev/null +++ b/docs/changelog/116957.yaml @@ -0,0 +1,5 @@ +pr: 116957 +summary: Propagate scoring function through random sampler +area: Machine Learning +type: bug +issues: [ 110134 ] diff --git a/modules/aggregations/build.gradle b/modules/aggregations/build.gradle index 5df0a890af753..2835180904620 100644 --- a/modules/aggregations/build.gradle +++ b/modules/aggregations/build.gradle @@ -20,7 +20,7 @@ esplugin { restResources { restApi { - include '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script' + include 'capabilities', '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script' } restTests { // Pulls in all aggregation tests from core AND the forwards v7's core for forwards compatibility diff --git a/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/random_sampler.yml b/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/random_sampler.yml index 5b2c2dc379cb9..4d8efe2a6f9d8 100644 --- a/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/random_sampler.yml +++ b/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/random_sampler.yml @@ -142,6 +142,66 @@ setup: } - match: { aggregations.sampled.mean.value: 1.0 } --- +"Test random_sampler aggregation with scored subagg": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ random_sampler_with_scored_subaggs ] + test_runner_features: capabilities + reason: "Support for random sampler with scored subaggs capability required" + - do: + search: + index: data + size: 0 + body: > + { + "query": { + "function_score": { + "random_score": {} + } + }, + "aggs": { + "sampled": { + "random_sampler": { + "probability": 0.5 + }, + "aggs": { + "top": { + "top_hits": {} + } + } + } + } + } + - is_true: aggregations.sampled.top.hits + - do: + search: + index: data + size: 0 + body: > + { + "query": { + "function_score": { + "random_score": {} + } + }, + "aggs": { + "sampled": { + "random_sampler": { + "probability": 1.0 + }, + "aggs": { + "top": { + "top_hits": {} + } + } + } + } + } + - match: { aggregations.sampled.top.hits.total.value: 6 } + - is_true: aggregations.sampled.top.hits.hits.0._score +--- "Test random_sampler aggregation with poor settings": - requires: cluster_features: ["gte_v8.2.0"] diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 7b57481ad5716..241f30b367782 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -41,6 +41,8 @@ private SearchCapabilities() {} /** Support multi-dense-vector script field access. */ private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access"; + private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; + public static final Set CAPABILITIES; static { HashSet capabilities = new HashSet<>(); @@ -50,6 +52,7 @@ private SearchCapabilities() {} capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS); capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); + capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java index bf9116207b375..1ea7769b33384 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java @@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator { protected final String name; protected final Aggregator parent; - private final AggregationContext context; + protected final AggregationContext context; private final Map metadata; protected final Aggregator[] subAggregators; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java index 921cbb96385ad..699b8c6b5d500 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java @@ -9,12 +9,15 @@ package org.elasticsearch.search.aggregations.bucket.sampler.random; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; -import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.search.aggregations.AggregationExecutionContext; import org.elasticsearch.search.aggregations.Aggregator; @@ -22,6 +25,7 @@ import org.elasticsearch.search.aggregations.CardinalityUpperBound; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.aggregations.bucket.BucketsAggregator; import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; @@ -34,14 +38,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single private final int seed; private final Integer shardSeed; private final double probability; - private final CheckedSupplier weightSupplier; + private Weight weight; RandomSamplerAggregator( String name, int seed, Integer shardSeed, double probability, - CheckedSupplier weightSupplier, AggregatorFactories factories, AggregationContext context, Aggregator parent, @@ -56,10 +59,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured" ); } - this.weightSupplier = weightSupplier; this.shardSeed = shardSeed; } + /** + * This creates the query weight which will be used in the aggregator. + * + * This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows + * the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground. + * @return weight to be used, is cached for additional usages + * @throws IOException when building the weight or queries fails; + */ + private Weight getWeight() throws IOException { + if (weight == null) { + ScoreMode scoreMode = scoreMode(); + BooleanQuery.Builder fullQuery = new BooleanQuery.Builder().add( + context.query(), + scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER + ); + if (probability < 1.0) { + Query sampleQuery = new RandomSamplingQuery(probability, seed, shardSeed == null ? context.shardRandomSeed() : shardSeed); + fullQuery.add(sampleQuery, BooleanClause.Occur.FILTER); + } + weight = context.searcher().createWeight(context.searcher().rewrite(fullQuery.build()), scoreMode, 1f); + } + return weight; + } + @Override public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException { return buildAggregationsForSingleBucket( @@ -101,22 +127,26 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt if (sub.isNoop()) { return LeafBucketCollector.NO_OP_COLLECTOR; } + + Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext()); + // This means there are no docs to iterate, possibly due to the fields not existing + if (scorer == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + sub.setScorer(scorer); + // No sampling is being done, collect all docs + // TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353 if (probability >= 1.0) { grow(1); - return new LeafBucketCollector() { + return new LeafBucketCollectorBase(sub, null) { @Override public void collect(int doc, long owningBucketOrd) throws IOException { collectExistingBucket(sub, doc, 0); } }; } - // TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353 - Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext()); - // This means there are no docs to iterate, possibly due to the fields not existing - if (scorer == null) { - return LeafBucketCollector.NO_OP_COLLECTOR; - } + final DocIdSetIterator docIt = scorer.iterator(); final Bits liveDocs = aggCtx.getLeafReaderContext().reader().getLiveDocs(); try { @@ -136,5 +166,4 @@ public void collect(int doc, long owningBucketOrd) throws IOException { // Since we have done our own collection, there is nothing for the leaf collector to do return LeafBucketCollector.NO_OP_COLLECTOR; } - } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java index 67c958046dac7..50921501896d3 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java @@ -9,10 +9,6 @@ package org.elasticsearch.search.aggregations.bucket.sampler.random; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Weight; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; @@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory { private final Integer shardSeed; private final double probability; private final SamplingContext samplingContext; - private Weight weight; RandomSamplerAggregatorFactory( String name, @@ -57,41 +52,6 @@ public Optional getSamplingContext() { @Override public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map metadata) throws IOException { - return new RandomSamplerAggregator( - name, - seed, - shardSeed, - probability, - this::getWeight, - factories, - context, - parent, - cardinality, - metadata - ); + return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata); } - - /** - * This creates the query weight which will be used in the aggregator. - * - * This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows - * the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground. - * @return weight to be used, is cached for additional usages - * @throws IOException when building the weight or queries fails; - */ - private Weight getWeight() throws IOException { - if (weight == null) { - RandomSamplingQuery query = new RandomSamplingQuery( - probability, - seed, - shardSeed == null ? context.shardRandomSeed() : shardSeed - ); - BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER) - .add(context.query(), BooleanClause.Occur.FILTER) - .build(); - weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE_NO_SCORES, 1f); - } - return weight; - } - } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorTests.java index f75f9f474c8e8..2f51a5a09a8ac 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorTests.java @@ -11,22 +11,29 @@ import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.Strings; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.bucket.filter.Filter; import org.elasticsearch.search.aggregations.metrics.Avg; import org.elasticsearch.search.aggregations.metrics.Max; import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.search.aggregations.metrics.TopHits; import org.hamcrest.Description; import org.hamcrest.Matcher; import org.hamcrest.TypeSafeMatcher; import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.DoubleStream; @@ -37,6 +44,8 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notANumber; @@ -76,6 +85,35 @@ public void testAggregationSampling() throws IOException { assertThat(avgAvg, closeTo(1.5, 0.5)); } + public void testAggregationSampling_withScores() throws IOException { + long[] counts = new long[5]; + AtomicInteger integer = new AtomicInteger(); + do { + testCase(RandomSamplerAggregatorTests::writeTestDocs, (InternalRandomSampler result) -> { + counts[integer.get()] = result.getDocCount(); + if (result.getDocCount() > 0) { + TopHits agg = result.getAggregations().get("top"); + List hits = Arrays.asList(agg.getHits().getHits()); + assertThat(Strings.toString(result), hits, hasSize(1)); + assertThat(Strings.toString(result), hits.get(0).getScore(), allOf(greaterThan(0.0f), lessThan(1.0f))); + } + }, + new AggTestConfig( + new RandomSamplerAggregationBuilder("my_agg").subAggregation(AggregationBuilders.topHits("top").size(1)) + .setProbability(0.25), + longField(NUMERIC_FIELD_NAME) + ).withQuery( + new BooleanQuery.Builder().add( + new TermQuery(new Term(KEYWORD_FIELD_NAME, KEYWORD_FIELD_VALUE)), + BooleanClause.Occur.SHOULD + ).build() + ) + ); + } while (integer.incrementAndGet() < 5); + long avgCount = LongStream.of(counts).sum() / integer.get(); + assertThat(avgCount, allOf(greaterThanOrEqualTo(20L), lessThanOrEqualTo(70L))); + } + public void testAggregationSamplingNestedAggsScaled() throws IOException { // in case 0 docs get sampled, which can rarely happen // in case the test index has many segments.