Skip to content

Commit

Permalink
Add some comments, remove some redundant lines, fix some format.
Browse files Browse the repository at this point in the history
Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed May 28, 2024
1 parent 0f5eab9 commit c5d0e99
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ public String getType() {
*
* @param queryTokens the queryTokens map, key is the token String, value is the score.
* @param thresholdRatio The ratio that control how tokens map be split.
* @return A map has two element, {[True, token map whose value above threshold],[False, token map whose value below threshold]}
* @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold }
*/
public static Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold(
final Map<String, Float> queryTokens,
final float thresholdRatio
) {

if (Objects.isNull(queryTokens)) {
throw new IllegalArgumentException("Query tokens cannot be null or empty.");
}
float max = 0f;
for (Float value : queryTokens.values()) {
max = Math.max(value, max);
Expand All @@ -143,8 +145,8 @@ public static Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRa

Map<String, Float> highScoreTokens = queryTokensByScore.get(Boolean.TRUE);
Map<String, Float> lowScoreTokens = queryTokensByScore.get(Boolean.FALSE);
if (Objects.isNull(highScoreTokens) || highScoreTokens.isEmpty()) {
throw new IllegalArgumentException("Query tokens cannot be null or empty.");
if (Objects.isNull(highScoreTokens)) {
highScoreTokens = Collections.emptyMap();
}
if (Objects.isNull(lowScoreTokens)) {
lowScoreTokens = Collections.emptyMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,16 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
private String queryText;
private String modelId;
private Float maxTokenScore;
private Map<String, Float> twoPhaseSharedQueryToken;
private Supplier<Map<String, Float>> queryTokensSupplier;
// A parameter that can detect if twoPhase are enabled, when twoPhasePruneRatio equals -1f, it means it two-phase rescoreQueryBuilder's
// subQueryBuilder.
// A filed that for neural_sparse_two_phase_processor, if twoPhaseSharedQueryToken is not null,
// it means it's origin NeuralSparseQueryBuilder and should split the low score tokens form itself then put it into
// twoPhaseSharedQueryToken.
private Map<String, Float> twoPhaseSharedQueryToken;
// A parameter with a default value 0F,
// 1. If the query request are using neural_sparse_two_phase_processor and be collected,
// It's value will be the ratio of processor.
// 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
// Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
private float twoPhasePruneRatio = 0F;

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;
Expand Down Expand Up @@ -133,7 +139,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
.queryText(this.queryText)
.modelId(this.modelId)
.maxTokenScore(this.maxTokenScore)
.twoPhasePruneRatio(-1f);
.twoPhasePruneRatio(-1f * ratio);
if (Objects.nonNull(this.queryTokensSupplier)) {
Map<String, Float> tokens = queryTokensSupplier.get();
// Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1,
Expand Down Expand Up @@ -307,6 +313,7 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui
@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// We need to inference the sentence to get the queryTokens. The logic is similar to NeuralQueryBuilder
// If the inference is finished, then rewrite to self and call doToQuery, otherwise, continue doRewrite
// If two-phase is enabled( twoPhaseSharedQueryToken is not null ), will split the queryTokens into high score tokens
// and low score tokens, and assign them to queryTokensSupplier and twoPhaseSharedQueryToken.
if (Objects.nonNull(queryTokensSupplier)) {
Expand All @@ -324,6 +331,20 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.twoPhasePruneRatio(twoPhasePruneRatio);
}

/**
* Creates a BiConsumer to asynchronously perform model inference on a given text query, processing
* the inference results to either update a shared token set or set a new token set.
* This method leverages the ML_CLIENT to infer sentence models and processes the results to handle
* token weights. Depending on the existence of two-phase shared query tokens, it may split the tokens
* based on a specified ratio and update the shared token set with the tokens whose score lower than ratio * maxScoreOfAllToken.
*
* @param setOnce A SetOnce instance to store the results of the token weights after processing.
* It captures high score part of the split tokens or all tokens if no split is required.
* @return A BiConsumer taking an OpenSearch Client and an ActionListener. The ActionListener
* is used to handle the result of the inference task or to catch failures.
* The BiConsumer can be utilized in asynchronous task pipelines where an Elasticsearch
* client interaction is involved.
*/
private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map<String, Float>> setOnce) {
return ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult(
modelId(),
Expand All @@ -336,7 +357,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
twoPhasePruneRatio
);
setOnce.set(splitQueryTokens.v1());
twoPhaseSharedQueryToken = (splitQueryTokens.v2());
twoPhaseSharedQueryToken = splitQueryTokens.v2();
} else {
setOnce.set(queryTokens);
}
Expand Down Expand Up @@ -381,22 +402,28 @@ private static void validateFieldType(MappedFieldType fieldType) {

@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false;
if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
if (this == obj) {
return true;
}
if (Objects.isNull(obj) || getClass() != obj.getClass()) {
return false;
}
if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) {
return false;
}
if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) {
return false;
}

EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore)
.append(twoPhasePruneRatio, obj.twoPhasePruneRatio);
.append(twoPhasePruneRatio, obj.twoPhasePruneRatio)
.append(twoPhaseSharedQueryToken, obj.twoPhaseSharedQueryToken);
if (Objects.nonNull(queryTokensSupplier)) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
equalsBuilder.append(twoPhaseSharedQueryToken, obj.twoPhaseSharedQueryToken);
}
return equalsBuilder.isEquals();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ public void testGetSplitSetOnceByScoreThreshold() {
}

@SneakyThrows
public void testGetSplitSetOnceByScoreThreshold_whenHighScoreTokenIsNull_thenThrowException() {
Map<String, Float> queryTokens = new HashMap<>();
public void testGetSplitSetOnceByScoreThreshold_whenNullQueryToken_thenThrowException() {
Map<String, Float> queryTokens = null;
expectThrows(
IllegalArgumentException.class,
() -> NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold(queryTokens, 0.4f)
Expand Down

0 comments on commit c5d0e99

Please sign in to comment.