diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 80c8579bf..2637d5931 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -122,13 +122,15 @@ public String getType() { * * @param queryTokens the queryTokens map, key is the token String, value is the score. * @param thresholdRatio The ratio that control how tokens map be split. - * @return A map has two element, {[True, token map whose value above threshold],[False, token map whose value below threshold]} + * @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold } */ public static Tuple, Map> splitQueryTokensByRatioedMaxScoreAsThreshold( final Map queryTokens, final float thresholdRatio ) { - + if (Objects.isNull(queryTokens)) { + throw new IllegalArgumentException("Query tokens cannot be null or empty."); + } float max = 0f; for (Float value : queryTokens.values()) { max = Math.max(value, max); @@ -143,8 +145,8 @@ public static Tuple, Map> splitQueryTokensByRa Map highScoreTokens = queryTokensByScore.get(Boolean.TRUE); Map lowScoreTokens = queryTokensByScore.get(Boolean.FALSE); - if (Objects.isNull(highScoreTokens) || highScoreTokens.isEmpty()) { - throw new IllegalArgumentException("Query tokens cannot be null or empty."); + if (Objects.isNull(highScoreTokens)) { + highScoreTokens = Collections.emptyMap(); } if (Objects.isNull(lowScoreTokens)) { lowScoreTokens = Collections.emptyMap(); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index ea4f295f6..383e4b468 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -79,10 +79,16 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder twoPhaseSharedQueryToken; private Supplier> queryTokensSupplier; - // A parameter that can detect if twoPhase are enabled, when twoPhasePruneRatio equals -1f, it means it two-phase rescoreQueryBuilder's - // subQueryBuilder. + // A filed that for neural_sparse_two_phase_processor, if twoPhaseSharedQueryToken is not null, + // it means it's origin NeuralSparseQueryBuilder and should split the low score tokens form itself then put it into + // twoPhaseSharedQueryToken. + private Map twoPhaseSharedQueryToken; + // A parameter with a default value 0F, + // 1. If the query request are using neural_sparse_two_phase_processor and be collected, + // It's value will be the ratio of processor. + // 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor. + // Then in the DoToQuery, we can use this to determine which type are this queryBuilder. private float twoPhasePruneRatio = 0F; private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0; @@ -133,7 +139,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float .queryText(this.queryText) .modelId(this.modelId) .maxTokenScore(this.maxTokenScore) - .twoPhasePruneRatio(-1f); + .twoPhasePruneRatio(-1f * ratio); if (Objects.nonNull(this.queryTokensSupplier)) { Map tokens = queryTokensSupplier.get(); // Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1, @@ -307,6 +313,7 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { // We need to inference the sentence to get the queryTokens. The logic is similar to NeuralQueryBuilder + // If the inference is finished, then rewrite to self and call doToQuery, otherwise, continue doRewrite // If two-phase is enabled( twoPhaseSharedQueryToken is not null ), will split the queryTokens into high score tokens // and low score tokens, and assign them to queryTokensSupplier and twoPhaseSharedQueryToken. if (Objects.nonNull(queryTokensSupplier)) { @@ -324,6 +331,20 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { .twoPhasePruneRatio(twoPhasePruneRatio); } + /** + * Creates a BiConsumer to asynchronously perform model inference on a given text query, processing + * the inference results to either update a shared token set or set a new token set. + * This method leverages the ML_CLIENT to infer sentence models and processes the results to handle + * token weights. Depending on the existence of two-phase shared query tokens, it may split the tokens + * based on a specified ratio and update the shared token set with the tokens whose score lower than ratio * maxScoreOfAllToken. + * + * @param setOnce A SetOnce instance to store the results of the token weights after processing. + * It captures high score part of the split tokens or all tokens if no split is required. + * @return A BiConsumer taking an OpenSearch Client and an ActionListener. The ActionListener + * is used to handle the result of the inference task or to catch failures. + * The BiConsumer can be utilized in asynchronous task pipelines where an Elasticsearch + * client interaction is involved. + */ private BiConsumer> getModelInferenceAsync(SetOnce> setOnce) { return ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult( modelId(), @@ -336,7 +357,7 @@ private BiConsumer> getModelInferenceAsync(SetOnce queryTokens = new HashMap<>(); + public void testGetSplitSetOnceByScoreThreshold_whenNullQueryToken_thenThrowException() { + Map queryTokens = null; expectThrows( IllegalArgumentException.class, () -> NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold(queryTokens, 0.4f)