From bd9dff82c48e394eac59ef596e575854647bf122 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 2 May 2024 14:13:33 -0400 Subject: [PATCH] Get sparse vector inference call working --- .../ml/queries/SparseVectorQueryBuilder.java | 187 ++++++------------ 1 file changed, 55 insertions(+), 132 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java index 23600aa34bdb5..cbb5d98bd96a7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java @@ -29,13 +29,12 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; @@ -44,6 +43,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; import java.io.IOException; import java.util.List; @@ -56,6 +56,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class SparseVectorQueryBuilder extends AbstractQueryBuilder { + private static final Logger logger = LogManager.getLogger(SparseVectorQueryBuilder.class); public static final String NAME = "sparse_vector"; public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); @@ -238,163 +239,85 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { if (tokens != null) { // Weighted tokens + logger.info("rewriting weighted tokens"); return weightedTokensToQuery(fieldName, tokens); } else { // Inference if (weightedTokensSupplier != null) { TextExpansionResults textExpansionResults = weightedTokensSupplier.get(); if (textExpansionResults == null) { + logger.info("no text expansion results yet, trying again"); return this; } + logger.info("text expansions to weighted tokens"); return weightedTokensToQuery(fieldName, textExpansionResults.getWeightedTokens()); } - // Get model ID from inference ID - GetInferenceModelAction.Request getInferenceModelActionRequest = new GetInferenceModelAction.Request( + logger.info("running inference"); + CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( inferenceId, - TaskType.SPARSE_EMBEDDING + List.of(text), + TextExpansionConfigUpdate.EMPTY_UPDATE, + false, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API ); + inferRequest.setHighPriority(true); + inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); + SetOnce textExpansionResultsSupplier = new SetOnce<>(); queryRewriteContext.registerAsyncAction( (client, listener) -> executeAsyncWithOrigin( client, ML_ORIGIN, - GetInferenceModelAction.INSTANCE, - getInferenceModelActionRequest, - new ActionListener<>() { - @Override - public void onResponse(GetInferenceModelAction.Response response) { - String modelId = response.getModels().get(0).getInferenceEntityId(); - CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( - modelId, - List.of(text), - TextExpansionConfigUpdate.EMPTY_UPDATE, - false, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API - ); - inferRequest.setHighPriority(true); - inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); - - queryRewriteContext.registerAsyncAction( - (client1, listener1) -> executeAsyncWithOrigin( - client1, - ML_ORIGIN, - InferenceAction.INSTANCE, - inferRequest, - ActionListener.wrap(inferenceResponse -> { - if (inferenceResponse.getResults().asMap().isEmpty()) { - listener1.onFailure(new IllegalStateException("inference response contain no results")); - return; - } - - if (inferenceResponse.getResults() - .transformToLegacyFormat() - .get(0) instanceof TextExpansionResults textExpansionResults) { - weightedTokensSupplier.set(textExpansionResults); - listener1.onResponse(null); - } else if (inferenceResponse.getResults() - .transformToLegacyFormat() - .get(0) instanceof WarningInferenceResults warning) { - listener1.onFailure(new IllegalStateException(warning.getWarning())); - } else { - listener1.onFailure( - new IllegalStateException( - "expected a result of type [" - + TextExpansionResults.NAME - + "] received [" - + inferenceResponse.getResults() - .transformToLegacyFormat() - .get(0) - .getWriteableName() - + "]. Is [" - + inferenceId - + "] a compatible inferenceId?" - ) - ); - } - }, listener1::onFailure) - ) - ); + CoordinatedInferenceAction.INSTANCE, + inferRequest, + ActionListener.wrap(inferenceResponse -> { + + if (inferenceResponse.getInferenceResults().isEmpty()) { + listener.onFailure(new IllegalStateException("inference response contain no results")); + return; } - @Override - public void onFailure(Exception e) { - listener.onFailure(e); + if (inferenceResponse.getInferenceResults().get(0) instanceof TextExpansionResults textExpansionResults) { + textExpansionResultsSupplier.set(textExpansionResults); + listener.onResponse(null); + } else if (inferenceResponse.getInferenceResults().get(0) instanceof WarningInferenceResults warning) { + listener.onFailure(new IllegalStateException(warning.getWarning())); + } else { + listener.onFailure( + new IllegalStateException( + "expected a result of type [" + + TextExpansionResults.NAME + + "] received [" + + inferenceResponse.getInferenceResults().get(0).getWriteableName() + + "]. Is [" + + inferenceId + + "] a compatible model?" + ) + ); } - } + }, listener::onFailure) ) ); - // queryRewriteContext.registerAsyncAction( - // (client, listener) -> executeAsyncWithOrigin( - // client, - // ML_ORIGIN, - // GetInferenceModelAction.INSTANCE, - // getInferenceModelActionRequest, - // ActionListener.wrap(getInferenceModelActionResponse -> { - // if (getInferenceModelActionResponse.getModels().isEmpty()) { - // listener.onFailure(new IllegalStateException("Inference ID not found: " + inferenceId)); - // return; - // } - // String modelId = getInferenceModelActionResponse.getModels().get(0).getInferenceEntityId(); - // CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( - // modelId, - // List.of(text), - // TextExpansionConfigUpdate.EMPTY_UPDATE, - // false, - // InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API - // ); - // inferRequest.setHighPriority(true); - // inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); - // - // queryRewriteContext.registerAsyncAction( - // (client1, listener1) -> executeAsyncWithOrigin( - // client1, - // ML_ORIGIN, - // InferenceAction.INSTANCE, - // inferRequest, - // ActionListener.wrap(inferenceResponse -> { - // if (inferenceResponse.getResults().asMap().isEmpty()) { - // listener1.onFailure(new IllegalStateException("inference response contain no results")); - // return; - // } - // - // if (inferenceResponse.getResults() - // .transformToLegacyFormat() - // .get(0) instanceof TextExpansionResults textExpansionResults) { - // weightedTokensSupplier.set(textExpansionResults); - // listener1.onResponse(null); - // } else if (inferenceResponse.getResults() - // .transformToLegacyFormat() - // .get(0) instanceof WarningInferenceResults warning) { - // listener1.onFailure(new IllegalStateException(warning.getWarning())); - // } else { - // listener1.onFailure( - // new IllegalStateException( - // "expected a result of type [" - // + TextExpansionResults.NAME - // + "] received [" - // + inferenceResponse.getResults().transformToLegacyFormat().get(0).getWriteableName() - // + "]. Is [" - // + inferenceId - // + "] a compatible inferenceId?" - // ) - // ); - // } - // }, listener1::onFailure) - // ) - // ); - // listener.onResponse(null); - // }, listener::onFailure) - // ) - // ); - - return new SparseVectorQueryBuilder(this, weightedTokensSupplier); + return new SparseVectorQueryBuilder(this, textExpansionResultsSupplier); } } private QueryBuilder weightedTokensToQuery(String fieldName, List weightedTokens) { - // TODO support pruning config + if (tokenPruningConfig != null) { + WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( + fieldName, + weightedTokens, + tokenPruningConfig + ); + weightedTokensQueryBuilder.queryName(queryName); + weightedTokensQueryBuilder.boost(boost); + return weightedTokensQueryBuilder; + } + // Note: Weighted tokens queries were introduced in 8.13.0. To support mixed version clusters prior to 8.13.0, + // if no token pruning configuration is specified we fall back to a boolean query. + // TODO this should be updated to always use a WeightedTokensQueryBuilder once it's in all supported versions. var boolQuery = QueryBuilders.boolQuery(); for (var weightedToken : weightedTokens) { boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight()));