Skip to content

Commit

Permalink
Infer model ID
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 18, 2024
1 parent e30bf6b commit a136ec1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS = def(8_808_00_0);
public static final TransportVersion EQL_ALLOW_PARTIAL_SEARCH_RESULTS = def(8_809_00_0);
public static final TransportVersion NODE_VERSION_INFORMATION_WITH_MIN_READ_ONLY_INDEX_VERSION = def(8_810_00_0);
public static final TransportVersion TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID = def(8_811_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

Expand All @@ -46,7 +48,7 @@ public class TextEmbeddingQueryVectorBuilder implements QueryVectorBuilder {
);

static {
PARSER.declareString(constructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareString(optionalConstructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareString(constructorArg(), MODEL_TEXT);
}

Expand All @@ -63,8 +65,13 @@ public TextEmbeddingQueryVectorBuilder(String modelId, String modelText) {
}

public TextEmbeddingQueryVectorBuilder(StreamInput in) throws IOException {
this.modelId = in.readString();
this.modelText = in.readString();
if (in.getTransportVersion().onOrAfter(TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID)) {
this.modelId = in.readOptionalString();
this.modelText = in.readString();
} else {
this.modelId = in.readString();
this.modelText = in.readString();
}
}

@Override
Expand All @@ -79,14 +86,20 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
if (out.getTransportVersion().onOrAfter(TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID)) {
out.writeOptionalString(modelId);
} else {
out.writeString(modelId);
}
out.writeString(modelText);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
if (modelId != null) {
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
}
builder.field(MODEL_TEXT.getPreferredName(), modelText);
builder.endObject();
return builder;
Expand All @@ -101,6 +114,11 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
);

if (modelId == null) {
throw new IllegalArgumentException("Required [model_id]");
}

inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors.
// knn queries using query vectors.
for (String inferenceId : inferenceIdsIndices.keySet()) {
boolQueryBuilder.should(
createSubQueryForIndices(inferenceIdsIndices.get(inferenceId), buildNestedQueryFromKnnVectorQuery(knnVec, inferenceId))
Expand All @@ -99,19 +99,33 @@ private QueryBuilder buildNestedQueryFromKnnVectorQuery(QueryBuilder queryBuilde
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
QueryVectorBuilder queryVectorBuilder = knnVectorQueryBuilder.queryVectorBuilder();
TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) queryVectorBuilder;
if (queryVectorBuilder != null) {
assert (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder);
TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) queryVectorBuilder;
if (searchInferenceId != null) {
queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(searchInferenceId, textEmbeddingQueryVectorBuilder.getModelText());
}
}
return QueryBuilders.nestedQuery(
SemanticTextField.getChunksFieldName(knnVectorQueryBuilder.getFieldName()),
buildNewKnnVectorQuery(SemanticTextField.getEmbeddingsFieldName(knnVectorQueryBuilder.getFieldName()), knnVectorQueryBuilder),
buildNewKnnVectorQuery(
SemanticTextField.getEmbeddingsFieldName(knnVectorQueryBuilder.getFieldName()),
knnVectorQueryBuilder,
queryVectorBuilder
),
ScoreMode.Max
);
}

private KnnVectorQueryBuilder buildNewKnnVectorQuery(String fieldName, KnnVectorQueryBuilder original) {
private KnnVectorQueryBuilder buildNewKnnVectorQuery(
String fieldName,
KnnVectorQueryBuilder original,
QueryVectorBuilder queryVectorBuilder
) {
if (original.queryVectorBuilder() != null) {
return new KnnVectorQueryBuilder(
fieldName,
original.queryVectorBuilder(),
queryVectorBuilder,
original.k(),
original.numCands(),
original.getVectorSimilarity()
Expand Down

0 comments on commit a136ec1

Please sign in to comment.