diff --git a/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java b/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java index 9dde32cc75e6a..b27019765e33b 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptTermStats.java @@ -12,9 +12,8 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; -import org.apache.lucene.index.TermState; import org.apache.lucene.index.TermStates; -import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.features.NodeFeature; @@ -71,17 +70,15 @@ public int uniqueTermsCount() { public int matchedTermsCount() { final int docId = docIdSupplier.getAsInt(); int matchedTerms = 0; + advancePostings(docId); - try { - for (PostingsEnum postingsEnum : postingsSupplier.get()) { - if (postingsEnum != null && postingsEnum.advance(docId) == docId && postingsEnum.freq() > 0) { - matchedTerms++; - } + for (PostingsEnum postingsEnum : postingsSupplier.get()) { + if (postingsEnum != null && postingsEnum.docID() == docId) { + matchedTerms++; } - return matchedTerms; - } catch (IOException e) { - throw new UncheckedIOException(e); } + + return matchedTerms; } /** @@ -150,8 +147,9 @@ public StatsSummary termFreq() { final int docId = docIdSupplier.getAsInt(); try { + advancePostings(docId); for (PostingsEnum postingsEnum : postingsSupplier.get()) { - if (postingsEnum == null || postingsEnum.advance(docId) != docId) { + if (postingsEnum == null || postingsEnum.docID() != docId) { statsSummary.accept(0); } else { statsSummary.accept(postingsEnum.freq()); @@ -170,12 +168,13 @@ public StatsSummary termFreq() { * @return statistics on termPositions for the terms of the query in the current dac */ public StatsSummary termPositions() { - try { - statsSummary.reset(); - int docId = docIdSupplier.getAsInt(); + statsSummary.reset(); + int docId = docIdSupplier.getAsInt(); + try { + advancePostings(docId); for (PostingsEnum postingsEnum : postingsSupplier.get()) { - if (postingsEnum == null || postingsEnum.advance(docId) != docId) { + if (postingsEnum == null || postingsEnum.docID() != docId) { continue; } for (int i = 0; i < postingsEnum.freq(); i++) { @@ -206,25 +205,9 @@ private TermStates[] loadTermContexts() { private PostingsEnum[] loadPostings() { try { PostingsEnum[] postings = new PostingsEnum[terms.length]; - TermStates[] contexts = termContextsSupplier.get(); for (int i = 0; i < terms.length; i++) { - TermStates termStates = contexts[i]; - if (termStates.docFreq() == 0) { - postings[i] = null; - continue; - } - - TermState state = termStates.get(leafReaderContext); - if (state == null) { - postings[i] = null; - continue; - } - - TermsEnum termsEnum = leafReaderContext.reader().terms(terms[i].field()).iterator(); - termsEnum.seekExact(terms[i].bytes(), state); - - postings[i] = termsEnum.postings(null, PostingsEnum.ALL); + postings[i] = leafReaderContext.reader().postings(terms[i], PostingsEnum.POSITIONS); } return postings; @@ -232,4 +215,16 @@ private PostingsEnum[] loadPostings() { throw new UncheckedIOException(e); } } + + private void advancePostings(int targetDocId) { + try { + for (PostingsEnum posting : postingsSupplier.get()) { + if (posting != null && posting.docID() < targetDocId && posting.docID() != DocIdSetIterator.NO_MORE_DOCS) { + posting.advance(targetDocId); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } } diff --git a/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java b/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java index b1b6a11764120..239c90bdee2fd 100644 --- a/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java +++ b/server/src/test/java/org/elasticsearch/script/ScriptTermStatsTests.java @@ -48,9 +48,9 @@ public void testMatchedTermsCount() throws IOException { // Partial match assertAllDocs( - Set.of(new Term("field", "foo"), new Term("field", "baz")), + Set.of(new Term("field", "foo"), new Term("field", "qux"), new Term("field", "baz")), ScriptTermStats::matchedTermsCount, - Map.of("doc-1", equalTo(1), "doc-2", equalTo(1), "doc-3", equalTo(0)) + Map.of("doc-1", equalTo(2), "doc-2", equalTo(1), "doc-3", equalTo(0)) ); // Always returns 0 when no term is provided. @@ -211,12 +211,12 @@ public void testTermFreq() throws IOException { // With missing terms { assertAllDocs( - Set.of(new Term("field", "foo"), new Term("field", "baz")), + Set.of(new Term("field", "foo"), new Term("field", "qux"), new Term("field", "baz")), ScriptTermStats::termFreq, Map.ofEntries( - Map.entry("doc-1", equalTo(new StatsSummary(2, 1, 0, 1))), - Map.entry("doc-2", equalTo(new StatsSummary(2, 2, 0, 2))), - Map.entry("doc-3", equalTo(new StatsSummary(2, 0, 0, 0))) + Map.entry("doc-1", equalTo(new StatsSummary(3, 2, 0, 1))), + Map.entry("doc-2", equalTo(new StatsSummary(3, 2, 0, 2))), + Map.entry("doc-3", equalTo(new StatsSummary(3, 0, 0, 0))) ) ); } @@ -274,10 +274,10 @@ public void testTermPositions() throws IOException { // With missing terms { assertAllDocs( - Set.of(new Term("field", "foo"), new Term("field", "baz")), + Set.of(new Term("field", "foo"), new Term("field", "qux"), new Term("field", "baz")), ScriptTermStats::termPositions, Map.ofEntries( - Map.entry("doc-1", equalTo(new StatsSummary(1, 1, 1, 1))), + Map.entry("doc-1", equalTo(new StatsSummary(2, 4, 1, 3))), Map.entry("doc-2", equalTo(new StatsSummary(2, 3, 1, 2))), Map.entry("doc-3", equalTo(new StatsSummary())) ) @@ -311,7 +311,7 @@ private void withIndexSearcher(CheckedConsumer consu Document doc = new Document(); doc.add(new TextField("id", "doc-1", Field.Store.YES)); - doc.add(new TextField("field", "foo bar", Field.Store.YES)); + doc.add(new TextField("field", "foo bar qux", Field.Store.YES)); w.addDocument(doc); doc = new Document();