Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference API] Semantic text delete inference #110487

Merged
merged 19 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
a774ca7
Prevent inference endpoints from being deleted if they are referenced…
maxhniebergall Jul 2, 2024
b75ded7
Update docs/changelog/110399.yaml
maxhniebergall Jul 2, 2024
e9249da
fix tests
maxhniebergall Jul 2, 2024
6b29685
Merge branch 'semanticTextDeleteInference' of https://github.com/elas…
maxhniebergall Jul 2, 2024
2c9c581
remove erroneous loging
maxhniebergall Jul 2, 2024
16ebaba
Apply suggestions from code review
maxhniebergall Jul 3, 2024
b7ca93a
Fix serialization problem
maxhniebergall Jul 3, 2024
493f29d
Update error messages
maxhniebergall Jul 3, 2024
e4c7724
Merge branch 'main' into semanticTextDeleteInference
maxhniebergall Jul 3, 2024
eec2fbc
Update Delete response to include new fields
maxhniebergall Jul 3, 2024
251b011
Refactor Delete Transport Action to return the error message on dry run
maxhniebergall Jul 3, 2024
0e73416
Fix tests including disabling failing yaml tests
maxhniebergall Jul 3, 2024
b9bc116
Fix YAML tests
maxhniebergall Jul 3, 2024
0d39190
move work off of transport thread onto utility threadpool
maxhniebergall Jul 4, 2024
a568816
Merge branch 'main' into semanticTextDeleteInference
elasticmachine Jul 4, 2024
325c339
clean up semantic text indexes after IT
maxhniebergall Jul 4, 2024
341af67
Merge branch 'semanticTextDeleteInference' of https://github.com/elas…
maxhniebergall Jul 4, 2024
1f0f06e
Merge branch 'main' into semanticTextDeleteInference
maxhniebergall Jul 8, 2024
13937b8
improvements from review
maxhniebergall Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/110399.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 110399
summary: "[Inference API] Prevent inference endpoints from being deleted if they are\
\ referenced by semantic text"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ static TransportVersion def(int id) {
public static final TransportVersion TEXT_SIMILARITY_RERANKER_RETRIEVER = def(8_699_00_0);
public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0);
public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0);
public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_702_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentBuilder;

Expand Down Expand Up @@ -105,10 +106,16 @@ public static class Response extends AcknowledgedResponse {

private final String PIPELINE_IDS = "pipelines";
Set<String> pipelineIds;
private final String REFERENCED_INDEXES = "indexes";
Set<String> indexes;
private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run
String dryRunMessage;

public Response(boolean acknowledged, Set<String> pipelineIds) {
public Response(boolean acknowledged, Set<String> pipelineIds, Set<String> semanticTextIndexes, @Nullable String dryRunMessage) {
super(acknowledged);
this.pipelineIds = pipelineIds;
this.indexes = semanticTextIndexes;
this.dryRunMessage = dryRunMessage;
}

public Response(StreamInput in) throws IOException {
Expand All @@ -118,6 +125,15 @@ public Response(StreamInput in) throws IOException {
} else {
pipelineIds = Set.of();
}

if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) {
indexes = in.readCollectionAsSet(StreamInput::readString);
dryRunMessage = in.readOptionalString();
} else {
indexes = Set.of();
dryRunMessage = null;
}

}

@Override
Expand All @@ -126,12 +142,18 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) {
out.writeCollection(pipelineIds, StreamOutput::writeString);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) {
out.writeCollection(indexes, StreamOutput::writeString);
out.writeOptionalString(dryRunMessage);
}
}

