Skip to content

Commit

Permalink
use fast log and pow approximations (LeelaChessZero#580)
Browse files Browse the repository at this point in the history
Also add -ffast-math to project arguments
  • Loading branch information
borg323 authored Dec 20, 2018
1 parent e4cfc87 commit 1a5f95f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ endif
if cc.get_id() == 'clang' or cc.get_id() == 'gcc'
add_project_arguments('-Wextra', language : 'cpp')
add_project_arguments('-pedantic', language : 'cpp')
add_project_arguments('-ffast-math', language : 'cpp')

if get_option('buildtype') == 'release'
add_project_arguments('-march=native', language : 'cpp')
Expand Down
8 changes: 6 additions & 2 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "mcts/node.h"
#include "neural/cache.h"
#include "neural/encoder.h"
#include "utils/fastmath.h"
#include "utils/random.h"

namespace lczero {
Expand Down Expand Up @@ -198,7 +199,7 @@ inline float ComputeCpuct(const SearchParams& params, uint32_t N) {
const float init = params.GetCpuct();
const float k = params.GetCpuctFactor();
const float base = params.GetCpuctBase();
return init + (k ? k * std::log((N + base) / base) : 0.0f);
return init + (k ? k * FastLog((N + base) / base) : 0.0f);
}
} // namespace

Expand Down Expand Up @@ -1139,7 +1140,10 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
float p =
computation_->GetPVal(idx_in_computation, edge.GetMove().as_nn_index());
if (params_.GetPolicySoftmaxTemp() != 1.0f) {
p = pow(p, 1 / params_.GetPolicySoftmaxTemp());
// Flush denormals to zero.
p = p < 1.17549435E-38
? 0.0
: FastPow2(FastLog2(p) / params_.GetPolicySoftmaxTemp());
}
edge.edge()->SetP(p);
// Edge::SetP does some rounding, so only add to the total after rounding.
Expand Down
68 changes: 68 additions & 0 deletions src/utils/fastmath.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018 The LCZero Authors
Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.
Additional permission under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/

#pragma once

#include <cstring>

namespace lczero {
// These stunts are performed by trained professionals, do not try this at home.

// Fast approximate log2(x). Does no range checking.
// The approximation used here is log2(2^N*(1+f)) ~ N+f*(1.342671-0.342671*f)
// where N is the integer and f the fractional part, f>=0.
inline float FastLog2(const float a) {
int32_t tmp;
std::memcpy(&tmp, &a, sizeof(float));
int expb = (tmp >> 23);
tmp = (tmp & 0x7fffff) | (0x7f << 23);
float out;
std::memcpy(&out, &tmp, sizeof(float));
return out * (2.028011f - 0.342671f * out) - 128.68534f + expb;
}

// Fast approximate 2^x. Does only limited range checking.
// The approximation used here is 2^(N+f) ~ 2^N*(1+f*(0.656366+0.343634*f))
// where N is the integer and f the fractional part, f>=0.
inline float FastPow2(const float a) {
if (a < -126) return 0.0;
int exp = floor(a);
float out = a - exp;
out = 1.0f + out * (0.656366f + 0.343634f * out);
int32_t tmp;
std::memcpy(&tmp, &out, sizeof(float));
tmp += exp << 23;
std::memcpy(&out, &tmp, sizeof(float));
return out;
}

// Fast approximate ln(x). Does no range checking.
inline float FastLog(const float a) {
return 0.6931471805599453f * FastLog2(a);
}

} // namespace lczero

0 comments on commit 1a5f95f

Please sign in to comment.