Skip to content

Commit

Permalink
Fixed LeafReaders casting errors to SegmentReaders when segment repli…
Browse files Browse the repository at this point in the history
…cation is enabled during search

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Jul 10, 2024
1 parent ac56da8 commit efcfc58
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
su `id -un 1000` -c "whoami && java -version && ./gradlew build"
else
echo "avx2 not available on system"
su `id -un 1000` -c "whoami && java -version && ./gradlew build -Dsimd.enabled=false"
su `id -un 1000` -c "whoami && java -version && ./gradlew build -Dsimd.enabled=false -PnumNodes=2"
fi
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:
./gradlew build
else
echo "avx2 not available on system"
./gradlew build -Dsimd.enabled=false
./gradlew build -Dsimd.enabled=false -PnumNodes=2
fi
Build-k-NN-Windows:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_security.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ jobs:
# switching the user, as OpenSearch cluster can only be started as root/Administrator on linux-deb/linux-rpm/windows-zip.
run: |
chown -R 1000:1000 `pwd`
su `id -un 1000` -c "whoami && java -version && ./gradlew integTest -Dsecurity.enabled=true -Dsimd.enabled=true"
su `id -un 1000` -c "whoami && java -version && ./gradlew integTest -PnumNodes=2 -Dsecurity.enabled=true -Dsimd.enabled=true"
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand Down Expand Up @@ -160,7 +161,7 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
List<EngineFileContext> engineFiles = new ArrayList<>();

for (LeafReaderContext leafReaderContext : indexReader.leaves()) {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
Path shardPath = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory();
String fileExtension = reader.getSegmentInfo().info.getUseCompoundFile()
? knnEngine.getCompoundExtension()
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -28,6 +27,7 @@
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -197,7 +197,7 @@ private int[] bitSetToIntArray(final BitSet bitSet) {

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality)
throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
final SegmentReader reader = Lucene.segmentReader(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
Expand Down Expand Up @@ -369,7 +369,7 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont

private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet)
throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
Expand Down
92 changes: 92 additions & 0 deletions src/test/java/org/opensearch/knn/index/SegmentReplicationIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.Assert;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;

import java.util.List;

/**
* This IT class contains will contain special cases of IT for segment replication behavior.
* All the index created in this test will have replication type SEGMENT, number of replicas: 1 and should be run on
* at-least 2 node configuration.
*/
@Log4j2
public class SegmentReplicationIT extends KNNRestTestCase {
private static final String INDEX_NAME = "segment-replicated-knn-index";

@SneakyThrows
public void testSearchOnReplicas_whenIndexHasDeletedDocs_thenSuccess() {
if(ensureMinDataNodesCountForSegmentReplication() == false) {
return;
}
createKnnIndex(INDEX_NAME, getKNNSegmentReplicatedIndexSettings(),
createKNNIndexMethodFieldMapping(FIELD_NAME, 2));

Float[] vector = { 1.3f, 2.2f };
int docsInIndex = 10;

for (int i = 0; i < docsInIndex; i++) {
addKnnDoc(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector);
}
refreshIndex(INDEX_NAME);
int deleteDocs = 5;
for (int i = 0; i < deleteDocs; i++) {
deleteKnnDoc(INDEX_NAME, Integer.toString(i));
}
refreshIndex(INDEX_NAME);
// sleep for 5sec to ensure data is replicated. I don't have a better way here to know if segments has been
// replicated.
Thread.sleep(5000);
// validate warmup is successful or not.
doKnnWarmup(List.of(INDEX_NAME));

XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilder.startObject("knn");
queryBuilder.startObject(FIELD_NAME);
queryBuilder.field("vector", vector);
queryBuilder.field("k", docsInIndex);
queryBuilder.endObject().endObject().endObject().endObject();
// validate replicas are working
Response searchResponse = performSearch(INDEX_NAME, queryBuilder.toString(), "preference=_replica");
String responseBody = EntityUtils.toString(searchResponse.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);
assertEquals(docsInIndex - deleteDocs, knnResults.size());

// validate primaries are working
searchResponse = performSearch(INDEX_NAME, queryBuilder.toString(), "preference=_primary");
responseBody = EntityUtils.toString(searchResponse.getEntity());
knnResults = parseSearchResponse(responseBody, FIELD_NAME);
assertEquals(docsInIndex - deleteDocs, knnResults.size());
}

private boolean ensureMinDataNodesCountForSegmentReplication() {
int dataNodeCount = getDataNodeCount();
if(dataNodeCount <= 1) {
log.warn("Not running segment replication tests named: " +
"testSearchOnReplicas_whenIndexHasDeletedDocs_thenSuccess, as data nodes count is not atleast 2. " +
"Actual datanode count : {}", dataNodeCount);
Assert.assertTrue(true);
// making the test successful because we don't want to break already running tests.
return false;
}
return true;
}
}
32 changes: 31 additions & 1 deletion src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuil
}

protected Response performSearch(final String indexName, final String query) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_search");
return performSearch(indexName, query, "");
}

protected Response performSearch(final String indexName, final String query, final String urlParameters) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_search?" + urlParameters);
request.setJsonEntity(query);

Response response = client().performRequest(request);
Expand Down Expand Up @@ -667,6 +671,32 @@ protected Settings getKNNDefaultIndexSettings() {
return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build();
}

protected Settings getKNNSegmentReplicatedIndexSettings() {
return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 1).put("index.knn", true)
.put("index.replication.type", "SEGMENT").build();
}


@SneakyThrows
protected int getDataNodeCount() {
Request request = new Request("GET", "_nodes/stats?filter_path=nodes.*.roles");

Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());

Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> nodesInfo = (Map<String, Object>)responseMap.get("nodes");
int dataNodeCount = 0;
for(String key : nodesInfo.keySet()) {
Map<String,List<String>> nodeRoles = (Map<String,List<String>>)nodesInfo.get(key);
if(nodeRoles.get("roles").contains("data")) {
dataNodeCount++;
}
}
return dataNodeCount;
}

/**
* Get Stats from KNN Plugin
*/
Expand Down

0 comments on commit efcfc58

Please sign in to comment.