Skip to content

Commit

Permalink
Support score type threshold in radial search (#1589)
Browse files Browse the repository at this point in the history
* Support score type threshold in radial search

Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 11, 2024
1 parent f6a58a7 commit 5e9fc7d
Show file tree
Hide file tree
Showing 14 changed files with 675 additions and 87 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498)
* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546)
* Support score type threshold in radial search [#1589](https://github.com/opensearch-project/k-NN/pull/1589)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ public float scoreTranslation(float rawScore) {
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;
}

@Override
public float scoreToDistanceTranslation(float score) {
if (score == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue()));
}
return 1 / score - 1;
}
},
COSINESIMIL("cosinesimil") {
@Override
Expand Down Expand Up @@ -170,4 +178,14 @@ public static SpaceType getSpace(String spaceTypeName) {
}
throw new IllegalArgumentException("Unable to find space: " + spaceTypeName);
}

/**
* Translate a score to a distance for this space type
*
* @param score score to translate
* @return translated distance
*/
public float scoreToDistanceTranslation(float score) {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue()));
}
}
90 changes: 76 additions & 14 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final ParseField SCORE_FIELD = new ParseField("score");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -64,6 +65,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final float[] vector;
private int k = 0;
private Float distance = null;
private Float score = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

Expand Down Expand Up @@ -92,13 +94,14 @@ public KNNQueryBuilder(String fieldName, float[] vector) {
*
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder k(int k) {
public KNNQueryBuilder k(Integer k) {
if (k == null) {
throw new IllegalArgumentException("[" + NAME + "] requires k to be set");
}
validateSingleQueryType(k, distance, score);
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
if (distance != null) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
this.k = k;
return this;
}
Expand All @@ -112,13 +115,28 @@ public KNNQueryBuilder distance(Float distance) {
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
}
if (k != 0) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validateSingleQueryType(k, distance, score);
this.distance = distance;
return this;
}

/**
* Builder method for score
*
* @param score the score threshold for the nearest neighbours
*/
public KNNQueryBuilder score(Float score) {
if (score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be set");
}
validateSingleQueryType(k, distance, score);
if (score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0");
}
this.score = score;
return this;
}

/**
* Builder method for filter
*
Expand Down Expand Up @@ -163,6 +181,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
this.score = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -200,6 +219,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
score = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -211,6 +233,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float distance = null;
Float score = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -241,6 +264,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -270,9 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

if ((k != null && distance != null) || (k == null && distance == null)) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validateSingleQueryType(k, distance, score);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand All @@ -281,8 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

if (k != null) {
knnQueryBuilder.k(k);
} else {
} else if (distance != null) {
knnQueryBuilder.distance(distance);
} else if (score != null) {
knnQueryBuilder.score(score);
}

return knnQueryBuilder;
Expand All @@ -300,6 +325,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(score);
}
}

/**
Expand All @@ -324,6 +352,10 @@ public float getDistance() {
return this.distance;
}

public float getScore() {
return this.score;
}

public QueryBuilder getFilter() {
return this.filter;
}
Expand Down Expand Up @@ -358,6 +390,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (score != null) {
builder.field(SCORE_FIELD.getPreferredName(), score);
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -397,8 +432,8 @@ protected Query doToQuery(QueryShardContext context) {
spaceType = knnMethodContext.getSpaceType();
}

// Currently, k-NN supports distance type radius search.
// We need transform distance radius to right type of engine required radius.
// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
Expand All @@ -407,6 +442,13 @@ protected Query doToQuery(QueryShardContext context) {
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
}

if (this.score != null) {
if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType);
}
radius = knnEngine.scoreToRadialThreshold(this.score, spaceType);
}

if (fieldDimension != vector.length) {
throw new IllegalArgumentException(
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension)
Expand Down Expand Up @@ -464,7 +506,7 @@ protected Query doToQuery(QueryShardContext context) {
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set");
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down Expand Up @@ -499,4 +541,24 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static void validateSingleQueryType(Integer k, Float distance, Float score) {
int countSetFields = 0;

if (k != null && k != 0) {
countSetFields++;
}
if (distance != null) {
countSetFields++;
}
if (score != null) {
countSetFields++;
}

if (countSetFields != 1) {
throw new IllegalArgumentException(
"[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score"
);
}
}
}
28 changes: 26 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
* Implements NativeLibrary for the faiss native library
*/
class Faiss extends NativeLibrary {
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
// about the compatibility version the file is created with. In the future, we should refactor this so that it
Expand All @@ -68,6 +69,12 @@ class Faiss extends NativeLibrary {
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
);

// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();

// Define encoders supported by faiss
private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext(
KNNConstants.ENCODER_FLAT,
Expand Down Expand Up @@ -301,7 +308,13 @@ class Faiss extends NativeLibrary {
).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build()
);

final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION);
final static Faiss INSTANCE = new Faiss(
METHODS,
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
);

/**
* Constructor for Faiss
Expand All @@ -315,9 +328,11 @@ private Faiss(
Map<String, KNNMethod> methods,
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
}

@Override
Expand All @@ -326,6 +341,15 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
if (this.scoreTransform.containsKey(spaceType)) {
return this.scoreTransform.get(spaceType).apply(score);
}
return spaceType.scoreToDistanceTranslation(score);
}

/**
* MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index.
* Faiss's index factory takes an "index description" that it uses to build the index. In this description,
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return knnLibrary.distanceToRadialThreshold(distance, spaceType);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return knnLibrary.scoreToRadialThreshold(score, spaceType);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return knnLibrary.validateMethod(knnMethodContext);
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public interface KNNLibrary {
*/
Float distanceToRadialThreshold(Float distance, SpaceType spaceType);

/**
* Translate the score threshold input from end user to the engine's threshold.
*
* @param score score threshold input from end user
* @param spaceType spaceType used to compute the threshold
*
* @return transformed score for the library
*/
Float scoreToRadialThreshold(Float score, SpaceType spaceType);

/**
* Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is
* deemed invalid.
Expand Down
15 changes: 13 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ public class Lucene extends JVMLibrary {
).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build()
);

// Map that overrides the default distance translations for Lucene, check more details in knn documentation:
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder()
.put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2)
.put(SpaceType.L2, distance -> 1 / (1 + distance))
.put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1)
.build();

final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS);
Expand Down Expand Up @@ -90,7 +92,16 @@ public float score(float rawScore, SpaceType spaceType) {
@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
// Lucene requires score threshold to be parameterized when calling the radius search.
return this.distanceTransform.get(spaceType).apply(distance);
if (this.distanceTransform.containsKey(spaceType)) {
return this.distanceTransform.get(spaceType).apply(distance);
}
return spaceType.scoreTranslation(distance);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Lucene engine uses distance as is and does not need transformation
return score;
}

@Override
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Nmslib.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ private Nmslib(
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return score;
}
}
Loading

0 comments on commit 5e9fc7d

Please sign in to comment.