Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the similarity as boost functionality to the Word2VecSynonyFilter #12433

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import java.io.IOException;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.BoostAttribute;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.search.BoostAttribute;
import org.apache.lucene.util.IgnoreRandomChains;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.synonym.SynonymGraphFilter;
import org.apache.lucene.analysis.tokenattributes.BoostAttribute;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
Expand All @@ -41,11 +42,13 @@ public final class Word2VecSynonymFilter extends TokenFilter {
private final PositionIncrementAttribute posIncrementAtt =
addAttribute(PositionIncrementAttribute.class);
private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class);
private final BoostAttribute boostAtt = addAttribute(BoostAttribute.class);
private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class);

private final Word2VecSynonymProvider synonymProvider;
private final int maxSynonymsPerTerm;
private final float minAcceptedSimilarity;
private final boolean similarityAsBoost;
private final LinkedList<TermAndBoost> synonymBuffer = new LinkedList<>();
private State lastState;

Expand All @@ -62,14 +65,16 @@ public Word2VecSynonymFilter(
TokenStream input,
Word2VecSynonymProvider synonymProvider,
int maxSynonymsPerTerm,
float minAcceptedSimilarity) {
float minAcceptedSimilarity,
boolean similarityAsBoost) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a back-compat ctor that doesn't take similarityAsBoost, and defaults it to false I guess?

super(input);
if (synonymProvider == null) {
throw new IllegalArgumentException("The SynonymProvider must be non-null");
}
this.synonymProvider = synonymProvider;
this.maxSynonymsPerTerm = maxSynonymsPerTerm;
this.minAcceptedSimilarity = minAcceptedSimilarity;
this.similarityAsBoost = similarityAsBoost;
}

@Override
Expand All @@ -81,6 +86,7 @@ public boolean incrementToken() throws IOException {
restoreState(this.lastState);
termAtt.setEmpty();
termAtt.append(synonym.term.utf8ToString());
boostAtt.setBoost(this.similarityAsBoost ? synonym.boost : 1);
typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM);
posLenAtt.setPositionLength(1);
posIncrementAtt.setPositionIncrement(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class Word2VecSynonymFilterFactory extends TokenFilterFactory
private final float minAcceptedSimilarity;
private final Word2VecSupportedFormats format;
private final String word2vecModelFileName;
private final boolean similarityAsBoost;

private Word2VecSynonymProvider synonymProvider;

Expand All @@ -61,6 +62,7 @@ public Word2VecSynonymFilterFactory(Map<String, String> args) {
} catch (IllegalArgumentException exc) {
throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc);
}
this.similarityAsBoost = getBoolean(args, "similarityAsBoost", true);

if (!args.isEmpty()) {
throw new IllegalArgumentException("Unknown parameters: " + args);
Expand Down Expand Up @@ -90,7 +92,7 @@ public TokenStream create(TokenStream input) {
return synonymProvider == null
? input
: new Word2VecSynonymFilter(
input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity, similarityAsBoost);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package org.apache.lucene.analysis.boost;

import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.BoostAttribute;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.search.BoostAttribute;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;

public class TestDelimitedBoostTokenFilter extends BaseTokenStreamTestCase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenizer;
import org.apache.lucene.util.BytesRef;
Expand All @@ -41,7 +42,17 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() thro

Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);

Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
float similarityAWithB = // 0.9969
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(1));
float similarityAWithC = // 0.9993
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(2));
float similarityAWithD = // 1.0
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(3));
float similarityAWithE = // 0.9999
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(4));
// float similarityAWithF = 0.8166 (not accepted)

Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity, true);
assertAnalyzesTo(
a,
"pre a post", // input
Expand All @@ -50,7 +61,38 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() thro
new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset
new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types
new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements
new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts
new int[] {1, 1, 1, 1, 1, 1, 1}, // posLenghts
new float[] {
1, 1, similarityAWithD, similarityAWithE, similarityAWithC, similarityAWithB, 1
}); // boost
a.close();
}

@Test
public void synonymExpansion_oneCandidate_shouldBeExpandedWithNoBoost() throws Exception {
int maxSynonymPerTerm = 10;
float minAcceptedSimilarity = 0.9f;
Word2VecModel model = new Word2VecModel(6, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));

Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);

Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity, false);
assertAnalyzesTo(
a,
"pre a post", // input
new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output
new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset
new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset
new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types
new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements
new int[] {1, 1, 1, 1, 1, 1, 1}, // posLenghts
new float[] {1, 1, 1, 1, 1, 1, 1}); // boost
a.close();
}

