Skip to content

Commit

Permalink
[7.16] [ML] Fix language identification bug when multi-languages are …
Browse files Browse the repository at this point in the history
…present (#80675) (#80707)

* [ML] Fix language identification bug when multi-languages are present (#80675)

Language identification works fairly well when only one language and
script type is present. But when multiple are present, it can return
some unexpected results Example: "행 레이블 this is english text obviously
and 생성 tom said to test it" Which appears to a human to be english text
(Latin unicode) with Korean via Hangul unicode is erroneously
categorized as Japanese. It should be categorized as English as it is
the dominate language and script type. This commit fixes this bug by
doing the following:  - Input text is partitioned into common,
continuous, unicode script    sections  - Those sections individual
language scores are gathered  - Each score is then weighted according to
the number of characters in    each section  - The resulting weight
scores are transformed into probabilities  - The final probabilities are
the ones returned to the user.

* fixing compilation
  • Loading branch information
benwtrent authored Nov 15, 2021
1 parent 9acb783 commit 2db72da
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;

import java.io.IOException;
Expand All @@ -44,6 +45,24 @@
*/
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {

public static class StringLengthAndEmbedding {
final int stringLen;
final double[] embedding;

public StringLengthAndEmbedding(int stringLen, double[] embedding) {
this.stringLen = stringLen;
this.embedding = embedding;
}

public int getStringLen() {
return stringLen;
}

public double[] getEmbedding() {
return embedding;
}
}

private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class);
public static final int MAX_STRING_SIZE_IN_BYTES = 10000;
public static final ParseField NAME = new ParseField("custom_word_embedding");
Expand Down Expand Up @@ -237,11 +256,75 @@ public void process(Map<String, Object> fields) {
String text = (String) field;
text = FeatureUtils.cleanAndLowerText(text);
text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES);
String finalText = text;
List<FeatureValue[]> processedFeatures = FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
.collect(Collectors.toList());
fields.put(destField, concatEmbeddings(processedFeatures));
final String finalText = text;
if (finalText.isEmpty() || finalText.codePoints().allMatch(Character::isWhitespace)) {
fields.put(
destField,
Collections.singletonList(
new StringLengthAndEmbedding(
0,
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
.collect(Collectors.toList())
)
)
)
);
return;
}
List<StringLengthAndEmbedding> embeddings = new ArrayList<>();
int[] codePoints = finalText.codePoints().toArray();
for (int i = 0; i < codePoints.length - 1;) {
while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) {
i++;
}
if (i >= codePoints.length) {
break;
}
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
int j = i + 1;
for (; j < codePoints.length; j++) {
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
j++;
}
if (j >= codePoints.length) {
break;
}
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
if (j < codePoints.length - 1) {
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
if (j2 != ScriptCode.Common && j2 != currentCode) {
break;
}
}
}
}
// Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces
String str = new String(codePoints, i, j - i);
StringBuilder builder = new StringBuilder();
if (str.startsWith(" ") == false) {
builder.append(" ");
}
builder.append(str);
if (str.endsWith(" ") == false) {
builder.append(" ");
}
embeddings.add(
new StringLengthAndEmbedding(
// Don't count white spaces as bytes for the prediction
str.trim().length(),
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
.collect(Collectors.toList())
)
)
);
i = j;
}
fields.put(destField, embeddings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,29 @@ public static List<ClassificationFeatureImportance> transformFeatureImportanceCl
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
return sumDoubleArrays(sumTo, inc, 1);
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
sumTo[i] += (inc[i] * weight);
}
return sumTo;
}

public static void divMut(double[] xs, int v) {
if (xs.length == 0) {
return;
}
if (v == 0) {
throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior");
}
for (int i = 0; i < xs.length; i++) {
xs[i] /= v;
}
}

