diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java index 3ee09ffc1e837..710a2855167cf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java @@ -248,8 +248,8 @@ public String getHypothesisTemplate() { return hypothesisTemplate; } - public List getLabels() { - return Optional.ofNullable(labels).orElse(List.of()); + public Optional> getLabels() { + return Optional.ofNullable(labels); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java index 3cf9f8c8f8354..acfd726ca27a5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java @@ -147,13 +147,13 @@ public InferenceConfig apply(InferenceConfig originalConfig) { tokenizationUpdate == null ? zeroShotConfig.getTokenization() : tokenizationUpdate.apply(zeroShotConfig.getTokenization()), zeroShotConfig.getHypothesisTemplate(), Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()), - Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels()), + Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels().orElse(null)), Optional.ofNullable(resultsField).orElse(zeroShotConfig.getResultsField()) ); } boolean isNoop(ZeroShotClassificationConfig originalConfig) { - return (labels == null || labels.equals(originalConfig.getLabels())) + return (labels == null || labels.equals(originalConfig.getLabels().orElse(null))) && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel())) && (resultsField == null || resultsField.equals(originalConfig.getResultsField())) && super.isNoop(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java index 79157bcb5ab27..37b37940a5780 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java @@ -15,6 +15,24 @@ import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigTests; import java.util.ArrayList; import java.util.Collections; @@ -25,6 +43,29 @@ public abstract class InferenceConfigItemTestCase extends AbstractBWCSerializationTestCase< T> { + + static InferenceConfig mutateForVersion(NlpConfig inferenceConfig, Version version) { + if (inferenceConfig instanceof TextClassificationConfig textClassificationConfig) { + return TextClassificationConfigTests.mutateForVersion(textClassificationConfig, version); + } else if (inferenceConfig instanceof FillMaskConfig fillMaskConfig) { + return FillMaskConfigTests.mutateForVersion(fillMaskConfig, version); + } else if (inferenceConfig instanceof QuestionAnsweringConfig questionAnsweringConfig) { + return QuestionAnsweringConfigTests.mutateForVersion(questionAnsweringConfig, version); + } else if (inferenceConfig instanceof NerConfig nerConfig) { + return NerConfigTests.mutateForVersion(nerConfig, version); + } else if (inferenceConfig instanceof PassThroughConfig passThroughConfig) { + return PassThroughConfigTests.mutateForVersion(passThroughConfig, version); + } else if (inferenceConfig instanceof TextEmbeddingConfig textEmbeddingConfig) { + return TextEmbeddingConfigTests.mutateForVersion(textEmbeddingConfig, version); + } else if (inferenceConfig instanceof TextSimilarityConfig textSimilarityConfig) { + return TextSimilarityConfigTests.mutateForVersion(textSimilarityConfig, version); + } else if (inferenceConfig instanceof ZeroShotClassificationConfig zeroShotClassificationConfig) { + return ZeroShotClassificationConfigTests.mutateForVersion(zeroShotClassificationConfig, version); + } else { + throw new IllegalArgumentException("unknown inference config [" + inferenceConfig.getName() + "]"); + } + } + @Override protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 90b37a67cf6f8..bbff114de5d4c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.core.ml.inference; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.bytes.BytesReference; @@ -29,10 +28,10 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocationTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests; @@ -60,7 +59,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/89008") public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase { private boolean lenient; @@ -397,8 +395,8 @@ protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instanc builder.setModelType(null); builder.setLocation(null); } - if (instance.getInferenceConfig()instanceof TextClassificationConfig textClassificationConfig) { - builder.setInferenceConfig(TextClassificationConfigTests.mutateInstance(textClassificationConfig, version)); + if (instance.getInferenceConfig()instanceof NlpConfig nlpConfig) { + builder.setInferenceConfig(InferenceConfigItemTestCase.mutateForVersion(nlpConfig, version)); } return builder.build(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java index 9a84c254c5452..952e6b4372534 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java @@ -19,6 +19,19 @@ public class BertTokenizationTests extends AbstractBWCSerializationTestCase getRandomFieldsExcludeFilter() { return field -> field.isEmpty() == false; @@ -44,7 +53,7 @@ protected FillMaskConfig createTestInstance() { @Override protected FillMaskConfig mutateInstanceForVersion(FillMaskConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static FillMaskConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java index 43020fe23e114..228cdb40e3a89 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java @@ -7,8 +7,22 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; + public final class InferenceConfigTestScaffolding { + static Tokenization mutateTokenizationForVersion(Tokenization tokenization, Version version) { + if (tokenization instanceof BertTokenization bertTokenization) { + return BertTokenizationTests.mutateForVersion(bertTokenization, version); + } else if (tokenization instanceof MPNetTokenization mpNetTokenization) { + return MPNetTokenizationTests.mutateForVersion(mpNetTokenization, version); + } else if (tokenization instanceof RobertaTokenization robertaTokenization) { + return RobertaTokenizationTests.mutateForVersion(robertaTokenization, version); + } else { + throw new IllegalArgumentException("unknown tokenization [" + tokenization.getName() + "]"); + } + } + static Tokenization cloneWithNewTruncation(Tokenization tokenization, Tokenization.Truncate truncate) { if (tokenization instanceof MPNetTokenization) { return new MPNetTokenization( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java index 4c01935a7ef43..dead82c736445 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java @@ -19,6 +19,19 @@ public class MPNetTokenizationTests extends AbstractBWCSerializationTestCase { + public static NerConfig mutateForVersion(NerConfig instance, Version version) { + return new NerConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getClassificationLabels(), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -48,7 +57,7 @@ protected NerConfig createTestInstance() { @Override protected NerConfig mutateInstanceForVersion(NerConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static NerConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java index 3701a07b73d5b..28e107101d288 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java @@ -17,6 +17,14 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase { + public static PassThroughConfig mutateForVersion(PassThroughConfig instance, Version version) { + return new PassThroughConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -44,7 +52,7 @@ protected PassThroughConfig createTestInstance() { @Override protected PassThroughConfig mutateInstanceForVersion(PassThroughConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static PassThroughConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java index 4f3b09259f8f9..0f8f2f0783660 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.XContentParser; @@ -16,9 +15,18 @@ import java.io.IOException; import java.util.function.Predicate; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/89008") public class QuestionAnsweringConfigTests extends InferenceConfigItemTestCase { + public static QuestionAnsweringConfig mutateForVersion(QuestionAnsweringConfig instance, Version version) { + return new QuestionAnsweringConfig( + instance.getNumTopClasses(), + instance.getMaxAnswerLength(), + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -46,7 +54,7 @@ protected QuestionAnsweringConfig createTestInstance() { @Override protected QuestionAnsweringConfig mutateInstanceForVersion(QuestionAnsweringConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static QuestionAnsweringConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java index 0803fec7304bc..920933be7450e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationTests.java @@ -19,6 +19,19 @@ public class RobertaTokenizationTests extends AbstractBWCSerializationTestCase { - public static TextClassificationConfig mutateInstance(TextClassificationConfig instance, Version version) { - if (version.before(Version.V_8_2_0)) { - final Tokenization tokenization; - if (instance.getTokenization() instanceof BertTokenization) { - tokenization = new BertTokenization( - instance.getTokenization().doLowerCase, - instance.getTokenization().withSpecialTokens, - instance.getTokenization().maxSequenceLength, - instance.getTokenization().truncate, - null - ); - } else if (instance.getTokenization() instanceof MPNetTokenization) { - tokenization = new MPNetTokenization( - instance.getTokenization().doLowerCase, - instance.getTokenization().withSpecialTokens, - instance.getTokenization().maxSequenceLength, - instance.getTokenization().truncate, - null - ); - } else { - throw new UnsupportedOperationException("unknown tokenization type: " + instance.getTokenization().getName()); - } - return new TextClassificationConfig( - instance.getVocabularyConfig(), - tokenization, - instance.getClassificationLabels(), - instance.getNumTopClasses(), - instance.getResultsField() - ); - } - return instance; + public static TextClassificationConfig mutateForVersion(TextClassificationConfig instance, Version version) { + return new TextClassificationConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getClassificationLabels(), + instance.getNumTopClasses(), + instance.getResultsField() + ); } @Override @@ -81,7 +58,7 @@ protected TextClassificationConfig createTestInstance() { @Override protected TextClassificationConfig mutateInstanceForVersion(TextClassificationConfig instance, Version version) { - return mutateInstance(instance, version); + return mutateForVersion(instance, version); } public void testInvalidClassificationLabels() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java index 373f3d3102e15..d60a8b28107da 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java @@ -17,6 +17,14 @@ public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase { + public static TextEmbeddingConfig mutateForVersion(TextEmbeddingConfig instance, Version version) { + return new TextEmbeddingConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -44,7 +52,7 @@ protected TextEmbeddingConfig createTestInstance() { @Override protected TextEmbeddingConfig mutateInstanceForVersion(TextEmbeddingConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static TextEmbeddingConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigTests.java index 77dd5dcf38d61..e8976ce1dd7c5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextSimilarityConfigTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.XContentParser; @@ -17,9 +16,17 @@ import java.util.Arrays; import java.util.function.Predicate; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/89008") public class TextSimilarityConfigTests extends InferenceConfigItemTestCase { + public static TextSimilarityConfig mutateForVersion(TextSimilarityConfig instance, Version version) { + return new TextSimilarityConfig( + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getResultsField(), + instance.getSpanScoreFunction().toString() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -47,7 +54,7 @@ protected TextSimilarityConfig createTestInstance() { @Override protected TextSimilarityConfig mutateInstanceForVersion(TextSimilarityConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static TextSimilarityConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java index 63b271c04dffb..48e4b25ea7316 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java @@ -18,6 +18,18 @@ public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCase { + public static ZeroShotClassificationConfig mutateForVersion(ZeroShotClassificationConfig instance, Version version) { + return new ZeroShotClassificationConfig( + instance.getClassificationLabels(), + instance.getVocabularyConfig(), + InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version), + instance.getHypothesisTemplate(), + instance.isMultiLabel(), + instance.getLabels().orElse(null), + instance.getResultsField() + ); + } + @Override protected boolean supportsUnknownFields() { return true; @@ -45,7 +57,7 @@ protected ZeroShotClassificationConfig createTestInstance() { @Override protected ZeroShotClassificationConfig mutateInstanceForVersion(ZeroShotClassificationConfig instance, Version version) { - return instance; + return mutateForVersion(instance, version); } public static ZeroShotClassificationConfig createRandom() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java index 7aa80885ed7f4..2d424edac4c94 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java @@ -125,7 +125,7 @@ public void testApply() { originalConfig.getTokenization(), originalConfig.getHypothesisTemplate(), true, - originalConfig.getLabels(), + originalConfig.getLabels().orElse(null), originalConfig.getResultsField() ), equalTo(new ZeroShotClassificationConfigUpdate.Builder().setMultiLabel(true).build().apply(originalConfig)) @@ -137,7 +137,7 @@ public void testApply() { originalConfig.getTokenization(), originalConfig.getHypothesisTemplate(), originalConfig.isMultiLabel(), - originalConfig.getLabels(), + originalConfig.getLabels().orElse(null), "updated-field" ), equalTo(new ZeroShotClassificationConfigUpdate.Builder().setResultsField("updated-field").build().apply(originalConfig)) @@ -152,7 +152,7 @@ public void testApply() { tokenization, originalConfig.getHypothesisTemplate(), originalConfig.isMultiLabel(), - originalConfig.getLabels(), + originalConfig.getLabels().orElse(null), originalConfig.getResultsField() ), equalTo( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java index e19529b705d77..eff6916d61609 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java @@ -52,7 +52,7 @@ public class ZeroShotClassificationProcessor extends NlpTask.Processor { "zero_shot_classification requires [entailment] and [contradiction] in classification_labels" ); } - this.labels = Optional.ofNullable(config.getLabels()).orElse(List.of()).toArray(String[]::new); + this.labels = config.getLabels().orElse(List.of()).toArray(String[]::new); this.hypothesisTemplate = config.getHypothesisTemplate(); this.isMultiLabel = config.isMultiLabel(); this.resultsField = config.getResultsField(); @@ -67,7 +67,7 @@ public void validateInputs(List inputs) { public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) { final String[] labelsValue; if (nlpConfig instanceof ZeroShotClassificationConfig zeroShotConfig) { - labelsValue = zeroShotConfig.getLabels().toArray(new String[0]); + labelsValue = zeroShotConfig.getLabels().orElse(List.of()).toArray(new String[0]); } else { labelsValue = this.labels; } @@ -83,7 +83,7 @@ public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) { final boolean isMultiLabelValue; final String resultsFieldValue; if (nlpConfig instanceof ZeroShotClassificationConfig zeroShotConfig) { - labelsValue = zeroShotConfig.getLabels().toArray(new String[0]); + labelsValue = zeroShotConfig.getLabels().orElse(List.of()).toArray(new String[0]); isMultiLabelValue = zeroShotConfig.isMultiLabel(); resultsFieldValue = zeroShotConfig.getResultsField(); } else {