diff --git a/docs/changelog/110399.yaml b/docs/changelog/110399.yaml new file mode 100644 index 0000000000000..9e04e2656809e --- /dev/null +++ b/docs/changelog/110399.yaml @@ -0,0 +1,6 @@ +pr: 110399 +summary: "[Inference API] Prevent inference endpoints from being deleted if they are\ + \ referenced by semantic text" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index ff50d1513d28a..f64a43d463d47 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -209,6 +209,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0); public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0); public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0); + public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java index dfb77ccd49fc2..e9d612751e48f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java @@ -11,8 +11,10 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; @@ -105,10 +107,16 @@ public static class Response extends AcknowledgedResponse { private final String PIPELINE_IDS = "pipelines"; Set pipelineIds; + private final String REFERENCED_INDEXES = "indexes"; + Set indexes; + private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run + String dryRunMessage; - public Response(boolean acknowledged, Set pipelineIds) { + public Response(boolean acknowledged, Set pipelineIds, Set semanticTextIndexes, @Nullable String dryRunMessage) { super(acknowledged); this.pipelineIds = pipelineIds; + this.indexes = semanticTextIndexes; + this.dryRunMessage = dryRunMessage; } public Response(StreamInput in) throws IOException { @@ -118,6 +126,15 @@ public Response(StreamInput in) throws IOException { } else { pipelineIds = Set.of(); } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + indexes = in.readCollectionAsSet(StreamInput::readString); + dryRunMessage = in.readOptionalString(); + } else { + indexes = Set.of(); + dryRunMessage = null; + } + } @Override @@ -126,23 +143,25 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) { out.writeCollection(pipelineIds, StreamOutput::writeString); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + out.writeCollection(indexes, StreamOutput::writeString); + out.writeOptionalString(dryRunMessage); + } } @Override protected void addCustomFields(XContentBuilder builder, Params params) throws IOException { super.addCustomFields(builder, params); builder.field(PIPELINE_IDS, pipelineIds); + builder.field(REFERENCED_INDEXES, indexes); + if (dryRunMessage != null) { + builder.field(DRY_RUN_MESSAGE, dryRunMessage); + } } @Override public String toString() { - StringBuilder returnable = new StringBuilder(); - returnable.append("acknowledged: ").append(this.acknowledged); - returnable.append(", pipelineIdsByEndpoint: "); - for (String entry : pipelineIds) { - returnable.append(entry).append(", "); - } - return returnable.toString(); + return Strings.toString(this); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java new file mode 100644 index 0000000000000..544c1e344c91f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.core.ml.utils; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.transport.Transports; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class SemanticTextInfoExtractor { + private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class); + + public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Set referenceIndices = new HashSet<>(); + + Map indices = metadata.indices(); + + indices.forEach((indexName, indexMetadata) -> { + if (indexMetadata.getInferenceFields() != null) { + Map inferenceFields = indexMetadata.getInferenceFields(); + if (inferenceFields.entrySet() + .stream() + .anyMatch( + entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId()) + )) { + referenceIndices.add(indexName); + } + } + }); + + return referenceIndices; + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 419869c0c4a5e..f30f2e8fe201a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -126,6 +126,25 @@ protected void deleteModel(String modelId, TaskType taskType) throws IOException assertOkOrCreated(response); } + protected void putSemanticText(String endpointId, String indexName) throws IOException { + var request = new Request("PUT", Strings.format("%s", indexName)); + String body = Strings.format(""" + { + "mappings": { + "properties": { + "inference_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + } + """, endpointId); + request.setJsonEntity(body); + var response = client().performRequest(request); + assertOkOrCreated(response); + } + protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return putRequest(endpoint, modelConfig); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 75e392b6d155f..242f786e95364 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.List; +import java.util.Set; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; @@ -124,14 +125,15 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { putPipeline(pipelineId, endpointId); { + var errorString = new StringBuilder().append("Inference endpoint ") + .append(endpointId) + .append(" is referenced by pipelines: ") + .append(Set.of(pipelineId)) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); - assertThat( - e.getMessage(), - containsString( - "Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. " - + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it." - ) - ); + assertThat(e.getMessage(), containsString(errorString.toString())); } { var response = deleteModel(endpointId, "dry_run=true"); @@ -146,4 +148,78 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { } deletePipeline(pipelineId); } + + public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + { + + var errorString = new StringBuilder().append(" Inference endpoint ") + .append(endpointId) + .append(" is being used in the mapping for indexes: ") + .append(Set.of(indexName)) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); + assertThat(e.getMessage(), containsString(errorString.toString())); + } + { + var response = deleteModel(endpointId, "dry_run=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":false")); + assertThat(entityString, containsString(indexName)); + } + { + var response = deleteModel(endpointId, "force=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":true")); + } + deleteIndex(indexName); + } + + public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + var pipelineId = "pipeline_referencing_model"; + putPipeline(pipelineId, endpointId); + { + + var errorString = new StringBuilder().append("Inference endpoint ") + .append(endpointId) + .append(" is referenced by pipelines: ") + .append(Set.of(pipelineId)) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint.") + .append(" Inference endpoint ") + .append(endpointId) + .append(" is being used in the mapping for indexes: ") + .append(Set.of(indexName)) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + + var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); + assertThat(e.getMessage(), containsString(errorString.toString())); + } + { + var response = deleteModel(endpointId, "dry_run=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":false")); + assertThat(entityString, containsString(indexName)); + assertThat(entityString, containsString(pipelineId)); + } + { + var response = deleteModel(endpointId, "force=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":true")); + } + deletePipeline(pipelineId); + deleteIndex(indexName); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 07d5e1e618578..e59ac4e1356f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -3,6 +3,8 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. + * + * this file was contributed to by a Generative AI */ package org.elasticsearch.xpack.inference.action; @@ -11,6 +13,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.TransportMasterNodeAction; @@ -18,12 +21,10 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -34,6 +35,10 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.Set; +import java.util.concurrent.Executor; + +import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< DeleteInferenceEndpointAction.Request, @@ -42,6 +47,7 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private static final Logger logger = LogManager.getLogger(TransportDeleteInferenceEndpointAction.class); + private final Executor executor; @Inject public TransportDeleteInferenceEndpointAction( @@ -66,6 +72,7 @@ public TransportDeleteInferenceEndpointAction( ); this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; + this.executor = threadPool.executor(UTILITY_THREAD_POOL_NAME); } @Override @@ -74,6 +81,15 @@ protected void masterOperation( DeleteInferenceEndpointAction.Request request, ClusterState state, ActionListener masterListener + ) { + // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can + executor.execute(ActionRunnable.wrap(masterListener, l -> doExecuteForked(request, state, l))); + } + + private void doExecuteForked( + DeleteInferenceEndpointAction.Request request, + ClusterState state, + ActionListener masterListener ) { SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry @@ -89,17 +105,15 @@ protected void masterOperation( } if (request.isDryRun()) { - masterListener.onResponse( - new DeleteInferenceEndpointAction.Response( - false, - InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())) - ) - ); + handleDryRun(request, state, masterListener); return; - } else if (request.isForceDelete() == false - && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), listener)) { + } else if (request.isForceDelete() == false) { + var errorString = endpointIsReferencedInPipelinesOrIndexes(state, request.getInferenceEndpointId()); + if (errorString != null) { + listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT)); return; } + } var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { @@ -126,47 +140,83 @@ && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), list }) .addListener( masterListener.delegateFailure( - (l3, didDeleteModel) -> masterListener.onResponse(new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of())) + (l3, didDeleteModel) -> masterListener.onResponse( + new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of(), Set.of(), null) + ) ) ); } - private static boolean endpointIsReferencedInPipelines( - final ClusterState state, - final String inferenceEndpointId, - ActionListener listener + private static void handleDryRun( + DeleteInferenceEndpointAction.Request request, + ClusterState state, + ActionListener masterListener ) { - Metadata metadata = state.getMetadata(); - if (metadata == null) { - listener.onFailure( - new ElasticsearchStatusException( - " Could not determine if the endpoint is referenced in a pipeline as cluster state metadata was unexpectedly null. " - + "Use `force` to delete it anyway", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - // Unsure why the ClusterState metadata would ever be null, but in this case it seems safer to assume the endpoint is referenced - return true; + Set pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())); + + Set indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints( + state.getMetadata(), + Set.of(request.getInferenceEndpointId()) + ); + + masterListener.onResponse( + new DeleteInferenceEndpointAction.Response( + false, + pipelines, + indexesReferencedBySemanticText, + buildErrorString(request.getInferenceEndpointId(), pipelines, indexesReferencedBySemanticText) + ) + ); + } + + private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterState state, final String inferenceEndpointId) { + + var pipelines = endpointIsReferencedInPipelines(state, inferenceEndpointId); + var indexes = endpointIsReferencedInIndex(state, inferenceEndpointId); + + if (pipelines.isEmpty() == false || indexes.isEmpty() == false) { + return buildErrorString(inferenceEndpointId, pipelines, indexes); } - IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); - if (ingestMetadata == null) { - logger.debug("No ingest metadata found in cluster state while attempting to delete inference endpoint"); - } else { - Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.getModelIdsFromInferenceProcessors(ingestMetadata); - if (modelIdsReferencedByPipelines.contains(inferenceEndpointId)) { - listener.onFailure( - new ElasticsearchStatusException( - "Inference endpoint " - + inferenceEndpointId - + " is referenced by pipelines and cannot be deleted. " - + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it.", - RestStatus.CONFLICT - ) - ); - return true; - } + return null; + } + + private static String buildErrorString(String inferenceEndpointId, Set pipelines, Set indexes) { + StringBuilder errorString = new StringBuilder(); + + if (pipelines.isEmpty() == false) { + errorString.append("Inference endpoint ") + .append(inferenceEndpointId) + .append(" is referenced by pipelines: ") + .append(pipelines) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); } - return false; + + if (indexes.isEmpty() == false) { + errorString.append(" Inference endpoint ") + .append(inferenceEndpointId) + .append(" is being used in the mapping for indexes: ") + .append(indexes) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + } + + return errorString.toString(); + } + + private static Set endpointIsReferencedInIndex(final ClusterState state, final String inferenceEndpointId) { + Set indexes = extractIndexesReferencingInferenceEndpoints(state.getMetadata(), Set.of(inferenceEndpointId)); + return indexes; + } + + private static Set endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) { + Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource( + state, + Set.of(inferenceEndpointId) + ); + return modelIdsReferencedByPipelines; } @Override diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index fd656c9d5d950..f6a7073914609 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -81,6 +81,7 @@ setup: - do: inference.delete: inference_id: sparse-inference-id + force: true - do: inference.put: @@ -119,6 +120,7 @@ setup: - do: inference.delete: inference_id: dense-inference-id + force: true - do: inference.put: @@ -155,6 +157,7 @@ setup: - do: inference.delete: inference_id: dense-inference-id + force: true - do: inference.put: