Skip to content

Commit

Permalink
Add search functionality to FlatCodes
Browse files Browse the repository at this point in the history
Summary: Using the new dispatcher functions, add search func to flat codes. To test it, make IndexLattice a subclass of FlatCodes and check the resonstruction there.

Differential Revision: D59367989
  • Loading branch information
mdouze authored and facebook-github-bot committed Jul 10, 2024
1 parent a53965d commit 67fc053
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 49 deletions.
6 changes: 6 additions & 0 deletions contrib/inspect_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def get_flat_data(index):
return xb.reshape(index.ntotal, index.d)


def get_flat_codes(index_flat):
""" get the codes from an indexFlatCodes as an array """
return faiss.vector_to_array(index_flat.codes).reshape(
index_flat.ntotal, index_flat.code_size)


def get_NSG_neighbors(nsg):
""" get the neighbor list for the vectors stored in the NSG structure, as
a N-by-K matrix of indices """
Expand Down
164 changes: 159 additions & 5 deletions faiss/IndexFlatCodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/extra_distances.h>

namespace faiss {

Expand Down Expand Up @@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
reconstruct_n(key, 1, recons);
}

FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
FAISS_THROW_MSG("not implemented");
}

void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
// minimal sanity checks
const IndexFlatCodes* other =
Expand Down Expand Up @@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
std::swap(codes, new_codes);
}

namespace {

template <class VD>
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
const IndexFlatCodes& codec;
const VD vd;
// temp buffers
std::vector<uint8_t> code_buffer;
std::vector<float> vec_buffer;
const float* query = nullptr;

GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
codec(*codec),
vd(vd),
code_buffer(codec->code_size * 4),
vec_buffer(codec->d * 4) {}

void set_query(const float* x) override {
query = x;
}

float operator()(idx_t i) override {
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
return vd(query, vec_buffer.data());
}

float distance_to_code(const uint8_t* code) override {
codec.sa_decode(1, code, vec_buffer.data());
return vd(query, vec_buffer.data());
}

float symmetric_dis(idx_t i, idx_t j) override {
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
}

void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) override {
uint8_t* cp = code_buffer.data();
for (idx_t i : {idx0, idx1, idx2, idx3}) {
memcpy(cp, codes + i * code_size, code_size);
cp += code_size;
}
// potential benefit is if batch decoding is more efficient than 1 by 1
// decoding
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
dis0 = vd(query, vec_buffer.data());
dis1 = vd(query, vec_buffer.data() + vd.d);
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
}
};

struct Run_get_distance_computer {
using T = FlatCodesDistanceComputer*;

template <class VD>
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
}
};

template <class BlockResultHandler>
struct Run_search_with_decompress {
using T = void;

template <class VectorDistance>
void f(VectorDistance& vd,
const IndexFlatCodes* index_ptr,
const float* xq,
BlockResultHandler& res) {
// Note that there seems to be a clang (?) bug that "sometimes" passes
// the const Index & parameters by value, so to be on the safe side,
// it's better to use pointers.
const IndexFlatCodes& index = *index_ptr;
size_t ntotal = index.ntotal;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
#pragma omp parallel // if (res.nq > 100)
{
std::unique_ptr<DC> dc(new DC(&index, vd));
SingleResultHandler resi(res);
#pragma omp for
for (int64_t q = 0; q < res.nq; q++) {
resi.begin(q);
dc->set_query(xq + vd.d * q);
for (size_t i = 0; i < ntotal; i++) {
if (res.is_in_selection(i)) {
float dis = (*dc)(i);
resi.add_result(dis, i);
}
}
resi.end();
}
}
}
};

struct Run_search_with_decompress_res {
using T = void;

template <class ResultHandler>
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
Run_search_with_decompress<ResultHandler> r;
dispatch_VectorDistance(
index->d,
index->metric_type,
index->metric_arg,
r,
index,
xq,
res);
}
};

} // anonymous namespace

FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
Run_get_distance_computer r;
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
}