@Override
protected void addCustomFields(XContentBuilder builder, Params params) throws IOException {
super.addCustomFields(builder, params);
builder.field(PIPELINE_IDS, pipelineIds);
builder.field(REFERENCED_INDEXES, indexes);
builder.field(DRY_RUN_MESSAGE, dryRunMessage);
maxhniebergall marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand All @@ -142,6 +164,11 @@ public String toString() {
for (String entry : pipelineIds) {
returnable.append(entry).append(", ");
}
returnable.append(", semanticTextFieldsByIndex: ");
maxhniebergall marked this conversation as resolved.
Show resolved Hide resolved
for (String entry : indexes) {
returnable.append(entry).append(", ");
}
returnable.append(", dryRunMessage: ").append(dryRunMessage);
return returnable.toString();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a Generative AI
*/

package org.elasticsearch.xpack.core.ml.utils;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.transport.Transports;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class SemanticTextInfoExtractor {
private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class);

public static Set<String> extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set<String> endpointIds) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

assert endpointIds.isEmpty() == false;
assert metadata != null;

Set<String> referenceIndices = new HashSet<>();

Map<String, IndexMetadata> indices = metadata.indices();

indices.forEach((indexName, indexMetadata) -> {
if (indexMetadata.getInferenceFields() != null) {
Map<String, InferenceFieldMetadata> inferenceFields = indexMetadata.getInferenceFields();
if (inferenceFields.entrySet()
.stream()
.anyMatch(
entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId())
)) {
referenceIndices.add(indexName);
}
}
});

return referenceIndices;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ protected void deleteModel(String modelId, TaskType taskType) throws IOException
assertOkOrCreated(response);
}

protected void putSemanticText(String endpointId, String indexName) throws IOException {
var request = new Request("PUT", Strings.format("%s", indexName));
String body = Strings.format("""
{
"mappings": {
"properties": {
"inference_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
}
""", endpointId);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
}

protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
return putRequest(endpoint, modelConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -124,14 +125,15 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException {
putPipeline(pipelineId, endpointId);

{
var errorString = new StringBuilder().append("Inference endpoint ")
.append(endpointId)
.append(" is referenced by pipelines: ")
.append(Set.of(pipelineId))
.append(". ")
.append("Ensure that no pipelines are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.");
var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
assertThat(
e.getMessage(),
containsString(
"Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. "
+ "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it."
)
);
assertThat(e.getMessage(), containsString(errorString.toString()));
}
{
var response = deleteModel(endpointId, "dry_run=true");
Expand All @@ -146,4 +148,78 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException {
}
deletePipeline(pipelineId);
}

public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException {
String endpointId = "endpoint_referenced_by_semantic_text";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
String indexName = randomAlphaOfLength(10).toLowerCase();
putSemanticText(endpointId, indexName);
{

var errorString = new StringBuilder().append(" Inference endpoint ")
.append(endpointId)
.append(" is being used in the mapping for indexes: ")
.append(Set.of(indexName))
.append(". ")
.append("Ensure that no index mappings are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.");
var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
assertThat(e.getMessage(), containsString(errorString.toString()));
}
{
var response = deleteModel(endpointId, "dry_run=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":false"));
assertThat(entityString, containsString(indexName));
}
{
var response = deleteModel(endpointId, "force=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":true"));
}
deleteIndex(indexName);
}

public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException {
String endpointId = "endpoint_referenced_by_semantic_text";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
String indexName = randomAlphaOfLength(10).toLowerCase();
putSemanticText(endpointId, indexName);
var pipelineId = "pipeline_referencing_model";
putPipeline(pipelineId, endpointId);
{

var errorString = new StringBuilder().append("Inference endpoint ")
.append(endpointId)
.append(" is referenced by pipelines: ")
.append(Set.of(pipelineId))
.append(". ")
.append("Ensure that no pipelines are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.")
.append(" Inference endpoint ")
.append(endpointId)
.append(" is being used in the mapping for indexes: ")
.append(Set.of(indexName))
.append(". ")
.append("Ensure that no index mappings are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.");

var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
assertThat(e.getMessage(), containsString(errorString.toString()));
}
{
var response = deleteModel(endpointId, "dry_run=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":false"));
assertThat(entityString, containsString(indexName));
assertThat(entityString, containsString(pipelineId));
}
{
var response = deleteModel(endpointId, "force=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":true"));
}
deletePipeline(pipelineId);
deleteIndex(indexName);
}
}
Loading
Loading