Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for configuring HNSW parameters #79193

Merged
merged 4 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public CodecService(@Nullable MapperService mapperService) {
codecs.put(BEST_COMPRESSION_CODEC, new Lucene90Codec(Lucene90Codec.Mode.BEST_COMPRESSION));
} else {
codecs.put(DEFAULT_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
new PerFieldMapperCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
codecs.put(BEST_COMPRESSION_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
new PerFieldMapperCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
}
codecs.put(LUCENE_DEFAULT_CODEC, Codec.getDefault());
for (String codec : Codec.availableCodecs()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,32 @@

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.mapper.MapperService;

/**
* {@link PerFieldMappingPostingFormatCodec This postings format} is the default
* {@link PostingsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} per field. This
* allows users to change the low level postings format for individual fields
* per index in real time via the mapping API. If no specific postings format is
* configured for a specific field the default postings format is used.
* {@link PerFieldMapperCodec This Lucene codec} provides the default
* {@link PostingsFormat} and {@link KnnVectorsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} and {@link KnnVectorsFormat} per field. This
* allows users to change the low level postings format and vectors format for individual fields
* per index in real time via the mapping API. If no specific postings format or vector format is
* configured for a specific field the default postings or vector format is used.
*/
public class PerFieldMappingPostingFormatCodec extends Lucene90Codec {
public class PerFieldMapperCodec extends Lucene90Codec {
private final MapperService mapperService;

private final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat();

static {
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) :
"PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMapperCodec.class) :
"PerFieldMapperCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
}

public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService) {
public PerFieldMapperCodec(Mode compressionMode, MapperService mapperService) {
super(compressionMode);
this.mapperService = mapperService;
}
Expand All @@ -48,6 +49,15 @@ public PostingsFormat getPostingsFormatForField(String field) {
return format;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
KnnVectorsFormat format = mapperService.mappingLookup().getKnnVectorsFormatForField(field);
if (format == null) {
return super.getKnnVectorsFormatForField(field);
}
return format;
}

@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
return docValuesFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -228,6 +229,20 @@ public PostingsFormat getPostingsFormat(String field) {
return completionFields.contains(field) ? CompletionFieldMapper.postingsFormat() : null;
}

/**
* Returns the knn vectors format for a particular field
* @param field the field to retrieve a knn vectors format for
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
Mapper fieldMapper = fieldMappers.get(field);
if (fieldMapper instanceof PerFieldKnnVectorsFormatFieldMapper) {
return ((PerFieldKnnVectorsFormatFieldMapper) fieldMapper).getKnnVectorsFormatForField();
} else {
return null;
}
}

void checkLimits(IndexSettings settings) {
checkFieldLimit(settings.getMappingTotalFieldsLimit());
checkObjectDepthLimit(settings.getMappingDepthLimit());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;

/**
* Field mapper used for the only purpose to provide a custom knn vectors format.
* For internal use only.
*/

public interface PerFieldKnnVectorsFormatFieldMapper {

/**
* Returns the knn vectors format that is customly set up for this field
* or {@code null} if the format is not set up.
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
KnnVectorsFormat getKnnVectorsFormatForField();
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class CodecTests extends ESTestCase {

public void testResolveDefaultCodecs() throws Exception {
CodecService codecService = createCodecService();
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingPostingFormatCodec.class));
assertThat(codecService.codec("default"), instanceOf(PerFieldMapperCodec.class));
assertThat(codecService.codec("default"), instanceOf(Lucene90Codec.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.unit.Fuzziness;
import org.elasticsearch.index.codec.PerFieldMapperCodec;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -38,7 +39,6 @@
import org.elasticsearch.index.analysis.IndexAnalyzers;
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.codec.PerFieldMappingPostingFormatCodec;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
Expand Down Expand Up @@ -122,8 +122,8 @@ public void testPostingsFormat() throws IOException {
MapperService mapperService = createMapperService(fieldMapping(this::minimalMapping));
CodecService codecService = new CodecService(mapperService);
Codec codec = codecService.codec("default");
assertThat(codec, instanceOf(PerFieldMappingPostingFormatCodec.class));
PerFieldMappingPostingFormatCodec perFieldCodec = (PerFieldMappingPostingFormatCodec) codec;
assertThat(codec, instanceOf(PerFieldMapperCodec.class));
PerFieldMapperCodec perFieldCodec = (PerFieldMapperCodec) codec;
assertThat(perFieldCodec.getPostingsFormatForField("field"), instanceOf(Completion90PostingsFormat.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ setup:
dims: 5
index: true
similarity: dot_product

- do:
index:
index: test-index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ setup:
dims: 3
index: true
similarity: l2_norm
index_options:
type: hnsw
m: 15
ef_construction: 50

---
"Indexing of Dense vectors should error when dims don't match defined in the mapping":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.elasticsearch.xpack.vectors.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
Expand All @@ -16,6 +18,10 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
Expand All @@ -40,14 +46,15 @@
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public class DenseVectorFieldMapper extends FieldMapper implements PerFieldKnnVectorsFormatFieldMapper {

public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; //maximum allowed number of dimensions
Expand All @@ -73,6 +80,8 @@ public static class Builder extends FieldMapper.Builder {
private final Parameter<Boolean> indexed = Parameter.indexParam(m -> toType(m).indexed, false);
private final Parameter<VectorSimilarity> similarity = Parameter.enumParam(
"similarity", false, m -> toType(m).similarity, null, VectorSimilarity.class);
private final Parameter<IndexOptions> indexOptions = new Parameter<>("index_options", false, () -> null,
(n, c, o) -> o == null ? null : parseIndexOptions(n, o), m -> toType(m).indexOptions);
private final Parameter<Map<String, String>> meta = Parameter.metaParam();

final Version indexVersionCreated;
Expand All @@ -84,11 +93,13 @@ public Builder(String name, Version indexVersionCreated) {
this.indexed.requiresParameters(similarity);
this.similarity.setSerializerCheck((id, ic, v) -> v != null);
this.similarity.requiresParameters(indexed);
this.indexOptions.requiresParameters(indexed);
this.indexOptions.setSerializerCheck((id, ic, v) -> v != null);
}

@Override
protected List<Parameter<?>> getParameters() {
return List.of(dims, indexed, similarity, meta);
return List.of(dims, indexed, similarity, indexOptions, meta);
}

@Override
Expand All @@ -100,6 +111,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
dims.getValue(),
indexed.getValue(),
similarity.getValue(),
indexOptions.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, context),
copyTo.build());
Expand All @@ -116,6 +128,67 @@ enum VectorSimilarity {
}
}

private abstract static class IndexOptions implements ToXContent {
final String type;
IndexOptions(String type) {
this.type = type;
}
}

private static class HnswIndexOptions extends IndexOptions {
private final int m;
private final int efConstruction;

static IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
Object mNode = indexOptionsMap.remove("m");
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
if (mNode == null) {
throw new MapperParsingException("[index_options] of type [hnsw] requires field [m] to be configured");
}
if (efConstructionNode == null) {
throw new MapperParsingException("[index_options] of type [hnsw] requires field [ef_construction] to be configured");
}
int m = XContentMapValues.nodeIntegerValue(mNode);
int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode);
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
return new HnswIndexOptions(m, efConstruction);
}

private HnswIndexOptions(int m, int efConstruction) {
super("hnsw");
this.m = m;
this.efConstruction = efConstruction;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("type", type);
builder.field("m", m);
builder.field("ef_construction", efConstruction);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HnswIndexOptions that = (HnswIndexOptions) o;
return m == that.m && efConstruction == that.efConstruction;
}

@Override
public int hashCode() {
return Objects.hash(type, m, efConstruction);
}

@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + " }";
}
}

public static final TypeParser PARSER
= new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE));

Expand Down Expand Up @@ -185,15 +258,17 @@ public Query termQuery(Object value, SearchExecutionContext context) {
private final int dims;
private final boolean indexed;
private final VectorSimilarity similarity;
private final IndexOptions indexOptions;
private final Version indexCreatedVersion;

private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims,
boolean indexed, VectorSimilarity similarity,
private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims, boolean indexed,
VectorSimilarity similarity, IndexOptions indexOptions,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
this.indexOptions = indexOptions;
this.indexCreatedVersion = indexCreatedVersion;
}

Expand Down Expand Up @@ -289,4 +364,29 @@ protected String contentType() {
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), indexCreatedVersion).init(this);
}

private static IndexOptions parseIndexOptions(String fieldName, Object propNode) {
@SuppressWarnings("unchecked")
Map<String, ?> indexOptionsMap = (Map<String, ?>) propNode;
Object typeNode = indexOptionsMap.remove("type");
if (typeNode == null) {
throw new MapperParsingException("[index_options] requires field [type] to be configured");
}
String type = XContentMapValues.nodeStringValue(typeNode);
if (type.equals("hnsw")) {
return HnswIndexOptions.parseIndexOptions(fieldName, indexOptionsMap);
} else {
throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]");
}
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField() {
if (indexOptions == null) {
return null; // use default format
} else {
HnswIndexOptions hnswIndexOptions = (HnswIndexOptions) indexOptions;
return new Lucene90HnswVectorsFormat(hnswIndexOptions.m, hnswIndexOptions.efConstruction);
}
}
}
Loading