Skip to content

Commit

Permalink
Added Integration Tests and Unit test for Efficient Filtering for Fai…
Browse files Browse the repository at this point in the history
…ss Engine

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Jun 13, 2023
1 parent f5ff953 commit 1bde71e
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 15 deletions.
23 changes: 10 additions & 13 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
}
docIdsToScoreMap.putAll(annResults);
}
if (docIdsToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
}
return convertSearchResponseToScorer(docIdsToScoreMap);
}

Expand All @@ -134,12 +137,7 @@ private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight f
return new FixedBitSet(0);
}

final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
// TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the
// distance for K vectors. This can avoid calls to native layer and save some latency.
final int cost = acceptDocs.cardinality();
log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost);
return acceptDocs;
return createBitSet(scorer.iterator(), liveDocs, maxDoc);
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
Expand All @@ -165,7 +163,7 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (true) {
while (docId < filteredDocsBitSet.length()) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
Expand Down Expand Up @@ -291,23 +289,22 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name);
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
//Creating min heap and init with MAX DocID and Score as -INF.
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
for (int filterId : filterIdsArray) {
int docId = values.advance(filterId);
final BytesRef value = values.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset,
value.length);
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
if(score > topDoc.score) {
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
Expand All @@ -329,7 +326,7 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont

return docToScore;
} catch (Exception e) {
log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery);
log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery, e);
}
return Collections.emptyMap();
}
Expand Down
115 changes: 115 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
package org.opensearch.knn.index;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.http.util.EntityUtils;
import org.junit.BeforeClass;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -26,6 +30,7 @@
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.net.URL;
Expand All @@ -44,6 +49,14 @@

public class FaissIT extends KNNRestTestCase {

private static final String INDEX_NAME = "test-index-1";
private static final String FIELD_NAME = "test-field-1";
private static final String DOC_ID = "doc1";
private static final String DOC_ID_2 = "doc2";
private static final String DOC_ID_3 = "doc3";
private static final String COLOR_FIELD_NAME = "color";
private static final String TASTE_FIELD_NAME = "taste";

static TestUtils.TestData testData;

@BeforeClass
Expand Down Expand Up @@ -280,4 +293,106 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}
}

@SneakyThrows
public void testQueryWithFilter_withDifferentCombination_thenSuccess() {
setupKNNIndexForFilterQuery();
final float[] searchVector = { 6.0f, 6.0f, 4.1f };
// K > filteredResults
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3);
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));

// K Limits Filter results
int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = List.of(DOC_ID);
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);

// Empty filter docIds
int k = 10;
final Response emptyFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
searchVector,
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present")
),
k
);
final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity());
final List<KNNResult> emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME);

assertEquals(0, emptyKNNFilteredResultsFromResponse.size());
}

private void setupKNNIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", 3)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.endObject()
.endObject();
final String mapping = Strings.toString(builder);

createKnnIndex(INDEX_NAME, mapping);

addKnnDocWithAttributes(
DOC_ID,
new float[] { 6.0f, 7.9f, 3.1f },
ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")
);
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);
}

private void addKnnDocWithAttributes(String docId, float[] vector, Map<String, String> fieldValues) throws IOException {
Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true");

final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector);
for (String fieldName : fieldValues.keySet()) {
builder.field(fieldName, fieldValues.get(fieldName));
}
builder.endObject();
request.setJsonEntity(Strings.toString(builder));
client().performRequest(request);

request = new Request("POST", "/" + INDEX_NAME + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.mockito.Mockito;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -23,6 +26,10 @@
import static org.mockito.Mockito.when;

public class KNNQueryFactoryTests extends KNNTestCase {
private static final String FILTER_FILED_NAME = "foo";
private static final String FILTER_FILED_VALUE = "fooval";
private static final QueryBuilder FILTER_QUERY_BUILDER = new TermQueryBuilder(FILTER_FILED_NAME, FILTER_FILED_VALUE);
private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE));
private final int testQueryDimension = 17;
private final float[] testQueryVector = new float[testQueryDimension];
private final String testIndexName = "test-index";
Expand Down Expand Up @@ -59,18 +66,42 @@ public void testCreateLuceneQueryWithFilter() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
QueryBuilder filter = new TermQueryBuilder("foo", "fooval");
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.k(testK)
.context(mockQueryShardContext)
.filter(filter)
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}
}

public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() {
final KNNEngine knnEngine = KNNEngine.FAISS;
final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY);
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.k(testK)
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
final Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query instanceof KNNQuery);

assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
assertEquals(testFieldName, ((KNNQuery) query).getField());
assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector());
assertEquals(testK, ((KNNQuery) query).getK());
assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery());
}
}
Loading

0 comments on commit 1bde71e

Please sign in to comment.