Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

FEAT: support cosine similarity #90

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
99cc16c
MNT: pass in spaceType in jni cpp
chenqi0805 Apr 12, 2020
6375d49
MNT: pass in spaceType in jni java
chenqi0805 Apr 12, 2020
28f33f9
Regenerate header file
chenqi0805 Apr 12, 2020
784e67f
Regenerate jnilib
chenqi0805 Apr 12, 2020
d29cabd
ADD: spacetype into KNN settings
chenqi0805 Apr 12, 2020
bde898c
Update vector field mapper
chenqi0805 Apr 12, 2020
0bdf53e
Pass in spaceType in consumer
chenqi0805 Apr 12, 2020
25cafe9
ADD: spaceType param into constant
chenqi0805 Apr 12, 2020
2920f5b
fix loadIndex in cache
chenqi0805 Apr 12, 2020
a9388a9
FIX: existing test cases
chenqi0805 Apr 12, 2020
8a58bf6
Update jni external library
chenqi0805 Apr 12, 2020
de8271d
FIX: release jstring and char array
chenqi0805 Apr 14, 2020
c49efaf
Rename test file and added a consinesimil test
chenqi0805 Apr 14, 2020
ce2660f
update .jnilib
chenqi0805 Apr 14, 2020
2c10607
Rename file
chenqi0805 Apr 14, 2020
65052c8
Rename file
chenqi0805 Apr 14, 2020
3eb7a47
Compile .so
chenqi0805 Apr 14, 2020
aa26cfb
Compile .so on linux
chenqi0805 Apr 14, 2020
18a8406
TST: IT on invalid spaceType settings
chenqi0805 Apr 15, 2020
5ebb930
Merge branch 'feat/28-cosine-similarity' of github.com:chenqi0805/k-N…
chenqi0805 Apr 15, 2020
635c7b1
Unused import
chenqi0805 Apr 15, 2020
83e4a90
Update README
chenqi0805 Apr 15, 2020
e858038
MNT: null check on spacetype
chenqi0805 Apr 17, 2020
6cbfc79
REF: SpaceTypes
chenqi0805 Apr 17, 2020
66cbbe8
MNT: arg order in jni
chenqi0805 Apr 17, 2020
c852d84
Compile .so
chenqi0805 Apr 17, 2020
0394e93
MNT: update readme on settings and experimental section
chenqi0805 Apr 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified buildSrc/libKNNIndexV1_7_3_6.jnilib
Binary file not shown.
2 changes: 1 addition & 1 deletion jni/external/nmslib
Submodule nmslib updated 329 files
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ using similarity::KNNQueue;
extern "C"

struct IndexWrapper {
IndexWrapper() {
space.reset(SpaceFactoryRegistry<float>::Instance().CreateSpace("l2", AnyParams()));
index.reset(MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", "l2", *space, data));
IndexWrapper(string spaceType) {
space.reset(SpaceFactoryRegistry<float>::Instance().CreateSpace(spaceType, AnyParams()));
index.reset(MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceType, *space, data));
}
std::unique_ptr<Space<float>> space;
std::unique_ptr<Index<float>> index;
Expand Down Expand Up @@ -85,15 +85,17 @@ void catch_cpp_exception_and_throw_java(JNIEnv* env)
}
}

JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_saveIndex(JNIEnv* env, jclass cls, jintArray ids, jobjectArray vectors, jstring indexPath, jobjectArray algoParams)
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_saveIndex(JNIEnv* env, jclass cls, jintArray ids, jobjectArray vectors, jstring indexPath, jstring spaceType, jobjectArray algoParams)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please add the new function arguments to the end of the existing parameters list. Please take care in other places as well.

{
Space<float>* space = NULL;
ObjectVector dataset;
Index<float>* index = NULL;
int* object_ids = NULL;

try {
space = SpaceFactoryRegistry<float>::Instance().CreateSpace("l2", AnyParams());
const char *spaceTypeCStr = env->GetStringUTFChars(spaceType, 0);
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
string spaceTypeString(spaceTypeCStr);
space = SpaceFactoryRegistry<float>::Instance().CreateSpace(spaceTypeString, AnyParams());
object_ids = env->GetIntArrayElements(ids, 0);
for (int i = 0; i < env->GetArrayLength(vectors); i++) {
jfloatArray vectorArray = (jfloatArray)env->GetObjectArrayElement(vectors, i);
Expand All @@ -103,7 +105,7 @@ JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v173
}
// free up memory
env->ReleaseIntArrayElements(ids, object_ids, 0);
index = MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", "l2", *space, dataset);
index = MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeString, *space, dataset);

int paramsCount = env->GetArrayLength(algoParams);
vector<string> paramsList;
Expand Down Expand Up @@ -171,17 +173,19 @@ JNIEXPORT jobjectArray JNICALL Java_com_amazon_opendistroforelasticsearch_knn_in
return NULL;
}

JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_init(JNIEnv* env, jclass cls, jstring indexPath, jobjectArray algoParams)
JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_init(JNIEnv* env, jclass cls, jstring indexPath, jstring spaceType, jobjectArray algoParams)
{
IndexWrapper *indexWrapper = NULL;
try {
const char *indexPathCStr = env->GetStringUTFChars(indexPath, 0);
string indexPathString(indexPathCStr);
const char *spaceTypeCStr = env->GetStringUTFChars(spaceType, 0);
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
string spaceTypeString(spaceTypeCStr);
env->ReleaseStringUTFChars(indexPath, indexPathCStr);
has_exception_in_stack(env);

// Load index from file (may throw)
IndexWrapper *indexWrapper = new IndexWrapper();
IndexWrapper *indexWrapper = new IndexWrapper(spaceTypeString);
indexWrapper->index->LoadIndex(indexPathString);

// Parse and set query params
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public KNNIndexCacheEntry loadIndex(String indexPathUrl, String indexName) throw
// the entry
fileWatcher.init();

final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, getQueryParams(indexName));
final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, KNNSettings.getSpaceType(indexName), getQueryParams(indexName));

// TODO verify that this is safe - ideally we'd explicitly ensure that the FileWatcher is only checked
// after the guava cache has finished loading the key to avoid a race condition where the watcher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.monitor.os.OsProbe;

