diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java similarity index 54% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelMetadata.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java index 5ff52190ba4fe..cf0659e37b4fb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference; import org.elasticsearch.TransportVersion; +import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.cluster.NamedDiff; @@ -34,52 +35,61 @@ import java.util.function.Function; import java.util.stream.Collectors; -public class TrainedModelMetadata implements Metadata.Custom { +public class TrainedModelCacheMetadata implements Metadata.Custom { - public static final String NAME = "trained_model_metadata"; + public static final String NAME = "trained_model_cache_metadata"; - public static final TrainedModelMetadata EMPTY = new TrainedModelMetadata(new HashMap<>()); - private static final ParseField MODELS = new ParseField("models"); + public static final TrainedModelCacheMetadata EMPTY = new TrainedModelCacheMetadata(new HashMap<>()); + private static final ParseField ENTRIES = new ParseField("entries"); private static final ParseField MODEL_ID = new ParseField("model_id"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, true, - args -> new TrainedModelMetadata((Map) args[0]) + args -> new TrainedModelCacheMetadata((Map) args[0]) ); static { PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { - List models = new ArrayList<>(); + List entries = new ArrayList<>(); while (p.nextToken() != XContentParser.Token.END_ARRAY) { - models.add(TrainedModelMetadataEntry.fromXContent(p)); + entries.add(TrainedModelCustomMetadataEntry.fromXContent(p)); } - return models.stream().collect(Collectors.toMap(TrainedModelMetadataEntry::getModelId, Function.identity())); - }, MODELS); + return entries.stream().collect(Collectors.toMap(TrainedModelCustomMetadataEntry::getModelId, Function.identity())); + }, ENTRIES); } - public static TrainedModelMetadata fromXContent(XContentParser parser) { + public static TrainedModelCacheMetadata fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } + public static TrainedModelCacheMetadata fromState(ClusterState clusterState) { + TrainedModelCacheMetadata cacheMetadata = clusterState.getMetadata().custom(NAME); + return cacheMetadata == null ? EMPTY : cacheMetadata; + } + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { - return new TrainedModelMetadataDiff(in); + return new TrainedModelCacheMetadataDiff(in); } - private final Map models; + private final Map entries; + + public TrainedModelCacheMetadata(Map entries) { + this.entries = entries; + } - public TrainedModelMetadata(Map models) { - this.models = models; + public TrainedModelCacheMetadata(StreamInput in) throws IOException { + this.entries = in.readImmutableMap(TrainedModelCustomMetadataEntry::new); } - public TrainedModelMetadata(StreamInput in) throws IOException { - this.models = in.readImmutableMap(TrainedModelMetadataEntry::new); + public Map entries() { + return entries; } @Override public Iterator toXContentChunked(ToXContent.Params ignored) { - return ChunkedToXContentHelper.xContentValuesMap(MODELS.getPreferredName(), models); + return ChunkedToXContentHelper.xContentValuesMap(ENTRIES.getPreferredName(), entries); } @Override @@ -100,46 +110,46 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap(this.models, StreamOutput::writeWriteable); + out.writeMap(this.entries, StreamOutput::writeWriteable); } @Override public Diff diff(Metadata.Custom previousState) { - return new TrainedModelMetadataDiff((TrainedModelMetadata) previousState, this); + return new TrainedModelCacheMetadataDiff((TrainedModelCacheMetadata) previousState, this); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TrainedModelMetadata that = (TrainedModelMetadata) o; - return Objects.equals(models, that.models); + TrainedModelCacheMetadata that = (TrainedModelCacheMetadata) o; + return Objects.equals(entries, that.entries); } @Override public int hashCode() { - return Objects.hash(models); + return Objects.hash(entries); } - public static class TrainedModelMetadataDiff implements NamedDiff { - final Diff> modelsDiff; + public static class TrainedModelCacheMetadataDiff implements NamedDiff { + final Diff> entriesDiff; - TrainedModelMetadataDiff(TrainedModelMetadata before, TrainedModelMetadata after) { - this.modelsDiff = DiffableUtils.diff(before.models, after.models, DiffableUtils.getStringKeySerializer()); + TrainedModelCacheMetadataDiff(TrainedModelCacheMetadata before, TrainedModelCacheMetadata after) { + this.entriesDiff = DiffableUtils.diff(before.entries, after.entries, DiffableUtils.getStringKeySerializer()); } - TrainedModelMetadataDiff(StreamInput in) throws IOException { - this.modelsDiff = DiffableUtils.readJdkMapDiff( + TrainedModelCacheMetadataDiff(StreamInput in) throws IOException { + this.entriesDiff = DiffableUtils.readJdkMapDiff( in, DiffableUtils.getStringKeySerializer(), - TrainedModelMetadataEntry::new, - TrainedModelMetadataEntry::readDiffFrom + TrainedModelCustomMetadataEntry::new, + TrainedModelCustomMetadataEntry::readDiffFrom ); } @Override public Metadata.Custom apply(Metadata.Custom part) { - return new TrainedModelMetadata(modelsDiff.apply(((TrainedModelMetadata) part).models)); + return new TrainedModelCacheMetadata(entriesDiff.apply(((TrainedModelCacheMetadata) part).entries)); } @Override @@ -149,7 +159,7 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - modelsDiff.writeTo(out); + entriesDiff.writeTo(out); } @Override @@ -158,31 +168,31 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); } } - public static class TrainedModelMetadataEntry implements SimpleDiffable, ToXContentObject { - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "trained_model_metadata_entry", + public static class TrainedModelCustomMetadataEntry implements SimpleDiffable, ToXContentObject { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "trained_model_cache_metadata_entry", true, - args -> new TrainedModelMetadataEntry((String) args[0]) + args -> new TrainedModelCustomMetadataEntry((String) args[0]) ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); } - private static Diff readDiffFrom(StreamInput in) throws IOException { - return SimpleDiffable.readDiffFrom(TrainedModelMetadataEntry::new, in); + private static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(TrainedModelCustomMetadataEntry::new, in); } - private static TrainedModelMetadataEntry fromXContent(XContentParser parser) { + private static TrainedModelCustomMetadataEntry fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } private final String modelId; - public TrainedModelMetadataEntry(String modelId) { + public TrainedModelCustomMetadataEntry(String modelId) { this.modelId = modelId; } - TrainedModelMetadataEntry(StreamInput in) throws IOException { + TrainedModelCustomMetadataEntry(StreamInput in) throws IOException { this.modelId = in.readString(); } @@ -207,7 +217,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TrainedModelMetadataEntry that = (TrainedModelMetadataEntry) o; + TrainedModelCustomMetadataEntry that = (TrainedModelCustomMetadataEntry) o; return Objects.equals(modelId, that.modelId); } @@ -218,7 +228,7 @@ public int hashCode() { @Override public String toString() { - return "TrainedModelMetadataEntry{modelId='" + modelId + "'}"; + return "TrainedModelCacheMetadataEntry{modelId='" + modelId + "'}"; } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelCacheMetadataTests.java similarity index 95% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelCacheMetadataTests.java index 3bed1cb3af10b..25c3b3cf3f13f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelCacheMetadataTests.java @@ -16,7 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase { +public class TrainedModelCacheMetadataTests extends AbstractBWCSerializationTestCase { private boolean lenient; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 14841079dc9ef..889ec8a711927 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -194,7 +194,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelMetadata; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -335,6 +335,7 @@ import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerBuilder; import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankService; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.pytorch.process.BlackHolePyTorchProcess; import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory; @@ -1133,7 +1134,8 @@ public Collection createComponents(PluginServices services) { clusterService, threadPool ); - final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final TrainedModelCacheMetadataService trainedModelCacheMetadataService = new TrainedModelCacheMetadataService(clusterService); + final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, trainedModelCacheMetadataService, xContentRegistry); final ModelLoadingService modelLoadingService = new ModelLoadingService( trainedModelProvider, inferenceAuditor, @@ -1825,8 +1827,8 @@ public List getNamedXContent() { namedXContent.add( new NamedXContentRegistry.Entry( Metadata.Custom.class, - new ParseField((TrainedModelMetadata.NAME)), - TrainedModelMetadata::fromXContent + new ParseField((TrainedModelCacheMetadata.NAME)), + TrainedModelCacheMetadata::fromXContent ) ); namedXContent.add( @@ -1867,8 +1869,8 @@ public List getNamedWriteables() { // Custom metadata namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelMetadata.NAME, TrainedModelMetadata::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelMetadata.NAME, TrainedModelMetadata::readDiffFrom)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::readDiffFrom)); namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); namedWriteables.add( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index b502e0d6db341..9110a4cb2ce00 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -127,9 +127,11 @@ public class TrainedModelProvider { private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; private final NamedXContentRegistry xContentRegistry; + private final TrainedModelCacheMetadataService modelCacheMetadataService; - public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) { + public TrainedModelProvider(Client client, TrainedModelCacheMetadataService modelCacheMetadataService, NamedXContentRegistry xContentRegistry) { this.client = client; + this.modelCacheMetadataService = modelCacheMetadataService; this.xContentRegistry = xContentRegistry; } @@ -894,7 +896,11 @@ public void deleteTrainedModel(String modelId, ActionListener listener) listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return; } - listener.onResponse(true); + + modelCacheMetadataService.deleteCacheMetadataEntry(modelId, ActionListener.wrap( + acknowledgedResponse -> listener.onResponse(true), + listener::onFailure + )); }, e -> { if (e.getClass() == IndexNotFoundException.class) { listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));