Skip to content

Commit

Permalink
Add max token score for SparseEncodingQueryBuilder and do renaming (#348
Browse files Browse the repository at this point in the history
)

* add lucene FeatureQuery

Signed-off-by: zhichao-aws <[email protected]>

* add max token score

Signed-off-by: zhichao-aws <[email protected]>

* add comments

Signed-off-by: zhichao-aws <[email protected]>

* add check and test

Signed-off-by: zhichao-aws <[email protected]>

* add doc

Signed-off-by: zhichao-aws <[email protected]>

* add change log

Signed-off-by: zhichao-aws <[email protected]>

* Address PR comments and change processor and query clause name. (#9)

* Address code review comments

Signed-off-by: zane-neo <[email protected]>

* Change lower case neural_sparse to upper case

Signed-off-by: zane-neo <[email protected]>

* Change back processor type name to sparse_encoding

Signed-off-by: zane-neo <[email protected]>

* Change names

Signed-off-by: zane-neo <[email protected]>

* Format code

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
Signed-off-by: zane-neo <[email protected]>
Co-authored-by: zane-neo <[email protected]>
  • Loading branch information
zhichao-aws and zane-neo authored Oct 1, 2023
1 parent 9e12de8 commit 7ae48df
Show file tree
Hide file tree
Showing 13 changed files with 455 additions and 105 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
Add `max_token_score` parameter to improve the execution efficiency for `neural_sparse` query clause ([#348](https://github.com/opensearch-project/neural-search/pull/348))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
235 changes: 235 additions & 0 deletions src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

/*
* This class is built based on lucene FeatureQuery. We use LinearFuntion to
* build the query and add an upperbound to it.
*/

package org.apache.lucene;

import java.io.IOException;
import java.util.Objects;

import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ImpactsDISI;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;

/**
* The feature queries of input tokens are wrapped by lucene BooleanQuery, which use WAND algorithm
* to accelerate the execution. The WAND algorithm leverage the score upper bound of sub-queries to
* skip non-competitive tokens. However, origin lucene FeatureQuery use Float.MAX_VALUE as the score
* upper bound, and this invalidates WAND.
*
* To mitigate this issue, we rewrite the FeatureQuery to BoundedLinearFeatureQuery. The caller can
* set the token score upperbound of this query. And according to our use case, we use LinearFunction
* as the score function.
*
* This class combines both <a href="https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java">FeatureQuery</a>
* and <a href="https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/document/FeatureField.java">FeatureField</a> together
* and will be deprecated after OpenSearch upgraded lucene to version 9.8.
*/

public final class BoundedLinearFeatureQuery extends Query {

private final String fieldName;
private final String featureName;
private final Float scoreUpperBound;

public BoundedLinearFeatureQuery(String fieldName, String featureName, Float scoreUpperBound) {
this.fieldName = Objects.requireNonNull(fieldName);
this.featureName = Objects.requireNonNull(featureName);
this.scoreUpperBound = Objects.requireNonNull(scoreUpperBound);
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
// LinearFunction return same object for rewrite
return super.rewrite(indexSearcher);
}

@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj;
return Objects.equals(fieldName, that.fieldName)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(scoreUpperBound, that.scoreUpperBound);
}

@Override
public int hashCode() {
int h = getClass().hashCode();
h = 31 * h + fieldName.hashCode();
h = 31 * h + featureName.hashCode();
h = 31 * h + scoreUpperBound.hashCode();
return h;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
if (!scoreMode.needsScores()) {
// We don't need scores (e.g. for faceting), and since features are stored as terms,
// allow TermQuery to optimize in this case
TermQuery tq = new TermQuery(new Term(fieldName, featureName));
return searcher.rewrite(tq).createWeight(searcher, scoreMode, boost);
}

return new Weight(this) {

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
String desc = "weight(" + getQuery() + " in " + doc + ") [\" BoundedLinearFeatureQuery \"]";

Terms terms = context.reader().terms(fieldName);
if (terms == null) {
return Explanation.noMatch(desc + ". Field " + fieldName + " doesn't exist.");
}
TermsEnum termsEnum = terms.iterator();
if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
return Explanation.noMatch(desc + ". Feature " + featureName + " doesn't exist.");
}

PostingsEnum postings = termsEnum.postings(null, PostingsEnum.FREQS);
if (postings.advance(doc) != doc) {
return Explanation.noMatch(desc + ". Feature " + featureName + " isn't set.");
}

int freq = postings.freq();
float featureValue = decodeFeatureValue(freq);
float score = boost * featureValue;
return Explanation.match(
score,
"Linear function on the " + fieldName + " field for the " + featureName + " feature, computed as w * S from:",
Explanation.match(boost, "w, weight of this function"),
Explanation.match(featureValue, "S, feature value")
);
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Terms terms = Terms.getTerms(context.reader(), fieldName);
TermsEnum termsEnum = terms.iterator();
if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
return null;
}

final SimScorer scorer = new SimScorer() {
@Override
public float score(float freq, long norm) {
return boost * decodeFeatureValue(freq);
}
};
final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS);
final ImpactsDISI impactsDisi = new ImpactsDISI(impacts, impacts, scorer);

return new Scorer(this) {

@Override
public int docID() {
return impacts.docID();
}

@Override
public float score() throws IOException {
return scorer.score(impacts.freq(), 1L);
}

@Override
public DocIdSetIterator iterator() {
return impactsDisi;
}

@Override
public int advanceShallow(int target) throws IOException {
return impactsDisi.advanceShallow(target);
}

@Override
public float getMaxScore(int upTo) throws IOException {
return impactsDisi.getMaxScore(upTo);
}

@Override
public void setMinCompetitiveScore(float minScore) {
impactsDisi.setMinCompetitiveScore(minScore);
}
};
}
};
}

