From 3fc88f30b1717ef4ce4e0b70086ddbbbc18fad3a Mon Sep 17 00:00:00 2001 From: Jack Mazanec Date: Mon, 9 Mar 2020 14:17:41 -0400 Subject: [PATCH] Integration test update (#61) --- build.gradle | 6 +- .../knn/plugin/KNNPlugin.java | 25 +- .../knn/plugin/stats/KNNStatsConfig.java | 47 +++ .../knn/index/BaseKNNIntegTestIT.java | 338 ++++++++++++++---- ...wIndexIT.java => KNN80HnswIndexTests.java} | 5 +- .../knn/index/KNNESIT.java | 206 +++++------ .../knn/index/KNNESSettingsTestIT.java | 110 +----- .../index/{KNNJNIIT.java => KNNJNITests.java} | 6 +- .../knn/index/KNNMapperSearcherIT.java | 240 +++++-------- .../knn/index/KNNResult.java | 34 ++ .../plugin/action/RestKNNStatsHandlerIT.java | 333 +++-------------- 11 files changed, 615 insertions(+), 735 deletions(-) create mode 100644 src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/stats/KNNStatsConfig.java rename src/test/java/com/amazon/opendistroforelasticsearch/knn/index/{KNN80HnswIndexIT.java => KNN80HnswIndexTests.java} (97%) rename src/test/java/com/amazon/opendistroforelasticsearch/knn/index/{KNNJNIIT.java => KNNJNITests.java} (98%) create mode 100644 src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNResult.java diff --git a/build.gradle b/build.gradle index ba7921be..e0ca04a4 100644 --- a/build.gradle +++ b/build.gradle @@ -61,8 +61,11 @@ allprojects { } apply plugin: 'elasticsearch.esplugin' +apply plugin: 'idea' apply plugin: 'jacoco' -apply from: 'build-tools/knnplugin-coverage.gradle' +if (!System.properties.containsKey('tests.rest.cluster') && !System.properties.containsKey('tests.cluster')) { + apply from: 'build-tools/knnplugin-coverage.gradle' +} jacoco { toolVersion = "0.8.3" @@ -108,6 +111,7 @@ es_tmp_dir.mkdirs() test { systemProperty 'tests.security.manager', 'false' + systemProperty "java.library.path", "$rootDir/buildSrc" } integTestRunner { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/KNNPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/KNNPlugin.java index 728e0bb5..07e32922 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/KNNPlugin.java @@ -71,6 +71,7 @@ import java.util.Optional; import java.util.function.Supplier; +import static com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNStatsConfig.KNN_STATS; import static java.util.Collections.singletonList; /** @@ -127,29 +128,7 @@ public Collection createComponents(Client client, ClusterService cluster KNNIndexCache.setResourceWatcherService(resourceWatcherService); KNNSettings.state().initialize(client, clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); - - Map> stats = ImmutableMap.>builder() - .put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) - .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::missCount))) - .put(StatNames.LOAD_SUCCESS_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::loadSuccessCount))) - .put(StatNames.LOAD_EXCEPTION_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::loadExceptionCount))) - .put(StatNames.TOTAL_LOAD_TIME.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::totalLoadTime))) - .put(StatNames.EVICTION_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::evictionCount))) - .put(StatNames.GRAPH_MEMORY_USAGE.getName(), new KNNStat<>(false, - new KNNCacheSupplier<>(KNNIndexCache::getWeightInKilobytes))) - .put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, - new KNNCacheSupplier<>(KNNIndexCache::isCacheCapacityReached))) - .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, - new KNNCircuitBreakerSupplier())).build(); - - knnStats = new KNNStats(stats); - + knnStats = new KNNStats(KNN_STATS); return ImmutableList.of(knnStats); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/stats/KNNStatsConfig.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/stats/KNNStatsConfig.java new file mode 100644 index 00000000..1ae6b40f --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/stats/KNNStatsConfig.java @@ -0,0 +1,47 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.knn.plugin.stats; + +import com.amazon.opendistroforelasticsearch.knn.index.KNNIndexCache; +import com.amazon.opendistroforelasticsearch.knn.plugin.stats.suppliers.KNNCacheSupplier; +import com.amazon.opendistroforelasticsearch.knn.plugin.stats.suppliers.KNNCircuitBreakerSupplier; +import com.amazon.opendistroforelasticsearch.knn.plugin.stats.suppliers.KNNInnerCacheStatsSupplier; +import com.google.common.cache.CacheStats; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +public class KNNStatsConfig { + public static Map> KNN_STATS = ImmutableMap.>builder() + .put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, + new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) + .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, + new KNNInnerCacheStatsSupplier(CacheStats::missCount))) + .put(StatNames.LOAD_SUCCESS_COUNT.getName(), new KNNStat<>(false, + new KNNInnerCacheStatsSupplier(CacheStats::loadSuccessCount))) + .put(StatNames.LOAD_EXCEPTION_COUNT.getName(), new KNNStat<>(false, + new KNNInnerCacheStatsSupplier(CacheStats::loadExceptionCount))) + .put(StatNames.TOTAL_LOAD_TIME.getName(), new KNNStat<>(false, + new KNNInnerCacheStatsSupplier(CacheStats::totalLoadTime))) + .put(StatNames.EVICTION_COUNT.getName(), new KNNStat<>(false, + new KNNInnerCacheStatsSupplier(CacheStats::evictionCount))) + .put(StatNames.GRAPH_MEMORY_USAGE.getName(), new KNNStat<>(false, + new KNNCacheSupplier<>(KNNIndexCache::getWeightInKilobytes))) + .put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, + new KNNCacheSupplier<>(KNNIndexCache::isCacheCapacityReached))) + .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, + new KNNCircuitBreakerSupplier())).build(); +} \ No newline at end of file diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/BaseKNNIntegTestIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/BaseKNNIntegTestIT.java index cff3ba52..37e85dba 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/BaseKNNIntegTestIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/BaseKNNIntegTestIT.java @@ -15,82 +15,298 @@ package com.amazon.opendistroforelasticsearch.knn.index; -import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.SearchType; -import org.elasticsearch.action.support.WriteRequest; +import com.amazon.opendistroforelasticsearch.knn.plugin.KNNPlugin; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.test.rest.ESRestTestCase; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; -public class BaseKNNIntegTestIT extends ESIntegTestCase { - protected void addKnnDoc(String index, String docId, Object[] vector) throws IOException { - IndexResponse response = client().prepareIndex(index, "_doc", docId) - .setSource(XContentFactory.jsonBuilder() - .startObject() - .array("my_vector", vector) - .field("price", 10) - .endObject()) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get(); - if(!response.status().equals(RestStatus.OK) && !response.status().equals(RestStatus.CREATED)) { - fail("Bad response while adding doc"); - } +/** + * Base class for integration tests for KNN plugin. Contains several methods for testing KNN ES functionality. + */ +public class BaseKNNIntegTestIT extends ESRestTestCase { + public static final String INDEX_NAME = "test_index"; + public static final String FIELD_NAME = "test_field"; + + /** + * Create KNN Index with default settings + */ + protected void createKnnIndex(String index, String mapping) throws IOException { + createIndex(index, getKNNDefaultIndexSettings()); + putMappingRequest(index, mapping); } - protected void addKnnDocWithField(String index, String docId, Object[] vector, String fieldname) throws IOException { - IndexResponse response = client().prepareIndex(index, "_doc", docId) - .setSource(XContentFactory.jsonBuilder() - .startObject() - .array(fieldname, vector) - .field("price", 10) - .endObject()) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get(); - if(!response.status().equals(RestStatus.OK) && !response.status().equals(RestStatus.CREATED)) { - fail("Bad response while adding doc"); - } + /** + * Create KNN Index + */ + protected void createKnnIndex(String index, Settings settings, String mapping) throws IOException { + createIndex(index, settings); + putMappingRequest(index, mapping); } - protected SearchResponse searchKNNIndex(String index, int resultSize, KNNQueryBuilder knnQueryBuilder) { - logger.info("Searching KNN index " + index ); - SearchResponse searchResponse = client().prepareSearch(index) - .setSearchType(SearchType.QUERY_THEN_FETCH) - .setQuery(knnQueryBuilder) // Query - .setSize(resultSize) - .setExplain(true) - .get(); - assertEquals(searchResponse.status(), RestStatus.OK); - return searchResponse; + /** + * Run KNN Search on Index + */ + protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws + IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject().endObject(); + + Request request = new Request( + "POST", + "/" + index + "/_search" + ); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + request.setJsonEntity(Strings.toString(builder)); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + return response; } - protected void createKnnIndex(String index, Settings settings) { - createIndex(index, settings==null ? indexSettings() : settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); - request.source("my_vector", "type=knn_vector,dimension=2"); - ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); + /** + * Parse the response of KNN search into a List of KNNResults + */ + protected List parseSearchResponse(String responseBody, String fieldName) throws IOException { + @SuppressWarnings("unchecked") + List hits = (List) ((Map)createParser(XContentType.JSON.xContent(), + responseBody).map().get("hits")).get("hits"); + + @SuppressWarnings("unchecked") + List knnSearchResponses = hits.stream().map(hit -> { + @SuppressWarnings("unchecked") + Float[] vector = Arrays.stream( + ((ArrayList) ((Map) + ((Map) hit).get("_source")).get(fieldName)).toArray()) + .map(Object::toString) + .map(Float::valueOf) + .toArray(Float[]::new); + return new KNNResult((String) ((Map) hit).get("_id"), vector); + } + ).collect(Collectors.toList()); + + return knnSearchResponses; } - protected Settings createIndexDefaultSettings() { - Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) - .build(); - return settings; + /** + * Delete KNN index + */ + protected void deleteKNNIndex(String index) throws IOException { + Request request = new Request( + "DELETE", + "/" + index + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); } - protected void createIndexAndFieldWithDefaultSettings() { - Settings settings = createIndexDefaultSettings(); - String index = "testindex"; - createIndex(index, settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); - request.source("my_vector", "type=knn_vector,dimension=4"); - ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); + /** + * For a given index, make a mapping request + */ + protected void putMappingRequest(String index, String mapping) throws IOException { + // Put KNN mapping + Request request = new Request( + "PUT", + "/" + index + "/_mapping" + ); + + request.setJsonEntity(mapping); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Utility to create a Knn Index Mapping + */ + protected String createKnnIndexMapping(String fieldName, Integer dimensions) throws IOException { + return Strings.toString(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimensions.toString()) + .endObject() + .endObject() + .endObject()); + } + + /** + * Force merge KNN index segments + */ + protected void forceMergeKnnIndex(String index) throws Exception { + Request request = new Request( + "POST", + "/" + index + "/_refresh" + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + request = new Request( + "POST", + "/" + index + "/_forcemerge" + ); + + request.addParameter("max_num_segments", "1"); + request.addParameter("flush", "true"); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + TimeUnit.SECONDS.sleep(5); // To make sure force merge is completed + } + + /** + * Add a single KNN Doc to an index + */ + protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request( + "POST", + "/" + index + "/_doc/" + docId + "?refresh=true" + ); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject() + .field(fieldName, vector) + .endObject(); + + request.setJsonEntity(Strings.toString(builder)); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Update a KNN Doc with a new vector for the given fieldName + */ + protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + Request request = new Request( + "POST", + "/" + index + "/_doc/" + docId + "?refresh=true" + ); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject() + .field(fieldName, vector) + .endObject(); + + request.setJsonEntity(Strings.toString(builder)); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Delete Knn Doc + */ + protected void deleteKnnDoc(String index, String docId) throws IOException { + // Put KNN mapping + Request request = new Request( + "DELETE", + "/" + index + "/_doc/" + docId + "?refresh" + ); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Utility to update settings + */ + protected void updateClusterSettings(String settingKey, Object value) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("persistent") + .field(settingKey, value) + .endObject() + .endObject(); + Request request = new Request("PUT", "_cluster/settings"); + request.setJsonEntity(Strings.toString(builder)); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + /** + * Return default index settings for index creation + */ + protected Settings getKNNDefaultIndexSettings() { + return Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put("index.knn", true) + .build(); } -} + /** + * Get Stats from KNN Plugin + */ + protected Response getKnnStats(List nodeIds, List stats) throws IOException { + String nodePrefix = ""; + if (!nodeIds.isEmpty()) { + nodePrefix = "/" + String.join(",", nodeIds); + } + + String statsSuffix = ""; + if (!stats.isEmpty()) { + statsSuffix = "/" + String.join(",", stats); + } + + Request request = new Request( + "GET", + KNNPlugin.KNN_BASE_URI + nodePrefix + "/stats" + statsSuffix + ); + + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; + } + + /** + * Parse KNN Cluster stats from response + */ + protected Map parseClusterStatsResponse(String responseBody) throws IOException { + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + responseMap.remove("cluster_name"); + responseMap.remove("_nodes"); + responseMap.remove("nodes"); + return responseMap; + } + + /** + * Parse KNN Node stats from response + */ + protected List> parseNodeStatsResponse(String responseBody) throws IOException { + @SuppressWarnings("unchecked") + Map responseMap = (Map)createParser(XContentType.JSON.xContent(), + responseBody).map().get("nodes"); + + @SuppressWarnings("unchecked") + List> nodeResponses = responseMap.keySet().stream().map(key -> + (Map) responseMap.get(key) + ).collect(Collectors.toList()); + + return nodeResponses; + } +} \ No newline at end of file diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexTests.java similarity index 97% rename from src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java rename to src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexTests.java index a8c4c096..a14a24db 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexTests.java @@ -32,15 +32,14 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.watcher.ResourceWatcherService; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; -@ESIntegTestCase.ClusterScope(scope=ESIntegTestCase.Scope.SUITE, numDataNodes=1) -public class KNN80HnswIndexIT extends ESIntegTestCase { +public class KNN80HnswIndexTests extends ESTestCase { public void testFooter() throws Exception { Directory dir = newFSDirectory(createTempDir()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESIT.java index 115b6bae..c3ee60d2 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESIT.java @@ -15,190 +15,155 @@ package com.amazon.opendistroforelasticsearch.knn.index; -import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.action.delete.DeleteResponse; -import org.elasticsearch.action.support.WriteRequest; +import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.MapperParsingException; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.common.xcontent.XContentFactory; + +import java.io.IOException; import static org.hamcrest.Matchers.containsString; -@ESIntegTestCase.ClusterScope(scope=ESIntegTestCase.Scope.SUITE, numDataNodes=1) public class KNNESIT extends BaseKNNIntegTestIT { - - @Override - public Settings indexSettings() { - return Settings.builder() - .put(super.indexSettings()) - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build(); - } - /** * Able to add docs to KNN index */ public void testAddKNNDoc() throws Exception { - createKnnIndex("testindex", null); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); Float[] vector = {6.0f, 6.0f}; - addKnnDoc("testindex", "1", vector); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } /** * Able to update docs in KNN index */ public void testUpdateKNNDoc() throws Exception { - createKnnIndex("testindex", null); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); Float[] vector = {6.0f, 6.0f}; - addKnnDoc("testindex", "1", vector); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // update Float[] updatedVector = {8.0f, 8.0f}; - addKnnDoc("testindex", "1", vector); + updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } /** * Able to delete docs in KNN index */ public void testDeleteKNNDoc() throws Exception { - createKnnIndex("testindex", null); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); Float[] vector = {6.0f, 6.0f}; - addKnnDoc("testindex", "1", vector); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // delete knn doc - DeleteResponse response = client().prepareDelete("testindex", "_doc", "1") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get(); - assertEquals(RestStatus.OK, response.status()); + deleteKnnDoc(INDEX_NAME, "1"); } /** * Create knn index with valid index algo params */ - public void testCreateIndexWithValidAlgoParams() throws Exception { + public void testCreateIndexWithValidAlgoParams() { try { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .put("index.knn.algo_param.m", 32) .put("index.knn.algo_param.ef_construction", 400) .build(); - createKnnIndex("testindex", settings); + createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); Float[] vector = {6.0f, 6.0f}; - addKnnDoc("testindex", "1", vector); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } catch (Exception ex) { - fail("Exception not expected as valid index arguements passed"); + fail("Exception not expected as valid index arguements passed: " + ex); } } /** * Create knn index with valid query algo params */ - public void testQueryIndexWithValidQueryAlgoParams() throws Exception { - try { - Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) - .put("index.knn.algo_param.ef_search", 300) - .build(); - createKnnIndex("testindex", settings); - Float[] vector = {6.0f, 6.0f}; - addKnnDoc("testindex", "1", vector); - - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("my_vector", queryVector, k); - searchKNNIndex("testindex", k, knnQueryBuilder); - } catch (Exception ex) { - fail("Exception not expected as valid index arguements passed"); - } + public void testQueryIndexWithValidQueryAlgoParams() throws IOException { + Settings settings = Settings.builder() + .put(getKNNDefaultIndexSettings()) + .put("index.knn.algo_param.ef_search", 300) + .build(); + createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); + Float[] vector = {6.0f, 6.0f}; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); + + float[] queryVector = {1.0f, 1.0f}; // vector to be queried + int k = 1; // nearest 1 neighbor + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); + searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); } public void testIndexingVectorValidationDifferentSizes() throws Exception { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .build(); - String index = "testindex"; - createIndex(index, settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); - - request.source("my_vector", "type=knn_vector,dimension=4"); - ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); + createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 4)); /** * valid case with 4 dimension */ Float[] vector = {6.0f, 7.0f, 8.0f, 9.0f}; - addKnnDoc(index, "1", vector); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); /** * invalid case with lesser dimension than original (3 < 4) */ Float[] vector1 = {6.0f, 7.0f, 8.0f}; - Exception ex = expectThrows(MapperParsingException.class, () -> addKnnDoc(index, "2", vector1)); - assertThat(ex.getCause().getMessage(), containsString("Vector dimension mismatch. Expected: 4, Given: 3")); + ResponseException ex = expectThrows(ResponseException.class, () -> + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector1)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), + containsString("Vector dimension mismatch. Expected: 4, Given: 3")); /** * invalid case with more dimension than original (5 > 4) */ Float[] vector2 = {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; - ex = expectThrows(MapperParsingException.class, () -> addKnnDoc(index, "3", vector2)); - assertThat(ex.getCause().getMessage(), containsString("Vector dimension mismatch. Expected: 4, Given: 5")); + ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector2)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), + containsString("Vector dimension mismatch. Expected: 4, Given: 5")); } public void testVectorMappingValidationNoDimension() throws Exception { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .build(); - String index = "testindex"; - createIndex(index, settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); + String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .endObject() + .endObject() + .endObject()); - request.source("my_vector", "type=knn_vector"); - Exception ex = expectThrows(MapperParsingException.class, () -> ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet())); - assertThat(ex.getMessage(), containsString("Dimension value missing for vector: my_vector")); + Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, mapping)); + assertThat(ex.getMessage(), containsString("Dimension value missing for vector: " + FIELD_NAME)); } - public void testVectorMappingValidationInvalidDimension() throws Exception { + public void testVectorMappingValidationInvalidDimension() { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .build(); - String index = "testindex"; - createIndex(index, settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); - - request.source("my_vector", "type=knn_vector,dimension=" + (KNNVectorFieldMapper.MAX_DIMENSION + 1)); - Exception ex = expectThrows(MapperParsingException.class, () -> ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet())); + Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, + createKnnIndexMapping(FIELD_NAME, KNNVectorFieldMapper.MAX_DIMENSION + 1))); assertThat(ex.getMessage(), containsString("Dimension value cannot be greater than " + - KNNVectorFieldMapper.MAX_DIMENSION + " for vector: my_vector")); + KNNVectorFieldMapper.MAX_DIMENSION + " for vector: " + FIELD_NAME)); } public void testVectorMappingValidationUpdateDimension() throws Exception { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .build(); - String index = "testindex"; - createIndex(index, settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); + createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 4)); - request.source("my_vector", "type=knn_vector,dimension=4"); - ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); - - - request.source("my_vector", "type=knn_vector,dimension=5"); - Exception ex = expectThrows(MapperParsingException.class, () -> ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet())); + Exception ex = expectThrows(ResponseException.class, () -> + putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 5))); assertThat(ex.getMessage(), containsString("Dimension value cannot be updated. Previous value: 4, Current value: 5")); } @@ -207,51 +172,56 @@ public void testVectorMappingValidationUpdateDimension() throws Exception { */ public void testVectorMappingValidationMultiFieldsDifferentDimension() throws Exception { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .build(); - String index = "testindex"; - createIndex(index, settings); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); - - request.source("my_vector", "type=knn_vector,dimension=4"); - request.source("my_vector1", "type=knn_vector,dimension=5"); - - ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); + String f4 = FIELD_NAME + "-4"; + String f5 = FIELD_NAME + "-5"; + String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(f4) + .field("type", "knn_vector") + .field("dimension", "4") + .endObject() + .startObject(f5) + .field("type", "knn_vector") + .field("dimension", "5") + .endObject() + .endObject() + .endObject()); + + createKnnIndex(INDEX_NAME, settings, mapping); /** * valid case with 4 dimension */ Float[] vector = {6.0f, 7.0f, 8.0f, 9.0f}; - addKnnDoc(index, "1", vector); + addKnnDoc(INDEX_NAME, "1", f4, vector); /** * valid case with 5 dimension */ Float[] vector1 = {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; - addKnnDocWithField(index, "1", vector1, "my_vector1"); + updateKnnDoc(INDEX_NAME, "1", f5, vector1); } - public void testInvalidIndexHnswAlgoParams() throws Exception { - String index = "testindex"; + public void testInvalidIndexHnswAlgoParams() { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .put("index.knn.algo_param.m", "-1") .build(); - Exception ex = expectThrows(IllegalArgumentException.class, () -> createIndex(index, settings)); + Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, + createKnnIndexMapping(FIELD_NAME, 2))); assertThat(ex.getMessage(), containsString("Failed to parse value [-1] for setting [index.knn.algo_param.m]")); } - public void testInvalidQueryHnswAlgoParams() throws Exception { - String index = "testindex"; + public void testInvalidQueryHnswAlgoParams() { Settings settings = Settings.builder() - .put(super.indexSettings()) - .put("index.knn", true) + .put(getKNNDefaultIndexSettings()) .put("index.knn.algo_param.ef_search", "-1") .build(); - Exception ex = expectThrows(IllegalArgumentException.class, () -> createIndex(index, settings)); + Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, + createKnnIndexMapping(FIELD_NAME, 2))); assertThat(ex.getMessage(), containsString("Failed to parse value [-1] for setting [index.knn.algo_param.ef_search]")); } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESSettingsTestIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESSettingsTestIT.java index adac8bc1..0121715e 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESSettingsTestIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNESSettingsTestIT.java @@ -15,136 +15,60 @@ package com.amazon.opendistroforelasticsearch.knn.index; -import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.rest.ESRestTestCase; - -import java.util.Map; import static org.hamcrest.Matchers.containsString; -public class KNNESSettingsTestIT extends ESRestTestCase { - - public void createKNNIndex(String indexName) throws Exception { - Settings settings = Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build(); - String mapping = "\"properties\":{\"my_vector\":{\"type\":\"knn_vector\",\"dimension\":\"2\"}}"; - createIndex(indexName, settings, mapping); - } - - public void indexKNNDoc(String indexName, float[] vector) throws Exception { - Request indexRequest = new Request("PUT", "/" + indexName + "/_doc/1"); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.field("my_vector", vector); - builder.endObject(); - indexRequest.setJsonEntity(Strings.toString(builder)); - client().performRequest(indexRequest); - } - - public Response makeGenericKnnQuery(String index, float[] vector, int k) throws Exception { - Request request = new Request("POST", "/" + index + "/_search" - ); - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .startObject("query") - .startObject("knn") - .startObject("my_vector") - .field("vector", vector) - .field("k", k) - .endObject() - .endObject() - .endObject() - .endObject(); - - request.setJsonEntity(Strings.toString(builder)); - return client().performRequest(request); - } - - public void updateSettings(String settingKey, Object value) throws Exception { - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("persistent") - .field(settingKey, value) - .endObject() - .endObject(); - Request request = new Request("PUT", "_cluster/settings"); - request.setJsonEntity(Strings.toString(builder)); - Response response = client().performRequest(request); - assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - } - - public void getClusterSettings() throws Exception { - Map getResponse = entityAsMap(adminClient().performRequest(new Request("GET", "/_cluster/settings"))); - Response response = client().performRequest(new Request("GET", "/_cluster/settings")); - XContentType.fromMediaTypeOrFormat(response.getEntity().getContentType().getValue()); - XContentParser xcp = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response.getEntity().getContent()); - Map mp = xcp.map(); - } - +public class KNNESSettingsTestIT extends BaseKNNIntegTestIT { /** * KNN Index writes should be blocked when the plugin disabled * @throws Exception Exception from test */ public void testIndexWritesPluginDisabled() throws Exception { - String indexName = "testindex"; - createKNNIndex(indexName); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - float[] vector = {6.0f, 6.0f}; - indexKNNDoc(indexName, vector); + Float[] vector = {6.0f, 6.0f}; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); float[] qvector = {1.0f, 2.0f}; - Response response = makeGenericKnnQuery(indexName, qvector, 1); + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); //disable plugin - updateSettings(KNNSettings.KNN_PLUGIN_ENABLED, false); + updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, false); // indexing should be blocked Exception ex = expectThrows(ResponseException.class, - () -> indexKNNDoc(indexName, vector)); + () -> addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector)); assertThat(ex.getMessage(), containsString("KNN plugin is disabled")); //enable plugin - updateSettings(KNNSettings.KNN_PLUGIN_ENABLED, true); - indexKNNDoc(indexName, vector); + updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, true); + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector); } public void testQueriesPluginDisabled() throws Exception { - String indexName = "testindex"; - createKNNIndex(indexName); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - float[] vector = {6.0f, 6.0f}; - indexKNNDoc(indexName, vector); + Float[] vector = {6.0f, 6.0f}; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); float[] qvector = {1.0f, 2.0f}; - Response response = makeGenericKnnQuery(indexName, qvector, 1); + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); //update settings - updateSettings(KNNSettings.KNN_PLUGIN_ENABLED, false); + updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, false); // indexing should be blocked Exception ex = expectThrows(ResponseException.class, - () -> makeGenericKnnQuery(indexName, qvector, 1)); + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1)); assertThat(ex.getMessage(), containsString("KNN plugin is disabled")); //enable plugin - updateSettings(KNNSettings.KNN_PLUGIN_ENABLED, true); - makeGenericKnnQuery(indexName, qvector, 1); + updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, true); + searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNIIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java similarity index 98% rename from src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNIIT.java rename to src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java index 0e8b6473..08a37a9c 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNIIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java @@ -22,7 +22,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; -import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; import java.nio.file.Paths; import java.security.AccessController; @@ -31,8 +31,8 @@ import java.util.Map; import java.util.stream.Collectors; -public class KNNJNIIT extends ESIntegTestCase { - private static final Logger logger = LogManager.getLogger(KNNJNIIT.class); +public class KNNJNITests extends ESTestCase { + private static final Logger logger = LogManager.getLogger(KNNJNITests.class); public void testCreateHnswIndex() throws Exception { int[] docs = {0, 1, 2}; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java index ed17993c..a3823e7d 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNMapperSearcherIT.java @@ -15,116 +15,38 @@ package com.amazon.opendistroforelasticsearch.knn.index; +import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.admin.indices.forcemerge.ForceMergeRequest; -import org.elasticsearch.action.admin.indices.forcemerge.ForceMergeResponse; -import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; -import org.elasticsearch.action.delete.DeleteResponse; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.SearchType; -import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; - -import java.io.IOException; +import org.elasticsearch.client.Response; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.TimeUnit; -@ESIntegTestCase.ClusterScope(scope=ESIntegTestCase.Scope.SUITE, numDataNodes=1) -public class KNNMapperSearcherIT extends ESIntegTestCase { +public class KNNMapperSearcherIT extends BaseKNNIntegTestIT { private static final Logger logger = LogManager.getLogger(KNNMapperSearcherIT.class); - @Override - public Settings indexSettings() { - return Settings.builder() - .put(super.indexSettings()) - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build(); - } - /** * Test Data set */ private void addTestData() throws Exception { - Float[] f1 = {6.0f, 6.0f}; - addKnnDoc("testindex", "1", f1); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); Float[] f2 = {2.0f, 2.0f}; - addKnnDoc("testindex", "2", f2); + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); Float[] f3 = {4.0f, 4.0f}; - addKnnDoc("testindex", "3", f3); + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); Float[] f4 = {3.0f, 3.0f}; - addKnnDoc("testindex", "4", f4); - } - - private void addKnnDoc(String index, String docId, Object[] vector) throws IOException { - IndexResponse response = client().prepareIndex(index, "_doc", docId) - .setSource(XContentFactory.jsonBuilder() - .startObject() - .array("my_vector", vector) - .field("price", 10) - .endObject()) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get(); - if(!response.status().equals(RestStatus.OK) && !response.status().equals(RestStatus.CREATED)) { - fail("Bad response while adding doc"); - } - } - - - private void createKnnIndex(String index) { - createIndex(index, indexSettings()); - PutMappingRequest request = new PutMappingRequest(index).type("_doc"); - request.source("my_vector", "type=knn_vector,dimension=2"); - ElasticsearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); - } - - private SearchResponse searchKNNIndex(String index, int resultSize, KNNQueryBuilder knnQueryBuilder) { - logger.info("Searching KNN index " + index ); - SearchResponse searchResponse = client().prepareSearch(index) - .setSearchType(SearchType.QUERY_THEN_FETCH) - .setQuery(knnQueryBuilder) // Query - .setSize(resultSize) - .setExplain(true) - .get(); - assertEquals(searchResponse.status(), RestStatus.OK); - return searchResponse; - } - - private void forceMergeKnnIndex(String index) throws Exception { - client().admin().indices().refresh(new RefreshRequest(index)).actionGet(); - ForceMergeRequest forceMergeRequest = new ForceMergeRequest(index); - forceMergeRequest.maxNumSegments(1); - forceMergeRequest.flush(true); - ForceMergeResponse forceMergeResponse =client().admin().indices().forceMerge(forceMergeRequest).actionGet(); - assertEquals(forceMergeResponse.getStatus(), RestStatus.OK); - TimeUnit.SECONDS.sleep(5); // To make sure force merge is completed - } - - private SearchResponse doQueryKnn(float[] queryVector, int k) throws Exception { - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("my_vector", queryVector, k); - forceMergeKnnIndex("testindex"); - return searchKNNIndex("testindex", k, knnQueryBuilder); + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); } - public void testKNNResultsWithForceMerge() throws Exception { - createKnnIndex("testindex"); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); + forceMergeKnnIndex(INDEX_NAME); /** * Query params @@ -132,21 +54,19 @@ public void testKNNResultsWithForceMerge() throws Exception { float[] queryVector = {1.0f, 1.0f}; // vector to be queried int k = 1; // nearest 1 neighbor - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("my_vector", queryVector, k); - - forceMergeKnnIndex("testindex"); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); - SearchResponse searchResponse; - searchResponse = searchKNNIndex("testindex", 10, knnQueryBuilder); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "2"); //Vector of DocId 2 is closest to the query + assertEquals(k, results.size()); + for(KNNResult result : results) { + assertEquals("2", result.getDocId()); } - ElasticsearchAssertions.assertHitCount(searchResponse, k); } public void testKNNResultsWithoutForceMerge() throws Exception { - createKnnIndex("testindex"); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); /** @@ -154,16 +74,15 @@ public void testKNNResultsWithoutForceMerge() throws Exception { */ float[] queryVector = {2.0f, 2.0f}; // vector to be queried int k = 3; //nearest 3 neighbors - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("my_vector", queryVector, k); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); - SearchResponse searchResponse; - - searchResponse = searchKNNIndex("testindex", k, knnQueryBuilder); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("2", "4", "3"); List actualDocids = new ArrayList<>(); - for(SearchHit hit : searchResponse.getHits()) { - actualDocids.add(hit.getId()); + for(KNNResult result : results) { + actualDocids.add(result.getDocId()); } assertEquals(actualDocids.size(), k); @@ -171,97 +90,107 @@ public void testKNNResultsWithoutForceMerge() throws Exception { } public void testKNNResultsWithNewDoc() throws Exception { - createKnnIndex("testindex"); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); float[] queryVector = {1.0f, 1.0f}; // vector to be queried int k = 1; // nearest 1 neighbor - SearchResponse searchResponse; - searchResponse = doQueryKnn(queryVector, k); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "2"); //Vector of DocId 2 is closest to the query + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("2", result.getDocId()); //Vector of DocId 2 is closest to the query } - ElasticsearchAssertions.assertHitCount(searchResponse, k); /** * Add new doc with vector not nearest than doc 2 */ Float[] newVector = {6.0f, 6.0f}; - addKnnDoc("testindex", "6", newVector); - searchResponse = doQueryKnn(queryVector, k); + addKnnDoc(INDEX_NAME, "6", FIELD_NAME, newVector); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "2"); //Vector of DocId 2 is closest to the query + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("2", result.getDocId()); } - ElasticsearchAssertions.assertHitCount(searchResponse, k); + /** * Add new doc with vector nearest than doc 2 to queryVector */ Float[] newVector1 = {0.5f, 0.5f}; - addKnnDoc("testindex", "7", newVector1); - searchResponse = doQueryKnn(queryVector, k); + addKnnDoc(INDEX_NAME, "7", FIELD_NAME, newVector1); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "7"); //Vector of DocId 7 is closest to the query + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("7", result.getDocId()); } - ElasticsearchAssertions.assertHitCount(searchResponse, k); } public void testKNNResultsWithUpdateDoc() throws Exception { - createKnnIndex("testindex"); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); float[] queryVector = {1.0f, 1.0f}; // vector to be queried int k = 1; // nearest 1 neighbor - SearchResponse searchResponse; - searchResponse = doQueryKnn(queryVector, k); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "2"); //Vector of DocId 2 is closest to the query + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("2", result.getDocId()); //Vector of DocId 2 is closest to the query } - ElasticsearchAssertions.assertHitCount(searchResponse, k); /** * update doc 3 to the nearest */ Float[] updatedVector = {0.1f, 0.1f}; - addKnnDoc("testindex", "3", updatedVector); - searchResponse = doQueryKnn(queryVector, k); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "3"); //Vector of DocId 3 is closest to the query + updateKnnDoc(INDEX_NAME, "3", FIELD_NAME, updatedVector); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("3", result.getDocId()); //Vector of DocId 3 is closest to the query } - ElasticsearchAssertions.assertHitCount(searchResponse, k); } public void testKNNResultsWithDeleteDoc() throws Exception { - createKnnIndex("testindex"); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); float[] queryVector = {1.0f, 1.0f}; // vector to be queried int k = 1; // nearest 1 neighbor - SearchResponse searchResponse; - searchResponse = doQueryKnn(queryVector, k); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "2"); //Vector of DocId 2 is closest to the query + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("2", result.getDocId()); //Vector of DocId 2 is closest to the query } - ElasticsearchAssertions.assertHitCount(searchResponse, k); + /** * delete the nearest doc (doc2) */ - DeleteResponse response = client().prepareDelete("testindex", "_doc", "2") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get(); - assertEquals(RestStatus.OK, response.status()); - - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("my_vector", queryVector, k+1); - searchResponse = searchKNNIndex("testindex", k, knnQueryBuilder); - for(SearchHit hit : searchResponse.getHits()) { - assertEquals(hit.getId(), "4"); //Vector of DocId 4 is closest to the query + deleteKnnDoc(INDEX_NAME, "2"); + + knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k+1); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + + assertEquals(results.size(), k); + for(KNNResult result : results) { + assertEquals("4", result.getDocId()); //Vector of DocId 4 is closest to the query } - ElasticsearchAssertions.assertHitCount(searchResponse, k); } /** @@ -269,33 +198,30 @@ public void testKNNResultsWithDeleteDoc() throws Exception { */ public void testNegativeK() { float[] vector = {1.0f, 2.0f}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", vector, -1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, vector, -1)); } /** * For zero K, query builder should throw Exception */ - public void testZeroK() throws Exception { - createKnnIndex("testindex"); - addTestData(); - - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 0; // nearest 1 neighbor - expectThrows(IllegalArgumentException.class, () -> doQueryKnn(queryVector, k)); + public void testZeroK() { + float[] vector = {1.0f, 2.0f}; + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, vector, 0)); } /** * K > > number of docs */ public void testLargeK() throws Exception { - createKnnIndex("testindex"); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); float[] queryVector = {1.0f, 1.0f}; // vector to be queried int k = KNNQueryBuilder.K_MAX; // nearest 1 neighbor - SearchResponse searchResponse; - searchResponse = doQueryKnn(queryVector, k); - ElasticsearchAssertions.assertHitCount(searchResponse, 4); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertEquals(results.size(), 4); } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNResult.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNResult.java new file mode 100644 index 00000000..d3178520 --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNResult.java @@ -0,0 +1,34 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.knn.index; + +public class KNNResult { + private String docId; + private Float[] vector; + + public KNNResult(String docId, Float[] vector) { + this.docId = docId; + this.vector = vector; + } + + public String getDocId() { + return docId; + } + + public Float[] getVector() { + return vector; + } +} \ No newline at end of file diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 5f0b5a6a..78eb1bf0 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -15,44 +15,34 @@ package com.amazon.opendistroforelasticsearch.knn.plugin.action; -import com.amazon.opendistroforelasticsearch.knn.index.KNNIndexCache; -import com.amazon.opendistroforelasticsearch.knn.plugin.KNNPlugin; -import com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNStat; +import com.amazon.opendistroforelasticsearch.knn.index.BaseKNNIntegTestIT; +import com.amazon.opendistroforelasticsearch.knn.index.KNNQueryBuilder; import com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNStats; import com.amazon.opendistroforelasticsearch.knn.plugin.stats.StatNames; -import com.amazon.opendistroforelasticsearch.knn.plugin.stats.suppliers.KNNCacheSupplier; -import com.amazon.opendistroforelasticsearch.knn.plugin.stats.suppliers.KNNCircuitBreakerSupplier; -import com.amazon.opendistroforelasticsearch.knn.plugin.stats.suppliers.KNNInnerCacheStatsSupplier; -import com.google.common.cache.CacheStats; -import com.google.common.collect.ImmutableMap; import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.rest.ESRestTestCase; import org.junit.rules.DisableOnDebug; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.concurrent.Callable; import org.junit.Before; +import static com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNStatsConfig.KNN_STATS; + /** * Integration tests to check the correctness of RestKNNStatsHandler */ -public class RestKNNStatsHandlerIT extends ESRestTestCase { +public class RestKNNStatsHandlerIT extends BaseKNNIntegTestIT { private static final Logger logger = LogManager.getLogger(RestKNNStatsHandlerIT.class); private boolean isDebuggingTest = new DisableOnDebug(null).isDebugging(); @@ -62,27 +52,7 @@ public class RestKNNStatsHandlerIT extends ESRestTestCase { @Before public void setup() { - Map> stats = ImmutableMap.>builder() - .put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) - .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::missCount))) - .put(StatNames.LOAD_SUCCESS_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::loadSuccessCount))) - .put(StatNames.LOAD_EXCEPTION_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::loadExceptionCount))) - .put(StatNames.TOTAL_LOAD_TIME.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::totalLoadTime))) - .put(StatNames.EVICTION_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::evictionCount))) - .put(StatNames.GRAPH_MEMORY_USAGE.getName(), new KNNStat<>(false, - new KNNCacheSupplier<>(KNNIndexCache::getWeightInKilobytes))) - .put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, - new KNNCacheSupplier<>(KNNIndexCache::isCacheCapacityReached))) - .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, - new KNNCircuitBreakerSupplier())).build(); - - knnStats = new KNNStats(stats); + knnStats = new KNNStats(KNN_STATS); } /** @@ -90,35 +60,12 @@ public void setup() { * @throws IOException throws IOException */ public void testCorrectStatsReturned() throws IOException { - - Request statsRequest = new Request( - "GET", - KNNPlugin.KNN_BASE_URI + "/stats" - ); - - // Check that all of the cluster level metrics are returned - String statsResponseBody = makeRequestAndReturnResponseBody(statsRequest); - Map responseMap = createParser(XContentType.JSON.xContent(), statsResponseBody).map(); - for (String metric : knnStats.getClusterStats().keySet()) { - assertTrue("Cluster metric is not in response: " + metric, responseMap.containsKey(metric)); - } - - // Check node level metrics - @SuppressWarnings("unchecked") - Map nodesResponseMap = (Map)responseMap.get("nodes"); - - // The key associated with the node that made the request - String key = (String)nodesResponseMap.keySet().toArray()[0]; - - @SuppressWarnings("unchecked") - Map metricMap = (Map) nodesResponseMap.get(key); - - // Confirm that all node level metrics are returned - Map> nodeStats = knnStats.getNodeStats(); - assertEquals("Incorrect number of metrics returned", nodeStats.size(), metricMap.size()); - for (String metric : nodeStats.keySet()) { - assertTrue("Metric should not be in response: " + metric, metricMap.containsKey(metric)); - } + Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); + Map clusterStats = parseClusterStatsResponse(responseBody); + assertEquals(knnStats.getClusterStats().keySet(), clusterStats.keySet()); + List> nodeStats = parseNodeStatsResponse(responseBody); + assertEquals(knnStats.getNodeStats().keySet(), nodeStats.get(0).keySet()); } /** @@ -126,112 +73,46 @@ public void testCorrectStatsReturned() throws IOException { * @throws IOException throws IOException */ public void testStatsValueCheck() throws IOException { - // Setup request for stat calls - Request statsRequest = new Request( - "GET", - KNNPlugin.KNN_BASE_URI + "/stats" - ); - - // Get initial stats as baseline - String statsResponseBody = makeRequestAndReturnResponseBody(statsRequest); - Map responseMap0 = createParser(XContentType.JSON.xContent(), statsResponseBody).map(); - assertNotNull("Stats response 0 is null", responseMap0); - - @SuppressWarnings("unchecked") - Map nodesResponseMap0 = (Map)responseMap0.get("nodes"); - assertNotNull("Stats node response 0 is null", nodesResponseMap0); + Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); - Object[] keys = nodesResponseMap0.keySet().toArray(); - assertTrue("No node keys returned", keys.length > 0); - String key = (String) keys[0]; - - @SuppressWarnings("unchecked") - Map metricMap0 = (Map) nodesResponseMap0.get(key); - Integer initialHitCount = (Integer) metricMap0.get(StatNames.HIT_COUNT.getName()); - Integer initialMissCount = (Integer) metricMap0.get(StatNames.MISS_COUNT.getName()); + Map nodeStats0 = parseNodeStatsResponse(responseBody).get(0); + Integer hitCount0 = (Integer) nodeStats0.get(StatNames.HIT_COUNT.getName()); + Integer missCount0 = (Integer) nodeStats0.get(StatNames.MISS_COUNT.getName()); // Setup index - Settings settings = Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build(); - String index = "testindex"; - createIndex(index, settings); - - // Put KNN mapping - Request mappingRequest = new Request( - "PUT", - "/" + index + "/_mapping" - ); - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("my_vector") - .field("type", "knn_vector") - .field("dimension", "2") - .endObject() - .endObject() - .endObject(); - - mappingRequest.setJsonEntity(Strings.toString(builder)); - Response response = client().performRequest(mappingRequest); - assertEquals(mappingRequest.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); // Index test document - Request indexRequest = new Request( - "POST", - "/" + index + "/_doc/1?refresh=true" // refresh=true ensures document is searchable immediately after index - ); - - float[] vector = {6.0f, 6.0f}; - - builder = XContentFactory.jsonBuilder().startObject() - .field("my_vector", vector) - .endObject(); - - indexRequest.setJsonEntity(Strings.toString(builder)); - - response = client().performRequest(indexRequest); - assertEquals(indexRequest.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Float[] vector = {6.0f, 6.0f}; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // First search: Ensure that misses=1 - response = makeGenericKnnQuery(index, vector, 1); - assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + float[] qvector = {6.0f, 6.0f}; + searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); - statsResponseBody = makeRequestAndReturnResponseBody(statsRequest); - Map responseMap1 = createParser(XContentType.JSON.xContent(), statsResponseBody).map(); - assertNotNull("Stats response 1 is null", responseMap1); + response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + responseBody = EntityUtils.toString(response.getEntity()); - @SuppressWarnings("unchecked") - Map nodesResponseMap1 = (Map)responseMap1.get("nodes"); - assertNotNull("Stats node response 1 is null", nodesResponseMap1); + Map nodeStats1 = parseNodeStatsResponse(responseBody).get(0); + Integer hitCount1 = (Integer) nodeStats1.get(StatNames.HIT_COUNT.getName()); + Integer missCount1 = (Integer) nodeStats1.get(StatNames.MISS_COUNT.getName()); - @SuppressWarnings("unchecked") - Map metricMap1 = (Map) nodesResponseMap1.get(key); - assertNotNull("Stats metric map response 1 is null", metricMap1); - assertTrue("Miss and hit count does not return expected", - (Integer) metricMap1.get(StatNames.MISS_COUNT.getName()) == initialMissCount + 1 && - metricMap1.get(StatNames.HIT_COUNT.getName()) == initialHitCount); + assertEquals(hitCount0, hitCount1); + assertEquals((Integer) (missCount0 + 1), missCount1); // Second search: Ensure that hits=1 - response = makeGenericKnnQuery(index, vector, 1); - assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); - statsResponseBody = makeRequestAndReturnResponseBody(statsRequest); - Map responseMap2 = createParser(XContentType.JSON.xContent(), statsResponseBody).map(); - assertNotNull("Stats response 2 is null", responseMap2); + response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + responseBody = EntityUtils.toString(response.getEntity()); - @SuppressWarnings("unchecked") - Map nodesResponseMap2 = (Map)responseMap2.get("nodes"); - assertNotNull("Stats node response 2 is null", nodesResponseMap2); + Map nodeStats2 = parseNodeStatsResponse(responseBody).get(0); + Integer hitCount2 = (Integer) nodeStats2.get(StatNames.HIT_COUNT.getName()); + Integer missCount2 = (Integer) nodeStats2.get(StatNames.MISS_COUNT.getName()); - @SuppressWarnings("unchecked") - Map metricMap2 = (Map) nodesResponseMap2.get(key); - assertNotNull("Stats metric map response 2 is null", metricMap2); - assertTrue("Miss and hit count does not return expected", - (Integer) metricMap2.get(StatNames.HIT_COUNT.getName()) == initialHitCount + 1 && - (Integer) metricMap2.get(StatNames.MISS_COUNT.getName()) == initialMissCount + 1); + assertEquals(missCount1, missCount2); + assertEquals((Integer) (hitCount1 + 1), hitCount2); } /** @@ -242,44 +123,23 @@ public void testValidMetricsStats() throws IOException { // Create request that only grabs two of the possible metrics String metric1 = StatNames.HIT_COUNT.getName(); String metric2 = StatNames.MISS_COUNT.getName(); - Request request = new Request( - "GET", - KNNPlugin.KNN_BASE_URI + "/stats/" + metric1 + "," + metric2 - ); - - Response response = client().performRequest(request); - // Check that the call succeeded - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Response response = getKnnStats(Collections.emptyList(), Arrays.asList(metric1, metric2)); + String responseBody = EntityUtils.toString(response.getEntity()); + Map nodeStats = parseNodeStatsResponse(responseBody).get(0); // Check that metric 1 and 2 are the only metrics in the response - String responseBody = EntityUtils.toString(response.getEntity());; - - @SuppressWarnings("unchecked") - Map nodesResponseMap = (Map)createParser(XContentType.JSON.xContent(), - responseBody).map().get("nodes"); - - String key = (String)nodesResponseMap.keySet().toArray()[0]; - - @SuppressWarnings("unchecked") - Set metricSet = ((Map) nodesResponseMap.get(key)).keySet(); - - assertEquals("Incorrect number of metrics returned", 2, metricSet.size()); - assertTrue("does not contain correct metric: " + metric1, metricSet.contains(metric1)); - assertTrue("does not contain correct metrics: " + metric2, metricSet.contains(metric2)); + assertEquals("Incorrect number of metrics returned", 2, nodeStats.size()); + assertTrue("does not contain correct metric: " + metric1, nodeStats.keySet().contains(metric1)); + assertTrue("does not contain correct metric: " + metric2, nodeStats.keySet().contains(metric2)); } /** * Test checks that handler correctly returns failure on an invalid metric - * @throws Exception throws exception */ - public void testInvalidMetricsStats() throws Exception { - Request request = new Request( - "GET", - KNNPlugin.KNN_BASE_URI + "/stats/invalid_metric" - ); - - assertFailWith(ResponseException.class, null, () -> client().performRequest(request)); + public void testInvalidMetricsStats() { + expectThrows(ResponseException.class, () -> getKnnStats(Collections.emptyList(), + Collections.singletonList("invalid_metric"))); } /** @@ -287,14 +147,10 @@ public void testInvalidMetricsStats() throws Exception { * @throws IOException throws IOException */ public void testValidNodeIdStats() throws IOException { - Request request = new Request( - "GET", - KNNPlugin.KNN_BASE_URI + "/_local/stats" - ); - - Response response = client().performRequest(request); - - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Response response = getKnnStats(Collections.singletonList("_local"), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); + List> nodeStats = parseNodeStatsResponse(responseBody); + assertEquals(1, nodeStats.size()); } /** @@ -302,85 +158,10 @@ public void testValidNodeIdStats() throws IOException { * @throws Exception throws Exception */ public void testInvalidNodeIdStats() throws Exception { - Request request = new Request( - "GET", - KNNPlugin.KNN_BASE_URI + "/invalid_nodeid/stats" - ); - - Response response = client().performRequest(request); - - // Check that the call succeeded, but had no nodes return values - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - String responseBody = EntityUtils.toString(response.getEntity());; - - @SuppressWarnings("unchecked") - Map nodesResponseMap = (Map)createParser(XContentType.JSON.xContent(), - responseBody).map().get("nodes"); - - assertEquals("Incorrect number of metrics returned", 0, nodesResponseMap.keySet().size()); - } - - /** - * Assertion checks to see if callable fails as expected - * @param clazz Exception class expected - * @param message Message thrown on failure to correctly fail - * @param callable Lambda to call - * @param Class template - * @param Callable template - * @throws Exception throws exception - */ - private static void assertFailWith(Class clazz, String message, Callable callable) throws Exception { - try { - callable.call(); - } catch (Throwable e) { - if (e.getClass() != clazz) { - throw e; - } - if (message != null && !e.getMessage().contains(message)) { - throw e; - } - } - } - - /** - * Helper method to make a request, assert that it is valid, and return the response body - * @param request request to be executed - * @return response body - * @throws IOException throws IO exception - */ - private String makeRequestAndReturnResponseBody(Request request) throws IOException { - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - return EntityUtils.toString(response.getEntity()); - } - - /** - * Helper method to generate a generic knn query for testing purposes - * @param index index name - * @param vector vector to be searched for - * @param k k nearest neighbors - * @throws IOException throws IO exception - */ - private Response makeGenericKnnQuery(String index, float[] vector, int k) throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_search" - ); - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .startObject("query") - .startObject("knn") - .startObject("my_vector") - .field("vector", vector) - .field("k",k) - .endObject() - .endObject() - .endObject() - .endObject(); - - request.setJsonEntity(Strings.toString(builder)); - return client().performRequest(request); + Response response = getKnnStats(Collections.singletonList("invalid_node"), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); + List> nodeStats = parseNodeStatsResponse(responseBody); + assertEquals(0, nodeStats.size()); } // Useful settings when debugging to prevent timeouts