Skip to content

Commit

Permalink
Implement NNUE Evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
primenumber committed Apr 30, 2024
1 parent a789296 commit a2b663c
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod eval;
pub mod hand;
pub mod last_cache;
pub mod midgame;
pub mod nnue_eval;
pub mod pattern_eval;
pub mod search;
pub mod table;
Expand Down
32 changes: 32 additions & 0 deletions src/engine/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@
mod test;
use crate::engine::board::*;

pub fn pow3(x: i8) -> usize {
if x == 0 {
1
} else {
3 * pow3(x - 1)
}
}

// interprete base-2 number as base-3 number
// base_2_to_3(x) := radix_parse(radix_fmt(x, 2), 3)
const fn base_2_to_3(mut x: usize) -> usize {
let mut base3 = 0;
let mut pow3 = 1;
while x > 0 {
base3 += (x % 2) * pow3;
pow3 *= 3;
x /= 2;
}
base3
}

const BASE_2_TO_3_TABLE_BITS: usize = 13;
pub const BASE_2_TO_3: [usize; 1 << BASE_2_TO_3_TABLE_BITS] = {
let mut table = [0usize; 1 << BASE_2_TO_3_TABLE_BITS];
let mut i = 0;
while i < table.len() {
table[i] = base_2_to_3(i);
i += 1;
}
table
};

pub trait Evaluator: Send + Sync {
fn eval(&self, board: Board) -> i16;
fn score_scale(&self) -> i16;
Expand Down
230 changes: 230 additions & 0 deletions src/engine/nnue_eval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
use crate::engine::bits::*;
use crate::engine::board::*;
use crate::engine::eval::*;
use serde::Deserialize;
use std::fs::File;
use std::io::{BufReader, Read};
use std::mem;
use std::path::Path;

#[derive(Deserialize, Debug)]
struct NNUEConfig {
#[serde(with = "pattern_format")]
patterns: Vec<u64>,
front: usize,
middle: usize,
back: usize,
}

mod pattern_format {
use serde::{Deserialize, Deserializer};
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u64>, D::Error>
where
D: Deserializer<'de>,
{
let v = Vec::<String>::deserialize(deserializer)?;
v.into_iter()
.map(|s| u64::from_str_radix(&s, 16))
.try_collect()
.map_err(|e| serde::de::Error::custom(format!("Failed to parse pattern: {:?}", e)))
}
}

impl NNUEConfig {
fn from_file(path: &Path) -> Option<NNUEConfig> {
let file = File::open(path).unwrap();
let reader = BufReader::new(file);
Some(serde_json::from_reader(reader).unwrap())
}
}

pub struct NNUEEvaluator {
config: NNUEConfig,
embedding: Vec<f32>,
offsets: Vec<usize>,
pos_to_indices: Vec<[usize; 64]>,
layer1_weight: Vec<f32>,
layer1_bias: Vec<f32>,
layer2_weight: Vec<f32>,
layer2_bias: Vec<f32>,
layer3_weight: Vec<f32>,
layer3_bias: Vec<f32>,
}

impl NNUEEvaluator {
fn load_param(path: &Path, length: usize) -> Option<Vec<f32>> {
let mut value_file = File::open(path).ok()?;
let mut buf = vec![0u8; length * 4];
value_file.read_exact(&mut buf).ok()?;
let mut v = Vec::with_capacity(length);
for ary in buf.as_chunks().0 {
let raw_weight = unsafe { mem::transmute::<[u8; 4], f32>(*ary) };
v.push(raw_weight);
}
Some(v)
}

fn normalize_vec(v: &mut [f32]) {
let mut sum = 0.;
for e in v.iter() {
sum += *e * *e;
}
if sum > 1.0 {
let scale = 1.0 / sum.sqrt();
for e in v {
*e *= scale;
}
}
}

fn generate_pos_to_indices(pattern: u64) -> [usize; 64] {
let mut count = 0;
let mut result = [0; 64];
for pos in 0..64 {
if (pattern >> pos) & 1 == 1 {
result[pos] = pow3(count);
count += 1;
}
}
result
}

fn transpose_mat(v: &mut [f32], row: usize, col: usize) {
let mut tmp = vec![0.; row * col];
for i in 0..row {
for j in 0..col {
tmp[i * col + j] = v[i + j * row];
}
}
v.copy_from_slice(&tmp);
}

pub fn load(path: &Path) -> Option<Self> {
let config = NNUEConfig::from_file(&path.join("config.json"))?;
let mut offsets = Vec::new();
let mut offset = 0;
for pattern_bits in &config.patterns {
offsets.push(offset);
offset += pow3(pattern_bits.count_ones() as i8);
}
let mut embedding = Self::load_param(&path.join("embedding.weight"), offset * config.front)?;
for chunk in embedding.chunks_mut(config.front) {
Self::normalize_vec(chunk);
}
let embedding = embedding;
let mut layer1_weight = Self::load_param(
&path.join("backend_block.0.weight"),
config.front * config.middle,
)?;
Self::transpose_mat(&mut layer1_weight, config.front, config.middle);
let mut layer2_weight = Self::load_param(
&path.join("backend_block.2.weight"),
config.middle * config.back,
)?;
Self::transpose_mat(&mut layer2_weight, config.middle, config.back);
let layer3_weight = Self::load_param(&path.join("backend_block.4.weight"), config.back)?;
let layer1_bias = Self::load_param(&path.join("backend_block.0.bias"), config.middle)?;
let layer2_bias = Self::load_param(&path.join("backend_block.2.bias"), config.back)?;
let layer3_bias = Self::load_param(&path.join("backend_block.4.bias"), 1)?;
let pos_to_indices = config
.patterns
.iter()
.map(|pattern| Self::generate_pos_to_indices(*pattern))
.collect();
Some(Self {
config,
embedding,
offsets,
pos_to_indices,
layer1_weight,
layer1_bias,
layer2_weight,
layer2_bias,
layer3_weight,
layer3_bias,
})
}

fn lookup_vec(&self, index: usize) -> &[f32] {
let first = index * self.config.front;
let last = first + self.config.front;
&self.embedding[first..last]
}

fn score_scale() -> i16 {
256
}

fn score_min() -> i16 {
-Self::score_scale() * BOARD_SIZE as i16
}

fn score_max() -> i16 {
Self::score_scale() * BOARD_SIZE as i16
}
}

impl Evaluator for NNUEEvaluator {
fn eval(&self, mut board: Board) -> i16 {
let mut front_vec = vec![0.0; self.config.front];
for _ in 0..4 {
for (pattern, offset) in self.config.patterns.iter().zip(self.offsets.iter()) {
let pbits = board.player.pext(*pattern) as usize;
let obits = board.opponent.pext(*pattern) as usize;
let index = BASE_2_TO_3[pbits] + 2 * BASE_2_TO_3[obits] + offset;
for (e, f) in self.lookup_vec(index).into_iter().zip(front_vec.iter_mut()) {
*f += *e;
}
}
let board_flip = board.flip_diag();
for (pattern, offset) in self.config.patterns.iter().zip(self.offsets.iter()) {
let pbits = board_flip.player.pext(*pattern) as usize;
let obits = board_flip.opponent.pext(*pattern) as usize;
let index = BASE_2_TO_3[pbits] + 2 * BASE_2_TO_3[obits] + offset;
for (e, f) in self.lookup_vec(index).into_iter().zip(front_vec.iter_mut()) {
*f += *e;
}
}
board = board.rot90();
}
let mut middle_vec = self.layer1_bias.clone();
for (j, fe) in front_vec.iter().enumerate() {
for (i, me) in middle_vec.iter_mut().enumerate() {
*me += *fe * unsafe { *self.layer1_weight.get_unchecked(i + self.config.middle * j) };
}
}
for me in middle_vec.iter_mut() {
if *me < 0. {
*me = 0.;
}
}
let mut back_vec = self.layer2_bias.clone();
for (j, me) in middle_vec.iter().enumerate() {
for (i, be) in back_vec.iter_mut().enumerate() {
*be += *me * unsafe { *self.layer2_weight.get_unchecked(i + self.config.back * j) };
}
}
for be in back_vec.iter_mut() {
if *be < 0. {
*be = 0.;
}
}
let mut result = self.layer3_bias[0];
for (be, w) in back_vec.iter().zip(self.layer3_weight.iter()) {
result += *be * *w;
}
(Self::score_scale() as f32 * result.clamp(-64., 64.)).round() as i16
}

fn score_scale(&self) -> i16 {
Self::score_scale()
}

fn score_min(&self) -> i16 {
Self::score_min()
}

fn score_max(&self) -> i16 {
Self::score_max()
}
}
32 changes: 0 additions & 32 deletions src/engine/pattern_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,32 +171,8 @@ impl FoldedEvaluator {
}
}

