Skip to content

Commit

Permalink
[8.4] [ML] fix NLP inference_config bwc serialization tests (#89011) (#…
Browse files Browse the repository at this point in the history
…89042)

* [ML] fix NLP inference_config bwc serialization tests (#89011)

The tests were failing because of span not being nulled out for question_answering and text_similarity tasks.

But, this change also attempts to make it more future proof so that if changes occur to the nlp task or tokenization configurations it will cause a failure more quickly and require handling the bwc testing.

closes: #89008
(cherry picked from commit 480479d)

* fixing backport
  • Loading branch information
benwtrent authored Aug 2, 2022
1 parent dd68bb7 commit 08f69f0
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ public String getHypothesisTemplate() {
return hypothesisTemplate;
}

public List<String> getLabels() {
return Optional.ofNullable(labels).orElse(List.of());
public Optional<List<String>> getLabels() {
return Optional.ofNullable(labels);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ public InferenceConfig apply(InferenceConfig originalConfig) {
tokenizationUpdate == null ? zeroShotConfig.getTokenization() : tokenizationUpdate.apply(zeroShotConfig.getTokenization()),
zeroShotConfig.getHypothesisTemplate(),
Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()),
Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels()),
Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels().orElse(null)),
Optional.ofNullable(resultsField).orElse(zeroShotConfig.getResultsField())
);
}

boolean isNoop(ZeroShotClassificationConfig originalConfig) {
return (labels == null || labels.equals(originalConfig.getLabels()))
return (labels == null || labels.equals(originalConfig.getLabels().orElse(null)))
&& (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()))
&& (resultsField == null || resultsField.equals(originalConfig.getResultsField()))
&& super.isNoop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigTests;

import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -25,6 +41,27 @@

