Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support score type threshold in radial search #1589

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498)
* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546)
* Support score type threshold in radial search [#1589](https://github.com/opensearch-project/k-NN/pull/1589)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;
}

@Override
public float scoreToDistanceTranslation(float score) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we adding the score to distance translation for L2 only? not for cosine and dot product?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@navneet1v IMO no need for cosine and dot product here. The SpaceType Enum class is used for common translation for different engines:

  • Cosine's score to distance function doesn't needed because Lucene's radial search api just need score parameter, in k-NN faiss, it doesn't support Cosine's type.
  • Dot product's score to distance function is only need in Faiss engine and in Faiss it has different translation from Lucene, so I just put the function inside Lucene.java where it needed.

if (score == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue()));

Check warning on line 43 in src/main/java/org/opensearch/knn/index/SpaceType.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/SpaceType.java#L43

Added line #L43 was not covered by tests
}
return 1 / score - 1;
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
}
},
COSINESIMIL("cosinesimil") {
@Override
Expand Down Expand Up @@ -170,4 +178,14 @@
}
throw new IllegalArgumentException("Unable to find space: " + spaceTypeName);
}

/**
* Translate a score to a distance for this space type
*
* @param score score to translate
* @return translated distance
*/
public float scoreToDistanceTranslation(float score) {
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue()));

Check warning on line 189 in src/main/java/org/opensearch/knn/index/SpaceType.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/SpaceType.java#L189

Added line #L189 was not covered by tests
}
}
90 changes: 76 additions & 14 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final ParseField SCORE_FIELD = new ParseField("score");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -64,6 +65,7 @@
private final float[] vector;
private int k = 0;
private Float distance = null;
private Float score = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

Expand Down Expand Up @@ -92,13 +94,14 @@
*
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder k(int k) {
public KNNQueryBuilder k(Integer k) {
if (k == null) {
throw new IllegalArgumentException("[" + NAME + "] requires k to be set");

Check warning on line 99 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L99

Added line #L99 was not covered by tests
}
validateSingleQueryType(k, distance, score);
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
if (distance != null) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
this.k = k;
return this;
}
Expand All @@ -112,13 +115,28 @@
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
}
if (k != 0) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validateSingleQueryType(k, distance, score);
this.distance = distance;
return this;
}

/**
* Builder method for score
*
* @param score the score threshold for the nearest neighbours
*/
public KNNQueryBuilder score(Float score) {
if (score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be set");
}
validateSingleQueryType(k, distance, score);
if (score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0");
}
this.score = score;
return this;
}

/**
* Builder method for filter
*
Expand Down Expand Up @@ -163,6 +181,7 @@
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
this.score = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -200,6 +219,9 @@
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
score = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -211,6 +233,7 @@
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float distance = null;
Float score = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -241,6 +264,8 @@
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -270,9 +295,7 @@
}
}

if ((k != null && distance != null) || (k == null && distance == null)) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validateSingleQueryType(k, distance, score);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand All @@ -281,8 +304,10 @@

if (k != null) {
knnQueryBuilder.k(k);
} else {
} else if (distance != null) {
knnQueryBuilder.distance(distance);
} else if (score != null) {
knnQueryBuilder.score(score);
}

return knnQueryBuilder;
Expand All @@ -300,6 +325,9 @@
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(score);
}
}

/**
Expand All @@ -324,6 +352,10 @@
return this.distance;
}

public float getScore() {
return this.score;
}

public QueryBuilder getFilter() {
return this.filter;
}
Expand Down Expand Up @@ -358,6 +390,9 @@
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (score != null) {
builder.field(SCORE_FIELD.getPreferredName(), score);

Check warning on line 394 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L394

Added line #L394 was not covered by tests
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -397,8 +432,8 @@
spaceType = knnMethodContext.getSpaceType();
}

// Currently, k-NN supports distance type radius search.
// We need transform distance radius to right type of engine required radius.
// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
Expand All @@ -407,6 +442,13 @@
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
}

if (this.score != null) {
if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType);
}
radius = knnEngine.scoreToRadialThreshold(this.score, spaceType);
}

if (fieldDimension != vector.length) {
throw new IllegalArgumentException(
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension)
Expand Down Expand Up @@ -464,7 +506,7 @@
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set");

Check warning on line 509 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L509

Added line #L509 was not covered by tests
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down Expand Up @@ -499,4 +541,24 @@
public String getWriteableName() {
return NAME;
}

private static void validateSingleQueryType(Integer k, Float distance, Float score) {
int countSetFields = 0;

if (k != null && k != 0) {
countSetFields++;
}
if (distance != null) {
countSetFields++;
}
if (score != null) {
countSetFields++;
}

if (countSetFields != 1) {
throw new IllegalArgumentException(

Check warning on line 559 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L559

Added line #L559 was not covered by tests
"[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score"
);
}
}
}
28 changes: 26 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
* Implements NativeLibrary for the faiss native library
*/
class Faiss extends NativeLibrary {
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
// about the compatibility version the file is created with. In the future, we should refactor this so that it
Expand All @@ -68,6 +69,12 @@ class Faiss extends NativeLibrary {
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
);

// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();

// Define encoders supported by faiss
private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext(
KNNConstants.ENCODER_FLAT,
Expand Down Expand Up @@ -301,7 +308,13 @@ class Faiss extends NativeLibrary {
).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build()
);

final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION);
final static Faiss INSTANCE = new Faiss(
METHODS,
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
);

/**
* Constructor for Faiss
Expand All @@ -315,9 +328,11 @@ private Faiss(
Map<String, KNNMethod> methods,
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
}

@Override
Expand All @@ -326,6 +341,15 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
if (this.scoreTransform.containsKey(spaceType)) {
return this.scoreTransform.get(spaceType).apply(score);
}
return spaceType.scoreToDistanceTranslation(score);
}

/**
* MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index.
* Faiss's index factory takes an "index description" that it uses to build the index. In this description,
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return knnLibrary.distanceToRadialThreshold(distance, spaceType);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return knnLibrary.scoreToRadialThreshold(score, spaceType);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return knnLibrary.validateMethod(knnMethodContext);
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public interface KNNLibrary {
*/
Float distanceToRadialThreshold(Float distance, SpaceType spaceType);

/**
* Translate the score threshold input from end user to the engine's threshold.
*
* @param score score threshold input from end user
* @param spaceType spaceType used to compute the threshold
*
* @return transformed score for the library
*/
Float scoreToRadialThreshold(Float score, SpaceType spaceType);

/**
* Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is
* deemed invalid.
Expand Down
15 changes: 13 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ public class Lucene extends JVMLibrary {
).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build()
);

// Map that overrides the default distance translations for Lucene, check more details in knn documentation:
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder()
.put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2)
.put(SpaceType.L2, distance -> 1 / (1 + distance))
.put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1)
.build();

final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS);
Expand Down Expand Up @@ -90,7 +92,16 @@ public float score(float rawScore, SpaceType spaceType) {
@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
// Lucene requires score threshold to be parameterized when calling the radius search.
return this.distanceTransform.get(spaceType).apply(distance);
if (this.distanceTransform.containsKey(spaceType)) {
return this.distanceTransform.get(spaceType).apply(distance);
}
return spaceType.scoreTranslation(distance);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Lucene engine uses distance as is and does not need transformation
return score;
}

@Override
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Nmslib.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return score;

Check warning on line 78 in src/main/java/org/opensearch/knn/index/util/Nmslib.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/util/Nmslib.java#L78

Added line #L78 was not covered by tests
}
}
Loading
Loading