Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr committed Mar 28, 2024
1 parent b110b5f commit 46e9883
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 304 deletions.
6 changes: 3 additions & 3 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down Expand Up @@ -258,7 +258,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down Expand Up @@ -828,7 +828,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed(

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down
9 changes: 4 additions & 5 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.commons.lang.math.RandomUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
Expand Down Expand Up @@ -307,14 +306,14 @@ public void testIndexReopening() throws Exception {
final float[] searchVector = TEST_QUERY_VECTORS[0];
final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length);

final List<Float[]> knnResultsBeforeIndexClosure = queryResults(searchVector, k);
final List<float[]> knnResultsBeforeIndexClosure = queryResults(searchVector, k);

closeIndex(INDEX_NAME);
openIndex(INDEX_NAME);

ensureGreen(INDEX_NAME);

final List<Float[]> knnResultsAfterIndexClosure = queryResults(searchVector, k);
final List<float[]> knnResultsAfterIndexClosure = queryResults(searchVector, k);

assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray());
}
Expand Down Expand Up @@ -365,15 +364,15 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector);
float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance);
assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001);
}
}
}

private List<Float[]> queryResults(final float[] searchVector, final int k) throws Exception {
private List<float[]> queryResults(final float[] searchVector, final int k) throws Exception {
final String responseBody = EntityUtils.toString(
searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity()
);
Expand Down
4 changes: 1 addition & 3 deletions src/test/java/org/opensearch/knn/index/NmslibIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;

Expand Down Expand Up @@ -115,7 +113,7 @@ public void testEndToEnd() throws Exception {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down
5 changes: 2 additions & 3 deletions src/test/java/org/opensearch/knn/index/OpenSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;

Expand Down Expand Up @@ -143,7 +142,7 @@ public void testEndToEnd() throws Exception {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName1);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), spaceType1),
actualScores.get(j),
Expand All @@ -159,7 +158,7 @@ public void testEndToEnd() throws Exception {

actualScores = parseSearchResponseScore(responseBody, fieldName2);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType2),
actualScores.get(j),
Expand Down
Loading

0 comments on commit 46e9883

Please sign in to comment.