Skip to content

Commit

Permalink
Optimize L2 norm computation with optional vectorization. Right now o…
Browse files Browse the repository at this point in the history
…nly float

is vectorized, other specialization will be added in a subsequent
commit.

To benchmark the query, I've modified flann_example_cpp to run 1000
query loops instead of just one.

Before this change, computing the norm is 36.4% of the execution. Even
though the loop is unrolled by a factor 4, loads, additions and
multiplications are still scalar. After the change, the loop is
vectorized. When the max distance is given, we still have to reduce
at every iteration to compare. Else, we only need a single reduce at the
end.
In the former case (worst_dist >= 0), computing the norm becomes 29.8% of execution. Execution time drops from 35.1 to 31.8s (10% improvement).
In the latter case (worst_dist < 0), computing the norm becomes 24.2% of execution. Execution time drops from 35.1 to 28.2 (20% improvement).

Before:
```
ROUTINE ======================== flann::L2::operator()
47564889135 47564889135 (flat, cum) 36.44% of Total
  35092264   35092264      1f830: lea    (%rsi,%rcx,4),%rcx               ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:154
 328058643  328058643      1f834: movaps %xmm0,%xmm5                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:150
  20314866   20314866      1f837: lea    -0xc(%rcx),%rax                  ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:155
    940905     940905      1f83b: cmp    %rax,%rsi                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:158
         .          .      1f83e: jae    1f8e1 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0xb1>
  18337652   18337652      1f844: pxor   %xmm6,%xmm6                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:152
 241505735  241505735      1f848: movaps %xmm6,%xmm0
   8351632    8351632      1f84b: nopl   0x0(%rax,%rax,1)
2483873273 2483873273      1f850: movss  (%rsi),%xmm1                     ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:159
 220146242  220146242      1f854: movss  0x4(%rsi),%xmm4                  ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:160
   3650929    3650929      1f859: add    $0x10,%rdx                       ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:165
3336460614 3336460614      1f85d: add    $0x10,%rsi                       ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:164
 393866511  393866511      1f861: subss  -0x10(%rdx),%xmm1                ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:159
 366575932  366575932      1f866: subss  -0xc(%rdx),%xmm4                 ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:160
   3634832    3634832      1f86b: movss  -0x8(%rsi),%xmm3                 ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:161
9095107561 9095107561      1f870: subss  -0x8(%rdx),%xmm3
 256805437  256805437      1f875: movss  -0x4(%rsi),%xmm2                 ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:162
 297745047  297745047      1f87a: subss  -0x4(%rdx),%xmm2
   3712756    3712756      1f87f: comiss %xmm6,%xmm5                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:167
7525622349 7525622349      1f882: mulss  %xmm1,%xmm1                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:163
 617814475  617814475      1f886: mulss  %xmm4,%xmm4
 111702640  111702640      1f88a: mulss  %xmm3,%xmm3
 261873539  261873539      1f88e: mulss  %xmm2,%xmm2
7422685724 7422685724      1f892: addss  %xmm4,%xmm1
 970925099  970925099      1f896: addss  %xmm3,%xmm1
5221453445 5221453445      1f89a: addss  %xmm2,%xmm1
4987699692 4987699692      1f89e: addss  %xmm1,%xmm0
                                                                          ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:167
1526384608 1526384608      1f8a2: jbe    1f8a9 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x79>
         .          .      1f8a4: comiss %xmm5,%xmm0
         .          .      1f8a7: ja     1f8d8 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0xa8>
         .          .      1f8a9: cmp    %rsi,%rax                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:158
1558693597 1558693597      1f8ac: ja     1f850 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x20>
         .          .      1f8ae: cmp    %rsi,%rcx                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:172
   8300423    8300423      1f8b1: jbe    1f8e0 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0xb0>
         .          .      1f8b3: add    $0x4,%rdx                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:173
         .          .      1f8b7: movss  (%rsi),%xmm1
         .          .      1f8bb: add    $0x4,%rsi
         .          .      1f8bf: subss  -0x4(%rdx),%xmm1
         .          .      1f8c4: mulss  %xmm1,%xmm1                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:174
         .          .      1f8c8: addss  %xmm1,%xmm0
         .          .      1f8cc: cmp    %rsi,%rcx                        ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:172
         .          .      1f8cf: ja     1f8b3 <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x83>
         .          .      1f8d1: retq                                    ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:177
         .          .      1f8d2: nopw   0x0(%rax,%rax,1)
         .          .      1f8d8: retq
         .          .      1f8d9: nopl   0x0(%rax)
 237552713  237552713      1f8e0: retq
         .          .      1f8e1: pxor   %xmm0,%xmm0                      ;_ZNK5flann2L2IfEclIPfPKfEEfT_T0_mf dist.h:152
         .          .      1f8e5: jmp    1f8ae <float flann::L2<float>::operator()<float*, float const*>(float*, float const*, unsigned long, float) const+0x7e>
         .          .      1f8e7: nopw   0x0(%rax,%rax,1)
```

