From 7ae48dfadf637aa185078cdb5f8657e57d22088f Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Sun, 1 Oct 2023 08:28:41 +0800 Subject: [PATCH] Add max token score for SparseEncodingQueryBuilder and do renaming (#348) * add lucene FeatureQuery Signed-off-by: zhichao-aws * add max token score Signed-off-by: zhichao-aws * add comments Signed-off-by: zhichao-aws * add check and test Signed-off-by: zhichao-aws * add doc Signed-off-by: zhichao-aws * add change log Signed-off-by: zhichao-aws * Address PR comments and change processor and query clause name. (#9) * Address code review comments Signed-off-by: zane-neo * Change lower case neural_sparse to upper case Signed-off-by: zane-neo * Change back processor type name to sparse_encoding Signed-off-by: zane-neo * Change names Signed-off-by: zane-neo * Format code Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --------- Signed-off-by: zhichao-aws Signed-off-by: zane-neo Co-authored-by: zane-neo --- CHANGELOG.md | 1 + .../lucene/BoundedLinearFeatureQuery.java | 235 ++++++++++++++++++ .../neuralsearch/plugin/NeuralSearch.java | 6 +- ...Processor.java => InferenceProcessor.java} | 10 +- .../processor/SparseEncodingProcessor.java | 2 +- .../processor/TextEmbeddingProcessor.java | 2 +- ...der.java => NeuralSparseQueryBuilder.java} | 50 ++-- .../neuralsearch/util/TokenWeightUtil.java | 2 +- ...ava => NeuralSparseQueryBuilderTests.java} | 109 +++++--- ...gQueryIT.java => NeuralSparseQueryIT.java} | 137 ++++++---- .../SparseEncodingIndexMappings.json | 2 +- .../SparseEncodingPipelineConfiguration.json | 2 +- .../UploadSparseEncodingModelRequestBody.json | 2 +- 13 files changed, 455 insertions(+), 105 deletions(-) create mode 100644 src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java rename src/main/java/org/opensearch/neuralsearch/processor/{NLPProcessor.java => InferenceProcessor.java} (97%) rename src/main/java/org/opensearch/neuralsearch/query/{SparseEncodingQueryBuilder.java => NeuralSparseQueryBuilder.java} (81%) rename src/test/java/org/opensearch/neuralsearch/query/{SparseEncodingQueryBuilderTests.java => NeuralSparseQueryBuilderTests.java} (73%) rename src/test/java/org/opensearch/neuralsearch/query/{SparseEncodingQueryIT.java => NeuralSparseQueryIT.java} (61%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f448684a..f552c534f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements +Add `max_token_score` parameter to improve the execution efficiency for `neural_sparse` query clause ([#348](https://github.com/opensearch-project/neural-search/pull/348)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java b/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java new file mode 100644 index 000000000..617662363 --- /dev/null +++ b/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java @@ -0,0 +1,235 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* + * This class is built based on lucene FeatureQuery. We use LinearFuntion to + * build the query and add an upperbound to it. + */ + +package org.apache.lucene; + +import java.io.IOException; +import java.util.Objects; + +import org.apache.lucene.index.ImpactsEnum; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.ImpactsDISI; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.similarities.Similarity.SimScorer; +import org.apache.lucene.util.BytesRef; + +/** + * The feature queries of input tokens are wrapped by lucene BooleanQuery, which use WAND algorithm + * to accelerate the execution. The WAND algorithm leverage the score upper bound of sub-queries to + * skip non-competitive tokens. However, origin lucene FeatureQuery use Float.MAX_VALUE as the score + * upper bound, and this invalidates WAND. + * + * To mitigate this issue, we rewrite the FeatureQuery to BoundedLinearFeatureQuery. The caller can + * set the token score upperbound of this query. And according to our use case, we use LinearFunction + * as the score function. + * + * This class combines both FeatureQuery + * and FeatureField together + * and will be deprecated after OpenSearch upgraded lucene to version 9.8. + */ + +public final class BoundedLinearFeatureQuery extends Query { + + private final String fieldName; + private final String featureName; + private final Float scoreUpperBound; + + public BoundedLinearFeatureQuery(String fieldName, String featureName, Float scoreUpperBound) { + this.fieldName = Objects.requireNonNull(fieldName); + this.featureName = Objects.requireNonNull(featureName); + this.scoreUpperBound = Objects.requireNonNull(scoreUpperBound); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + // LinearFunction return same object for rewrite + return super.rewrite(indexSearcher); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj; + return Objects.equals(fieldName, that.fieldName) + && Objects.equals(featureName, that.featureName) + && Objects.equals(scoreUpperBound, that.scoreUpperBound); + } + + @Override + public int hashCode() { + int h = getClass().hashCode(); + h = 31 * h + fieldName.hashCode(); + h = 31 * h + featureName.hashCode(); + h = 31 * h + scoreUpperBound.hashCode(); + return h; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (!scoreMode.needsScores()) { + // We don't need scores (e.g. for faceting), and since features are stored as terms, + // allow TermQuery to optimize in this case + TermQuery tq = new TermQuery(new Term(fieldName, featureName)); + return searcher.rewrite(tq).createWeight(searcher, scoreMode, boost); + } + + return new Weight(this) { + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + String desc = "weight(" + getQuery() + " in " + doc + ") [\" BoundedLinearFeatureQuery \"]"; + + Terms terms = context.reader().terms(fieldName); + if (terms == null) { + return Explanation.noMatch(desc + ". Field " + fieldName + " doesn't exist."); + } + TermsEnum termsEnum = terms.iterator(); + if (termsEnum.seekExact(new BytesRef(featureName)) == false) { + return Explanation.noMatch(desc + ". Feature " + featureName + " doesn't exist."); + } + + PostingsEnum postings = termsEnum.postings(null, PostingsEnum.FREQS); + if (postings.advance(doc) != doc) { + return Explanation.noMatch(desc + ". Feature " + featureName + " isn't set."); + } + + int freq = postings.freq(); + float featureValue = decodeFeatureValue(freq); + float score = boost * featureValue; + return Explanation.match( + score, + "Linear function on the " + fieldName + " field for the " + featureName + " feature, computed as w * S from:", + Explanation.match(boost, "w, weight of this function"), + Explanation.match(featureValue, "S, feature value") + ); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Terms terms = Terms.getTerms(context.reader(), fieldName); + TermsEnum termsEnum = terms.iterator(); + if (termsEnum.seekExact(new BytesRef(featureName)) == false) { + return null; + } + + final SimScorer scorer = new SimScorer() { + @Override + public float score(float freq, long norm) { + return boost * decodeFeatureValue(freq); + } + }; + final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS); + final ImpactsDISI impactsDisi = new ImpactsDISI(impacts, impacts, scorer); + + return new Scorer(this) { + + @Override + public int docID() { + return impacts.docID(); + } + + @Override + public float score() throws IOException { + return scorer.score(impacts.freq(), 1L); + } + + @Override + public DocIdSetIterator iterator() { + return impactsDisi; + } + + @Override + public int advanceShallow(int target) throws IOException { + return impactsDisi.advanceShallow(target); + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return impactsDisi.getMaxScore(upTo); + } + + @Override + public void setMinCompetitiveScore(float minScore) { + impactsDisi.setMinCompetitiveScore(minScore); + } + }; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(fieldName)) { + visitor.visitLeaf(this); + } + } + + @Override + public String toString(String field) { + return "BoundedLinearFeatureQuery(field=" + fieldName + ", feature=" + featureName + ", scoreUpperBound=" + scoreUpperBound + ")"; + } + + // the field and decodeFeatureValue are modified from FeatureField.decodeFeatureValue + static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15; + + // Rewriting this function to make scoreUpperBound work. + private float decodeFeatureValue(float freq) { + if (freq > MAX_FREQ) { + return scoreUpperBound; + } + int tf = (int) freq; // lossless + int featureBits = tf << 15; + return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 2ac8853e4..3ce9582f4 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -41,7 +41,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; @@ -81,7 +81,7 @@ public Collection createComponents( final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); - SparseEncodingQueryBuilder.initialize(clientAccessor); + NeuralSparseQueryBuilder.initialize(clientAccessor); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } @@ -91,7 +91,7 @@ public List> getQueries() { return Arrays.asList( new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent), new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent), - new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent) + new QuerySpec<>(NeuralSparseQueryBuilder.NAME, NeuralSparseQueryBuilder::new, NeuralSparseQueryBuilder::fromXContent) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java similarity index 97% rename from src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 4ac63d419..acf0eb32b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -32,7 +32,7 @@ * and set the target fields according to the field name map. */ @Log4j2 -public abstract class NLPProcessor extends AbstractProcessor { +public abstract class InferenceProcessor extends AbstractProcessor { public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; @@ -51,7 +51,7 @@ public abstract class NLPProcessor extends AbstractProcessor { private final Environment environment; - public NLPProcessor( + public InferenceProcessor( String tag, String description, String type, @@ -249,7 +249,7 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { - NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0); + InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0); Map result = new LinkedHashMap<>(); for (Map.Entry knnMapEntry : processorMap.entrySet()) { String knnKey = knnMapEntry.getKey(); @@ -270,7 +270,7 @@ private void putNLPResultToSourceMapForMapType( String processorKey, Object sourceValue, List results, - NLPProcessor.IndexWrapper indexWrapper, + InferenceProcessor.IndexWrapper indexWrapper, Map sourceAndMetadataMap ) { if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; @@ -294,7 +294,7 @@ private void putNLPResultToSourceMapForMapType( private List> buildNLPResultForListType( List sourceValue, List results, - NLPProcessor.IndexWrapper indexWrapper + InferenceProcessor.IndexWrapper indexWrapper ) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 275117809..b5bb85aac 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -22,7 +22,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results. */ @Log4j2 -public final class SparseEncodingProcessor extends NLPProcessor { +public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 1df60baea..c30d14caf 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -21,7 +21,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. */ @Log4j2 -public final class TextEmbeddingProcessor extends NLPProcessor { +public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java similarity index 81% rename from src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java rename to src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 4b8b6f0d4..fd15b431b 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -21,9 +21,10 @@ import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; -import org.apache.lucene.document.FeatureField; +import org.apache.lucene.BoundedLinearFeatureQuery; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; @@ -44,7 +45,7 @@ import com.google.common.annotations.VisibleForTesting; /** - * SparseEncodingQueryBuilder is responsible for handling "sparse_encoding" query types. It uses an ML SPARSE_ENCODING model + * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model * or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed * to Lucene FeatureQuery wrapped by Lucene BooleanQuery. */ @@ -55,22 +56,25 @@ @Accessors(chain = true, fluent = true) @NoArgsConstructor @AllArgsConstructor -public class SparseEncodingQueryBuilder extends AbstractQueryBuilder { - public static final String NAME = "sparse_encoding"; +public class NeuralSparseQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "neural_sparse"; @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); @VisibleForTesting static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); + @VisibleForTesting + static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score"); private static MLCommonsClientAccessor ML_CLIENT; public static void initialize(MLCommonsClientAccessor mlClient) { - SparseEncodingQueryBuilder.ML_CLIENT = mlClient; + NeuralSparseQueryBuilder.ML_CLIENT = mlClient; } private String fieldName; private String queryText; private String modelId; + private Float maxTokenScore; private Supplier> queryTokensSupplier; /** @@ -79,11 +83,12 @@ public static void initialize(MLCommonsClientAccessor mlClient) { * @param in StreamInput to initialize object from * @throws IOException thrown if unable to read from input stream */ - public SparseEncodingQueryBuilder(StreamInput in) throws IOException { + public NeuralSparseQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.queryText = in.readString(); this.modelId = in.readString(); + this.maxTokenScore = in.readOptionalFloat(); } @Override @@ -91,6 +96,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); out.writeString(modelId); + out.writeOptionalFloat(maxTokenScore); } @Override @@ -99,6 +105,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws xContentBuilder.startObject(fieldName); xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore); printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -108,15 +115,16 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * The expected parsing form looks like: * "SAMPLE_FIELD": { * "query_text": "string", - * "model_id": "string" + * "model_id": "string", + * "token_score_upper_bound": float (optional) * } * * @param parser XContentParser * @return NeuralQueryBuilder * @throws IOException can be thrown by parser */ - public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder(); + public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throws IOException { + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder(); if (parser.currentToken() != XContentParser.Token.START_OBJECT) { throw new ParsingException(parser.getTokenLocation(), "First token of " + NAME + "query must be START_OBJECT"); } @@ -146,11 +154,14 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr sparseEncodingQueryBuilder.modelId(), String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) ); + if (sparseEncodingQueryBuilder.maxTokenScore != null && sparseEncodingQueryBuilder.maxTokenScore <= 0) { + throw new IllegalArgumentException(MAX_TOKEN_SCORE_FIELD.getPreferredName() + " must be larger than 0."); + } return sparseEncodingQueryBuilder; } - private static void parseQueryParams(XContentParser parser, SparseEncodingQueryBuilder sparseEncodingQueryBuilder) throws IOException { + private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBuilder sparseEncodingQueryBuilder) throws IOException { XContentParser.Token token; String currentFieldName = ""; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { @@ -165,6 +176,8 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB sparseEncodingQueryBuilder.queryText(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseEncodingQueryBuilder.modelId(parser.text()); + } else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -200,9 +213,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws }, actionListener::onFailure) )) ); - return new SparseEncodingQueryBuilder().fieldName(fieldName) + return new NeuralSparseQueryBuilder().fieldName(fieldName) .queryText(queryText) .modelId(modelId) + .maxTokenScore(maxTokenScore) .queryTokensSupplier(queryTokensSetOnce::get); } @@ -214,9 +228,14 @@ protected Query doToQuery(QueryShardContext context) throws IOException { Map queryTokens = queryTokensSupplier.get(); validateQueryTokens(queryTokens); + final Float scoreUpperBound = maxTokenScore != null ? maxTokenScore : Float.MAX_VALUE; + BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Map.Entry entry : queryTokens.entrySet()) { - builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD); + builder.add( + new BoostQuery(new BoundedLinearFeatureQuery(fieldName, entry.getKey(), scoreUpperBound), entry.getValue()), + BooleanClause.Occur.SHOULD + ); } return builder.build(); } @@ -254,18 +273,19 @@ private static void validateQueryTokens(Map queryTokens) { } @Override - protected boolean doEquals(SparseEncodingQueryBuilder obj) { + protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) - .append(modelId, obj.modelId); + .append(modelId, obj.modelId) + .append(maxTokenScore, obj.maxTokenScore); return equalsBuilder.isEquals(); } @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode(); + return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore).toHashCode(); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 76ce0fa16..853fc743d 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -12,7 +12,7 @@ import java.util.stream.Collectors; /** - * Utility class for working with sparse_encoding queries and ingest processor. + * Utility class for working with neural_sparse queries and ingest processor. * Used to fetch the (token, weight) Map from the response returned by {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} * */ diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java similarity index 73% rename from src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java rename to src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 6cb122c4f..34850dcb7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -11,9 +11,10 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; -import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD; -import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME; -import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MAX_TOKEN_SCORE_FIELD; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD; import java.io.IOException; import java.util.List; @@ -42,13 +43,14 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.test.OpenSearchTestCase; -public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { +public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase { private static final String FIELD_NAME = "testField"; private static final String QUERY_TEXT = "Hello world!"; private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; + private static final Float MAX_TOKEN_SCORE = 123f; private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); @SneakyThrows @@ -71,7 +73,7 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); @@ -85,6 +87,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { "VECTOR_FIELD": { "query_text": "string", "model_id": "string", + "max_token_score": 123.0, "boost": 10.0, "_name": "something", } @@ -95,6 +98,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE) .field(BOOST_FIELD.getPreferredName(), BOOST) .field(NAME_FIELD.getPreferredName(), QUERY_NAME) .endObject() @@ -102,11 +106,12 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); + assertEquals(MAX_TOKEN_SCORE, sparseEncodingQueryBuilder.maxTokenScore(), 0.0); assertEquals(BOOST, sparseEncodingQueryBuilder.boost(), 0.0); assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName()); } @@ -137,7 +142,7 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - expectThrows(ParsingException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + expectThrows(ParsingException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); } @SneakyThrows @@ -158,7 +163,31 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() { XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithNegativeMaxTokenScore_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "max_token_score": -1 + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), -1f) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); } @SneakyThrows @@ -179,7 +208,7 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() { XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); } @SneakyThrows @@ -206,15 +235,16 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - expectThrows(IOException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); } @SuppressWarnings("unchecked") @SneakyThrows public void testToXContent() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) .modelId(MODEL_ID) - .queryText(QUERY_TEXT); + .queryText(QUERY_TEXT) + .maxTokenScore(MAX_TOKEN_SCORE); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -239,13 +269,15 @@ public void testToXContent() { assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); + assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0); } @SneakyThrows public void testStreams() { - SparseEncodingQueryBuilder original = new SparseEncodingQueryBuilder(); + NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); + original.maxTokenScore(MAX_TOKEN_SCORE); original.modelId(MODEL_ID); original.boost(BOOST); original.queryName(QUERY_NAME); @@ -260,7 +292,7 @@ public void testStreams() { ) ); - SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput); + NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput); assertEquals(original, copy); } @@ -271,64 +303,82 @@ public void testHashAndEquals() { String queryText2 = "query text 2"; String modelId1 = "model-1"; String modelId2 = "model-2"; + float maxTokenScore1 = 1.1f; + float maxTokenScore2 = 2.2f; float boost1 = 1.8f; float boost2 = 3.8f; String queryName1 = "query-1"; String queryName2 = "query-2"; - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName( - fieldName1 - ).queryText(queryText1).modelId(modelId1); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore1); // Identical to sparseEncodingQueryBuilder_baseline except diff field name - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff query text - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText2) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff model ID - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId2) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff boost - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost2) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff query name - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName2); + // Identical to sparseEncodingQueryBuilder_baseline except diff max token score + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore2) + .boost(boost1) + .queryName(queryName1); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -352,11 +402,14 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode()); } @SneakyThrows public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID); Map expectedMap = Map.of("1", 1f, "2", 2f); @@ -366,7 +419,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); return null; }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); - SparseEncodingQueryBuilder.initialize(mlCommonsClientAccessor); + NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); @@ -382,7 +435,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() return null; }).when(queryRewriteContext).registerAsyncAction(any()); - SparseEncodingQueryBuilder queryBuilder = (SparseEncodingQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); assertNotNull(queryBuilder.queryTokensSupplier()); assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get()); @@ -390,7 +443,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() @SneakyThrows public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java similarity index 61% rename from src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java rename to src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java index 54991d7e2..672ab2940 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java @@ -21,16 +21,16 @@ import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; -public class SparseEncodingQueryIT extends BaseSparseEncodingIT { +public class NeuralSparseQueryIT extends BaseSparseEncodingIT { private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; - private static final String TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; - private static final String TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; + private static final String TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; + private static final String TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; private static final String TEST_NESTED_INDEX_NAME = "test-sparse-nested-index"; private static final String TEST_QUERY_TEXT = "Hello world a b"; - private static final String TEST_SPARSE_ENCODING_FIELD_NAME_1 = "test-sparse-encoding-1"; - private static final String TEST_SPARSE_ENCODING_FIELD_NAME_2 = "test-sparse-encoding-2"; + private static final String TEST_NEURAL_SPARSE_FIELD_NAME_1 = "test-sparse-encoding-1"; + private static final String TEST_NEURAL_SPARSE_FIELD_NAME_2 = "test-sparse-encoding-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field"; - private static final String TEST_SPARSE_ENCODING_FIELD_NAME_NESTED = "nested.sparse_encoding.field"; + private static final String TEST_NEURAL_SPARSE_FIELD_NAME_NESTED = "nested.neural_sparse.field"; private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); @@ -55,7 +55,7 @@ public void tearDown() { * Tests basic query: * { * "query": { - * "sparse_encoding": { + * "neural_sparse": { * "text_sparse": { * "query_text": "Hello world a b", * "model_id": "dcsdcasd" @@ -68,9 +68,9 @@ public void tearDown() { public void testBasicQueryUsingQueryText() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); String modelId = getDeployedModelId(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_1 - ).queryText(TEST_QUERY_TEXT).modelId(modelId); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); @@ -83,7 +83,47 @@ public void testBasicQueryUsingQueryText() { * Tests basic query: * { * "query": { - * "sparse_encoding": { + * "neural_sparse": { + * "text_sparse": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd", + * "max_token_score": float + * } + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryWithMaxTokenScore() { + float maxTokenScore = 0.00001f; + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .maxTokenScore(maxTokenScore); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + Map queryTokens = runSparseModelInference(modelId, TEST_QUERY_TEXT); + float expectedScore = 0f; + for (Map.Entry entry : queryTokens.entrySet()) { + if (testRankFeaturesDoc.containsKey(entry.getKey())) { + expectedScore += entry.getValue() * Math.min( + getFeatureFieldCompressedNumber(testRankFeaturesDoc.get(entry.getKey())), + maxTokenScore + ); + } + } + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } + + /** + * Tests basic query: + * { + * "query": { + * "neural_sparse": { * "text_sparse": { * "query_text": "Hello world a b", * "model_id": "dcsdcasd", @@ -97,9 +137,10 @@ public void testBasicQueryUsingQueryText() { public void testBoostQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); String modelId = getDeployedModelId(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_1 - ).queryText(TEST_QUERY_TEXT).modelId(modelId).boost(2.0f); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); @@ -117,7 +158,7 @@ public void testBoostQuery() { * "rescore": { * "query": { * "rescore_query": { - * "sparse_encoding": { + * "neural_sparse": { * "text_sparse": { * * "query_text": "Hello world a b", * * "model_id": "dcsdcasd" @@ -133,9 +174,9 @@ public void testRescoreQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); String modelId = getDeployedModelId(); MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_1 - ).queryText(TEST_QUERY_TEXT).modelId(modelId); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, matchAllQueryBuilder, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); @@ -150,13 +191,13 @@ public void testRescoreQuery() { * "query": { * "bool" : { * "should": [ - * "sparse_encoding": { + * "neural_sparse": { * "field1": { * "query_text": "Hello world a b", * "model_id": "dcsdcasd" * } * }, - * "sparse_encoding": { + * "neural_sparse": { * "field2": { * "query_text": "Hello world a b", * "model_id": "dcsdcasd" @@ -169,20 +210,20 @@ public void testRescoreQuery() { */ @SneakyThrows public void testBooleanQuery_withMultipleSparseEncodingQueries() { - initializeIndexIfNotExist(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME); + initializeIndexIfNotExist(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME); String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder1 = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_1 - ).queryText(TEST_QUERY_TEXT).modelId(modelId); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder2 = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_2 - ).queryText(TEST_QUERY_TEXT).modelId(modelId); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); - Map searchResponseAsMap = search(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map searchResponseAsMap = search(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME, boolQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); @@ -196,13 +237,13 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { * "query": { * "bool" : { * "should": [ - * "sparse_encoding": { + * "neural_sparse": { * "field1": { * "query_text": "Hello world a b", * "model_id": "dcsdcasd" * } * }, - * "sparse_encoding": { + * "neural_sparse": { * "field2": { * "query_text": "Hello world a b", * "model_id": "dcsdcasd" @@ -215,17 +256,17 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { */ @SneakyThrows public void testBooleanQuery_withSparseEncodingAndBM25Queries() { - initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME); + initializeIndexIfNotExist(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME); String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_1 - ).queryText(TEST_QUERY_TEXT).modelId(modelId); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); boolQueryBuilder.should(sparseEncodingQueryBuilder).should(matchQueryBuilder); - Map searchResponseAsMap = search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map searchResponseAsMap = search(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME, boolQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); @@ -235,41 +276,41 @@ public void testBooleanQuery_withSparseEncodingAndBM25Queries() { @SneakyThrows public void testBasicQueryUsingQueryText_whenQueryWrongFieldType_thenFail() { - initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME); + initializeIndexIfNotExist(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME); String modelId = getDeployedModelId(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(TEST_TEXT_FIELD_NAME_1) + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_TEXT_FIELD_NAME_1) .queryText(TEST_QUERY_TEXT) .modelId(modelId); - expectThrows(ResponseException.class, () -> search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, sparseEncodingQueryBuilder, 1)); + expectThrows(ResponseException.class, () -> search(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME, sparseEncodingQueryBuilder, 1)); } @SneakyThrows protected void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); - addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testRankFeaturesDoc)); + prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1), List.of(testRankFeaturesDoc)); assertEquals(1, getDocCount(indexName)); } - if (TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2)); + if (TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1, TEST_NEURAL_SPARSE_FIELD_NAME_2)); addSparseEncodingDoc( indexName, "1", - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2), + List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1, TEST_NEURAL_SPARSE_FIELD_NAME_2), List.of(testRankFeaturesDoc, testRankFeaturesDoc) ); assertEquals(1, getDocCount(indexName)); } - if (TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); + if (TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1)); addSparseEncodingDoc( indexName, "1", - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), + List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1), List.of(testRankFeaturesDoc), List.of(TEST_TEXT_FIELD_NAME_1), List.of(TEST_QUERY_TEXT) @@ -278,8 +319,8 @@ protected void initializeIndexIfNotExist(String indexName) { } if (TEST_NESTED_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED)); - addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), List.of(testRankFeaturesDoc)); + prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_NESTED)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_NEURAL_SPARSE_FIELD_NAME_NESTED), List.of(testRankFeaturesDoc)); assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME)); } } diff --git a/src/test/resources/processor/SparseEncodingIndexMappings.json b/src/test/resources/processor/SparseEncodingIndexMappings.json index 87dee278e..9748e8f3d 100644 --- a/src/test/resources/processor/SparseEncodingIndexMappings.json +++ b/src/test/resources/processor/SparseEncodingIndexMappings.json @@ -23,4 +23,4 @@ } } } -} \ No newline at end of file +} diff --git a/src/test/resources/processor/SparseEncodingPipelineConfiguration.json b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json index 82d13c8fe..04a4baf80 100644 --- a/src/test/resources/processor/SparseEncodingPipelineConfiguration.json +++ b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json @@ -15,4 +15,4 @@ } } ] -} \ No newline at end of file +} diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json index c45334bae..50b4b8a9b 100644 --- a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json +++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json @@ -7,4 +7,4 @@ "model_group_id": "", "model_content_hash_value": "b345e9e943b62c405a8dd227ef2c46c84c5ff0a0b71b584be9132b37bce91a9a", "url": "https://github.com/opensearch-project/ml-commons/raw/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/sparse_encoding/sparse_demo.zip" -} \ No newline at end of file +}