From 1be0f2b5aedce0294efffdf88df97f9771286e33 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Wed, 3 Jul 2024 17:27:12 -0400 Subject: [PATCH] =?UTF-8?q?Revert=20"[Inference=20API]=20Prevent=20inferen?= =?UTF-8?q?ce=20endpoints=20from=20being=20deleted=20if=20the=E2=80=A6"=20?= =?UTF-8?q?(#110446)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 89a1bd9c2da88160ba3d38ef572e98777f902e6a. --- docs/changelog/110399.yaml | 6 - .../org/elasticsearch/TransportVersions.java | 1 - .../action/DeleteInferenceEndpointAction.java | 29 +--- .../ml/utils/SemanticTextInfoExtractor.java | 48 ------- .../inference/InferenceBaseRestTest.java | 19 --- .../xpack/inference/InferenceCrudIT.java | 88 +------------ ...ransportDeleteInferenceEndpointAction.java | 124 +++++++----------- ..._text_query_inference_endpoint_changes.yml | 3 - 8 files changed, 52 insertions(+), 266 deletions(-) delete mode 100644 docs/changelog/110399.yaml delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java diff --git a/docs/changelog/110399.yaml b/docs/changelog/110399.yaml deleted file mode 100644 index 9e04e2656809e..0000000000000 --- a/docs/changelog/110399.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 110399 -summary: "[Inference API] Prevent inference endpoints from being deleted if they are\ - \ referenced by semantic text" -area: Machine Learning -type: enhancement -issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index fe87a055146d8..2004c6fda8ce5 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -208,7 +208,6 @@ static TransportVersion def(int id) { public static final TransportVersion TEXT_SIMILARITY_RERANKER_RETRIEVER = def(8_699_00_0); public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0); public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0); - public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_702_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java index 00debb5bf9366..dfb77ccd49fc2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; @@ -106,16 +105,10 @@ public static class Response extends AcknowledgedResponse { private final String PIPELINE_IDS = "pipelines"; Set pipelineIds; - private final String REFERENCED_INDEXES = "indexes"; - Set indexes; - private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run - String dryRunMessage; - public Response(boolean acknowledged, Set pipelineIds, Set semanticTextIndexes, @Nullable String dryRunMessage) { + public Response(boolean acknowledged, Set pipelineIds) { super(acknowledged); this.pipelineIds = pipelineIds; - this.indexes = semanticTextIndexes; - this.dryRunMessage = dryRunMessage; } public Response(StreamInput in) throws IOException { @@ -125,15 +118,6 @@ public Response(StreamInput in) throws IOException { } else { pipelineIds = Set.of(); } - - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { - indexes = in.readCollectionAsSet(StreamInput::readString); - dryRunMessage = in.readOptionalString(); - } else { - indexes = Set.of(); - dryRunMessage = null; - } - } @Override @@ -142,18 +126,12 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) { out.writeCollection(pipelineIds, StreamOutput::writeString); } - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { - out.writeCollection(indexes, StreamOutput::writeString); - out.writeOptionalString(dryRunMessage); - } } @Override protected void addCustomFields(XContentBuilder builder, Params params) throws IOException { super.addCustomFields(builder, params); builder.field(PIPELINE_IDS, pipelineIds); - builder.field(REFERENCED_INDEXES, indexes); - builder.field(DRY_RUN_MESSAGE, dryRunMessage); } @Override @@ -164,11 +142,6 @@ public String toString() { for (String entry : pipelineIds) { returnable.append(entry).append(", "); } - returnable.append(", semanticTextFieldsByIndex: "); - for (String entry : indexes) { - returnable.append(entry).append(", "); - } - returnable.append(", dryRunMessage: ").append(dryRunMessage); return returnable.toString(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java deleted file mode 100644 index ed021baf31828..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file was contributed to by a Generative AI - */ - -package org.elasticsearch.xpack.core.ml.utils; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.cluster.metadata.Metadata; - -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -public class SemanticTextInfoExtractor { - private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class); - - public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { - assert endpointIds.isEmpty() == false; - assert metadata != null; - - Set referenceIndices = new HashSet<>(); - - Map indices = metadata.indices(); - - indices.forEach((indexName, indexMetadata) -> { - if (indexMetadata.getInferenceFields() != null) { - Map inferenceFields = indexMetadata.getInferenceFields(); - if (inferenceFields.entrySet() - .stream() - .anyMatch( - entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId()) - )) { - referenceIndices.add(indexName); - } - } - }); - - return referenceIndices; - } -} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index f30f2e8fe201a..419869c0c4a5e 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -126,25 +126,6 @@ protected void deleteModel(String modelId, TaskType taskType) throws IOException assertOkOrCreated(response); } - protected void putSemanticText(String endpointId, String indexName) throws IOException { - var request = new Request("PUT", Strings.format("%s", indexName)); - String body = Strings.format(""" - { - "mappings": { - "properties": { - "inference_field": { - "type": "semantic_text", - "inference_id": "%s" - } - } - } - } - """, endpointId); - request.setJsonEntity(body); - var response = client().performRequest(request); - assertOkOrCreated(response); - } - protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return putRequest(endpoint, modelConfig); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 034457ec28a79..75e392b6d155f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -16,7 +16,6 @@ import java.io.IOException; import java.util.List; -import java.util.Set; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; @@ -125,15 +124,14 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { putPipeline(pipelineId, endpointId); { - var errorString = new StringBuilder().append("Inference endpoint ") - .append(endpointId) - .append(" is referenced by pipelines: ") - .append(Set.of(pipelineId)) - .append(". ") - .append("Ensure that no pipelines are using this inference endpoint, ") - .append("or use force to ignore this warning and delete the inference endpoint."); var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); - assertThat(e.getMessage(), containsString(errorString.toString())); + assertThat( + e.getMessage(), + containsString( + "Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. " + + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it." + ) + ); } { var response = deleteModel(endpointId, "dry_run=true"); @@ -148,76 +146,4 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { } deletePipeline(pipelineId); } - - public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException { - String endpointId = "endpoint_referenced_by_semantic_text"; - putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - String indexName = randomAlphaOfLength(10).toLowerCase(); - putSemanticText(endpointId, indexName); - { - - var errorString = new StringBuilder().append(" Inference endpoint ") - .append(endpointId) - .append(" is being used in the mapping for indexes: ") - .append(Set.of(indexName)) - .append(". ") - .append("Ensure that no index mappings are using this inference endpoint, ") - .append("or use force to ignore this warning and delete the inference endpoint."); - var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); - assertThat(e.getMessage(), containsString(errorString.toString())); - } - { - var response = deleteModel(endpointId, "dry_run=true"); - var entityString = EntityUtils.toString(response.getEntity()); - assertThat(entityString, containsString("\"acknowledged\":false")); - assertThat(entityString, containsString(indexName)); - } - { - var response = deleteModel(endpointId, "force=true"); - var entityString = EntityUtils.toString(response.getEntity()); - assertThat(entityString, containsString("\"acknowledged\":true")); - } - } - - public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException { - String endpointId = "endpoint_referenced_by_semantic_text"; - putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - String indexName = randomAlphaOfLength(10).toLowerCase(); - putSemanticText(endpointId, indexName); - var pipelineId = "pipeline_referencing_model"; - putPipeline(pipelineId, endpointId); - { - - var errorString = new StringBuilder().append("Inference endpoint ") - .append(endpointId) - .append(" is referenced by pipelines: ") - .append(Set.of(pipelineId)) - .append(". ") - .append("Ensure that no pipelines are using this inference endpoint, ") - .append("or use force to ignore this warning and delete the inference endpoint.") - .append(" Inference endpoint ") - .append(endpointId) - .append(" is being used in the mapping for indexes: ") - .append(Set.of(indexName)) - .append(". ") - .append("Ensure that no index mappings are using this inference endpoint, ") - .append("or use force to ignore this warning and delete the inference endpoint."); - - var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); - assertThat(e.getMessage(), containsString(errorString.toString())); - } - { - var response = deleteModel(endpointId, "dry_run=true"); - var entityString = EntityUtils.toString(response.getEntity()); - assertThat(entityString, containsString("\"acknowledged\":false")); - assertThat(entityString, containsString(indexName)); - assertThat(entityString, containsString(pipelineId)); - } - { - var response = deleteModel(endpointId, "force=true"); - var entityString = EntityUtils.toString(response.getEntity()); - assertThat(entityString, containsString("\"acknowledged\":true")); - } - deletePipeline(pipelineId); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 9a84f572a6d60..07d5e1e618578 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -3,8 +3,6 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. - * - * this file was contributed to by a Generative AI */ package org.elasticsearch.xpack.inference.action; @@ -20,10 +18,12 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -35,8 +35,6 @@ import java.util.Set; -import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; - public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< DeleteInferenceEndpointAction.Request, DeleteInferenceEndpointAction.Response> { @@ -91,15 +89,17 @@ protected void masterOperation( } if (request.isDryRun()) { - handleDryRun(request, state, masterListener); + masterListener.onResponse( + new DeleteInferenceEndpointAction.Response( + false, + InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())) + ) + ); return; - } else if (request.isForceDelete() == false) { - var errorString = endpointIsReferencedInPipelinesOrIndexes(state, request.getInferenceEndpointId()); - if (errorString != null) { - listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT)); + } else if (request.isForceDelete() == false + && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), listener)) { return; } - } var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { @@ -126,83 +126,47 @@ protected void masterOperation( }) .addListener( masterListener.delegateFailure( - (l3, didDeleteModel) -> masterListener.onResponse( - new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of(), Set.of(), null) - ) + (l3, didDeleteModel) -> masterListener.onResponse(new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of())) ) ); } - private static void handleDryRun( - DeleteInferenceEndpointAction.Request request, - ClusterState state, - ActionListener masterListener + private static boolean endpointIsReferencedInPipelines( + final ClusterState state, + final String inferenceEndpointId, + ActionListener listener ) { - Set pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())); - - Set indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints( - state.getMetadata(), - Set.of(request.getInferenceEndpointId()) - ); - - masterListener.onResponse( - new DeleteInferenceEndpointAction.Response( - false, - pipelines, - indexesReferencedBySemanticText, - buildErrorString(request.getInferenceEndpointId(), pipelines, indexesReferencedBySemanticText) - ) - ); - } - - private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterState state, final String inferenceEndpointId) { - - var pipelines = endpointIsReferencedInPipelines(state, inferenceEndpointId); - var indexes = endpointIsReferencedInIndex(state, inferenceEndpointId); - - if (pipelines.isEmpty() == false || indexes.isEmpty() == false) { - return buildErrorString(inferenceEndpointId, pipelines, indexes); - } - return null; - } - - private static String buildErrorString(String inferenceEndpointId, Set pipelines, Set indexes) { - StringBuilder errorString = new StringBuilder(); - - if (pipelines.isEmpty() == false) { - errorString.append("Inference endpoint ") - .append(inferenceEndpointId) - .append(" is referenced by pipelines: ") - .append(pipelines) - .append(". ") - .append("Ensure that no pipelines are using this inference endpoint, ") - .append("or use force to ignore this warning and delete the inference endpoint."); + Metadata metadata = state.getMetadata(); + if (metadata == null) { + listener.onFailure( + new ElasticsearchStatusException( + " Could not determine if the endpoint is referenced in a pipeline as cluster state metadata was unexpectedly null. " + + "Use `force` to delete it anyway", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + // Unsure why the ClusterState metadata would ever be null, but in this case it seems safer to assume the endpoint is referenced + return true; } - - if (indexes.isEmpty() == false) { - errorString.append(" Inference endpoint ") - .append(inferenceEndpointId) - .append(" is being used in the mapping for indexes: ") - .append(indexes) - .append(". ") - .append("Ensure that no index mappings are using this inference endpoint, ") - .append("or use force to ignore this warning and delete the inference endpoint."); + IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + logger.debug("No ingest metadata found in cluster state while attempting to delete inference endpoint"); + } else { + Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.getModelIdsFromInferenceProcessors(ingestMetadata); + if (modelIdsReferencedByPipelines.contains(inferenceEndpointId)) { + listener.onFailure( + new ElasticsearchStatusException( + "Inference endpoint " + + inferenceEndpointId + + " is referenced by pipelines and cannot be deleted. " + + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it.", + RestStatus.CONFLICT + ) + ); + return true; + } } - - return errorString.toString(); - } - - private static Set endpointIsReferencedInIndex(final ClusterState state, final String inferenceEndpointId) { - Set indexes = extractIndexesReferencingInferenceEndpoints(state.getMetadata(), Set.of(inferenceEndpointId)); - return indexes; - } - - private static Set endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) { - Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource( - state, - Set.of(inferenceEndpointId) - ); - return modelIdsReferencedByPipelines; + return false; } @Override diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index f6a7073914609..fd656c9d5d950 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -81,7 +81,6 @@ setup: - do: inference.delete: inference_id: sparse-inference-id - force: true - do: inference.put: @@ -120,7 +119,6 @@ setup: - do: inference.delete: inference_id: dense-inference-id - force: true - do: inference.put: @@ -157,7 +155,6 @@ setup: - do: inference.delete: inference_id: dense-inference-id - force: true - do: inference.put: