diff --git a/CHANGELOG.md b/CHANGELOG.md index c64865a01eb37..199461fd93cd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Make SearchTemplateRequest implement IndicesRequest.Replaceable ([#9122]()https://github.com/opensearch-project/OpenSearch/pull/9122) - [BWC and API enforcement] Define the initial set of annotations, their meaning and relations between them ([#9223](https://github.com/opensearch-project/OpenSearch/pull/9223)) - [Segment Replication] Support realtime reads for GET requests ([#9212](https://github.com/opensearch-project/OpenSearch/pull/9212)) +- [Feature] Expose term frequency in Painless script score context ([#9081](https://github.com/opensearch-project/OpenSearch/pull/9081)) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307)) @@ -164,4 +165,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Security [Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD -[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x \ No newline at end of file +[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x diff --git a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java index 6be299146a181..3932559f7685c 100644 --- a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java +++ b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java @@ -66,7 +66,7 @@ public boolean needs_score() { @Override public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException { - return new ScoreScript(null, null, null) { + return new ScoreScript(null, null, null, null) { // Fake the scorer until setScorer is called. DoubleValues values = source.getValues(leaf, new DoubleValues() { @Override diff --git a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java index 1c3dc69359952..035d2402857e0 100644 --- a/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java +++ b/modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScriptEngine.java @@ -37,6 +37,7 @@ import org.apache.lucene.expressions.js.JavascriptCompiler; import org.apache.lucene.expressions.js.VariableContext; import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.SpecialPermission; import org.opensearch.common.Nullable; import org.opensearch.index.fielddata.IndexFieldData; @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map params, SearchLoo contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() { @Override - public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup) { + public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup, IndexSearcher indexSearcher) { return newScoreScript(expr, lookup, params); } diff --git a/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java b/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java index f5193b393ee88..67b298eee7973 100644 --- a/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java +++ b/modules/lang-painless/src/main/java/org/opensearch/painless/action/PainlessExecuteAction.java @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService } else if (scriptContext == ScoreScript.CONTEXT) { return prepareRamIndex(request, (context, leafReaderContext) -> { ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT); - ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup()); + ScoreScript.LeafFactory leafFactory = factory.newFactory( + request.getScript().getParams(), + context.lookup(), + context.searcher() + ); ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext); scoreScript.setDocument(0); diff --git a/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt b/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt index 61d53608a30c8..5533f0bc55522 100644 --- a/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt +++ b/modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import { } static_import { + int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq + float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF + long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq + long sumTotalTermFreq(org.opensearch.script.ScoreScript, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/120_script_score_term_frequency.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/120_script_score_term_frequency.yml new file mode 100644 index 0000000000000..b3ff66251938d --- /dev/null +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/120_script_score_term_frequency.yml @@ -0,0 +1,95 @@ +--- +setup: + - skip: + version: " - 2.9.99" + reason: "termFreq functions for script_score was introduced in 2.10.0" + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + mappings: + properties: + f1: + type: keyword + f2: + type: text + - do: + bulk: + refresh: true + body: + - '{"index": {"_index": "test", "_id": "doc1"}}' + - '{"f1": "v0", "f2": "v1"}' + - '{"index": {"_index": "test", "_id": "doc2"}}' + - '{"f2": "v2"}' + +--- +"Script score function using the termFreq function": + - do: + search: + index: test + rest_total_hits_as_int: true + body: + query: + function_score: + query: + match_all: {} + script_score: + script: + source: "termFreq(params.field, params.term)" + params: + field: "f1" + term: "v0" + - match: { hits.total: 2 } + - match: { hits.hits.0._id: "doc1" } + - match: { hits.hits.1._id: "doc2" } + - match: { hits.hits.0._score: 1.0 } + - match: { hits.hits.1._score: 0.0 } + +--- +"Script score function using the totalTermFreq function": + - do: + search: + index: test + rest_total_hits_as_int: true + body: + query: + function_score: + query: + match_all: {} + script_score: + script: + source: "if (doc[params.field].size() == 0) return params.default_value; else { return totalTermFreq(params.field, params.term); }" + params: + default_value: 0.5 + field: "f1" + term: "v0" + - match: { hits.total: 2 } + - match: { hits.hits.0._id: "doc1" } + - match: { hits.hits.1._id: "doc2" } + - match: { hits.hits.0._score: 1.0 } + - match: { hits.hits.1._score: 0.5 } + +--- +"Script score function using the sumTotalTermFreq function": + - do: + search: + index: test + rest_total_hits_as_int: true + body: + query: + function_score: + query: + match_all: {} + script_score: + script: + source: "if (doc[params.field].size() == 0) return params.default_value; else { return sumTotalTermFreq(params.field); }" + params: + default_value: 0.5 + field: "f1" + - match: { hits.total: 2 } + - match: { hits.hits.0._id: "doc1" } + - match: { hits.hits.1._id: "doc2" } + - match: { hits.hits.0._score: 1.0 } + - match: { hits.hits.1._score: 0.5 } diff --git a/plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java b/plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java index e7615d9ad7204..07c2d4d6435d7 100644 --- a/plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java +++ b/plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.common.settings.Settings; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.ScriptPlugin; @@ -120,20 +121,22 @@ public boolean isResultDeterministic() { @Override public LeafFactory newFactory( Map params, - SearchLookup lookup + SearchLookup lookup, + IndexSearcher indexSearcher ) { - return new PureDfLeafFactory(params, lookup); + return new PureDfLeafFactory(params, lookup, indexSearcher); } } private static class PureDfLeafFactory implements LeafFactory { private final Map params; private final SearchLookup lookup; + private final IndexSearcher indexSearcher; private final String field; private final String term; private PureDfLeafFactory( - Map params, SearchLookup lookup) { + Map params, SearchLookup lookup, IndexSearcher indexSearcher) { if (params.containsKey("field") == false) { throw new IllegalArgumentException( "Missing parameter [field]"); @@ -144,6 +147,7 @@ private PureDfLeafFactory( } this.params = params; this.lookup = lookup; + this.indexSearcher = indexSearcher; field = params.get("field").toString(); term = params.get("term").toString(); } @@ -163,7 +167,7 @@ public ScoreScript newInstance(LeafReaderContext context) * the field and/or term don't exist in this segment, * so always return 0 */ - return new ScoreScript(params, lookup, context) { + return new ScoreScript(params, lookup, indexSearcher, context) { @Override public double execute( ExplanationHolder explanation @@ -172,7 +176,7 @@ public double execute( } }; } - return new ScoreScript(params, lookup, context) { + return new ScoreScript(params, lookup, indexSearcher, context) { int currentDocid = -1; @Override public void setDocument(int docid) { diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java index 3651a7354e5de..f329677a94340 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchType; @@ -93,7 +94,7 @@ public String getType() { public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { assert scriptSource.equals("explainable_script"); assert context == ScoreScript.CONTEXT; - ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() { + ScoreScript.Factory factory = (params1, lookup, indexSearcher) -> new ScoreScript.LeafFactory() { @Override public boolean needs_score() { return false; @@ -101,7 +102,7 @@ public boolean needs_score() { @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return new MyScript(params1, lookup, ctx); + return new MyScript(params1, lookup, indexSearcher, ctx); } }; return context.factoryClazz.cast(factory); @@ -117,8 +118,8 @@ public Set> getSupportedContexts() { static class MyScript extends ScoreScript implements ExplainableScoreScript { - MyScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { - super(params, lookup, leafContext); + MyScript(Map params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) { + super(params, lookup, indexSearcher, leafContext); } @Override diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index e241211911502..3dadaeada2e60 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -114,7 +114,7 @@ protected int doHashCode() { protected ScoreFunction doToFunction(QueryShardContext context) { try { ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); - ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); + ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher()); return new ScriptScoreFunction( script, searchScript, diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java index 51c4362b6e257..e302ebcee4ba7 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreQueryBuilder.java @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { ); } ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); - ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup()); + ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher()); final QueryBuilder queryBuilder = this.query; Query query = queryBuilder.toQuery(context); return new ScriptScoreQuery( diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java new file mode 100644 index 0000000000000..95fbecc53f4ae --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunction.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query.functionscore; + +import java.io.IOException; + +/** + * An interface representing a term frequency function used to compute document scores + * based on specific term frequency calculations. Implementations of this interface should + * provide a way to execute the term frequency function for a given document ID. + * + * @opensearch.internal + */ +public interface TermFrequencyFunction { + Object execute(int docId) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java new file mode 100644 index 0000000000000..4edcd34889abd --- /dev/null +++ b/server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query.functionscore; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TFValueSource; +import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.common.lucene.BytesRefs; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * A factory class for creating instances of {@link TermFrequencyFunction}. + * This class provides methods for creating different term frequency functions based on + * the specified function name, field, and term. Each term frequency function is designed + * to compute document scores based on specific term frequency calculations. + * + * @opensearch.internal + */ +public class TermFrequencyFunctionFactory { + public static TermFrequencyFunction createFunction( + TermFrequencyFunctionName functionName, + String field, + String term, + LeafReaderContext readerContext, + IndexSearcher indexSearcher + ) throws IOException { + switch (functionName) { + case TERM_FREQ: + TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term)); + FunctionValues functionValues = termFreqValueSource.getValues(null, readerContext); + return docId -> functionValues.intVal(docId); + case TF: + TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term)); + Map tfContext = new HashMap<>() { + { + put("searcher", indexSearcher); + } + }; + functionValues = tfValueSource.getValues(tfContext, readerContext); + return docId -> functionValues.floatVal(docId); + case TOTAL_TERM_FREQ: + TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource( + field, + term, + field, + BytesRefs.toBytesRef(term) + ); + Map ttfContext = new HashMap<>(); + totalTermFreqValueSource.createWeight(ttfContext, indexSearcher); + functionValues = totalTermFreqValueSource.getValues(ttfContext, readerContext); + return docId -> functionValues.longVal(docId); + case SUM_TOTAL_TERM_FREQ: + SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field); + Map sttfContext = new HashMap<>(); + sumTotalTermFreqValueSource.createWeight(sttfContext, indexSearcher); + functionValues = sumTotalTermFreqValueSource.getValues(sttfContext, readerContext); + return docId -> functionValues.longVal(docId); + default: + throw new IllegalArgumentException("Unsupported function: " + functionName); + } + } + + /** + * An enumeration representing the names of supported term frequency functions. + */ + public enum TermFrequencyFunctionName { + TERM_FREQ("termFreq"), + TF("tf"), + TOTAL_TERM_FREQ("totalTermFreq"), + SUM_TOTAL_TERM_FREQ("sumTotalTermFreq"); + + private final String termFrequencyFunctionName; + + TermFrequencyFunctionName(String termFrequencyFunctionName) { + this.termFrequencyFunctionName = termFrequencyFunctionName; + } + + public String getTermFrequencyFunctionName() { + return termFrequencyFunctionName; + } + } +} diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 5c6553ffc2a28..70de636a655f2 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -33,11 +33,14 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Scorable; import org.opensearch.Version; import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.index.fielddata.ScriptDocValues; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName; import org.opensearch.search.lookup.LeafSearchLookup; +import org.opensearch.search.lookup.LeafTermFrequencyLookup; import org.opensearch.search.lookup.SearchLookup; import org.opensearch.search.lookup.SourceLookup; @@ -107,6 +110,9 @@ public Explanation get(double score, Explanation subQueryExplanation) { /** A leaf lookup for the bound segment this script will operate on. */ private final LeafSearchLookup leafLookup; + /** A leaf term frequency lookup for the bound segment this script will operate on. */ + private final LeafTermFrequencyLookup leafTermFrequencyLookup; + private DoubleSupplier scoreSupplier = () -> 0.0; private final int docBase; @@ -115,16 +121,18 @@ public Explanation get(double score, Explanation subQueryExplanation) { private String indexName = null; private Version indexVersion = null; - public ScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + public ScoreScript(Map params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass if (lookup == null) { assert params == null; assert leafContext == null; this.params = null; this.leafLookup = null; + this.leafTermFrequencyLookup = null; this.docBase = 0; } else { this.leafLookup = lookup.getLeafSearchLookup(leafContext); + this.leafTermFrequencyLookup = new LeafTermFrequencyLookup(indexSearcher, leafLookup); params = new HashMap<>(params); params.putAll(leafLookup.asMap()); this.params = new DynamicMap(params, PARAMS_FUNCTIONS); @@ -144,6 +152,10 @@ public Map> getDoc() { return leafLookup.doc(); } + public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val) throws IOException { + return leafTermFrequencyLookup.getTermFrequency(functionName, field, val, docId); + } + /** Set the current document to run the script on next. */ public void setDocument(int docid) { this.docId = docid; @@ -268,7 +280,7 @@ public interface LeafFactory { */ public interface Factory extends ScriptFactory { - ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup); + ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup, IndexSearcher indexSearcher); } diff --git a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java index 76d0a8bb44da0..0767c29fa1b31 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java +++ b/server/src/main/java/org/opensearch/script/ScoreScriptUtils.java @@ -48,6 +48,10 @@ import java.time.ZonedDateTime; import static org.opensearch.common.util.BitMixer.mix32; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.SUM_TOTAL_TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TERM_FREQ; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TF; +import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TOTAL_TERM_FREQ; /** * Utilities for scoring scripts @@ -70,6 +74,90 @@ public static double sigmoid(double value, double k, double a) { return Math.pow(value, a) / (Math.pow(k, a) + Math.pow(value, a)); } + /** + * Retrieves the term frequency within a field for a specific term. + * + * @opensearch.internal + */ + public static final class TermFreq { + private final ScoreScript scoreScript; + + public TermFreq(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public int termFreq(String field, String term) { + try { + return (int) scoreScript.getTermFrequency(TERM_FREQ, field, term); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + + /** + * Calculates the term frequency-inverse document frequency (tf-idf) for a specific term within a field. + * + * @opensearch.internal + */ + public static final class TF { + private final ScoreScript scoreScript; + + public TF(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public float tf(String field, String term) { + try { + return (float) scoreScript.getTermFrequency(TF, field, term); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + + /** + * Retrieves the total term frequency within a field for a specific term. + * + * @opensearch.internal + */ + public static final class TotalTermFreq { + private final ScoreScript scoreScript; + + public TotalTermFreq(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public long totalTermFreq(String field, String term) { + try { + return (long) scoreScript.getTermFrequency(TOTAL_TERM_FREQ, field, term); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + + /** + * Retrieves the sum of total term frequencies within a field. + * + * @opensearch.internal + */ + public static final class SumTotalTermFreq { + private final ScoreScript scoreScript; + + public SumTotalTermFreq(ScoreScript scoreScript) { + this.scoreScript = scoreScript; + } + + public long sumTotalTermFreq(String field) { + try { + return (long) scoreScript.getTermFrequency(SUM_TOTAL_TERM_FREQ, field, null); + } catch (Exception e) { + throw ExceptionsHelper.convertToOpenSearchException(e); + } + } + } + /** * random score based on the documents' values of the given field * diff --git a/server/src/main/java/org/opensearch/search/lookup/LeafTermFrequencyLookup.java b/server/src/main/java/org/opensearch/search/lookup/LeafTermFrequencyLookup.java new file mode 100644 index 0000000000000..d02313ada1db9 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/lookup/LeafTermFrequencyLookup.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.lookup; + +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.index.query.functionscore.TermFrequencyFunction; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory; +import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +/** + * Looks up term frequency per-segment + * + * @opensearch.internal + */ +public class LeafTermFrequencyLookup { + + private final IndexSearcher indexSearcher; + private final LeafSearchLookup leafLookup; + private final Map termFreqCache; + + public LeafTermFrequencyLookup(IndexSearcher indexSearcher, LeafSearchLookup leafLookup) { + this.indexSearcher = indexSearcher; + this.leafLookup = leafLookup; + this.termFreqCache = new HashMap<>(); + } + + public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val, int docId) throws IOException { + TermFrequencyFunction termFrequencyFunction = getOrCreateTermFrequencyFunction(functionName, field, val); + return termFrequencyFunction.execute(docId); + } + + private TermFrequencyFunction getOrCreateTermFrequencyFunction(TermFrequencyFunctionName functionName, String field, String val) + throws IOException { + String cacheKey = (val == null) + ? String.format(Locale.ROOT, "%s-%s", functionName, field) + : String.format(Locale.ROOT, "%s-%s-%s", functionName, field, val); + + if (!termFreqCache.containsKey(cacheKey)) { + TermFrequencyFunction termFrequencyFunction = TermFrequencyFunctionFactory.createFunction( + functionName, + field, + val, + leafLookup.ctx, + indexSearcher + ); + termFreqCache.put(cacheKey, termFrequencyFunction); + } + + return termFreqCache.get(cacheKey); + } +} diff --git a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java index e1002e114822e..ca4b7dc49f6f0 100644 --- a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java @@ -184,6 +184,7 @@ private ScoreScript.LeafFactory newFactory( ) { SearchLookup lookup = mock(SearchLookup.class); LeafSearchLookup leafLookup = mock(LeafSearchLookup.class); + IndexSearcher indexSearcher = mock(IndexSearcher.class); when(lookup.getLeafSearchLookup(any())).thenReturn(leafLookup); return new ScoreScript.LeafFactory() { @Override @@ -193,7 +194,7 @@ public boolean needs_score() { @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return new ScoreScript(script.getParams(), lookup, leafReaderContext) { + return new ScoreScript(script.getParams(), lookup, indexSearcher, leafReaderContext) { @Override public double execute(ExplanationHolder explanation) { return function.apply(explanation); diff --git a/test/framework/src/main/java/org/opensearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/opensearch/script/MockScriptEngine.java index 98912e53c9d6a..cb0614ddeb808 100644 --- a/test/framework/src/main/java/org/opensearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/opensearch/script/MockScriptEngine.java @@ -33,6 +33,7 @@ package org.opensearch.script; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Scorable; import org.opensearch.index.query.IntervalFilterScript; import org.opensearch.index.similarity.ScriptedSimilarity.Doc; @@ -624,7 +625,7 @@ public MockScoreScript(MockDeterministicScript script) { } @Override - public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup) { + public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup, IndexSearcher indexSearcher) { return new ScoreScript.LeafFactory() { @Override public boolean needs_score() { @@ -634,7 +635,7 @@ public boolean needs_score() { @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { Scorable[] scorerHolder = new Scorable[1]; - return new ScoreScript(params, lookup, ctx) { + return new ScoreScript(params, lookup, indexSearcher, ctx) { @Override public double execute(ExplanationHolder explanation) { Map vars = new HashMap<>(getParams());