Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic text dense vector support #105515

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8a2dbd4
Move SimilarityMeasure to server code
carlosdelest Feb 14, 2024
fc76918
Add dimensions and similarity to ServiceSettings, create ModelSetting…
carlosdelest Feb 14, 2024
bd4e19e
Change implementation of asMap() to avoid extra nesting in inference …
carlosdelest Feb 14, 2024
896ec49
Change inference results structure
carlosdelest Feb 14, 2024
af763f0
Field mapper uses new inference results structure
carlosdelest Feb 14, 2024
fbefa0b
Fix BulkOperationTests
carlosdelest Feb 14, 2024
59194d7
Fix tests of SemanticTextInferenceResultFieldMapper
carlosdelest Feb 14, 2024
8b5489b
Revert "Change implementation of asMap() to avoid extra nesting in in…
carlosdelest Feb 15, 2024
ebd49b0
Use coordination format instead of changing the asMap() implementation
carlosdelest Feb 15, 2024
7f1dfb4
Fix tests for new results structure
carlosdelest Feb 15, 2024
22daa5d
Add service extension for dense vector embeddings
carlosdelest Feb 15, 2024
724f8d1
Add tests for dense vector embeddings
carlosdelest Feb 15, 2024
ce4125a
Fix bug in model settings
carlosdelest Feb 15, 2024
02b7cc4
Initial work on inference field results mapping for validation
carlosdelest Feb 15, 2024
7e98736
Fix spotless
carlosdelest Feb 15, 2024
436b40f
Refactored inference services with common abstract class
carlosdelest Feb 15, 2024
b2769f3
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Feb 20, 2024
3500c9c
Add javadoc
carlosdelest Feb 21, 2024
bac3502
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Feb 21, 2024
2e46037
Fixing tests after merge
carlosdelest Feb 21, 2024
9d7be42
Fix tests
carlosdelest Feb 21, 2024
fd553a9
Change back javadoc
carlosdelest Mar 5, 2024
cdc579f
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Mar 5, 2024
139c94b
Fix node_selector in new esql test (#105943)
ldematte Mar 5, 2024
e9ff896
Fix gradle run on Serverless (#105938)
tvernum Mar 5, 2024
1e76b18
YAML test framework: re-introduce `requires` section and `cluster_fea…
ldematte Mar 5, 2024
7191758
Unmute testRollupNonTSIndex() and (#105949)
martijnvg Mar 5, 2024
4f2c8ca
Test mute for #105952 (#105953)
benwtrent Mar 5, 2024
61b3d98
Add note about optional times and epochs (#105786)
benwtrent Mar 5, 2024
fbbfbd5
[test] Disable index.shard.check_on_startup for searchable snapshot t…
arteam Mar 5, 2024
7c6120b
Fix TransportSLMGetExpiredSnapshotsActionTests (#105950)
DaveCTurner Mar 5, 2024
335afe5
Fix performance bug in `SourceConfirmedTextQuery#matches` (#105930)
jimczi Mar 5, 2024
fe13a04
Bugfix for mixed version cluster queries using text expansion (#105912)
kderusso Mar 5, 2024
a3bdabf
Merge branch 'main' into carlosdelest/semantic-text-dense-vector-support
carlosdelest Mar 5, 2024
0f5e7a3
Remove duplicate class
carlosdelest Mar 5, 2024
7eddba0
Check mappings are as expected
carlosdelest Mar 5, 2024
360256d
I hate YAML
carlosdelest Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.inference.ModelSettings;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -46,10 +48,10 @@ public class BulkShardRequestInferenceProvider {
public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference";

// Contains the original text for the field
public static final String TEXT_SUBFIELD_NAME = "text";

// Contains the inference result when it's a sparse vector
public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding";
public static final String INFERENCE_RESULTS = "inference_results";
public static final String INFERENCE_CHUNKS_RESULTS = "inference";
public static final String INFERENCE_CHUNKS_TEXT = "text";

private final ClusterState clusterState;
private final Map<String, InferenceProvider> inferenceProvidersMap;
Expand Down Expand Up @@ -90,7 +92,13 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) {
var service = inferenceServiceRegistry.getService(unparsedModel.service());
if (service.isEmpty() == false) {
InferenceProvider inferenceProvider = new InferenceProvider(
service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()),
service.get()
.parsePersistedConfigWithSecrets(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Secrets are needed in order to perform inference on external services

inferenceId,
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
),
service.get()
);
inferenceProviderMap.put(inferenceId, inferenceProvider);
Expand All @@ -105,7 +113,7 @@ public void onFailure(Exception e) {
}
};

modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire()));
modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire()));
}
}
}
Expand Down Expand Up @@ -259,35 +267,22 @@ public void onResponse(InferenceServiceResults results) {
}

