Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader #13763

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ Optimizations
This should speed up query evaluation when the child query has a specialized bulk scorer, such as disjunctive queries.
(Mike Pellegrini)

* GITHUB#13763: Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (Pan Guixin)

Changes in runtime behavior
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -34,6 +32,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
Expand All @@ -53,14 +52,16 @@
*/
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final long checksumSeed;
private final FieldInfos fieldInfos;

Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
long[] checksumRef = new long[1];
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -161,7 +162,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce

FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -221,13 +222,18 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
}

@Override
Expand All @@ -238,8 +244,7 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -35,6 +33,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
Expand All @@ -56,13 +55,15 @@
*/
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -155,7 +156,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -215,13 +216,18 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
}

@Override
Expand All @@ -232,8 +238,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -34,6 +32,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -53,13 +52,15 @@
*/
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -152,7 +153,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -212,13 +213,18 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData);
}

@Override
Expand All @@ -229,8 +235,7 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -35,6 +33,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -54,13 +53,15 @@
*/
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -153,7 +154,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -230,48 +231,41 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ expectedEncoding);
}
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
}

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}

Expand All @@ -289,9 +283,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}

Expand Down
Loading