Skip to content

Commit

Permalink
Fix dim validation for bit element_type
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Oct 10, 2024
1 parent dbdf22f commit 5fbf53e
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public class ES813FlatVectorFormat extends KnnVectorsFormat {

static final String NAME = "ES813FlatVectorFormat";
Expand All @@ -55,6 +57,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
return new ES813FlatVectorReader(format.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}

static class ES813FlatVectorWriter extends KnnVectorsWriter {

private final FlatVectorsWriter writer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public class ES813Int8FlatVectorFormat extends KnnVectorsFormat {

static final String NAME = "ES813Int8FlatVectorFormat";
Expand Down Expand Up @@ -58,6 +60,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
return new ES813FlatVectorReader(format.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}

@Override
public String toString() {
return NAME + "(name=" + NAME + ", innerFormat=" + format + ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public final class ES814HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {

Expand Down Expand Up @@ -70,7 +71,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
return MAX_DIMS_COUNT;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public class ES815BitFlatVectorFormat extends KnnVectorsFormat {

static final String NAME = "ES815BitFlatVectorFormat";
Expand Down Expand Up @@ -45,4 +47,9 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
public String toString() {
return NAME;
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {

static final String NAME = "ES815HnswBitVectorsFormat";
Expand Down Expand Up @@ -72,4 +74,9 @@ public String toString() {
+ flatVectorsFormat
+ ")";
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

/**
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
*/
Expand Down Expand Up @@ -68,6 +70,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException
return new ES816BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer);
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}

@Override
public String toString() {
return "ES816BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

/**
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
Expand Down Expand Up @@ -128,7 +129,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
return MAX_DIMS_COUNT;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,32 +139,27 @@ public static class Builder extends FieldMapper.Builder {
if (o instanceof Integer == false) {
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
}
int dims = XContentMapValues.nodeIntegerValue(o);
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
if (dims < minDims || dims > maxDims) {
throw new MapperParsingException(
"The number of dimensions for field ["
+ n
+ "] should be in the range ["
+ minDims
+ ", "
+ maxDims
+ "] but was ["
+ dims
+ "]"
);
}
if (elementType.getValue() == ElementType.BIT) {
if (dims % Byte.SIZE != 0) {

return XContentMapValues.nodeIntegerValue(o);
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current))
.addValidator(dims -> {
if (dims == null) {
return;
}
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
if (dims < minDims || dims > maxDims) {
throw new MapperParsingException(
"The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]"
"The number of dimensions should be in the range [" + minDims + ", " + maxDims + "] but was [" + dims + "]"
);
}
}
return dims;
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
if (elementType.getValue() == ElementType.BIT) {
if (dims % Byte.SIZE != 0) {
throw new MapperParsingException("The number of dimensions for should be a multiple of 8 but was [" + dims + "]");
}
}
});
private final Parameter<VectorSimilarity> similarity;

private final Parameter<IndexOptions> indexOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("dims", dims * 8)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "bit")
Expand All @@ -192,7 +192,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("dims", dims * 8)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "bit")
Expand Down Expand Up @@ -891,9 +891,7 @@ public void testDims() {
})));
assertThat(
e.getMessage(),
equalTo(
"Failed to parse mapping: " + "The number of dimensions for field [field] should be in the range [1, 4096] but was [0]"
)
equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [0]")
);
}
// test max limit for non-indexed vectors
Expand All @@ -904,10 +902,7 @@ public void testDims() {
})));
assertThat(
e.getMessage(),
equalTo(
"Failed to parse mapping: "
+ "The number of dimensions for field [field] should be in the range [1, 4096] but was [5000]"
)
equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [5000]")
);
}
// test max limit for indexed vectors
Expand All @@ -919,10 +914,7 @@ public void testDims() {
})));
assertThat(
e.getMessage(),
equalTo(
"Failed to parse mapping: "
+ "The number of dimensions for field [field] should be in the range [1, 4096] but was [5000]"
)
equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [5000]")
);
}
}
Expand Down Expand Up @@ -955,6 +947,14 @@ public void testMergeDims() throws IOException {
);
}

public void testLargeDimsBit() throws IOException {
createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 1024 * Byte.SIZE);
b.field("element_type", ElementType.BIT.toString());
}));
}

public void testDefaults() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dims", 3)));

Expand Down

0 comments on commit 5fbf53e

Please sign in to comment.