diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java index 989f53f19ef2e..95bb521964be4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode; import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor; import org.elasticsearch.xpack.core.ml.utils.MlParserUtils; @@ -43,6 +44,24 @@ */ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + public static class StringLengthAndEmbedding { + final int stringLen; + final double[] embedding; + + public StringLengthAndEmbedding(int stringLen, double[] embedding) { + this.stringLen = stringLen; + this.embedding = embedding; + } + + public int getStringLen() { + return stringLen; + } + + public double[] getEmbedding() { + return embedding; + } + } + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class); public static final int MAX_STRING_SIZE_IN_BYTES = 10000; public static final ParseField NAME = new ParseField("custom_word_embedding"); @@ -213,11 +232,75 @@ public void process(Map fields) { String text = (String) field; text = FeatureUtils.cleanAndLowerText(text); text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES); - String finalText = text; - List processedFeatures = FEATURE_EXTRACTORS.stream() - .map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) - .collect(Collectors.toList()); - fields.put(destField, concatEmbeddings(processedFeatures)); + final String finalText = text; + if (finalText.isEmpty() || finalText.isBlank()) { + fields.put( + destField, + Collections.singletonList( + new StringLengthAndEmbedding( + 0, + concatEmbeddings( + FEATURE_EXTRACTORS.stream() + .map((featureExtractor) -> featureExtractor.extractFeatures(finalText)) + .collect(Collectors.toList()) + ) + ) + ) + ); + return; + } + List embeddings = new ArrayList<>(); + int[] codePoints = finalText.codePoints().toArray(); + for (int i = 0; i < codePoints.length - 1;) { + while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) { + i++; + } + if (i >= codePoints.length) { + break; + } + ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i])); + int j = i + 1; + for (; j < codePoints.length; j++) { + while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) { + j++; + } + if (j >= codePoints.length) { + break; + } + ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j])); + if (j1 != currentCode && j1 != ScriptCode.Inherited) { + if (j < codePoints.length - 1) { + ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1])); + if (j2 != ScriptCode.Common && j2 != currentCode) { + break; + } + } + } + } + // Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces + String str = new String(codePoints, i, j - i); + StringBuilder builder = new StringBuilder(); + if (str.startsWith(" ") == false) { + builder.append(" "); + } + builder.append(str); + if (str.endsWith(" ") == false) { + builder.append(" "); + } + embeddings.add( + new StringLengthAndEmbedding( + // Don't count white spaces as bytes for the prediction + str.trim().length(), + concatEmbeddings( + FEATURE_EXTRACTORS.stream() + .map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString())) + .collect(Collectors.toList()) + ) + ) + ); + i = j; + } + fields.put(destField, embeddings); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 0011abb45fc70..f214070e20949 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -188,13 +188,29 @@ public static List transformFeatureImportanceCl } public static double[] sumDoubleArrays(double[] sumTo, double[] inc) { + return sumDoubleArrays(sumTo, inc, 1); + } + + public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) { assert sumTo != null && inc != null && sumTo.length == inc.length; for (int i = 0; i < inc.length; i++) { - sumTo[i] += inc[i]; + sumTo[i] += (inc[i] * weight); } return sumTo; } + public static void divMut(double[] xs, int v) { + if (xs.length == 0) { + return; + } + if (v == 0) { + throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior"); + } + for (int i = 0; i < xs.length; i++) { + xs[i] /= v; + } + } + public static class TopClassificationValue { private final int value; private final double probability; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index 9890eda059172..047c24dce1a95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; @@ -36,6 +37,8 @@ import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays; import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel { @@ -217,27 +220,32 @@ public InferenceResults infer(Map fields, InferenceConfig config throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName()); } Object vector = fields.get(embeddedVectorFeatureName); - if (vector instanceof double[] == false) { + if (vector instanceof List == false) { throw ExceptionsHelper.badRequestException( - "[{}] model could not find non-null numerical array named [{}]", + "[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. " + + "Please verify that the input is a string.", NAME.getPreferredName(), embeddedVectorFeatureName ); } - double[] embeddedVector = (double[]) vector; - if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) { - throw ExceptionsHelper.badRequestException( - "[{}] model is expecting embedding vector of length [{}] but got [{}]", - NAME.getPreferredName(), - EMBEDDING_VECTOR_LENGTH, - embeddedVector.length - ); + List embeddedVector = (List) vector; + double[] scores = new double[LANGUAGE_NAMES.size()]; + int totalLen = 0; + for (Object vec : embeddedVector) { + if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) { + continue; + } + CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec; + int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen(); + totalLen += square; + double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding()); + double[] score = softmaxLayer.productPlusBias(true, h0); + sumDoubleArrays(scores, score, Math.max(square, 1)); + } + if (totalLen != 0) { + divMut(scores, totalLen); } - double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector); - double[] scores = softmaxLayer.productPlusBias(true, h0); - double[] probabilities = softMax(scores); - ClassificationConfig classificationConfig = (ClassificationConfig) config; Tuple> topClasses = InferenceHelpers.topClasses( probabilities, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index 7283c19891e51..a572d47eafb6a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -20,30 +20,103 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.hamcrest.Matcher; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.greaterThan; import static org.mockito.Mockito.mock; public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { - public void testLangInference() throws Exception { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); - PlainActionFuture future = new PlainActionFuture<>(); - // Should be OK as we don't make any client calls - trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future); - TrainedModelConfig config = future.actionGet(); + public void testAdverseScenarios() throws Exception { + InferenceDefinition inferenceDefinition = grabModel(); + ClassificationConfig classificationConfig = new ClassificationConfig(5); - config.ensureParsedDefinition(xContentRegistry()); - TrainedModelDefinition trainedModelDefinition = config.getModelDefinition(); - InferenceDefinition inferenceDefinition = new InferenceDefinition( - (LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(), - trainedModelDefinition.getPreProcessors() + ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj(""), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj(" "), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("!@#$%^&*()"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("1234567890"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("-----=-=--=-=+__+_+__==-=-!@#$%^&*()"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("A"), classificationConfig); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("lb")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("„ÍÎÏ◊˝Ïδ„€‹›fifl‡°·‚∏ØÒÚÒ˘ÚÆ’ÆÚ”∏Ø\uF8FFÔÓ˝Ïδ„‹›fiˇflÁ¨ˆØ"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi")); + + // Should not throw + inferenceDefinition.infer(inferenceObj("행 A A"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행 A성 xx"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행 A성 성x"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행A A성 x성"), classificationConfig); + inferenceDefinition.infer(inferenceObj("행A 성 x"), classificationConfig); + } + + public void testMixedLangInference() throws Exception { + InferenceDefinition inferenceDefinition = grabModel(); + ClassificationConfig classificationConfig = new ClassificationConfig(5); + + ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("행 레이블 this is english text obviously and 생성 tom said to test it "), + classificationConfig ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("en")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj("행 레이블 Dashboard ISSUE Qual. Plan Qual. Report Qual. 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + + singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( + inferenceObj( + "@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual." + + " 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈" + ), + classificationConfig + ); + assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko")); + + } + + public void testLangInference() throws Exception { + + InferenceDefinition inferenceDefinition = grabModel(); List examples = new LanguageExamples().getLanguageExamples(); ClassificationConfig classificationConfig = new ClassificationConfig(1); @@ -52,23 +125,42 @@ public void testLangInference() throws Exception { String cld3Actual = entry.getPredictedLanguage(); double cld3Probability = entry.getProbability(); - Map inferenceFields = new HashMap<>(); - inferenceFields.put("text", text); ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer( - inferenceFields, + inferenceObj(text), classificationConfig ); assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual)); - double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001; + Matcher matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001); assertThat( "mismatch probability for language " + cld3Actual, singleValueInferenceResults.getTopClasses().get(0).getProbability(), - closeTo(cld3Probability, eps) + matcher ); } } + InferenceDefinition grabModel() throws IOException { + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + PlainActionFuture future = new PlainActionFuture<>(); + // Should be OK as we don't make any client calls + trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future); + TrainedModelConfig config = future.actionGet(); + + config.ensureParsedDefinition(xContentRegistry()); + TrainedModelDefinition trainedModelDefinition = config.getModelDefinition(); + return new InferenceDefinition( + (LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(), + trainedModelDefinition.getPreProcessors() + ); + } + + private static Map inferenceObj(String text) { + Map inferenceFields = new HashMap<>(); + inferenceFields.put("text", text); + return inferenceFields; + } + @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());