Skip to content

Commit

Permalink
Implement binary format support
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jul 8, 2024
1 parent 5139b16 commit e62fb76
Show file tree
Hide file tree
Showing 70 changed files with 2,105 additions and 316 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Implement binary format support [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 1 addition & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace knn_jni {
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
void Free(jlong indexPointer, jboolean isBinaryIndexJ);

// Free shared index state in memory at shareIndexStatePointerJ
void FreeSharedIndexState(jlong shareIndexStatePointerJ);
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
* Signature: (J)V
* Signature: (JZ)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free
(JNIEnv *, jclass, jlong);
(JNIEnv *, jclass, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
13 changes: 10 additions & 3 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
return results;
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
void knn_jni::faiss_wrapper::Free(jlong indexPointer, jboolean isBinaryIndexJ) {
bool isBinaryIndex = static_cast<bool>(isBinaryIndexJ);
if (isBinaryIndex) {
auto *indexWrapper = reinterpret_cast<faiss::IndexBinary*>(indexPointer);
delete indexWrapper;
}
else {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
}
}

void knn_jni::faiss_wrapper::FreeSharedIndexState(jlong shareIndexStatePointerJ) {
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ)
{
try {
return knn_jni::faiss_wrapper::Free(indexPointerJ);
return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
16 changes: 15 additions & 1 deletion jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,21 @@ TEST(FaissFreeTest, BasicAssertions) {
test_util::FaissCreateIndex(dim, method, metricType));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex));
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_FALSE);
}


TEST(FaissBinaryFreeTest, BasicAssertions) {
// Define the data
int dim = 8;
std::string method = "BHNSW32";

// Create the index
faiss::IndexBinary *createdIndex(
test_util::FaissCreateBinaryIndex(dim, method));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_TRUE);
}

