Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Fix language identification bug when multi-languages are present #80675

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -43,6 +45,24 @@
*/
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {

public static class ByteSizeAndEmbedding {
final int utf8ByteSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it very strange that the weighting is the number of UTF-8 bytes, not the number of characters.

That means that if I have some text that's 100 characters of Roman alphabet and 100 Chinese characters then the Chinese could get a weighting of 300 while the western language gets a weighting of 100. Is the byte count a sneaky heuristic for saying each Chinese character embeds more information than a Roman alphabet character? It would be good to add a comment with the justification.

final double[] embedding;

public ByteSizeAndEmbedding(int utf8ByteSize, double[] embedding) {
this.utf8ByteSize = utf8ByteSize;
this.embedding = embedding;
}

public int getUtf8ByteSize() {
return utf8ByteSize;
}

public double[] getEmbedding() {
return embedding;
}
}

private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class);
public static final int MAX_STRING_SIZE_IN_BYTES = 10000;
public static final ParseField NAME = new ParseField("custom_word_embedding");
Expand Down Expand Up @@ -214,10 +234,75 @@ public void process(Map<String, Object> fields) {
text = FeatureUtils.cleanAndLowerText(text);
text = FeatureUtils.truncateToNumValidBytes(text, MAX_STRING_SIZE_IN_BYTES);
String finalText = text;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be clearer if it's explicitly final

Suggested change
String finalText = text;
final String finalText = text;

List<FeatureValue[]> processedFeatures = FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
.collect(Collectors.toList());
fields.put(destField, concatEmbeddings(processedFeatures));
if (text.isEmpty() || text.isBlank()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems potentially confusing to mix text and finalText in the main algorithm. Since finalText needs to be used in lambdas I'd just use it everywhere to avoid making the reader double-check if there's a difference.

Suggested change
if (text.isEmpty() || text.isBlank()) {
if (finalText.isEmpty() || finalText.isBlank()) {

fields.put(
destField,
Arrays.asList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Arrays.asList(
Collections.singletonList(

(because Arrays.asList with 1 item causes an IntelliJ warning)

new ByteSizeAndEmbedding(
// Don't count white spaces as bytes for the prediction
finalText.trim().getBytes(StandardCharsets.UTF_8).length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
finalText.trim().getBytes(StandardCharsets.UTF_8).length,
0,

If this is wrong please add a comment saying how the trimmed length of a blank or empty string can be > 0

concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(finalText))
.collect(Collectors.toList())
)
)
)
);
return;
}
List<ByteSizeAndEmbedding> embeddings = new ArrayList<>();
int[] codePoints = finalText.codePoints().toArray();
for (int i = 0; i < codePoints.length - 1;) {
while (i < codePoints.length - 1 && Character.isLetter(codePoints[i]) == false) {
i++;
}
if (i >= codePoints.length) {
break;
}
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
int j = i + 1;
for (; j < codePoints.length; j++) {
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
j++;
}
if (j >= codePoints.length) {
break;
}
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
if (j < codePoints.length - 1) {
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
if (j2 != ScriptCode.Common && j2 != currentCode) {
break;
}
}
}
}
// Knowing the start and the end of the section is important for feature building, so make sure its wrapped in spaces
String str = new String(codePoints, i, j - i);
StringBuilder builder = new StringBuilder();
if (str.startsWith(" ") == false) {
builder.append(" ");
}
builder.append(str);
if (str.endsWith(" ") == false) {
builder.append(" ");
}
embeddings.add(
new ByteSizeAndEmbedding(
// Don't count white spaces as bytes for the prediction
str.trim().getBytes(StandardCharsets.UTF_8).length,
concatEmbeddings(
FEATURE_EXTRACTORS.stream()
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))
.collect(Collectors.toList())
)
)
);
i = j;
}
fields.put(destField, embeddings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,29 @@ public static List<ClassificationFeatureImportance> transformFeatureImportanceCl
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
return sumDoubleArrays(sumTo, inc, 1);
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
sumTo[i] += (inc[i] * weight);
}
return sumTo;
}

public static void divMut(double[] xs, int v) {
if (xs.length == 0) {
return;
}
if (v == 0) {
throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior");
}
for (int i = 0; i < xs.length; i++) {
xs[i] /= v;
}
}

