Skip to content

Commit

Permalink
Introduced the Word2VecSynonymFilter (#12169)
Browse files Browse the repository at this point in the history
Co-authored-by: Alessandro Benedetti <[email protected]>
  • Loading branch information
dantuzi and alessandrobenedetti committed Apr 24, 2023
1 parent f517a79 commit 53709cc
Show file tree
Hide file tree
Showing 24 changed files with 1,450 additions and 23 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ New Features
crash the JVM. To disable this feature, pass the following sysprop on Java command line:
"-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)

* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti)

Improvements
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.stempel.StempelStemmer;
import org.apache.lucene.analysis.synonym.SynonymMap;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenFilter;
Expand All @@ -99,8 +101,10 @@
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.util.AttributeFactory;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IgnoreRandomChains;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
Expand Down Expand Up @@ -415,6 +419,27 @@ private String randomNonEmptyString(Random random) {
}
}
});
put(
Word2VecSynonymProvider.class,
random -> {
final int numEntries = atLeast(10);
final int vectorDimension = random.nextInt(99) + 1;
Word2VecModel model = new Word2VecModel(numEntries, vectorDimension);
for (int j = 0; j < numEntries; j++) {
String s = TestUtil.randomSimpleString(random, 10, 20);
float[] vec = new float[vectorDimension];
for (int i = 0; i < vectorDimension; i++) {
vec[i] = random.nextFloat();
}
model.addTermAndVector(new TermAndVector(new BytesRef(s), vec));
}
try {
return new Word2VecSynonymProvider(model);
} catch (IOException e) {
Rethrow.rethrow(e);
return null; // unreachable code
}
});
put(
DateFormat.class,
random -> {
Expand Down
2 changes: 2 additions & 0 deletions lucene/analysis/common/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
exports org.apache.lucene.analysis.sr;
exports org.apache.lucene.analysis.sv;
exports org.apache.lucene.analysis.synonym;
exports org.apache.lucene.analysis.synonym.word2vec;
exports org.apache.lucene.analysis.ta;
exports org.apache.lucene.analysis.te;
exports org.apache.lucene.analysis.th;
Expand Down Expand Up @@ -256,6 +257,7 @@
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory,
org.apache.lucene.analysis.synonym.SynonymFilterFactory,
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory,
org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory,
org.apache.lucene.analysis.core.FlattenGraphFilterFactory,
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory,
org.apache.lucene.analysis.te.TeluguStemFilterFactory,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.analysis.synonym.word2vec;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Locale;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndVector;

/**
* Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a
* Word2VecModel with normalized vectors
*
* <p>Dl4j Word2Vec documentation:
* https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to
* generate a model using dl4j:
* https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java
*
* @lucene.experimental
*/
public class Dl4jModelReader implements Closeable {

private static final String MODEL_FILE_NAME_PREFIX = "syn0";

private final ZipInputStream word2VecModelZipFile;

public Dl4jModelReader(InputStream stream) {
this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream));
}

public Word2VecModel read() throws IOException {

ZipEntry entry;
while ((entry = word2VecModelZipFile.getNextEntry()) != null) {
String fileName = entry.getName();
if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) {
BufferedReader reader =
new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8));

String header = reader.readLine();
String[] headerValues = header.split(" ");
int dictionarySize = Integer.parseInt(headerValues[0]);
int vectorDimension = Integer.parseInt(headerValues[1]);

Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension);
String line = reader.readLine();
boolean isTermB64Encoded = false;
if (line != null) {
String[] tokens = line.split(" ");
isTermB64Encoded =
tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0;
model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
}
while ((line = reader.readLine()) != null) {
String[] tokens = line.split(" ");
model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
}
return model;
}
}
throw new IllegalArgumentException(
"Cannot read Dl4j word2vec model - '"
+ MODEL_FILE_NAME_PREFIX
+ "' file is missing in the zip. '"
+ MODEL_FILE_NAME_PREFIX
+ "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library.");
}

private static TermAndVector extractTermAndVector(
String[] tokens, int vectorDimension, boolean isTermB64Encoded) {
BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0]));

float[] vector = new float[tokens.length - 1];

if (vectorDimension != vector.length) {
throw new RuntimeException(
String.format(
Locale.ROOT,
"Word2Vec model file corrupted. "
+ "Declared vectors of size %d but found vector of size %d for word %s (%s)",
vectorDimension,
vector.length,
tokens[0],
term.utf8ToString()));
}

for (int i = 1; i < tokens.length; i++) {
vector[i - 1] = Float.parseFloat(tokens[i]);
}
return new TermAndVector(term, vector);
}

static BytesRef decodeB64Term(String term) {
byte[] buffer = Base64.getDecoder().decode(term.substring(4));
return new BytesRef(buffer, 0, buffer.length);
}

@Override
public void close() throws IOException {
word2VecModelZipFile.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.analysis.synonym.word2vec;

import java.io.IOException;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
* word in dictionary
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues<float[]> {

private final int dictionarySize;
private final int vectorDimension;
private final TermAndVector[] termsAndVectors;
private final BytesRefHash word2Vec;
private int loadedCount = 0;

public Word2VecModel(int dictionarySize, int vectorDimension) {
this.dictionarySize = dictionarySize;
this.vectorDimension = vectorDimension;
this.termsAndVectors = new TermAndVector[dictionarySize];
this.word2Vec = new BytesRefHash();
}

private Word2VecModel(
int dictionarySize,
int vectorDimension,
TermAndVector[] termsAndVectors,
BytesRefHash word2Vec) {
this.dictionarySize = dictionarySize;
this.vectorDimension = vectorDimension;
this.termsAndVectors = termsAndVectors;
this.word2Vec = word2Vec;
}

public void addTermAndVector(TermAndVector modelEntry) {
modelEntry.normalizeVector();
this.termsAndVectors[loadedCount++] = modelEntry;
this.word2Vec.add(modelEntry.getTerm());
}

@Override
public float[] vectorValue(int targetOrd) {
return termsAndVectors[targetOrd].getVector();
}

public float[] vectorValue(BytesRef term) {
int termOrd = this.word2Vec.find(term);
if (termOrd < 0) return null;
TermAndVector entry = this.termsAndVectors[termOrd];
return (entry == null) ? null : entry.getVector();
}

public BytesRef termValue(int targetOrd) {
return termsAndVectors[targetOrd].getTerm();
}

@Override
public int dimension() {
return vectorDimension;
}

@Override
public int size() {
return dictionarySize;
}

@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
return new Word2VecModel(
this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec);
}
}
Loading

0 comments on commit 53709cc

Please sign in to comment.