After:

```
ROUTINE ======================== flann::L2::Compute
34754208773 34754208773 (flat, cum) 29.84% of Total
  48641080   48641080      13d20: pxor   %xmm3,%xmm3                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
 333401185  333401185      13d24: lea    (%rdi,%rdx,4),%r8                ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:189
  11949349   11949349      13d28: comiss %xmm3,%xmm0                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
   1843594    1843594      13d2b: lea    -0xc(%r8),%rcx
  26596752   26596752      13d2f: ja     13d75 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x55>
 263415728  263415728      13d31: jmp    13da8 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x88>

                           [...non-taken branch...]

  31882241   31882241      13d99: cmp    %rdi,%r8                         ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:200
                                                                          ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf
         .          .      13d9c: ja     13d80 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x60>
  34863218   34863218      13d9e: movaps %xmm3,%xmm0                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:205
 157914789  157914789      13da1: retq
         .          .      13da2: nopw   0x0(%rax,%rax,1)                 ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf
   7396959    7396959      13da8: mov    %rsi,%rdx                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
   3658216    3658216      13dab: mov    %rdi,%rax
  45037071   45037071      13dae: cmp    %rcx,%rdi
                                                                          ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff
         .          .      13db1: jae    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
 294492579  294492579      13db3: nopl   0x0(%rax,%rax,1)                 ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
2729755072 2729755072      13db8: movups (%rax),%xmm0                     ;_Z10_mm_sub_psDv4_fS_ dist.h:196
1834682910 1834682910      13dbb: movups (%rdx),%xmm5
 645605525  645605525      13dbe: add    $0x10,%rax                       ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
 903030920  903030920      13dc2: add    $0x10,%rdx
2299620094 2299620094      13dc6: subps  %xmm5,%xmm0                      ;_Z10_mm_sub_psDv4_fS_ dist.h:196
2684229688 2684229688      13dc9: mulps  %xmm0,%xmm0                      ;_Z10_mm_mul_psDv4_fS_ dist.h:196
3352147473 3352147473      13dcc: movaps %xmm0,%xmm2                      ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rff dist.h:196
1097632435 1097632435      13dcf: movaps %xmm0,%xmm1
1530297598 1530297598      13dd2: unpckhps %xmm0,%xmm2
1760337677 1760337677      13dd5: shufps $0xff,%xmm0,%xmm1
3750657802 3750657802      13dd9: addss  %xmm2,%xmm1
 769888830  769888830      13ddd: movaps %xmm0,%xmm2
 956681123  956681123      13de0: shufps $0x55,%xmm0,%xmm2
2762013893 2762013893      13de4: addss  %xmm2,%xmm1
3084482538 3084482538      13de8: addss  %xmm1,%xmm0
1995863847 1995863847      13dec: addss  %xmm0,%xmm3
   3697571    3697571      13df0: cmp    %rcx,%rax
 991839621  991839621      13df3: jb     13db8 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x98>
  21978923   21978923      13df5: mov    %r8,%rax
  27542392   27542392      13df8: sub    %rdi,%rax
  36804713   36804713      13dfb: sub    $0xd,%rax
   5487291    5487291      13dff: and    $0xfffffffffffffff0,%rax
  32055141   32055141      13e03: add    $0x10,%rax
  32198989   32198989      13e07: add    %rax,%rdi
 176303775  176303775      13e0a: add    %rax,%rsi
   8280171    8280171      13e0d: jmp    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
```