public static class TopClassificationValue {
private final int value;
private final double probability;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
Expand All @@ -36,6 +37,8 @@
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.divMut;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;

public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel {
Expand Down Expand Up @@ -217,27 +220,32 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
}
Object vector = fields.get(embeddedVectorFeatureName);
if (vector instanceof double[] == false) {
if (vector instanceof List<?> == false) {
throw ExceptionsHelper.badRequestException(
"[{}] model could not find non-null numerical array named [{}]",
"[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. "
+ "Please verify that the input is a string.",
NAME.getPreferredName(),
embeddedVectorFeatureName
);
}
double[] embeddedVector = (double[]) vector;
if (embeddedVector.length != EMBEDDING_VECTOR_LENGTH) {
throw ExceptionsHelper.badRequestException(
"[{}] model is expecting embedding vector of length [{}] but got [{}]",
NAME.getPreferredName(),
EMBEDDING_VECTOR_LENGTH,
embeddedVector.length
);
List<?> embeddedVector = (List<?>) vector;
double[] scores = new double[LANGUAGE_NAMES.size()];
int totalByteSize = 0;
for (Object vec : embeddedVector) {
if (vec instanceof CustomWordEmbedding.ByteSizeAndEmbedding == false) {
continue;
}
CustomWordEmbedding.ByteSizeAndEmbedding byteSizeAndEmbedding = (CustomWordEmbedding.ByteSizeAndEmbedding) vec;
int square = (int) Math.pow(byteSizeAndEmbedding.getUtf8ByteSize(), 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly suspect multiplying two integers is much faster than using some generic x^y algorithm that works on arbitrary floating point numbers.

Suggested change
int square = (int) Math.pow(byteSizeAndEmbedding.getUtf8ByteSize(), 2);
int square = byteSizeAndEmbedding.getUtf8ByteSize() * byteSizeAndEmbedding.getUtf8ByteSize();

totalByteSize += square;
double[] h0 = hiddenLayer.productPlusBias(false, byteSizeAndEmbedding.getEmbedding());
double[] score = softmaxLayer.productPlusBias(true, h0);
sumDoubleArrays(scores, score, Math.max(square, 1));
}
if (totalByteSize != 0) {
divMut(scores, totalByteSize);
}
double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
double[] scores = softmaxLayer.productPlusBias(true, h0);

double[] probabilities = softMax(scores);

ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,95 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.hamcrest.Matcher;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.Mockito.mock;

public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {

public void testLangInference() throws Exception {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();
public void testAdverseScenarios() throws Exception {
InferenceDefinition inferenceDefinition = grabModel();
ClassificationConfig classificationConfig = new ClassificationConfig(5);

config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
InferenceDefinition inferenceDefinition = new InferenceDefinition(
(LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(""),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(" "),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("!@#$%^&*()"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("1234567890"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("-----=-=--=-=+__+_+__==-=-!@#$%^&*()"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ja"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("A"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("lb"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("„ÍÎÏ◊˝Ïδ„€‹›fifl‡°·‚∏ØÒÚÒ˘ÚÆ’ÆÚ”∏Ø\uF8FFÔÓ˝Ïδ„‹›fiˇflÁ¨ˆØ"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("vi"));
}

public void testMixedLangInference() throws Exception {
InferenceDefinition inferenceDefinition = grabModel();
ClassificationConfig classificationConfig = new ClassificationConfig(5);

ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("행 레이블 this is english text obviously and 생성 tom said to test it "),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("en"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj("행 레이블 Dashboard ISSUE Qual. Plan Qual. Report Qual. 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));

singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceObj(
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
+ " 현황 Risk Task생성 개발과제지정 개발모델 개발목표 개발비 개발팀별 현황 과제이슈"
),
classificationConfig
);
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
}

public void testLangInference() throws Exception {

InferenceDefinition inferenceDefinition = grabModel();
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
ClassificationConfig classificationConfig = new ClassificationConfig(1);

Expand All @@ -52,23 +117,42 @@ public void testLangInference() throws Exception {
String cld3Actual = entry.getPredictedLanguage();
double cld3Probability = entry.getProbability();

Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
ClassificationInferenceResults singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
inferenceFields,
inferenceObj(text),
classificationConfig
);

assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001;
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
assertThat(
"mismatch probability for language " + cld3Actual,
singleValueInferenceResults.getTopClasses().get(0).getProbability(),
closeTo(cld3Probability, eps)
matcher
);
}
}

InferenceDefinition grabModel() throws IOException {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), future);
TrainedModelConfig config = future.actionGet();

config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
return new InferenceDefinition(
(LangIdentNeuralNetwork) trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
);
}

private static Map<String, Object> inferenceObj(String text) {
Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
return inferenceFields;
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
Expand Down