Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move bulkScorer() from Weight to ScorerSupplier #13408

Merged
merged 13 commits into from
May 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -211,26 +211,6 @@ private Scorer scorerForIterator(DocIdSetIterator iterator) {
return new ConstantScoreScorer(this, score(), scoreMode, iterator);
}

@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
final Terms terms = context.reader().terms(q.getField());
if (terms == null) {
return null;
}
final WeightOrDocIdSetIterator weightOrIterator = rewrite(context, terms);
if (weightOrIterator == null) {
return null;
} else if (weightOrIterator.weight != null) {
return weightOrIterator.weight.bulkScorer(context);
} else {
final Scorer scorer = scorerForIterator(weightOrIterator.iterator);
if (scorer == null) {
return null;
}
return new DefaultBulkScorer(scorer);
}
}

@Override
public Matches matches(LeafReaderContext context, int doc) throws IOException {
final Terms terms = context.reader().terms(q.field);
Expand Down Expand Up @@ -276,6 +256,32 @@ public Scorer get(long leadCost) throws IOException {
() -> new ConstantScoreScorer(weight, score(), scoreMode, DocIdSetIterator.empty()));
}

@Override
public BulkScorer bulkScorer() throws IOException {
WeightOrDocIdSetIterator weightOrIterator = rewrite(context, terms);
final BulkScorer bulkScorer;
if (weightOrIterator == null) {
bulkScorer = null;
} else if (weightOrIterator.weight != null) {
bulkScorer = weightOrIterator.weight.bulkScorer(context);
} else {
bulkScorer =
new DefaultBulkScorer(
new ConstantScoreScorer(weight, score(), scoreMode, weightOrIterator.iterator));
}

// It's against the API contract to return a null scorer from a non-null ScoreSupplier.
// So if our ScoreSupplier was non-null (i.e., thought there might be hits) but we now
// find that there are actually no hits, we need to return an empty BulkScorer as opposed
// to null:
return Objects.requireNonNullElseGet(
bulkScorer,
() ->
new DefaultBulkScorer(
new ConstantScoreScorer(
weight, score(), scoreMode, DocIdSetIterator.empty())));
}

