-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
IndexFlatCodes.cpp
117 lines (99 loc) · 3.31 KB
/
IndexFlatCodes.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexFlatCodes.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/CodePacker.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
namespace faiss {
IndexFlatCodes::IndexFlatCodes(size_t code_size, idx_t d, MetricType metric)
: Index(d, metric), code_size(code_size) {}
IndexFlatCodes::IndexFlatCodes() : code_size(0) {}
void IndexFlatCodes::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
if (n == 0) {
return;
}
codes.resize((ntotal + n) * code_size);
sa_encode(n, x, codes.data() + (ntotal * code_size));
ntotal += n;
}
void IndexFlatCodes::reset() {
codes.clear();
ntotal = 0;
}
size_t IndexFlatCodes::sa_code_size() const {
return code_size;
}
size_t IndexFlatCodes::remove_ids(const IDSelector& sel) {
idx_t j = 0;
for (idx_t i = 0; i < ntotal; i++) {
if (sel.is_member(i)) {
// should be removed
} else {
if (i > j) {
memmove(&codes[code_size * j],
&codes[code_size * i],
code_size);
}
j++;
}
}
size_t nremove = ntotal - j;
if (nremove > 0) {
ntotal = j;
codes.resize(ntotal * code_size);
}
return nremove;
}
void IndexFlatCodes::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
sa_decode(ni, codes.data() + i0 * code_size, recons);
}
void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
reconstruct_n(key, 1, recons);
}
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
FAISS_THROW_MSG("not implemented");
}
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
// minimal sanity checks
const IndexFlatCodes* other =
dynamic_cast<const IndexFlatCodes*>(&otherIndex);
FAISS_THROW_IF_NOT(other);
FAISS_THROW_IF_NOT(other->d == d);
FAISS_THROW_IF_NOT(other->code_size == code_size);
FAISS_THROW_IF_NOT_MSG(
typeid(*this) == typeid(*other),
"can only merge indexes of the same type");
}
void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatCodes index");
check_compatible_for_merge(otherIndex);
IndexFlatCodes* other = static_cast<IndexFlatCodes*>(&otherIndex);
codes.resize((ntotal + other->ntotal) * code_size);
memcpy(codes.data() + (ntotal * code_size),
other->codes.data(),
other->ntotal * code_size);
ntotal += other->ntotal;
other->reset();
}
CodePacker* IndexFlatCodes::get_CodePacker() const {
return new CodePackerFlat(code_size);
}
void IndexFlatCodes::permute_entries(const idx_t* perm) {
std::vector<uint8_t> new_codes(codes.size());
for (idx_t i = 0; i < ntotal; i++) {
memcpy(new_codes.data() + i * code_size,
codes.data() + perm[i] * code_size,
code_size);
}
std::swap(codes, new_codes);
}
} // namespace faiss