diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java index 8e8174c3a1834..68995958b68b6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java @@ -48,15 +48,15 @@ public class TrainedModelCacheMetadata implements Metadata.Custom { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, true, - args -> new TrainedModelCacheMetadata((Map) args[0]) + args -> new TrainedModelCacheMetadata((Map) args[0]) ); static { PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { - Map entries = new HashMap<>(); + Map entries = new HashMap<>(); while (p.nextToken() != XContentParser.Token.END_OBJECT) { String modelId = p.currentName(); - entries.put(modelId, TrainedModelCustomMetadataEntry.fromXContent(p)); + entries.put(modelId, TrainedModelCacheMetadataEntry.fromXContent(p)); } return entries; }, ENTRIES); @@ -80,8 +80,8 @@ public static Set getUpdatedModelIds(ClusterChangedEvent event) { return Collections.emptySet(); } - Map oldCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.previousState()).entries(); - Map newCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.state()).entries(); + Map oldCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.previousState()).entries(); + Map newCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.state()).entries(); return Sets.union(oldCacheMetadataEntries.keySet(), newCacheMetadataEntries.keySet()).stream() .filter(modelId -> { @@ -94,17 +94,17 @@ public static Set getUpdatedModelIds(ClusterChangedEvent event) { .collect(Collectors.toSet()); } - private final Map entries; + private final Map entries; - public TrainedModelCacheMetadata(Map entries) { + public TrainedModelCacheMetadata(Map entries) { this.entries = entries; } public TrainedModelCacheMetadata(StreamInput in) throws IOException { - this.entries = in.readImmutableMap(TrainedModelCustomMetadataEntry::new); + this.entries = in.readImmutableMap(TrainedModelCacheMetadataEntry::new); } - public Map entries() { + public Map entries() { return entries; } @@ -153,7 +153,7 @@ public int hashCode() { } public static class TrainedModelCacheMetadataDiff implements NamedDiff { - final Diff> entriesDiff; + final Diff> entriesDiff; TrainedModelCacheMetadataDiff(TrainedModelCacheMetadata before, TrainedModelCacheMetadata after) { this.entriesDiff = DiffableUtils.diff(before.entries, after.entries, DiffableUtils.getStringKeySerializer()); @@ -163,8 +163,8 @@ public static class TrainedModelCacheMetadataDiff implements NamedDiff, ToXContentObject { - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + public static class TrainedModelCacheMetadataEntry implements SimpleDiffable, ToXContentObject { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "trained_model_cache_metadata_entry", true, - args -> new TrainedModelCustomMetadataEntry((String) args[0]) + args -> new TrainedModelCacheMetadataEntry((String) args[0]) ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); } - private static Diff readDiffFrom(StreamInput in) throws IOException { - return SimpleDiffable.readDiffFrom(TrainedModelCustomMetadataEntry::new, in); + private static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(TrainedModelCacheMetadataEntry::new, in); } - private static TrainedModelCustomMetadataEntry fromXContent(XContentParser parser) { + private static TrainedModelCacheMetadataEntry fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } private final String modelId; - public TrainedModelCustomMetadataEntry(String modelId) { + public TrainedModelCacheMetadataEntry(String modelId) { this.modelId = modelId; } - TrainedModelCustomMetadataEntry(StreamInput in) throws IOException { + TrainedModelCacheMetadataEntry(StreamInput in) throws IOException { this.modelId = in.readString(); } @@ -238,7 +238,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TrainedModelCustomMetadataEntry that = (TrainedModelCustomMetadataEntry) o; + TrainedModelCacheMetadataEntry that = (TrainedModelCacheMetadataEntry) o; return Objects.equals(modelId, that.modelId); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java new file mode 100644 index 0000000000000..33f8b2ad50949 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractChunkedSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata.TrainedModelCacheMetadataEntry; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TrainedModelCacheMetadataTests extends AbstractChunkedSerializingTestCase { + public static TrainedModelCacheMetadataEntry randomEntry() { + return new TrainedModelCacheMetadataEntry(randomIdentifier()); + } + + public static TrainedModelCacheMetadata randomInstance() { + Map entries = Stream.generate(TrainedModelCacheMetadataTests::randomEntry) + .limit(randomInt(5)) + .collect(Collectors.toMap(TrainedModelCacheMetadataEntry::getModelId, Function.identity(), (k, k1) -> k, HashMap::new)); + + return new TrainedModelCacheMetadata(entries); + } + + @Override + protected TrainedModelCacheMetadata doParseInstance(XContentParser parser) throws IOException { + return TrainedModelCacheMetadata.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelCacheMetadata::new; + } + + @Override + protected TrainedModelCacheMetadata createTestInstance() { + return randomInstance(); + } + + @Override + protected TrainedModelCacheMetadata mutateInstance(TrainedModelCacheMetadata instance) { + return randomValueOtherThan(instance, () -> randomInstance()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelCacheMetadataTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java similarity index 95% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelCacheMetadataTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java index 25c3b3cf3f13f..3bed1cb3af10b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelCacheMetadataTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java @@ -16,7 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class TrainedModelCacheMetadataTests extends AbstractBWCSerializationTestCase { +public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase { private boolean lenient; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java index f3234c38c20d2..cc7764a98cefa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java @@ -76,7 +76,7 @@ private static class PutModelCacheMetadataTask extends ModelCacheMetadataManagem protected TrainedModelCacheMetadata execute(TrainedModelCacheMetadata currentCacheMetadata, TaskContext taskContext) { var entries = new HashMap<>(currentCacheMetadata.entries()); - entries.put(modelId, new TrainedModelCacheMetadata.TrainedModelCustomMetadataEntry(modelId)); + entries.put(modelId, new TrainedModelCacheMetadata.TrainedModelCacheMetadataEntry(modelId)); taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE)); return new TrainedModelCacheMetadata(entries); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 94e0c533ef5fc..f73513bd9ca32 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -67,7 +67,7 @@ public class TrainedModelProviderTests extends ESTestCase { public void testDeleteModelStoredAsResource() { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), mock(TrainedModelCacheMetadataService.class), xContentRegistry()); PlainActionFuture future = new PlainActionFuture<>(); // Should be OK as we don't make any client calls trainedModelProvider.deleteTrainedModel("lang_ident_model_1", future); @@ -77,7 +77,7 @@ public void testDeleteModelStoredAsResource() { public void testPutModelThatExistsAsResource() { TrainedModelConfig config = TrainedModelConfigTests.createTestInstance("lang_ident_model_1").build(); - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), mock(TrainedModelCacheMetadataService.class), xContentRegistry()); PlainActionFuture future = new PlainActionFuture<>(); trainedModelProvider.storeTrainedModel(config, future); ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet); @@ -85,7 +85,7 @@ public void testPutModelThatExistsAsResource() { } public void testGetModelThatExistsAsResource() throws Exception { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), mock(TrainedModelCacheMetadataService.class), xContentRegistry()); for (String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) { PlainActionFuture future = new PlainActionFuture<>(); trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, future); @@ -180,7 +180,7 @@ public void testExpandIdsPagination() { } public void testGetModelThatExistsAsResourceButIsMissing() { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), mock(TrainedModelCacheMetadataService.class), xContentRegistry()); ElasticsearchException ex = expectThrows( ElasticsearchException.class, () -> trainedModelProvider.loadModelFromResource("missing_model", randomBoolean()) @@ -350,7 +350,7 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreate() { try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("inferenceEntityId").build(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelConfig(config, future); @@ -362,7 +362,7 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreateWhen try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("inferenceEntityId").build(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelConfig(config, future, false); @@ -374,7 +374,7 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationIndex() { try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("inferenceEntityId").build(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelConfig(config, future, true); @@ -386,7 +386,7 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCr try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("inferenceEntityId"); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModel(config, future); @@ -398,7 +398,7 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCr try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("inferenceEntityId"); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModel(config, future, false); @@ -410,7 +410,7 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationIn try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("inferenceEntityId"); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModel(config, future, true); @@ -422,7 +422,7 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCre try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelDefinitionDoc(config, future); @@ -434,7 +434,7 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCre try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelDefinitionDoc(config, "index", future, false); @@ -446,7 +446,7 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationInd try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelDefinitionDoc(config, "index", future, true); @@ -458,7 +458,7 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var vocab = createVocabulary(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelVocabulary("inferenceEntityId", mock(VocabularyConfig.class), vocab, future); @@ -470,7 +470,7 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var vocab = createVocabulary(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelVocabulary("inferenceEntityId", mock(VocabularyConfig.class), vocab, future, false); @@ -482,7 +482,7 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationIndex( try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var vocab = createVocabulary(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelVocabulary("inferenceEntityId", mock(VocabularyConfig.class), vocab, future, true); @@ -494,7 +494,7 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreate() try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelMetadata(metadata, future); @@ -506,7 +506,7 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreateWh try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelMetadata(metadata, future, false); @@ -518,7 +518,7 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationIndex() try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelMetadata(metadata, future, true); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index 0139b0d500341..6c771522220e7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.hamcrest.Matcher; @@ -176,7 +177,7 @@ public void testLangInference() throws Exception { } InferenceDefinition grabModel() throws IOException { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), mock(TrainedModelCacheMetadataService.class), xContentRegistry()); PlainActionFuture future = new PlainActionFuture<>(); // Should be OK as we don't make any client calls trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), null, future);