Skip to content

Commit

Permalink
QINCo implementation in CPU Faiss (#3608)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3608

This is a straightforward implementation of QINCo in CPU Faiss, with encoding and decoding capabilities (not training).

For this, we translate a simplified version of some torch classes:

- tensors, restricted to 2D and int32 + float32

- Linear and Embedding layer

Then the QINCoStep and QINCo can just be defined as C++ objects that are copy-constructable.

There is some plumbing required in the wrapping layers to support the integration. Pytroch tensors are converted to numpy for getting / setting them in C++.

Reviewed By: asadoughi

Differential Revision: D59132952
  • Loading branch information
mdouze authored and facebook-github-bot committed Jul 10, 2024
1 parent 67fc053 commit 7279eda
Show file tree
Hide file tree
Showing 11 changed files with 1,213 additions and 1 deletion.
77 changes: 77 additions & 0 deletions demos/demo_qinco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This demonstrates how to reproduce the QINCo paper results using the Faiss
QINCo implementation. The code loads the reference model because training
is not implemented in Faiss.
Prepare the data with
cd /tmp
# get the reference qinco code
git clone https://github.com/facebookresearch/Qinco.git
# get the data
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
# get the model
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
"""

import numpy as np
from faiss.contrib.vecs_io import bvecs_mmap
import sys
import time
import torch
import faiss

# make sure pickle deserialization will work
sys.path.append("/tmp/Qinco")
import model_qinco

with torch.no_grad():

qinco = torch.load("/tmp/bigann_8x8_L2.pt")
qinco.eval()
# print(qinco)
if True:
torch.set_num_threads(1)
faiss.omp_set_num_threads(1)

x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
x_scaled = torch.from_numpy(x_base) / qinco.db_scale

t0 = time.time()
codes, _ = qinco.encode(x_scaled)
x_decoded_scaled = qinco.decode(codes)
print(f"Pytorch encode {time.time() - t0:.3f} s")
# multi-thread: 1.13s, single-thread: 7.744

x_decoded = x_decoded_scaled.numpy() * qinco.db_scale

err = ((x_decoded - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper

qinco2 = faiss.QINCo(qinco)
t0 = time.time()
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
print(f"Faiss encode {time.time() - t0:.3f} s")
# multi-thread: 3.2s, single thread: 7.019

# these tests don't work because there are outlier encodings
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
# np.testing.assert_allclose(x_decoded, x_decoded2)

ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
assert ndiff < 0.01
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
assert ndiff / len(x_base) < 0.01

err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14213.551
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(FAISS_SRC
IndexScalarQuantizer.cpp
IndexShards.cpp
IndexShardsIVF.cpp
IndexNeuralNetCodec.cpp
MatrixStats.cpp
MetaIndexes.cpp
VectorTransform.cpp
Expand Down Expand Up @@ -81,6 +82,7 @@ set(FAISS_SRC
invlists/InvertedLists.cpp
invlists/InvertedListsIOHook.cpp
utils/Heap.cpp
utils/NeuralNet.cpp
utils/WorkerThread.cpp
utils/distances.cpp
utils/distances_simd.cpp
Expand Down
56 changes: 56 additions & 0 deletions faiss/IndexNeuralNetCodec.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <faiss/IndexNeuralNetCodec.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/hamming.h>

namespace faiss {

/*********************************************************
* IndexNeuralNetCodec implementation
*********************************************************/

IndexNeuralNetCodec::IndexNeuralNetCodec(
int d,
int M,
int nbits,
MetricType metric)
: IndexFlatCodes((M * nbits + 7) / 8, d, metric), M(M), nbits(nbits) {
is_trained = false;
}

void IndexNeuralNetCodec::train(idx_t n, const float* x) {
FAISS_THROW_MSG("Training not implemented in C++, use Pytorch");
}

void IndexNeuralNetCodec::sa_encode(idx_t n, const float* x, uint8_t* codes)
const {
nn::Tensor2D x_tensor(n, d, x);
nn::Int32Tensor2D codes_tensor = net->encode(x_tensor);
pack_bitstrings(n, M, nbits, codes_tensor.data(), codes, code_size);
}

void IndexNeuralNetCodec::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
nn::Int32Tensor2D codes_tensor(n, M);
unpack_bitstrings(n, M, nbits, codes, code_size, codes_tensor.data());
nn::Tensor2D x_tensor = net->decode(codes_tensor);
memcpy(x, x_tensor.data(), d * n * sizeof(float));
}

/*********************************************************
* IndexQINeuralNetCodec implementation
*********************************************************/

IndexQINCo::IndexQINCo(int d, int M, int nbits, int L, int h, MetricType metric)
: IndexNeuralNetCodec(d, M, nbits, metric),
qinco(d, 1 << nbits, L, M, h) {
net = &qinco;
}

} // namespace faiss
50 changes: 50 additions & 0 deletions faiss/IndexNeuralNetCodec.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <vector>

#include <faiss/IndexFlatCodes.h>
#include <faiss/utils/NeuralNet.h>

namespace faiss {

struct IndexNeuralNetCodec : IndexFlatCodes {
NeuralNetCodec* net = nullptr;
size_t M, nbits;

explicit IndexNeuralNetCodec(
int d = 0,
int M = 0,
int nbits = 0,
MetricType metric = METRIC_L2);

void train(idx_t n, const float* x) override;

void sa_encode(idx_t n, const float* x, uint8_t* codes) const override;
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;

~IndexNeuralNetCodec() {}
};

struct IndexQINCo : IndexNeuralNetCodec {
QINCo qinco;

IndexQINCo(
int d,
int M,
int nbits,
int L,
int h,
MetricType metric = METRIC_L2);

~IndexQINCo() {}
};


} // namespace faiss
1 change: 1 addition & 0 deletions faiss/impl/ResultHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/partitioning.h>

#include <algorithm>
#include <iostream>

Expand Down
8 changes: 8 additions & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
class_wrappers.handle_IDSelectorSubset(IDSelectorBitmap, class_owns=False, force_int64=False)
class_wrappers.handle_CodeSet(CodeSet)

class_wrappers.handle_Tensor2D(Tensor2D)
class_wrappers.handle_Tensor2D(Int32Tensor2D)
class_wrappers.handle_Embedding(Embedding)
class_wrappers.handle_Linear(Linear)
class_wrappers.handle_QINCo(QINCo)
class_wrappers.handle_QINCoStep(QINCoStep)


this_module = sys.modules[__name__]

# handle sub-classes
Expand Down
149 changes: 149 additions & 0 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,3 +1247,152 @@ def replacement_insert(self, codes, inserted=None):
return inserted

replace_method(the_class, 'insert', replacement_insert)

######################################################
# Syntatic sugar for NeuralNet classes
######################################################


def handle_Tensor2D(the_class):
the_class.original_init = the_class.__init__

def replacement_init(self, *args):
if len(args) == 1:
array, = args
n, d = array.shape
self.original_init(n, d)
faiss.copy_array_to_vector(
np.ascontiguousarray(array).ravel(), self.v)
else:
self.original_init(*args)

def numpy(self):
shape = np.zeros(2, dtype=np.int64)
faiss.memcpy(faiss.swig_ptr(shape), self.shape, shape.nbytes)
return faiss.vector_to_array(self.v).reshape(shape[0], shape[1])

the_class.__init__ = replacement_init
the_class.numpy = numpy


def handle_Embedding(the_class):
the_class.original_init = the_class.__init__

def replacement_init(self, *args):
if len(args) != 1 or args[0].__class__ == the_class:
self.original_init(*args)
return
# assume it's a torch.Embedding
emb = args[0]
self.original_init(emb.num_embeddings, emb.embedding_dim)
self.from_torch(emb)

def from_torch(self, emb):
""" copy weights from torch.Embedding """
assert emb.weight.shape == (self.num_embeddings, self.embedding_dim)
faiss.copy_array_to_vector(
np.ascontiguousarray(emb.weight.data).ravel(), self.weight)

def from_array(self, array):
""" copy weights from numpy array """
assert array.shape == (self.num_embeddings, self.embedding_dim)
faiss.copy_array_to_vector(
np.ascontiguousarray(array).ravel(), self.weight)

the_class.from_array = from_array
the_class.from_torch = from_torch
the_class.__init__ = replacement_init


def handle_Linear(the_class):
the_class.original_init = the_class.__init__

def replacement_init(self, *args):
if len(args) != 1 or args[0].__class__ == the_class:
self.original_init(*args)
return
# assume it's a torch.Linear
linear = args[0]
bias = linear.bias is not None
self.original_init(linear.in_features, linear.out_features, bias)
self.from_torch(linear)

def from_torch(self, linear):
""" copy weights from torch.Linear """
assert linear.weight.shape == (self.out_features, self.in_features)
faiss.copy_array_to_vector(
linear.weight.data.numpy().ravel(), self.weight)
if linear.bias is not None:
assert linear.bias.shape == (self.out_features,)
faiss.copy_array_to_vector(linear.bias.data.numpy(), self.bias)

def from_array(self, array, bias=None):
""" copy weights from numpy array """
assert array.shape == (self.out_features, self.in_features)
faiss.copy_array_to_vector(
np.ascontiguousarray(array).ravel(), self.weight)
if bias is not None:
assert bias.shape == (self.out_features,)
faiss.copy_array_to_vector(bias, self.bias)

the_class.__init__ = replacement_init
the_class.from_array = from_array
the_class.from_torch = from_torch

######################################################
# Syntatic sugar for QINCo and QINCoStep
######################################################

def handle_QINCoStep(the_class):
the_class.original_init = the_class.__init__

def replacement_init(self, *args):
if len(args) != 1 or args[0].__class__ == the_class:
self.original_init(*args)
return
step = args[0]
# assume it's a Torch QINCoStep
self.original_init(step.d, step.K, step.L, step.h)
self.from_torch(step)

def from_torch(self, step):
""" copy weights from torch.QINCoStep """
assert (step.d, step.K, step.L, step.h) == (self.d, self.K, self.L, self.h)
self.codebook.from_torch(step.codebook)
self.MLPconcat.from_torch(step.MLPconcat)

for l in range(step.L):
src = step.residual_blocks[l]
dest = self.get_residual_block(l)
dest.linear1.from_torch(src[0])
dest.linear2.from_torch(src[2])

the_class.__init__ = replacement_init
the_class.from_torch = from_torch


def handle_QINCo(the_class):
the_class.original_init = the_class.__init__

def replacement_init(self, *args):
if len(args) != 1 or args[0].__class__ == the_class:
self.original_init(*args)
return

# assume it's a Torch QINCo
qinco = args[0]
self.original_init(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h)
self.from_torch(qinco)

def from_torch(self, qinco):
""" copy weights from torch.QINCo """
assert (
(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h) ==
(self.d, self.K, self.L, self.M, self.h)
)
self.codebook0.from_torch(qinco.codebook0)
for m in range(qinco.M - 1):
self.get_step(m).from_torch(qinco.steps[m])

the_class.__init__ = replacement_init
the_class.from_torch = from_torch
Loading

0 comments on commit 7279eda

Please sign in to comment.