Skip to content

Commit

Permalink
Create a service to manage cache entries in the cluster state.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Apr 3, 2024
1 parent f6eef1e commit bb8f6f2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.DiffableUtils;
import org.elasticsearch.cluster.NamedDiff;
Expand All @@ -34,52 +35,61 @@
import java.util.function.Function;
import java.util.stream.Collectors;

public class TrainedModelMetadata implements Metadata.Custom {
public class TrainedModelCacheMetadata implements Metadata.Custom {

public static final String NAME = "trained_model_metadata";
public static final String NAME = "trained_model_cache_metadata";

public static final TrainedModelMetadata EMPTY = new TrainedModelMetadata(new HashMap<>());
private static final ParseField MODELS = new ParseField("models");
public static final TrainedModelCacheMetadata EMPTY = new TrainedModelCacheMetadata(new HashMap<>());
private static final ParseField ENTRIES = new ParseField("entries");
private static final ParseField MODEL_ID = new ParseField("model_id");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<TrainedModelMetadata, Void> PARSER = new ConstructingObjectParser<>(
private static final ConstructingObjectParser<TrainedModelCacheMetadata, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
args -> new TrainedModelMetadata((Map<String, TrainedModelMetadataEntry>) args[0])
args -> new TrainedModelCacheMetadata((Map<String, TrainedModelCustomMetadataEntry>) args[0])
);

static {
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
List<TrainedModelMetadataEntry> models = new ArrayList<>();
List<TrainedModelCustomMetadataEntry> entries = new ArrayList<>();
while (p.nextToken() != XContentParser.Token.END_ARRAY) {
models.add(TrainedModelMetadataEntry.fromXContent(p));
entries.add(TrainedModelCustomMetadataEntry.fromXContent(p));
}
return models.stream().collect(Collectors.toMap(TrainedModelMetadataEntry::getModelId, Function.identity()));
}, MODELS);
return entries.stream().collect(Collectors.toMap(TrainedModelCustomMetadataEntry::getModelId, Function.identity()));
}, ENTRIES);
}

