Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] fix NLP inference_config bwc serialization tests #89011

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ public String getHypothesisTemplate() {
return hypothesisTemplate;
}

public List<String> getLabels() {
return Optional.ofNullable(labels).orElse(List.of());
public Optional<List<String>> getLabels() {
return Optional.ofNullable(labels);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ public InferenceConfig apply(InferenceConfig originalConfig) {
tokenizationUpdate == null ? zeroShotConfig.getTokenization() : tokenizationUpdate.apply(zeroShotConfig.getTokenization()),
zeroShotConfig.getHypothesisTemplate(),
Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()),
Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels()),
Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels().orElse(null)),
Optional.ofNullable(resultsField).orElse(zeroShotConfig.getResultsField())
);
}

boolean isNoop(ZeroShotClassificationConfig originalConfig) {
return (labels == null || labels.equals(originalConfig.getLabels()))
return (labels == null || labels.equals(originalConfig.getLabels().orElse(null)))
&& (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()))
&& (resultsField == null || resultsField.equals(originalConfig.getResultsField()))
&& super.isNoop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigTests;

import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -25,6 +43,29 @@

public abstract class InferenceConfigItemTestCase<T extends VersionedNamedWriteable & ToXContent> extends AbstractBWCSerializationTestCase<
T> {

static InferenceConfig mutateForVersion(NlpConfig inferenceConfig, Version version) {
if (inferenceConfig instanceof TextClassificationConfig textClassificationConfig) {
return TextClassificationConfigTests.mutateForVersion(textClassificationConfig, version);
} else if (inferenceConfig instanceof FillMaskConfig fillMaskConfig) {
return FillMaskConfigTests.mutateForVersion(fillMaskConfig, version);
} else if (inferenceConfig instanceof QuestionAnsweringConfig questionAnsweringConfig) {
return QuestionAnsweringConfigTests.mutateForVersion(questionAnsweringConfig, version);
} else if (inferenceConfig instanceof NerConfig nerConfig) {
return NerConfigTests.mutateForVersion(nerConfig, version);
} else if (inferenceConfig instanceof PassThroughConfig passThroughConfig) {
return PassThroughConfigTests.mutateForVersion(passThroughConfig, version);
} else if (inferenceConfig instanceof TextEmbeddingConfig textEmbeddingConfig) {
return TextEmbeddingConfigTests.mutateForVersion(textEmbeddingConfig, version);
} else if (inferenceConfig instanceof TextSimilarityConfig textSimilarityConfig) {
return TextSimilarityConfigTests.mutateForVersion(textSimilarityConfig, version);
} else if (inferenceConfig instanceof ZeroShotClassificationConfig zeroShotClassificationConfig) {
return ZeroShotClassificationConfigTests.mutateForVersion(zeroShotClassificationConfig, version);
} else {
throw new IllegalArgumentException("unknown inference config [" + inferenceConfig.getName() + "]");
}
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
*/
package org.elasticsearch.xpack.core.ml.inference;

import org.apache.lucene.tests.util.LuceneTestCase;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.bytes.BytesReference;
Expand All @@ -29,10 +28,10 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocationTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests;
Expand Down Expand Up @@ -60,7 +59,6 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/89008")
public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<TrainedModelConfig> {

private boolean lenient;
Expand Down Expand Up @@ -397,8 +395,8 @@ protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instanc
builder.setModelType(null);
builder.setLocation(null);
}
if (instance.getInferenceConfig()instanceof TextClassificationConfig textClassificationConfig) {
builder.setInferenceConfig(TextClassificationConfigTests.mutateInstance(textClassificationConfig, version));
if (instance.getInferenceConfig()instanceof NlpConfig nlpConfig) {
builder.setInferenceConfig(InferenceConfigItemTestCase.mutateForVersion(nlpConfig, version));
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ public class BertTokenizationTests extends AbstractBWCSerializationTestCase<Bert

private boolean lenient;

public static BertTokenization mutateForVersion(BertTokenization instance, Version version) {
if (version.before(Version.V_8_2_0)) {
return new BertTokenization(
instance.doLowerCase,
instance.withSpecialTokens,
instance.maxSequenceLength,
instance.truncate,
null
);
}
return instance;
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
Expand All @@ -41,7 +54,7 @@ protected BertTokenization createTestInstance() {

@Override
protected BertTokenization mutateInstanceForVersion(BertTokenization instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static BertTokenization createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ protected boolean supportsUnknownFields() {
return true;
}

public static FillMaskConfig mutateForVersion(FillMaskConfig instance, Version version) {
return new FillMaskConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getNumTopClasses(),
instance.getResultsField()
);
}

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.isEmpty() == false;
Expand All @@ -44,7 +53,7 @@ protected FillMaskConfig createTestInstance() {

@Override
protected FillMaskConfig mutateInstanceForVersion(FillMaskConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static FillMaskConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;

public final class InferenceConfigTestScaffolding {

static Tokenization mutateTokenizationForVersion(Tokenization tokenization, Version version) {
if (tokenization instanceof BertTokenization bertTokenization) {
return BertTokenizationTests.mutateForVersion(bertTokenization, version);
} else if (tokenization instanceof MPNetTokenization mpNetTokenization) {
return MPNetTokenizationTests.mutateForVersion(mpNetTokenization, version);
} else if (tokenization instanceof RobertaTokenization robertaTokenization) {
return RobertaTokenizationTests.mutateForVersion(robertaTokenization, version);
} else {
throw new IllegalArgumentException("unknown tokenization [" + tokenization.getName() + "]");
}
}

static Tokenization cloneWithNewTruncation(Tokenization tokenization, Tokenization.Truncate truncate) {
if (tokenization instanceof MPNetTokenization) {
return new MPNetTokenization(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ public class MPNetTokenizationTests extends AbstractBWCSerializationTestCase<MPN

private boolean lenient;

static MPNetTokenization mutateForVersion(MPNetTokenization instance, Version version) {
if (version.before(Version.V_8_2_0)) {
return new MPNetTokenization(
instance.doLowerCase,
instance.withSpecialTokens,
instance.maxSequenceLength,
instance.truncate,
null
);
}
return instance;
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
Expand All @@ -41,7 +54,7 @@ protected MPNetTokenization createTestInstance() {

@Override
protected MPNetTokenization mutateInstanceForVersion(MPNetTokenization instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static MPNetTokenization createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@

public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {

public static NerConfig mutateForVersion(NerConfig instance, Version version) {
return new NerConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getClassificationLabels(),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -48,7 +57,7 @@ protected NerConfig createTestInstance() {

@Override
protected NerConfig mutateInstanceForVersion(NerConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static NerConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThroughConfig> {

public static PassThroughConfig mutateForVersion(PassThroughConfig instance, Version version) {
return new PassThroughConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -44,7 +52,7 @@ protected PassThroughConfig createTestInstance() {

@Override
protected PassThroughConfig mutateInstanceForVersion(PassThroughConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static PassThroughConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.apache.lucene.tests.util.LuceneTestCase;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -16,9 +15,18 @@
import java.io.IOException;
import java.util.function.Predicate;

@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/89008")
public class QuestionAnsweringConfigTests extends InferenceConfigItemTestCase<QuestionAnsweringConfig> {

public static QuestionAnsweringConfig mutateForVersion(QuestionAnsweringConfig instance, Version version) {
return new QuestionAnsweringConfig(
instance.getNumTopClasses(),
instance.getMaxAnswerLength(),
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -46,7 +54,7 @@ protected QuestionAnsweringConfig createTestInstance() {

@Override
protected QuestionAnsweringConfig mutateInstanceForVersion(QuestionAnsweringConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static QuestionAnsweringConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ public class RobertaTokenizationTests extends AbstractBWCSerializationTestCase<R

private boolean lenient;

public static RobertaTokenization mutateForVersion(RobertaTokenization instance, Version version) {
if (version.before(Version.V_8_2_0)) {
return new RobertaTokenization(
instance.withSpecialTokens,
instance.isAddPrefixSpace(),
instance.maxSequenceLength,
instance.truncate,
null
);
}
return instance;
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
Expand All @@ -41,7 +54,7 @@ protected RobertaTokenization createTestInstance() {

@Override
protected RobertaTokenization mutateInstanceForVersion(RobertaTokenization instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static RobertaTokenization createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,14 @@

public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {

public static TextClassificationConfig mutateInstance(TextClassificationConfig instance, Version version) {
if (version.before(Version.V_8_2_0)) {
final Tokenization tokenization;
if (instance.getTokenization() instanceof BertTokenization) {
tokenization = new BertTokenization(
instance.getTokenization().doLowerCase,
instance.getTokenization().withSpecialTokens,
instance.getTokenization().maxSequenceLength,
instance.getTokenization().truncate,
null
);
} else if (instance.getTokenization() instanceof MPNetTokenization) {
tokenization = new MPNetTokenization(
instance.getTokenization().doLowerCase,
instance.getTokenization().withSpecialTokens,
instance.getTokenization().maxSequenceLength,
instance.getTokenization().truncate,
null
);
} else {
throw new UnsupportedOperationException("unknown tokenization type: " + instance.getTokenization().getName());
}
return new TextClassificationConfig(
instance.getVocabularyConfig(),
tokenization,
instance.getClassificationLabels(),
instance.getNumTopClasses(),
instance.getResultsField()
);
}
return instance;
public static TextClassificationConfig mutateForVersion(TextClassificationConfig instance, Version version) {
return new TextClassificationConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getClassificationLabels(),
instance.getNumTopClasses(),
instance.getResultsField()
);
}

@Override
Expand Down Expand Up @@ -81,7 +58,7 @@ protected TextClassificationConfig createTestInstance() {

@Override
protected TextClassificationConfig mutateInstanceForVersion(TextClassificationConfig instance, Version version) {
return mutateInstance(instance, version);
return mutateForVersion(instance, version);
}

public void testInvalidClassificationLabels() {
Expand Down
Loading