Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] fix LangIdent model when multiple unicode scripts are present #81876

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -258,7 +258,7 @@ public void process(Map<String, Object> fields) {
if (i >= codePoints.length) {
break;
}
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
int j = i + 1;
for (; j < codePoints.length; j++) {
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
Expand All @@ -267,11 +267,11 @@ public void process(Map<String, Object> fields) {
if (j >= codePoints.length) {
break;
}
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]);
if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) {
if (j < codePoints.length - 1) {
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
if (j2 != ScriptCode.Common && j2 != currentCode) {
Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]);
if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) {
break;
}
}
Expand All @@ -290,7 +290,7 @@ public void process(Map<String, Object> fields) {
embeddings.add(
new StringLengthAndEmbedding(
// Don't count white spaces as bytes for the prediction
str.trim().length(),
str.trim().getBytes(StandardCharsets.UTF_8).length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment here to say that using the number of UTF-8 bytes:

  • Matches what the equivalent Python code did
  • Acts as a heuristic to account for the fact that languages like Chinese embed more information in each character so using the number of UTF-8 bytes gives them a boost to compensate for shorter words

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will!

concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
);
}
List<?> embeddedVector = (List<?>) vector;
double[] scores = new double[LANGUAGE_NAMES.size()];
double[] probabilities = new double[LANGUAGE_NAMES.size()];
int totalLen = 0;
for (Object vec : embeddedVector) {
if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
Expand All @@ -240,12 +240,11 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
totalLen += square;
double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
double[] score = softmaxLayer.productPlusBias(true, h0);
sumDoubleArrays(scores, score, Math.max(square, 1));
sumDoubleArrays(probabilities, softMax(score), Math.max(square, 1));
}
if (totalLen != 0) {
divMut(scores, totalLen);
divMut(probabilities, totalLen);
Copy link
Contributor

@tveasey tveasey Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it is worth noting somewhere (maybe in docs) in multilingual cases the probabilities we report are related to the fraction of the document which is classified with the language type. (We can probably just gloss over the fact we give short fragments less weight though.)

}
double[] probabilities = softMax(scores);
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.Mockito.mock;

public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
Expand Down Expand Up @@ -103,6 +102,12 @@ public void testMixedLangInference() throws Exception {
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("매트 스미스는 BBC äôs Doctor Who를 그만둔다."),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
Expand All @@ -112,6 +117,34 @@ public void testMixedLangInference() throws Exception {
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"김걸도혁(金乞都革) 김공소(金公疎) 김교합(金咬哈) 김다롱합(金多弄哈) 김마상개(金麻尙介) 김우리개(金于里介) 김상미(金尙美) 김아도을치(金阿都乙赤) "
+ "김아라(金阿喇) 김아랑합(金阿郞哈) 김아을가(金阿乙加) 김역류(金易留) 김우두(金于豆) 김우허내(金右虛乃) 김유리가(金留里加) 김윤적(金允績) "
+ "김이랑합(金伊郞哈) 김인을개(金引乙介) 김입성(金入成) 김주창개(金主昌介) 김지하리(金之下里) 김차독(金箚禿) 김지칭가(金只稱哥) 김자라노(金者羅老)."
),
classificationConfig
);
// Half the string is ko the other half is zh
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.5, 0.1));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("zh"));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.5, 0.1));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"[ Republic of Korea ],\n"
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
+ " !대한민국(, 영어: Republic of Korea, KOR)은 동아시아의 한반도 남부에 자리한 민주공화국이다. 서쪽으로 중화인민공화국과 황해를 사이에 두고"
),
classificationConfig
);
// Majority of the text is obviously Thai, but a close second is Korean
assertThat(singleValueInferenceResults.valueAsString(), equalTo("th"));
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.6, 0.1));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("ko"));
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.4, 0.1));
}

public void testLangInference() throws Exception {
Expand All @@ -131,7 +164,9 @@ public void testLangInference() throws Exception {
);

assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
// The stored language example is a mixture of `ja` and other languages, it should not be predicted with 1.0 accuracy as the
// cld3 probability indicates.
Matcher<Double> matcher = entry.getLanguage().equals("ja") ? closeTo(cld3Probability, 0.11) : closeTo(cld3Probability, .01);
assertThat(
"mismatch probability for language " + cld3Actual,
singleValueInferenceResults.getTopClasses().get(0).getProbability(),
Expand Down