import java.security.InvalidParameterException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -63,6 +66,7 @@ public class KNNSettings {
/**
* Settings name
*/
public static final String KNN_SPACE_TYPE = "index.knn.space_type";
public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m";
public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction";
public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search";
Expand All @@ -80,6 +84,11 @@ public class KNNSettings {
* Settings Definition
*/

public static final Setting<String> INDEX_KNN_SPACE_TYPE = Setting.simpleString(KNN_SPACE_TYPE,
"l2",
new SpaceTypeValidator(),
IndexScope);

/**
* M - the number of bi-directional links created for every new element during construction.
* Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic
Expand Down Expand Up @@ -252,7 +261,8 @@ public Setting<?> getSetting(String key) {
}

public List<Setting<?>> getSettings() {
List<Setting<?>> settings = Arrays.asList(INDEX_KNN_ALGO_PARAM_M_SETTING,
List<Setting<?>> settings = Arrays.asList(INDEX_KNN_SPACE_TYPE,
INDEX_KNN_ALGO_PARAM_M_SETTING,
INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING,
INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING,
KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING,
Expand Down Expand Up @@ -357,6 +367,16 @@ public static int getEfSearchParam(String index) {
return getIndexSettingValue(index, KNN_ALGO_PARAM_EF_SEARCH, 512);
}

/**
*
* @param index Name of the index
* @return spaceType value
*/
public static String getSpaceType(String index) {
return KNNSettings.state().clusterService.state().getMetaData()
.index(index).getSettings().get(KNN_SPACE_TYPE, SpaceTypes.l2.getValue());
}

public static int getIndexSettingValue(String index, String settingName, int defaultValue) {
return KNNSettings.state().clusterService.state().getMetaData()
.index(index).getSettings()
Expand All @@ -367,4 +387,48 @@ public void setClusterService(ClusterService clusterService) {
this.clusterService = clusterService;
}

static class SpaceTypeValidator implements Setting.Validator<String> {

private Set<String> types = SpaceTypes.getValues();

@Override public void validate(String value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

null check on value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vamshin I put a null check here, but it still fails to guard against null value, i.e. for some reason, if index.knn.space_type == null, it will not go through the validate. Similar bug happens to other parameters, e.g.

PUT /myindex
{
    "settings" : {
        "index": {
            "knn": true,
            "knn.algo_param.m": null
        }
    }
}
PUT /myindex/_doc/2?refresh=true
{
    "my_vector1" : [1.5, 2.5],
    "price":10
}
POST /myindex/_search
{
    "size" : 10,
    "query": {
        "knn": {
            "my_vector1": {
                "vector": [15, 25],
                "k": 2
            }
        }
    }
}

yields

{
  "took" : 0,
  "timed_out" : false,
  "_shards" : {
    "total" : 1,
    "successful" : 1,
    "skipped" : 0,
    "failed" : 0
  },
  "hits" : {
    "total" : {
      "value" : 0,
      "relation" : "eq"
    },
    "max_score" : null,
    "hits" : [ ]
  }
}

How about we open a separate issue to track this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. I think we can probably ignore this case.

if (!types.contains(value.toLowerCase())){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this case insensitive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intended to set it case insensitive, e.g. L2, Cosinesimil is also valid string. But I can remove this if not necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think case insensitive is good

throw new InvalidParameterException(String.format("Unsupported space type: %s", value));
}
}
}

/**
* Enum contains space types for k-NN similarity search
*/
public enum SpaceTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Move this enum class to a dedicated file

l2("l2"),
cosinesimil("cosinesimil");

private String value;

SpaceTypes(String value) { this.value = value; }

/**
* Get space type
*
* @return name
*/
public String getValue() { return value; }

/**
* Get all space types
*
* @return set of all stat names
*/
public static Set<String> getValues() {
Set<String> values = new HashSet<>();

for (SpaceTypes spaceType : SpaceTypes.values()) {
values.add(spaceType.getValue());
}
return values;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ public Builder ignoreMalformed(boolean ignoreMalformed) {
return builder;
}

public Builder spaceTypeParam(String key, String paramValue) {
Defaults.FIELD_TYPE.putAttribute(key, paramValue.toLowerCase());
return builder;
}

public Builder algoParams(String key, int paramValue) {
Defaults.FIELD_TYPE.putAttribute(key, String.valueOf(paramValue));
return builder;
Expand Down Expand Up @@ -135,6 +140,8 @@ public static class TypeParser implements Mapper.TypeParser {
public Mapper.Builder<?, ?> parse(String name, Map<String, Object> node, ParserContext parserContext)
throws MapperParsingException {
Builder builder = new KNNVectorFieldMapper.Builder(name);
builder.spaceTypeParam(KNNConstants.SPACE_TYPE, parserContext.mapperService().getIndexSettings().getValue(
KNNSettings.INDEX_KNN_SPACE_TYPE));
builder.algoParams(KNNConstants.HNSW_ALGO_M, parserContext.mapperService().getIndexSettings().getValue(
KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING));
builder.algoParams(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, parserContext.mapperService().getIndexSettings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)

// Pass the path for the nms library to save the file
String tempIndexPath = indexPath + TEMP_SUFFIX;
String[] algoParams = getKNNIndexParams(field.attributes());
Map<String, String> fieldAttributes = field.attributes();
String spaceType = fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, KNNSettings.SpaceTypes.l2.getValue());
String[] algoParams = getKNNIndexParams(fieldAttributes);
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex.saveIndex(pair.docs, pair.vectors, tempIndexPath, algoParams);
KNNIndex.saveIndex(pair.docs, pair.vectors, tempIndexPath, spaceType, algoParams);
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.amazon.opendistroforelasticsearch.knn.index.util;

public class KNNConstants {
public static final String SPACE_TYPE = "spaceType";
public static final String HNSW_ALGO_M = "M";
public static final String HNSW_ALGO_EF_CONSTRUCTION = "efConstruction";
public static final String HNSW_ALGO_EF_SEARCH = "efSearch";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ public void close() {
* Loads the knn index to memory for querying the neighbours
*
* @param indexPath path where the hnsw index is stored
* @param spaceType space type of the index
* @param algoParams hnsw algorithm parameters
* @return knn index that can be queried for k nearest neighbours
*/
public static KNNIndex loadIndex(String indexPath, final String[] algoParams) {
public static KNNIndex loadIndex(String indexPath, String spaceType, final String[] algoParams) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: final String spaceType?

long fileSize = computeFileSize(indexPath);
long indexPointer = init(indexPath, algoParams);
long indexPointer = init(indexPath, spaceType, algoParams);
return new KNNIndex(indexPointer, fileSize);
}

Expand All @@ -137,13 +138,13 @@ private static long computeFileSize(String indexPath) {
}

// Builds index and writes to disk (no index pointer escapes).
public static native void saveIndex(int[] ids, float[][] data, String indexPath, String[] algoParams);
public static native void saveIndex(int[] ids, float[][] data, String indexPath, String spaceType, String[] algoParams);

// Queries index (thread safe with other readers, blocked by write lock)
private static native KNNQueryResult[] queryIndex(long indexPointer, float[] query, int k);

// Loads index and returns pointer to index
private static native long init(String indexPath, String[] algoParams);
private static native long init(String indexPath, String spaceType, String[] algoParams);

// Deletes memory pointed to by index pointer (needs write lock)
private static native void gc(long indexPointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import java.util.Map;
import java.util.stream.Collectors;

public class KNNJNITests extends ESTestCase {
private static final Logger logger = LogManager.getLogger(KNNJNITests.class);
public class KNNJNITestsIT extends ESTestCase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. When I ran the test locally on IntelliJ, it fails to detect the integration tests therein due to lack of suffix "IT" in the class name. So I switched that for easier debugging. Will change it back but I think it might worth doublecheck if CI also neglects the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests as opposed to IT signals that it is a unit test. These tests run during ./gradlew build, which is what we use for the PR testing workflow. You can also run them individually with this command:

./gradlew ':test' --tests "com.amazon.opendistroforelasticsearch.knn.index.KNNJNITests.testCreateHnswIndex"

We consider these unit tests because they are testing an isolated portion of code. Interesting that IntelliJ does not pick them up. Will need to look into it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. Sorry, I misidentified it as IT.

private static final Logger logger = LogManager.getLogger(KNNJNITestsIT.class);

public void testCreateHnswIndex() throws Exception {
int[] docs = {0, 1, 2};
Expand All @@ -52,7 +52,7 @@ public void testCreateHnswIndex() throws Exception {
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex.saveIndex(docs, vectors, indexPath, algoParams);
KNNIndex.saveIndex(docs, vectors, indexPath, "l2", algoParams);
return null;
}
}
Expand Down Expand Up @@ -80,7 +80,7 @@ public void testQueryHnswIndex() throws Exception {
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex.saveIndex(docs, vectors, indexPath, algoParams);
KNNIndex.saveIndex(docs, vectors, indexPath, "l2", algoParams);
return null;
}
}
Expand All @@ -91,7 +91,7 @@ public Void run() {
float[] queryVector = {1.0f, 1.0f, 1.0f, 1.0f};
String[] algoQueryParams = {"efSearch=20"};

final KNNIndex knnIndex = KNNIndex.loadIndex(indexPath, algoQueryParams);
final KNNIndex knnIndex = KNNIndex.loadIndex(indexPath, "l2", algoQueryParams);
final KNNQueryResult[] results = knnIndex.queryIndex(queryVector, 30);

Map<Integer, Float> scores = Arrays.stream(results).collect(
Expand Down Expand Up @@ -126,7 +126,7 @@ public void testAssertExceptionFromJni() throws Exception {
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex index = KNNIndex.loadIndex(indexPath.toString(), new String[] {});
KNNIndex index = KNNIndex.loadIndex(indexPath.toString(), "l2", new String[] {});
return null;
}
}
Expand All @@ -153,7 +153,7 @@ public void testQueryHnswIndexWithValidAlgoParams() throws Exception {
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex.saveIndex(docs, vectors, indexPath, algoIndexParams);
KNNIndex.saveIndex(docs, vectors, indexPath, "l2", algoIndexParams);
return null;
}
}
Expand All @@ -165,7 +165,7 @@ public Void run() {
float[] queryVector = {1.0f, 1.0f, 1.0f, 1.0f};
String[] algoQueryParams = {"efSearch=200"};

final KNNIndex index = KNNIndex.loadIndex(indexPath, algoQueryParams);
final KNNIndex index = KNNIndex.loadIndex(indexPath, "l2", algoQueryParams);
final KNNQueryResult[] results = index.queryIndex(queryVector, 30);

Map<Integer, Float> scores = Arrays.stream(results).collect(
Expand Down