Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Integration Tests and Unit test for Efficient Filtering for Faiss Engine #934

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
}
docIdsToScoreMap.putAll(annResults);
}
if (docIdsToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
}
return convertSearchResponseToScorer(docIdsToScoreMap);
}

Expand All @@ -134,12 +137,7 @@ private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight f
return new FixedBitSet(0);
}

final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
// TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
// distance for K vectors. This can avoid calls to native layer and save some latency.
final int cost = acceptDocs.cardinality();
log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost);
return acceptDocs;
return createBitSet(scorer.iterator(), liveDocs, maxDoc);
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
Expand All @@ -165,7 +163,7 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (true) {
while (docId < filteredDocsBitSet.length()) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
Expand Down Expand Up @@ -291,23 +289,22 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name);
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
//Creating min heap and init with MAX DocID and Score as -INF.
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
for (int filterId : filterIdsArray) {
int docId = values.advance(filterId);
final BytesRef value = values.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset,
value.length);
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
if(score > topDoc.score) {
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
Expand All @@ -329,7 +326,7 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont

return docToScore;
} catch (Exception e) {
log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery);
log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery, e);
}
return Collections.emptyMap();
}
Expand Down
99 changes: 99 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
package org.opensearch.knn.index;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.http.util.EntityUtils;
import org.junit.BeforeClass;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -26,6 +30,7 @@
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.net.URL;
Expand All @@ -44,6 +49,14 @@

public class FaissIT extends KNNRestTestCase {

private static final String INDEX_NAME = "test-index-1";
private static final String FIELD_NAME = "test-field-1";
private static final String DOC_ID_1 = "doc1";
private static final String DOC_ID_2 = "doc2";
private static final String DOC_ID_3 = "doc3";
private static final String COLOR_FIELD_NAME = "color";
private static final String TASTE_FIELD_NAME = "taste";

static TestUtils.TestData testData;

@BeforeClass
Expand Down Expand Up @@ -280,4 +293,90 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}
}

@SneakyThrows
public void testQueryWithFilter_withDifferentCombination_thenSuccess() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add different tests for exact and ann searchers? I on a fence if we need it in IT as we already have unit test for this and it might be too low level for IT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we should have integration tests. Even when we have Unit tests. The reason why I am combining both is, the setup for the tests are same. Hence I combined them at 1 place. Given that we already have so many integration tests and setting up each integration tests add more time. Hence, using a single setup.

setupKNNIndexForFilterQuery();
final float[] searchVector = { 6.0f, 6.0f, 4.1f };
// K > filteredResults
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList(DOC_ID_1, DOC_ID_3);
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));

// K Limits Filter results
int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = List.of(DOC_ID_1);
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dup of 313 - should we just make this its own statement?

knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);

// Empty filter docIds
int k = 10;
final Response emptyFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
searchVector,
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present")
),
k
);
final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity());
final List<KNNResult> emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME);

assertEquals(0, emptyKNNFilteredResultsFromResponse.size());
}

protected void setupKNNIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", 3)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.endObject()
.endObject();
final String mapping = Strings.toString(builder);

createKnnIndex(INDEX_NAME, mapping);

addKnnDocWithAttributes(
DOC_ID_1,
new float[] { 6.0f, 7.9f, 3.1f },
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can generate vector values using random, seems we're not testing exact doc ids in result. This way it's easy to generate more docs and then vary between ann and exact searches

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use random generator. But I don't see any point of using random generator. Atleast this way we know what we are ingesting and in future we can compare scores also.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have both exact vector values and random generated, first can be used in case we're checking score values and second one in cases when we do need more massive corpus. I think we doing random generation for other test in k-NN. Although not a blocker for this PR.

ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")
);
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);
}
}
16 changes: 0 additions & 16 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -361,22 +361,6 @@ public void testIndexReopening() throws Exception {
assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray());
}

private void addKnnDocWithAttributes(String docId, float[] vector, Map<String, String> fieldValues) throws IOException {
Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true");

XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector);
for (String fieldName : fieldValues.keySet()) {
builder.field(fieldName, fieldValues.get(fieldName));
}
builder.endObject();
request.setJsonEntity(Strings.toString(builder));
client().performRequest(request);

request = new Request("POST", "/" + INDEX_NAME + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.mockito.Mockito;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -23,6 +26,10 @@
import static org.mockito.Mockito.when;

public class KNNQueryFactoryTests extends KNNTestCase {
private static final String FILTER_FILED_NAME = "foo";
private static final String FILTER_FILED_VALUE = "fooval";
private static final QueryBuilder FILTER_QUERY_BUILDER = new TermQueryBuilder(FILTER_FILED_NAME, FILTER_FILED_VALUE);
private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE));
private final int testQueryDimension = 17;
private final float[] testQueryVector = new float[testQueryDimension];
private final String testIndexName = "test-index";
Expand Down Expand Up @@ -59,18 +66,42 @@ public void testCreateLuceneQueryWithFilter() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
QueryBuilder filter = new TermQueryBuilder("foo", "fooval");
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.k(testK)
.context(mockQueryShardContext)
.filter(filter)
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}
}

public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() {
final KNNEngine knnEngine = KNNEngine.FAISS;
final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY);
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.k(testK)
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
final Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query instanceof KNNQuery);

assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
assertEquals(testFieldName, ((KNNQuery) query).getField());
assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector());
assertEquals(testK, ((KNNQuery) query).getK());
assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery());
}
}
Loading