forked from Enet4/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Heap.cpp
122 lines (101 loc) · 2.96 KB
/
Heap.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
/* Function for soft heap */
#include "Heap.h"
namespace faiss {
template <typename C>
void HeapArray<C>::heapify ()
{
#pragma omp parallel for
for (size_t j = 0; j < nh; j++)
heap_heapify<C> (k, val + j * k, ids + j * k);
}
template <typename C>
void HeapArray<C>::reorder ()
{
#pragma omp parallel for
for (size_t j = 0; j < nh; j++)
heap_reorder<C> (k, val + j * k, ids + j * k);
}
template <typename C>
void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
size_t i0, long ni)
{
if (ni == -1) ni = nh;
assert (i0 >= 0 && i0 + ni <= nh);
#pragma omp parallel for
for (size_t i = i0; i < i0 + ni; i++) {
T * __restrict simi = get_val(i);
TI * __restrict idxi = get_ids (i);
const T *ip_line = vin + (i - i0) * nj;
for (size_t j = 0; j < nj; j++) {
T ip = ip_line [j];
if (C::cmp(simi[0], ip)) {
heap_pop<C> (k, simi, idxi);
heap_push<C> (k, simi, idxi, ip, j + j0);
}
}
}
}
template <typename C>
void HeapArray<C>::addn_with_ids (
size_t nj, const T *vin, const TI *id_in,
long id_stride, size_t i0, long ni)
{
if (id_in == nullptr) {
addn (nj, vin, 0, i0, ni);
return;
}
if (ni == -1) ni = nh;
assert (i0 >= 0 && i0 + ni <= nh);
#pragma omp parallel for
for (size_t i = i0; i < i0 + ni; i++) {
T * __restrict simi = get_val(i);
TI * __restrict idxi = get_ids (i);
const T *ip_line = vin + (i - i0) * nj;
const TI *id_line = id_in + (i - i0) * id_stride;
for (size_t j = 0; j < nj; j++) {
T ip = ip_line [j];
if (C::cmp(simi[0], ip)) {
heap_pop<C> (k, simi, idxi);
heap_push<C> (k, simi, idxi, ip, id_line [j]);
}
}
}
}
template <typename C>
void HeapArray<C>::per_line_extrema (
T * out_val,
TI * out_ids) const
{
#pragma omp parallel for
for (size_t j = 0; j < nh; j++) {
long imin = -1;
typename C::T xval = C::Crev::neutral ();
const typename C::T * x_ = val + j * k;
for (size_t i = 0; i < k; i++)
if (C::cmp (x_[i], xval)) {
xval = x_[i];
imin = i;
}
if (out_val)
out_val[j] = xval;
if (out_ids) {
if (ids && imin != -1)
out_ids[j] = ids [j * k + imin];
else
out_ids[j] = imin;
}
}
}
// explicit instanciations
template struct HeapArray<CMin <float, long> >;
template struct HeapArray<CMax <float, long> >;
template struct HeapArray<CMin <int, long> >;
template struct HeapArray<CMax <int, long> >;
} // END namespace fasis