From f339cdeffb0c1ce5ac727fb7abb88cee515d4564 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 11 Nov 2021 08:18:10 -0500 Subject: [PATCH 1/4] [ML] Fix language identification bug when multi-languages are present MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Language identification works fairly well when only one language and script type is present. But when multiple are present, it can return some unexpected results Example: "행 레이블 this is english text obviously and 생성 tom said to test it" Which appears to a human to be english text (Latin unicode) with Korean via Hangul unicode is erroneously categorized as Japanese. It should be categorized as English as it is the dominate language and script type. This commit fixes this bug by doing the following: - Input text is partitioned into common, continuous, unicode script sections - Those sections individual language scores are gathered - Each score is then weighted according to the number of utf-8 bytes in each section - The resulting weight scores are transformed into probabilities - The final probabilities are the ones returned to the user. --- .../preprocessing/CustomWordEmbedding.java | 93 +++++++++++++- .../trainedmodel/InferenceHelpers.java | 18 ++- .../langident/LangIdentNeuralNetwork.java | 36 +++--- .../LangIdentNeuralNetworkInferenceTests.java | 116 +++++++++++++++--- 4 files changed, 228 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 989f53f19ef2e..c251253809783 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -20,10 +20,12 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor; import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -43,6 +45,24 @@ */ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + public static class ByteSizeAndEmbedding { + final int utf8ByteSize; + final double[] embedding; + + public ByteSizeAndEmbedding(int utf8ByteSize, double[] embedding) { + this.utf8ByteSize = utf8ByteSize; + this.embedding = embedding; + } + + public int getUtf8ByteSize() { + return utf8ByteSize; + } + + public double[] getEmbedding() { + return embedding; + } + } + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class); public static final int MAX_STRING_SIZE_IN_BYTES = 10000; public static final ParseField NAME = new ParseField("custom_word_embedding"); @@ -214,10 +234,75 @@ public void process(Map fields) { text = FeatureUtils.cleanAndLowerText(text); text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES); String finalText = text; - List processedFeatures = FEATURE_EXTRACTORS.stream() - .map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) - .collect(Collectors.toList()); - fields.put(destField, concatEmbeddings(processedFeatures)); + if (text.isEmpty() || text.isBlank()) { + fields.put( + destField, + Arrays.asList( + new ByteSizeAndEmbedding( + // Don't count white spaces as bytes for the prediction + finalText.trim().getBytes(StandardCharsets.UTF_8).length, + concatEmbeddings( + FEATURE_EXTRACTORS.stream() + .map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) + .collect(Collectors.toList()) + ) + ) + ) + ); + return; + } + List embeddings = new ArrayList<>(); + int[] codePoints = finalText.codePoints().toArray(); + for (int i = 0; i < codePoints.length - 1;) { + while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) { + i++; + } + if (i >= codePoints.length) { + break; + } + ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i])); + int j = i + 1; + for (; j < codePoints.length; j++) { + while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) { + j++; + } + if (j >= codePoints.length) { + break; + } + ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j])); + if (j1 != currentCode && j1 != ScriptCode.Inherited) { + if (j < codePoints.length - 1) { + ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1])); + if (j2 != ScriptCode.Common && j2 != currentCode) { + break; + } + } + } + } + // Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces + String str = new String(codePoints, i, j - i); + StringBuilder builder = new StringBuilder(); + if (str.startsWith(" ") == false) { + builder.append(" "); + } + builder.append(str); + if (str.endsWith(" ") == false) { + builder.append(" "); + } + embeddings.add( + new ByteSizeAndEmbedding( + // Don't count white spaces as bytes for the prediction + str.trim().getBytes(StandardCharsets.UTF_8).length, + concatEmbeddings( + FEATURE_EXTRACTORS.stream() + .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString())) + .collect(Collectors.toList()) + ) + ) + ); + i = j; + } + fields.put(destField, embeddings); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 0011abb45fc70..f214070e20949 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -188,13 +188,29 @@ public static List transformFeatureImportanceCl } public static double[] sumDoubleArrays(double[] sumTo, double[] inc) { + return sumDoubleArrays(sumTo, inc, 1); + } + + public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) { assert sumTo != null && inc != null && sumTo.length == inc.length; for (int i = 0; i < inc.length; i++) { - sumTo[i] += inc[i]; + sumTo[i] += (inc[i] * weight); } return sumTo; } + public static void divMut(double[] xs, int v) { + if (xs.length == 0) { + return; + } + if (v == 0) { + throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior"); + } + for (int i = 0; i < xs.length; i++) { + xs[i] /= v; + } + } + public static class TopClassificationValue { private final int value; private final double probability; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index 9890eda059172..b08fc27e31101 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; @@ -36,6 +37,8 @@ import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays; import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel { @@ -217,27 +220,32 @@ public InferenceResults infer(Map fields, InferenceConfig config throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName()); } Object vector = fields.get(embeddedVectorFeatureName); - if (vector instanceof double[] == false) { + if (vector instanceof List == false) { throw ExceptionsHelper.badRequestException( - "[{}] model could not find non-null numerical array named [{}]", + "[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. " + + "Please verify that the input is a string.", NAME.getPreferredName(), embeddedVectorFeatureName ); } - double[] embeddedVector = (double[]) vector; - if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) { - throw ExceptionsHelper.badRequestException( - "[{}] model is expecting embedding vector of length [{}] but got [{}]", - NAME.getPreferredName(), - EMBEDDING_VECTOR_LENGTH, - embeddedVector.length - ); + List embeddedVector = (List) vector; + double[] scores = new double[LANGUAGE_NAMES.size()]; + int totalByteSize = 0; + for (Object vec : embeddedVector) { + if (vec instanceof CustomWordEmbedding.ByteSizeAndEmbedding == false) { + continue; + } + CustomWordEmbedding.ByteSizeAndEmbedding byteSizeAndEmbedding = (CustomWordEmbedding.ByteSizeAndEmbedding) vec; + int square = (int) Math.pow(byteSizeAndEmbedding.getUtf8ByteSize(), 2); + totalByteSize += square; + double[] h0 = hiddenLayer.productPlusBias(false, byteSizeAndEmbedding.getEmbedding()); + double[] score = softmaxLayer.productPlusBias(true, h0); + sumDoubleArrays(scores, score, Math.max(square, 1)); + } + if (totalByteSize != 0) { + divMut(scores, totalByteSize); } - double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector); - double[] scores = softmaxLayer.productPlusBias(true, h0); - double[] probabilities = softMax(scores); - ClassificationConfig classificationConfig = (ClassificationConfig) config; Tuple> topClasses = InferenceHelpers.topClasses( probabilities, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index 7283c19891e51..51dd1addf744e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -20,30 +20,95 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.hamcrest.Matcher; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.greaterThan; import static org.mockito.Mockito.mock; public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { - public void testLangInference() throws Exception { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); - PlainActionFuture future = new PlainActionFuture<>(); - // Should be OK as we don't make any client calls - trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future); - TrainedModelConfig config = future.actionGet(); + public void testAdverseScenarios() throws Exception { + InferenceDefinition inferenceDefinition = grabModel(); + ClassificationConfig classificationConfig = new ClassificationConfig(5); - config.ensureParsedDefinition(xContentRegistry()); - TrainedModelDefinition trainedModelDefinition = config.getModelDefinition(); - InferenceDefinition inferenceDefinition = new InferenceDefinition( - (LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(), - trainedModelDefinition.getPreProcessors() + ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj(""), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj(" "), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("!@#$%^&*()"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("1234567890"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("-----=-=--=-=+__+_+__==-=-!@#$%^&*()"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("A"), classificationConfig); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("lb")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("„ÍÎÏ◊˝Ïδ„€‹›fifl‡°·‚∏ØÒÚÒ˘ÚÆ’ÆÚ”∏Ø\uF8FFÔÓ˝Ïδ„‹›fiˇflÁ¨ˆØ"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi")); + } + + public void testMixedLangInference() throws Exception { + InferenceDefinition inferenceDefinition = grabModel(); + ClassificationConfig classificationConfig = new ClassificationConfig(5); + + ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("행 레이블 this is english text obviously and 생성 tom said to test it "), + classificationConfig ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("en")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("행 레이블 Dashboard ISSUE Qual. Plan Qual. Report Qual. 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj( + "@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual." + + " 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈" + ), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + } + + public void testLangInference() throws Exception { + + InferenceDefinition inferenceDefinition = grabModel(); List examples = new LanguageExamples().getLanguageExamples(); ClassificationConfig classificationConfig = new ClassificationConfig(1); @@ -52,23 +117,42 @@ public void testLangInference() throws Exception { String cld3Actual = entry.getPredictedLanguage(); double cld3Probability = entry.getProbability(); - Map inferenceFields = new HashMap<>(); - inferenceFields.put("text", text); ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( - inferenceFields, + inferenceObj(text), classificationConfig ); assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual)); - double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001; + Matcher matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001); assertThat( "mismatch probability for language " + cld3Actual, singleValueInferenceResults.getTopClasses().get(0).getProbability(), - closeTo(cld3Probability, eps) + matcher ); } } + InferenceDefinition grabModel() throws IOException { + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + PlainActionFuture future = new PlainActionFuture<>(); + // Should be OK as we don't make any client calls + trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future); + TrainedModelConfig config = future.actionGet(); + + config.ensureParsedDefinition(xContentRegistry()); + TrainedModelDefinition trainedModelDefinition = config.getModelDefinition(); + return new InferenceDefinition( + (LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(), + trainedModelDefinition.getPreProcessors() + ); + } + + private static Map inferenceObj(String text) { + Map inferenceFields = new HashMap<>(); + inferenceFields.put("text", text); + return inferenceFields; + } + @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); From b40a9cacb91cbc4131edec5c916d2bc1ed03ae43 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 11 Nov 2021 15:17:22 -0500 Subject: [PATCH 2/4] addressing PR comments --- .../preprocessing/CustomWordEmbedding.java | 30 +++++++++---------- .../langident/LangIdentNeuralNetwork.java | 16 +++++----- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index c251253809783..95bb521964be4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -45,17 +44,17 @@ */ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - public static class ByteSizeAndEmbedding { - final int utf8ByteSize; + public static class StringLengthAndEmbedding { + final int stringLen; final double[] embedding; - public ByteSizeAndEmbedding(int utf8ByteSize, double[] embedding) { - this.utf8ByteSize = utf8ByteSize; + public StringLengthAndEmbedding(int stringLen, double[] embedding) { + this.stringLen = stringLen; this.embedding = embedding; } - public int getUtf8ByteSize() { - return utf8ByteSize; + public int getStringLen() { + return stringLen; } public double[] getEmbedding() { @@ -233,14 +232,13 @@ public void process(Map fields) { String text = (String) field; text = FeatureUtils.cleanAndLowerText(text); text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES); - String finalText = text; - if (text.isEmpty() || text.isBlank()) { + final String finalText = text; + if (finalText.isEmpty() || finalText.isBlank()) { fields.put( destField, - Arrays.asList( - new ByteSizeAndEmbedding( - // Don't count white spaces as bytes for the prediction - finalText.trim().getBytes(StandardCharsets.UTF_8).length, + Collections.singletonList( + new StringLengthAndEmbedding( + 0, concatEmbeddings( FEATURE_EXTRACTORS.stream() .map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) @@ -251,7 +249,7 @@ public void process(Map fields) { ); return; } - List embeddings = new ArrayList<>(); + List embeddings = new ArrayList<>(); int[] codePoints = finalText.codePoints().toArray(); for (int i = 0; i < codePoints.length - 1;) { while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) { @@ -290,9 +288,9 @@ public void process(Map fields) { builder.append(" "); } embeddings.add( - new ByteSizeAndEmbedding( + new StringLengthAndEmbedding( // Don't count white spaces as bytes for the prediction - str.trim().getBytes(StandardCharsets.UTF_8).length, + str.trim().length(), concatEmbeddings( FEATURE_EXTRACTORS.stream() .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString())) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index b08fc27e31101..047c24dce1a95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -230,20 +230,20 @@ public InferenceResults infer(Map fields, InferenceConfig config } List embeddedVector = (List) vector; double[] scores = new double[LANGUAGE_NAMES.size()]; - int totalByteSize = 0; + int totalLen = 0; for (Object vec : embeddedVector) { - if (vec instanceof CustomWordEmbedding.ByteSizeAndEmbedding == false) { + if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) { continue; } - CustomWordEmbedding.ByteSizeAndEmbedding byteSizeAndEmbedding = (CustomWordEmbedding.ByteSizeAndEmbedding) vec; - int square = (int) Math.pow(byteSizeAndEmbedding.getUtf8ByteSize(), 2); - totalByteSize += square; - double[] h0 = hiddenLayer.productPlusBias(false, byteSizeAndEmbedding.getEmbedding()); + CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec; + int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen(); + totalLen += square; + double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding()); double[] score = softmaxLayer.productPlusBias(true, h0); sumDoubleArrays(scores, score, Math.max(square, 1)); } - if (totalByteSize != 0) { - divMut(scores, totalByteSize); + if (totalLen != 0) { + divMut(scores, totalLen); } double[] probabilities = softMax(scores); ClassificationConfig classificationConfig = (ClassificationConfig) config; From ff529c4a07b9f9d4e671d3eaa59c3e4f0ea5532d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 12 Nov 2021 08:02:53 -0500 Subject: [PATCH 3/4] adding more tests --- .../langident/LangIdentNeuralNetworkInferenceTests.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index 51dd1addf744e..6459de0876690 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -75,6 +75,13 @@ public void testAdverseScenarios() throws Exception { classificationConfig ); assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi")); + + // Should not throw + inferenceDefinition.infer(inferenceObj("행 A A"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행 A성 xx"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행 A성 성x"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행A A성 x성"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행A 성 x"), classificationConfig); } public void testMixedLangInference() throws Exception { @@ -104,6 +111,8 @@ public void testMixedLangInference() throws Exception { classificationConfig ); assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + + } public void testLangInference() throws Exception { From f945f01ba7bb42a85c2469281190ac095f866e34 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 12 Nov 2021 08:04:23 -0500 Subject: [PATCH 4/4] fixing format --- .../langident/LangIdentNeuralNetworkInferenceTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index 6459de0876690..a572d47eafb6a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -112,7 +112,6 @@ public void testMixedLangInference() throws Exception { ); assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); - } public void testLangInference() throws Exception {