From cc3bef85c9ae29f1c7392d5f2616f0ff86ae12bf Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 21 Nov 2023 18:37:52 -0800 Subject: [PATCH 1/8] Fixed nested field case Signed-off-by: Martin Gaievski --- CHANGELOG.md | 5 +- .../neuralsearch/query/HybridQuery.java | 5 ++ .../query/HybridQueryPhaseSearcher.java | 73 ++++++++++++++++- .../common/BaseNeuralSearchIT.java | 16 +++- .../neuralsearch/query/HybridQueryIT.java | 80 +++++++++++++++++-- .../query/HybridQueryPhaseSearcherTests.java | 7 +- 6 files changed, 174 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa866ca7e..9c68a924f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -Fix async actions are left in neural_sparse query ([438](https://github.com/opensearch-project/neural-search/pull/438)) -Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490) +Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) +Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)) +Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 2c79c56e5..c96187a65 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -148,6 +148,11 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } + public void addSubQuery(final Query query) { + Objects.requireNonNull(subQueries, "collection of queries must not be null"); + subQueries.add(query); + } + /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index f65e30222..c838a7d8e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -19,12 +19,17 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.SeqNoFieldMapper; +import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; @@ -48,6 +53,8 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + final static int MAX_NESTED_SUBQUERY_LIMIT = 20; + public HybridQueryPhaseSearcher() { super(); } @@ -55,17 +62,79 @@ public HybridQueryPhaseSearcher() { public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, - final Query query, + Query query, final LinkedList collectors, final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (query instanceof HybridQuery) { + if (isHybridQuery(query, searchContext)) { + query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + validateHybridQuery(query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + if (query instanceof HybridQuery) { + return true; + } else if (hasNestedFieldOrNestedDocs(query, searchContext) && mightBeWrappedHybridQuery(query)) { + BooleanQuery booleanQuery = (BooleanQuery) query; + return booleanQuery.clauses() + .stream() + .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .allMatch( + clause -> clause.getOccur() == BooleanClause.Occur.FILTER + && clause.getQuery() instanceof FieldExistsQuery + && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()) + ); + } + return false; + } + + private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); + } + + private boolean mightBeWrappedHybridQuery(final Query query) { + return query instanceof BooleanQuery + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); + } + + private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + if (hasNestedFieldOrNestedDocs(query, searchContext) + && mightBeWrappedHybridQuery(query) + && ((BooleanQuery) query).clauses().size() > 0) { + // extract hybrid query and replace bool with hybrid query + List booleanClauses = ((BooleanQuery) query).clauses(); + return booleanClauses.stream().findFirst().get().getQuery(); + } + return query; + } + + private void validateHybridQuery(final Query query) { + if (query instanceof BooleanQuery) { + List booleanClauses = ((BooleanQuery) query).clauses(); + for (BooleanClause booleanClause : booleanClauses) { + validateNestedBooleanQuery(booleanClause.getQuery(), 1); + } + } + } + + private void validateNestedBooleanQuery(final Query query, int level) { + if (query instanceof HybridQuery) { + throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); + } + if (level >= MAX_NESTED_SUBQUERY_LIMIT) { + throw new IllegalStateException("reached max nested query limit, cannot process query"); + } + if (query instanceof BooleanQuery) { + for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) { + validateNestedBooleanQuery(booleanClause.getQuery(), level + 1); + } + } + } + @VisibleForTesting protected boolean searchWithCollector( final SearchContext searchContext, diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 33cdff9a0..61c6c13df 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -523,7 +523,16 @@ protected boolean checkComplete(Map node) { } @SneakyThrows - private String buildIndexConfiguration(List knnFieldConfigs, int numberOfShards) { + protected String buildIndexConfiguration(final List knnFieldConfigs, final int numberOfShards) { + return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards); + } + + @SneakyThrows + protected String buildIndexConfiguration( + final List knnFieldConfigs, + final List nestedFields, + final int numberOfShards + ) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject("settings") @@ -544,6 +553,11 @@ private String buildIndexConfiguration(List knnFieldConfigs, int .endObject() .endObject(); } + + for (String nestedField : nestedFields) { + xContentBuilder.startObject(nestedField).field("type", "nested").endObject(); + } + xContentBuilder.endObject().endObject().endObject(); return xContentBuilder.toString(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 171d2f4a4..176628eb8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.query; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -21,6 +23,7 @@ import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -35,6 +38,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = + "test-neural-multi-doc-nested-type--single-shard-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -191,7 +196,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( } @SneakyThrows - public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() { + public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenFail() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -202,23 +207,71 @@ public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder); + ResponseException exceptionNoNestedTypes = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, boolQueryBuilder, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + + org.hamcrest.MatcherAssert.assertThat( + exceptionNoNestedTypes.getMessage(), + allOf( + containsString("hybrid query must be a top level query and cannot be wrapped into other queries"), + containsString("illegal_argument_exception") + ) + ); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + ResponseException exceptionQWithNestedTypes = expectThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + boolQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exceptionQWithNestedTypes.getMessage(), + allOf( + containsString("hybrid query must be a top level query and cannot be wrapped into other queries"), + containsString("illegal_argument_exception") + ) + ); + } + + @SneakyThrows + public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(termQuery2Builder); + Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - boolQueryBuilder, + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + hybridQueryBuilderOnlyTerm, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertTrue(getHitCount(searchResponseAsMap) > 0); + assertEquals(0, getHitCount(searchResponseAsMap)); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f); + assertEquals(0.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); - assertTrue((int) total.get("value") > 0); + assertEquals(0, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( @@ -284,6 +337,21 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { ); addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } + + if (TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of("user"), + 1 + ), + "" + ); + + addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + } } private void addDocsToIndex(final String testMultiDocIndexName) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e9c55cc54..b8dc8fa4a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -41,6 +41,7 @@ import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -204,6 +205,8 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + MapperService mapperService = mock(MapperService.class); + when(searchContext.mapperService()).thenReturn(mapperService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -217,7 +220,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -265,6 +269,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); From 9a0c2093a080f2fcc2c314bacaf6cae1049dbd23 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 28 Nov 2023 14:09:21 -0800 Subject: [PATCH 2/8] Add unit tests Signed-off-by: Martin Gaievski --- .../neuralsearch/query/HybridQuery.java | 5 - .../query/HybridQueryPhaseSearcher.java | 4 +- .../common/BaseNeuralSearchIT.java | 28 +- .../neuralsearch/query/HybridQueryIT.java | 53 ++- .../query/HybridQueryPhaseSearcherTests.java | 322 +++++++++++++++++- 5 files changed, 398 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index c96187a65..2c79c56e5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -148,11 +148,6 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } - public void addSubQuery(final Query query) { - Objects.requireNonNull(subQueries, "collection of queries must not be null"); - subQueries.add(query); - } - /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index c838a7d8e..7b9224c39 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -71,7 +71,7 @@ public boolean searchWith( query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - validateHybridQuery(query); + validateQuery(query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } @@ -112,7 +112,7 @@ && mightBeWrappedHybridQuery(query) return query; } - private void validateHybridQuery(final Query query) { + private void validateQuery(final Query query) { if (query instanceof BooleanQuery) { List booleanClauses = ((BooleanQuery) query).clauses(); for (BooleanClause booleanClause : booleanClauses) { diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 61c6c13df..e3e57a141 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -413,6 +413,18 @@ protected void addKnnDoc(String index, String docId, List vectorFieldNam addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList()); } + @SneakyThrows + protected void addKnnDoc( + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts + ) { + addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); + } + /** * Add a set of knn vectors and text to an index * @@ -422,6 +434,8 @@ protected void addKnnDoc(String index, String docId, List vectorFieldNam * @param vectors List of vectors corresponding to those fields * @param textFieldNames List of text fields to be added * @param texts List of text corresponding to those fields + * @param nestedFieldNames List of nested fields to be added + * @param nestedFields List of fields and values corresponding to those fields */ @SneakyThrows protected void addKnnDoc( @@ -430,7 +444,9 @@ protected void addKnnDoc( List vectorFieldNames, List vectors, List textFieldNames, - List texts + List texts, + List nestedFieldNames, + List> nestedFields ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -441,6 +457,16 @@ protected void addKnnDoc( for (int i = 0; i < textFieldNames.size(); i++) { builder.field(textFieldNames.get(i), texts.get(i)); } + + for (int i = 0; i < nestedFieldNames.size(); i++) { + builder.field(nestedFieldNames.get(i)); + builder.startObject(); + Map nestedValues = nestedFields.get(i); + for (Map.Entry entry : nestedValues.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } builder.endObject(); request.setJsonEntity(builder.toString()); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 176628eb8..4c3ff2cb9 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -7,6 +7,7 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.index.query.QueryBuilders.matchQuery; import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -21,11 +22,13 @@ import lombok.SneakyThrows; +import org.apache.lucene.search.join.ScoreMode; import org.junit.After; import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -51,9 +54,14 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final String NESTED_FIELD_1 = "firstname"; + private static final String NESTED_FIELD_2 = "lastname"; + private static final String NESTED_FIELD_1_VALUE = "john"; + private static final String NESTED_FIELD_2_VALUE = "black"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -271,6 +279,39 @@ public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows + public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + NestedQueryBuilder nestedQueryBuilder = QueryBuilders.nestedQuery( + TEST_NESTED_TYPE_FIELD_NAME_1, + matchQuery(TEST_NESTED_TYPE_FIELD_NAME_1 + "." + NESTED_FIELD_1, NESTED_FIELD_1_VALUE), + ScoreMode.Total + ); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(nestedQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + hybridQueryBuilderOnlyTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(1, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(getMaxScore(searchResponseAsMap).get() > 0); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(1, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -344,13 +385,23 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { indexName, buildIndexConfiguration( Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), - List.of("user"), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), 1 ), "" ); addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + List.of(), + List.of(), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index b8dc8fa4a..4e00cf0f5 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -5,8 +5,10 @@ package org.opensearch.neuralsearch.search.query; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -14,6 +16,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAX_NESTED_SUBQUERY_LIMIT; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; @@ -21,6 +24,7 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import java.util.Set; import lombok.SneakyThrows; @@ -30,6 +34,8 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -38,11 +44,13 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -83,7 +91,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -126,6 +135,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -152,7 +162,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -196,6 +207,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -205,8 +217,6 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); - MapperService mapperService = mock(MapperService.class); - when(searchContext.mapperService()).thenReturn(mapperService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -315,7 +325,8 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -365,6 +376,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -409,6 +421,295 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); + + Query query = boolQueryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("hybrid query must be a top level query and cannot be wrapped into other queries") + ); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + b.startObject("user"); + b.field("type", "nested"); + b.endObject(); + })); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + when(mockQueryShardContext.getMapperService()).thenReturn(mapperService); + when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME)); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD) + .add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertTrue(scoreDocs.length > 0); + assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); + assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + + TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); + List expectedIds1 = List.of(docId1); + assertQueryResults(subQueryTopDocs1, expectedIds1, reader); + + TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); + List expectedIds2 = List.of(); + assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testBoolQuery_whenTooManyNestedLevels_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + Query query = createNestedBoolQuery( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2).toQuery(mockQueryShardContext), + MAX_NESTED_SUBQUERY_LIMIT + 1 + ); + + when(searchContext.query()).thenReturn(query); + + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("reached max nested query limit, cannot process quer") + ); + + releaseResources(directory, w, reader); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); @@ -452,4 +753,15 @@ private List getSubQueryResultsForSingleShard(final TopDocs topDocs) { } return topDocsList; } + + private BooleanQuery createNestedBoolQuery(final Query query1, final Query query2, int level) { + if (level == 0) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(query1, BooleanClause.Occur.SHOULD).add(query2, BooleanClause.Occur.SHOULD); + return builder.build(); + } + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(createNestedBoolQuery(query1, query2, level - 1), BooleanClause.Occur.MUST); + return builder.build(); + } } From b8f794b936577b2f4c23658d99eea6765ad97613 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 28 Nov 2023 17:26:22 -0800 Subject: [PATCH 3/8] Ading comments, minor refactoring Signed-off-by: Martin Gaievski --- .../query/HybridQueryPhaseSearcher.java | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 7b9224c39..6c4568936 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -79,8 +79,31 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte if (query instanceof HybridQuery) { return true; } else if (hasNestedFieldOrNestedDocs(query, searchContext) && mightBeWrappedHybridQuery(query)) { - BooleanQuery booleanQuery = (BooleanQuery) query; - return booleanQuery.clauses() + // checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. + // main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks + // hybrid query for indexes with nested field types. + // in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for + // this search request. + // below is sample structure of such query: + // + // Boolean { + // should: { + // hybrid: { + // sub_query1 {} + // sub_query2 {} + // } + // } + // filter: { + // exists: { + // field: "_primary_term" + // } + // } + // } + if (query instanceof BooleanQuery == false) { + return false; + } + return ((BooleanQuery) query).clauses() .stream() .filter(clause -> clause.getQuery() instanceof HybridQuery == false) .allMatch( From 19791bbb65b369058e141de3f015154abfb5b118 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 28 Nov 2023 18:26:52 -0800 Subject: [PATCH 4/8] Addressing review comments Signed-off-by: Martin Gaievski --- .../search/query/HybridQueryPhaseSearcher.java | 8 ++++++-- .../opensearch/neuralsearch/query/HybridQueryIT.java | 10 +++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 6c4568936..1f2939362 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -53,7 +53,7 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { - final static int MAX_NESTED_SUBQUERY_LIMIT = 20; + final static int MAX_NESTED_SUBQUERY_LIMIT = 50; public HybridQueryPhaseSearcher() { super(); @@ -130,7 +130,11 @@ && mightBeWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { // extract hybrid query and replace bool with hybrid query List booleanClauses = ((BooleanQuery) query).clauses(); - return booleanClauses.stream().findFirst().get().getQuery(); + Query hybridQuery = booleanClauses.stream().findFirst().get().getQuery(); + if (!(hybridQuery instanceof HybridQuery)) { + throw new IllegalStateException("cannot find hybrid type query in expected location"); + } + return hybridQuery; } return query; } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 4c3ff2cb9..4a8f0d065 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -254,7 +254,7 @@ public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenFail() { public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); hybridQueryBuilderOnlyTerm.add(termQueryBuilder); @@ -268,13 +268,13 @@ public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertEquals(0, getHitCount(searchResponseAsMap)); + assertEquals(1, getHitCount(searchResponseAsMap)); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertEquals(0.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.5f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); - assertEquals(0, total.get("value")); + assertEquals(1, total.get("value")); assertNotNull(total.get("relation")); assertEquals(RELATION_EQUAL_TO, total.get("relation")); } @@ -303,7 +303,7 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess( assertEquals(1, getHitCount(searchResponseAsMap)); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(getMaxScore(searchResponseAsMap).get() > 0); + assertEquals(0.5f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); From e33de9bcc6ea689e7cd10242fff9a05b3e89abeb Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 29 Nov 2023 10:30:49 -0800 Subject: [PATCH 5/8] Added comments, improved exception handling Signed-off-by: Martin Gaievski --- .../query/HybridQueryPhaseSearcher.java | 49 +++++-- .../query/HybridQueryPhaseSearcherTests.java | 131 +++++++++++++++++- 2 files changed, 163 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 1f2939362..cf0c9c85e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -28,6 +28,8 @@ import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; @@ -53,8 +55,6 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { - final static int MAX_NESTED_SUBQUERY_LIMIT = 50; - public HybridQueryPhaseSearcher() { super(); } @@ -71,7 +71,7 @@ public boolean searchWith( query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - validateQuery(query); + validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } @@ -100,17 +100,18 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte // } // } // } + // TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression if (query instanceof BooleanQuery == false) { return false; } return ((BooleanQuery) query).clauses() .stream() .filter(clause -> clause.getQuery() instanceof HybridQuery == false) - .allMatch( - clause -> clause.getOccur() == BooleanClause.Occur.FILTER + .allMatch(clause -> { + return clause.getOccur() == BooleanClause.Occur.FILTER && clause.getQuery() instanceof FieldExistsQuery - && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()) - ); + && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); + }); } return false; } @@ -130,20 +131,38 @@ && mightBeWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { // extract hybrid query and replace bool with hybrid query List booleanClauses = ((BooleanQuery) query).clauses(); - Query hybridQuery = booleanClauses.stream().findFirst().get().getQuery(); - if (!(hybridQuery instanceof HybridQuery)) { - throw new IllegalStateException("cannot find hybrid type query in expected location"); + if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) { + throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query"); } - return hybridQuery; + return booleanClauses.get(0).getQuery(); } return query; } - private void validateQuery(final Query query) { + /** + * Validate the query from neural-search plugin point of view. Current main goal for validation is to block cases + * when hybrid query is wrapped into other compound queries. + * For example, if we have Bool query like below we need to throw an error + * bool: { + * should: [ + * match: {}, + * hybrid: { + * sub_query1 {} + * sub_query2 {} + * } + * ] + * } + * TODO add similar validation for other compound type queries like dis_max, constant_score etc. + * + * @param query query to validate + */ + private void validateQuery(final SearchContext searchContext, final Query query) { if (query instanceof BooleanQuery) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + int maxDepthLimit = MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); List booleanClauses = ((BooleanQuery) query).clauses(); for (BooleanClause booleanClause : booleanClauses) { - validateNestedBooleanQuery(booleanClause.getQuery(), 1); + validateNestedBooleanQuery(booleanClause.getQuery(), maxDepthLimit); } } } @@ -152,12 +171,12 @@ private void validateNestedBooleanQuery(final Query query, int level) { if (query instanceof HybridQuery) { throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); } - if (level >= MAX_NESTED_SUBQUERY_LIMIT) { + if (level <= 0) { throw new IllegalStateException("reached max nested query limit, cannot process query"); } if (query instanceof BooleanQuery) { for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) { - validateNestedBooleanQuery(booleanClause.getQuery(), level + 1); + validateNestedBooleanQuery(booleanClause.getQuery(), level - 1); } } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 4e00cf0f5..52f81f4d4 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -16,7 +16,6 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAX_NESTED_SUBQUERY_LIMIT; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; @@ -25,6 +24,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Set; +import java.util.UUID; import lombok.SneakyThrows; @@ -44,10 +44,13 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.settings.Settings; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; @@ -82,6 +85,8 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final int K = 10; private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); + private static final UUID INDEX_UUID = UUID.randomUUID(); + private static final String TEST_INDEX = "index"; @SneakyThrows public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { @@ -473,6 +478,13 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -509,6 +521,114 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + b.startObject("user"); + b.field("type", "nested"); + b.endObject(); + })); + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER) + .add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("cannot process hybrid query due to incorrect structure of top level bool query") + ); + + releaseResources(directory, w, reader); + } + @SneakyThrows public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); @@ -677,6 +797,13 @@ public void testBoolQuery_whenTooManyNestedLevels_thenFail() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -685,7 +812,7 @@ public void testBoolQuery_whenTooManyNestedLevels_thenFail() { Query query = createNestedBoolQuery( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2).toQuery(mockQueryShardContext), - MAX_NESTED_SUBQUERY_LIMIT + 1 + (int) (MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.getDefault(null) + 1) ); when(searchContext.query()).thenReturn(query); From a1f977472c98307e5f71938941037ff45aeb6982 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 29 Nov 2023 15:47:37 -0800 Subject: [PATCH 6/8] Changing exception to log message is we reach max number of nested levels Signed-off-by: Martin Gaievski --- .../query/HybridQueryPhaseSearcher.java | 15 ++++++++--- .../query/HybridQueryPhaseSearcherTests.java | 27 ++++++++----------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index cf0c9c85e..fd29aec29 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -158,11 +158,9 @@ && mightBeWrappedHybridQuery(query) */ private void validateQuery(final SearchContext searchContext, final Query query) { if (query instanceof BooleanQuery) { - Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); - int maxDepthLimit = MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); List booleanClauses = ((BooleanQuery) query).clauses(); for (BooleanClause booleanClause : booleanClauses) { - validateNestedBooleanQuery(booleanClause.getQuery(), maxDepthLimit); + validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext)); } } } @@ -172,7 +170,11 @@ private void validateNestedBooleanQuery(final Query query, int level) { throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); } if (level <= 0) { - throw new IllegalStateException("reached max nested query limit, cannot process query"); + // ideally we should throw an error here but this code is on the main search workflow path and that might block + // execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect + // results in case hybrid query is wrapped into such bool query + log.error("reached max nested query limit, cannot process bool query with that many nested clauses"); + return; } if (query instanceof BooleanQuery) { for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) { @@ -324,4 +326,9 @@ private float getMaxScore(final List topDocs) { private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { return sortAndFormats == null ? null : sortAndFormats.formats; } + + private int getMaxDepthLimit(final SearchContext searchContext) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 52f81f4d4..c4f3f4a3e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -746,7 +746,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then } @SneakyThrows - public void testBoolQuery_whenTooManyNestedLevels_thenFail() { + public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); @@ -817,22 +817,17 @@ public void testBoolQuery_whenTooManyNestedLevels_thenFail() { when(searchContext.query()).thenReturn(query); - IllegalStateException exception = expectThrows( - IllegalStateException.class, - () -> hybridQueryPhaseSearcher.searchWith( - searchContext, - contextIndexSearcher, - query, - collectors, - hasFilterCollector, - hasTimeout - ) - ); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); - org.hamcrest.MatcherAssert.assertThat( - exception.getMessage(), - containsString("reached max nested query limit, cannot process quer") - ); + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertTrue(scoreDocs.length > 0); + assertFalse(isHybridQueryStartStopElement(scoreDocs[0])); + assertFalse(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); releaseResources(directory, w, reader); } From dbac80efa585bdf4562f06a5d3b75336ab6faaee Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 30 Nov 2023 15:42:13 -0800 Subject: [PATCH 7/8] Addressing code comments Signed-off-by: Martin Gaievski --- CHANGELOG.md | 6 +-- .../query/HybridQueryPhaseSearcher.java | 52 +++++++++---------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c68a924f..093dbff70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) -Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)) -Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) +- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) +- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)) +- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index fd29aec29..ecfab7611 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -78,29 +78,29 @@ public boolean searchWith( private boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; - } else if (hasNestedFieldOrNestedDocs(query, searchContext) && mightBeWrappedHybridQuery(query)) { - // checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. - // main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks - // hybrid query for indexes with nested field types. - // in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for - // this search request. - // below is sample structure of such query: - // - // Boolean { - // should: { - // hybrid: { - // sub_query1 {} - // sub_query2 {} - // } - // } - // filter: { - // exists: { - // field: "_primary_term" - // } - // } - // } - // TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression + } else if (hasNestedFieldOrNestedDocs(query, searchContext) && isWrappedHybridQuery(query)) { + /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code + https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. + main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks + hybrid query for indexes with nested field types. + in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for + this search request. + below is sample structure of such query: + + Boolean { + should: { + hybrid: { + sub_query1 {} + sub_query2 {} + } + } + filter: { + exists: { + field: "_primary_term" + } + } + } + TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ if (query instanceof BooleanQuery == false) { return false; } @@ -120,14 +120,14 @@ private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContex return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - private boolean mightBeWrappedHybridQuery(final Query query) { + private boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } private Query extractHybridQuery(final SearchContext searchContext, final Query query) { if (hasNestedFieldOrNestedDocs(query, searchContext) - && mightBeWrappedHybridQuery(query) + && isWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { // extract hybrid query and replace bool with hybrid query List booleanClauses = ((BooleanQuery) query).clauses(); @@ -165,7 +165,7 @@ private void validateQuery(final SearchContext searchContext, final Query query) } } - private void validateNestedBooleanQuery(final Query query, int level) { + private void validateNestedBooleanQuery(final Query query, final int level) { if (query instanceof HybridQuery) { throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); } From 379478aa4b507dcb9bb1a09a43436fbfdaae4bcb Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 30 Nov 2023 17:01:45 -0800 Subject: [PATCH 8/8] Restore final for the main method query argument Signed-off-by: Martin Gaievski --- .../search/query/HybridQueryPhaseSearcher.java | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index ecfab7611..26f580364 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -62,14 +62,14 @@ public HybridQueryPhaseSearcher() { public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, - Query query, + final Query query, final LinkedList collectors, final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { if (isHybridQuery(query, searchContext)) { - query = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + Query hybridQuery = extractHybridQuery(searchContext, query); + return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); @@ -78,7 +78,7 @@ public boolean searchWith( private boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; - } else if (hasNestedFieldOrNestedDocs(query, searchContext) && isWrappedHybridQuery(query)) { + } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks @@ -101,9 +101,7 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte } } TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ - if (query instanceof BooleanQuery == false) { - return false; - } + // we have already checked if query in instance of Boolean in higher level else if condition return ((BooleanQuery) query).clauses() .stream() .filter(clause -> clause.getQuery() instanceof HybridQuery == false)