Skip to content

Commit

Permalink
Add default boundary for neuralsparse twophase settings, and compleme…
Browse files Browse the repository at this point in the history
…nt test cases.

Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed Apr 22, 2024
1 parent 63340cb commit d19132a
Show file tree
Hide file tree
Showing 7 changed files with 870 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName()));
}

if (sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() <= 0
|| sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() >= 1) {
if (sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() < 0
|| sparseEncodingQueryBuilder.neuralSparseTwoPhaseParameters.pruning_ratio() > 1) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] %s field value must be in range (0,1)",
"[%s] %s field value must be in range (0,1]",
NeuralSparseTwoPhaseParameters.NAME.getPreferredName(),
NeuralSparseTwoPhaseParameters.PRUNING_RATIO.getPreferredName()
)
Expand Down Expand Up @@ -434,7 +434,7 @@ private Map<String, Float> getFilteredScoreTokens(boolean aboveThreshold, float
}
return queryTokens.entrySet()
.stream()
.filter(entry -> (aboveThreshold == (entry.getValue() > threshold)))
.filter(entry -> (aboveThreshold == (entry.getValue() >= threshold)))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,16 @@ public final class NeuralSearchSettings {
public static final Setting<Float> NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION = Setting.floatSetting(
"plugins.neural_search.neural_sparse.two_phase.default_window_size_expansion",
5f,
1f,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final Setting<Float> NEURAL_SPARSE_TWO_PHASE_DEFAULT_PRUNING_RATIO = Setting.floatSetting(
"plugins.neural_search.neural_sparse.two_phase.default_pruning_ratio",
0.4f,
0.0f,
1.0f,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
Expand All @@ -59,6 +62,7 @@ public final class NeuralSearchSettings {
public static final Setting<Integer> NEURAL_SPARSE_TWO_PHASE_MAX_WINDOW_SIZE = Setting.intSetting(
"plugins.neural_search.neural_sparse.two_phase.max_window_size",
10000,
0,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.query;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -17,6 +18,7 @@
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TOKENS_FIELD;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -47,9 +49,11 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.settings.NeuralSearchSettings;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
Expand Down Expand Up @@ -230,15 +234,14 @@ public void testFromXContent_whenBuiltWithIllegalTwoPhaseWindowSizeExpansion_the
"query_text": "string",
"model_id": "string",
"two_phase_settings":{
"window_size_expansion": 0.5,
"window_size_expansion": 0.4,
"pruning_ratio": 0.4,
"enabled": true
}
}
}
*/
NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings()
.window_size_expansion(NeuralSparseTwoPhaseParametersTests.TEST_WINDOW_SIZE_EXPANSION);
NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings().window_size_expansion(0.4f);
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
Expand All @@ -261,13 +264,13 @@ public void testFromXContent_whenBuiltWithIllegalTwoPhasePruningRate_thenBuildSu
"model_id": "string",
"two_phase_settings":{
"window_size_expansion": 0.5,
"pruning_ratio": 0, // or 1
"pruning_ratio": -0.001, // or 1.001
"enabled": true
}
}
}
*/
NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings().pruning_ratio(0f);
NeuralSparseTwoPhaseParameters parameters = NeuralSparseTwoPhaseParameters.getDefaultSettings().pruning_ratio(-0.001f);
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
Expand All @@ -280,7 +283,7 @@ public void testFromXContent_whenBuiltWithIllegalTwoPhasePruningRate_thenBuildSu
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));

parameters.pruning_ratio(1f);
parameters.pruning_ratio(1.001f);
XContentBuilder xContentBuilder2 = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
Expand Down Expand Up @@ -880,4 +883,43 @@ public void testBuildFeatureFieldQueryFormTokens() {
assertSame(booleanQuery.clauses().size(), 2);
}

@SneakyThrows
public void testTokenDividedByScores_whenDefaultSettings() {
Map<String, Float> map = new HashMap<>();
for (int i = 1; i < 11; i++) {
map.put(String.valueOf(i), (float) i);
}
final Supplier<Map<String, Float>> tokenSupplier = () -> map;
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName("rank_features")
.queryText(QUERY_TEXT)
.neuralSparseTwoPhaseParameters(NeuralSparseTwoPhaseParameters.getDefaultSettings())
.modelId(MODEL_ID)
.queryTokensSupplier(tokenSupplier);
QueryShardContext context = mock(QueryShardContext.class);
MappedFieldType mappedFieldType = mock(MappedFieldType.class);
when(mappedFieldType.typeName()).thenReturn("rank_features");
when(context.fieldMapper(anyString())).thenReturn(mappedFieldType);
NeuralSparseQuery neuralSparseQuery = (NeuralSparseQuery) sparseEncodingQueryBuilder.doToQuery(context);
BooleanQuery highScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getHighScoreTokenQuery();
BooleanQuery lowScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getLowScoreTokenQuery();
assertNotNull(highScoreTokenQuery.clauses());
assertNotNull(lowScoreTokenQuery.clauses());
assertEquals(highScoreTokenQuery.clauses().size(), 7);
assertEquals(lowScoreTokenQuery.clauses().size(), 3);
sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName("rank_features")
.queryText(QUERY_TEXT)
.neuralSparseTwoPhaseParameters(
new NeuralSparseTwoPhaseParameters().enabled(true).window_size_expansion(5f).pruning_ratio(0.6f)
)
.modelId(MODEL_ID)
.queryTokensSupplier(tokenSupplier);
neuralSparseQuery = (NeuralSparseQuery) sparseEncodingQueryBuilder.doToQuery(context);
highScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getHighScoreTokenQuery();
lowScoreTokenQuery = (BooleanQuery) neuralSparseQuery.getLowScoreTokenQuery();
assertNotNull(highScoreTokenQuery.clauses());
assertNotNull(lowScoreTokenQuery.clauses());
assertEquals(highScoreTokenQuery.clauses().size(), 5);
assertEquals(lowScoreTokenQuery.clauses().size(), 5);
}

}
Loading

0 comments on commit d19132a

Please sign in to comment.