Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Fix language identification bug when multi-languages are present #80675

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

Expand All @@ -43,6 +44,24 @@
*/
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {

public static class StringLengthAndEmbedding {
final int stringLen;
final double[] embedding;

public StringLengthAndEmbedding(int stringLen, double[] embedding) {
this.stringLen = stringLen;
this.embedding = embedding;
}

public int getStringLen() {
return stringLen;
}

public double[] getEmbedding() {
return embedding;
}
}

private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class);
public static final int MAX_STRING_SIZE_IN_BYTES = 10000;
public static final ParseField NAME = new ParseField("custom_word_embedding");
Expand Down Expand Up @@ -213,11 +232,75 @@ public void process(Map<String, Object> fields) {
String text = (String) field;
text = FeatureUtils.cleanAndLowerText(text);
text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES);
String finalText = text;
List<FeatureValue[]> processedFeatures = FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
.collect(Collectors.toList());
fields.put(destField, concatEmbeddings(processedFeatures));
final String finalText = text;
if (finalText.isEmpty() || finalText.isBlank()) {
fields.put(
destField,
Collections.singletonList(
new StringLengthAndEmbedding(
0,
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
.collect(Collectors.toList())
)
)
)
);
return;
}
List<StringLengthAndEmbedding> embeddings = new ArrayList<>();
int[] codePoints = finalText.codePoints().toArray();
for (int i = 0; i < codePoints.length - 1;) {
while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) {
i++;
}
if (i >= codePoints.length) {
break;
}
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
int j = i + 1;
for (; j < codePoints.length; j++) {
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
j++;
}
if (j >= codePoints.length) {
break;
}
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
if (j < codePoints.length - 1) {
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
if (j2 != ScriptCode.Common && j2 != currentCode) {
break;
}
}
}
}
// Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces
String str = new String(codePoints, i, j - i);
StringBuilder builder = new StringBuilder();
if (str.startsWith(" ") == false) {
builder.append(" ");
}
builder.append(str);
if (str.endsWith(" ") == false) {
builder.append(" ");
}
embeddings.add(
new StringLengthAndEmbedding(
// Don't count white spaces as bytes for the prediction
str.trim().length(),
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
.collect(Collectors.toList())
)
)
);
i = j;
}
fields.put(destField, embeddings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,29 @@ public static List<ClassificationFeatureImportance> transformFeatureImportanceCl
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
return sumDoubleArrays(sumTo, inc, 1);
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
sumTo[i] += (inc[i] * weight);
}
return sumTo;
}

public static void divMut(double[] xs, int v) {
if (xs.length == 0) {
return;
}
if (v == 0) {
throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior");
}
for (int i = 0; i < xs.length; i++) {
xs[i] /= v;
}
}

public static class TopClassificationValue {
private final int value;
private final double probability;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
Expand All @@ -36,6 +37,8 @@
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;

public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel {
Expand Down Expand Up @@ -217,27 +220,32 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
}
Object vector = fields.get(embeddedVectorFeatureName);
if (vector instanceof double[] == false) {
if (vector instanceof List<?> == false) {
throw ExceptionsHelper.badRequestException(
"[{}] model could not find non-null numerical array named [{}]",
"[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. "
+ "Please verify that the input is a string.",
NAME.getPreferredName(),
embeddedVectorFeatureName
);
}
double[] embeddedVector = (double[]) vector;
if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) {
throw ExceptionsHelper.badRequestException(
"[{}] model is expecting embedding vector of length [{}] but got [{}]",
NAME.getPreferredName(),
EMBEDDING_VECTOR_LENGTH,
embeddedVector.length
);
List<?> embeddedVector = (List<?>) vector;
double[] scores = new double[LANGUAGE_NAMES.size()];
int totalLen = 0;
for (Object vec : embeddedVector) {
if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
continue;
}
CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen();
totalLen += square;
double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
double[] score = softmaxLayer.productPlusBias(true, h0);
sumDoubleArrays(scores, score, Math.max(square, 1));
}
if (totalLen != 0) {
divMut(scores, totalLen);
}
double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
double[] scores = softmaxLayer.productPlusBias(true, h0);

double[] probabilities = softMax(scores);

ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,95 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.hamcrest.Matcher;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.Mockito.mock;

public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {

public void testLangInference() throws Exception {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();
public void testAdverseScenarios() throws Exception {
InferenceDefinition inferenceDefinition = grabModel();
ClassificationConfig classificationConfig = new ClassificationConfig(5);

config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
InferenceDefinition inferenceDefinition = new InferenceDefinition(
(LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(""),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(" "),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("!@#$%^&*()"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("1234567890"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("-----=-=--=-=+__+_+__==-=-!@#$%^&*()"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("A"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("lb"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("„ÍÎÏ◊˝Ïδ„€‹›fifl‡°·‚∏ØÒÚÒ˘ÚÆ’ÆÚ”∏Ø\uF8FFÔÓ˝Ïδ„‹›fiˇflÁ¨ˆØ"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi"));
}

public void testMixedLangInference() throws Exception {
InferenceDefinition inferenceDefinition = grabModel();
ClassificationConfig classificationConfig = new ClassificationConfig(5);

ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("행 레이블 this is english text obviously and 생성 tom said to test it "),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("en"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("행 레이블 Dashboard ISSUE Qual. Plan Qual. Report Qual. 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
+ " 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"
),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
}

public void testLangInference() throws Exception {

InferenceDefinition inferenceDefinition = grabModel();
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
ClassificationConfig classificationConfig = new ClassificationConfig(1);

Expand All @@ -52,23 +117,42 @@ public void testLangInference() throws Exception {
String cld3Actual = entry.getPredictedLanguage();
double cld3Probability = entry.getProbability();

Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceFields,
inferenceObj(text),
classificationConfig
);

assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001;
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
assertThat(
"mismatch probability for language " + cld3Actual,
singleValueInferenceResults.getTopClasses().get(0).getProbability(),
closeTo(cld3Probability, eps)
matcher
);
}
}

InferenceDefinition grabModel() throws IOException {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();

config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
return new InferenceDefinition(
(LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
);
}

private static Map<String, Object> inferenceObj(String text) {
Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
return inferenceFields;
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
Expand Down