Skip to content

Commit

Permalink
Address Navneets comments
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 28, 2023
1 parent e404775 commit 6801844
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 29 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Multiple identical subqueries in Hybrid query ([#524](https://github.com/opensearch-project/neural-search/pull/524))
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -155,7 +156,19 @@ private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(
SearchHits searchHits = fetchSearchResultOptional.get().hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
if (Objects.isNull(searchHitArray)) {
log.info("array of search hits in fetch phase results is null");
return true;

Check warning on line 161 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java#L160-L161

Added lines #L160 - L161 were not covered by tests
}
if (searchHitArray.length != docIds.size()) {
log.info(
String.format(
Locale.ROOT,
"number of documents in fetch results [%d] and query results [%d] is different",
searchHitArray.length,
docIds.size()
)
);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);
SearchHit[] searchHitArray = getSearchHits(fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
Expand Down Expand Up @@ -169,21 +169,9 @@ private void updateOriginalFetchResults(
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
private SearchHit[] getSearchHits(final FetchSearchResult fetchSearchResult) {
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray)) {
throw new IllegalStateException(
"Score normalization processor cannot produce final query result, for one shard case fetch does not have any results"
);
}
if (searchHitArray.length != docIds.size()) {
throw new IllegalStateException(
"Score normalization processor cannot produce final query result, for one shard case number of fetched documents does not match number of search hits"
);
}
return searchHitArray;
return searchHits.getHits();
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.query;

import static java.util.Locale.ROOT;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -118,12 +120,21 @@ public float[] hybridScores() throws IOException {
}
Query query = scorer.getWeight().getQuery();
List<Integer> indexes = queryToIndex.get(query);
// we need to find the index of first sub-query that hasn't been updated yet
// we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0"
int index = indexes.stream()
.mapToInt(idx -> idx)
.filter(index1 -> Float.compare(scores[index1], 0.0f) == 0)
.filter(idx -> Float.compare(scores[idx], 0.0f) == 0)
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect score for subquery"));
.orElseThrow(
() -> new IllegalStateException(
String.format(

Check warning on line 130 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L129-L130

Added lines #L129 - L130 were not covered by tests
ROOT,
"cannot set score for one of hybrid search subquery [%s] and document [%d]",
query.toString(),
scorer.docID()

Check warning on line 134 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L133-L134

Added lines #L133 - L134 were not covered by tests
)
)
);
scores[index] = scorer.score();
}
return scores;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccess() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
Expand Down Expand Up @@ -282,14 +282,12 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

expectThrows(
IllegalStateException.class,
() -> normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
)
normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);
TestUtils.assertQueryResultScores(querySearchResults);
}
}
69 changes: 69 additions & 0 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,75 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() {
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

/**
* Tests complex query with multiple nested sub-queries, where soem sub-queries are same
* {
* "query": {
* "hybrid": {
* "queries": [
* {
* "term": {
* "text": "word1"
* }
* },
* {
* "term": {
* "text": "word2"
* }
* },
* {
* "term": {
* "text": "word3"
* }
* }
* ]
* }
* }
* }
*/
@SneakyThrows
public void testComplexQuery_whenMultipleIdenticalSubQueries_thenSuccessful() {
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);

TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilderThreeTerms = new HybridQueryBuilder();
hybridQueryBuilderThreeTerms.add(termQueryBuilder1);
hybridQueryBuilderThreeTerms.add(termQueryBuilder2);
hybridQueryBuilderThreeTerms.add(termQueryBuilder3);

Map<String, Object> searchResponseAsMap1 = search(
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
hybridQueryBuilderThreeTerms,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertEquals(2, getHitCount(searchResponseAsMap1));

List<Map<String, Object>> hits1NestedList = getNestedHits(searchResponseAsMap1);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hits1NestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}

// verify that scores are in desc order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());

Map<String, Object> total = getTotalHits(searchResponseAsMap1);
assertNotNull(total.get("value"));
assertEquals(2, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME);
Expand Down

0 comments on commit 6801844

Please sign in to comment.