Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into feature/text-embedding-pr…
Browse files Browse the repository at this point in the history
…ocessor-nested-fields
  • Loading branch information
Sanjana679 authored Nov 22, 2023
2 parents 0e38fea + b3c73bd commit 2177ba2
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
Fix async actions are left in neural_sparse query ([438](https://github.com/opensearch-project/neural-search/pull/438))
Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)
### Infrastructure
### Documentation
### Maintenance
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ dependencies {
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
runtimeOnly group: 'org.json', name: 'json', version: '20231013'
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public DocIdSetIterator iterator() {
*/
@Override
public float getMaxScore(int upTo) throws IOException {
return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
return subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
try {
return scorer.getMaxScore(upTo);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -84,13 +85,23 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
if (in.readBoolean()) {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeString(queryText);
out.writeString(modelId);
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -256,16 +267,25 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false;
if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId);
if (!Objects.isNull(queryTokensSupplier)) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode();
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId);
if (!Objects.isNull(queryTokensSupplier)) {
builder.append(queryTokensSupplier.get());
}
return builder.toHashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Map;
import java.util.Optional;

import org.opensearch.indices.IndicesService;
import org.opensearch.ingest.IngestService;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
Expand Down Expand Up @@ -66,7 +67,8 @@ public void testProcessors() {
null,
mock(IngestService.class),
null,
null
null,
mock(IndicesService.class)
);
Map<String, Processor.Factory> processors = plugin.getProcessors(processorParams);
assertNotNull(processors);
Expand Down
88 changes: 66 additions & 22 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.junit.After;
import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -33,6 +34,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index";
private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index";
private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index";
private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index";
private static final String TEST_QUERY_TEXT = "greetings";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final String TEST_QUERY_TEXT3 = "hello";
Expand Down Expand Up @@ -188,6 +190,35 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult(
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);

MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
MatchQueryBuilder matchQuery2Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(matchQueryBuilder);
hybridQueryBuilderOnlyTerm.add(matchQuery2Builder);
MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
boolQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertTrue(getHitCount(searchResponseAsMap) > 0);
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f);

Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertTrue((int) total.get("value") > 0);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -242,32 +273,45 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
TEST_MULTI_DOC_INDEX_NAME,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE))
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"1",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1)
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"2",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector2).toArray())
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"3",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector3).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2)
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME);
}

if (TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD)) {
prepareKnnIndex(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
1
);
assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME));
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
}
}

private void addDocsToIndex(final String testMultiDocIndexName) {
addKnnDoc(
testMultiDocIndexName,
"1",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1)
);
addKnnDoc(
testMultiDocIndexName,
"2",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector2).toArray())
);
addKnnDoc(
testMultiDocIndexName,
"3",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector3).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2)
);
assertEquals(3, getDocCount(testMultiDocIndexName));
}

private List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.neuralsearch.query;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -21,6 +23,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.TestUtil;

Expand Down Expand Up @@ -169,6 +172,63 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR
testWithQuery(docs, scores, hybridQueryScorer);
}

@SneakyThrows
public void testMaxScore_whenMultipleScorers_thenSuccessful() {
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(
weight,
Arrays.asList(
scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())),
scorer(docs, scores, fakeWeight(new MatchNoDocsQuery()))
)
);

float maxScore = hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertTrue(maxScore > 0.0f);

HybridQueryScorer hybridQueryScorerWithSomeNullSubScorers = new HybridQueryScorer(
weight,
Arrays.asList(null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null)
);

maxScore = hybridQueryScorerWithSomeNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertTrue(maxScore > 0.0f);

HybridQueryScorer hybridQueryScorerWithAllNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(null, null));

maxScore = hybridQueryScorerWithAllNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertEquals(0.0f, maxScore, 0.0f);
}

@SneakyThrows
public void testMaxScoreFailures_whenScorerThrowsException_thenFail() {
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

Scorer scorer = mock(Scorer.class);
when(scorer.getWeight()).thenReturn(fakeWeight(new MatchAllDocsQuery()));
when(scorer.iterator()).thenReturn(iterator(docs));
when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception"));

HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer));

RuntimeException runtimeException = expectThrows(
RuntimeException.class,
() -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE)
);
assertTrue(runtimeException.getMessage().contains("Test exception"));
}

private Pair<int[], float[]> generateDocuments(int maxDocId) {
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() {
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
IndexSearcher searcher = newSearcher(reader);
Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f);
Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);

assertNotNull(weight);

LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0);
Scorer scorer = weight.scorer(leafReaderContext);

assertNotNull(scorer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import lombok.SneakyThrows;

import org.opensearch.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -262,6 +263,23 @@ public void testStreams() {

NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);

SetOnce<Map<String, Float>> queryTokensSetOnce = new SetOnce<>();
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
original.queryTokensSupplier(queryTokensSetOnce::get);

BytesStreamOutput streamOutput2 = new BytesStreamOutput();
original.writeTo(streamOutput2);

filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput2.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
);

copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

public void testHashAndEquals() {
Expand All @@ -275,6 +293,8 @@ public void testHashAndEquals() {
float boost2 = 3.8f;
String queryName1 = "query-1";
String queryName2 = "query-2";
Map<String, Float> queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f);
Map<String, Float> queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f);

NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
Expand Down Expand Up @@ -329,6 +349,22 @@ public void testHashAndEquals() {
.boost(boost1)
.queryName(queryName2);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);

assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());

Expand All @@ -352,6 +388,12 @@ public void testHashAndEquals() {

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode());
}

@SneakyThrows
Expand Down

0 comments on commit 2177ba2

Please sign in to comment.