@Override
public long cost() {
return cost;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,7 @@ public int count() throws IOException {

private final DocIdStreamView docIdStreamView = new DocIdStreamView();

BooleanScorer(
BooleanWeight weight,
Collection<BulkScorer> scorers,
int minShouldMatch,
boolean needsScores) {
BooleanScorer(Collection<BulkScorer> scorers, int minShouldMatch, boolean needsScores) {
if (minShouldMatch < 1 || minShouldMatch > scorers.size()) {
throw new IllegalArgumentException(
"minShouldMatch should be within 1..num_scorers. Got " + minShouldMatch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,29 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalLong;
import java.util.stream.Stream;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.Weight.DefaultBulkScorer;
import org.apache.lucene.util.Bits;

final class Boolean2ScorerSupplier extends ScorerSupplier {
final class BooleanScorerSupplier extends ScorerSupplier {

private final Weight weight;
private final Map<BooleanClause.Occur, Collection<ScorerSupplier>> subs;
private final ScoreMode scoreMode;
private final int minShouldMatch;
private final int maxDoc;
private long cost = -1;
private boolean topLevelScoringClause;

Boolean2ScorerSupplier(
BooleanScorerSupplier(
Weight weight,
Map<Occur, Collection<ScorerSupplier>> subs,
ScoreMode scoreMode,
int minShouldMatch) {
int minShouldMatch,
int maxDoc) {
if (minShouldMatch < 0) {
throw new IllegalArgumentException(
"minShouldMatch must be positive, but got: " + minShouldMatch);
Expand All @@ -64,6 +69,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
this.subs = subs;
this.scoreMode = scoreMode;
this.minShouldMatch = minShouldMatch;
this.maxDoc = maxDoc;
}

private long computeCost() {
Expand Down Expand Up @@ -166,6 +172,217 @@ private Scorer getInternal(long leadCost) throws IOException {
}
}

@Override
public BulkScorer bulkScorer() throws IOException {
final BulkScorer bulkScorer = booleanScorer();
if (bulkScorer != null) {
// bulk scoring is applicable, use it
return bulkScorer;
} else {
// use a Scorer-based impl (BS2)
return super.bulkScorer();
}
}

BulkScorer booleanScorer() throws IOException {
final int numOptionalClauses = subs.get(Occur.SHOULD).size();
final int numRequiredClauses = subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size();

BulkScorer positiveScorer;
if (numRequiredClauses == 0) {
// TODO: what is the right heuristic here?
final long costThreshold;
if (minShouldMatch <= 1) {
// when all clauses are optional, use BooleanScorer aggressively
// TODO: is there actually a threshold under which we should rather
// use the regular scorer?
costThreshold = -1;
} else {
// when a minimum number of clauses should match, BooleanScorer is
// going to score all windows that have at least minNrShouldMatch
// matches in the window. But there is no way to know if there is
// an intersection (all clauses might match a different doc ID and
// there will be no matches in the end) so we should only use
// BooleanScorer if matches are very dense
costThreshold = maxDoc / 3;
}

if (cost() < costThreshold) {
return null;
}

positiveScorer = optionalBulkScorer();
} else if (numRequiredClauses > 0 && numOptionalClauses == 0 && minShouldMatch == 0) {
positiveScorer = requiredBulkScorer();
} else {
// TODO: there are some cases where BooleanScorer
// would handle conjunctions faster than
// BooleanScorer2...
return null;
}

if (positiveScorer == null) {
return null;
}
final long positiveScorerCost = positiveScorer.cost();

List<Scorer> prohibited = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.MUST_NOT)) {
prohibited.add(ss.get(positiveScorerCost));
}

if (prohibited.isEmpty()) {
return positiveScorer;
} else {
Scorer prohibitedScorer =
prohibited.size() == 1
? prohibited.get(0)
: new DisjunctionSumScorer(weight, prohibited, ScoreMode.COMPLETE_NO_SCORES);
return new ReqExclBulkScorer(positiveScorer, prohibitedScorer);
}
}

static BulkScorer disableScoring(final BulkScorer scorer) {
Objects.requireNonNull(scorer);
return new BulkScorer() {

@Override
public int score(final LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
final LeafCollector noScoreCollector =
new LeafCollector() {
Score fake = new Score();

@Override
public void setScorer(Scorable scorer) throws IOException {
collector.setScorer(fake);
}

@Override
public void collect(int doc) throws IOException {
collector.collect(doc);
}
};
return scorer.score(noScoreCollector, acceptDocs, min, max);
}

@Override
public long cost() {
return scorer.cost();
}
};
}

// Return a BulkScorer for the optional clauses only,
// or null if it is not applicable
// pkg-private for forcing use of BooleanScorer in tests
BulkScorer optionalBulkScorer() throws IOException {
if (subs.get(Occur.SHOULD).size() == 0) {
return null;
} else if (subs.get(Occur.SHOULD).size() == 1 && minShouldMatch <= 1) {
return subs.get(Occur.SHOULD).iterator().next().bulkScorer();
}

if (scoreMode == ScoreMode.TOP_SCORES && minShouldMatch <= 1) {
List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optionalScorers.add(ss.get(Long.MAX_VALUE));
}

return new MaxScoreBulkScorer(maxDoc, optionalScorers);
}

List<BulkScorer> optional = new ArrayList<BulkScorer>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optional.add(ss.bulkScorer());
}

return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores());
}

// Return a BulkScorer for the required clauses only
private BulkScorer requiredBulkScorer() throws IOException {
if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) {
// No required clauses at all.
return null;
} else if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 1) {
BulkScorer scorer;
if (subs.get(Occur.MUST).isEmpty() == false) {
scorer = subs.get(Occur.MUST).iterator().next().bulkScorer();
} else {
scorer = subs.get(Occur.FILTER).iterator().next().bulkScorer();
if (scoreMode.needsScores()) {
scorer = disableScoring(scorer);
}
}
return scorer;
}

long leadCost =
subs.get(Occur.MUST).stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE);
leadCost =
subs.get(Occur.FILTER).stream().mapToLong(ScorerSupplier::cost).min().orElse(leadCost);

List<Scorer> requiredNoScoring = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
requiredNoScoring.add(ss.get(leadCost));
}
List<Scorer> requiredScoring = new ArrayList<>();
Collection<ScorerSupplier> requiredScoringSupplier = subs.get(Occur.MUST);
for (ScorerSupplier ss : requiredScoringSupplier) {
if (requiredScoringSupplier.size() == 1) {
ss.setTopLevelScoringClause();
}
requiredScoring.add(ss.get(leadCost));
}
if (scoreMode == ScoreMode.TOP_SCORES
&& requiredNoScoring.isEmpty()
&& requiredScoring.size() > 1
// Only specialize top-level conjunctions for clauses that don't have a two-phase iterator.
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new BlockMaxConjunctionBulkScorer(maxDoc, requiredScoring);
}
if (scoreMode != ScoreMode.TOP_SCORES
&& requiredScoring.size() + requiredNoScoring.size() >= 2
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
}
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
requiredScoring =
Collections.singletonList(new BlockMaxConjunctionScorer(weight, requiredScoring));
}
Scorer conjunctionScorer;
if (requiredNoScoring.size() + requiredScoring.size() == 1) {
if (requiredScoring.size() == 1) {
conjunctionScorer = requiredScoring.get(0);
} else {
conjunctionScorer = requiredNoScoring.get(0);
if (scoreMode.needsScores()) {
Scorer inner = conjunctionScorer;
conjunctionScorer =
new FilterScorer(inner) {
@Override
public float score() throws IOException {
return 0f;
}

@Override
public float getMaxScore(int upTo) throws IOException {
return 0f;
}
};
}
}
} else {
List<Scorer> required = new ArrayList<>();
required.addAll(requiredScoring);
required.addAll(requiredNoScoring);
conjunctionScorer = new ConjunctionScorer(weight, required, requiredScoring);
}
return new DefaultBulkScorer(conjunctionScorer);
}

/**
* Create a new scorer for the given required clauses. Note that {@code requiredScoring} is a
* subset of {@code required} containing required clauses that should participate in scoring.
Expand Down
Loading