Skip to content

Commit

Permalink
Continue implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Apr 30, 2024
1 parent 388aa8a commit 002d21d
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ private static FieldScript.LeafFactory newFieldScript(Expression expr, SearchLoo
private static FilterScript.LeafFactory newFilterScript(Expression expr, SearchLookup lookup, @Nullable Map<String, Object> vars) {
ScoreScript.LeafFactory searchLeafFactory = newScoreScript(expr, lookup, vars);
return docReader -> {
ScoreScript script = searchLeafFactory.newInstance(docReader, null);
ScoreScript script = searchLeafFactory.newInstance(docReader);
return new FilterScript(vars, lookup, docReader) {
@Override
public boolean execute() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ static_import {
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
java.util.DoubleSummaryStatistics documentFrequencyStatistics(org.elasticsearch.script.ScoreScript, String, String) bound_to org.elasticsearch.script.TermStatsScriptUtils$DocumentFrequencyStatistics
java.util.DoubleSummaryStatistics termFrequencyStatistics(org.elasticsearch.script.ScoreScript, String, String) bound_to org.elasticsearch.script.TermStatsScriptUtils$TermFrequencyStatistics
}

Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public void selectBestTerms() throws IOException {
continue;
}

Terms terms = fields.terms(fieldName);
Terms terms = fields.terms(fieldName);
Terms topLevelTerms = topLevelFields.terms(fieldName);

// if no terms found, take the retrieved term vector fields for stats
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
Expand Down Expand Up @@ -59,6 +60,8 @@
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.xcontent.XContentParserConfiguration;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -508,7 +511,15 @@ public void setLookupProviders(
new FieldDataContext(getFullyQualifiedIndex().getName(), searchLookup, this::sourcePath, fielddataOperation)
),
sourceProvider,
fieldLookupProvider
fieldLookupProvider,
(term) -> {
try {
TermStates termStates = TermStates.build(searcher, term, true);
return termStates;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
);
}

Expand Down
3 changes: 3 additions & 0 deletions server/src/main/java/org/elasticsearch/script/DocReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.script.field.Field;
import org.elasticsearch.search.lookup.Source;

import java.io.IOException;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Stream;
Expand Down Expand Up @@ -41,4 +42,6 @@ public interface DocReader {

/** Helper for source access */
Supplier<Source> source();

TermsStatsReader termsStatsReader(String fieldName, String query) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.lookup.Source;

import java.io.IOException;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Stream;
Expand All @@ -33,11 +34,13 @@ public class DocValuesDocReader implements DocReader, LeafReaderContextSupplier

/** A leaf lookup for the bound segment this proxy will operate on. */
protected LeafSearchLookup leafSearchLookup;
protected final TermsStatsReaderProvider termsStatsReaderProvider;

public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext) {
this.searchLookup = searchLookup;
this.leafReaderContext = leafContext;
this.leafSearchLookup = searchLookup.getLeafSearchLookup(leafReaderContext);
this.termsStatsReaderProvider = new TermsStatsReaderProvider(searchLookup, leafContext);
}

@Override
Expand Down Expand Up @@ -80,4 +83,8 @@ public LeafReaderContext getLeafReaderContext() {
public Supplier<Source> source() {
return leafSearchLookup.source();
}

public TermsStatsReader termsStatsReader(String fieldName, String query) throws IOException {
return termsStatsReaderProvider.termStatsReader(fieldName, query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}

public void _termsStatsReader(String fieldName, String query, String analyzer) throws IOException {
docReader.termsStatsReader(fieldName, query);
}

/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,36 @@

package org.elasticsearch.script;

import java.io.IOException;
import java.util.DoubleSummaryStatistics;
import java.util.stream.DoubleStream;
import java.util.function.Supplier;

public class TermStatsScriptUtils {

public static final class DocumentFrequencyStatistics {
public DocumentFrequencyStatistics(ScoreScript scoreScript, String fieldName, String query) {
private final TermsStatsReader termsStatsReader;

public DocumentFrequencyStatistics(ScoreScript scoreScript, String fieldName, String query) throws IOException {
this.termsStatsReader = scoreScript.docReader.termsStatsReader(fieldName, query);
}

public DoubleSummaryStatistics documentFrequencyStatistics() {
return DoubleStream.empty().summaryStatistics();
return termsStatsReader.docFrequencies().values().stream().mapToDouble(Integer::doubleValue).summaryStatistics();
}
}

public static final class TermFrequencyStatistics {
private final TermsStatsReader termsStatsReader;
private final Supplier<Integer> docIdSupplier;


public TermFrequencyStatistics(ScoreScript scoreScript, String fieldName, String query) throws IOException {
this.termsStatsReader = scoreScript.docReader.termsStatsReader(fieldName, query);
this.docIdSupplier = scoreScript::_getDocId;
}

public DoubleSummaryStatistics termFrequencyStatistics() throws IOException {
return termsStatsReader.termFrequencies(docIdSupplier.get()).values().stream().mapToDouble(Integer::doubleValue).summaryStatistics();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.script;

import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

public class TermsStatsReader {
private final Map<Term, TermStates> termStates;

private final Map<Term, PostingsEnum> postingsEnums;

public TermsStatsReader(Map<Term, TermStates> termStates, Map<Term, PostingsEnum> postingsEnums) {
this.termStates = termStates;
this.postingsEnums = postingsEnums;
}

public Map<Term, Integer> docFrequencies() {
return termStates.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().docFreq()));
}

public Map<Term, Integer> termFrequencies(int docId) throws IOException {
Map<Term, Integer> termFreqs = new HashMap<>();

for (Term term: postingsEnums.keySet()) {
PostingsEnum postingsEnum = postingsEnums.get(term);
if (postingsEnum.advance(docId) == docId){
termFreqs.put(term, postingsEnum.freq());
} else {
termFreqs.put(term, 0);
}
}

return termFreqs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.script;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.TermsEnum;
import org.elasticsearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class TermsStatsReaderProvider {

private final SearchLookup searchLookup;

private final LeafReaderContext leafReaderContext;

private final Map<CacheKey, TermsStatsReader> termsStatsReaderCache = new HashMap<>();

public TermsStatsReaderProvider(SearchLookup searchLookup, LeafReaderContext leafReaderContext) {
this.searchLookup = searchLookup;
this.leafReaderContext = leafReaderContext;
}

public TermsStatsReader termStatsReader(String fieldName, String query) throws IOException {
CacheKey cacheKey = new CacheKey(fieldName, query);

if (termsStatsReaderCache.containsKey(cacheKey) == false) {
termsStatsReaderCache.put(cacheKey, buildTermStatsReader(fieldName, query));
}

return termsStatsReaderCache.get(cacheKey);
}

private TermsStatsReader buildTermStatsReader(String fieldName, String query) throws IOException {

IndexReaderContext topLevelContext = ReaderUtil.getTopLevelContext(leafReaderContext);

final Map<Term, TermStates> termStates = new HashMap<>();
final Map<Term, PostingsEnum> postingsEnums = new HashMap<>();

try(
Analyzer analyzer = searchLookup.fieldType(fieldName).getTextSearchInfo().searchAnalyzer();
TokenStream ts = analyzer.tokenStream(fieldName, query);
) {
TermToBytesRefAttribute termAttr = ts.getAttribute(TermToBytesRefAttribute.class);
ts.reset();
while (ts.incrementToken()) {
Term term = new Term(fieldName, termAttr.getBytesRef());
TermStates termContext = searchLookup.getTermStates(term);
termStates.put(term, termContext);

TermsEnum termsEnum = leafReaderContext.reader().terms(term.field()).iterator();
termsEnum.seekExact(term.bytes(), termContext.get(leafReaderContext));
postingsEnums.put(term, termsEnum.postings(null, PostingsEnum.ALL));
}
}

return new TermsStatsReader(termStates, postingsEnums);
}

private record CacheKey(String fieldName, String query) { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.elasticsearch.search.lookup;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.MappedFieldType;
Expand Down Expand Up @@ -52,6 +54,8 @@ public class SearchLookup implements SourceProvider {
IndexFieldData<?>> fieldDataLookup;
private final Function<LeafReaderContext, LeafFieldLookupProvider> fieldLookupProvider;

private final Function<Term, TermStates> termStatesProvider;

/**
* Create a new SearchLookup, using the default stored fields provider
* @param fieldTypeLookup defines how to look up field types
Expand All @@ -63,7 +67,7 @@ public SearchLookup(
TriFunction<MappedFieldType, Supplier<SearchLookup>, MappedFieldType.FielddataOperation, IndexFieldData<?>> fieldDataLookup,
SourceProvider sourceProvider
) {
this(fieldTypeLookup, fieldDataLookup, sourceProvider, LeafFieldLookupProvider.fromStoredFields());
this(fieldTypeLookup, fieldDataLookup, sourceProvider, LeafFieldLookupProvider.fromStoredFields(), null);
}

/**
Expand All @@ -77,13 +81,15 @@ public SearchLookup(
Function<String, MappedFieldType> fieldTypeLookup,
TriFunction<MappedFieldType, Supplier<SearchLookup>, MappedFieldType.FielddataOperation, IndexFieldData<?>> fieldDataLookup,
SourceProvider sourceProvider,
Function<LeafReaderContext, LeafFieldLookupProvider> fieldLookupProvider
Function<LeafReaderContext, LeafFieldLookupProvider> fieldLookupProvider,
Function<Term, TermStates> termStatesProvider
) {
this.fieldTypeLookup = fieldTypeLookup;
this.fieldChain = Collections.emptySet();
this.sourceProvider = sourceProvider;
this.fieldDataLookup = fieldDataLookup;
this.fieldLookupProvider = fieldLookupProvider;
this.termStatesProvider = termStatesProvider;
}

/**
Expand All @@ -99,6 +105,7 @@ private SearchLookup(SearchLookup searchLookup, Set<String> fieldChain) {
this.fieldTypeLookup = searchLookup.fieldTypeLookup;
this.fieldDataLookup = searchLookup.fieldDataLookup;
this.fieldLookupProvider = searchLookup.fieldLookupProvider;
this.termStatesProvider = null;
}

/**
Expand Down Expand Up @@ -143,4 +150,8 @@ public IndexFieldData<?> getForField(MappedFieldType fieldType, MappedFieldType.
public Source getSource(LeafReaderContext ctx, int doc) throws IOException {
return sourceProvider.getSource(ctx, doc);
}

public TermStates getTermStates(Term term) {
return termStatesProvider.apply(term);
}
}

0 comments on commit 002d21d

Please sign in to comment.