Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

FEAT: support cosine similarity #90

Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
99cc16c
MNT: pass in spaceType in jni cpp
chenqi0805 Apr 12, 2020
6375d49
MNT: pass in spaceType in jni java
chenqi0805 Apr 12, 2020
28f33f9
Regenerate header file
chenqi0805 Apr 12, 2020
784e67f
Regenerate jnilib
chenqi0805 Apr 12, 2020
d29cabd
ADD: spacetype into KNN settings
chenqi0805 Apr 12, 2020
bde898c
Update vector field mapper
chenqi0805 Apr 12, 2020
0bdf53e
Pass in spaceType in consumer
chenqi0805 Apr 12, 2020
25cafe9
ADD: spaceType param into constant
chenqi0805 Apr 12, 2020
2920f5b
fix loadIndex in cache
chenqi0805 Apr 12, 2020
a9388a9
FIX: existing test cases
chenqi0805 Apr 12, 2020
8a58bf6
Update jni external library
chenqi0805 Apr 12, 2020
de8271d
FIX: release jstring and char array
chenqi0805 Apr 14, 2020
c49efaf
Rename test file and added a consinesimil test
chenqi0805 Apr 14, 2020
ce2660f
update .jnilib
chenqi0805 Apr 14, 2020
2c10607
Rename file
chenqi0805 Apr 14, 2020
65052c8
Rename file
chenqi0805 Apr 14, 2020
3eb7a47
Compile .so
chenqi0805 Apr 14, 2020
aa26cfb
Compile .so on linux
chenqi0805 Apr 14, 2020
18a8406
TST: IT on invalid spaceType settings
chenqi0805 Apr 15, 2020
5ebb930
Merge branch 'feat/28-cosine-similarity' of github.com:chenqi0805/k-N…
chenqi0805 Apr 15, 2020
635c7b1
Unused import
chenqi0805 Apr 15, 2020
83e4a90
Update README
chenqi0805 Apr 15, 2020
e858038
MNT: null check on spacetype
chenqi0805 Apr 17, 2020
6cbfc79
REF: SpaceTypes
chenqi0805 Apr 17, 2020
66cbbe8
MNT: arg order in jni
chenqi0805 Apr 17, 2020
c852d84
Compile .so
chenqi0805 Apr 17, 2020
0394e93
MNT: update readme on settings and experimental section
chenqi0805 Apr 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ You must provide index-level settings when you create the index. If you don't pr
##### index.knn
This setting indicates whether the index uses the KNN Codec or not. Possible values are *true*, *false*. Default value is *false*.

##### index.knn.space_type
This setting indicates the similarity metrics between vectors. Supported values are *l2*, *cosinesimil*. *l2* refers to euclidean distance metric; *cosinesimil* refers to cosine similarity. Default value is *l2*.

##### index.knn.algo_param.m
This setting is an HNSW parameter that represents "the number of bi-directional links created for every new element during construction. Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic dimensionality and/or high recall, while low M work better for datasets with low intrinsic dimensionality and/or low recalls. The parameter also determines the algorithm's memory consumption, which is roughly M * 8-10 bytes per stored element." [nmslib/hnswlib](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) The default value is *16*.

Expand Down
Binary file modified buildSrc/libKNNIndexV1_7_3_6.jnilib
Binary file not shown.
Binary file modified buildSrc/libKNNIndexV1_7_3_6.so
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ using similarity::KNNQueue;
extern "C"

struct IndexWrapper {
IndexWrapper() {
space.reset(SpaceFactoryRegistry<float>::Instance().CreateSpace("l2", AnyParams()));
index.reset(MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", "l2", *space, data));
IndexWrapper(string spaceType) {
space.reset(SpaceFactoryRegistry<float>::Instance().CreateSpace(spaceType, AnyParams()));
index.reset(MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceType, *space, data));
}
std::unique_ptr<Space<float>> space;
std::unique_ptr<Index<float>> index;
Expand Down Expand Up @@ -85,15 +85,19 @@ void catch_cpp_exception_and_throw_java(JNIEnv* env)
}
}

JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_saveIndex(JNIEnv* env, jclass cls, jintArray ids, jobjectArray vectors, jstring indexPath, jobjectArray algoParams)
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_saveIndex(JNIEnv* env, jclass cls, jintArray ids, jobjectArray vectors, jstring indexPath, jstring spaceType, jobjectArray algoParams)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please add the new function arguments to the end of the existing parameters list. Please take care in other places as well.

{
Space<float>* space = NULL;
ObjectVector dataset;
Index<float>* index = NULL;
int* object_ids = NULL;

try {
space = SpaceFactoryRegistry<float>::Instance().CreateSpace("l2", AnyParams());
const char *spaceTypeCStr = env->GetStringUTFChars(spaceType, 0);
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
string spaceTypeString(spaceTypeCStr);
env->ReleaseStringUTFChars(spaceType, spaceTypeCStr);
has_exception_in_stack(env);
space = SpaceFactoryRegistry<float>::Instance().CreateSpace(spaceTypeString, AnyParams());
object_ids = env->GetIntArrayElements(ids, 0);
for (int i = 0; i < env->GetArrayLength(vectors); i++) {
jfloatArray vectorArray = (jfloatArray)env->GetObjectArrayElement(vectors, i);
Expand All @@ -103,7 +107,7 @@ JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v173
}
// free up memory
env->ReleaseIntArrayElements(ids, object_ids, 0);
index = MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", "l2", *space, dataset);
index = MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeString, *space, dataset);

int paramsCount = env->GetArrayLength(algoParams);
vector<string> paramsList;
Expand Down Expand Up @@ -171,7 +175,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_amazon_opendistroforelasticsearch_knn_in
return NULL;
}

JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_init(JNIEnv* env, jclass cls, jstring indexPath, jobjectArray algoParams)
JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v1736_KNNIndex_init(JNIEnv* env, jclass cls, jstring indexPath, jstring spaceType, jobjectArray algoParams)
{
IndexWrapper *indexWrapper = NULL;
try {
Expand All @@ -181,7 +185,11 @@ JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_v17
has_exception_in_stack(env);

// Load index from file (may throw)
IndexWrapper *indexWrapper = new IndexWrapper();
const char *spaceTypeCStr = env->GetStringUTFChars(spaceType, 0);
string spaceTypeString(spaceTypeCStr);
env->ReleaseStringUTFChars(spaceType, spaceTypeCStr);
has_exception_in_stack(env);
IndexWrapper *indexWrapper = new IndexWrapper(spaceTypeString);
indexWrapper->index->LoadIndex(indexPathString);

// Parse and set query params
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public KNNIndexCacheEntry loadIndex(String indexPathUrl, String indexName) throw
// the entry
fileWatcher.init();

final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, getQueryParams(indexName));
final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, KNNSettings.getSpaceType(indexName), getQueryParams(indexName));

// TODO verify that this is safe - ideally we'd explicitly ensure that the FileWatcher is only checked
// after the guava cache has finished loading the key to avoid a race condition where the watcher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.monitor.os.OsProbe;

