Skip to content

Commit

Permalink
Get sparse vector inference call working
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed May 2, 2024
1 parent d2dd882 commit bd9dff8
Showing 1 changed file with 55 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
Expand All @@ -44,6 +43,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;

import java.io.IOException;
import java.util.List;
Expand All @@ -56,6 +56,7 @@
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQueryBuilder> {
private static final Logger logger = LogManager.getLogger(SparseVectorQueryBuilder.class);
public static final String NAME = "sparse_vector";
public static final ParseField FIELD_FIELD = new ParseField("field");
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
Expand Down Expand Up @@ -238,163 +239,85 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {

if (tokens != null) {
// Weighted tokens
logger.info("rewriting weighted tokens");
return weightedTokensToQuery(fieldName, tokens);
} else {
// Inference
if (weightedTokensSupplier != null) {
TextExpansionResults textExpansionResults = weightedTokensSupplier.get();
if (textExpansionResults == null) {
logger.info("no text expansion results yet, trying again");
return this;
}
logger.info("text expansions to weighted tokens");
return weightedTokensToQuery(fieldName, textExpansionResults.getWeightedTokens());
}

// Get model ID from inference ID
GetInferenceModelAction.Request getInferenceModelActionRequest = new GetInferenceModelAction.Request(
logger.info("running inference");
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
inferenceId,
TaskType.SPARSE_EMBEDDING
List.of(text),
TextExpansionConfigUpdate.EMPTY_UPDATE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
);
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);

SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>();
queryRewriteContext.registerAsyncAction(
(client, listener) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
GetInferenceModelAction.INSTANCE,
getInferenceModelActionRequest,
new ActionListener<>() {
@Override
public void onResponse(GetInferenceModelAction.Response response) {
String modelId = response.getModels().get(0).getInferenceEntityId();
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
modelId,
List.of(text),
TextExpansionConfigUpdate.EMPTY_UPDATE,
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
);
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);

queryRewriteContext.registerAsyncAction(
(client1, listener1) -> executeAsyncWithOrigin(
client1,
ML_ORIGIN,
InferenceAction.INSTANCE,
inferRequest,
ActionListener.wrap(inferenceResponse -> {
if (inferenceResponse.getResults().asMap().isEmpty()) {
listener1.onFailure(new IllegalStateException("inference response contain no results"));
return;
}

if (inferenceResponse.getResults()
.transformToLegacyFormat()
.get(0) instanceof TextExpansionResults textExpansionResults) {
weightedTokensSupplier.set(textExpansionResults);
listener1.onResponse(null);
} else if (inferenceResponse.getResults()
.transformToLegacyFormat()
.get(0) instanceof WarningInferenceResults warning) {
listener1.onFailure(new IllegalStateException(warning.getWarning()));
} else {
listener1.onFailure(
new IllegalStateException(
"expected a result of type ["
+ TextExpansionResults.NAME
+ "] received ["
+ inferenceResponse.getResults()
.transformToLegacyFormat()
.get(0)
.getWriteableName()
+ "]. Is ["
+ inferenceId
+ "] a compatible inferenceId?"
)
);
}
}, listener1::onFailure)
)
);
CoordinatedInferenceAction.INSTANCE,
inferRequest,
ActionListener.wrap(inferenceResponse -> {

if (inferenceResponse.getInferenceResults().isEmpty()) {
listener.onFailure(new IllegalStateException("inference response contain no results"));
return;
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
if (inferenceResponse.getInferenceResults().get(0) instanceof TextExpansionResults textExpansionResults) {
textExpansionResultsSupplier.set(textExpansionResults);
listener.onResponse(null);
} else if (inferenceResponse.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
listener.onFailure(
new IllegalStateException(
"expected a result of type ["
+ TextExpansionResults.NAME
+ "] received ["
+ inferenceResponse.getInferenceResults().get(0).getWriteableName()
+ "]. Is ["
+ inferenceId
+ "] a compatible model?"
)
);
}
}
}, listener::onFailure)
)
);

// queryRewriteContext.registerAsyncAction(
// (client, listener) -> executeAsyncWithOrigin(
// client,
// ML_ORIGIN,
// GetInferenceModelAction.INSTANCE,
// getInferenceModelActionRequest,
// ActionListener.wrap(getInferenceModelActionResponse -> {
// if (getInferenceModelActionResponse.getModels().isEmpty()) {
// listener.onFailure(new IllegalStateException("Inference ID not found: " + inferenceId));
// return;
// }
// String modelId = getInferenceModelActionResponse.getModels().get(0).getInferenceEntityId();
// CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
// modelId,
// List.of(text),
// TextExpansionConfigUpdate.EMPTY_UPDATE,
// false,
// InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
// );
// inferRequest.setHighPriority(true);
// inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
//
// queryRewriteContext.registerAsyncAction(
// (client1, listener1) -> executeAsyncWithOrigin(
// client1,
// ML_ORIGIN,
// InferenceAction.INSTANCE,
// inferRequest,
// ActionListener.wrap(inferenceResponse -> {
// if (inferenceResponse.getResults().asMap().isEmpty()) {
// listener1.onFailure(new IllegalStateException("inference response contain no results"));
// return;
// }
//
// if (inferenceResponse.getResults()
// .transformToLegacyFormat()
// .get(0) instanceof TextExpansionResults textExpansionResults) {
// weightedTokensSupplier.set(textExpansionResults);
// listener1.onResponse(null);
// } else if (inferenceResponse.getResults()
// .transformToLegacyFormat()
// .get(0) instanceof WarningInferenceResults warning) {
// listener1.onFailure(new IllegalStateException(warning.getWarning()));
// } else {
// listener1.onFailure(
// new IllegalStateException(
// "expected a result of type ["
// + TextExpansionResults.NAME
// + "] received ["
// + inferenceResponse.getResults().transformToLegacyFormat().get(0).getWriteableName()
// + "]. Is ["
// + inferenceId
// + "] a compatible inferenceId?"
// )
// );
// }
// }, listener1::onFailure)
// )
// );
// listener.onResponse(null);
// }, listener::onFailure)
// )
// );

return new SparseVectorQueryBuilder(this, weightedTokensSupplier);
return new SparseVectorQueryBuilder(this, textExpansionResultsSupplier);
}
}

private QueryBuilder weightedTokensToQuery(String fieldName, List<WeightedToken> weightedTokens) {
// TODO support pruning config
if (tokenPruningConfig != null) {
WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(
fieldName,
weightedTokens,
tokenPruningConfig
);
weightedTokensQueryBuilder.queryName(queryName);
weightedTokensQueryBuilder.boost(boost);
return weightedTokensQueryBuilder;
}
// Note: Weighted tokens queries were introduced in 8.13.0. To support mixed version clusters prior to 8.13.0,
// if no token pruning configuration is specified we fall back to a boolean query.
// TODO this should be updated to always use a WeightedTokensQueryBuilder once it's in all supported versions.
var boolQuery = QueryBuilders.boolQuery();
for (var weightedToken : weightedTokens) {
boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight()));
Expand Down

0 comments on commit bd9dff8

Please sign in to comment.