public static TrainedModelMetadata fromXContent(XContentParser parser) {
public static TrainedModelCacheMetadata fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public static TrainedModelCacheMetadata fromState(ClusterState clusterState) {
TrainedModelCacheMetadata cacheMetadata = clusterState.getMetadata().custom(NAME);
return cacheMetadata == null ? EMPTY : cacheMetadata;
}

public static NamedDiff<Metadata.Custom> readDiffFrom(StreamInput in) throws IOException {
return new TrainedModelMetadataDiff(in);
return new TrainedModelCacheMetadataDiff(in);
}

private final Map<String, TrainedModelMetadataEntry> models;
private final Map<String, TrainedModelCustomMetadataEntry> entries;

public TrainedModelCacheMetadata(Map<String, TrainedModelCustomMetadataEntry> entries) {
this.entries = entries;
}

public TrainedModelMetadata(Map<String, TrainedModelMetadataEntry> models) {
this.models = models;
public TrainedModelCacheMetadata(StreamInput in) throws IOException {
this.entries = in.readImmutableMap(TrainedModelCustomMetadataEntry::new);
}

public TrainedModelMetadata(StreamInput in) throws IOException {
this.models = in.readImmutableMap(TrainedModelMetadataEntry::new);
public Map<String, TrainedModelCustomMetadataEntry> entries() {
return entries;
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params ignored) {
return ChunkedToXContentHelper.xContentValuesMap(MODELS.getPreferredName(), models);
return ChunkedToXContentHelper.xContentValuesMap(ENTRIES.getPreferredName(), entries);
}

@Override
Expand All @@ -100,46 +110,46 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeMap(this.models, StreamOutput::writeWriteable);
out.writeMap(this.entries, StreamOutput::writeWriteable);
}

@Override
public Diff<Metadata.Custom> diff(Metadata.Custom previousState) {
return new TrainedModelMetadataDiff((TrainedModelMetadata) previousState, this);
return new TrainedModelCacheMetadataDiff((TrainedModelCacheMetadata) previousState, this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelMetadata that = (TrainedModelMetadata) o;
return Objects.equals(models, that.models);
TrainedModelCacheMetadata that = (TrainedModelCacheMetadata) o;
return Objects.equals(entries, that.entries);
}

@Override
public int hashCode() {
return Objects.hash(models);
return Objects.hash(entries);
}

public static class TrainedModelMetadataDiff implements NamedDiff<Metadata.Custom> {
final Diff<Map<String, TrainedModelMetadataEntry>> modelsDiff;
public static class TrainedModelCacheMetadataDiff implements NamedDiff<Metadata.Custom> {
final Diff<Map<String, TrainedModelCustomMetadataEntry>> entriesDiff;

TrainedModelMetadataDiff(TrainedModelMetadata before, TrainedModelMetadata after) {
this.modelsDiff = DiffableUtils.diff(before.models, after.models, DiffableUtils.getStringKeySerializer());
TrainedModelCacheMetadataDiff(TrainedModelCacheMetadata before, TrainedModelCacheMetadata after) {
this.entriesDiff = DiffableUtils.diff(before.entries, after.entries, DiffableUtils.getStringKeySerializer());
}

TrainedModelMetadataDiff(StreamInput in) throws IOException {
this.modelsDiff = DiffableUtils.readJdkMapDiff(
TrainedModelCacheMetadataDiff(StreamInput in) throws IOException {
this.entriesDiff = DiffableUtils.readJdkMapDiff(
in,
DiffableUtils.getStringKeySerializer(),
TrainedModelMetadataEntry::new,
TrainedModelMetadataEntry::readDiffFrom
TrainedModelCustomMetadataEntry::new,
TrainedModelCustomMetadataEntry::readDiffFrom
);
}

@Override
public Metadata.Custom apply(Metadata.Custom part) {
return new TrainedModelMetadata(modelsDiff.apply(((TrainedModelMetadata) part).models));
return new TrainedModelCacheMetadata(entriesDiff.apply(((TrainedModelCacheMetadata) part).entries));
}

@Override
Expand All @@ -149,7 +159,7 @@ public String getWriteableName() {

@Override
public void writeTo(StreamOutput out) throws IOException {
modelsDiff.writeTo(out);
entriesDiff.writeTo(out);
}

@Override
Expand All @@ -158,31 +168,31 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
}
}
public static class TrainedModelMetadataEntry implements SimpleDiffable<TrainedModelMetadataEntry>, ToXContentObject {
private static final ConstructingObjectParser<TrainedModelMetadataEntry, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_metadata_entry",
public static class TrainedModelCustomMetadataEntry implements SimpleDiffable<TrainedModelCustomMetadataEntry>, ToXContentObject {
private static final ConstructingObjectParser<TrainedModelCustomMetadataEntry, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_cache_metadata_entry",
true,
args -> new TrainedModelMetadataEntry((String) args[0])
args -> new TrainedModelCustomMetadataEntry((String) args[0])
);
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
}

private static Diff<TrainedModelMetadataEntry> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(TrainedModelMetadataEntry::new, in);
private static Diff<TrainedModelCustomMetadataEntry> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(TrainedModelCustomMetadataEntry::new, in);
}

private static TrainedModelMetadataEntry fromXContent(XContentParser parser) {
private static TrainedModelCustomMetadataEntry fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final String modelId;

public TrainedModelMetadataEntry(String modelId) {
public TrainedModelCustomMetadataEntry(String modelId) {
this.modelId = modelId;
}

TrainedModelMetadataEntry(StreamInput in) throws IOException {
TrainedModelCustomMetadataEntry(StreamInput in) throws IOException {
this.modelId = in.readString();
}

Expand All @@ -207,7 +217,7 @@ public void writeTo(StreamOutput out) throws IOException {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelMetadataEntry that = (TrainedModelMetadataEntry) o;
TrainedModelCustomMetadataEntry that = (TrainedModelCustomMetadataEntry) o;
return Objects.equals(modelId, that.modelId);
}

Expand All @@ -218,7 +228,7 @@ public int hashCode() {

@Override
public String toString() {
return "TrainedModelMetadataEntry{modelId='" + modelId + "'}";
return "TrainedModelCacheMetadataEntry{modelId='" + modelId + "'}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase<TrainedModelMetadata> {
public class TrainedModelCacheMetadataTests extends AbstractBWCSerializationTestCase<TrainedModelMetadata> {

private boolean lenient;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
Expand Down Expand Up @@ -335,6 +335,7 @@
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerBuilder;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankService;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.inference.pytorch.process.BlackHolePyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory;
Expand Down Expand Up @@ -1133,7 +1134,8 @@ public Collection<?> createComponents(PluginServices services) {
clusterService,
threadPool
);
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
final TrainedModelCacheMetadataService trainedModelCacheMetadataService = new TrainedModelCacheMetadataService(clusterService);
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, trainedModelCacheMetadataService, xContentRegistry);
final ModelLoadingService modelLoadingService = new ModelLoadingService(
trainedModelProvider,
inferenceAuditor,
Expand Down Expand Up @@ -1825,8 +1827,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
namedXContent.add(
new NamedXContentRegistry.Entry(
Metadata.Custom.class,
new ParseField((TrainedModelMetadata.NAME)),
TrainedModelMetadata::fromXContent
new ParseField((TrainedModelCacheMetadata.NAME)),
TrainedModelCacheMetadata::fromXContent
)
);
namedXContent.add(
Expand Down Expand Up @@ -1867,8 +1869,8 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
// Custom metadata
namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelMetadata.NAME, TrainedModelMetadata::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelMetadata.NAME, TrainedModelMetadata::readDiffFrom));
namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::readDiffFrom));
namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom));
namedWriteables.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ public class TrainedModelProvider {
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
private final Client client;
private final NamedXContentRegistry xContentRegistry;
private final TrainedModelCacheMetadataService modelCacheMetadataService;

public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) {
public TrainedModelProvider(Client client, TrainedModelCacheMetadataService modelCacheMetadataService, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.modelCacheMetadataService = modelCacheMetadataService;
this.xContentRegistry = xContentRegistry;
}

Expand Down Expand Up @@ -894,7 +896,11 @@ public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener)
listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return;
}
listener.onResponse(true);

modelCacheMetadataService.deleteCacheMetadataEntry(modelId, ActionListener.wrap(
acknowledgedResponse -> listener.onResponse(true),
listener::onFailure
));
}, e -> {
if (e.getClass() == IndexNotFoundException.class) {
listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
Expand Down

0 comments on commit bb8f6f2

Please sign in to comment.