Skip to content

Commit

Permalink
Merge branch 'main' into add_model_version
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanbogan committed Sep 5, 2024
2 parents 28ddef8 + cb9ba71 commit dda1f0d
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 175 deletions.
3 changes: 2 additions & 1 deletion release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Compatible with OpenSearch 2.17.0
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002)
* Add mode/compression configuration support for disk-based vector search [#2034](https://github.com/opensearch-project/k-NN/pull/2034)
* Add spaceType as a top level optional parameter while creating vector field. [#2044](https://github.com/opensearch-project/k-NN/pull/2044)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
* Add model version to model metadata and change model metadata reads to be from cluster metadata [#2005](https://github.com/opensearch-project/k-NN/pull/2005)
Expand All @@ -33,4 +34,4 @@ Compatible with OpenSearch 2.17.0
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997)
* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997)
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class KNNConstants {
public static final String METHOD_IVF = "ivf";
public static final String METHOD_PARAMETER_NLIST = "nlist";
public static final String METHOD_PARAMETER_SPACE_TYPE = "space_type"; // used for mapping parameter
// used for defining toplevel parameter
public static final String TOP_LEVEL_PARAMETER_SPACE_TYPE = METHOD_PARAMETER_SPACE_TYPE;
public static final String COMPOUND_EXTENSION = "c";
public static final String MODEL = "model";
public static final String MODELS = "models";
Expand Down Expand Up @@ -72,6 +74,7 @@ public class KNNConstants {
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;
public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature";
public static final String TOP_LEVEL_SPACE_TYPE_FEATURE = "top_level_space_type_feature";

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String MODEL_VERSION = "model_version";
Expand Down
12 changes: 11 additions & 1 deletion src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

package org.opensearch.knn.index;

import java.util.Arrays;
import java.util.Locale;

import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

Expand Down Expand Up @@ -149,6 +151,12 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
public static SpaceType DEFAULT = L2;
public static SpaceType DEFAULT_BINARY = HAMMING;

private static final String[] VALID_VALUES = Arrays.stream(SpaceType.values())
.filter(space -> space != SpaceType.UNDEFINED)
.map(SpaceType::getValue)
.collect(Collectors.toList())
.toArray(new String[0]);

private final String value;

SpaceType(String value) {
Expand Down Expand Up @@ -221,7 +229,9 @@ public static SpaceType getSpace(String spaceTypeName) {
return currentSpaceType;
}
}
throw new IllegalArgumentException("Unable to find space: " + spaceTypeName);
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Unable to find space: %s . Valid values are: %s", spaceTypeName, Arrays.toString(VALID_VALUES))
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Reads quantization states
Expand All @@ -32,7 +29,7 @@
public final class KNN990QuantizationStateReader {

/**
* Read quantization states and return list of fieldNames and bytes
* Reads an individual quantization state for a given field
* File format:
* Header
* QS1 state bytes
Expand All @@ -48,37 +45,6 @@ public final class KNN990QuantizationStateReader {
* -1 (marker)
* Footer
*
* @param state the read state to read from
*/
public static Map<String, byte[]> read(SegmentReadState state) throws IOException {
String quantizationStateFileName = getQuantizationStateFileName(state);
Map<String, byte[]> readQuantizationStateInfos = null;

try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);

int numFields = getNumFields(input);

readQuantizationStateInfos = new HashMap<>();

// Read each field's metadata from the index section and then read bytes
for (int i = 0; i < numFields; i++) {
int fieldNumber = input.readInt();
int length = input.readInt();
long position = input.readVLong();
byte[] stateBytes = readStateBytes(input, position, length);
String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName();
readQuantizationStateInfos.put(fieldName, stateBytes);
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e);
return Collections.emptyMap();
}
return readQuantizationStateInfos;
}

/**
* Reads an individual quantization state for a given field
* @param readConfig a config class that contains necessary information for reading the state
* @return quantization state
*/
Expand All @@ -88,41 +54,43 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();

IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ);
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);
try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) {

long position = -1;
int length = 0;
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}
long position = -1;
int length = 0;

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}
// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}

