diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/boost/DelimitedBoostTokenFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/boost/DelimitedBoostTokenFilter.java index 9e693ca8710b..5946247c0462 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/boost/DelimitedBoostTokenFilter.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/boost/DelimitedBoostTokenFilter.java @@ -19,8 +19,8 @@ import java.io.IOException; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.BoostAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.util.IgnoreRandomChains; /** diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java index a8db4c4c764a..1219c058fcbe 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java @@ -23,6 +23,7 @@ import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.synonym.SynonymGraphFilter; +import org.apache.lucene.analysis.tokenattributes.BoostAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; @@ -41,11 +42,13 @@ public final class Word2VecSynonymFilter extends TokenFilter { private final PositionIncrementAttribute posIncrementAtt = addAttribute(PositionIncrementAttribute.class); private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class); + private final BoostAttribute boostAtt = addAttribute(BoostAttribute.class); private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class); private final Word2VecSynonymProvider synonymProvider; private final int maxSynonymsPerTerm; private final float minAcceptedSimilarity; + private final boolean similarityAsBoost; private final LinkedList synonymBuffer = new LinkedList<>(); private State lastState; @@ -62,7 +65,8 @@ public Word2VecSynonymFilter( TokenStream input, Word2VecSynonymProvider synonymProvider, int maxSynonymsPerTerm, - float minAcceptedSimilarity) { + float minAcceptedSimilarity, + boolean similarityAsBoost) { super(input); if (synonymProvider == null) { throw new IllegalArgumentException("The SynonymProvider must be non-null"); @@ -70,6 +74,7 @@ public Word2VecSynonymFilter( this.synonymProvider = synonymProvider; this.maxSynonymsPerTerm = maxSynonymsPerTerm; this.minAcceptedSimilarity = minAcceptedSimilarity; + this.similarityAsBoost = similarityAsBoost; } @Override @@ -81,6 +86,7 @@ public boolean incrementToken() throws IOException { restoreState(this.lastState); termAtt.setEmpty(); termAtt.append(synonym.term.utf8ToString()); + boostAtt.setBoost(this.similarityAsBoost ? synonym.boost : 1); typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM); posLenAtt.setPositionLength(1); posIncrementAtt.setPositionIncrement(0); diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java index 32b6288926fc..5d1810e55e82 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java @@ -45,6 +45,7 @@ public class Word2VecSynonymFilterFactory extends TokenFilterFactory private final float minAcceptedSimilarity; private final Word2VecSupportedFormats format; private final String word2vecModelFileName; + private final boolean similarityAsBoost; private Word2VecSynonymProvider synonymProvider; @@ -61,6 +62,7 @@ public Word2VecSynonymFilterFactory(Map args) { } catch (IllegalArgumentException exc) { throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc); } + this.similarityAsBoost = getBoolean(args, "similarityAsBoost", true); if (!args.isEmpty()) { throw new IllegalArgumentException("Unknown parameters: " + args); @@ -90,7 +92,7 @@ public TokenStream create(TokenStream input) { return synonymProvider == null ? input : new Word2VecSynonymFilter( - input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); + input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity, similarityAsBoost); } @Override diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/boost/TestDelimitedBoostTokenFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/boost/TestDelimitedBoostTokenFilter.java index 3a4774c465a1..8a14b5b84ad7 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/boost/TestDelimitedBoostTokenFilter.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/boost/TestDelimitedBoostTokenFilter.java @@ -17,8 +17,8 @@ package org.apache.lucene.analysis.boost; import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.BoostAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; public class TestDelimitedBoostTokenFilter extends BaseTokenStreamTestCase { diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java index 3999931dd758..9203710b0b4f 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java @@ -19,6 +19,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; import org.apache.lucene.tests.analysis.MockTokenizer; import org.apache.lucene.util.BytesRef; @@ -41,7 +42,17 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() thro Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); + float similarityAWithB = // 0.9969 + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(1)); + float similarityAWithC = // 0.9993 + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(2)); + float similarityAWithD = // 1.0 + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(3)); + float similarityAWithE = // 0.9999 + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(4)); + // float similarityAWithF = 0.8166 (not accepted) + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity, true); assertAnalyzesTo( a, "pre a post", // input @@ -50,7 +61,38 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() thro new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements - new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts + new int[] {1, 1, 1, 1, 1, 1, 1}, // posLenghts + new float[] { + 1, 1, similarityAWithD, similarityAWithE, similarityAWithC, similarityAWithB, 1 + }); // boost + a.close(); + } + + @Test + public void synonymExpansion_oneCandidate_shouldBeExpandedWithNoBoost() throws Exception { + int maxSynonymPerTerm = 10; + float minAcceptedSimilarity = 0.9f; + Word2VecModel model = new Word2VecModel(6, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity, false); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output + new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset + new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset + new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types + new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements + new int[] {1, 1, 1, 1, 1, 1, 1}, // posLenghts + new float[] {1, 1, 1, 1, 1, 1, 1}); // boost a.close(); } @@ -67,7 +109,14 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() thr Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); + // float similarityAWithB = 0.9969 (not in top 2) + // float similarityAWithC = 0.9993 (not in top 2) + float similarityAWithD = // 1.0 + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(3)); + float similarityAWithE = // 0.9999 + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(4)); + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity, true); assertAnalyzesTo( a, "pre a post", // input @@ -76,7 +125,8 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() thr new int[] {3, 5, 5, 5, 10}, // end offset new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types new int[] {1, 1, 0, 0, 1}, // posIncrements - new int[] {1, 1, 1, 1, 1}); // posLenghts + new int[] {1, 1, 1, 1, 1}, // posLenghts + new float[] {1, 1, similarityAWithD, similarityAWithE, 1}); // boost a.close(); } @@ -94,7 +144,18 @@ public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Excepti Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f); + float similarityAWithB = + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(1)); + float similarityAWithC = + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(2)); + float similarityAWithD = + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(3)); + float similarityAWithE = + VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(4)); + float similarityPostWithAfter = + VectorSimilarityFunction.COSINE.compare(model.vectorValue(6), model.vectorValue(7)); + + Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f, true); assertAnalyzesTo( a, "pre a post", // input @@ -105,7 +166,17 @@ public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Excepti "word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM" }, new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements - new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths + new int[] {1, 1, 1, 1, 1, 1, 1, 1}, // posLengths + new float[] { + 1, + 1, + similarityAWithD, + similarityAWithE, + similarityAWithC, + similarityAWithB, + 1, + similarityPostWithAfter + }); // boost a.close(); } @@ -120,7 +191,7 @@ public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f); + Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f, true); assertAnalyzesTo( a, "pre a post", // input @@ -129,14 +200,16 @@ public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms new int[] {3, 5, 10}, // end offset new String[] {"word", "word", "word"}, // types new int[] {1, 1, 1}, // posIncrements - new int[] {1, 1, 1}); // posLengths + new int[] {1, 1, 1}, // posLengths + new float[] {1, 1, 1}); // boost a.close(); } private Analyzer getAnalyzer( Word2VecSynonymProvider synonymProvider, int maxSynonymsPerTerm, - float minAcceptedSimilarity) { + float minAcceptedSimilarity, + boolean similarityAsBoost) { return new Analyzer() { @Override protected TokenStreamComponents createComponents(String fieldName) { @@ -144,7 +217,11 @@ protected TokenStreamComponents createComponents(String fieldName) { // Make a local variable so testRandomHuge doesn't share it across threads! Word2VecSynonymFilter synFilter = new Word2VecSynonymFilter( - tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); + tokenizer, + synonymProvider, + maxSynonymsPerTerm, + minAcceptedSimilarity, + similarityAsBoost); return new TokenStreamComponents(tokenizer, synFilter); } }; diff --git a/lucene/core/src/java/org/apache/lucene/analysis/tokenattributes/BoostAttribute.java b/lucene/core/src/java/org/apache/lucene/analysis/tokenattributes/BoostAttribute.java new file mode 100644 index 000000000000..4c7df36d75fd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/analysis/tokenattributes/BoostAttribute.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.analysis.tokenattributes; + +import org.apache.lucene.util.Attribute; + +/** + * Add this {@link BoostAttribute} if you want to manipulate the token stream in order to update the + * boost associated to a token + * + *

Please note: This attribute does not work at index time + * + * @lucene.internal + */ +public interface BoostAttribute extends Attribute { + /** Sets the boost in this attribute */ + public void setBoost(float boost); + /** Retrieves the boost, default is {@code 1.0f}. */ + public float getBoost(); +} diff --git a/lucene/core/src/java/org/apache/lucene/analysis/tokenattributes/BoostAttributeImpl.java b/lucene/core/src/java/org/apache/lucene/analysis/tokenattributes/BoostAttributeImpl.java new file mode 100644 index 000000000000..7934b9cb9ab3 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/analysis/tokenattributes/BoostAttributeImpl.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.analysis.tokenattributes; + +import java.util.Objects; +import org.apache.lucene.util.AttributeImpl; +import org.apache.lucene.util.AttributeReflector; + +/** + * Implementation class for {@link BoostAttribute}. + * + * @lucene.internal + */ +public final class BoostAttributeImpl extends AttributeImpl implements BoostAttribute { + private float boost = 1.0f; + + /** Initialize this attribute with no boost. */ + public BoostAttributeImpl() {} + + @Override + public void setBoost(float boost) { + if (Float.isFinite(boost) == false || Float.compare(boost, 0f) < 0) { + throw new IllegalArgumentException("boost must be a positive float, got " + boost); + } + this.boost = boost; + } + + @Override + public float getBoost() { + return boost; + } + + @Override + public void clear() { + boost = 1.0f; + } + + @Override + public BoostAttributeImpl clone() { + BoostAttributeImpl clone = (BoostAttributeImpl) super.clone(); + clone.boost = this.boost; + return clone; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + BoostAttributeImpl that = (BoostAttributeImpl) other; + return Float.compare(that.boost, boost) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(boost); + } + + @Override + public void copyTo(AttributeImpl target) { + ((BoostAttribute) target).setBoost(boost); + } + + @Override + public void reflectWith(AttributeReflector reflector) { + reflector.reflect(BoostAttribute.class, "boost", boost); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/FuzzyTermsEnum.java b/lucene/core/src/java/org/apache/lucene/search/FuzzyTermsEnum.java index cfd6ed232de5..e82e0597ef21 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FuzzyTermsEnum.java +++ b/lucene/core/src/java/org/apache/lucene/search/FuzzyTermsEnum.java @@ -49,7 +49,7 @@ public final class FuzzyTermsEnum extends TermsEnum { // We use this to communicate the score (boost) of the current matched term we are on back to // MultiTermQuery.TopTermsBlendedFreqScoringRewrite that is collecting the best (default 50) // matched terms: - private final BoostAttribute boostAtt; + private final MultiTermQueryBoostAttribute boostAtt; // MultiTermQuery.TopTermsBlendedFreqScoringRewrite tells us the worst boost still in its queue // using this att, @@ -142,7 +142,7 @@ private FuzzyTermsEnum( this.term = term; this.maxBoostAtt = atts.addAttribute(MaxNonCompetitiveBoostAttribute.class); - this.boostAtt = atts.addAttribute(BoostAttribute.class); + this.boostAtt = atts.addAttribute(MultiTermQueryBoostAttribute.class); atts.addAttributeImpl(new AutomatonAttributeImpl()); AutomatonAttribute aa = atts.addAttribute(AutomatonAttribute.class); diff --git a/lucene/core/src/java/org/apache/lucene/search/BoostAttribute.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryBoostAttribute.java similarity index 96% rename from lucene/core/src/java/org/apache/lucene/search/BoostAttribute.java rename to lucene/core/src/java/org/apache/lucene/search/MultiTermQueryBoostAttribute.java index 0a1570e2edf1..148cc7b6949a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BoostAttribute.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryBoostAttribute.java @@ -33,7 +33,7 @@ * * @lucene.internal */ -public interface BoostAttribute extends Attribute { +public interface MultiTermQueryBoostAttribute extends Attribute { float DEFAULT_BOOST = 1.0f; /** Sets the boost in this attribute */ public void setBoost(float boost); diff --git a/lucene/core/src/java/org/apache/lucene/search/BoostAttributeImpl.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryBoostAttributeImpl.java similarity index 80% rename from lucene/core/src/java/org/apache/lucene/search/BoostAttributeImpl.java rename to lucene/core/src/java/org/apache/lucene/search/MultiTermQueryBoostAttributeImpl.java index f7f218e33603..790131996237 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BoostAttributeImpl.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryBoostAttributeImpl.java @@ -20,11 +20,12 @@ import org.apache.lucene.util.AttributeReflector; /** - * Implementation class for {@link BoostAttribute}. + * Implementation class for {@link MultiTermQueryBoostAttribute}. * * @lucene.internal */ -public final class BoostAttributeImpl extends AttributeImpl implements BoostAttribute { +public final class MultiTermQueryBoostAttributeImpl extends AttributeImpl + implements MultiTermQueryBoostAttribute { private float boost = 1.0f; @Override @@ -44,11 +45,11 @@ public void clear() { @Override public void copyTo(AttributeImpl target) { - ((BoostAttribute) target).setBoost(boost); + ((MultiTermQueryBoostAttribute) target).setBoost(boost); } @Override public void reflectWith(AttributeReflector reflector) { - reflector.reflect(BoostAttribute.class, "boost", boost); + reflector.reflect(MultiTermQueryBoostAttribute.class, "boost", boost); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java b/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java index 9adb0b1f94b2..e6c72d971704 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java +++ b/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java @@ -129,12 +129,12 @@ final class ParallelArraysTermCollector extends TermCollector { new BytesRefHash(new ByteBlockPool(new ByteBlockPool.DirectAllocator()), 16, array); TermsEnum termsEnum; - private BoostAttribute boostAtt; + private MultiTermQueryBoostAttribute boostAtt; @Override public void setNextEnum(TermsEnum termsEnum) { this.termsEnum = termsEnum; - this.boostAtt = termsEnum.attributes().addAttribute(BoostAttribute.class); + this.boostAtt = termsEnum.attributes().addAttribute(MultiTermQueryBoostAttribute.class); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java b/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java index fe95a2fb3e31..e6a91d15149a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java @@ -74,7 +74,7 @@ public final Query rewrite(IndexSearcher indexSearcher, final MultiTermQuery que private final Map visitedTerms = new HashMap<>(); private TermsEnum termsEnum; - private BoostAttribute boostAtt; + private MultiTermQueryBoostAttribute boostAtt; private ScoreTerm st; @Override @@ -85,7 +85,7 @@ public void setNextEnum(TermsEnum termsEnum) { // lazy init the initial ScoreTerm because comparator is not known on ctor: if (st == null) st = new ScoreTerm(new TermStates(topReaderContext)); - boostAtt = termsEnum.attributes().addAttribute(BoostAttribute.class); + boostAtt = termsEnum.attributes().addAttribute(MultiTermQueryBoostAttribute.class); } // for assert: diff --git a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java index 3ad17dca5ce9..63d7d80e9bba 100644 --- a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java @@ -16,7 +16,7 @@ */ package org.apache.lucene.util; -import static org.apache.lucene.search.BoostAttribute.DEFAULT_BOOST; +import static org.apache.lucene.search.MultiTermQueryBoostAttribute.DEFAULT_BOOST; import java.io.IOException; import java.util.ArrayList; @@ -31,9 +31,9 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.MultiPhraseQuery; +import org.apache.lucene.search.MultiTermQueryBoostAttribute; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SynonymQuery; @@ -372,7 +372,7 @@ protected Query createFieldQuery( /** Creates simple term query from the cached tokenstream contents */ protected Query analyzeTerm(String field, TokenStream stream) throws IOException { TermToBytesRefAttribute termAtt = stream.getAttribute(TermToBytesRefAttribute.class); - BoostAttribute boostAtt = stream.addAttribute(BoostAttribute.class); + MultiTermQueryBoostAttribute boostAtt = stream.addAttribute(MultiTermQueryBoostAttribute.class); stream.reset(); if (!stream.incrementToken()) { @@ -385,7 +385,7 @@ protected Query analyzeTerm(String field, TokenStream stream) throws IOException /** Creates simple boolean query from the cached tokenstream contents */ protected Query analyzeBoolean(String field, TokenStream stream) throws IOException { TermToBytesRefAttribute termAtt = stream.getAttribute(TermToBytesRefAttribute.class); - BoostAttribute boostAtt = stream.addAttribute(BoostAttribute.class); + MultiTermQueryBoostAttribute boostAtt = stream.addAttribute(MultiTermQueryBoostAttribute.class); stream.reset(); List terms = new ArrayList<>(); @@ -419,7 +419,7 @@ protected Query analyzeMultiBoolean( TermToBytesRefAttribute termAtt = stream.getAttribute(TermToBytesRefAttribute.class); PositionIncrementAttribute posIncrAtt = stream.getAttribute(PositionIncrementAttribute.class); - BoostAttribute boostAtt = stream.addAttribute(BoostAttribute.class); + MultiTermQueryBoostAttribute boostAtt = stream.addAttribute(MultiTermQueryBoostAttribute.class); stream.reset(); while (stream.incrementToken()) { @@ -440,7 +440,7 @@ protected Query analyzePhrase(String field, TokenStream stream, int slop) throws builder.setSlop(slop); TermToBytesRefAttribute termAtt = stream.getAttribute(TermToBytesRefAttribute.class); - BoostAttribute boostAtt = stream.addAttribute(BoostAttribute.class); + MultiTermQueryBoostAttribute boostAtt = stream.addAttribute(MultiTermQueryBoostAttribute.class); PositionIncrementAttribute posIncrAtt = stream.getAttribute(PositionIncrementAttribute.class); int position = -1; float phraseBoost = DEFAULT_BOOST; @@ -545,7 +545,8 @@ public Query next() { .map( s -> { TermToBytesRefAttribute t = s.addAttribute(TermToBytesRefAttribute.class); - BoostAttribute b = s.addAttribute(BoostAttribute.class); + MultiTermQueryBoostAttribute b = + s.addAttribute(MultiTermQueryBoostAttribute.class); return new TermAndBoost(t.getBytesRef(), b.getBoost()); }) .toArray(TermAndBoost[]::new); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMultiTermQueryRewrites.java b/lucene/core/src/test/org/apache/lucene/search/TestMultiTermQueryRewrites.java index 451e7648765b..47d0b483ead4 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestMultiTermQueryRewrites.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestMultiTermQueryRewrites.java @@ -167,7 +167,8 @@ private void checkBoosts(MultiTermQuery.RewriteMethod method) throws Exception { protected TermsEnum getTermsEnum(Terms terms, AttributeSource atts) throws IOException { return new FilteredTermsEnum(terms.iterator()) { - final BoostAttribute boostAtt = attributes().addAttribute(BoostAttribute.class); + final MultiTermQueryBoostAttribute boostAtt = + attributes().addAttribute(MultiTermQueryBoostAttribute.class); @Override protected AcceptStatus accept(BytesRef term) { diff --git a/lucene/core/src/test/org/apache/lucene/util/TestQueryBuilder.java b/lucene/core/src/test/org/apache/lucene/util/TestQueryBuilder.java index 2ea575d5274f..cc0346479804 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestQueryBuilder.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestQueryBuilder.java @@ -27,10 +27,10 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MultiPhraseQuery; +import org.apache.lucene.search.MultiTermQueryBoostAttribute; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SynonymQuery; @@ -550,7 +550,7 @@ public void testMaxBooleanClause() throws Exception { private static final class MockBoostTokenFilter extends TokenFilter { - final BoostAttribute boostAtt = addAttribute(BoostAttribute.class); + final MultiTermQueryBoostAttribute boostAtt = addAttribute(MultiTermQueryBoostAttribute.class); final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); protected MockBoostTokenFilter(TokenStream input) { diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java index f834e66d33ef..71212468595b 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java @@ -34,11 +34,11 @@ import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.FuzzyTermsEnum; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MultiTermQueryBoostAttribute; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; @@ -217,7 +217,8 @@ private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws int numVariants = 0; int totalVariantDocFreqs = 0; BytesRef possibleMatch; - BoostAttribute boostAtt = fe.attributes().addAttribute(BoostAttribute.class); + MultiTermQueryBoostAttribute boostAtt = + fe.attributes().addAttribute(MultiTermQueryBoostAttribute.class); while ((possibleMatch = fe.next()) != null) { numVariants++; totalVariantDocFreqs += fe.docFreq(); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java index f9cd607ccacc..4cb5e5a3a34f 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java @@ -40,6 +40,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.TokenStreamToAutomaton; +import org.apache.lucene.analysis.tokenattributes.BoostAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.FlagsAttribute; import org.apache.lucene.analysis.tokenattributes.KeywordAttribute; @@ -55,7 +56,6 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -251,7 +251,7 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's - if (boostAtt != null) boostAtt.setBoost(-1f); + if (boostAtt != null) boostAtt.setBoost(1.0f); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before assertTrue("token " + i + " does not exist", ts.incrementToken()); @@ -417,7 +417,7 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's - if (boostAtt != null) boostAtt.setBoost(-1); + if (boostAtt != null) boostAtt.setBoost(1.0f); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before