void IndexFlatCodes::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
Run_search_with_decompress_res r;
const IDSelector* sel = params ? params->sel : nullptr;
dispatch_knn_ResultHandler(
n, distances, labels, k, metric_type, sel, r, this, x);
}

void IndexFlatCodes::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params) const {
const IDSelector* sel = params ? params->sel : nullptr;
Run_search_with_decompress_res r;
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
}

} // namespace faiss
23 changes: 20 additions & 3 deletions faiss/IndexFlatCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#pragma once

#include <faiss/Index.h>
Expand Down Expand Up @@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
* different from the usual ones: the new ids are shifted */
size_t remove_ids(const IDSelector& sel) override;

/** a FlatCodesDistanceComputer offers a distance_to_code method */
/** a FlatCodesDistanceComputer offers a distance_to_code method
*
* The default implementation explicitly decodes the vector with sa_decode.
*/
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;

DistanceComputer* get_distance_computer() const override {
return get_FlatCodesDistanceComputer();
}

/** Search implemented by decoding */
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// returns a new instance of a CodePacker
CodePacker* get_CodePacker() const;

Expand Down
20 changes: 1 addition & 19 deletions faiss/IndexLattice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace faiss {

IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
: Index(d),
: IndexFlatCodes(0, d, METRIC_L2),
nsq(nsq),
dsq(d / nsq),
zn_sphere_codec(dsq, r2),
Expand Down Expand Up @@ -114,22 +114,4 @@ void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
}
}

void IndexLattice::add(idx_t, const float*) {
FAISS_THROW_MSG("not implemented");
}

void IndexLattice::search(
idx_t,
const float*,
idx_t,
float*,
idx_t*,
const SearchParameters*) const {
FAISS_THROW_MSG("not implemented");
}

void IndexLattice::reset() {
FAISS_THROW_MSG("not implemented");
}

} // namespace faiss
25 changes: 3 additions & 22 deletions faiss/IndexLattice.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,18 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#ifndef FAISS_INDEX_LATTICE_H
#define FAISS_INDEX_LATTICE_H
#pragma once

#include <vector>

#include <faiss/IndexIVF.h>
#include <faiss/IndexFlatCodes.h>
#include <faiss/impl/lattice_Zn.h>

namespace faiss {

/** Index that encodes a vector with a series of Zn lattice quantizers
*/
struct IndexLattice : Index {
struct IndexLattice : IndexFlatCodes {
/// number of sub-vectors
int nsq;
/// dimension of sub-vectors
Expand All @@ -30,8 +27,6 @@ struct IndexLattice : Index {

/// nb bits used to encode the scale, per subvector
int scale_nbit, lattice_nbit;
/// total, in bytes
size_t code_size;

/// mins and maxes of the vector norms, per subquantizer
std::vector<float> trained;
Expand All @@ -46,20 +41,6 @@ struct IndexLattice : Index {
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;

void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;

/// not implemented
void add(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void reset() override;
};

} // namespace faiss

#endif
14 changes: 14 additions & 0 deletions tests/test_standalone_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks


class TestEncodeDecode(unittest.TestCase):

def do_encode_twice(self, factory_key):
Expand Down Expand Up @@ -263,6 +264,19 @@ def test_ZnSphereCodecAlt32(self):
def test_ZnSphereCodecAlt24(self):
self.run_ZnSphereCodecAlt(24, 14)

def test_lattice_index(self):
index = faiss.index_factory(96, "ZnLattice3x10_4")
rs = np.random.RandomState(123)
xq = rs.randn(10, 96).astype('float32')
xb = rs.randn(20, 96).astype('float32')
index.train(xb)
index.add(xb)
D, I = index.search(xq, 5)
for i in range(10):
recons = index.reconstruct_batch(I[i, :])
ref_dis = ((recons - xq[i]) ** 2).sum(1)
np.testing.assert_allclose(D[i, :], ref_dis, atol=1e-4)


class TestBitstring(unittest.TestCase):

Expand Down

0 comments on commit 67fc053

Please sign in to comment.