Skip to content

Commit

Permalink
Add Search Inference ID To Semantic Text Mapping (elastic#113051)
Browse files Browse the repository at this point in the history
Adds a search_inference_id parameter to the semantic_text mapping. This parameter defines the inference endpoint that is used to generate embeddings at query time.
  • Loading branch information
Mikep86 committed Sep 24, 2024
1 parent ce06812 commit 3aeea7a
Show file tree
Hide file tree
Showing 11 changed files with 562 additions and 80 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/113051.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 113051
summary: Add Search Inference ID To Semantic Text Mapping
area: Mapping
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_AGGREGATION_OPERATOR_STATUS_FINISH_NANOS = def(8_747_00_0);
public static final TransportVersion ML_TELEMETRY_MEMORY_ADDED = def(8_748_00_0);
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0);
public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_SEARCH_INFERENCE_ID;

/**
* Contains inference field data for fields.
* As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need
Expand All @@ -32,28 +34,43 @@
*/
public final class InferenceFieldMetadata implements SimpleDiffable<InferenceFieldMetadata>, ToXContentFragment {
private static final String INFERENCE_ID_FIELD = "inference_id";
private static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
private static final String SOURCE_FIELDS_FIELD = "source_fields";

private final String name;
private final String inferenceId;
private final String searchInferenceId;
private final String[] sourceFields;

public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) {
this(name, inferenceId, inferenceId, sourceFields);
}

public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) {
this.name = Objects.requireNonNull(name);
this.inferenceId = Objects.requireNonNull(inferenceId);
this.searchInferenceId = Objects.requireNonNull(searchInferenceId);
this.sourceFields = Objects.requireNonNull(sourceFields);
}

public InferenceFieldMetadata(StreamInput input) throws IOException {
this.name = input.readString();
this.inferenceId = input.readString();
if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_SEARCH_INFERENCE_ID)) {
this.searchInferenceId = input.readString();
} else {
this.searchInferenceId = this.inferenceId;
}
this.sourceFields = input.readStringArray();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeString(inferenceId);
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_SEARCH_INFERENCE_ID)) {
out.writeString(searchInferenceId);
}
out.writeStringArray(sourceFields);
}

Expand All @@ -64,12 +81,13 @@ public boolean equals(Object o) {
InferenceFieldMetadata that = (InferenceFieldMetadata) o;
return Objects.equals(name, that.name)
&& Objects.equals(inferenceId, that.inferenceId)
&& Objects.equals(searchInferenceId, that.searchInferenceId)
&& Arrays.equals(sourceFields, that.sourceFields);
}

@Override
public int hashCode() {
int result = Objects.hash(name, inferenceId);
int result = Objects.hash(name, inferenceId, searchInferenceId);
result = 31 * result + Arrays.hashCode(sourceFields);
return result;
}
Expand All @@ -82,6 +100,10 @@ public String getInferenceId() {
return inferenceId;
}

public String getSearchInferenceId() {
return searchInferenceId;
}

public String[] getSourceFields() {
return sourceFields;
}
Expand All @@ -94,6 +116,9 @@ public static Diff<InferenceFieldMetadata> readDiffFrom(StreamInput in) throws I
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(name);
builder.field(INFERENCE_ID_FIELD, inferenceId);
if (searchInferenceId.equals(inferenceId) == false) {
builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId);
}
builder.array(SOURCE_FIELDS_FIELD, sourceFields);
return builder.endObject();
}
Expand All @@ -106,13 +131,16 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws

String currentFieldName = null;
String inferenceId = null;
String searchInferenceId = null;
List<String> inputFields = new ArrayList<>();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token == XContentParser.Token.VALUE_STRING) {
if (INFERENCE_ID_FIELD.equals(currentFieldName)) {
inferenceId = parser.text();
} else if (SEARCH_INFERENCE_ID_FIELD.equals(currentFieldName)) {
searchInferenceId = parser.text();
}
} else if (token == XContentParser.Token.START_ARRAY) {
if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) {
Expand All @@ -128,6 +156,11 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
parser.skipChildren();
}
}
return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new));
return new InferenceFieldMetadata(
name,
inferenceId,
searchInferenceId == null ? inferenceId : searchInferenceId,
inputFields.toArray(String[]::new)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ protected boolean supportsUnknownFields() {
private static InferenceFieldMetadata createTestItem() {
String name = randomAlphaOfLengthBetween(3, 10);
String inferenceId = randomIdentifier();
String searchInferenceId = randomIdentifier();
String[] inputFields = generateRandomStringArray(5, 10, false, false);
return new InferenceFieldMetadata(name, inferenceId, inputFields);
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields);
}

public void testNullCtorArgsThrowException() {
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

Expand All @@ -23,7 +24,8 @@ public class InferenceFeatures implements FeatureSpecification {
public Set<NodeFeature> getFeatures() {
return Set.of(
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID
);
}

Expand Down
Loading

0 comments on commit 3aeea7a

Please sign in to comment.