After (worst_dist <= 0):
```
ROUTINE ======================== flann::L2::Compute
25758491778 25758491778 (flat, cum) 24.17% of Total
 240673135  240673135      13d20: pxor   %xmm3,%xmm3                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
 137020253  137020253      13d24: lea    (%rdi,%rdx,4),%r8                ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:189
  11939237   11939237      13d28: comiss %xmm3,%xmm0                      ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:191
  58564715   58564715      13d2b: lea    -0xc(%r8),%rcx
 111559377  111559377      13d2f: ja     13d75 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x55>
 118986257  118986257      13d31: jmp    13da8 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x88>

                           [...non-taken branch...]

 70503557   70503557      13da8: cmp    %rcx,%rdi                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
                                                                          ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13dab: jae    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
    876666     876666      13dad: mov    %rsi,%rdx                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
 124448567  124448567      13db0: mov    %rdi,%rax
 136135059  136135059      13db3: pxor   %xmm1,%xmm1
  79507393   79507393      13db7: nopw   0x0(%rax,%rax,1)
2822080992 2822080992      13dc0: movups (%rax),%xmm0                     ;_Z10_mm_sub_psDv4_fS_ dist.h:196
  31839811   31839811      13dc3: movups (%rdx),%xmm6
 363014116  363014116      13dc6: add    $0x10,%rax                       ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
  22856826   22856826      13dca: add    $0x10,%rdx
1031249812 1031249812      13dce: subps  %xmm6,%xmm0                      ;_Z10_mm_sub_psDv4_fS_ dist.h:196
1682412937 1682412937      13dd1: mulps  %xmm0,%xmm0                      ;_Z10_mm_mul_psDv4_fS_ dist.h:196
17546296134 17546296134      13dd4: addps  %xmm0,%xmm1                      ;_Z10_mm_add_psDv4_fS_ dist.h:196
   1780455    1780455      13dd7: cmp    %rcx,%rax                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
  36550095   36550095      13dda: jb     13dc0 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0xa0>
    891637     891637      13ddc: movaps %xmm1,%xmm3
  24502724   24502724      13ddf: movaps %xmm1,%xmm0
    905376     905376      13de2: mov    %r8,%rax
  12790369   12790369      13de5: shufps $0x55,%xmm1,%xmm0
 342813117  342813117      13de9: addss  %xmm0,%xmm3
    918634     918634      13ded: movaps %xmm1,%xmm0
         .          .      13df0: sub    %rdi,%rax                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13df3: unpckhps %xmm1,%xmm0
   2675031    2675031      13df6: sub    $0xd,%rax                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
         .          .      13dfa: shufps $0xff,%xmm1,%xmm1                ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13dfe: and    $0xfffffffffffffff0,%rax
 449641847  449641847      13e02: addss  %xmm0,%xmm3                      ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
   2732463    2732463      13e06: add    $0x10,%rax
         .          .      13e0a: add    %rax,%rdi                        ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf
         .          .      13e0d: add    %rax,%rsi
 290497884  290497884      13e10: addss  %xmm1,%xmm3                      ;_ZN5flann2L2IfE14VectorizedLoopERPKfS3_S4_Rf dist.h:196
                                                                          ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf dist.h:200
   1827302    1827302      13e14: jmp    13d99 <float flann::L2<float>::Compute<float const*, float const*>(float const*, float const*, unsigned long, float) const [clone .constprop.0]+0x79>
         .          .      13e16: nopw   %cs:0x0(%rax,%rax,1)             ;_ZNK5flann2L2IfE7ComputeIPKfS4_EEfT_T0_mf.constprop.0
```
  • Loading branch information
legrosbuffle committed Jun 19, 2020
1 parent 1d04523 commit 802bddb
Showing 1 changed file with 103 additions and 25 deletions.
128 changes: 103 additions & 25 deletions src/cpp/flann/algorithms/dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ typedef unsigned __int64 uint64_t;

#include "flann/defines.h"

#ifdef __SSE2__
#include <xmmintrin.h>
#endif