byte[] stateBytes = readStateBytes(input, position, length);
if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
byte[] stateBytes = readStateBytes(input, position, length);

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader {

private final FlatVectorsReader flatVectorsReader;
private final SegmentReadState segmentReadState;
private final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance();
private Map<String, String> quantizationStateCacheKeyPerField;

public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException {
this.segmentReadState = state;
this.flatVectorsReader = flatVectorsReader;
primeQuantizationStateCache();
loadCacheKeyMap();
}

/**
Expand Down Expand Up @@ -178,8 +177,10 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
@Override
public void close() throws IOException {
IOUtils.close(flatVectorsReader);
for (String cacheKey : quantizationStateCacheKeyPerField.values()) {
QuantizationStateCacheManager.getInstance().evict(cacheKey);
if (quantizationStateCacheKeyPerField != null) {
for (String cacheKey : quantizationStateCacheKeyPerField.values()) {
QuantizationStateCacheManager.getInstance().evict(cacheKey);
}
}
}

Expand All @@ -191,7 +192,7 @@ public long ramBytesUsed() {
return flatVectorsReader.ramBytesUsed();
}

private void primeQuantizationStateCache() throws IOException {
private void loadCacheKeyMap() throws IOException {
quantizationStateCacheKeyPerField = new HashMap<>();
for (FieldInfo fieldInfo : segmentReadState.fieldInfos) {
String cacheKey = UUIDs.base64UUID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
CompressionLevel.NAMES_ARRAY
).acceptsNull();

// A top level space Type field.
protected final Parameter<String> topLevelSpaceType = Parameter.stringParam(
KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE,
false,
m -> toType(m).originalMappingParameters.getTopLevelSpaceType(),
SpaceType.UNDEFINED.getValue()
).setValidator(SpaceType::getSpace);

protected final Parameter<Map<String, String>> meta = Parameter.metaParam();

protected ModelDao modelDao;
Expand All @@ -187,7 +195,18 @@ public Builder(

@Override
protected List<Parameter<?>> getParameters() {
return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId, mode, compressionLevel);
return Arrays.asList(
stored,
hasDocValues,
dimension,
vectorDataType,
meta,
knnMethodContext,
modelId,
mode,
compressionLevel,
topLevelSpaceType
);
}

protected Explicit<Boolean> ignoreMalformed(BuilderContext context) {
Expand Down Expand Up @@ -346,13 +365,31 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
validateFromModel(builder);
} else {
validateMode(builder);
validateSpaceType(builder);
resolveKNNMethodComponents(builder, parserContext);
validateFromKNNMethod(builder);
}

return builder;
}

private void validateSpaceType(KNNVectorFieldMapper.Builder builder) {
final KNNMethodContext knnMethodContext = builder.knnMethodContext.get();
// if context is defined
if (knnMethodContext != null) {
// now ensure both space types are same.
final SpaceType knnMethodContextSpaceType = knnMethodContext.getSpaceType();
final SpaceType topLevelSpaceType = SpaceType.getSpace(builder.topLevelSpaceType.get());
if (topLevelSpaceType != SpaceType.UNDEFINED
&& topLevelSpaceType != knnMethodContextSpaceType
&& knnMethodContextSpaceType != SpaceType.UNDEFINED) {
throw new MapperParsingException(
"Space type in \"method\" and top level space type should be same or one of them should be defined"
);
}
}
}

private void validateMode(KNNVectorFieldMapper.Builder builder) {
boolean isKNNMethodContextConfigured = builder.originalParameters.getKnnMethodContext() != null;
boolean isModeConfigured = builder.mode.isConfigured() || builder.compressionLevel.isConfigured();
Expand Down Expand Up @@ -386,6 +423,11 @@ private void validateFromModel(KNNVectorFieldMapper.Builder builder) {
if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", builder.name()));
}
// ensure model and top level spaceType is not defined
if (builder.modelId.get() != null && SpaceType.getSpace(builder.topLevelSpaceType.get()) != SpaceType.UNDEFINED) {
throw new IllegalArgumentException("TopLevel Space type and model can not be both specified in the " + "mapping");
}

validateCompressionAndModeNotSet(builder, builder.name(), "model");
}

Expand Down Expand Up @@ -439,36 +481,64 @@ private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, Pa
// Configure method from map or legacy
if (builder.originalParameters.isLegacyMapping()) {
builder.originalParameters.setResolvedKnnMethodContext(
createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated())
createKNNMethodContextFromLegacy(
parserContext.getSettings(),
parserContext.indexVersionCreated(),
SpaceType.getSpace(builder.topLevelSpaceType.get())
)
);
} else if (Mode.isConfigured(Mode.fromName(builder.mode.get()))
|| CompressionLevel.isConfigured(CompressionLevel.fromName(builder.compressionLevel.get()))) {
// we need don't need to resolve the space type, whatever default we are using will be passed down to
// while resolving KNNMethodContext for the mode and compression. and then when we resolve the spaceType
// we will set the correct spaceType.
builder.originalParameters.setResolvedKnnMethodContext(
ModeBasedResolver.INSTANCE.resolveKNNMethodContext(
builder.knnMethodConfigContext.getMode(),
builder.knnMethodConfigContext.getCompressionLevel(),
false
false,
SpaceType.getSpace(builder.originalParameters.getTopLevelSpaceType())
)
);
}
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.originalParameters.getVectorDataType());
// this function should now correct the space type for the above resolved context too, if spaceType was
// not provided.
setSpaceType(
builder.originalParameters.getResolvedKnnMethodContext(),
builder.originalParameters.getVectorDataType(),
builder.topLevelSpaceType.get()
);
}

private boolean isKNNDisabled(Settings settings) {
boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(settings);
return !isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(settings);
}

private void setDefaultSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) {
private void setSpaceType(
final KNNMethodContext knnMethodContext,
final VectorDataType vectorDataType,
final String topLevelSpaceType
) {
// Now KNNMethodContext should never be null. Because only case it could be null is flatMapper which is
// already handled
if (knnMethodContext == null) {
return;
throw new IllegalArgumentException("KNNMethodContext cannot be null");
}

final SpaceType topLevelSpaceTypeEnum = SpaceType.getSpace(topLevelSpaceType);
// Now set the spaceSpaceType for KNNMethodContext
if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) {
if (VectorDataType.BINARY == vectorDataType) {
knnMethodContext.setSpaceType(SpaceType.DEFAULT_BINARY);
// We are handling the case when top level space type is defined but method level spaceType is not
// defined.
if (topLevelSpaceTypeEnum != SpaceType.UNDEFINED) {
knnMethodContext.setSpaceType(topLevelSpaceTypeEnum);
} else {
knnMethodContext.setSpaceType(SpaceType.DEFAULT);
// If both spaceTypes are undefined then put the default spaceType based on datatype
if (VectorDataType.BINARY == vectorDataType) {
knnMethodContext.setSpaceType(SpaceType.DEFAULT_BINARY);
} else {
knnMethodContext.setSpaceType(SpaceType.DEFAULT);
}
}
}
}
Expand Down
Loading

0 comments on commit dda1f0d

Please sign in to comment.