Skip to content

Commit

Permalink
[ML] fix LangIdent model when multiple unicode scripts are present (#…
Browse files Browse the repository at this point in the history
…81876)

LangIdent was recently updated to handle multiple unicode scripts (#80675). But this introduced some bugs fixed with this commit.

1. Sections with the same scripted were weighted by Java string length (utf-16) encoding. This is not accurate as certain languages (like Chinese and Korean) convey much more information with fewer utf-16 characters. FIX weight by utf-8 length.
2. The weighing of different language scores was done via the raw score from the neural network. This caused languages with a high score (but low compared to most likely language) from the network to be inaccurately weighted. FIX We are now instead weighing the probabilities of the sections of the text.
3. To split the input across the multiple scripts, we split on the "paired down" CDL3 script types. Java has superior support for unicode script blocks. FIX split by Java unicode script blocks not by the paired down CDL3 scripts
  • Loading branch information
benwtrent authored Dec 17, 2021
1 parent 3d5337d commit 4b0864d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -45,16 +45,16 @@
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {

public static class StringLengthAndEmbedding {
final int stringLen;
final int utf8StringLen;
final double[] embedding;

public StringLengthAndEmbedding(int stringLen, double[] embedding) {
this.stringLen = stringLen;
public StringLengthAndEmbedding(int utf8StringLen, double[] embedding) {
this.utf8StringLen = utf8StringLen;
this.embedding = embedding;
}

public int getStringLen() {
return stringLen;
public int getUtf8StringLen() {
return utf8StringLen;
}

public double[] getEmbedding() {
Expand Down Expand Up @@ -258,7 +258,7 @@ public void process(Map<String, Object> fields) {
if (i >= codePoints.length) {
break;
}
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
int j = i + 1;
for (; j < codePoints.length; j++) {
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
Expand All @@ -267,11 +267,11 @@ public void process(Map<String, Object> fields) {
if (j >= codePoints.length) {
break;
}
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]);
if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) {
if (j < codePoints.length - 1) {
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
if (j2 != ScriptCode.Common && j2 != currentCode) {
Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]);
if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) {
break;
}
}
Expand All @@ -290,7 +290,11 @@ public void process(Map<String, Object> fields) {
embeddings.add(
new StringLengthAndEmbedding(
// Don't count white spaces as bytes for the prediction
str.trim().length(),
// We ues utf-8 length here as
// * The original C++ implementation does this when measuring string length
// * Languages with complex characters (like zh) convey more information per a single utf-16 character and
// using utf-8 length captures that.
str.trim().getBytes(StandardCharsets.UTF_8).length,
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,22 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
);
}
List<?> embeddedVector = (List<?>) vector;
double[] scores = new double[LANGUAGE_NAMES.size()];
double[] probabilities = new double[LANGUAGE_NAMES.size()];
int totalLen = 0;
for (Object vec : embeddedVector) {
if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
continue;
}
CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen();
int square = stringLengthAndEmbedding.getUtf8StringLen() * stringLengthAndEmbedding.getUtf8StringLen();
totalLen += square;
double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
double[] score = softmaxLayer.productPlusBias(true, h0);
sumDoubleArrays(scores, score, Math.max(square, 1));
sumDoubleArrays(probabilities, softMax(score), Math.max(square, 1));
}
if (totalLen != 0) {
divMut(scores, totalLen);
divMut(probabilities, totalLen);
}
double[] probabilities = softMax(scores);
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.Mockito.mock;

public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
Expand Down Expand Up @@ -103,6 +102,12 @@ public void testMixedLangInference() throws Exception {
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("매트 스미스는 BBC äôs Doctor Who를 그만둔다."),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
Expand All @@ -112,6 +117,34 @@ public void testMixedLangInference() throws Exception {
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"김걸도혁(金乞都革) 김공소(金公疎) 김교합(金咬哈) 김다롱합(金多弄哈) 김마상개(金麻尙介) 김우리개(金于里介) 김상미(金尙美) 김아도을치(金阿都乙赤) "
+ "김아라(金阿喇) 김아랑합(金阿郞哈) 김아을가(金阿乙加) 김역류(金易留) 김우두(金于豆) 김우허내(金右虛乃) 김유리가(金留里加) 김윤적(金允績) "
+ "김이랑합(金伊郞哈) 김인을개(金引乙介) 김입성(金入成) 김주창개(金主昌介) 김지하리(金之下里) 김차독(金箚禿) 김지칭가(金只稱哥) 김자라노(金者羅老)."
),
classificationConfig
);
// Half the string is ko the other half is zh
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.5, 0.1));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("zh"));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.5, 0.1));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"[ Republic of Korea ],\n"
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+ " !대한민국(, 영어: Republic of Korea, KOR)은 동아시아의 한반도 남부에 자리한 민주공화국이다. 서쪽으로 중화인민공화국과 황해를 사이에 두고"
),
classificationConfig
);
// Majority of the text is obviously Thai, but a close second is Korean
assertThat(singleValueInferenceResults.valueAsString(), equalTo("th"));
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.6, 0.1));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("ko"));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.4, 0.1));
}

public void testLangInference() throws Exception {
Expand All @@ -131,7 +164,9 @@ public void testLangInference() throws Exception {
);

assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
// The stored language example is a mixture of `ja` and other languages, it should not be predicted with 1.0 accuracy as the
// cld3 probability indicates.
Matcher<Double> matcher = entry.getLanguage().equals("ja") ? closeTo(cld3Probability, 0.11) : closeTo(cld3Probability, .01);
assertThat(
"mismatch probability for language " + cld3Actual,
singleValueInferenceResults.getTopClasses().get(0).getProbability(),
Expand Down

0 comments on commit 4b0864d

Please sign in to comment.