Skip to content

Commit

Permalink
Speed up dense/sparse vector stats (#111729)
Browse files Browse the repository at this point in the history
This change ensures that we don't try to compute stats on mappings that don't have dense or sparse vector fields. We don't need to go through all the fields on every segment, instead we can extract the vector fields upfront and limit the work to only indices that define these types.

Closes #111715
  • Loading branch information
jimczi authored Aug 12, 2024
1 parent fd916c2 commit 59cf661
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 44 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/111729.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 111729
summary: Speed up dense/sparse vector stats
area: Vector Search
type: bug
issues:
- 111715
82 changes: 46 additions & 36 deletions server/src/main/java/org/elasticsearch/index/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.VersionType;
import org.elasticsearch.index.mapper.DocumentParser;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.FieldNamesFieldMapper;
import org.elasticsearch.index.mapper.LuceneDocument;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.Mapping;
import org.elasticsearch.index.mapper.MappingLookup;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.Uid;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.merge.MergeStats;
import org.elasticsearch.index.seqno.SeqNoStats;
Expand Down Expand Up @@ -242,29 +242,43 @@ protected final DocsStats docsStats(IndexReader indexReader) {
/**
* Returns the {@link DenseVectorStats} for this engine
*/
public DenseVectorStats denseVectorStats() {
public DenseVectorStats denseVectorStats(MappingLookup mappingLookup) {
if (mappingLookup == null) {
return new DenseVectorStats(0);
}

List<String> fields = new ArrayList<>();
for (Mapper mapper : mappingLookup.fieldMappers()) {
if (mapper instanceof DenseVectorFieldMapper) {
fields.add(mapper.fullPath());
}
}
if (fields.isEmpty()) {
return new DenseVectorStats(0);
}
try (Searcher searcher = acquireSearcher(DOC_STATS_SOURCE, SearcherScope.INTERNAL)) {
return denseVectorStats(searcher.getIndexReader());
return denseVectorStats(searcher.getIndexReader(), fields);
}
}

protected final DenseVectorStats denseVectorStats(IndexReader indexReader) {
protected final DenseVectorStats denseVectorStats(IndexReader indexReader, List<String> fields) {
long valueCount = 0;
// we don't wait for a pending refreshes here since it's a stats call instead we mark it as accessed only which will cause
// the next scheduled refresh to go through and refresh the stats as well
for (LeafReaderContext readerContext : indexReader.leaves()) {
try {
valueCount += getDenseVectorValueCount(readerContext.reader());
valueCount += getDenseVectorValueCount(readerContext.reader(), fields);
} catch (IOException e) {
logger.trace(() -> "failed to get dense vector stats for [" + readerContext + "]", e);
}
}
return new DenseVectorStats(valueCount);
}

private long getDenseVectorValueCount(final LeafReader atomicReader) throws IOException {
private long getDenseVectorValueCount(final LeafReader atomicReader, List<String> fields) throws IOException {
long count = 0;
for (FieldInfo info : atomicReader.getFieldInfos()) {
for (var field : fields) {
var info = atomicReader.getFieldInfos().fieldInfo(field);
if (info.getVectorDimension() > 0) {
switch (info.getVectorEncoding()) {
case FLOAT32 -> {
Expand All @@ -285,52 +299,48 @@ private long getDenseVectorValueCount(final LeafReader atomicReader) throws IOEx
* Returns the {@link SparseVectorStats} for this engine
*/
public SparseVectorStats sparseVectorStats(MappingLookup mappingLookup) {
if (mappingLookup == null) {
return new SparseVectorStats(0);
}
List<BytesRef> fields = new ArrayList<>();
for (Mapper mapper : mappingLookup.fieldMappers()) {
if (mapper instanceof SparseVectorFieldMapper) {
fields.add(new BytesRef(mapper.fullPath()));
}
}
if (fields.isEmpty()) {
return new SparseVectorStats(0);
}
Collections.sort(fields);
try (Searcher searcher = acquireSearcher(DOC_STATS_SOURCE, SearcherScope.INTERNAL)) {
return sparseVectorStats(searcher.getIndexReader(), mappingLookup);
return sparseVectorStats(searcher.getIndexReader(), fields);
}
}

protected final SparseVectorStats sparseVectorStats(IndexReader indexReader, MappingLookup mappingLookup) {
protected final SparseVectorStats sparseVectorStats(IndexReader indexReader, List<BytesRef> fields) {
long valueCount = 0;

if (mappingLookup == null) {
return new SparseVectorStats(valueCount);
}

// we don't wait for a pending refreshes here since it's a stats call instead we mark it as accessed only which will cause
// the next scheduled refresh to go through and refresh the stats as well
for (LeafReaderContext readerContext : indexReader.leaves()) {
try {
valueCount += getSparseVectorValueCount(readerContext.reader(), mappingLookup);
valueCount += getSparseVectorValueCount(readerContext.reader(), fields);
} catch (IOException e) {
logger.trace(() -> "failed to get sparse vector stats for [" + readerContext + "]", e);
}
}
return new SparseVectorStats(valueCount);
}

private long getSparseVectorValueCount(final LeafReader atomicReader, MappingLookup mappingLookup) throws IOException {
private long getSparseVectorValueCount(final LeafReader atomicReader, List<BytesRef> fields) throws IOException {
long count = 0;

Map<String, FieldMapper> mappers = new HashMap<>();
for (Mapper mapper : mappingLookup.fieldMappers()) {
if (mapper instanceof FieldMapper fieldMapper) {
if (fieldMapper.fieldType() instanceof SparseVectorFieldMapper.SparseVectorFieldType) {
mappers.put(fieldMapper.fullPath(), fieldMapper);
}
}
}

for (FieldInfo info : atomicReader.getFieldInfos()) {
String name = info.name;
if (mappers.containsKey(name)) {
Terms terms = atomicReader.terms(FieldNamesFieldMapper.NAME);
if (terms != null) {
TermsEnum termsEnum = terms.iterator();
if (termsEnum.seekExact(new BytesRef(name))) {
count += termsEnum.docFreq();
}
}
Terms terms = atomicReader.terms(FieldNamesFieldMapper.NAME);
if (terms == null) {
return count;
}
TermsEnum termsEnum = terms.iterator();
for (var fieldName : fields) {
if (termsEnum.seekExact(fieldName)) {
count += termsEnum.docFreq();
}
}
return count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,8 @@ public CompletionStats completionStats(String... fields) {

public DenseVectorStats denseVectorStats() {
readAllowed();
return getEngine().denseVectorStats();
MappingLookup mappingLookup = mapperService != null ? mapperService.mappingLookup() : null;
return getEngine().denseVectorStats(mappingLookup);
}

public SparseVectorStats sparseVectorStats() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ public final class FrozenEngine extends ReadOnlyEngine {
);
private final SegmentsStats segmentsStats;
private final DocsStats docsStats;
private final DenseVectorStats denseVectorStats;
private final SparseVectorStats sparseVectorStats;
private volatile ElasticsearchDirectoryReader lastOpenedReader;
private final ElasticsearchDirectoryReader canMatchReader;
private final Object cacheIdentity = new Object();
Expand Down Expand Up @@ -95,8 +93,6 @@ public FrozenEngine(
fillSegmentStats(segmentReader, true, segmentsStats);
}
this.docsStats = docsStats(reader);
this.denseVectorStats = denseVectorStats(reader);
this.sparseVectorStats = sparseVectorStats(reader, null);
canMatchReader = ElasticsearchDirectoryReader.wrap(
new RewriteCachingDirectoryReader(directory, reader.leaves(), null),
config.getShardId()
Expand Down Expand Up @@ -334,13 +330,17 @@ public DocsStats docStats() {
}

@Override
public DenseVectorStats denseVectorStats() {
return denseVectorStats;
public DenseVectorStats denseVectorStats(MappingLookup mappingLookup) {
// We could cache the result on first call but dense vectors on frozen tier
// are very unlikely, so we just don't count them in the stats.
return new DenseVectorStats(0);
}

@Override
public SparseVectorStats sparseVectorStats(MappingLookup mappingLookup) {
return sparseVectorStats;
// We could cache the result on first call but sparse vectors on frozen tier
// are very unlikely, so we just don't count them in the stats.
return new SparseVectorStats(0);
}

synchronized boolean isReaderOpen() {
Expand Down

0 comments on commit 59cf661

Please sign in to comment.