-
Notifications
You must be signed in to change notification settings - Fork 56
FEAT: support cosine similarity #90
Changes from 11 commits
99cc16c
6375d49
28f33f9
784e67f
d29cabd
bde898c
0bdf53e
25cafe9
2920f5b
a9388a9
8a58bf6
de8271d
c49efaf
ce2660f
2c10607
65052c8
3eb7a47
aa26cfb
18a8406
5ebb930
635c7b1
83e4a90
e858038
6cbfc79
66cbbe8
c852d84
0394e93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,11 +31,14 @@ | |
import org.elasticsearch.monitor.jvm.JvmInfo; | ||
import org.elasticsearch.monitor.os.OsProbe; | ||
|
||
import java.security.InvalidParameterException; | ||
import java.util.Arrays; | ||
import java.util.HashMap; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.Set; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.Stream; | ||
|
@@ -63,6 +66,7 @@ public class KNNSettings { | |
/** | ||
* Settings name | ||
*/ | ||
public static final String KNN_SPACE_TYPE = "index.knn.space_type"; | ||
public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m"; | ||
public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction"; | ||
public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search"; | ||
|
@@ -80,6 +84,11 @@ public class KNNSettings { | |
* Settings Definition | ||
*/ | ||
|
||
public static final Setting<String> INDEX_KNN_SPACE_TYPE = Setting.simpleString(KNN_SPACE_TYPE, | ||
"l2", | ||
new SpaceTypeValidator(), | ||
IndexScope); | ||
|
||
/** | ||
* M - the number of bi-directional links created for every new element during construction. | ||
* Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic | ||
|
@@ -252,7 +261,8 @@ public Setting<?> getSetting(String key) { | |
} | ||
|
||
public List<Setting<?>> getSettings() { | ||
List<Setting<?>> settings = Arrays.asList(INDEX_KNN_ALGO_PARAM_M_SETTING, | ||
List<Setting<?>> settings = Arrays.asList(INDEX_KNN_SPACE_TYPE, | ||
INDEX_KNN_ALGO_PARAM_M_SETTING, | ||
INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING, | ||
INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, | ||
KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING, | ||
|
@@ -357,6 +367,16 @@ public static int getEfSearchParam(String index) { | |
return getIndexSettingValue(index, KNN_ALGO_PARAM_EF_SEARCH, 512); | ||
} | ||
|
||
/** | ||
* | ||
* @param index Name of the index | ||
* @return spaceType value | ||
*/ | ||
public static String getSpaceType(String index) { | ||
return KNNSettings.state().clusterService.state().getMetaData() | ||
.index(index).getSettings().get(KNN_SPACE_TYPE, SpaceTypes.l2.getValue()); | ||
} | ||
|
||
public static int getIndexSettingValue(String index, String settingName, int defaultValue) { | ||
return KNNSettings.state().clusterService.state().getMetaData() | ||
.index(index).getSettings() | ||
|
@@ -367,4 +387,48 @@ public void setClusterService(ClusterService clusterService) { | |
this.clusterService = clusterService; | ||
} | ||
|
||
static class SpaceTypeValidator implements Setting.Validator<String> { | ||
|
||
private Set<String> types = SpaceTypes.getValues(); | ||
|
||
@Override public void validate(String value) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. null check on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vamshin I put a null check here, but it still fails to guard against null value, i.e. for some reason, if index.knn.space_type == null, it will not go through the validate. Similar bug happens to other parameters, e.g.
yields
How about we open a separate issue to track this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think we can probably ignore this case. |
||
if (!types.contains(value.toLowerCase())){ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this case insensitive? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I intended to set it case insensitive, e.g. L2, Cosinesimil is also valid string. But I can remove this if not necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think case insensitive is good |
||
throw new InvalidParameterException(String.format("Unsupported space type: %s", value)); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Enum contains space types for k-NN similarity search | ||
*/ | ||
public enum SpaceTypes { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Move this enum class to a dedicated file |
||
l2("l2"), | ||
cosinesimil("cosinesimil"); | ||
|
||
private String value; | ||
|
||
SpaceTypes(String value) { this.value = value; } | ||
|
||
/** | ||
* Get space type | ||
* | ||
* @return name | ||
*/ | ||
public String getValue() { return value; } | ||
|
||
/** | ||
* Get all space types | ||
* | ||
* @return set of all stat names | ||
*/ | ||
public static Set<String> getValues() { | ||
Set<String> values = new HashSet<>(); | ||
|
||
for (SpaceTypes spaceType : SpaceTypes.values()) { | ||
values.add(spaceType.getValue()); | ||
} | ||
return values; | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,12 +110,13 @@ public void close() { | |
* Loads the knn index to memory for querying the neighbours | ||
* | ||
* @param indexPath path where the hnsw index is stored | ||
* @param spaceType space type of the index | ||
* @param algoParams hnsw algorithm parameters | ||
* @return knn index that can be queried for k nearest neighbours | ||
*/ | ||
public static KNNIndex loadIndex(String indexPath, final String[] algoParams) { | ||
public static KNNIndex loadIndex(String indexPath, String spaceType, final String[] algoParams) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: final String spaceType? |
||
long fileSize = computeFileSize(indexPath); | ||
long indexPointer = init(indexPath, algoParams); | ||
long indexPointer = init(indexPath, spaceType, algoParams); | ||
return new KNNIndex(indexPointer, fileSize); | ||
} | ||
|
||
|
@@ -137,13 +138,13 @@ private static long computeFileSize(String indexPath) { | |
} | ||
|
||
// Builds index and writes to disk (no index pointer escapes). | ||
public static native void saveIndex(int[] ids, float[][] data, String indexPath, String[] algoParams); | ||
public static native void saveIndex(int[] ids, float[][] data, String indexPath, String spaceType, String[] algoParams); | ||
|
||
// Queries index (thread safe with other readers, blocked by write lock) | ||
private static native KNNQueryResult[] queryIndex(long indexPointer, float[] query, int k); | ||
|
||
// Loads index and returns pointer to index | ||
private static native long init(String indexPath, String[] algoParams); | ||
private static native long init(String indexPath, String spaceType, String[] algoParams); | ||
|
||
// Deletes memory pointed to by index pointer (needs write lock) | ||
private static native void gc(long indexPointer); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,8 +31,8 @@ | |
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
public class KNNJNITests extends ESTestCase { | ||
private static final Logger logger = LogManager.getLogger(KNNJNITests.class); | ||
public class KNNJNITestsIT extends ESTestCase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why switch this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah. When I ran the test locally on IntelliJ, it fails to detect the integration tests therein due to lack of suffix "IT" in the class name. So I switched that for easier debugging. Will change it back but I think it might worth doublecheck if CI also neglects the tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests as opposed to IT signals that it is a unit test. These tests run during
We consider these unit tests because they are testing an isolated portion of code. Interesting that IntelliJ does not pick them up. Will need to look into it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah. Sorry, I misidentified it as IT. |
||
private static final Logger logger = LogManager.getLogger(KNNJNITestsIT.class); | ||
|
||
public void testCreateHnswIndex() throws Exception { | ||
int[] docs = {0, 1, 2}; | ||
|
@@ -52,7 +52,7 @@ public void testCreateHnswIndex() throws Exception { | |
AccessController.doPrivileged( | ||
new PrivilegedAction<Void>() { | ||
public Void run() { | ||
KNNIndex.saveIndex(docs, vectors, indexPath, algoParams); | ||
KNNIndex.saveIndex(docs, vectors, indexPath, "l2", algoParams); | ||
return null; | ||
} | ||
} | ||
|
@@ -80,7 +80,7 @@ public void testQueryHnswIndex() throws Exception { | |
AccessController.doPrivileged( | ||
new PrivilegedAction<Void>() { | ||
public Void run() { | ||
KNNIndex.saveIndex(docs, vectors, indexPath, algoParams); | ||
KNNIndex.saveIndex(docs, vectors, indexPath, "l2", algoParams); | ||
return null; | ||
} | ||
} | ||
|
@@ -91,7 +91,7 @@ public Void run() { | |
float[] queryVector = {1.0f, 1.0f, 1.0f, 1.0f}; | ||
String[] algoQueryParams = {"efSearch=20"}; | ||
|
||
final KNNIndex knnIndex = KNNIndex.loadIndex(indexPath, algoQueryParams); | ||
final KNNIndex knnIndex = KNNIndex.loadIndex(indexPath, "l2", algoQueryParams); | ||
final KNNQueryResult[] results = knnIndex.queryIndex(queryVector, 30); | ||
|
||
Map<Integer, Float> scores = Arrays.stream(results).collect( | ||
|
@@ -126,7 +126,7 @@ public void testAssertExceptionFromJni() throws Exception { | |
AccessController.doPrivileged( | ||
new PrivilegedAction<Void>() { | ||
public Void run() { | ||
KNNIndex index = KNNIndex.loadIndex(indexPath.toString(), new String[] {}); | ||
KNNIndex index = KNNIndex.loadIndex(indexPath.toString(), "l2", new String[] {}); | ||
return null; | ||
} | ||
} | ||
|
@@ -153,7 +153,7 @@ public void testQueryHnswIndexWithValidAlgoParams() throws Exception { | |
AccessController.doPrivileged( | ||
new PrivilegedAction<Void>() { | ||
public Void run() { | ||
KNNIndex.saveIndex(docs, vectors, indexPath, algoIndexParams); | ||
KNNIndex.saveIndex(docs, vectors, indexPath, "l2", algoIndexParams); | ||
return null; | ||
} | ||
} | ||
|
@@ -165,7 +165,7 @@ public Void run() { | |
float[] queryVector = {1.0f, 1.0f, 1.0f, 1.0f}; | ||
String[] algoQueryParams = {"efSearch=200"}; | ||
|
||
final KNNIndex index = KNNIndex.loadIndex(indexPath, algoQueryParams); | ||
final KNNIndex index = KNNIndex.loadIndex(indexPath, "l2", algoQueryParams); | ||
final KNNQueryResult[] results = index.queryIndex(queryVector, 30); | ||
|
||
Map<Integer, Float> scores = Arrays.stream(results).collect( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Please add the new function arguments to the end of the existing parameters list. Please take care in other places as well.