Expand All @@ -67,7 +109,14 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() thr

Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);

Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
// float similarityAWithB = 0.9969 (not in top 2)
// float similarityAWithC = 0.9993 (not in top 2)
float similarityAWithD = // 1.0
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(3));
float similarityAWithE = // 0.9999
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(4));

Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity, true);
assertAnalyzesTo(
a,
"pre a post", // input
Expand All @@ -76,7 +125,8 @@ public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() thr
new int[] {3, 5, 5, 5, 10}, // end offset
new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types
new int[] {1, 1, 0, 0, 1}, // posIncrements
new int[] {1, 1, 1, 1, 1}); // posLenghts
new int[] {1, 1, 1, 1, 1}, // posLenghts
new float[] {1, 1, similarityAWithD, similarityAWithE, 1}); // boost
a.close();
}

Expand All @@ -94,7 +144,18 @@ public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Excepti

Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);

Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f);
float similarityAWithB =
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(1));
float similarityAWithC =
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(2));
float similarityAWithD =
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(3));
float similarityAWithE =
VectorSimilarityFunction.COSINE.compare(model.vectorValue(0), model.vectorValue(4));
float similarityPostWithAfter =
VectorSimilarityFunction.COSINE.compare(model.vectorValue(6), model.vectorValue(7));

Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f, true);
assertAnalyzesTo(
a,
"pre a post", // input
Expand All @@ -105,7 +166,17 @@ public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Excepti
"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM"
},
new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements
new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths
new int[] {1, 1, 1, 1, 1, 1, 1, 1}, // posLengths
new float[] {
1,
1,
similarityAWithD,
similarityAWithE,
similarityAWithC,
similarityAWithB,
1,
similarityPostWithAfter
}); // boost
a.close();
}

Expand All @@ -120,7 +191,7 @@ public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms

Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);

Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f);
Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f, true);
assertAnalyzesTo(
a,
"pre a post", // input
Expand All @@ -129,22 +200,28 @@ public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms
new int[] {3, 5, 10}, // end offset
new String[] {"word", "word", "word"}, // types
new int[] {1, 1, 1}, // posIncrements
new int[] {1, 1, 1}); // posLengths
new int[] {1, 1, 1}, // posLengths
new float[] {1, 1, 1}); // boost
a.close();
}

private Analyzer getAnalyzer(
Word2VecSynonymProvider synonymProvider,
int maxSynonymsPerTerm,
float minAcceptedSimilarity) {
float minAcceptedSimilarity,
boolean similarityAsBoost) {
return new Analyzer() {
@Override
protected TokenStreamComponents createComponents(String fieldName) {
Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false);
// Make a local variable so testRandomHuge doesn't share it across threads!
Word2VecSynonymFilter synFilter =
new Word2VecSynonymFilter(
tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
tokenizer,
synonymProvider,
maxSynonymsPerTerm,
minAcceptedSimilarity,
similarityAsBoost);
return new TokenStreamComponents(tokenizer, synFilter);
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.tokenattributes;

import org.apache.lucene.util.Attribute;

/**
* Add this {@link BoostAttribute} if you want to manipulate the token stream in order to update the
* boost associated to a token
*
* <p><b>Please note:</b> This attribute does not work at index time
*
* @lucene.internal
*/
public interface BoostAttribute extends Attribute {
/** Sets the boost in this attribute */
public void setBoost(float boost);
/** Retrieves the boost, default is {@code 1.0f}. */
public float getBoost();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.tokenattributes;

import java.util.Objects;
import org.apache.lucene.util.AttributeImpl;
import org.apache.lucene.util.AttributeReflector;

/**
* Implementation class for {@link BoostAttribute}.
*
* @lucene.internal
*/
public final class BoostAttributeImpl extends AttributeImpl implements BoostAttribute {
private float boost = 1.0f;

/** Initialize this attribute with no boost. */
public BoostAttributeImpl() {}

@Override
public void setBoost(float boost) {
if (Float.isFinite(boost) == false || Float.compare(boost, 0f) < 0) {
throw new IllegalArgumentException("boost must be a positive float, got " + boost);
}
this.boost = boost;
}

@Override
public float getBoost() {
return boost;
}

@Override
public void clear() {
boost = 1.0f;
}

@Override
public BoostAttributeImpl clone() {
BoostAttributeImpl clone = (BoostAttributeImpl) super.clone();
clone.boost = this.boost;
return clone;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}

if (other == null || getClass() != other.getClass()) {
return false;
}
BoostAttributeImpl that = (BoostAttributeImpl) other;
return Float.compare(that.boost, boost) == 0;
}

@Override
public int hashCode() {
return Objects.hash(boost);
}

@Override
public void copyTo(AttributeImpl target) {
((BoostAttribute) target).setBoost(boost);
}

@Override
public void reflectWith(AttributeReflector reflector) {
reflector.reflect(BoostAttribute.class, "boost", boost);
}
}
Loading