diff --git a/CHANGELOG.md b/CHANGELOG.md index b92edd850..fa866ca7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +Fix async actions are left in neural_sparse query ([438](https://github.com/opensearch-project/neural-search/pull/438)) +Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490) ### Infrastructure ### Documentation ### Maintenance diff --git a/build.gradle b/build.gradle index 335157549..7220fda23 100644 --- a/build.gradle +++ b/build.gradle @@ -153,7 +153,7 @@ dependencies { runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' - runtimeOnly group: 'org.json', name: 'json', version: '20230227' + runtimeOnly group: 'org.json', name: 'json', version: '20231013' } // In order to add the jar to the classpath, we need to unzip the diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 42f0a56e6..9ee3e9dbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -83,7 +83,7 @@ public DocIdSetIterator iterator() { */ @Override public float getMaxScore(int upTo) throws IOException { - return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> { + return subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> { try { return scorer.getMaxScore(upTo); } catch (IOException e) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 3e181c73f..86859a054 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.function.Supplier; import lombok.AllArgsConstructor; @@ -84,6 +85,10 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { this.fieldName = in.readString(); this.queryText = in.readString(); this.modelId = in.readString(); + if (in.readBoolean()) { + Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); + this.queryTokensSupplier = () -> queryTokens; + } } @Override @@ -91,6 +96,12 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); out.writeString(modelId); + if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) { + out.writeBoolean(true); + out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); + } else { + out.writeBoolean(false); + } } @Override @@ -256,16 +267,25 @@ private static void validateQueryTokens(Map queryTokens) { @Override protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; + if (Objects.isNull(obj) || getClass() != obj.getClass()) return false; + if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false; + if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) .append(modelId, obj.modelId); + if (!Objects.isNull(queryTokensSupplier)) { + equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); + } return equalsBuilder.isEquals(); } @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode(); + HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId); + if (!Objects.isNull(queryTokensSupplier)) { + builder.append(queryTokensSupplier.get()); + } + return builder.toHashCode(); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 69791681e..32901cf12 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -11,6 +11,7 @@ import java.util.Map; import java.util.Optional; +import org.opensearch.indices.IndicesService; import org.opensearch.ingest.IngestService; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; @@ -66,7 +67,8 @@ public void testProcessors() { null, mock(IngestService.class), null, - null + null, + mock(IndicesService.class) ); Map processors = plugin.getProcessors(processorParams); assertNotNull(processors); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index eec6955ff..171d2f4a4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -33,6 +34,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -188,6 +190,35 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows + public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + MatchQueryBuilder matchQuery2Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + hybridQueryBuilderOnlyTerm.add(matchQuery2Builder); + MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + boolQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertTrue(getHitCount(searchResponseAsMap) > 0); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( @@ -242,32 +273,45 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_NAME, Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) ); - addKnnDoc( - TEST_MULTI_DOC_INDEX_NAME, - "1", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector1).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT1) - ); - addKnnDoc( - TEST_MULTI_DOC_INDEX_NAME, - "2", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector2).toArray()) - ); - addKnnDoc( - TEST_MULTI_DOC_INDEX_NAME, - "3", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector3).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT2) + addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); + } + + if (TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 ); - assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME)); + addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } } + private void addDocsToIndex(final String testMultiDocIndexName) { + addKnnDoc( + testMultiDocIndexName, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + testMultiDocIndexName, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + testMultiDocIndexName, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + assertEquals(3, getDocCount(testMultiDocIndexName)); + } + private List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index 62ddb64f6..77ca3e64e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -6,7 +6,9 @@ package org.opensearch.neuralsearch.query; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; @@ -21,6 +23,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.tests.util.TestUtil; @@ -169,6 +172,63 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR testWithQuery(docs, scores, hybridQueryScorer); } + @SneakyThrows + public void testMaxScore_whenMultipleScorers_thenSuccessful() { + int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores = generateDocuments(maxDocId); + int[] docs = docsAndScores.getLeft(); + float[] scores = docsAndScores.getRight(); + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer( + weight, + Arrays.asList( + scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), + scorer(docs, scores, fakeWeight(new MatchNoDocsQuery())) + ) + ); + + float maxScore = hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE); + assertTrue(maxScore > 0.0f); + + HybridQueryScorer hybridQueryScorerWithSomeNullSubScorers = new HybridQueryScorer( + weight, + Arrays.asList(null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null) + ); + + maxScore = hybridQueryScorerWithSomeNullSubScorers.getMaxScore(Integer.MAX_VALUE); + assertTrue(maxScore > 0.0f); + + HybridQueryScorer hybridQueryScorerWithAllNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(null, null)); + + maxScore = hybridQueryScorerWithAllNullSubScorers.getMaxScore(Integer.MAX_VALUE); + assertEquals(0.0f, maxScore, 0.0f); + } + + @SneakyThrows + public void testMaxScoreFailures_whenScorerThrowsException_thenFail() { + int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores = generateDocuments(maxDocId); + int[] docs = docsAndScores.getLeft(); + float[] scores = docsAndScores.getRight(); + + Weight weight = mock(Weight.class); + + Scorer scorer = mock(Scorer.class); + when(scorer.getWeight()).thenReturn(fakeWeight(new MatchAllDocsQuery())); + when(scorer.iterator()).thenReturn(iterator(docs)); + when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception")); + + HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer)); + + RuntimeException runtimeException = expectThrows( + RuntimeException.class, + () -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE) + ); + assertTrue(runtimeException.getMessage().contains("Test exception")); + } + private Pair generateDocuments(int maxDocId) { final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); final int[] docs = new int[numDocs]; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index c876621a2..8656d7f04 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -61,11 +61,11 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) ); IndexSearcher searcher = newSearcher(reader); - Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); + Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); assertNotNull(weight); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); Scorer scorer = weight.scorer(leafReaderContext); assertNotNull(scorer); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7ff6ca0cb..f3fa3264d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -26,6 +26,7 @@ import lombok.SneakyThrows; import org.opensearch.client.Client; +import org.opensearch.common.SetOnce; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -262,6 +263,23 @@ public void testStreams() { NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput); assertEquals(original, copy); + + SetOnce> queryTokensSetOnce = new SetOnce<>(); + queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); + original.queryTokensSupplier(queryTokensSetOnce::get); + + BytesStreamOutput streamOutput2 = new BytesStreamOutput(); + original.writeTo(streamOutput2); + + filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput2.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + copy = new NeuralSparseQueryBuilder(filterStreamInput); + assertEquals(original, copy); } public void testHashAndEquals() { @@ -275,6 +293,8 @@ public void testHashAndEquals() { float boost2 = 3.8f; String queryName1 = "query-1"; String queryName2 = "query-2"; + Map queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f); + Map queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f); NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) @@ -329,6 +349,22 @@ public void testHashAndEquals() { .boost(boost1) .queryName(queryName2); + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(() -> queryTokens1); + + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(() -> queryTokens2); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -352,6 +388,12 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); } @SneakyThrows