@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(fieldName)) {
visitor.visitLeaf(this);
}
}

@Override
public String toString(String field) {
return "BoundedLinearFeatureQuery(field=" + fieldName + ", feature=" + featureName + ", scoreUpperBound=" + scoreUpperBound + ")";
}

// the field and decodeFeatureValue are modified from FeatureField.decodeFeatureValue
static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15;

// Rewriting this function to make scoreUpperBound work.
private float decodeFeatureValue(float freq) {
if (freq > MAX_FREQ) {
return scoreUpperBound;
}
int tf = (int) freq; // lossless
int featureBits = tf << 15;
return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
Expand Down Expand Up @@ -81,7 +81,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
NeuralSparseQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -91,7 +91,7 @@ public List<QuerySpec<?>> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
new QuerySpec<>(NeuralSparseQueryBuilder.NAME, NeuralSparseQueryBuilder::new, NeuralSparseQueryBuilder::fromXContent)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* and set the target fields according to the field name map.
*/
@Log4j2
public abstract class NLPProcessor extends AbstractProcessor {
public abstract class InferenceProcessor extends AbstractProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
Expand All @@ -51,7 +51,7 @@ public abstract class NLPProcessor extends AbstractProcessor {

private final Environment environment;

public NLPProcessor(
public InferenceProcessor(
String tag,
String description,
String type,
Expand Down Expand Up @@ -249,7 +249,7 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0);
InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
String knnKey = knnMapEntry.getKey();
Expand All @@ -270,7 +270,7 @@ private void putNLPResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
NLPProcessor.IndexWrapper indexWrapper,
InferenceProcessor.IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
Expand All @@ -294,7 +294,7 @@ private void putNLPResultToSourceMapForMapType(
private List<Map<String, Object>> buildNLPResultForListType(
List<String> sourceValue,
List<?> results,
NLPProcessor.IndexWrapper indexWrapper
InferenceProcessor.IndexWrapper indexWrapper
) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results.
*/
@Log4j2
public final class SparseEncodingProcessor extends NLPProcessor {
public final class SparseEncodingProcessor extends InferenceProcessor {

public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
public final class TextEmbeddingProcessor extends NLPProcessor {
public final class TextEmbeddingProcessor extends InferenceProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
Expand Down
Loading

0 comments on commit 7ae48df

Please sign in to comment.