public abstract class InferenceConfigItemTestCase<T extends VersionedNamedWriteable & ToXContent> extends AbstractBWCSerializationTestCase<
T> {

static InferenceConfig mutateForVersion(NlpConfig inferenceConfig, Version version) {
if (inferenceConfig instanceof TextClassificationConfig textClassificationConfig) {
return TextClassificationConfigTests.mutateForVersion(textClassificationConfig, version);
} else if (inferenceConfig instanceof FillMaskConfig fillMaskConfig) {
return FillMaskConfigTests.mutateForVersion(fillMaskConfig, version);
} else if (inferenceConfig instanceof QuestionAnsweringConfig questionAnsweringConfig) {
return QuestionAnsweringConfigTests.mutateForVersion(questionAnsweringConfig, version);
} else if (inferenceConfig instanceof NerConfig nerConfig) {
return NerConfigTests.mutateForVersion(nerConfig, version);
} else if (inferenceConfig instanceof PassThroughConfig passThroughConfig) {
return PassThroughConfigTests.mutateForVersion(passThroughConfig, version);
} else if (inferenceConfig instanceof TextEmbeddingConfig textEmbeddingConfig) {
return TextEmbeddingConfigTests.mutateForVersion(textEmbeddingConfig, version);
} else if (inferenceConfig instanceof ZeroShotClassificationConfig zeroShotClassificationConfig) {
return ZeroShotClassificationConfigTests.mutateForVersion(zeroShotClassificationConfig, version);
} else {
throw new IllegalArgumentException("unknown inference config [" + inferenceConfig.getName() + "]");
}
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocationTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
Expand Down Expand Up @@ -391,8 +391,8 @@ protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instanc
builder.setModelType(null);
builder.setLocation(null);
}
if (instance.getInferenceConfig()instanceof TextClassificationConfig textClassificationConfig) {
builder.setInferenceConfig(TextClassificationConfigTests.mutateInstance(textClassificationConfig, version));
if (instance.getInferenceConfig()instanceof NlpConfig nlpConfig) {
builder.setInferenceConfig(InferenceConfigItemTestCase.mutateForVersion(nlpConfig, version));
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ public class BertTokenizationTests extends AbstractBWCSerializationTestCase<Bert

private boolean lenient;

public static BertTokenization mutateForVersion(BertTokenization instance, Version version) {
if (version.before(Version.V_8_2_0)) {
return new BertTokenization(
instance.doLowerCase,
instance.withSpecialTokens,
instance.maxSequenceLength,
instance.truncate,
null
);
}
return instance;
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
Expand All @@ -41,7 +54,7 @@ protected BertTokenization createTestInstance() {

@Override
protected BertTokenization mutateInstanceForVersion(BertTokenization instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static BertTokenization createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ protected boolean supportsUnknownFields() {
return true;
}

public static FillMaskConfig mutateForVersion(FillMaskConfig instance, Version version) {
return new FillMaskConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getNumTopClasses(),
instance.getResultsField()
);
}

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.isEmpty() == false;
Expand All @@ -44,7 +53,7 @@ protected FillMaskConfig createTestInstance() {

@Override
protected FillMaskConfig mutateInstanceForVersion(FillMaskConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static FillMaskConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;

public final class InferenceConfigTestScaffolding {

static Tokenization mutateTokenizationForVersion(Tokenization tokenization, Version version) {
if (tokenization instanceof BertTokenization bertTokenization) {
return BertTokenizationTests.mutateForVersion(bertTokenization, version);
} else if (tokenization instanceof MPNetTokenization mpNetTokenization) {
return MPNetTokenizationTests.mutateForVersion(mpNetTokenization, version);
} else if (tokenization instanceof RobertaTokenization robertaTokenization) {
return RobertaTokenizationTests.mutateForVersion(robertaTokenization, version);
} else {
throw new IllegalArgumentException("unknown tokenization [" + tokenization.getName() + "]");
}
}

static Tokenization cloneWithNewTruncation(Tokenization tokenization, Tokenization.Truncate truncate) {
if (tokenization instanceof MPNetTokenization) {
return new MPNetTokenization(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ public class MPNetTokenizationTests extends AbstractBWCSerializationTestCase<MPN

private boolean lenient;

static MPNetTokenization mutateForVersion(MPNetTokenization instance, Version version) {
if (version.before(Version.V_8_2_0)) {
return new MPNetTokenization(
instance.doLowerCase,
instance.withSpecialTokens,
instance.maxSequenceLength,
instance.truncate,
null
);
}
return instance;
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
Expand All @@ -41,7 +54,7 @@ protected MPNetTokenization createTestInstance() {

@Override
protected MPNetTokenization mutateInstanceForVersion(MPNetTokenization instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static MPNetTokenization createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@

public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {

public static NerConfig mutateForVersion(NerConfig instance, Version version) {
return new NerConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getClassificationLabels(),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -48,7 +57,7 @@ protected NerConfig createTestInstance() {

@Override
protected NerConfig mutateInstanceForVersion(NerConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static NerConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThroughConfig> {

public static PassThroughConfig mutateForVersion(PassThroughConfig instance, Version version) {
return new PassThroughConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -44,7 +52,7 @@ protected PassThroughConfig createTestInstance() {

@Override
protected PassThroughConfig mutateInstanceForVersion(PassThroughConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static PassThroughConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

public class QuestionAnsweringConfigTests extends InferenceConfigItemTestCase<QuestionAnsweringConfig> {

public static QuestionAnsweringConfig mutateForVersion(QuestionAnsweringConfig instance, Version version) {
return new QuestionAnsweringConfig(
instance.getNumTopClasses(),
instance.getMaxAnswerLength(),
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -44,7 +54,7 @@ protected QuestionAnsweringConfig createTestInstance() {

@Override
protected QuestionAnsweringConfig mutateInstanceForVersion(QuestionAnsweringConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static QuestionAnsweringConfig createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ public class RobertaTokenizationTests extends AbstractBWCSerializationTestCase<R

private boolean lenient;

public static RobertaTokenization mutateForVersion(RobertaTokenization instance, Version version) {
if (version.before(Version.V_8_2_0)) {
return new RobertaTokenization(
instance.withSpecialTokens,
instance.isAddPrefixSpace(),
instance.maxSequenceLength,
instance.truncate,
null
);
}
return instance;
}

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
Expand All @@ -41,7 +54,7 @@ protected RobertaTokenization createTestInstance() {

@Override
protected RobertaTokenization mutateInstanceForVersion(RobertaTokenization instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static RobertaTokenization createRandom() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,14 @@

public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {

public static TextClassificationConfig mutateInstance(TextClassificationConfig instance, Version version) {
if (version.before(Version.V_8_2_0)) {
final Tokenization tokenization;
if (instance.getTokenization() instanceof BertTokenization) {
tokenization = new BertTokenization(
instance.getTokenization().doLowerCase,
instance.getTokenization().withSpecialTokens,
instance.getTokenization().maxSequenceLength,
instance.getTokenization().truncate,
null
);
} else if (instance.getTokenization() instanceof MPNetTokenization) {
tokenization = new MPNetTokenization(
instance.getTokenization().doLowerCase,
instance.getTokenization().withSpecialTokens,
instance.getTokenization().maxSequenceLength,
instance.getTokenization().truncate,
null
);
} else {
throw new UnsupportedOperationException("unknown tokenization type: " + instance.getTokenization().getName());
}
return new TextClassificationConfig(
instance.getVocabularyConfig(),
tokenization,
instance.getClassificationLabels(),
instance.getNumTopClasses(),
instance.getResultsField()
);
}
return instance;
public static TextClassificationConfig mutateForVersion(TextClassificationConfig instance, Version version) {
return new TextClassificationConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getClassificationLabels(),
instance.getNumTopClasses(),
instance.getResultsField()
);
}

@Override
Expand Down Expand Up @@ -81,7 +58,7 @@ protected TextClassificationConfig createTestInstance() {

@Override
protected TextClassificationConfig mutateInstanceForVersion(TextClassificationConfig instance, Version version) {
return mutateInstance(instance, version);
return mutateForVersion(instance, version);
}

public void testInvalidClassificationLabels() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEmbeddingConfig> {

public static TextEmbeddingConfig mutateForVersion(TextEmbeddingConfig instance, Version version) {
return new TextEmbeddingConfig(
instance.getVocabularyConfig(),
InferenceConfigTestScaffolding.mutateTokenizationForVersion(instance.getTokenization(), version),
instance.getResultsField()
);
}

@Override
protected boolean supportsUnknownFields() {
return true;
Expand Down Expand Up @@ -44,7 +52,7 @@ protected TextEmbeddingConfig createTestInstance() {

@Override
protected TextEmbeddingConfig mutateInstanceForVersion(TextEmbeddingConfig instance, Version version) {
return instance;
return mutateForVersion(instance, version);
}

public static TextEmbeddingConfig createRandom() {
Expand Down
Loading

0 comments on commit 08f69f0

Please sign in to comment.