Skip to content

Commit

Permalink
Code refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Apr 4, 2024
1 parent b794742 commit 16024b3
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.DiffableUtils;
Expand All @@ -16,6 +17,7 @@
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -25,14 +27,13 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.Set;
import java.util.stream.Collectors;

public class TrainedModelCacheMetadata implements Metadata.Custom {
Expand All @@ -52,11 +53,12 @@ public class TrainedModelCacheMetadata implements Metadata.Custom {

static {
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
List<TrainedModelCustomMetadataEntry> entries = new ArrayList<>();
while (p.nextToken() != XContentParser.Token.END_ARRAY) {
entries.add(TrainedModelCustomMetadataEntry.fromXContent(p));
Map<String, TrainedModelCustomMetadataEntry> entries = new HashMap<>();
while (p.nextToken() != XContentParser.Token.END_OBJECT) {
String modelId = p.currentName();
entries.put(modelId, TrainedModelCustomMetadataEntry.fromXContent(p));
}
return entries.stream().collect(Collectors.toMap(TrainedModelCustomMetadataEntry::getModelId, Function.identity()));
return entries;
}, ENTRIES);
}

Expand All @@ -73,6 +75,25 @@ public static NamedDiff<Metadata.Custom> readDiffFrom(StreamInput in) throws IOE
return new TrainedModelCacheMetadataDiff(in);
}

public static Set<String> getUpdatedModelIds(ClusterChangedEvent event) {
if (event.changedCustomMetadataSet().contains(TrainedModelCacheMetadata.NAME) == false) {
return Collections.emptySet();
}

Map<String, TrainedModelCustomMetadataEntry> oldCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.previousState()).entries();
Map<String, TrainedModelCustomMetadataEntry> newCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.state()).entries();

return Sets.union(oldCacheMetadataEntries.keySet(), newCacheMetadataEntries.keySet()).stream()
.filter(modelId -> {
if ((oldCacheMetadataEntries.containsKey(modelId) && newCacheMetadataEntries.containsKey(modelId)) == false) {
return true;
}

return Objects.equals(oldCacheMetadataEntries.get(modelId), newCacheMetadataEntries.get(modelId)) == false;
})
.collect(Collectors.toSet());
}

private final Map<String, TrainedModelCustomMetadataEntry> entries;

