Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for post filter in hybrid query #633

Merged
merged 9 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630))
- Support for post filter in hybrid query ([#633](https://github.com/opensearch-project/neural-search/pull/633))
### Bug Fixes
- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
Expand Down Expand Up @@ -46,6 +52,9 @@
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
private final Weight filterWeight;
public static final float boost_factor = 1f;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -60,27 +69,44 @@
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Weight filteringWeight = null;
// Check for post filter to create weight for filter query and later use that weight in the search workflow
if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) {
Query filterQuery = searchContext.parsedPostFilter().query();
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
ContextIndexSearcher searcher = searchContext.searcher();
// ScoreMode COMPLETE_NO_SCORES will be passed as post_filter does not contribute in scoring. COMPLETE_NO_SCORES means it is not
// a scoring clause
// Boost factor 1f is taken because if boost is multiplicative of 1 then it means "no boost"
// Previously this code in OpenSearch looked like
// https://github.com/opensearch-project/OpenSearch/commit/36a5cf8f35e5cbaa1ff857b5a5db8c02edc1a187
filteringWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, boost_factor);
}

return searchContext.shouldUseConcurrentSearch()
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort()
searchContext.sort(),
filteringWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort()
searchContext.sort(),
filteringWeight
);
}

@Override
public Collector newCollector() {
Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker);
return hybridcollector;
// Check if filterWeight is present. If it is present then return wrap Hybrid collector object underneath the FilteredCollector
// object and return it.
return filterWeight == null ? hybridcollector : new FilteredCollector(hybridcollector, filterWeight);
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down Expand Up @@ -108,7 +134,10 @@
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector);
}
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());

Check warning on line 139 in src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java#L139

Added line #L139 was not covered by tests
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
Expand Down Expand Up @@ -216,9 +245,10 @@
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats
SortAndFormats sortAndFormats,
Weight filteringWeight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make this parameter as Optional optionalFilterWeight.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If i do that then in newCollector method

 if (optionalFilterWeight.isEmpty()) {
            return hybridcollector;
        }

This code breaks with below error

{
    "error": {
        "root_cause": [
            {
                "type": "null_pointer_exception",
                "reason": "Cannot invoke \"java.util.Optional.isEmpty()\" because \"this.optionalFilterWeight\" is null"
            }
        ],
        "type": "search_phase_execution_exception",
        "reason": "all shards failed",
        "phase": "query",
        "grouped": true,
        "failed_shards": [
            {
                "shard": 0,
                "index": "movies",
                "node": "MpUK-BFmRRiigVIII84uIQ",
                "reason": {
                    "type": "null_pointer_exception",
                    "reason": "Cannot invoke \"java.util.Optional.isEmpty()\" because \"this.optionalFilterWeight\" is null"
                }
            }
        ],
        "caused_by": {
            "type": "null_pointer_exception",
            "reason": "Cannot invoke \"java.util.Optional.isEmpty()\" because \"this.optionalFilterWeight\" is null",
            "caused_by": {
                "type": "null_pointer_exception",
                "reason": "Cannot invoke \"java.util.Optional.isEmpty()\" because \"this.optionalFilterWeight\" is null"
            }
        }
    },
    "status": 500
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vibrantvarun did you construct optional with ofNullable method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. @navneet1v and me hopped on call and determined that it is kind of potiential bug and not the right practice to use Optional in class variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder what is the potential bug, it's not a blocker for this PR but using optional definitely isn't a bug

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@navneet1v would you like to answer martin question

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not a bug. The implementation was not correct due to which above exception happened.

but overall when we used optional on class member the IDE was adding warnings. Hence I suggested we can skip that. It was more of a coding practice and not a serious issue

) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats);
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
}

Expand All @@ -245,9 +275,10 @@
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats);
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lombok.SneakyThrows;

import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
Expand Down Expand Up @@ -46,6 +47,7 @@ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT {
"test-neural-aggs-pipeline-multi-doc-index-multiple-shards";
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard";
private static final String TEST_QUERY_TEXT3 = "hello";
private static final String TEST_QUERY_TEXT4 = "everyone";
private static final String TEST_QUERY_TEXT5 = "welcome";
private static final String TEST_DOC_TEXT1 = "Hello world";
private static final String TEST_DOC_TEXT2 = "Hi to this place";
Expand Down Expand Up @@ -182,6 +184,204 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu
}
}

@SneakyThrows
public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
testPostFilterWithSimpleHybridQuery(false, true);
testPostFilterWithComplexHybridQuery(false, true);
}

@SneakyThrows
public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
testPostFilterWithSimpleHybridQuery(false, true);
testPostFilterWithComplexHybridQuery(false, true);
}

@SneakyThrows
private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean hasPostFilterQuery) {
try {
if (isSingleShard) {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
} else {
prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);
}

HybridQueryBuilder simpleHybridQueryBuilder = createHybridQueryBuilder(false);

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000);

Map<String, Object> searchResponseAsMap;

if (isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);

assertHitResultsFromQuery(1, searchResponseAsMap);
} else if (isSingleShard && !hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}

// assert post-filter
List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);

List<Integer> docIndexes = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
assertNotNull(oneHit.get("_source"));
Map<String, Object> source = (Map<String, Object>) oneHit.get("_source");
int docIndex = (int) source.get(INTEGER_FIELD_1);
docIndexes.add(docIndex);
}
if (isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (isSingleShard && !hasPostFilterQuery) {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (!isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
} else {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
}
} finally {
if (isSingleShard) {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE);
} else {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}
}

@SneakyThrows
private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean hasPostFilterQuery) {
try {
if (isSingleShard) {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
} else {
prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);
}

HybridQueryBuilder complexHybridQueryBuilder = createHybridQueryBuilder(true);

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000);

Map<String, Object> searchResponseAsMap;

if (isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);

assertHitResultsFromQuery(1, searchResponseAsMap);
} else if (isSingleShard && !hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);
assertHitResultsFromQuery(4, searchResponseAsMap);
} else {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}

// assert post-filter
List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);

List<Integer> docIndexes = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
assertNotNull(oneHit.get("_source"));
Map<String, Object> source = (Map<String, Object>) oneHit.get("_source");
int docIndex = (int) source.get(INTEGER_FIELD_1);
docIndexes.add(docIndex);
}
if (isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (isSingleShard && !hasPostFilterQuery) {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (!isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
} else {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
}
} finally {
if (isSingleShard) {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE);
} else {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}
}

vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
@SneakyThrows
private void testAvgSumMinMaxAggs() {
try {
Expand Down Expand Up @@ -227,6 +427,20 @@ private void testAvgSumMinMaxAggs() {
}
}

@SneakyThrows
public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
testPostFilterWithSimpleHybridQuery(true, true);
testPostFilterWithComplexHybridQuery(true, true);
}

@SneakyThrows
public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
testPostFilterWithSimpleHybridQuery(true, true);
testPostFilterWithComplexHybridQuery(true, true);
}

private void testMaxAggsOnSingleShardCluster() throws Exception {
try {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
Expand Down Expand Up @@ -594,4 +808,29 @@ private void assertHitResultsFromQuery(int expected, Map<String, Object> searchR
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

private HybridQueryBuilder createHybridQueryBuilder(boolean isComplex) {
if (isComplex) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should().add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000);

QueryBuilder matchQuery = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(boolQueryBuilder).add(rangeFilterQuery).add(matchQuery);
return hybridQueryBuilder;

} else {
TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder3);
return hybridQueryBuilderNeuralThenTerm;
}
}

}
Loading
Loading