int i = 0;
for (InferenceResults inferenceResults : results.transformToLegacyFormat()) {
String fieldName = inferenceFieldNames.get(i++);
List<Map<String, Object>> inferenceFieldResultList;
try {
inferenceFieldResultList = (List<Map<String, Object>>) rootInferenceFieldMap.computeIfAbsent(
fieldName,
k -> new ArrayList<>()
);
} catch (ClassCastException e) {
onBulkItemFailure.apply(
bulkItemRequest,
itemIndex,
new IllegalArgumentException(
"Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object"
for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) {
String inferenceFieldName = inferenceFieldNames.get(i++);
Map<String, Object> inferenceFieldResult = new LinkedHashMap<>();
inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add model settings information to make it available to field mapping

inferenceFieldResult.put(
INFERENCE_RESULTS,
List.of(
Map.of(
INFERENCE_CHUNKS_RESULTS,
inferenceResults.asMap("output").get("output"),
INFERENCE_CHUNKS_TEXT,
docMap.get(inferenceFieldName)
)
);
return;
}
// Remove previous inference results if any
inferenceFieldResultList.clear();

// TODO Check inference result type to change subfield name
var inferenceFieldMap = Map.of(
SPARSE_VECTOR_SUBFIELD_NAME,
inferenceResults.asMap("output").get("output"),
TEXT_SUBFIELD_NAME,
docMap.get(fieldName)
)
);
inferenceFieldResultList.add(inferenceFieldMap);
rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@ protected Parameter<?>[] getParameters() {
return new Parameter<?>[] { elementType, dims, indexed, similarity, indexOptions, meta };
}

public Builder similarity(VectorSimilarity vectorSimilarity) {
similarity.setValue(vectorSimilarity);
return this;
}

public Builder dimensions(int dimensions) {
this.dims.setValue(dimensions);
return this;
}

@Override
public DenseVectorFieldMapper build(MapperBuilderContext context) {
return new DenseVectorFieldMapper(
Expand Down Expand Up @@ -724,7 +734,7 @@ static Function<StringBuilder, StringBuilder> errorByteElementsAppender(byte[] v
ElementType.FLOAT
);

enum VectorSimilarity {
public enum VectorSimilarity {
L2_NORM {
@Override
float score(float similarity, ElementType elementType, int dim) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragm
List<? extends InferenceResults> transformToLegacyFormat();

/**
* Convert the result to a map to aid with test assertions
* Retrieves a map representation of the results. It should be equivalent to parsing the
* XContent representation of the results.
*
* @return the results as a map
*/
Map<String, Object> asMap();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of the renaming, why not keeping the original name? ModelSettings is too generic imo and we only need these settings for the field mapping.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, renaming it back to SemanticTextModelSettings

package org.elasticsearch.inference;

import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class ModelSettings {
carlosdelest marked this conversation as resolved.
Show resolved Hide resolved

public static final String NAME = "model_settings";
public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type");
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
public static final ParseField SIMILARITY_FIELD = new ParseField("similarity");
private final TaskType taskType;
private final String inferenceId;
private final Integer dimensions;
private final SimilarityMeasure similarity;

public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) {
Objects.requireNonNull(taskType, "task type must not be null");
kderusso marked this conversation as resolved.
Show resolved Hide resolved
Objects.requireNonNull(inferenceId, "inferenceId must not be null");
this.taskType = taskType;
this.inferenceId = inferenceId;
this.dimensions = dimensions;
this.similarity = similarity;
}

public ModelSettings(Model model) {
this(
model.getTaskType(),
model.getInferenceEntityId(),
model.getServiceSettings().dimensions(),
model.getServiceSettings().similarity()
);
}

public static ModelSettings parse(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}

private static final ConstructingObjectParser<ModelSettings, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
TaskType taskType = TaskType.fromString((String) args[0]);
String inferenceId = (String) args[1];
Integer dimensions = (Integer) args[2];
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]);
return new ModelSettings(taskType, inferenceId, dimensions, similarity);
});
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD);
}

public Map<String, Object> asMap() {
Map<String, Object> attrsMap = new HashMap<>();
attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString());
attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
if (dimensions != null) {
attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions);
}
if (similarity != null) {
attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity);
}
return Map.of(NAME, attrsMap);
}

public TaskType taskType() {
return taskType;
}

public String inferenceId() {
return inferenceId;
}

public Integer dimensions() {
return dimensions;
}

public SimilarityMeasure similarity() {
return similarity;
}
}
Loading