import java.security.InvalidParameterException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -63,6 +66,7 @@ public class KNNSettings {
/**
* Settings name
*/
public static final String KNN_SPACE_TYPE = "index.knn.space_type";
public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m";
public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction";
public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search";
Expand All @@ -80,6 +84,11 @@ public class KNNSettings {
* Settings Definition
*/

public static final Setting<String> INDEX_KNN_SPACE_TYPE = Setting.simpleString(KNN_SPACE_TYPE,
"l2",
new SpaceTypeValidator(),
IndexScope);

/**
* M - the number of bi-directional links created for every new element during construction.
* Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic
Expand Down Expand Up @@ -252,7 +261,8 @@ public Setting<?> getSetting(String key) {
}

public List<Setting<?>> getSettings() {
List<Setting<?>> settings = Arrays.asList(INDEX_KNN_ALGO_PARAM_M_SETTING,
List<Setting<?>> settings = Arrays.asList(INDEX_KNN_SPACE_TYPE,
INDEX_KNN_ALGO_PARAM_M_SETTING,
INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING,
INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING,
KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING,
Expand Down Expand Up @@ -357,6 +367,16 @@ public static int getEfSearchParam(String index) {
return getIndexSettingValue(index, KNN_ALGO_PARAM_EF_SEARCH, 512);
}

/**
*
* @param index Name of the index
* @return spaceType value
*/
public static String getSpaceType(String index) {
return KNNSettings.state().clusterService.state().getMetaData()
.index(index).getSettings().get(KNN_SPACE_TYPE, SpaceTypes.l2.getValue());
}

public static int getIndexSettingValue(String index, String settingName, int defaultValue) {
return KNNSettings.state().clusterService.state().getMetaData()
.index(index).getSettings()
Expand All @@ -367,4 +387,48 @@ public void setClusterService(ClusterService clusterService) {
this.clusterService = clusterService;
}

static class SpaceTypeValidator implements Setting.Validator<String> {

private Set<String> types = SpaceTypes.getValues();

@Override public void validate(String value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

null check on value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vamshin I put a null check here, but it still fails to guard against null value, i.e. for some reason, if index.knn.space_type == null, it will not go through the validate. Similar bug happens to other parameters, e.g.

PUT /myindex
{
    "settings" : {
        "index": {
            "knn": true,
            "knn.algo_param.m": null
        }
    }
}
PUT /myindex/_doc/2?refresh=true
{
    "my_vector1" : [1.5, 2.5],
    "price":10
}
POST /myindex/_search
{
    "size" : 10,
    "query": {
        "knn": {
            "my_vector1": {
                "vector": [15, 25],
                "k": 2
            }
        }
    }
}

yields

{
  "took" : 0,
  "timed_out" : false,
  "_shards" : {
    "total" : 1,
    "successful" : 1,
    "skipped" : 0,
    "failed" : 0
  },
  "hits" : {
    "total" : {
      "value" : 0,
      "relation" : "eq"
    },
    "max_score" : null,
    "hits" : [ ]
  }
}

How about we open a separate issue to track this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. I think we can probably ignore this case.

if (!types.contains(value.toLowerCase())){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this case insensitive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intended to set it case insensitive, e.g. L2, Cosinesimil is also valid string. But I can remove this if not necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think case insensitive is good

throw new InvalidParameterException(String.format("Unsupported space type: %s", value));
}
}
}

/**
* Enum contains space types for k-NN similarity search
*/
public enum SpaceTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Move this enum class to a dedicated file

l2("l2"),
cosinesimil("cosinesimil");

private String value;

SpaceTypes(String value) { this.value = value; }

/**
* Get space type
*
* @return name
*/
public String getValue() { return value; }

/**
* Get all space types
*
* @return set of all stat names
*/
public static Set<String> getValues() {
Set<String> values = new HashSet<>();

for (SpaceTypes spaceType : SpaceTypes.values()) {
values.add(spaceType.getValue());
}
return values;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ public Builder ignoreMalformed(boolean ignoreMalformed) {
return builder;
}

public Builder spaceTypeParam(String key, String paramValue) {
Defaults.FIELD_TYPE.putAttribute(key, paramValue.toLowerCase());
return builder;
}

public Builder algoParams(String key, int paramValue) {
Defaults.FIELD_TYPE.putAttribute(key, String.valueOf(paramValue));
return builder;
Expand Down Expand Up @@ -135,6 +140,8 @@ public static class TypeParser implements Mapper.TypeParser {
public Mapper.Builder<?, ?> parse(String name, Map<String, Object> node, ParserContext parserContext)
throws MapperParsingException {
Builder builder = new KNNVectorFieldMapper.Builder(name);
builder.spaceTypeParam(KNNConstants.SPACE_TYPE, parserContext.mapperService().getIndexSettings().getValue(
KNNSettings.INDEX_KNN_SPACE_TYPE));
builder.algoParams(KNNConstants.HNSW_ALGO_M, parserContext.mapperService().getIndexSettings().getValue(
KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING));
builder.algoParams(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, parserContext.mapperService().getIndexSettings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)

// Pass the path for the nms library to save the file
String tempIndexPath = indexPath + TEMP_SUFFIX;
String[] algoParams = getKNNIndexParams(field.attributes());
Map<String, String> fieldAttributes = field.attributes();
String spaceType = fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, KNNSettings.SpaceTypes.l2.getValue());
String[] algoParams = getKNNIndexParams(fieldAttributes);
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex.saveIndex(pair.docs, pair.vectors, tempIndexPath, algoParams);
KNNIndex.saveIndex(pair.docs, pair.vectors, tempIndexPath, spaceType, algoParams);
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.amazon.opendistroforelasticsearch.knn.index.util;

public class KNNConstants {
public static final String SPACE_TYPE = "spaceType";
public static final String HNSW_ALGO_M = "M";
public static final String HNSW_ALGO_EF_CONSTRUCTION = "efConstruction";
public static final String HNSW_ALGO_EF_SEARCH = "efSearch";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ public void close() {
* Loads the knn index to memory for querying the neighbours
*
* @param indexPath path where the hnsw index is stored
* @param spaceType space type of the index
* @param algoParams hnsw algorithm parameters
* @return knn index that can be queried for k nearest neighbours
*/
public static KNNIndex loadIndex(String indexPath, final String[] algoParams) {
public static KNNIndex loadIndex(String indexPath, String spaceType, final String[] algoParams) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: final String spaceType?

long fileSize = computeFileSize(indexPath);
long indexPointer = init(indexPath, algoParams);
long indexPointer = init(indexPath, spaceType, algoParams);
return new KNNIndex(indexPointer, fileSize);
}

Expand All @@ -137,13 +138,13 @@ private static long computeFileSize(String indexPath) {
}

// Builds index and writes to disk (no index pointer escapes).
public static native void saveIndex(int[] ids, float[][] data, String indexPath, String[] algoParams);
public static native void saveIndex(int[] ids, float[][] data, String indexPath, String spaceType, String[] algoParams);

// Queries index (thread safe with other readers, blocked by write lock)
private static native KNNQueryResult[] queryIndex(long indexPointer, float[] query, int k);

// Loads index and returns pointer to index
private static native long init(String indexPath, String[] algoParams);
private static native long init(String indexPath, String spaceType, String[] algoParams);

// Deletes memory pointed to by index pointer (needs write lock)
private static native void gc(long indexPointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.rest.RestStatus;

import java.io.IOException;

import static org.hamcrest.Matchers.containsString;

public class KNNESSettingsTestIT extends KNNRestTestCase {
Expand Down Expand Up @@ -70,5 +73,18 @@ public void testQueriesPluginDisabled() throws Exception {
updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, true);
searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1);
}

public void testCreateIndexWithInvalidSpaceType() throws IOException {
String invalidSpaceType = "bar";
Settings invalidSettings = Settings.builder()
.put("number_of_shards", 1)
.put("number_of_replicas", 0)
.put("index.knn", true)
.put("index.knn.space_type", invalidSpaceType)
.build();
Exception ex = expectThrows(ResponseException.class,
() -> createKnnIndex(INDEX_NAME, invalidSettings, createKnnIndexMapping(FIELD_NAME, 2)));
assertThat(ex.getMessage(), containsString(String.format("Unsupported space type: %s", invalidSpaceType)));
}
}

Loading