Skip to content

Commit

Permalink
Propagate scoring function through random sampler (elastic#116957)
Browse files Browse the repository at this point in the history
* Propagate scoring function through random sampler.

* Update docs/changelog/116957.yaml

* Correct score mode in random sampler weight

* Fix random sampling with scores and p=1.0

* Unit test with scores

* YAML test

* Add capability
  • Loading branch information
jan-elastic authored Nov 20, 2024
1 parent c3f73d0 commit dea1e7d
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 55 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/116957.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116957
summary: Propagate scoring function through random sampler
area: Machine Learning
type: bug
issues: [ 110134 ]
2 changes: 1 addition & 1 deletion modules/aggregations/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ esplugin {

restResources {
restApi {
include '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
include 'capabilities', '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
}
restTests {
// Pulls in all aggregation tests from core AND the forwards v7's core for forwards compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,66 @@ setup:
}
- match: { aggregations.sampled.mean.value: 1.0 }
---
"Test random_sampler aggregation with scored subagg":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ random_sampler_with_scored_subaggs ]
test_runner_features: capabilities
reason: "Support for random sampler with scored subaggs capability required"
- do:
search:
index: data
size: 0
body: >
{
"query": {
"function_score": {
"random_score": {}
}
},
"aggs": {
"sampled": {
"random_sampler": {
"probability": 0.5
},
"aggs": {
"top": {
"top_hits": {}
}
}
}
}
}
- is_true: aggregations.sampled.top.hits
- do:
search:
index: data
size: 0
body: >
{
"query": {
"function_score": {
"random_score": {}
}
},
"aggs": {
"sampled": {
"random_sampler": {
"probability": 1.0
},
"aggs": {
"top": {
"top_hits": {}
}
}
}
}
}
- match: { aggregations.sampled.top.hits.total.value: 6 }
- is_true: aggregations.sampled.top.hits.hits.0._score
---
"Test random_sampler aggregation with poor settings":
- requires:
cluster_features: ["gte_v8.2.0"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ private SearchCapabilities() {}
/** Support multi-dense-vector script field access. */
private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";

private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";

public static final Set<String> CAPABILITIES;
static {
HashSet<String> capabilities = new HashSet<>();
Expand All @@ -50,6 +52,7 @@ private SearchCapabilities() {}
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS);
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator {

protected final String name;
protected final Aggregator parent;
private final AggregationContext context;
protected final AggregationContext context;
private final Map<String, Object> metadata;

protected final Aggregator[] subAggregators;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.support.AggregationContext;
Expand All @@ -34,14 +38,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
private final int seed;
private final Integer shardSeed;
private final double probability;
private final CheckedSupplier<Weight, IOException> weightSupplier;
private Weight weight;

RandomSamplerAggregator(
String name,
int seed,
Integer shardSeed,
double probability,
CheckedSupplier<Weight, IOException> weightSupplier,
AggregatorFactories factories,
AggregationContext context,
Aggregator parent,
Expand All @@ -56,10 +59,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
);
}
this.weightSupplier = weightSupplier;
this.shardSeed = shardSeed;
}

/**
* This creates the query weight which will be used in the aggregator.
*
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
* @return weight to be used, is cached for additional usages
* @throws IOException when building the weight or queries fails;
*/
private Weight getWeight() throws IOException {
if (weight == null) {
ScoreMode scoreMode = scoreMode();
BooleanQuery.Builder fullQuery = new BooleanQuery.Builder().add(
context.query(),
scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER
);
if (probability < 1.0) {
Query sampleQuery = new RandomSamplingQuery(probability, seed, shardSeed == null ? context.shardRandomSeed() : shardSeed);
fullQuery.add(sampleQuery, BooleanClause.Occur.FILTER);
}
weight = context.searcher().createWeight(context.searcher().rewrite(fullQuery.build()), scoreMode, 1f);
}
return weight;
}

@Override
public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException {
return buildAggregationsForSingleBucket(
Expand Down Expand Up @@ -101,22 +127,26 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt
if (sub.isNoop()) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
// This means there are no docs to iterate, possibly due to the fields not existing
if (scorer == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
sub.setScorer(scorer);

// No sampling is being done, collect all docs
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
if (probability >= 1.0) {
grow(1);
return new LeafBucketCollector() {
return new LeafBucketCollectorBase(sub, null) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
collectExistingBucket(sub, doc, 0);
}
};
}
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext());
// This means there are no docs to iterate, possibly due to the fields not existing
if (scorer == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

final DocIdSetIterator docIt = scorer.iterator();
final Bits liveDocs = aggCtx.getLeafReaderContext().reader().getLiveDocs();
try {
Expand All @@ -136,5 +166,4 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
// Since we have done our own collection, there is nothing for the leaf collector to do
return LeafBucketCollector.NO_OP_COLLECTOR;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
Expand All @@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
private final Integer shardSeed;
private final double probability;
private final SamplingContext samplingContext;
private Weight weight;

RandomSamplerAggregatorFactory(
String name,
Expand All @@ -57,41 +52,6 @@ public Optional<SamplingContext> getSamplingContext() {
@Override
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
throws IOException {
return new RandomSamplerAggregator(
name,
seed,
shardSeed,
probability,
this::getWeight,
factories,
context,
parent,
cardinality,
metadata
);
return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata);
}

/**
* This creates the query weight which will be used in the aggregator.
*
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
* @return weight to be used, is cached for additional usages
* @throws IOException when building the weight or queries fails;
*/
private Weight getWeight() throws IOException {
if (weight == null) {
RandomSamplingQuery query = new RandomSamplingQuery(
probability,
seed,
shardSeed == null ? context.shardRandomSeed() : shardSeed
);
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
.add(context.query(), BooleanClause.Occur.FILTER)
.build();
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
}
return weight;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,29 @@

import org.apache.lucene.document.LongPoint;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Avg;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.aggregations.metrics.Min;
import org.elasticsearch.search.aggregations.metrics.TopHits;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.DoubleStream;
Expand All @@ -37,6 +44,8 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notANumber;
Expand Down Expand Up @@ -76,6 +85,35 @@ public void testAggregationSampling() throws IOException {
assertThat(avgAvg, closeTo(1.5, 0.5));
}

public void testAggregationSampling_withScores() throws IOException {
long[] counts = new long[5];
AtomicInteger integer = new AtomicInteger();
do {
testCase(RandomSamplerAggregatorTests::writeTestDocs, (InternalRandomSampler result) -> {
counts[integer.get()] = result.getDocCount();
if (result.getDocCount() > 0) {
TopHits agg = result.getAggregations().get("top");
List<SearchHit> hits = Arrays.asList(agg.getHits().getHits());
assertThat(Strings.toString(result), hits, hasSize(1));
assertThat(Strings.toString(result), hits.get(0).getScore(), allOf(greaterThan(0.0f), lessThan(1.0f)));
}
},
new AggTestConfig(
new RandomSamplerAggregationBuilder("my_agg").subAggregation(AggregationBuilders.topHits("top").size(1))
.setProbability(0.25),
longField(NUMERIC_FIELD_NAME)
).withQuery(
new BooleanQuery.Builder().add(
new TermQuery(new Term(KEYWORD_FIELD_NAME, KEYWORD_FIELD_VALUE)),
BooleanClause.Occur.SHOULD
).build()
)
);
} while (integer.incrementAndGet() < 5);
long avgCount = LongStream.of(counts).sum() / integer.get();
assertThat(avgCount, allOf(greaterThanOrEqualTo(20L), lessThanOrEqualTo(70L)));
}

public void testAggregationSamplingNestedAggsScaled() throws IOException {
// in case 0 docs get sampled, which can rarely happen
// in case the test index has many segments.
Expand Down

0 comments on commit dea1e7d

Please sign in to comment.