// interprete base-2 number as base-3 number
// base_2_to_3(x) := radix_parse(radix_fmt(x, 2), 3)
const fn base_2_to_3(mut x: usize) -> usize {
let mut base3 = 0;
let mut pow3 = 1;
while x > 0 {
base3 += (x % 2) * pow3;
pow3 *= 3;
x /= 2;
}
base3
}

const NON_PATTERN_SCORES: usize = 4;

const BASE_2_TO_3_TABLE_BITS: usize = 13;
const BASE_2_TO_3: [usize; 1 << BASE_2_TO_3_TABLE_BITS] = {
let mut table = [0usize; 1 << BASE_2_TO_3_TABLE_BITS];
let mut i = 0;
while i < table.len() {
table[i] = base_2_to_3(i);
i += 1;
}
table
};

// x = R^rot M ^ mirror
struct SquareGroup {
rot: u8,
Expand Down Expand Up @@ -438,14 +414,6 @@ pub struct PatternLinearEvaluator {
vectorizer: IndicesVectorizer<INDICES_VECTORIZER_PACK_SIZE>,
}

fn pow3(x: i8) -> usize {
if x == 0 {
1
} else {
3 * pow3(x - 1)
}
}

const SCALE: i16 = 256;
const EVAL_SCORE_MAX: i16 = BOARD_SIZE as i16 * SCALE;
const EVAL_SCORE_MIN: i16 = -EVAL_SCORE_MAX;
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(const_option)]
#![feature(portable_simd)]
#![feature(iterator_try_collect)]
#![feature(slice_as_chunks)]
#![feature(test)]
mod book;
mod compression;
Expand Down
6 changes: 3 additions & 3 deletions src/setup.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::engine::pattern_eval::*;
use crate::engine::nnue_eval::*;
use crate::engine::search::*;
use crate::engine::table::*;
use std::path::Path;
use std::sync::Arc;

pub fn setup_default() -> SolveObj<PatternLinearEvaluator> {
pub fn setup_default() -> SolveObj<NNUEEvaluator> {
let res_cache = Arc::new(ResCacheTable::new(2048, 16384));
let eval_cache = Arc::new(EvalCacheTable::new(2048, 16384));
let evaluator = Arc::new(PatternLinearEvaluator::load(Path::new("table-220710")).unwrap());
let evaluator = Arc::new(NNUEEvaluator::load(Path::new("nnue_32x64x32_240429")).unwrap());
let search_params = SearchParams {
reduce: false,
parallel_depth_limit: 16,
Expand Down

0 comments on commit a2b663c

Please sign in to comment.