diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java index df530776d..071c620ed 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java @@ -6,6 +6,7 @@ package org.opensearch.knn.bwc; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.lucene.util.VectorUtil; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -49,6 +50,38 @@ public void testKNNL2ScriptScore() throws Exception { } } + // KNN script scoring for space_type "cosine" + public void testKNNCosineScriptScore() throws Exception { + float[] indexVector1 = { 1.1f, 2.1f, 3.3f }; + float[] indexVector2 = { 8.1f, 9.1f, 0.3f }; + float[] queryVector = { 3.0f, 4.0f, 5.5f }; + if (isRunningAgainstOldCluster()) { + createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, 3)); + addKnnDoc(testIndex, "1", TEST_FIELD, indexVector1); + validateScore(1, queryVector, new float[] { cosineSimilarity(queryVector, indexVector1) }); + } else { + addKnnDoc(testIndex, "2", TEST_FIELD, indexVector2); + validateScore( + 2, + queryVector, + new float[] { cosineSimilarity(queryVector, indexVector1), cosineSimilarity(queryVector, indexVector2) } + ); + } + } + + private float cosineSimilarity(float[] vectorA, float[] vectorB) { + return 1 + VectorUtil.cosine(vectorA, vectorB); + } + + private void validateScore(int k, float[] queryVector, float[] expectedScore) throws Exception { + final Response responseBody = executeKNNScriptScoreRequest(testIndex, TEST_FIELD, k, SpaceType.COSINESIMIL, queryVector); + List actualScores = parseSearchResponseScore(EntityUtils.toString(responseBody.getEntity()), TEST_FIELD); + assertEquals(expectedScore.length, actualScores.size()); + for (int i = 0; i < expectedScore.length; i++) { + assertEquals(expectedScore[i], actualScores.get(i), 0.001); + } + } + // KNN script scoring for space_type "l1" public void testKNNL1ScriptScore() throws Exception { if (isRunningAgainstOldCluster()) { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java index 190257112..3283827ad 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ScriptScoringIT.java @@ -5,8 +5,13 @@ package org.opensearch.knn.bwc; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.lucene.util.VectorUtil; +import org.opensearch.client.Response; import org.opensearch.knn.index.SpaceType; +import java.util.List; + import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; public class ScriptScoringIT extends AbstractRollingUpgradeTestCase { @@ -54,4 +59,65 @@ public void validateKNNL2ScriptScoreOnUpgrade(int totalDocsCount, int docId) thr validateKNNScriptScoreSearch(testIndex, TEST_FIELD, DIMENSIONS, totalDocsCount, K, SpaceType.L2); } + // KNN script scoring for space_type "cosine" + public void testKNNCosineScriptScore() throws Exception { + float[] indexVector1 = { 1.1f, 2.1f, 3.3f }; + float[] indexVector2 = { 8.1f, 9.1f, 10.3f }; + float[] indexVector3 = { 9.1f, 10.1f, 11.3f }; + float[] indexVector4 = { 10.1f, 11.1f, 12.3f }; + float[] queryVector = { 3.0f, 4.0f, 13.5f }; + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + int k = 10; + switch (getClusterType()) { + case OLD: + createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, 3)); + addKnnDoc(testIndex, "1", TEST_FIELD, indexVector1); + validateScore(k, queryVector, new float[] { cosineSimilarity(queryVector, indexVector1) }); + break; + case MIXED: + if (isFirstMixedRound()) { + addKnnDoc(testIndex, "2", TEST_FIELD, indexVector2); + validateScore( + k, + queryVector, + new float[] { cosineSimilarity(queryVector, indexVector1), cosineSimilarity(queryVector, indexVector2) } + ); + } else { + addKnnDoc(testIndex, "3", TEST_FIELD, indexVector3); + validateScore( + k, + queryVector, + new float[] { + cosineSimilarity(queryVector, indexVector1), + cosineSimilarity(queryVector, indexVector2), + cosineSimilarity(queryVector, indexVector3) } + ); + } + break; + case UPGRADED: + addKnnDoc(testIndex, "4", TEST_FIELD, indexVector3); + validateScore( + k, + queryVector, + new float[] { + cosineSimilarity(queryVector, indexVector1), + cosineSimilarity(queryVector, indexVector2), + cosineSimilarity(queryVector, indexVector3), + cosineSimilarity(queryVector, indexVector4) } + ); + } + } + + private float cosineSimilarity(float[] vectorA, float[] vectorB) { + return 1 + VectorUtil.cosine(vectorA, vectorB); + } + + private void validateScore(int k, float[] queryVector, float[] expectedScores) throws Exception { + final Response responseBody = executeKNNScriptScoreRequest(testIndex, TEST_FIELD, k, SpaceType.COSINESIMIL, queryVector); + List actualScores = parseSearchResponseScore(EntityUtils.toString(responseBody.getEntity()), TEST_FIELD); + assertEquals(expectedScores.length, actualScores.size()); + for (int i = 0; i < expectedScores.length; i++) { + assertEquals(expectedScores[i], actualScores.get(i), 0.01f); + } + } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 896674a18..0712b6267 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1511,6 +1511,19 @@ protected void validateKNNScriptScoreSearch(String testIndex, String testField, IDVectorProducer idVectorProducer = new IDVectorProducer(dimension, numDocs); float[] queryVector = idVectorProducer.getVector(numDocs); + final Response response = executeKNNScriptScoreRequest(testIndex, testField, k, spaceType, queryVector); + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField); + assertEquals(k, results.size()); + PriorityQueue pq = computeGroundTruthValues(k, spaceType, idVectorProducer); + for (int i = k - 1; i >= 0; i--) { + int expDocID = Integer.parseInt(pq.poll().getDocID()); + int actualDocID = Integer.parseInt(results.get(i).getDocId()); + assertEquals(expDocID, actualDocID); + } + } + + protected Response executeKNNScriptScoreRequest(String testIndex, String testField, int k, SpaceType spaceType, float[] queryVector) + throws Exception { QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); params.put(FIELD, testField); @@ -1520,17 +1533,7 @@ protected void validateKNNScriptScoreSearch(String testIndex, String testField, Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k, Collections.emptyMap()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField); - assertEquals(k, results.size()); - - PriorityQueue pq = computeGroundTruthValues(k, spaceType, idVectorProducer); - - for (int i = k - 1; i >= 0; i--) { - int expDocID = Integer.parseInt(pq.poll().getDocID()); - int actualDocID = Integer.parseInt(results.get(i).getDocId()); - assertEquals(expDocID, actualDocID); - } + return response; } // validate KNN painless script score search for the space_types : "l2", "l1"