From 46e98836fd55e09e20e9c531b74755f2d5c18e58 Mon Sep 17 00:00:00 2001 From: panguixin Date: Fri, 29 Mar 2024 00:25:46 +0800 Subject: [PATCH] fix test Signed-off-by: panguixin --- .../org/opensearch/knn/index/FaissIT.java | 6 +- .../opensearch/knn/index/LuceneEngineIT.java | 9 +- .../org/opensearch/knn/index/NmslibIT.java | 4 +- .../opensearch/knn/index/OpenSearchIT.java | 5 +- .../knn/plugin/script/KNNScriptScoringIT.java | 410 ++++++------------ .../org/opensearch/knn/KNNRestTestCase.java | 19 +- .../java/org/opensearch/knn/KNNResult.java | 27 +- 7 files changed, 176 insertions(+), 304 deletions(-) diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 3fafae9ba..0cec3810e 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -148,7 +148,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), @@ -258,7 +258,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), @@ -828,7 +828,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 8919519d1..b17155704 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -7,7 +7,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Floats; import lombok.SneakyThrows; import org.apache.commons.lang.math.RandomUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; @@ -307,14 +306,14 @@ public void testIndexReopening() throws Exception { final float[] searchVector = TEST_QUERY_VECTORS[0]; final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length); - final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); + final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); closeIndex(INDEX_NAME); openIndex(INDEX_NAME); ensureGreen(INDEX_NAME); - final List knnResultsAfterIndexClosure = queryResults(searchVector, k); + final List knnResultsAfterIndexClosure = queryResults(searchVector, k); assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } @@ -365,7 +364,7 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector); float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001); @@ -373,7 +372,7 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep } } - private List queryResults(final float[] searchVector, final int k) throws Exception { + private List queryResults(final float[] searchVector, final int k) throws Exception { final String responseBody = EntityUtils.toString( searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity() ); diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index 8007504cf..86745ab13 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -30,11 +30,9 @@ import java.io.IOException; import java.net.URL; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -115,7 +113,7 @@ public void testEndToEnd() throws Exception { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 2e37e26c4..d82a7f98c 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -143,7 +142,7 @@ public void testEndToEnd() throws Exception { List actualScores = parseSearchResponseScore(responseBody, fieldName1); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), spaceType1), actualScores.get(j), @@ -159,7 +158,7 @@ public void testEndToEnd() throws Exception { actualScores = parseSearchResponseScore(responseBody, fieldName2); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType2), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index a5c39fce2..8d014afec 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -5,7 +5,8 @@ package org.opensearch.knn.plugin.script; -import java.io.IOException; +import java.util.function.BiFunction; +import java.util.function.Function; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; @@ -23,6 +24,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.script.Script; @@ -40,221 +43,26 @@ public class KNNScriptScoringIT extends KNNRestTestCase { public void testKNNL2ScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - randomCreateKNNIndex(); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 2.0f, 2.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 4.0f, 4.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [2.0, 2.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.L2.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("2", "4", "3", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("2", results.get(0).getDocId()); - assertEquals("4", results.get(1).getDocId()); - assertEquals("3", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.L2); } public void testKNNL1ScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - randomCreateKNNIndex(); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 4.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 5.0f, 5.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [1.0, 1.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.L1); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("2", "4", "3", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("2", results.get(0).getDocId()); - assertEquals("3", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.L1); } public void testKNNLInfScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - randomCreateKNNIndex(); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 4.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 5.0f, 5.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [1.0, 1.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.LINF.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("3", "2", "4", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("3", results.get(0).getDocId()); - assertEquals("2", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.LINF); } public void testKNNCosineScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - randomCreateKNNIndex(); - Float[] f1 = { 1.0f, -1.0f }; - addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); - - Float[] f2 = { 1.0f, 0.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f2); - - Float[] f3 = { 1.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f3); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "query_value": [2.0, 2.0], - * "space_type": "L2" - * } - * - * - */ - float[] queryVector = { 2.0f, -2.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.COSINESIMIL.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("0", "1", "2"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(3, results.size()); - - // assert document order - assertEquals("0", results.get(0).getDocId()); - assertEquals("1", results.get(1).getDocId()); - assertEquals("2", results.get(2).getDocId()); + testKNNScriptScore(SpaceType.COSINESIMIL); } public void testKNNInvalidSourceScript() throws Exception { /* * Create knn index and populate data */ - randomCreateKNNIndex(); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); /** * Construct Search Request @@ -296,7 +104,7 @@ public void testInvalidSpace() throws Exception { /* * Create knn index and populate data */ - randomCreateKNNIndex(); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); /** * Construct Search Request @@ -319,7 +127,7 @@ public void testMissingParamsInScript() throws Exception { /* * Create knn index and populate data */ - randomCreateKNNIndex(); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); /** * Construct Search Request @@ -352,7 +160,7 @@ public void testUnequalDimensions() throws Exception { /* * Create knn index and populate data */ - randomCreateKNNIndex(); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); Float[] f1 = { 1.0f, -1.0f }; addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); @@ -375,7 +183,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception { /* * Create knn index and populate data */ - randomCreateKNNIndex(); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); Float[] f1 = { 1.0f, 1.0f }; addDocWithNumericField(INDEX_NAME, "0", "price", 10); addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -399,10 +207,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception { responseBody ).map().get("hits")).get("hits"); - List docIds = hits.stream().map(hit -> { - String id = ((String) ((Map) hit).get("_id")); - return id; - }).collect(Collectors.toList()); + List docIds = hits.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); // assert document order assertEquals("1", docIds.get(0)); assertEquals("0", docIds.get(1)); @@ -636,64 +441,14 @@ public void testHammingScriptScore_Base64() throws Exception { } public void testKNNInnerProdScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - randomCreateKNNIndex(); - Float[] f1 = { -2.0f, -2.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 1.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 2.0f, 2.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 2.0f, -2.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "query_value": [1.0, 1.0], - * "space_type": "innerproduct", - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.INNER_PRODUCT.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("3", "2", "4", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("3", results.get(0).getDocId()); - assertEquals("2", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.INNER_PRODUCT); } public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { /* * Create knn index and populate data */ - randomCreateKNNIndex(); + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); @@ -795,25 +550,120 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { assertEquals(1, secondQueryCacheMap.get("hit_count")); } - /** - * Create native knn index or Lucene knn index with/without doc values randomly - * @throws IOException - */ - private void randomCreateKNNIndex() throws IOException { - if (randomBoolean()) { - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - } else { - createKnnIndex( - INDEX_NAME, - createKnnIndexMapping( - FIELD_NAME, - 2, - KNNConstants.METHOD_HNSW, - KNNEngine.LUCENE.getName(), - SpaceType.DEFAULT.getValue(), - randomBoolean() - ) - ); + private List createMappers(int dimensions) throws Exception { + return List.of( + createKnnIndexMapping(FIELD_NAME, dimensions), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + true + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false + ) + ); + } + + private float[] randomVector(int dimensions) { + final float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloat(); } + return vector; + } + + private Map createDataset(Function scoreFunction, int dimensions, int numDocs) { + final Map dataset = new HashMap<>(numDocs); + for (int i = 0; i < numDocs; i++) { + final float[] vector = randomVector(dimensions); + final float score = scoreFunction.apply(vector); + dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score)); + } + return dataset; + } + + private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + FIELD_NAME, + Collections.emptyMap(), + queryVector.length, + VectorDataType.FLOAT, + null + ); + List target = new ArrayList<>(queryVector.length); + for (float f : queryVector) { + target.add(f); + } + KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create(spaceType.getValue(), target, knnVectorFieldType); + switch (spaceType) { + case L1: + return ((KNNScoringSpace.L1) knnScoringSpace).scoringMethod; + case L2: + return ((KNNScoringSpace.L2) knnScoringSpace).scoringMethod; + case LINF: + return ((KNNScoringSpace.LInf) knnScoringSpace).scoringMethod; + case COSINESIMIL: + return ((KNNScoringSpace.CosineSimilarity) knnScoringSpace).scoringMethod; + case INNER_PRODUCT: + return ((KNNScoringSpace.InnerProd) knnScoringSpace).scoringMethod; + default: + throw new IllegalArgumentException(); + } + } + + private void testKNNScriptScore(SpaceType spaceType) throws Exception { + final int dims = randomIntBetween(2, 10); + final float[] queryVector = randomVector(dims); + final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + for (String mapper : createMappers(dims)) { + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector); + } + } + + private void createIndexAndAssertScriptScore( + String mapper, + SpaceType spaceType, + BiFunction scoreFunction, + int dimensions, + float[] queryVector + ) throws Exception { + /* + * Create knn index and populate data + */ + createKnnIndex(INDEX_NAME, mapper); + Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10)); + for (Map.Entry entry : dataset.entrySet()) { + addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector()); + } + + /** + * Construct Search Request + */ + QueryBuilder qb = new MatchAllQueryBuilder(); + Map params = new HashMap<>(); + /* + * params": { + * "field": FIELD_NAME, + * "vector": queryVector + * } + */ + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType.getValue()); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertTrue(results.stream().allMatch(r -> dataset.get(r.getDocId()).equals(r))); + deleteKNNIndex(INDEX_NAME); } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 68900102b..68255388b 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -245,10 +245,16 @@ protected List parseSearchResponse(String responseBody, String fieldN @SuppressWarnings("unchecked") List knnSearchResponses = hits.stream().map(hit -> { @SuppressWarnings("unchecked") - Float[] vector = Arrays.stream( - ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() - ).map(Object::toString).map(Float::valueOf).toArray(Float[]::new); - return new KNNResult((String) ((Map) hit).get("_id"), vector); + final float[] vector = Floats.toArray( + Arrays.stream( + ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() + ).map(Object::toString).map(Float::valueOf).collect(Collectors.toList()) + ); + return new KNNResult( + (String) ((Map) hit).get("_id"), + vector, + ((Double) ((Map) hit).get("_score")).floatValue() + ); }).collect(Collectors.toList()); return knnSearchResponses; @@ -482,7 +488,7 @@ protected void forceMergeKnnIndex(String index, int maxSegments) throws Exceptio /** * Add a single KNN Doc to an index */ - protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void addKnnDoc(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); @@ -1043,8 +1049,7 @@ public float[][] getIndexVectorsFromIndex(String testIndex, String testField, in int i = 0; for (KNNResult result : results) { - float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList())); - vectors[i++] = primitiveArray; + vectors[i++] = result.getVector(); } return vectors; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNResult.java b/src/testFixtures/java/org/opensearch/knn/KNNResult.java index 803c2ae72..ee2ba39f7 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNResult.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNResult.java @@ -5,20 +5,41 @@ package org.opensearch.knn; +import java.util.Arrays; +import java.util.Objects; + public class KNNResult { + private final static float delta = 1e-3f; + private String docId; - private Float[] vector; + private float[] vector; + private Float score; - public KNNResult(String docId, Float[] vector) { + public KNNResult(String docId, float[] vector, Float score) { this.docId = docId; this.vector = vector; + this.score = score; } public String getDocId() { return docId; } - public Float[] getVector() { + public float[] getVector() { return vector; } + + public Float getScore() { + return score; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KNNResult knnResult = (KNNResult) o; + return Objects.equals(docId, knnResult.docId) + && Arrays.equals(vector, knnResult.vector) + && (Float.compare(score, knnResult.score) == 0 || Math.abs(score - knnResult.score) <= delta); + } }