namespace flann
{
Expand Down Expand Up @@ -137,6 +141,11 @@ struct L2
typedef T ElementType;
typedef typename Accumulator<T>::Type ResultType;

template <typename Iterator>
struct ConstIterator { using type = Iterator; };
template <typename U>
struct ConstIterator<U*> { using type = const U*; };

/**
* Compute the squared Euclidean distance between two vectors.
*
Expand All @@ -149,31 +158,13 @@ struct L2
template <typename Iterator1, typename Iterator2>
ResultType operator()(Iterator1 a, Iterator2 b, size_t size, ResultType worst_dist = -1) const
{
ResultType result = ResultType();
ResultType diff0, diff1, diff2, diff3;
Iterator1 last = a + size;
Iterator1 lastgroup = last - 3;

/* Process 4 items with each loop for efficiency. */
while (a < lastgroup) {
diff0 = (ResultType)(a[0] - b[0]);
diff1 = (ResultType)(a[1] - b[1]);
diff2 = (ResultType)(a[2] - b[2]);
diff3 = (ResultType)(a[3] - b[3]);
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;

if ((worst_dist>0)&&(result>worst_dist)) {
return result;
}
}
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
while (a < last) {
diff0 = (ResultType)(*a++ - *b++);
result += diff0 * diff0;
}
return result;
/* We only ever read from a and b, so we pass const versions to `Compute`.
* This ensures that in the case of pointers, the const* version is
* selected, and avoids having to write const and non-const overloads of
* VectorizedLoop. */
return Compute(static_cast<typename ConstIterator<Iterator1>::type>(a),
static_cast<typename ConstIterator<Iterator2>::type>(b),
size, worst_dist);
}

/**
Expand All @@ -187,6 +178,93 @@ struct L2
{
return (a-b)*(a-b);
}

private:
static_assert(std::is_same<typename ConstIterator<float*>::type, const float*>::value, "");
static_assert(std::is_same<typename ConstIterator<const float*>::type, const float*>::value, "");

template <typename ConstIterator1, typename ConstIterator2>
ResultType Compute(ConstIterator1 a, ConstIterator2 b, size_t size, ResultType worst_dist) const
{
ConstIterator1 last = a + size;
ResultType result = ResultType();
/* Process several pixels at a time. */
if (worst_dist>0) {
if (VectorizedLoop(a, last, b, result, worst_dist)) {
return result;
}
} else {
VectorizedLoop(a, last, b, result);
}

/* Process last pixels. Not needed for standard vector lengths. */
while (a < last) {
ResultType diff0 = (ResultType)(*a++ - *b++);
result += diff0 * diff0;
}
return result;
}

/* Default loop implementation.. */
template <typename ConstIterator1, typename ConstIterator2>
static inline bool VectorizedLoop(ConstIterator1& a, ConstIterator1 last, ConstIterator2& b, ResultType& result, ResultType worst_dist = 0) {
ConstIterator1 lastgroup = last - 3;
/* Process 4 items with each loop for efficiency. */
while (a < lastgroup) {
ResultType diff0 = (ResultType)(a[0] - b[0]);
ResultType diff1 = (ResultType)(a[1] - b[1]);
ResultType diff2 = (ResultType)(a[2] - b[2]);
ResultType diff3 = (ResultType)(a[3] - b[3]);
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;

if ((worst_dist>0) &&(result>worst_dist)) {
return true;
}
}
return false;
};

#ifdef __SSE2__
/* A more efficient loop for (const float*, const float*) -> float. */
static inline bool VectorizedLoop(const float*& a, const float* last, const float*& b, float& result, float worst_dist) {
const float* const lastgroup = last - 3;
/* Process 4 items in parallel. */
/* When a worst_dist is provided, we have to reduce at every iteration
* to check*/
while (a < lastgroup) {
const __m128 diff = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
const __m128 sqr = _mm_mul_ps(diff, diff);
float elements[4];
memcpy(elements, &sqr, sizeof(__m128));
result+= elements[3] + elements[2] + elements[1] + elements[0];
a += 4;
b += 4;

if ((worst_dist>0)&&(result>worst_dist)) {
return true;
}
}
return false;
};
static inline void VectorizedLoop(const float*& a, const float* last, const float*& b, float& result) {
const float* const lastgroup = last - 3;
/* Process 4 items in parallel. */
__m128 v_result = _mm_set1_ps(0.0f);
while (a < lastgroup) {
const __m128 diff = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
const __m128 sqr = _mm_mul_ps(diff, diff);
v_result = _mm_add_ps(v_result, sqr);
a += 4;
b += 4;
}
float elements[4];
memcpy(elements, &v_result, sizeof(__m128));
result = elements[0] + elements[1] + elements[2] + elements[3];
};
#endif

};


Expand Down

0 comments on commit 802bddb

Please sign in to comment.