From 08f69f09beeb562e8c98edd12dce6941ac267511 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 2 Aug 2022 12:40:33 -0400 Subject: [PATCH] [8.4] [ML] fix NLP inference_config bwc serialization tests (#89011) (#89042) * [ML] fix NLP inference_config bwc serialization tests (#89011) The tests were failing because of span not being nulled out for question_answering and text_similarity tasks. But, this change also attempts to make it more future proof so that if changes occur to the nlp task or tokenization configurations it will cause a failure more quickly and require handling the bwc testing. closes: #89008 (cherry picked from commit 480479d288f2a4b3f5d48e1757027dfcdf643c57) * fixing backport --- .../ZeroShotClassificationConfig.java | 4 +- .../ZeroShotClassificationConfigUpdate.java | 4 +- .../InferenceConfigItemTestCase.java | 37 +++++++++++++++++ .../ml/inference/TrainedModelConfigTests.java | 6 +-- .../trainedmodel/BertTokenizationTests.java | 15 ++++++- .../trainedmodel/FillMaskConfigTests.java | 11 ++++- .../InferenceConfigTestScaffolding.java | 14 +++++++ .../trainedmodel/MPNetTokenizationTests.java | 15 ++++++- .../trainedmodel/NerConfigTests.java | 11 ++++- .../trainedmodel/PassThroughConfigTests.java | 10 ++++- .../QuestionAnsweringConfigTests.java | 12 +++++- .../RobertaTokenizationTests.java | 15 ++++++- .../TextClassificationConfigTests.java | 41 ++++--------------- .../TextEmbeddingConfigTests.java | 10 ++++- .../ZeroShotClassificationConfigTests.java | 14 ++++++- ...roShotClassificationConfigUpdateTests.java | 6 +-- .../nlp/ZeroShotClassificationProcessor.java | 6 +-- 17 files changed, 177 insertions(+), 54 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java index 3ee09ffc1e837..710a2855167cf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java @@ -248,8 +248,8 @@ public String getHypothesisTemplate() { return hypothesisTemplate; } - public List getLabels() { - return Optional.ofNullable(labels).orElse(List.of()); + public Optional> getLabels() { + return Optional.ofNullable(labels); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java index 3cf9f8c8f8354..acfd726ca27a5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java @@ -147,13 +147,13 @@ public InferenceConfig apply(InferenceConfig originalConfig) { tokenizationUpdate == null ? zeroShotConfig.getTokenization() : tokenizationUpdate.apply(zeroShotConfig.getTokenization()), zeroShotConfig.getHypothesisTemplate(), Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()), - Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels()), + Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels().orElse(null)), Optional.ofNullable(resultsField).orElse(zeroShotConfig.getResultsField()) ); } boolean isNoop(ZeroShotClassificationConfig originalConfig) { - return (labels == null || labels.equals(originalConfig.getLabels())) + return (labels == null || labels.equals(originalConfig.getLabels().orElse(null))) && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel())) && (resultsField == null || resultsField.equals(originalConfig.getResultsField())) && super.isNoop(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java index 79157bcb5ab27..97eb38790d071 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java @@ -15,6 +15,22 @@ import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigTests; import java.util.ArrayList; import java.util.Collections; @@ -25,6 +41,27 @@ public abstract class InferenceConfigItemTestCase extends AbstractBWCSerializationTestCase< T> { + + static InferenceConfig mutateForVersion(NlpConfig inferenceConfig, Version version) { + if (inferenceConfig instanceof TextClassificationConfig textClassificationConfig) { + return TextClassificationConfigTests.mutateForVersion(textClassificationConfig, version); + } else if (inferenceConfig instanceof FillMaskConfig fillMaskConfig) { + return FillMaskConfigTests.mutateForVersion(fillMaskConfig, version); + } else if (inferenceConfig instanceof QuestionAnsweringConfig questionAnsweringConfig) { + return QuestionAnsweringConfigTests.mutateForVersion(questionAnsweringConfig, version); + } else if (inferenceConfig instanceof NerConfig nerConfig) { + return NerConfigTests.mutateForVersion(nerConfig, version); + } else if (inferenceConfig instanceof PassThroughConfig passThroughConfig) { + return PassThroughConfigTests.mutateForVersion(passThroughConfig, version); + } else if (inferenceConfig instanceof TextEmbeddingConfig textEmbeddingConfig) { + return TextEmbeddingConfigTests.mutateForVersion(textEmbeddingConfig, version); + } else if (inferenceConfig instanceof ZeroShotClassificationConfig zeroShotClassificationConfig) { + return ZeroShotClassificationConfigTests.mutateForVersion(zeroShotClassificationConfig, version); + } else { + throw new IllegalArgumentException("unknown inference config [" + inferenceConfig.getName() + "]"); + } + } + @Override protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 8f4bc321fa1fb..9b014d338e6bd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -28,9 +28,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocationTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -391,8 +391,8 @@ protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instanc builder.setModelType(null); builder.setLocation(null); } - if (instance.getInferenceConfig()instanceof TextClassificationConfig textClassificationConfig) { - builder.setInferenceConfig(TextClassificationConfigTests.mutateInstance(textClassificationConfig, version)); + if (instance.getInferenceConfig()instanceof NlpConfig nlpConfig) { + builder.setInferenceConfig(InferenceConfigItemTestCase.mutateForVersion(nlpConfig, version)); } return builder.build(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java index 9a84c254c5452..952e6b4372534 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java @@ -19,6 +19,19 @@ public class BertTokenizationTests extends AbstractBWCSerializationTestCase getRandomFieldsExcludeFilter() { return field -> field.isEmpty() == false; @@ -44,7 +53,7 @@ protected FillMaskConfig createTestInstance() { @Override protected FillMaskConfig mutateInstanceForVersion(FillMaskConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static FillMaskConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java index 43020fe23e114..228cdb40e3a89 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java @@ -7,8 +7,22 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; + public final class InferenceConfigTestScaffolding { + static Tokenization mutateTokenizationForVersion(Tokenization tokenization, Version version) { + if (tokenization instanceof BertTokenization bertTokenization) { + return BertTokenizationTests.mutateForVersion(bertTokenization, version); + } else if (tokenization instanceof MPNetTokenization mpNetTokenization) { + return MPNetTokenizationTests.mutateForVersion(mpNetTokenization, version); + } else if (tokenization instanceof RobertaTokenization robertaTokenization) { + return RobertaTokenizationTests.mutateForVersion(robertaTokenization, version); + } else { + throw new IllegalArgumentException("unknown tokenization [" + tokenization.getName() + "]"); + } + } + static Tokenization cloneWithNewTruncation(Tokenization tokenization, Tokenization.Truncate truncate) { if (tokenization instanceof MPNetTokenization) { return new MPNetTokenization( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java index 4c01935a7ef43..dead82c736445 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java @@ -19,6 +19,19 @@ public class MPNetTokenizationTests extends AbstractBWCSerializationTestCase { + public static NerConfig mutateForVersion(NerConfig instance, Version version) { + return new NerConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getClassificationLabels(), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -48,7 +57,7 @@ protected NerConfig createTestInstance() { @Override protected NerConfig mutateInstanceForVersion(NerConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static NerConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java index 3701a07b73d5b..28e107101d288 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java @@ -17,6 +17,14 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase { + public static PassThroughConfig mutateForVersion(PassThroughConfig instance, Version version) { + return new PassThroughConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -44,7 +52,7 @@ protected PassThroughConfig createTestInstance() { @Override protected PassThroughConfig mutateInstanceForVersion(PassThroughConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static PassThroughConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java index 2ad335d3cf4b0..0f8f2f0783660 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java @@ -17,6 +17,16 @@ public class QuestionAnsweringConfigTests extends InferenceConfigItemTestCase { + public static QuestionAnsweringConfig mutateForVersion(QuestionAnsweringConfig instance, Version version) { + return new QuestionAnsweringConfig( + instance.getNumTopClasses(), + instance.getMaxAnswerLength(), + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -44,7 +54,7 @@ protected QuestionAnsweringConfig createTestInstance() { @Override protected QuestionAnsweringConfig mutateInstanceForVersion(QuestionAnsweringConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static QuestionAnsweringConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java index 0803fec7304bc..920933be7450e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java @@ -19,6 +19,19 @@ public class RobertaTokenizationTests extends AbstractBWCSerializationTestCase { - public static TextClassificationConfig mutateInstance(TextClassificationConfig instance, Version version) { - if (version.before(Version.V_8_2_0)) { - final Tokenization tokenization; - if (instance.getTokenization() instanceof BertTokenization) { - tokenization = new BertTokenization( - instance.getTokenization().doLowerCase, - instance.getTokenization().withSpecialTokens, - instance.getTokenization().maxSequenceLength, - instance.getTokenization().truncate, - null - ); - } else if (instance.getTokenization() instanceof MPNetTokenization) { - tokenization = new MPNetTokenization( - instance.getTokenization().doLowerCase, - instance.getTokenization().withSpecialTokens, - instance.getTokenization().maxSequenceLength, - instance.getTokenization().truncate, - null - ); - } else { - throw new UnsupportedOperationException("unknown tokenization type: " + instance.getTokenization().getName()); - } - return new TextClassificationConfig( - instance.getVocabularyConfig(), - tokenization, - instance.getClassificationLabels(), - instance.getNumTopClasses(), - instance.getResultsField() - ); - } - return instance; + public static TextClassificationConfig mutateForVersion(TextClassificationConfig instance, Version version) { + return new TextClassificationConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getClassificationLabels(), + instance.getNumTopClasses(), + instance.getResultsField() + ); } @Override @@ -81,7 +58,7 @@ protected TextClassificationConfig createTestInstance() { @Override protected TextClassificationConfig mutateInstanceForVersion(TextClassificationConfig instance, Version version) { - return mutateInstance(instance, version); + return mutateForVersion(instance, version); } public void testInvalidClassificationLabels() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java index 373f3d3102e15..d60a8b28107da 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java @@ -17,6 +17,14 @@ public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase { + public static TextEmbeddingConfig mutateForVersion(TextEmbeddingConfig instance, Version version) { + return new TextEmbeddingConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -44,7 +52,7 @@ protected TextEmbeddingConfig createTestInstance() { @Override protected TextEmbeddingConfig mutateInstanceForVersion(TextEmbeddingConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static TextEmbeddingConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java index 63b271c04dffb..48e4b25ea7316 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java @@ -18,6 +18,18 @@ public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCase { + public static ZeroShotClassificationConfig mutateForVersion(ZeroShotClassificationConfig instance, Version version) { + return new ZeroShotClassificationConfig( + instance.getClassificationLabels(), + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getHypothesisTemplate(), + instance.isMultiLabel(), + instance.getLabels().orElse(null), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -45,7 +57,7 @@ protected ZeroShotClassificationConfig createTestInstance() { @Override protected ZeroShotClassificationConfig mutateInstanceForVersion(ZeroShotClassificationConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static ZeroShotClassificationConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java index 7aa80885ed7f4..2d424edac4c94 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java @@ -125,7 +125,7 @@ public void testApply() { originalConfig.getTokenization(), originalConfig.getHypothesisTemplate(), true, - originalConfig.getLabels(), + originalConfig.getLabels().orElse(null), originalConfig.getResultsField() ), equalTo(new ZeroShotClassificationConfigUpdate.Builder().setMultiLabel(true).build().apply(originalConfig)) @@ -137,7 +137,7 @@ public void testApply() { originalConfig.getTokenization(), originalConfig.getHypothesisTemplate(), originalConfig.isMultiLabel(), - originalConfig.getLabels(), + originalConfig.getLabels().orElse(null), "updated-field" ), equalTo(new ZeroShotClassificationConfigUpdate.Builder().setResultsField("updated-field").build().apply(originalConfig)) @@ -152,7 +152,7 @@ public void testApply() { tokenization, originalConfig.getHypothesisTemplate(), originalConfig.isMultiLabel(), - originalConfig.getLabels(), + originalConfig.getLabels().orElse(null), originalConfig.getResultsField() ), equalTo( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java index e932df01604ad..6deca79272d1c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java @@ -55,7 +55,7 @@ public class ZeroShotClassificationProcessor extends NlpTask.Processor { "zero_shot_classification requires [entailment] and [contradiction] in classification_labels" ); } - this.labels = Optional.ofNullable(config.getLabels()).orElse(List.of()).toArray(String[]::new); + this.labels = config.getLabels().orElse(List.of()).toArray(String[]::new); this.hypothesisTemplate = config.getHypothesisTemplate(); this.isMultiLabel = config.isMultiLabel(); this.resultsField = config.getResultsField(); @@ -70,7 +70,7 @@ public void validateInputs(List inputs) { public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) { final String[] labelsValue; if (nlpConfig instanceof ZeroShotClassificationConfig zeroShotConfig) { - labelsValue = zeroShotConfig.getLabels().toArray(new String[0]); + labelsValue = zeroShotConfig.getLabels().orElse(List.of()).toArray(new String[0]); } else { labelsValue = this.labels; } @@ -86,7 +86,7 @@ public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) { final boolean isMultiLabelValue; final String resultsFieldValue; if (nlpConfig instanceof ZeroShotClassificationConfig zeroShotConfig) { - labelsValue = zeroShotConfig.getLabels().toArray(new String[0]); + labelsValue = zeroShotConfig.getLabels().orElse(List.of()).toArray(new String[0]); isMultiLabelValue = zeroShotConfig.isMultiLabel(); resultsFieldValue = zeroShotConfig.getResultsField(); } else {