public static class TopClassificationValue {
private final int value;
private final double probability;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
Expand All @@ -37,6 +38,8 @@
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;

public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel {
Expand Down Expand Up @@ -218,27 +221,32 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
}
Object vector = fields.get(embeddedVectorFeatureName);
if (vector instanceof double[] == false) {
if (vector instanceof List<?> == false) {
throw ExceptionsHelper.badRequestException(
"[{}] model could not find non-null numerical array named [{}]",
"[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. "
+ "Please verify that the input is a string.",
NAME.getPreferredName(),
embeddedVectorFeatureName
);
}
double[] embeddedVector = (double[]) vector;
if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) {
throw ExceptionsHelper.badRequestException(
"[{}] model is expecting embedding vector of length [{}] but got [{}]",
NAME.getPreferredName(),
EMBEDDING_VECTOR_LENGTH,
embeddedVector.length
);
List<?> embeddedVector = (List<?>) vector;
double[] scores = new double[LANGUAGE_NAMES.size()];
int totalLen = 0;
for (Object vec : embeddedVector) {
if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
continue;
}
CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen();
totalLen += square;
double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
double[] score = softmaxLayer.productPlusBias(true, h0);
sumDoubleArrays(scores, score, Math.max(square, 1));
}
if (totalLen != 0) {
divMut(scores, totalLen);
}
double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
double[] scores = softmaxLayer.productPlusBias(true, h0);

double[] probabilities = softMax(scores);

ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,103 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.hamcrest.Matcher;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.Mockito.mock;

public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {

public void testLangInference() throws Exception {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();
public void testAdverseScenarios() throws Exception {
InferenceDefinition inferenceDefinition = grabModel();
ClassificationConfig classificationConfig = new ClassificationConfig(5);

config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
InferenceDefinition inferenceDefinition = new InferenceDefinition(
(LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(""),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(" "),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("!@#$%^&*()"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("1234567890"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("-----=-=--=-=+__+_+__==-=-!@#$%^&*()"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("A"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("lb"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("„ÍÎÏ◊˝Ïδ„€‹›fifl‡°·‚∏ØÒÚÒ˘ÚÆ’ÆÚ”∏Ø\uF8FFÔÓ˝Ïδ„‹›fiˇflÁ¨ˆØ"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi"));

// Should not throw
inferenceDefinition.infer(inferenceObj("행 A A"), classificationConfig);
inferenceDefinition.infer(inferenceObj("행 A성 xx"), classificationConfig);
inferenceDefinition.infer(inferenceObj("행 A성 성x"), classificationConfig);
inferenceDefinition.infer(inferenceObj("행A A성 x성"), classificationConfig);
inferenceDefinition.infer(inferenceObj("행A 성 x"), classificationConfig);
}

public void testMixedLangInference() throws Exception {
InferenceDefinition inferenceDefinition = grabModel();
ClassificationConfig classificationConfig = new ClassificationConfig(5);

ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("행 레이블 this is english text obviously and 생성 tom said to test it "),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("en"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("행 레이블 Dashboard ISSUE Qual. Plan Qual. Report Qual. 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
+ " 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"
),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

}

public void testLangInference() throws Exception {

InferenceDefinition inferenceDefinition = grabModel();
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
ClassificationConfig classificationConfig = new ClassificationConfig(1);

Expand All @@ -52,23 +125,42 @@ public void testLangInference() throws Exception {
String cld3Actual = entry.getPredictedLanguage();
double cld3Probability = entry.getProbability();

Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceFields,
inferenceObj(text),
classificationConfig
);

assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001;
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
assertThat(
"mismatch probability for language " + cld3Actual,
singleValueInferenceResults.getTopClasses().get(0).getProbability(),
closeTo(cld3Probability, eps)
matcher
);
}
}

InferenceDefinition grabModel() throws IOException {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();

config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
return new InferenceDefinition(
(LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
);
}

private static Map<String, Object> inferenceObj(String text) {
Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
return inferenceFields;
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
Expand Down

0 comments on commit 2db72da

Please sign in to comment.