From 80bafa3b1bc6e737b4633b0f0b1cb33da688b860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20FOUCRET?= Date: Mon, 22 Apr 2024 19:16:23 +0200 Subject: [PATCH] Adding trained model metadata class. (#106988) --- .../org/elasticsearch/TransportVersions.java | 1 + .../action/FlushTrainedModelCacheAction.java | 54 ++++++ .../inference/TrainedModelCacheMetadata.java | 109 +++++++++++ ...shTrainedModelCacheActionRequestTests.java | 42 +++++ .../TrainedModelCacheMetadataTests.java | 36 ++++ .../MlLearningToRankRescorerIT.java | 98 +++++++++- .../ml/integration/MlNativeIntegTestCase.java | 7 + .../ChunkedTrainedModelPersisterIT.java | 14 +- .../integration/ModelInferenceActionIT.java | 14 +- .../integration/TrainedModelProviderIT.java | 14 +- .../xpack/ml/MachineLearning.java | 29 ++- ...TransportFlushTrainedModelCacheAction.java | 67 +++++++ .../loadingservice/ModelLoadingService.java | 15 +- .../TrainedModelCacheMetadataService.java | 131 +++++++++++++ .../persistence/TrainedModelProvider.java | 26 ++- ...portFlushTrainedModelCacheActionTests.java | 85 +++++++++ ...TrainedModelCacheMetadataServiceTests.java | 178 ++++++++++++++++++ .../TrainedModelProviderTests.java | 54 ++++-- .../LangIdentNeuralNetworkInferenceTests.java | 7 +- .../xpack/security/operator/Constants.java | 1 + 20 files changed, 948 insertions(+), 34 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataServiceTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index b073cabbe79c7..456497c167294 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -178,6 +178,7 @@ static TransportVersion def(int id) { public static final TransportVersion WATERMARK_THRESHOLDS_STATS = def(8_637_00_0); public static final TransportVersion ENRICH_CACHE_ADDITIONAL_STATS = def(8_638_00_0); public static final TransportVersion ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED = def(8_639_00_0); + public static final TransportVersion ML_TRAINED_MODEL_CACHE_METADATA_ADDED = def(8_640_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java new file mode 100644 index 0000000000000..bdba626676b2d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.TimeValue; + +import java.io.IOException; +import java.util.Objects; + +public class FlushTrainedModelCacheAction extends ActionType { + + public static final FlushTrainedModelCacheAction INSTANCE = new FlushTrainedModelCacheAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/clear_model_cache"; + + private FlushTrainedModelCacheAction() { + super(NAME); + } + + public static class Request extends AcknowledgedRequest { + public Request() { + super(); + } + + Request(TimeValue timeout) { + super(timeout); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public int hashCode() { + return Objects.hashCode(ackTimeout()); + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if (other == null || getClass() != other.getClass()) return false; + Request that = (Request) other; + return Objects.equals(that.ackTimeout(), ackTimeout()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java new file mode 100644 index 0000000000000..35c6bf96a09e1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadata.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.cluster.AbstractNamedDiffable; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.Objects; + +public class TrainedModelCacheMetadata extends AbstractNamedDiffable implements Metadata.Custom { + public static final String NAME = "trained_model_cache_metadata"; + public static final TrainedModelCacheMetadata EMPTY = new TrainedModelCacheMetadata(0L); + private static final ParseField VERSION_FIELD = new ParseField("version"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> new TrainedModelCacheMetadata((long) args[0]) + ); + + static { + PARSER.declareLong(ConstructingObjectParser.constructorArg(), VERSION_FIELD); + } + + public static TrainedModelCacheMetadata fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static TrainedModelCacheMetadata fromState(ClusterState clusterState) { + TrainedModelCacheMetadata cacheMetadata = clusterState.getMetadata().custom(NAME); + return cacheMetadata == null ? EMPTY : cacheMetadata; + } + + public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException { + return readDiffFrom(Metadata.Custom.class, NAME, streamInput); + } + + private final long version; + + public TrainedModelCacheMetadata(long version) { + this.version = version; + } + + public TrainedModelCacheMetadata(StreamInput in) throws IOException { + this.version = in.readVLong(); + } + + public long version() { + return version; + } + + @Override + public Iterator toXContentChunked(ToXContent.Params ignored) { + return Iterators.single(((builder, params) -> { return builder.field(VERSION_FIELD.getPreferredName(), version); })); + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_TRAINED_MODEL_CACHE_METADATA_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(version); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelCacheMetadata that = (TrainedModelCacheMetadata) o; + return Objects.equals(version, that.version); + } + + @Override + public int hashCode() { + return Objects.hash(version); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheActionRequestTests.java new file mode 100644 index 0000000000000..dd4c2fc33723b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheActionRequestTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction.Request; + +import java.io.IOException; + +public class FlushTrainedModelCacheActionRequestTests extends AbstractBWCWireSerializationTestCase { + @Override + protected Request createTestInstance() { + return randomBoolean() ? new Request() : new Request(randomTimeout()); + } + + @Override + protected Request mutateInstance(Request instance) throws IOException { + return new Request(randomValueOtherThan(instance.timeout(), this::randomTimeout)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + + @Override + protected Request mutateInstanceForVersion(Request instance, TransportVersion version) { + return instance; + } + + private TimeValue randomTimeout() { + return TimeValue.parseTimeValue(randomTimeValue(), null, "timeout"); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java new file mode 100644 index 0000000000000..577cb3e288676 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelCacheMetadataTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractChunkedSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class TrainedModelCacheMetadataTests extends AbstractChunkedSerializingTestCase { + @Override + protected TrainedModelCacheMetadata doParseInstance(XContentParser parser) throws IOException { + return TrainedModelCacheMetadata.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelCacheMetadata::new; + } + + @Override + protected TrainedModelCacheMetadata createTestInstance() { + return new TrainedModelCacheMetadata(randomNonNegativeLong()); + } + + @Override + protected TrainedModelCacheMetadata mutateInstance(TrainedModelCacheMetadata instance) { + return new TrainedModelCacheMetadata(randomValueOtherThan(instance.version(), () -> randomNonNegativeLong())); + } +} diff --git a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java index 0dab4f9e4256c..e5238c4aa44f0 100644 --- a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java +++ b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java @@ -28,7 +28,7 @@ public class MlLearningToRankRescorerIT extends ESRestTestCase { @Before public void setupModelAndData() throws IOException { - putRegressionModel(MODEL_ID, """ + putLearningToRankModel(MODEL_ID, """ { "description": "super complex model for tests", "input": { "field_names": ["cost", "product"] }, @@ -328,6 +328,95 @@ public void testLtrCanMatch() throws Exception { assertThat(response.toString(), (List) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0)); } + @SuppressWarnings("unchecked") + public void testModelCacheIsFlushedOnModelChange() throws IOException { + String searchBody = """ + { + "rescore": { + "window_size": 10, + "learning_to_rank": { + "model_id": "basic-ltr-model" + } + } + }"""; + + Response searchResponse = searchDfs(searchBody); + Map response = responseAsMap(searchResponse); + assertThat( + response.toString(), + (List) XContentMapValues.extractValue("hits.hits._score", response), + contains(20.0, 20.0, 9.0, 9.0, 6.0) + ); + + deleteLearningToRankModel(MODEL_ID); + putLearningToRankModel(MODEL_ID, """ + { + "input": { "field_names": ["cost"] }, + "inference_config": { + "learning_to_rank": { + "feature_extractors": [ + { + "query_extractor": { + "feature_name": "cost", + "query": { + "script_score": { + "query": { "match_all": {} }, + "script": { "source": "return doc[\\"cost\\"].value" } + } + } + } + } + ] + } + }, + "definition": { + "trained_model": { + "ensemble": { + "feature_names": ["cost"], + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["cost"], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 1000, + "decision_type": "lt", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 1.0 + }, + { + "node_index": 2, + "leaf_value": 10 + } + ], + "target_type": "regression" + } + } + ] + } + } + } + } + """); + + searchResponse = searchDfs(searchBody); + response = responseAsMap(searchResponse); + assertThat( + response.toString(), + (List) XContentMapValues.extractValue("hits.hits._score", response), + contains(10.0, 1.0, 1.0, 1.0, 1.0) + ); + } + private void indexData(String data) throws IOException { Request request = new Request("POST", INDEX_NAME + "/_doc"); request.setJsonEntity(data); @@ -354,7 +443,12 @@ private Response searchCanMatch(String searchBody, boolean dfs) throws IOExcepti return client().performRequest(request); } - private void putRegressionModel(String modelId, String body) throws IOException { + private void deleteLearningToRankModel(String modelId) throws IOException { + Request model = new Request("DELETE", "_ml/trained_models/" + modelId); + assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); + } + + private void putLearningToRankModel(String modelId, String body) throws IOException { Request model = new Request("PUT", "_ml/trained_models/" + modelId); model.setJsonEntity(body); assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index c4640bc17845a..7addedf779450 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -81,6 +81,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -352,6 +353,12 @@ protected void ensureClusterStateConsistency() throws IOException { ); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); + entries.add( + new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::new) + ); + entries.add( + new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::readDiffFrom) + ); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, IndexLifecycleMetadata.TYPE, IndexLifecycleMetadata::new)); entries.add( diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 9b3326a4ba348..8c9c527382106 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.Tuple; @@ -39,6 +41,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; @@ -57,14 +60,23 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.startsWith; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { private TrainedModelProvider trainedModelProvider; @Before + @SuppressWarnings("unchecked") public void createComponents() throws Exception { - trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + TrainedModelCacheMetadataService modelCacheMetadataService = mock(TrainedModelCacheMetadataService.class); + doAnswer(invocationOnMock -> { + invocationOnMock.getArgument(0, ActionListener.class).onResponse(AcknowledgedResponse.TRUE); + return Void.TYPE; + }).when(modelCacheMetadataService).updateCacheVersion(any(ActionListener.class)); + trainedModelProvider = new TrainedModelProvider(client(), modelCacheMetadataService, xContentRegistry()); waitForMlTemplates(); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index e03445912175a..39f6ea87e4e2a 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.MlConfigVersion; @@ -30,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; @@ -52,14 +55,23 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; public class ModelInferenceActionIT extends MlSingleNodeTestCase { private TrainedModelProvider trainedModelProvider; @Before + @SuppressWarnings("unchecked") public void createComponents() throws Exception { - trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + TrainedModelCacheMetadataService modelCacheMetadataService = mock(TrainedModelCacheMetadataService.class); + doAnswer(invocationOnMock -> { + invocationOnMock.getArgument(0, ActionListener.class).onResponse(AcknowledgedResponse.TRUE); + return Void.TYPE; + }).when(modelCacheMetadataService).updateCacheVersion(any(ActionListener.class)); + trainedModelProvider = new TrainedModelProvider(client(), modelCacheMetadataService, xContentRegistry()); waitForMlTemplates(); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index ffe70d9747a56..e103d439d269a 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -13,6 +14,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.broadcast.BroadcastResponse; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.license.License; @@ -31,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; @@ -48,14 +51,23 @@ import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; public class TrainedModelProviderIT extends MlSingleNodeTestCase { private TrainedModelProvider trainedModelProvider; @Before + @SuppressWarnings("unchecked") public void createComponents() throws Exception { - trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + TrainedModelCacheMetadataService modelCacheMetadataService = mock(TrainedModelCacheMetadataService.class); + doAnswer(invocationOnMock -> { + invocationOnMock.getArgument(0, ActionListener.class).onResponse(AcknowledgedResponse.TRUE); + return Void.TYPE; + }).when(modelCacheMetadataService).updateCacheVersion(any(ActionListener.class)); + trainedModelProvider = new TrainedModelProvider(client(), modelCacheMetadataService, xContentRegistry()); waitForMlTemplates(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 7fa2bcca952bf..1b5951ffdb0e0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -119,6 +119,7 @@ import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; import org.elasticsearch.xpack.core.ml.action.ForecastJobAction; import org.elasticsearch.xpack.core.ml.action.GetBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetCalendarEventsAction; @@ -192,6 +193,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -225,6 +227,7 @@ import org.elasticsearch.xpack.ml.action.TransportExternalInferModelAction; import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction; import org.elasticsearch.xpack.ml.action.TransportFlushJobAction; +import org.elasticsearch.xpack.ml.action.TransportFlushTrainedModelCacheAction; import org.elasticsearch.xpack.ml.action.TransportForecastJobAction; import org.elasticsearch.xpack.ml.action.TransportGetBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetCalendarEventsAction; @@ -332,6 +335,7 @@ import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerBuilder; import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankService; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.pytorch.process.BlackHolePyTorchProcess; import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory; @@ -1128,7 +1132,15 @@ public Collection createComponents(PluginServices services) { clusterService, threadPool ); - final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final TrainedModelCacheMetadataService trainedModelCacheMetadataService = new TrainedModelCacheMetadataService( + clusterService, + client + ); + final TrainedModelProvider trainedModelProvider = new TrainedModelProvider( + client, + trainedModelCacheMetadataService, + xContentRegistry + ); final ModelLoadingService modelLoadingService = new ModelLoadingService( trainedModelProvider, inferenceAuditor, @@ -1310,6 +1322,7 @@ public Collection createComponents(PluginServices services) { dataFrameAnalyticsConfigProvider, nativeStorageProvider, modelLoadingService, + trainedModelCacheMetadataService, trainedModelProvider, trainedModelAssignmentService, trainedModelAllocationClusterServiceSetOnce.get(), @@ -1558,6 +1571,7 @@ public List getRestHandlers( actionHandlers.add( new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class) ); + actionHandlers.add(new ActionHandler<>(FlushTrainedModelCacheAction.INSTANCE, TransportFlushTrainedModelCacheAction.class)); actionHandlers.add(new ActionHandler<>(InferModelAction.INSTANCE, TransportInternalInferModelAction.class)); actionHandlers.add(new ActionHandler<>(InferModelAction.EXTERNAL_INSTANCE, TransportExternalInferModelAction.class)); actionHandlers.add(new ActionHandler<>(GetDeploymentStatsAction.INSTANCE, TransportGetDeploymentStatsAction.class)); @@ -1817,6 +1831,13 @@ public List getNamedXContent() { namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + namedXContent.add( + new NamedXContentRegistry.Entry( + Metadata.Custom.class, + new ParseField((TrainedModelCacheMetadata.NAME)), + TrainedModelCacheMetadata::fromXContent + ) + ); namedXContent.add( new NamedXContentRegistry.Entry( Metadata.Custom.class, @@ -1855,6 +1876,12 @@ public List getNamedWriteables() { // Custom metadata namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::readDiffFrom) + ); namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); namedWriteables.add( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java new file mode 100644 index 0000000000000..ab7c9f399fdc1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; + +public class TransportFlushTrainedModelCacheAction extends AcknowledgedTransportMasterNodeAction { + + private final TrainedModelCacheMetadataService modelCacheMetadataService; + + @Inject + public TransportFlushTrainedModelCacheAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelCacheMetadataService modelCacheMetadataService + ) { + super( + FlushTrainedModelCacheAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + FlushTrainedModelCacheAction.Request::new, + indexNameExpressionResolver, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.modelCacheMetadataService = modelCacheMetadataService; + } + + @Override + protected void masterOperation( + Task task, + FlushTrainedModelCacheAction.Request request, + ClusterState state, + ActionListener listener + ) { + modelCacheMetadataService.updateCacheVersion(listener); + } + + @Override + protected ClusterBlockException checkBlock(FlushTrainedModelCacheAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 5869f353c80c9..43e20a6581e07 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -38,6 +38,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; @@ -750,6 +751,12 @@ private void cacheEvictionListener(RemovalNotification @Override public void clusterChanged(ClusterChangedEvent event) { + if (event.changedCustomMetadataSet().contains(TrainedModelCacheMetadata.NAME)) { + // Flush all models cache since we are detecting some changes. + logger.trace("Trained model cache invalidated on node [{}]", () -> event.state().nodes().getLocalNodeId()); + localModelCache.invalidateAll(); + } + final boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode(); // If we are not prefetching models and there were no model alias changes, don't bother handling the changes if ((prefetchModels == false) @@ -909,13 +916,13 @@ private Map gatherLazyChangedAliasesAndUpdateModelAliases( ) { Map changedAliases = new HashMap<>(); if (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME)) { - final Map modelAliasesToIds = new HashMap<>( + final Map modelAliasesToIds = new HashMap<>( ModelAliasMetadata.fromState(event.state()).modelAliases() ); modelIdToModelAliases.clear(); - for (Map.Entry aliasToId : modelAliasesToIds.entrySet()) { + for (Map.Entry aliasToId : modelAliasesToIds.entrySet()) { modelIdToModelAliases.computeIfAbsent(aliasToId.getValue().getModelId(), k -> new HashSet<>()).add(aliasToId.getKey()); - java.lang.String modelId = modelAliasToId.get(aliasToId.getKey()); + String modelId = modelAliasToId.get(aliasToId.getKey()); if (modelId != null && modelId.equals(aliasToId.getValue().getModelId()) == false) { if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) { changedAliases.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); @@ -927,7 +934,7 @@ private Map gatherLazyChangedAliasesAndUpdateModelAliases( modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); } } - Set removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet()); + Set removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet()); modelAliasToId.keySet().removeAll(removedAliases); } return changedAliases; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java new file mode 100644 index 0000000000000..bd7510a09a013 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.persistence; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.ClusterStateTaskExecutor; +import org.elasticsearch.cluster.ClusterStateTaskExecutor.TaskContext; +import org.elasticsearch.cluster.ClusterStateTaskListener; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.MasterServiceTaskQueue; +import org.elasticsearch.common.Priority; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class TrainedModelCacheMetadataService implements ClusterStateListener { + private static final Logger LOGGER = LogManager.getLogger(TrainedModelCacheMetadataService.class); + static final String TASK_QUEUE_NAME = "trained-models-cache-metadata-management"; + private final MasterServiceTaskQueue metadataUpdateTaskQueue; + private final Client client; + private volatile boolean isMasterNode = false; + + public TrainedModelCacheMetadataService(ClusterService clusterService, Client client) { + this.client = new OriginSettingClient(client, ML_ORIGIN); + CacheMetadataUpdateTaskExecutor metadataUpdateTaskExecutor = new CacheMetadataUpdateTaskExecutor(); + this.metadataUpdateTaskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, metadataUpdateTaskExecutor); + clusterService.addListener(this); + } + + public void updateCacheVersion(ActionListener listener) { + if (this.isMasterNode == false) { + client.execute(FlushTrainedModelCacheAction.INSTANCE, new FlushTrainedModelCacheAction.Request(), listener); + return; + } + + CacheMetadataUpdateTask updateMetadataTask = new RefreshCacheMetadataVersionTask(listener); + this.metadataUpdateTaskQueue.submitTask(updateMetadataTask.getDescription(), updateMetadataTask, null); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.state().clusterRecovered() == false || event.state().nodes().getMasterNode() == null) { + return; + } + this.isMasterNode = event.localNodeMaster(); + } + + abstract static class CacheMetadataUpdateTask implements ClusterStateTaskListener { + protected final ActionListener listener; + + CacheMetadataUpdateTask(ActionListener listener) { + this.listener = listener; + } + + protected abstract TrainedModelCacheMetadata execute( + TrainedModelCacheMetadata currentCacheMetadata, + TaskContext taskContext + ); + + protected abstract String getDescription(); + + @Override + public void onFailure(@Nullable Exception e) { + LOGGER.error("unexpected failure during [" + getDescription() + "]", e); + listener.onFailure(e); + } + } + + static class RefreshCacheMetadataVersionTask extends CacheMetadataUpdateTask { + RefreshCacheMetadataVersionTask(ActionListener listener) { + super(listener); + } + + @Override + protected TrainedModelCacheMetadata execute( + TrainedModelCacheMetadata currentCacheMetadata, + TaskContext taskContext + ) { + long newVersion = currentCacheMetadata.version() < Long.MAX_VALUE ? currentCacheMetadata.version() + 1 : 1L; + taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE)); + return new TrainedModelCacheMetadata(newVersion); + } + + @Override + protected String getDescription() { + return "refresh trained model cache version"; + } + } + + static class CacheMetadataUpdateTaskExecutor implements ClusterStateTaskExecutor { + @Override + public ClusterState execute(BatchExecutionContext batchExecutionContext) { + final var initialState = batchExecutionContext.initialState(); + XPackPlugin.checkReadyForXPackCustomMetadata(initialState); + + final TrainedModelCacheMetadata originalCacheMetadata = TrainedModelCacheMetadata.fromState(initialState); + TrainedModelCacheMetadata currentCacheMetadata = originalCacheMetadata; + + for (final var taskContext : batchExecutionContext.taskContexts()) { + try (var ignored = taskContext.captureResponseHeaders()) { + currentCacheMetadata = taskContext.getTask().execute(currentCacheMetadata, taskContext); + } + } + + if (currentCacheMetadata == originalCacheMetadata) { + return initialState; + } + + return ClusterState.builder(initialState) + .metadata(Metadata.builder(initialState.metadata()).putCustom(TrainedModelCacheMetadata.NAME, currentCacheMetadata)) + .build(); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index b9b38cb07fa39..f493c735d87ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -30,6 +30,7 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.broadcast.BroadcastResponse; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Numbers; @@ -127,9 +128,15 @@ public class TrainedModelProvider { private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; private final NamedXContentRegistry xContentRegistry; + private final TrainedModelCacheMetadataService modelCacheMetadataService; - public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) { + public TrainedModelProvider( + Client client, + TrainedModelCacheMetadataService modelCacheMetadataService, + NamedXContentRegistry xContentRegistry + ) { this.client = client; + this.modelCacheMetadataService = modelCacheMetadataService; this.xContentRegistry = xContentRegistry; } @@ -208,7 +215,7 @@ public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, Actio ML_ORIGIN, TransportIndexAction.TYPE, request, - ActionListener.wrap(indexResponse -> listener.onResponse(true), e -> { + ActionListener.wrap(indexResponse -> refreshCacheVersion(listener), e -> { if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { listener.onFailure( new ResourceAlreadyExistsException( @@ -521,7 +528,8 @@ private void storeTrainedModelAndDefinition( wrappedListener.onFailure(firstFailure); return; } - wrappedListener.onResponse(true); + + refreshCacheVersion(wrappedListener); }, wrappedListener::onFailure); executeAsyncWithOrigin(client, ML_ORIGIN, TransportBulkAction.TYPE, bulkRequest.request(), bulkResponseActionListener); @@ -894,7 +902,8 @@ public void deleteTrainedModel(String modelId, ActionListener listener) listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return; } - listener.onResponse(true); + + refreshCacheVersion(listener); }, e -> { if (e.getClass() == IndexNotFoundException.class) { listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); @@ -1376,4 +1385,13 @@ private static IndexRequest createRequest(IndexRequest request, String docId, To throw ExceptionsHelper.serverError("Unexpected serialization exception for [" + docId + "]", ex); } } + + private void refreshCacheVersion(ActionListener listener) { + modelCacheMetadataService.updateCacheVersion(ActionListener.wrap(resp -> { + // Checking the response is always AcknowledgedResponse.TRUE because AcknowledgedResponse.FALSE does not make sense. + // Errors should be reported through the onFailure method of the listener. + assert resp.equals(AcknowledgedResponse.TRUE); + listener.onResponse(true); + }, listener::onFailure)); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java new file mode 100644 index 0000000000000..61d39a8a962b6 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; +import org.junit.After; +import org.junit.Before; + +import java.util.concurrent.atomic.AtomicReference; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class TransportFlushTrainedModelCacheActionTests extends ESTestCase { + + private ThreadPool threadPool; + private TrainedModelCacheMetadataService modelCacheMetadataService; + + @Before + @SuppressWarnings({ "unchecked", "rawtypes" }) + private void setupMocks() { + threadPool = new TestThreadPool(getTestName()); + modelCacheMetadataService = mock(TrainedModelCacheMetadataService.class); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(0, ActionListener.class); + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(modelCacheMetadataService).updateCacheVersion(any(ActionListener.class)); + } + + @After + public void shutdown() throws Exception { + if (threadPool != null) { + threadPool.shutdownNow(); + } + threadPool = null; + } + + public void testOperation() { + ClusterService clusterService = mock(ClusterService.class); + TransportFlushTrainedModelCacheAction action = createAction(clusterService); + + ClusterState clusterState = ClusterState.builder(new ClusterName("flush-trained-model-cache-metadata-tests")).build(); + + FlushTrainedModelCacheAction.Request request = new FlushTrainedModelCacheAction.Request(); + AtomicReference ack = new AtomicReference<>(); + ActionListener listener = ActionTestUtils.assertNoFailureListener(ack::set); + + action.masterOperation(null, request, clusterState, listener); + + assertTrue(ack.get().isAcknowledged()); + verify(modelCacheMetadataService).updateCacheVersion(listener); + } + + private TransportFlushTrainedModelCacheAction createAction(ClusterService clusterService) { + return new TransportFlushTrainedModelCacheAction( + mock(TransportService.class), + clusterService, + threadPool, + mock(ActionFilters.class), + mock(IndexNameExpressionResolver.class), + modelCacheMetadataService + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataServiceTests.java new file mode 100644 index 0000000000000..b333b3596aba5 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataServiceTests.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.persistence; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateTaskExecutor.TaskContext; +import org.elasticsearch.cluster.ClusterStateTaskListener; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.MasterServiceTaskQueue; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService.CacheMetadataUpdateTask; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService.CacheMetadataUpdateTaskExecutor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService.RefreshCacheMetadataVersionTask; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TrainedModelCacheMetadataServiceTests extends ESTestCase { + private ClusterService clusterService; + private Client client; + private MasterServiceTaskQueue taskQueue; + + @Before + @SuppressWarnings("unchecked") + public void setUpMocks() { + clusterService = mockClusterService(); + client = mockClient(); + taskQueue = (MasterServiceTaskQueue) mock(MasterServiceTaskQueue.class); + + Mockito.when(clusterService.createTaskQueue(eq(TrainedModelCacheMetadataService.TASK_QUEUE_NAME), any(), any())) + .thenReturn(taskQueue); + } + + public void testRefreshCacheVersionOnMasterNode() { + final var taskExecutorCaptor = ArgumentCaptor.forClass(CacheMetadataUpdateTaskExecutor.class); + final TrainedModelCacheMetadataService modelCacheMetadataService = new TrainedModelCacheMetadataService(clusterService, client); + verify(clusterService).createTaskQueue(eq(TrainedModelCacheMetadataService.TASK_QUEUE_NAME), any(), taskExecutorCaptor.capture()); + + DiscoveryNodes clusterNodes = mock(DiscoveryNodes.class); + when(clusterNodes.getMasterNode()).thenReturn(mock(DiscoveryNode.class)); + when(clusterNodes.isLocalNodeElectedMaster()).thenReturn(true); + + ClusterState clusterState = mock(ClusterState.class); + when(clusterState.clusterRecovered()).thenReturn(true); + when(clusterState.nodes()).thenReturn(clusterNodes); + + modelCacheMetadataService.clusterChanged(new ClusterChangedEvent("test", clusterState, ClusterState.EMPTY_STATE)); + + @SuppressWarnings("unchecked") + final ActionListener listener = mock(ActionListener.class); + modelCacheMetadataService.updateCacheVersion(listener); + + // Verify a cluster state update task were submitted. + ArgumentCaptor updateTaskCaptor = ArgumentCaptor.forClass(RefreshCacheMetadataVersionTask.class); + verify(taskQueue).submitTask(any(String.class), updateTaskCaptor.capture(), isNull()); + assertThat(updateTaskCaptor.getValue().listener, is(listener)); + + verify(client, never()).execute(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + public void testRefreshCacheVersionOnNonMasterNode() { + final var taskExecutorCaptor = ArgumentCaptor.forClass(CacheMetadataUpdateTaskExecutor.class); + final TrainedModelCacheMetadataService modelCacheMetadataService = new TrainedModelCacheMetadataService(clusterService, client); + verify(clusterService).createTaskQueue(eq(TrainedModelCacheMetadataService.TASK_QUEUE_NAME), any(), taskExecutorCaptor.capture()); + + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(2, ActionListener.class); + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(client).execute(any(ActionType.class), any(FlushTrainedModelCacheAction.Request.class), any(ActionListener.class)); + + @SuppressWarnings("unchecked") + final ActionListener listener = mock(ActionListener.class); + modelCacheMetadataService.updateCacheVersion(listener); + + // Check a FlushTrainedModelCacheAction request is emitted to the master node, that will flush the cache. + verify(client).execute( + eq(FlushTrainedModelCacheAction.INSTANCE), + any(FlushTrainedModelCacheAction.Request.class), + any(ActionListener.class) + ); + verify(listener).onResponse(eq(AcknowledgedResponse.TRUE)); + + // Verify no cluster state update task were submitted on a non-master node. + verify(taskQueue, never()).submitTask(any(String.class), any(RefreshCacheMetadataVersionTask.class), any(TimeValue.class)); + } + + public void testRefreshCacheMetadataVersionTaskExecution() { + @SuppressWarnings("unchecked") + final ActionListener listener = mock(ActionListener.class); + final RefreshCacheMetadataVersionTask task = new RefreshCacheMetadataVersionTask(listener); + + final TrainedModelCacheMetadata currentCacheMetadata = new TrainedModelCacheMetadata( + randomValueOtherThan(Long.MAX_VALUE, () -> randomNonNegativeLong()) + ); + + @SuppressWarnings("unchecked") + final TaskContext taskContext = mock(TaskContext.class); + doAnswer(invocationOnMock -> { + invocationOnMock.getArgument(0, Runnable.class).run(); + return null; + }).when(taskContext).success(any(Runnable.class)); + + final TrainedModelCacheMetadata updatedCacheMetadata = task.execute(currentCacheMetadata, taskContext); + + // Check the version is incremented correctly + assertThat(updatedCacheMetadata.version(), equalTo(currentCacheMetadata.version() + 1)); + + // Check the task is marked as successful and the listener is called. + verify(taskContext).success(any(Runnable.class)); + verify(listener).onResponse(eq(AcknowledgedResponse.TRUE)); + } + + public void testRefreshCacheMetadataVersionTaskExecutionWithMaxVersion() { + @SuppressWarnings("unchecked") + final ActionListener listener = mock(ActionListener.class); + final RefreshCacheMetadataVersionTask task = new RefreshCacheMetadataVersionTask(listener); + + final TrainedModelCacheMetadata currentCacheMetadata = new TrainedModelCacheMetadata(Long.MAX_VALUE); + @SuppressWarnings("unchecked") + final TaskContext taskContext = mock(TaskContext.class); + final TrainedModelCacheMetadata updatedCacheMetadata = task.execute(currentCacheMetadata, taskContext); + + // Check the version counter is reset to 1 + assertThat(updatedCacheMetadata.version(), equalTo(1L)); + } + + private static Client mockClient() { + final Client client = mock(Client.class); + ThreadPool threadpool = mock(ThreadPool.class); + when(threadpool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + when(client.threadPool()).thenReturn(threadpool); + return client; + } + + private static ClusterService mockClusterService() { + final ClusterState clusterState = mock(ClusterState.class); + Mockito.when(clusterState.metadata()).thenReturn(Metadata.EMPTY_METADATA); + + final ClusterService clusterService = mock(ClusterService.class); + Mockito.when(clusterService.state()).thenReturn(clusterState); + Mockito.when(clusterService.getClusterName()).thenReturn(ClusterName.DEFAULT); + + return clusterService; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 3daeed561e88b..3a59a08242552 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -67,7 +67,11 @@ public class TrainedModelProviderTests extends ESTestCase { public void testDeleteModelStoredAsResource() { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider( + mock(Client.class), + mock(TrainedModelCacheMetadataService.class), + xContentRegistry() + ); PlainActionFuture future = new PlainActionFuture<>(); // Should be OK as we don't make any client calls trainedModelProvider.deleteTrainedModel("lang_ident_model_1", future); @@ -77,7 +81,11 @@ public void testDeleteModelStoredAsResource() { public void testPutModelThatExistsAsResource() { TrainedModelConfig config = TrainedModelConfigTests.createTestInstance("lang_ident_model_1").build(); - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider( + mock(Client.class), + mock(TrainedModelCacheMetadataService.class), + xContentRegistry() + ); PlainActionFuture future = new PlainActionFuture<>(); trainedModelProvider.storeTrainedModel(config, future); ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet); @@ -85,7 +93,11 @@ public void testPutModelThatExistsAsResource() { } public void testGetModelThatExistsAsResource() throws Exception { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider( + mock(Client.class), + mock(TrainedModelCacheMetadataService.class), + xContentRegistry() + ); for (String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) { PlainActionFuture future = new PlainActionFuture<>(); trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, future); @@ -180,7 +192,11 @@ public void testExpandIdsPagination() { } public void testGetModelThatExistsAsResourceButIsMissing() { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider( + mock(Client.class), + mock(TrainedModelCacheMetadataService.class), + xContentRegistry() + ); ElasticsearchException ex = expectThrows( ElasticsearchException.class, () -> trainedModelProvider.loadModelFromResource("missing_model", randomBoolean()) @@ -350,7 +366,7 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreate() { try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("inferenceEntityId").build(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelConfig(config, future); @@ -362,7 +378,7 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationCreateWhen try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("inferenceEntityId").build(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelConfig(config, future, false); @@ -374,7 +390,7 @@ public void testStoreTrainedModelConfigCallsClientExecuteWithOperationIndex() { try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelConfigTests.createTestInstance("inferenceEntityId").build(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelConfig(config, future, true); @@ -386,7 +402,7 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCr try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("inferenceEntityId"); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModel(config, future); @@ -398,7 +414,7 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationCr try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("inferenceEntityId"); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModel(config, future, false); @@ -410,7 +426,7 @@ public void testStoreTrainedModelWithDefinitionCallsClientExecuteWithOperationIn try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = createTrainedModelConfigWithDefinition("inferenceEntityId"); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModel(config, future, true); @@ -422,7 +438,7 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCre try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelDefinitionDoc(config, future); @@ -434,7 +450,7 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationCre try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelDefinitionDoc(config, "index", future, false); @@ -446,7 +462,7 @@ public void testStoreTrainedModelDefinitionDocCallsClientExecuteWithOperationInd try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var config = TrainedModelDefinitionDocTests.createDefinitionDocInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelDefinitionDoc(config, "index", future, true); @@ -458,7 +474,7 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var vocab = createVocabulary(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelVocabulary("inferenceEntityId", mock(VocabularyConfig.class), vocab, future); @@ -470,7 +486,7 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationCreate try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var vocab = createVocabulary(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelVocabulary("inferenceEntityId", mock(VocabularyConfig.class), vocab, future, false); @@ -482,7 +498,7 @@ public void testStoreTrainedModelVocabularyCallsClientExecuteWithOperationIndex( try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var vocab = createVocabulary(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelVocabulary("inferenceEntityId", mock(VocabularyConfig.class), vocab, future, true); @@ -494,7 +510,7 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreate() try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelMetadata(metadata, future); @@ -506,7 +522,7 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationCreateWh try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelMetadata(metadata, future, false); @@ -518,7 +534,7 @@ public void testStoreTrainedModelMetadataCallsClientExecuteWithOperationIndex() try (var threadPool = createThreadPool()) { final var client = createMockClient(threadPool); var metadata = TrainedModelMetadataTests.randomInstance(); - var trainedModelProvider = new TrainedModelProvider(client, xContentRegistry()); + var trainedModelProvider = new TrainedModelProvider(client, mock(TrainedModelCacheMetadataService.class), xContentRegistry()); var future = new PlainActionFuture(); trainedModelProvider.storeTrainedModelMetadata(metadata, future, true); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index 0139b0d500341..cd8673fd4d301 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.hamcrest.Matcher; @@ -176,7 +177,11 @@ public void testLangInference() throws Exception { } InferenceDefinition grabModel() throws IOException { - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider( + mock(Client.class), + mock(TrainedModelCacheMetadataService.class), + xContentRegistry() + ); PlainActionFuture future = new PlainActionFuture<>(); // Should be OK as we don't make any client calls trainedModelProvider.getTrainedModel("lang_ident_model_1", GetTrainedModelsAction.Includes.forModelDefinition(), null, future); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 126b05aec7f2f..65651b4a7eb65 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -201,6 +201,7 @@ public class Constants { "cluster:admin/xpack/ml/inference/put", "cluster:admin/xpack/ml/inference/model_aliases/put", "cluster:admin/xpack/ml/inference/model_aliases/delete", + "cluster:admin/xpack/ml/inference/clear_model_cache", "cluster:admin/xpack/ml/job/close", "cluster:admin/xpack/ml/job/data/post", "cluster:admin/xpack/ml/job/delete",