diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 4ee799d12c18f..e420a4de51b5d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -22,10 +22,10 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor; -import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -46,16 +46,16 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { public static class StringLengthAndEmbedding { - final int stringLen; + final int utf8StringLen; final double[] embedding; - public StringLengthAndEmbedding(int stringLen, double[] embedding) { - this.stringLen = stringLen; + public StringLengthAndEmbedding(int utf8StringLen, double[] embedding) { + this.utf8StringLen = utf8StringLen; this.embedding = embedding; } - public int getStringLen() { - return stringLen; + public int getUtf8StringLen() { + return utf8StringLen; } public double[] getEmbedding() { @@ -282,7 +282,7 @@ public void process(Map fields) { if (i >= codePoints.length) { break; } - ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i])); + Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]); int j = i + 1; for (; j < codePoints.length; j++) { while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) { @@ -291,11 +291,11 @@ public void process(Map fields) { if (j >= codePoints.length) { break; } - ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j])); - if (j1 != currentCode && j1 != ScriptCode.Inherited) { + Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]); + if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) { if (j < codePoints.length - 1) { - ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1])); - if (j2 != ScriptCode.Common && j2 != currentCode) { + Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]); + if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) { break; } } @@ -314,7 +314,11 @@ public void process(Map fields) { embeddings.add( new StringLengthAndEmbedding( // Don't count white spaces as bytes for the prediction - str.trim().length(), + // We ues utf-8 length here as + // * The original C++ implementation does this when measuring string length + // * Languages with complex characters (like zh) convey more information per a single utf-16 character and + // using utf-8 length captures that. + str.trim().getBytes(StandardCharsets.UTF_8).length, concatEmbeddings( FEATURE_EXTRACTORS.stream() .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString())) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index 5f69f7a961b10..f732419c7e98c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -230,23 +230,22 @@ public InferenceResults infer(Map fields, InferenceConfig config ); } List embeddedVector = (List) vector; - double[] scores = new double[LANGUAGE_NAMES.size()]; + double[] probabilities = new double[LANGUAGE_NAMES.size()]; int totalLen = 0; for (Object vec : embeddedVector) { if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) { continue; } CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec; - int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen(); + int square = stringLengthAndEmbedding.getUtf8StringLen() * stringLengthAndEmbedding.getUtf8StringLen(); totalLen += square; double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding()); double[] score = softmaxLayer.productPlusBias(true, h0); - sumDoubleArrays(scores, score, Math.max(square, 1)); + sumDoubleArrays(probabilities, softMax(score), Math.max(square, 1)); } if (totalLen != 0) { - divMut(scores, totalLen); + divMut(probabilities, totalLen); } - double[] probabilities = softMax(scores); ClassificationConfig classificationConfig = (ClassificationConfig) config; Tuple> topClasses = InferenceHelpers.topClasses( probabilities, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index a572d47eafb6a..ed7fe76d23f16 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -29,7 +29,6 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.closeTo; -import static org.hamcrest.Matchers.greaterThan; import static org.mockito.Mockito.mock; public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { @@ -103,6 +102,12 @@ public void testMixedLangInference() throws Exception { singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig); assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("매트 스미스는 BBC äôs Doctor Who를 그만둔다."), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( inferenceObj( "@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual." @@ -112,6 +117,34 @@ public void testMixedLangInference() throws Exception { ); assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj( + "김걸도혁(金乞都革) 김공소(金公疎) 김교합(金咬哈) 김다롱합(金多弄哈) 김마상개(金麻尙介) 김우리개(金于里介) 김상미(金尙美) 김아도을치(金阿都乙赤) " + + "김아라(金阿喇) 김아랑합(金阿郞哈) 김아을가(金阿乙加) 김역류(金易留) 김우두(金于豆) 김우허내(金右虛乃) 김유리가(金留里加) 김윤적(金允績) " + + "김이랑합(金伊郞哈) 김인을개(金引乙介) 김입성(金入成) 김주창개(金主昌介) 김지하리(金之下里) 김차독(金箚禿) 김지칭가(金只稱哥) 김자라노(金者羅老)." + ), + classificationConfig + ); + // Half the string is ko the other half is zh + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.5, 0.1)); + assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("zh")); + assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.5, 0.1)); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj( + "[ Republic of Korea ],\n" + + "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n" + + "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n" + + " !대한민국(, 영어: Republic of Korea, KOR)은 동아시아의 한반도 남부에 자리한 민주공화국이다. 서쪽으로 중화인민공화국과 황해를 사이에 두고" + ), + classificationConfig + ); + // Majority of the text is obviously Thai, but a close second is Korean + assertThat(singleValueInferenceResults.valueAsString(), equalTo("th")); + assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.6, 0.1)); + assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("ko")); + assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.4, 0.1)); } public void testLangInference() throws Exception { @@ -131,7 +164,9 @@ public void testLangInference() throws Exception { ); assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual)); - Matcher matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001); + // The stored language example is a mixture of `ja` and other languages, it should not be predicted with 1.0 accuracy as the + // cld3 probability indicates. + Matcher matcher = entry.getLanguage().equals("ja") ? closeTo(cld3Probability, 0.11) : closeTo(cld3Probability, .01); assertThat( "mismatch probability for language " + cld3Actual, singleValueInferenceResults.getTopClasses().get(0).getProbability(),