public TrainedModelCacheMetadata(Map<String, TrainedModelCustomMetadataEntry> entries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
Expand Down Expand Up @@ -750,7 +751,11 @@ private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer>

@Override
public void clusterChanged(ClusterChangedEvent event) {
logger.debug("Need to check trained model changes to flush the cache if needed");
TrainedModelCacheMetadata.getUpdatedModelIds(event).forEach(modelId -> {
localModelCache.invalidate(modelId);
logger.debug("Invalidated cache for model " + modelId);
});

final boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode();
// If we are not prefetching models and there were no model alias changes, don't bother handling the changes
if ((prefetchModels == false)
Expand Down Expand Up @@ -910,13 +915,13 @@ private Map<String, String> gatherLazyChangedAliasesAndUpdateModelAliases(
) {
Map<String, String> changedAliases = new HashMap<>();
if (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME)) {
final Map<java.lang.String, ModelAliasMetadata.ModelAliasEntry> modelAliasesToIds = new HashMap<>(
final Map<String, ModelAliasMetadata.ModelAliasEntry> modelAliasesToIds = new HashMap<>(
ModelAliasMetadata.fromState(event.state()).modelAliases()
);
modelIdToModelAliases.clear();
for (Map.Entry<java.lang.String, ModelAliasMetadata.ModelAliasEntry> aliasToId : modelAliasesToIds.entrySet()) {
for (Map.Entry<String, ModelAliasMetadata.ModelAliasEntry> aliasToId : modelAliasesToIds.entrySet()) {
modelIdToModelAliases.computeIfAbsent(aliasToId.getValue().getModelId(), k -> new HashSet<>()).add(aliasToId.getKey());
java.lang.String modelId = modelAliasToId.get(aliasToId.getKey());
String modelId = modelAliasToId.get(aliasToId.getKey());
if (modelId != null && modelId.equals(aliasToId.getValue().getModelId()) == false) {
if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) {
changedAliases.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
Expand All @@ -928,7 +933,7 @@ private Map<String, String> gatherLazyChangedAliasesAndUpdateModelAliases(
modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
}
}
Set<java.lang.String> removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet());
Set<String> removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet());
modelAliasToId.keySet().removeAll(removedAliases);
}
return changedAliases;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;

import java.util.HashMap;

Expand All @@ -42,6 +43,11 @@ public void deleteCacheMetadataEntry(String modelId, ActionListener<Acknowledged
this.modelCacheMetadataManagementTaskQueue.submitTask(deleteModelCacheMetadataTask.getDescription(), deleteModelCacheMetadataTask, null);
}

public void saveCacheMetadataEntry(TrainedModelConfig modelConfig, ActionListener<AcknowledgedResponse> listener) {
ModelCacheMetadataManagementTask putModelCacheMetadataTask = new PutModelCacheMetadataTask(modelConfig.getModelId(), listener);
this.modelCacheMetadataManagementTaskQueue.submitTask(putModelCacheMetadataTask.getDescription(), putModelCacheMetadataTask, null);
}

private abstract static class ModelCacheMetadataManagementTask implements ClusterStateTaskListener {
protected final ActionListener<AcknowledgedResponse> listener;

Expand All @@ -60,6 +66,27 @@ public void onFailure(@Nullable Exception e) {
}
}

private static class PutModelCacheMetadataTask extends ModelCacheMetadataManagementTask {
private final String modelId;

PutModelCacheMetadataTask(String modelId, ActionListener<AcknowledgedResponse> listener) {
super(listener);
this.modelId = modelId;
}

protected TrainedModelCacheMetadata execute(TrainedModelCacheMetadata currentCacheMetadata, TaskContext<ModelCacheMetadataManagementTask> taskContext) {
var entries = new HashMap<>(currentCacheMetadata.entries());
entries.put(modelId, new TrainedModelCacheMetadata.TrainedModelCustomMetadataEntry(modelId));
taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE));
return new TrainedModelCacheMetadata(entries);
}

@Override
protected String getDescription() {
return "saving cache metadata for model [" + modelId + "]";
}
}

private static class DeleteModelCacheMetadataTask extends ModelCacheMetadataManagementTask {
private final String modelId;

Expand All @@ -80,7 +107,7 @@ protected TrainedModelCacheMetadata execute(TrainedModelCacheMetadata currentCac
updatedCacheMetadata = currentCacheMetadata;
}

listener.onResponse(AcknowledgedResponse.TRUE);
taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE));
return updatedCacheMetadata;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,28 @@ public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, Actio
ML_ORIGIN,
TransportIndexAction.TYPE,
request,
ActionListener.wrap(indexResponse -> listener.onResponse(true), e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(
new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())
)
);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL, trainedModelConfig.getModelId()),
RestStatus.INTERNAL_SERVER_ERROR,
e
)
);
}
})
ActionListener.wrap(
indexResponse -> modelCacheMetadataService.saveCacheMetadataEntry(
trainedModelConfig,
ActionListener.wrap(resp -> listener.onResponse(true), listener::onFailure)
),
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(
new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())
)
);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Messages.getMessage(Messages.INFERENCE_FAILED_TO_STORE_MODEL, trainedModelConfig.getModelId()),
RestStatus.INTERNAL_SERVER_ERROR,
e
)
);
}
})
);
}

Expand Down Expand Up @@ -523,7 +528,11 @@ private void storeTrainedModelAndDefinition(
wrappedListener.onFailure(firstFailure);
return;
}
wrappedListener.onResponse(true);

modelCacheMetadataService.saveCacheMetadataEntry(
trainedModelConfig,
ActionListener.wrap(resp -> wrappedListener.onResponse(true), wrappedListener::onFailure)
);
}, wrappedListener::onFailure);

executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
Expand Down

0 comments on commit 16024b3

Please sign in to comment.