-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathKNNEngine.java
204 lines (169 loc) · 5.92 KB
/
KNNEngine.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.knn.index.util;
import com.google.common.collect.ImmutableSet;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.KNNMethod;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
/**
* KNNEngine provides the functionality to validate and transform user defined indices into information that can be
* passed to the respective k-NN library's JNI layer.
*/
public enum KNNEngine implements KNNLibrary {
NMSLIB(NMSLIB_NAME, Nmslib.INSTANCE),
FAISS(FAISS_NAME, Faiss.INSTANCE),
LUCENE(LUCENE_NAME, Lucene.INSTANCE);
public static final KNNEngine DEFAULT = NMSLIB;
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE);
private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
16_000,
KNNEngine.FAISS,
16_000,
KNNEngine.LUCENE,
16_000
);
/**
* Constructor for KNNEngine
*
* @param name name of engine
* @param knnLibrary library the engine uses
*/
KNNEngine(String name, KNNLibrary knnLibrary) {
this.name = name;
this.knnLibrary = knnLibrary;
}
private final String name;
private final KNNLibrary knnLibrary;
/**
* Get the engine
*
* @param name of engine to be fetched
* @return KNNEngine corresponding to name
*/
public static KNNEngine getEngine(String name) {
if (NMSLIB.getName().equalsIgnoreCase(name)) {
return NMSLIB;
}
if (FAISS.getName().equalsIgnoreCase(name)) {
return FAISS;
}
if (LUCENE.getName().equalsIgnoreCase(name)) {
return LUCENE;
}
throw new IllegalArgumentException(String.format("Invalid engine type: %s", name));
}
/**
* Get the engine from the path.
*
* @param path to be checked
* @return KNNEngine corresponding to path
*/
public static KNNEngine getEngineNameFromPath(String path) {
if (path.endsWith(KNNEngine.NMSLIB.getExtension()) || path.endsWith(KNNEngine.NMSLIB.getCompoundExtension())) {
return KNNEngine.NMSLIB;
}
if (path.endsWith(KNNEngine.FAISS.getExtension()) || path.endsWith(KNNEngine.FAISS.getCompoundExtension())) {
return KNNEngine.FAISS;
}
throw new IllegalArgumentException("No engine matches the path's suffix");
}
/**
* Returns all engines that create custom segment files.
*
* @return Set of all engines that create custom segment files.
*/
public static Set<KNNEngine> getEnginesThatCreateCustomSegmentFiles() {
return CUSTOM_SEGMENT_FILE_ENGINES;
}
public static Set<KNNEngine> getEnginesThatSupportsFilters() {
return ENGINES_SUPPORTING_FILTERS;
}
/**
* Returns all engines that support radial search.
*
* @return Set of all engines that support radial search.
*/
public static Set<KNNEngine> getEnginesThatSupportsRadialSearch() {
return ENGINES_SUPPORTING_RADIAL_SEARCH;
}
/**
* Return number of max allowed dimensions per single vector based on the knn engine
* @param knnEngine knn engine to check max dimensions value
* @return
*/
public static int getMaxDimensionByEngine(KNNEngine knnEngine) {
return MAX_DIMENSIONS_BY_ENGINE.getOrDefault(knnEngine, MAX_DIMENSIONS_BY_ENGINE.get(KNNEngine.DEFAULT));
}
/**
* Get the name of the engine
*
* @return name of the engine
*/
public String getName() {
return name;
}
@Override
public String getVersion() {
return knnLibrary.getVersion();
}
@Override
public String getExtension() {
return knnLibrary.getExtension();
}
@Override
public String getCompoundExtension() {
return knnLibrary.getCompoundExtension();
}
@Override
public KNNMethod getMethod(String methodName) {
return knnLibrary.getMethod(methodName);
}
@Override
public float score(float rawScore, SpaceType spaceType) {
return knnLibrary.score(rawScore, spaceType);
}
@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return knnLibrary.distanceToRadialThreshold(distance, spaceType);
}
@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return knnLibrary.validateMethod(knnMethodContext);
}
@Override
public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
return knnLibrary.isTrainingRequired(knnMethodContext);
}
@Override
public Map<String, Object> getMethodAsMap(KNNMethodContext knnMethodContext) {
return knnLibrary.getMethodAsMap(knnMethodContext);
}
@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) {
return knnLibrary.estimateOverheadInKB(knnMethodContext, dimension);
}
@Override
public Boolean isInitialized() {
return knnLibrary.isInitialized();
}
@Override
public void setInitialized(Boolean isInitialized) {
knnLibrary.setInitialized(isInitialized);
}
@Override
public List<String> mmapFileExtensions() {
return knnLibrary.mmapFileExtensions();
}
}