-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
QINCo implementation in CPU Faiss (#3608)
Summary: Pull Request resolved: #3608 This is a straightforward implementation of QINCo in CPU Faiss, with encoding and decoding capabilities (not training). For this, we translate a simplified version of some torch classes: - tensors, restricted to 2D and int32 + float32 - Linear and Embedding layer Then the QINCoStep and QINCo can just be defined as C++ objects that are copy-constructable. There is some plumbing required in the wrapping layers to support the integration. Pytroch tensors are converted to numpy for getting / setting them in C++. Reviewed By: asadoughi Differential Revision: D59132952
- Loading branch information
1 parent
67fc053
commit 7279eda
Showing
11 changed files
with
1,213 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
This demonstrates how to reproduce the QINCo paper results using the Faiss | ||
QINCo implementation. The code loads the reference model because training | ||
is not implemented in Faiss. | ||
Prepare the data with | ||
cd /tmp | ||
# get the reference qinco code | ||
git clone https://github.com/facebookresearch/Qinco.git | ||
# get the data | ||
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs | ||
# get the model | ||
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt | ||
""" | ||
|
||
import numpy as np | ||
from faiss.contrib.vecs_io import bvecs_mmap | ||
import sys | ||
import time | ||
import torch | ||
import faiss | ||
|
||
# make sure pickle deserialization will work | ||
sys.path.append("/tmp/Qinco") | ||
import model_qinco | ||
|
||
with torch.no_grad(): | ||
|
||
qinco = torch.load("/tmp/bigann_8x8_L2.pt") | ||
qinco.eval() | ||
# print(qinco) | ||
if True: | ||
torch.set_num_threads(1) | ||
faiss.omp_set_num_threads(1) | ||
|
||
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32') | ||
x_scaled = torch.from_numpy(x_base) / qinco.db_scale | ||
|
||
t0 = time.time() | ||
codes, _ = qinco.encode(x_scaled) | ||
x_decoded_scaled = qinco.decode(codes) | ||
print(f"Pytorch encode {time.time() - t0:.3f} s") | ||
# multi-thread: 1.13s, single-thread: 7.744 | ||
|
||
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale | ||
|
||
err = ((x_decoded - x_base) ** 2).sum(1).mean() | ||
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper | ||
|
||
qinco2 = faiss.QINCo(qinco) | ||
t0 = time.time() | ||
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled)) | ||
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale | ||
print(f"Faiss encode {time.time() - t0:.3f} s") | ||
# multi-thread: 3.2s, single thread: 7.019 | ||
|
||
# these tests don't work because there are outlier encodings | ||
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy()) | ||
# np.testing.assert_allclose(x_decoded, x_decoded2) | ||
|
||
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel() | ||
assert ndiff < 0.01 | ||
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum() | ||
assert ndiff / len(x_base) < 0.01 | ||
|
||
err = ((x_decoded2 - x_base) ** 2).sum(1).mean() | ||
print("MSE=", err) # = 14213.551 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/** | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <faiss/IndexNeuralNetCodec.h> | ||
#include <faiss/impl/FaissAssert.h> | ||
#include <faiss/utils/hamming.h> | ||
|
||
namespace faiss { | ||
|
||
/********************************************************* | ||
* IndexNeuralNetCodec implementation | ||
*********************************************************/ | ||
|
||
IndexNeuralNetCodec::IndexNeuralNetCodec( | ||
int d, | ||
int M, | ||
int nbits, | ||
MetricType metric) | ||
: IndexFlatCodes((M * nbits + 7) / 8, d, metric), M(M), nbits(nbits) { | ||
is_trained = false; | ||
} | ||
|
||
void IndexNeuralNetCodec::train(idx_t n, const float* x) { | ||
FAISS_THROW_MSG("Training not implemented in C++, use Pytorch"); | ||
} | ||
|
||
void IndexNeuralNetCodec::sa_encode(idx_t n, const float* x, uint8_t* codes) | ||
const { | ||
nn::Tensor2D x_tensor(n, d, x); | ||
nn::Int32Tensor2D codes_tensor = net->encode(x_tensor); | ||
pack_bitstrings(n, M, nbits, codes_tensor.data(), codes, code_size); | ||
} | ||
|
||
void IndexNeuralNetCodec::sa_decode(idx_t n, const uint8_t* codes, float* x) | ||
const { | ||
nn::Int32Tensor2D codes_tensor(n, M); | ||
unpack_bitstrings(n, M, nbits, codes, code_size, codes_tensor.data()); | ||
nn::Tensor2D x_tensor = net->decode(codes_tensor); | ||
memcpy(x, x_tensor.data(), d * n * sizeof(float)); | ||
} | ||
|
||
/********************************************************* | ||
* IndexQINeuralNetCodec implementation | ||
*********************************************************/ | ||
|
||
IndexQINCo::IndexQINCo(int d, int M, int nbits, int L, int h, MetricType metric) | ||
: IndexNeuralNetCodec(d, M, nbits, metric), | ||
qinco(d, 1 << nbits, L, M, h) { | ||
net = &qinco; | ||
} | ||
|
||
} // namespace faiss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/** | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
|
||
#include <faiss/IndexFlatCodes.h> | ||
#include <faiss/utils/NeuralNet.h> | ||
|
||
namespace faiss { | ||
|
||
struct IndexNeuralNetCodec : IndexFlatCodes { | ||
NeuralNetCodec* net = nullptr; | ||
size_t M, nbits; | ||
|
||
explicit IndexNeuralNetCodec( | ||
int d = 0, | ||
int M = 0, | ||
int nbits = 0, | ||
MetricType metric = METRIC_L2); | ||
|
||
void train(idx_t n, const float* x) override; | ||
|
||
void sa_encode(idx_t n, const float* x, uint8_t* codes) const override; | ||
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override; | ||
|
||
~IndexNeuralNetCodec() {} | ||
}; | ||
|
||
struct IndexQINCo : IndexNeuralNetCodec { | ||
QINCo qinco; | ||
|
||
IndexQINCo( | ||
int d, | ||
int M, | ||
int nbits, | ||
int L, | ||
int h, | ||
MetricType metric = METRIC_L2); | ||
|
||
~IndexQINCo() {} | ||
}; | ||
|
||
|
||
} // namespace faiss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.