Skip to content

Commit

Permalink
Rename context class, adjust lucene IT
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 24, 2022
1 parent 52e2b6b commit 60a27bf
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.bwc;

import org.hamcrest.MatcherAssert;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand All @@ -19,7 +20,11 @@

import java.io.IOException;

import static org.hamcrest.CoreMatchers.anyOf;
import static org.hamcrest.CoreMatchers.containsString;
import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

/**
* Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine
Expand All @@ -36,7 +41,11 @@ public void testLuceneFiltering() throws Exception {
float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0];
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMappingWithLuceneField(TEST_FIELD, DIMENSIONS));
createKnnIndex(
testIndex,
getKNNDefaultIndexSettings(),
createKnnIndexMapping(TEST_FIELD, DIMENSIONS, METHOD_HNSW, LUCENE_NAME)
);
bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS);
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
break;
Expand All @@ -50,25 +59,6 @@ public void testLuceneFiltering() throws Exception {
}
}

protected String createKnnIndexMappingWithLuceneField(final String fieldName, int dimension) throws IOException {
return Strings.toString(
XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", Integer.toString(dimension))
.startObject("method")
.field("name", "hnsw")
.field("engine", "lucene")
.field("space_type", "l2")
.endObject()
.endObject()
.endObject()
.endObject()
);
}

private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
Expand All @@ -81,6 +71,15 @@ private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQuery
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(Strings.toString(builder));

expectThrows(ResponseException.class, () -> client().performRequest(request));
Exception exception = expectThrows(ResponseException.class, () -> client().performRequest(request));
//assert for two possible exception messages, fist one can come from current version in case serialized request is coming from lower version,
//second exception is vise versa, when lower version node receives request with filter field from higher version
MatcherAssert.assertThat(
exception.getMessage(),
anyOf(
containsString("filter field is supported from version"),
containsString("[knn] unknown token [START_OBJECT] after [filter]")
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class KNNClusterContext {
public class KNNClusterUtil {

private ClusterService clusterService;
private static KNNClusterContext instance;
private static KNNClusterUtil instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized KNNClusterContext instance() {
public static synchronized KNNClusterUtil instance() {
if (instance == null) {
instance = new KNNClusterContext();
instance = new KNNClusterUtil();
}
return instance;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.opensearch.Version;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNClusterContext;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -322,6 +322,6 @@ public String getWriteableName() {
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
return KNNClusterContext.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
}
}
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.opensearch.index.codec.CodecServiceFactory;
import org.opensearch.index.engine.EngineFactory;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNClusterContext;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand Down Expand Up @@ -180,7 +180,7 @@ public Collection<Object> createComponents(
NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader);

KNNSettings.state().initialize(client, clusterService);
KNNClusterContext.instance().initialize(clusterService);
KNNClusterUtil.instance().initialize(clusterService);
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@
import static org.mockito.Mockito.when;
import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService;

public class KNNClusterContextTests extends KNNTestCase {
public class KNNClusterUtilTests extends KNNTestCase {

public void testSingleNodeCluster() {
ClusterService clusterService = mockClusterService(Version.V_2_4_0);

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

final Version minVersion = knnClusterContext.getClusterMinVersion();
final Version minVersion = knnClusterUtil.getClusterMinVersion();

assertTrue(Version.V_2_4_0.equals(minVersion));
}

public void testMultipleNodesCluster() {
ClusterService clusterService = mockClusterService(Version.V_2_3_0);

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

final Version minVersion = knnClusterContext.getClusterMinVersion();
final Version minVersion = knnClusterUtil.getClusterMinVersion();

assertTrue(Version.V_2_3_0.equals(minVersion));
}
Expand All @@ -41,10 +41,10 @@ public void testWhenErrorOnClusterStateDiscover() {
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready"));

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

final Version minVersion = knnClusterContext.getClusterMinVersion();
final Version minVersion = knnClusterUtil.getClusterMinVersion();

assertTrue(Version.CURRENT.equals(minVersion));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.opensearch.index.Index;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNClusterContext;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -105,8 +105,8 @@ public void testFromXcontent() throws Exception {
public void testFromXcontent_WithFilter() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
Expand All @@ -127,8 +127,8 @@ public void testFromXcontent_WithFilter() throws Exception {
public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception {
final ClusterService clusterService = mockClusterService(Version.V_2_3_0);

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
Expand Down Expand Up @@ -268,8 +268,8 @@ private void assertSerialization(final Version version, final Optional<QueryBuil

final ClusterService clusterService = mockClusterService(version);

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
output.writeNamedWriteable(knnQueryBuilder);
Expand Down

0 comments on commit 60a27bf

Please sign in to comment.