Skip to content

Commit

Permalink
Add BWC test for cosine similarity for knn script scoring
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 6, 2025
1 parent 84cfa8e commit c2cd8c8
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.bwc;

import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.index.query.MatchAllQueryBuilder;
Expand Down Expand Up @@ -49,6 +50,38 @@ public void testKNNL2ScriptScore() throws Exception {
}
}

// KNN script scoring for space_type "cosine"
public void testKNNCosineScriptScore() throws Exception {
float[] indexVector1 = { 1.1f, 2.1f, 3.3f };
float[] indexVector2 = { 8.1f, 9.1f, 0.3f };
float[] queryVector = { 3.0f, 4.0f, 5.5f };
if (isRunningAgainstOldCluster()) {
createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, 3));
addKnnDoc(testIndex, "1", TEST_FIELD, indexVector1);
validateScore(1, queryVector, new float[] { cosineSimilarity(queryVector, indexVector1) });
} else {
addKnnDoc(testIndex, "2", TEST_FIELD, indexVector2);
validateScore(
2,
queryVector,
new float[] { cosineSimilarity(queryVector, indexVector1), cosineSimilarity(queryVector, indexVector2) }
);
}
}

private float cosineSimilarity(float[] vectorA, float[] vectorB) {
return 1 + VectorUtil.cosine(vectorA, vectorB);
}

private void validateScore(int k, float[] queryVector, float[] expectedScore) throws Exception {
final Response responseBody = executeKNNScriptScoreRequest(testIndex, TEST_FIELD, k, SpaceType.COSINESIMIL, queryVector);
List<Float> actualScores = parseSearchResponseScore(EntityUtils.toString(responseBody.getEntity()), TEST_FIELD);
assertEquals(expectedScore.length, actualScores.size());
for (int i = 0; i < expectedScore.length; i++) {
assertEquals(expectedScore[i], actualScores.get(i), 0.001);
}
}

// KNN script scoring for space_type "l1"
public void testKNNL1ScriptScore() throws Exception {
if (isRunningAgainstOldCluster()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.knn.bwc;

import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.client.Response;
import org.opensearch.knn.index.SpaceType;

import java.util.List;

import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

public class ScriptScoringIT extends AbstractRollingUpgradeTestCase {
Expand Down Expand Up @@ -54,4 +59,65 @@ public void validateKNNL2ScriptScoreOnUpgrade(int totalDocsCount, int docId) thr
validateKNNScriptScoreSearch(testIndex, TEST_FIELD, DIMENSIONS, totalDocsCount, K, SpaceType.L2);
}

// KNN script scoring for space_type "cosine"
public void testKNNCosineScriptScore() throws Exception {
float[] indexVector1 = { 1.1f, 2.1f, 3.3f };
float[] indexVector2 = { 8.1f, 9.1f, 10.3f };
float[] indexVector3 = { 9.1f, 10.1f, 11.3f };
float[] indexVector4 = { 10.1f, 11.1f, 12.3f };
float[] queryVector = { 3.0f, 4.0f, 13.5f };
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
int k = 10;
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, createKNNDefaultScriptScoreSettings(), createKnnIndexMapping(TEST_FIELD, 3));
addKnnDoc(testIndex, "1", TEST_FIELD, indexVector1);
validateScore(k, queryVector, new float[] { cosineSimilarity(queryVector, indexVector1) });
break;
case MIXED:
if (isFirstMixedRound()) {
addKnnDoc(testIndex, "2", TEST_FIELD, indexVector2);
validateScore(
k,
queryVector,
new float[] { cosineSimilarity(queryVector, indexVector1), cosineSimilarity(queryVector, indexVector2) }
);
} else {
addKnnDoc(testIndex, "3", TEST_FIELD, indexVector3);
validateScore(
k,
queryVector,
new float[] {
cosineSimilarity(queryVector, indexVector1),
cosineSimilarity(queryVector, indexVector2),
cosineSimilarity(queryVector, indexVector3) }
);
}
break;
case UPGRADED:
addKnnDoc(testIndex, "4", TEST_FIELD, indexVector3);
validateScore(
k,
queryVector,
new float[] {
cosineSimilarity(queryVector, indexVector1),
cosineSimilarity(queryVector, indexVector2),
cosineSimilarity(queryVector, indexVector3),
cosineSimilarity(queryVector, indexVector4) }
);
}
}

private float cosineSimilarity(float[] vectorA, float[] vectorB) {
return 1 + VectorUtil.cosine(vectorA, vectorB);
}

private void validateScore(int k, float[] queryVector, float[] expectedScores) throws Exception {
final Response responseBody = executeKNNScriptScoreRequest(testIndex, TEST_FIELD, k, SpaceType.COSINESIMIL, queryVector);
List<Float> actualScores = parseSearchResponseScore(EntityUtils.toString(responseBody.getEntity()), TEST_FIELD);
assertEquals(expectedScores.length, actualScores.size());
for (int i = 0; i < expectedScores.length; i++) {
assertEquals(expectedScores[i], actualScores.get(i), 0.01f);
}
}
}
25 changes: 14 additions & 11 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,19 @@ protected void validateKNNScriptScoreSearch(String testIndex, String testField,
IDVectorProducer idVectorProducer = new IDVectorProducer(dimension, numDocs);
float[] queryVector = idVectorProducer.getVector(numDocs);

final Response response = executeKNNScriptScoreRequest(testIndex, testField, k, spaceType, queryVector);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
assertEquals(k, results.size());
PriorityQueue<DistVector> pq = computeGroundTruthValues(k, spaceType, idVectorProducer);
for (int i = k - 1; i >= 0; i--) {
int expDocID = Integer.parseInt(pq.poll().getDocID());
int actualDocID = Integer.parseInt(results.get(i).getDocId());
assertEquals(expDocID, actualDocID);
}
}

protected Response executeKNNScriptScoreRequest(String testIndex, String testField, int k, SpaceType spaceType, float[] queryVector)
throws Exception {
QueryBuilder qb = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
params.put(FIELD, testField);
Expand All @@ -1520,17 +1533,7 @@ protected void validateKNNScriptScoreSearch(String testIndex, String testField,
Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k, Collections.emptyMap());
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
assertEquals(k, results.size());

PriorityQueue<DistVector> pq = computeGroundTruthValues(k, spaceType, idVectorProducer);

for (int i = k - 1; i >= 0; i--) {
int expDocID = Integer.parseInt(pq.poll().getDocID());
int actualDocID = Integer.parseInt(results.get(i).getDocId());
assertEquals(expDocID, actualDocID);
}
return response;
}

// validate KNN painless script score search for the space_types : "l2", "l1"
Expand Down

0 comments on commit c2cd8c8

Please sign in to comment.