Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NUMA memory replication for NNUE weights #5285

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/ci/libcxx17.imp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{ include: [ "<__fwd/sstream.h>", private, "<iosfwd>", public ] },
{ include: [ "<__fwd/streambuf.h>", private, "<iosfwd>", public ] },
{ include: [ "<__fwd/string_view.h>", private, "<string_view>", public ] },
{ include: [ "<__system_error/errc.h>", private, "<system_error>", public ] },

# Mappings for includes between public headers
{ include: [ "<ios>", public, "<iostream>", public ] },
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \
tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h score.h
tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h score.h numa.h

OBJS = $(notdir $(SRCS:.cpp=.o))

Expand Down
88 changes: 70 additions & 18 deletions src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

#include "engine.h"

#include <cassert>
#include <deque>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <sstream>
#include <string_view>
#include <utility>
#include <vector>
#include <sstream>
#include <iosfwd>
#include <cassert>

#include "evaluate.h"
#include "misc.h"
Expand All @@ -48,10 +48,14 @@ constexpr auto StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq -

Engine::Engine(std::string path) :
binaryDirectory(CommandLine::get_binary_directory(path)),
numaContext(NumaConfig::from_system()),
states(new std::deque<StateInfo>(1)),
networks(NN::Networks(
NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG),
NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))) {
threads(),
networks(
numaContext,
NN::Networks(
NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG),
NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))) {
pos.set(StartFEN, false, &states->back());
capSq = SQ_NONE;
}
Expand All @@ -74,7 +78,7 @@ void Engine::stop() { threads.stop = true; }
void Engine::search_clear() {
wait_for_search_finished();

tt.clear(options["Threads"]);
tt.clear(threads);
threads.clear();

// @TODO wont work with multiple instances
Expand Down Expand Up @@ -124,40 +128,71 @@ void Engine::set_position(const std::string& fen, const std::vector<std::string>

// modifiers

void Engine::resize_threads() { threads.set({options, threads, tt, networks}, updateContext); }
void Engine::set_numa_config_from_option(const std::string& o) {
if (o == "auto" || o == "system")
{
numaContext.set_numa_config(NumaConfig::from_system());
}
else if (o == "none")
{
numaContext.set_numa_config(NumaConfig{});
}
else
{
numaContext.set_numa_config(NumaConfig::from_string(o));
}

// Force reallocation of threads in case affinities need to change.
resize_threads();
}

void Engine::resize_threads() {
threads.wait_for_search_finished();
threads.set(numaContext.get_numa_config(), {options, threads, tt, networks}, updateContext);

// Reallocate the hash with the new threadpool size
set_tt_size(options["Hash"]);
}

void Engine::set_tt_size(size_t mb) {
wait_for_search_finished();
tt.resize(mb, options["Threads"]);
tt.resize(mb, threads);
}

void Engine::set_ponderhit(bool b) { threads.main_manager()->ponder = b; }

// network related

void Engine::verify_networks() const {
networks.big.verify(options["EvalFile"]);
networks.small.verify(options["EvalFileSmall"]);
networks->big.verify(options["EvalFile"]);
networks->small.verify(options["EvalFileSmall"]);
}

void Engine::load_networks() {
load_big_network(options["EvalFile"]);
load_small_network(options["EvalFileSmall"]);
networks.modify_and_replicate([this](NN::Networks& networks_) {
networks_.big.load(binaryDirectory, options["EvalFile"]);
networks_.small.load(binaryDirectory, options["EvalFileSmall"]);
});
threads.clear();
}

void Engine::load_big_network(const std::string& file) {
networks.big.load(binaryDirectory, file);
networks.modify_and_replicate(
[this, &file](NN::Networks& networks_) { networks_.big.load(binaryDirectory, file); });
threads.clear();
}

void Engine::load_small_network(const std::string& file) {
networks.small.load(binaryDirectory, file);
networks.modify_and_replicate(
[this, &file](NN::Networks& networks_) { networks_.small.load(binaryDirectory, file); });
threads.clear();
}

void Engine::save_network(const std::pair<std::optional<std::string>, std::string> files[2]) {
networks.big.save(files[0].first);
networks.small.save(files[1].first);
networks.modify_and_replicate([&files](NN::Networks& networks_) {
networks_.big.save(files[0].first);
networks_.small.save(files[1].first);
});
}

// utility functions
Expand All @@ -169,7 +204,7 @@ void Engine::trace_eval() const {

verify_networks();

sync_cout << "\n" << Eval::trace(p, networks) << sync_endl;
sync_cout << "\n" << Eval::trace(p, *networks) << sync_endl;
}

OptionsMap& Engine::get_options() { return options; }
Expand All @@ -184,4 +219,21 @@ std::string Engine::visualize() const {
return ss.str();
}

std::vector<std::pair<size_t, size_t>> Engine::get_bound_thread_count_by_numa_node() const {
auto counts = threads.get_bound_thread_count_by_numa_node();
const NumaConfig& cfg = numaContext.get_numa_config();
std::vector<std::pair<size_t, size_t>> ratios;
NumaIndex n = 0;
for (; n < counts.size(); ++n)
ratios.emplace_back(counts[n], cfg.num_cpus_in_numa_node(n));
if (!counts.empty())
for (; n < cfg.num_numa_nodes(); ++n)
ratios.emplace_back(0, cfg.num_cpus_in_numa_node(n));
return ratios;
}

std::string Engine::get_numa_config_as_string() const {
return numaContext.get_numa_config().to_string();
}

}
31 changes: 22 additions & 9 deletions src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "thread.h"
#include "tt.h"
#include "ucioption.h"
#include "numa.h"

namespace Stockfish {

Expand All @@ -47,6 +48,13 @@ class Engine {
using InfoIter = Search::InfoIteration;

Engine(std::string path = "");

// Can't be movable due to components holding backreferences to fields
Engine(const Engine&) = delete;
Engine(Engine&&) = delete;
Engine& operator=(const Engine&) = delete;
Engine& operator=(Engine&&) = delete;

~Engine() { wait_for_search_finished(); }

std::uint64_t perft(const std::string& fen, Depth depth, bool isChess960);
Expand All @@ -63,6 +71,7 @@ class Engine {

// modifiers

void set_numa_config_from_option(const std::string& o);
void resize_threads();
void set_tt_size(size_t mb);
void set_ponderhit(bool);
Expand All @@ -83,23 +92,27 @@ class Engine {

// utility functions

void trace_eval() const;
OptionsMap& get_options();
std::string fen() const;
void flip();
std::string visualize() const;
void trace_eval() const;
OptionsMap& get_options();
std::string fen() const;
void flip();
std::string visualize() const;
std::vector<std::pair<size_t, size_t>> get_bound_thread_count_by_numa_node() const;
std::string get_numa_config_as_string() const;

private:
const std::string binaryDirectory;

NumaReplicationContext numaContext;

Position pos;
StateListPtr states;
Square capSq;

OptionsMap options;
ThreadPool threads;
TranspositionTable tt;
Eval::NNUE::Networks networks;
OptionsMap options;
ThreadPool threads;
TranspositionTable tt;
NumaReplicated<Eval::NNUE::Networks> networks;

Search::SearchManager::UpdateContext updateContext;
};
Expand Down
Loading