Skip to content

Commit

Permalink
[ML] fix LangIdent model when multiple unicode scripts are present
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Dec 17, 2021
1 parent 03954f1 commit 2ee148a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -258,7 +258,7 @@ public void process(Map<String, Object> fields) {
if (i >= codePoints.length) {
break;
}
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
int j = i + 1;
for (; j < codePoints.length; j++) {
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
Expand All @@ -267,11 +267,11 @@ public void process(Map<String, Object> fields) {
if (j >= codePoints.length) {
break;
}
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]);
if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) {
if (j < codePoints.length - 1) {
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
if (j2 != ScriptCode.Common && j2 != currentCode) {
Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]);
if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) {
break;
}
}
Expand All @@ -290,7 +290,7 @@ public void process(Map<String, Object> fields) {
embeddings.add(
new StringLengthAndEmbedding(
// Don't count white spaces as bytes for the prediction
str.trim().length(),
str.trim().getBytes(StandardCharsets.UTF_8).length,
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
);
}
List<?> embeddedVector = (List<?>) vector;
double[] scores = new double[LANGUAGE_NAMES.size()];
double[] probabilities = new double[LANGUAGE_NAMES.size()];
int totalLen = 0;
for (Object vec : embeddedVector) {
if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
Expand All @@ -240,12 +240,11 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
totalLen += square;
double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
double[] score = softmaxLayer.productPlusBias(true, h0);
sumDoubleArrays(scores, score, Math.max(square, 1));
sumDoubleArrays(probabilities, softMax(score), Math.max(square, 1));
}
if (totalLen != 0) {
divMut(scores, totalLen);
divMut(probabilities, totalLen);
}
double[] probabilities = softMax(scores);
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.Mockito.mock;

public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
Expand Down Expand Up @@ -103,6 +102,12 @@ public void testMixedLangInference() throws Exception {
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("매트 스미스는 BBC äôs Doctor Who를 그만둔다."),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
Expand All @@ -112,6 +117,34 @@ public void testMixedLangInference() throws Exception {
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"김걸도혁(金乞都革) 김공소(金公疎) 김교합(金咬哈) 김다롱합(金多弄哈) 김마상개(金麻尙介) 김우리개(金于里介) 김상미(金尙美) 김아도을치(金阿都乙赤) "
+ "김아라(金阿喇) 김아랑합(金阿郞哈) 김아을가(金阿乙加) 김역류(金易留) 김우두(金于豆) 김우허내(金右虛乃) 김유리가(金留里加) 김윤적(金允績) "
+ "김이랑합(金伊郞哈) 김인을개(金引乙介) 김입성(金入成) 김주창개(金主昌介) 김지하리(金之下里) 김차독(金箚禿) 김지칭가(金只稱哥) 김자라노(金者羅老)."
),
classificationConfig
);
// Half the string is ko the other half is zh
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.5, 0.1));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("zh"));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.5, 0.1));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"[ Republic of Korea ],\n"
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+ " !대한민국(, 영어: Republic of Korea, KOR)은 동아시아의 한반도 남부에 자리한 민주공화국이다. 서쪽으로 중화인민공화국과 황해를 사이에 두고"
),
classificationConfig
);
// Majority of the text is obviously Thai, but a close second is Korean
assertThat(singleValueInferenceResults.valueAsString(), equalTo("th"));
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.6, 0.1));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("ko"));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.4, 0.1));
}

public void testLangInference() throws Exception {
Expand All @@ -131,7 +164,9 @@ public void testLangInference() throws Exception {
);

assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
// The stored language example is a mixture of `ja` and other languages, it should not be predicted with 1.0 accuracy as the
// cld3 probability indicates.
Matcher<Double> matcher = entry.getLanguage().equals("ja") ? closeTo(cld3Probability, 0.11) : closeTo(cld3Probability, .01);
assertThat(
"mismatch probability for language " + cld3Actual,
singleValueInferenceResults.getTopClasses().get(0).getProbability(),
Expand Down

0 comments on commit 2ee148a

Please sign in to comment.