-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Fix language identification bug when multi-languages are present #80675
Changes from 1 commit
f339cde
b40a9ca
b8e868f
ff529c4
f945f01
aa3779d
fb3182c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -20,10 +20,12 @@ | |||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue; | ||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor; | ||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor; | ||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode; | ||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor; | ||||||
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; | ||||||
|
||||||
import java.io.IOException; | ||||||
import java.nio.charset.StandardCharsets; | ||||||
import java.util.ArrayList; | ||||||
import java.util.Arrays; | ||||||
import java.util.Collections; | ||||||
|
@@ -43,6 +45,24 @@ | |||||
*/ | ||||||
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { | ||||||
|
||||||
public static class ByteSizeAndEmbedding { | ||||||
final int utf8ByteSize; | ||||||
final double[] embedding; | ||||||
|
||||||
public ByteSizeAndEmbedding(int utf8ByteSize, double[] embedding) { | ||||||
this.utf8ByteSize = utf8ByteSize; | ||||||
this.embedding = embedding; | ||||||
} | ||||||
|
||||||
public int getUtf8ByteSize() { | ||||||
return utf8ByteSize; | ||||||
} | ||||||
|
||||||
public double[] getEmbedding() { | ||||||
return embedding; | ||||||
} | ||||||
} | ||||||
|
||||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class); | ||||||
public static final int MAX_STRING_SIZE_IN_BYTES = 10000; | ||||||
public static final ParseField NAME = new ParseField("custom_word_embedding"); | ||||||
|
@@ -214,10 +234,75 @@ public void process(Map<String, Object> fields) { | |||||
text = FeatureUtils.cleanAndLowerText(text); | ||||||
text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES); | ||||||
String finalText = text; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be clearer if it's explicitly
Suggested change
|
||||||
List<FeatureValue[]> processedFeatures = FEATURE_EXTRACTORS.stream() | ||||||
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) | ||||||
.collect(Collectors.toList()); | ||||||
fields.put(destField, concatEmbeddings(processedFeatures)); | ||||||
if (text.isEmpty() || text.isBlank()) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems potentially confusing to mix
Suggested change
|
||||||
fields.put( | ||||||
destField, | ||||||
Arrays.asList( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(because |
||||||
new ByteSizeAndEmbedding( | ||||||
// Don't count white spaces as bytes for the prediction | ||||||
finalText.trim().getBytes(StandardCharsets.UTF_8).length, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
If this is wrong please add a comment saying how the trimmed length of a blank or empty string can be > 0 |
||||||
concatEmbeddings( | ||||||
FEATURE_EXTRACTORS.stream() | ||||||
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) | ||||||
.collect(Collectors.toList()) | ||||||
) | ||||||
) | ||||||
) | ||||||
); | ||||||
return; | ||||||
} | ||||||
List<ByteSizeAndEmbedding> embeddings = new ArrayList<>(); | ||||||
int[] codePoints = finalText.codePoints().toArray(); | ||||||
for (int i = 0; i < codePoints.length - 1;) { | ||||||
while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) { | ||||||
i++; | ||||||
} | ||||||
if (i >= codePoints.length) { | ||||||
break; | ||||||
} | ||||||
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i])); | ||||||
int j = i + 1; | ||||||
for (; j < codePoints.length; j++) { | ||||||
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) { | ||||||
j++; | ||||||
} | ||||||
if (j >= codePoints.length) { | ||||||
break; | ||||||
} | ||||||
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j])); | ||||||
if (j1 != currentCode && j1 != ScriptCode.Inherited) { | ||||||
if (j < codePoints.length - 1) { | ||||||
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1])); | ||||||
if (j2 != ScriptCode.Common && j2 != currentCode) { | ||||||
break; | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
// Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces | ||||||
String str = new String(codePoints, i, j - i); | ||||||
StringBuilder builder = new StringBuilder(); | ||||||
if (str.startsWith(" ") == false) { | ||||||
builder.append(" "); | ||||||
} | ||||||
builder.append(str); | ||||||
if (str.endsWith(" ") == false) { | ||||||
builder.append(" "); | ||||||
} | ||||||
embeddings.add( | ||||||
new ByteSizeAndEmbedding( | ||||||
// Don't count white spaces as bytes for the prediction | ||||||
str.trim().getBytes(StandardCharsets.UTF_8).length, | ||||||
concatEmbeddings( | ||||||
FEATURE_EXTRACTORS.stream() | ||||||
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString())) | ||||||
.collect(Collectors.toList()) | ||||||
) | ||||||
) | ||||||
); | ||||||
i = j; | ||||||
} | ||||||
fields.put(destField, embeddings); | ||||||
} | ||||||
|
||||||
@Override | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,6 +15,7 @@ | |||||
import org.elasticsearch.xcontent.ParseField; | ||||||
import org.elasticsearch.xcontent.XContentBuilder; | ||||||
import org.elasticsearch.xcontent.XContentParser; | ||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding; | ||||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; | ||||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; | ||||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; | ||||||
|
@@ -36,6 +37,8 @@ | |||||
import java.util.Objects; | ||||||
|
||||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut; | ||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays; | ||||||
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; | ||||||
|
||||||
public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel { | ||||||
|
@@ -217,27 +220,32 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config | |||||
throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName()); | ||||||
} | ||||||
Object vector = fields.get(embeddedVectorFeatureName); | ||||||
if (vector instanceof double[] == false) { | ||||||
if (vector instanceof List<?> == false) { | ||||||
throw ExceptionsHelper.badRequestException( | ||||||
"[{}] model could not find non-null numerical array named [{}]", | ||||||
"[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. " | ||||||
+ "Please verify that the input is a string.", | ||||||
NAME.getPreferredName(), | ||||||
embeddedVectorFeatureName | ||||||
); | ||||||
} | ||||||
double[] embeddedVector = (double[]) vector; | ||||||
if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) { | ||||||
throw ExceptionsHelper.badRequestException( | ||||||
"[{}] model is expecting embedding vector of length [{}] but got [{}]", | ||||||
NAME.getPreferredName(), | ||||||
EMBEDDING_VECTOR_LENGTH, | ||||||
embeddedVector.length | ||||||
); | ||||||
List<?> embeddedVector = (List<?>) vector; | ||||||
double[] scores = new double[LANGUAGE_NAMES.size()]; | ||||||
int totalByteSize = 0; | ||||||
for (Object vec : embeddedVector) { | ||||||
if (vec instanceof CustomWordEmbedding.ByteSizeAndEmbedding == false) { | ||||||
continue; | ||||||
} | ||||||
CustomWordEmbedding.ByteSizeAndEmbedding byteSizeAndEmbedding = (CustomWordEmbedding.ByteSizeAndEmbedding) vec; | ||||||
int square = (int) Math.pow(byteSizeAndEmbedding.getUtf8ByteSize(), 2); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I strongly suspect multiplying two integers is much faster than using some generic x^y algorithm that works on arbitrary floating point numbers.
Suggested change
|
||||||
totalByteSize += square; | ||||||
double[] h0 = hiddenLayer.productPlusBias(false, byteSizeAndEmbedding.getEmbedding()); | ||||||
double[] score = softmaxLayer.productPlusBias(true, h0); | ||||||
sumDoubleArrays(scores, score, Math.max(square, 1)); | ||||||
} | ||||||
if (totalByteSize != 0) { | ||||||
divMut(scores, totalByteSize); | ||||||
} | ||||||
double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector); | ||||||
double[] scores = softmaxLayer.productPlusBias(true, h0); | ||||||
|
||||||
double[] probabilities = softMax(scores); | ||||||
|
||||||
ClassificationConfig classificationConfig = (ClassificationConfig) config; | ||||||
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses( | ||||||
probabilities, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it very strange that the weighting is the number of UTF-8 bytes, not the number of characters.
That means that if I have some text that's 100 characters of Roman alphabet and 100 Chinese characters then the Chinese could get a weighting of 300 while the western language gets a weighting of 100. Is the byte count a sneaky heuristic for saying each Chinese character embeds more information than a Roman alphabet character? It would be good to add a comment with the justification.