Skip to content

Commit

Permalink
Replace some NeuralSparseIT test case's model inference with fixed qu…
Browse files Browse the repository at this point in the history
…eryToken, add some comments, replace some String with String.format.

Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed Apr 25, 2024
1 parent feb9606 commit 69c3b0d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ public final class NeuralSparseQuery extends Query {
*/
@Override
public String toString(String field) {
return "NeuralSparseQuery("
+ currentQuery.toString(field)
+ ","
+ highScoreTokenQuery.toString(field)
+ ", "
+ lowScoreTokenQuery.toString(field)
+ ")";

return String.format(
"NeuralSparseQuery(%s,%s,%s)",
currentQuery.toString(field),
highScoreTokenQuery.toString(field),
lowScoreTokenQuery.toString(field)
);
}

public Query rewrite(IndexSearcher indexSearcher) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,14 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) {

@Override
protected int doHashCode() {
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore);
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName)
.append(queryText)
.append(modelId)
.append(maxTokenScore)
.append(neuralSparseTwoPhaseParameters);
if (queryTokensSupplier != null) {
builder.append(queryTokensSupplier.get());
}
if (Objects.nonNull(neuralSparseTwoPhaseParameters)) {
builder.append(neuralSparseTwoPhaseParameters.hashcode());
}
return builder.toHashCode();
}

Expand Down Expand Up @@ -455,6 +456,9 @@ private Map<String, Float> getFilteredScoreTokens(Map<String, Float> queryTokens
.stream()
.filter(entry -> (aboveThreshold == (entry.getValue() >= threshold)))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// This call will generate a filter from queryTokens.
// When aboveThreshold is true, will filter out all key-value pairs whose value >= threshold to a return map.
// When aboveThreshold is false, will filter out all key-value pairs whose value <= threshold to a return map.
}

private BooleanQuery buildFeatureFieldQueryFromTokens(Map<String, Float> tokens, String fieldName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class NeuralSparseTwoPhaseParameters implements Writeable {
* @param clusterService The opensearch clusterService.
* @param settings The env settings to initialize.
*/
public static void initialize(ClusterService clusterService, Settings settings) {
public static void initialize(final ClusterService clusterService, final Settings settings) {
DEFAULT_ENABLED = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_ENABLED.get(settings);
DEFAULT_WINDOW_SIZE_EXPANSION = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION.get(settings);
DEFAULT_PRUNING_RATIO = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_PRUNING_RATIO.get(settings);
Expand Down Expand Up @@ -96,14 +96,14 @@ public static NeuralSparseTwoPhaseParameters getDefaultSettings() {
* @param in StreamInput to initialize object from
* @throws IOException thrown if unable to read from input stream
*/
public NeuralSparseTwoPhaseParameters(StreamInput in) throws IOException {
public NeuralSparseTwoPhaseParameters(final StreamInput in) throws IOException {
window_size_expansion = in.readFloat();
pruning_ratio = in.readFloat();
enabled = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
public void writeTo(final StreamOutput out) throws IOException {
out.writeFloat(window_size_expansion);
out.writeFloat(pruning_ratio);
out.writeBoolean(enabled);
Expand All @@ -116,7 +116,7 @@ public void writeTo(StreamOutput out) throws IOException {
* @return the given XContentBuilder with object content added.
* @throws IOException if building the content fails.
*/
public XContentBuilder doXContent(XContentBuilder builder) throws IOException {
public XContentBuilder doXContent(final XContentBuilder builder) throws IOException {
builder.startObject(NAME.getPreferredName());
builder.field(WINDOW_SIZE_EXPANSION.getPreferredName(), window_size_expansion);
builder.field(PRUNING_RATIO.getPreferredName(), pruning_ratio);
Expand All @@ -132,7 +132,7 @@ public XContentBuilder doXContent(XContentBuilder builder) throws IOException {
* @return a new instance of NeuralSparseTwoPhaseParameters initialized from the parser.
* @throws IOException if parsing fails.
*/
public static NeuralSparseTwoPhaseParameters parseFromXContent(XContentParser parser) throws IOException {
public static NeuralSparseTwoPhaseParameters parseFromXContent(final XContentParser parser) throws IOException {
XContentParser.Token token;
String currentFieldName = "";
NeuralSparseTwoPhaseParameters neuralSparseTwoPhaseParameters = NeuralSparseTwoPhaseParameters.getDefaultSettings();
Expand Down Expand Up @@ -177,7 +177,8 @@ public boolean equals(Object obj) {
&& Objects.equals(pruning_ratio, other.pruning_ratio);
}

public int hashcode() {
@Override
public int hashCode() {
HashCodeBuilder builder = new HashCodeBuilder().append(enabled).append(window_size_expansion).append(pruning_ratio);
return builder.toHashCode();
}
Expand All @@ -188,7 +189,7 @@ public int hashcode() {
* @param neuralSparseTwoPhaseParameters The parameters to check.
* @return true if enabled, false otherwise.
*/
public static boolean isEnabled(NeuralSparseTwoPhaseParameters neuralSparseTwoPhaseParameters) {
public static boolean isEnabled(final NeuralSparseTwoPhaseParameters neuralSparseTwoPhaseParameters) {
if (!isClusterOnOrAfterMinReqVersionForTwoPhaseSearchSupport() || Objects.isNull(neuralSparseTwoPhaseParameters)) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,36 @@
*/
public class NeuralSparseTwoPhaseUtil {
/**
* This function determine any neuralSparseQuery from query, extract lowScoreTokenQuery from each of them,
* And as that, the neuralSparseQuery be extracted 's currentQuery will be changed from allScoreTokenQuery to highScoreTokenQuery.
* Then build a QueryRescoreContext of these extra lowScoreTokenQuery and add the built QueryRescoreContext to the searchContext.
* Finally, the score of TopDocs will be sum of highScoreTokenQuery and lowScoreTokenQuery, which equals to allTokenQuery.
* @param query The whole query include neuralSparseQuery to executed.
* @param searchContext The searchContext with this query.
*/
public static void addRescoreContextFromNeuralSparseQuery(final Query query, SearchContext searchContext) {
public static void addRescoreContextFromNeuralSparseQuery(final Query query, final SearchContext searchContext) {
Map<Query, Float> query2weight = new HashMap<>();
float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f);
Query twoPhaseQuery;
if (query2weight.isEmpty()) {
return;
} else if (query2weight.size() == 1) {
Map.Entry<Query, Float> entry = query2weight.entrySet().stream().findFirst().get();
}
if (query2weight.size() == 1) {
Map.Entry<Query, Float> entry = query2weight.entrySet().iterator().next();
twoPhaseQuery = new BoostQuery(entry.getKey(), entry.getValue());
} else {
twoPhaseQuery = getNestedTwoPhaseQuery(query2weight);
}
int curWindowSize = (int) (searchContext.size() * windowSizeExpansion);
if (curWindowSize < 0 || curWindowSize > NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE) {
throw new IllegalArgumentException(
"Two phase final windowSize "
+ curWindowSize
+ " out of score with limit "
+ NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE
+ "."
+ "You can change the value of cluster setting "
+ "[plugins.neural_search.neural_sparse.two_phase.max_window_size] to a integer at least 50."
String.format(
"Two phase final windowSize %d out of score with limit %d. "
+ "You can change the value of cluster setting [plugins.neural_search.neural_sparse.two_phase.max_window_size] "
+ "to a integer at least 50.",
curWindowSize,
NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE
)
);
}
QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize);
Expand Down
Loading

0 comments on commit 69c3b0d

Please sign in to comment.