-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Semantic text dense vector support #105515
Changes from 17 commits
8a2dbd4
fc76918
bd4e19e
896ec49
af763f0
fbefa0b
59194d7
8b5489b
ebd49b0
7f1dfb4
22daa5d
724f8d1
ce4125a
02b7cc4
7e98736
436b40f
b2769f3
3500c9c
bac3502
2e46037
9d7be42
fd553a9
cdc579f
139c94b
e9ff896
1e76b18
7191758
4f2c8ca
61b3d98
fbbfbd5
7c6120b
335afe5
fe13a04
a3bdabf
0f5e7a3
7eddba0
360256d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,11 +24,13 @@ | |
import org.elasticsearch.inference.InputType; | ||
import org.elasticsearch.inference.Model; | ||
import org.elasticsearch.inference.ModelRegistry; | ||
import org.elasticsearch.inference.ModelSettings; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.HashSet; | ||
import java.util.LinkedHashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
@@ -46,10 +48,10 @@ public class BulkShardRequestInferenceProvider { | |
public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; | ||
|
||
// Contains the original text for the field | ||
public static final String TEXT_SUBFIELD_NAME = "text"; | ||
|
||
// Contains the inference result when it's a sparse vector | ||
public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; | ||
public static final String INFERENCE_RESULTS = "inference_results"; | ||
public static final String INFERENCE_CHUNKS_RESULTS = "inference"; | ||
public static final String INFERENCE_CHUNKS_TEXT = "text"; | ||
|
||
private final ClusterState clusterState; | ||
private final Map<String, InferenceProvider> inferenceProvidersMap; | ||
|
@@ -90,7 +92,13 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { | |
var service = inferenceServiceRegistry.getService(unparsedModel.service()); | ||
if (service.isEmpty() == false) { | ||
InferenceProvider inferenceProvider = new InferenceProvider( | ||
service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()), | ||
service.get() | ||
.parsePersistedConfigWithSecrets( | ||
inferenceId, | ||
unparsedModel.taskType(), | ||
unparsedModel.settings(), | ||
unparsedModel.secrets() | ||
), | ||
service.get() | ||
); | ||
inferenceProviderMap.put(inferenceId, inferenceProvider); | ||
|
@@ -105,7 +113,7 @@ public void onFailure(Exception e) { | |
} | ||
}; | ||
|
||
modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); | ||
modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); | ||
} | ||
} | ||
} | ||
|
@@ -259,35 +267,22 @@ public void onResponse(InferenceServiceResults results) { | |
} | ||
|
||
int i = 0; | ||
for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { | ||
String fieldName = inferenceFieldNames.get(i++); | ||
List<Map<String, Object>> inferenceFieldResultList; | ||
try { | ||
inferenceFieldResultList = (List<Map<String, Object>>) rootInferenceFieldMap.computeIfAbsent( | ||
fieldName, | ||
k -> new ArrayList<>() | ||
); | ||
} catch (ClassCastException e) { | ||
onBulkItemFailure.apply( | ||
bulkItemRequest, | ||
itemIndex, | ||
new IllegalArgumentException( | ||
"Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object" | ||
for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { | ||
String inferenceFieldName = inferenceFieldNames.get(i++); | ||
Map<String, Object> inferenceFieldResult = new LinkedHashMap<>(); | ||
inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add model settings information to make it available to field mapping |
||
inferenceFieldResult.put( | ||
INFERENCE_RESULTS, | ||
List.of( | ||
Map.of( | ||
INFERENCE_CHUNKS_RESULTS, | ||
inferenceResults.asMap("output").get("output"), | ||
INFERENCE_CHUNKS_TEXT, | ||
docMap.get(inferenceFieldName) | ||
) | ||
); | ||
return; | ||
} | ||
// Remove previous inference results if any | ||
inferenceFieldResultList.clear(); | ||
|
||
// TODO Check inference result type to change subfield name | ||
var inferenceFieldMap = Map.of( | ||
SPARSE_VECTOR_SUBFIELD_NAME, | ||
inferenceResults.asMap("output").get("output"), | ||
TEXT_SUBFIELD_NAME, | ||
docMap.get(fieldName) | ||
) | ||
); | ||
inferenceFieldResultList.add(inferenceFieldMap); | ||
rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0 and the Server Side Public License, v 1; you may not use this file except | ||
* in compliance with, at your election, the Elastic License 2.0 or the Server | ||
* Side Public License, v 1. | ||
*/ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a fan of the renaming, why not keeping the original name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, renaming it back to SemanticTextModelSettings |
||
package org.elasticsearch.inference; | ||
|
||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
public class ModelSettings { | ||
carlosdelest marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
public static final String NAME = "model_settings"; | ||
public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); | ||
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); | ||
public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); | ||
public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); | ||
private final TaskType taskType; | ||
private final String inferenceId; | ||
private final Integer dimensions; | ||
private final SimilarityMeasure similarity; | ||
|
||
public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { | ||
Objects.requireNonNull(taskType, "task type must not be null"); | ||
kderusso marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Objects.requireNonNull(inferenceId, "inferenceId must not be null"); | ||
this.taskType = taskType; | ||
this.inferenceId = inferenceId; | ||
this.dimensions = dimensions; | ||
this.similarity = similarity; | ||
} | ||
|
||
public ModelSettings(Model model) { | ||
this( | ||
model.getTaskType(), | ||
model.getInferenceEntityId(), | ||
model.getServiceSettings().dimensions(), | ||
model.getServiceSettings().similarity() | ||
); | ||
} | ||
|
||
public static ModelSettings parse(XContentParser parser) throws IOException { | ||
return PARSER.apply(parser, null); | ||
} | ||
|
||
private static final ConstructingObjectParser<ModelSettings, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> { | ||
TaskType taskType = TaskType.fromString((String) args[0]); | ||
String inferenceId = (String) args[1]; | ||
Integer dimensions = (Integer) args[2]; | ||
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); | ||
return new ModelSettings(taskType, inferenceId, dimensions, similarity); | ||
}); | ||
static { | ||
PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); | ||
PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); | ||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); | ||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); | ||
} | ||
|
||
public Map<String, Object> asMap() { | ||
Map<String, Object> attrsMap = new HashMap<>(); | ||
attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); | ||
attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); | ||
if (dimensions != null) { | ||
attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); | ||
} | ||
if (similarity != null) { | ||
attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity); | ||
} | ||
return Map.of(NAME, attrsMap); | ||
} | ||
|
||
public TaskType taskType() { | ||
return taskType; | ||
} | ||
|
||
public String inferenceId() { | ||
return inferenceId; | ||
} | ||
|
||
public Integer dimensions() { | ||
return dimensions; | ||
} | ||
|
||
public SimilarityMeasure similarity() { | ||
return similarity; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Secrets are needed in order to perform inference on external services