TEST(FaissInitLibraryTest, BasicAssertions) {
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNFaissUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

import java.util.Map;

import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;

public class KNNFaissUtil {
public boolean isBinaryIndex(Map<String, Object> parameters) {
return parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX);
}
}
19 changes: 13 additions & 6 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ public static void validateFloatVectorValue(float value) {
*
* @param value float value in byte range
*/
public static void validateByteVectorValue(float value) {
public static void validateByteVectorValue(float value, final VectorDataType dataType) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
dataType.getValue()
)

);
Expand All @@ -60,7 +60,7 @@ public static void validateByteVectorValue(float value) {
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
dataType.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
Expand All @@ -73,10 +73,17 @@ public static void validateByteVectorValue(float value) {
*
* @param dimension dimension of vector
* @param vectorSize size of the vector
* @param dataType vector data type
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
public static void validateVectorDimension(final int dimension, final int vectorSize, final VectorDataType dataType) {
int actualDimension = VectorDataType.BINARY == dataType ? vectorSize * Byte.SIZE : vectorSize;
if (dimension != actualDimension) {
String errorMessage = String.format(
Locale.ROOT,
"Vector dimension mismatch. Expected: %d, Given: %d",
dimension,
actualDimension
);
throw new IllegalArgumentException(errorMessage);
}
}
Expand Down
15 changes: 8 additions & 7 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -50,19 +52,18 @@ public class KNNMethodContext implements ToXContentFragment, Writeable {

public static synchronized KNNMethodContext getDefault() {
if (defaultInstance == null) {
defaultInstance = new KNNMethodContext(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap());
methodComponentContext.setIndexVersion(Version.CURRENT);
defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext);
}
return defaultInstance;
}

@NonNull
private final KNNEngine knnEngine;
@NonNull
private final SpaceType spaceType;
@Setter
private SpaceType spaceType;
@NonNull
private final MethodComponentContext methodComponentContext;

Expand Down Expand Up @@ -131,7 +132,7 @@ public static KNNMethodContext parse(Object in) {
Map<String, Object> methodMap = (Map<String, Object>) in;

KNNEngine engine = KNNEngine.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.UNDEFINED; // Get or default
String name = "";
Map<String, Object> parameters = new HashMap<>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.knn.plugin.script.KNNScoringUtil;

/**
* Wrapper class of VectorSimilarityFunction to support more function than what Lucene provides
*/
public enum KNNVectorSimilarityFunction {
EUCLIDEAN(VectorSimilarityFunction.EUCLIDEAN),
DOT_PRODUCT(VectorSimilarityFunction.DOT_PRODUCT),
COSINE(VectorSimilarityFunction.COSINE),
MAXIMUM_INNER_PRODUCT(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT),
HAMMING(null) {
@Override
public float compare(float[] v1, float[] v2) {
throw new IllegalStateException("Hamming space is not supported with float vectors");
}

@Override
public float compare(byte[] v1, byte[] v2) {
return 1.0f / (1 + KNNScoringUtil.calculateHammingBit(v1, v2));
}

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
}
};

private final VectorSimilarityFunction vectorSimilarityFunction;

KNNVectorSimilarityFunction(final VectorSimilarityFunction vectorSimilarityFunction) {
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

public VectorSimilarityFunction getVectorSimilarityFunction() {
return vectorSimilarityFunction;
}

public float compare(float[] var1, float[] var2) {
return vectorSimilarityFunction.compare(var1, var2);
}

public float compare(byte[] var1, byte[] var2) {
return vectorSimilarityFunction.compare(var1, var2);
}
}
55 changes: 44 additions & 11 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
package org.opensearch.knn.index;

import java.util.Locale;
import org.apache.lucene.index.VectorSimilarityFunction;

import java.util.HashSet;
import java.util.Set;
Expand All @@ -26,15 +25,28 @@
* nmslib calls the inner_product space "negdotprod". This translation should take place in the nmslib's jni layer.
*/
public enum SpaceType {
// This undefined space type is used to indicate that space type is not provided by user
// Later, we need to assign a default value based on data type
UNDEFINED("undefined") {
@Override
public float scoreTranslation(final float rawScore) {
throw new IllegalStateException("This method should not be called with UNDEFINED space type");
}

@Override
public boolean isSupported(VectorDataType vectorDataType) {
throw new IllegalStateException("This method should not be called with UNDEFINED space type");
}
},
L2("l2") {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.EUCLIDEAN;
}

@Override
Expand All @@ -52,8 +64,8 @@ public float scoreTranslation(float rawScore) {
}

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.COSINE;
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
}

@Override
Expand Down Expand Up @@ -104,18 +116,29 @@ public float scoreTranslation(float rawScore) {
}

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
}
},
HAMMING_BIT("hammingbit") {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public boolean isSupported(VectorDataType vectorDataType) {
return true;
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.HAMMING;
}
};

public static SpaceType DEFAULT = L2;
public static SpaceType DEFAULT_BINARY = HAMMING_BIT;

private final String value;

Expand All @@ -126,12 +149,12 @@ public float scoreTranslation(float rawScore) {
public abstract float scoreTranslation(float rawScore);

/**
* Get VectorSimilarityFunction that maps to this SpaceType
* Get KNNVectorSimilarityFunction that maps to this SpaceType
*
* @return VectorSimilarityFunction
* @return KNNVectorSimilarityFunction
*/
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a vector similarity function", getValue()));
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a knn vector similarity function", getValue()));
}

/**
Expand All @@ -152,6 +175,10 @@ public void validateVector(float[] vector) {
// do nothing
}

public boolean isSupported(VectorDataType vectorDataType) {
return VectorDataType.FLOAT == vectorDataType || VectorDataType.BYTE == vectorDataType;
}

/**
* Get space type name in engine
*
Expand All @@ -172,6 +199,12 @@ public static Set<String> getValues() {

public static SpaceType getSpace(String spaceTypeName) {
for (SpaceType currentSpaceType : SpaceType.values()) {
// UNDEFINED space type is a temporary value to be used only until we set proper default value when
// space type is not provided by user
// Therefore, we do not allow converting space type name to UNDEFINED space type
if (SpaceType.UNDEFINED == currentSpaceType) {
continue;
}
if (currentSpaceType.getValue().equalsIgnoreCase(spaceTypeName)) {
return currentSpaceType;
}
Expand Down
18 changes: 16 additions & 2 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,25 @@
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

/**
* Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin.
* We have two vector data_types, one is float (default) and the other one is byte.
* Enum contains data_type of vectors
* Lucene supports byte and float data type
* NMSLib supports only float data type
* Faiss supports binary and float data type
*/
@AllArgsConstructor
public enum VectorDataType {
BINARY("binary") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new IllegalArgumentException("This is not supported");
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
throw new IllegalArgumentException("This is not supported");
}
},
BYTE("byte") {

@Override
Expand Down
Loading

0 comments on commit e62fb76

Please sign in to comment.