From a2b663caf75155da9ab1e305313f82fc8dd5cdab Mon Sep 17 00:00:00 2001 From: primenumber Date: Wed, 1 May 2024 07:42:34 +0900 Subject: [PATCH] Implement NNUE Evaluator --- src/engine.rs | 1 + src/engine/eval.rs | 32 ++++++ src/engine/nnue_eval.rs | 230 +++++++++++++++++++++++++++++++++++++ src/engine/pattern_eval.rs | 32 ------ src/main.rs | 1 + src/setup.rs | 6 +- 6 files changed, 267 insertions(+), 35 deletions(-) create mode 100644 src/engine/nnue_eval.rs diff --git a/src/engine.rs b/src/engine.rs index 2323d56..c465b68 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -5,6 +5,7 @@ pub mod eval; pub mod hand; pub mod last_cache; pub mod midgame; +pub mod nnue_eval; pub mod pattern_eval; pub mod search; pub mod table; diff --git a/src/engine/eval.rs b/src/engine/eval.rs index 19e6374..937c40c 100644 --- a/src/engine/eval.rs +++ b/src/engine/eval.rs @@ -2,6 +2,38 @@ mod test; use crate::engine::board::*; +pub fn pow3(x: i8) -> usize { + if x == 0 { + 1 + } else { + 3 * pow3(x - 1) + } +} + +// interprete base-2 number as base-3 number +// base_2_to_3(x) := radix_parse(radix_fmt(x, 2), 3) +const fn base_2_to_3(mut x: usize) -> usize { + let mut base3 = 0; + let mut pow3 = 1; + while x > 0 { + base3 += (x % 2) * pow3; + pow3 *= 3; + x /= 2; + } + base3 +} + +const BASE_2_TO_3_TABLE_BITS: usize = 13; +pub const BASE_2_TO_3: [usize; 1 << BASE_2_TO_3_TABLE_BITS] = { + let mut table = [0usize; 1 << BASE_2_TO_3_TABLE_BITS]; + let mut i = 0; + while i < table.len() { + table[i] = base_2_to_3(i); + i += 1; + } + table +}; + pub trait Evaluator: Send + Sync { fn eval(&self, board: Board) -> i16; fn score_scale(&self) -> i16; diff --git a/src/engine/nnue_eval.rs b/src/engine/nnue_eval.rs new file mode 100644 index 0000000..a69c7bc --- /dev/null +++ b/src/engine/nnue_eval.rs @@ -0,0 +1,230 @@ +use crate::engine::bits::*; +use crate::engine::board::*; +use crate::engine::eval::*; +use serde::Deserialize; +use std::fs::File; +use std::io::{BufReader, Read}; +use std::mem; +use std::path::Path; + +#[derive(Deserialize, Debug)] +struct NNUEConfig { + #[serde(with = "pattern_format")] + patterns: Vec, + front: usize, + middle: usize, + back: usize, +} + +mod pattern_format { + use serde::{Deserialize, Deserializer}; + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let v = Vec::::deserialize(deserializer)?; + v.into_iter() + .map(|s| u64::from_str_radix(&s, 16)) + .try_collect() + .map_err(|e| serde::de::Error::custom(format!("Failed to parse pattern: {:?}", e))) + } +} + +impl NNUEConfig { + fn from_file(path: &Path) -> Option { + let file = File::open(path).unwrap(); + let reader = BufReader::new(file); + Some(serde_json::from_reader(reader).unwrap()) + } +} + +pub struct NNUEEvaluator { + config: NNUEConfig, + embedding: Vec, + offsets: Vec, + pos_to_indices: Vec<[usize; 64]>, + layer1_weight: Vec, + layer1_bias: Vec, + layer2_weight: Vec, + layer2_bias: Vec, + layer3_weight: Vec, + layer3_bias: Vec, +} + +impl NNUEEvaluator { + fn load_param(path: &Path, length: usize) -> Option> { + let mut value_file = File::open(path).ok()?; + let mut buf = vec![0u8; length * 4]; + value_file.read_exact(&mut buf).ok()?; + let mut v = Vec::with_capacity(length); + for ary in buf.as_chunks().0 { + let raw_weight = unsafe { mem::transmute::<[u8; 4], f32>(*ary) }; + v.push(raw_weight); + } + Some(v) + } + + fn normalize_vec(v: &mut [f32]) { + let mut sum = 0.; + for e in v.iter() { + sum += *e * *e; + } + if sum > 1.0 { + let scale = 1.0 / sum.sqrt(); + for e in v { + *e *= scale; + } + } + } + + fn generate_pos_to_indices(pattern: u64) -> [usize; 64] { + let mut count = 0; + let mut result = [0; 64]; + for pos in 0..64 { + if (pattern >> pos) & 1 == 1 { + result[pos] = pow3(count); + count += 1; + } + } + result + } + + fn transpose_mat(v: &mut [f32], row: usize, col: usize) { + let mut tmp = vec![0.; row * col]; + for i in 0..row { + for j in 0..col { + tmp[i * col + j] = v[i + j * row]; + } + } + v.copy_from_slice(&tmp); + } + + pub fn load(path: &Path) -> Option { + let config = NNUEConfig::from_file(&path.join("config.json"))?; + let mut offsets = Vec::new(); + let mut offset = 0; + for pattern_bits in &config.patterns { + offsets.push(offset); + offset += pow3(pattern_bits.count_ones() as i8); + } + let mut embedding = Self::load_param(&path.join("embedding.weight"), offset * config.front)?; + for chunk in embedding.chunks_mut(config.front) { + Self::normalize_vec(chunk); + } + let embedding = embedding; + let mut layer1_weight = Self::load_param( + &path.join("backend_block.0.weight"), + config.front * config.middle, + )?; + Self::transpose_mat(&mut layer1_weight, config.front, config.middle); + let mut layer2_weight = Self::load_param( + &path.join("backend_block.2.weight"), + config.middle * config.back, + )?; + Self::transpose_mat(&mut layer2_weight, config.middle, config.back); + let layer3_weight = Self::load_param(&path.join("backend_block.4.weight"), config.back)?; + let layer1_bias = Self::load_param(&path.join("backend_block.0.bias"), config.middle)?; + let layer2_bias = Self::load_param(&path.join("backend_block.2.bias"), config.back)?; + let layer3_bias = Self::load_param(&path.join("backend_block.4.bias"), 1)?; + let pos_to_indices = config + .patterns + .iter() + .map(|pattern| Self::generate_pos_to_indices(*pattern)) + .collect(); + Some(Self { + config, + embedding, + offsets, + pos_to_indices, + layer1_weight, + layer1_bias, + layer2_weight, + layer2_bias, + layer3_weight, + layer3_bias, + }) + } + + fn lookup_vec(&self, index: usize) -> &[f32] { + let first = index * self.config.front; + let last = first + self.config.front; + &self.embedding[first..last] + } + + fn score_scale() -> i16 { + 256 + } + + fn score_min() -> i16 { + -Self::score_scale() * BOARD_SIZE as i16 + } + + fn score_max() -> i16 { + Self::score_scale() * BOARD_SIZE as i16 + } +} + +impl Evaluator for NNUEEvaluator { + fn eval(&self, mut board: Board) -> i16 { + let mut front_vec = vec![0.0; self.config.front]; + for _ in 0..4 { + for (pattern, offset) in self.config.patterns.iter().zip(self.offsets.iter()) { + let pbits = board.player.pext(*pattern) as usize; + let obits = board.opponent.pext(*pattern) as usize; + let index = BASE_2_TO_3[pbits] + 2 * BASE_2_TO_3[obits] + offset; + for (e, f) in self.lookup_vec(index).into_iter().zip(front_vec.iter_mut()) { + *f += *e; + } + } + let board_flip = board.flip_diag(); + for (pattern, offset) in self.config.patterns.iter().zip(self.offsets.iter()) { + let pbits = board_flip.player.pext(*pattern) as usize; + let obits = board_flip.opponent.pext(*pattern) as usize; + let index = BASE_2_TO_3[pbits] + 2 * BASE_2_TO_3[obits] + offset; + for (e, f) in self.lookup_vec(index).into_iter().zip(front_vec.iter_mut()) { + *f += *e; + } + } + board = board.rot90(); + } + let mut middle_vec = self.layer1_bias.clone(); + for (j, fe) in front_vec.iter().enumerate() { + for (i, me) in middle_vec.iter_mut().enumerate() { + *me += *fe * unsafe { *self.layer1_weight.get_unchecked(i + self.config.middle * j) }; + } + } + for me in middle_vec.iter_mut() { + if *me < 0. { + *me = 0.; + } + } + let mut back_vec = self.layer2_bias.clone(); + for (j, me) in middle_vec.iter().enumerate() { + for (i, be) in back_vec.iter_mut().enumerate() { + *be += *me * unsafe { *self.layer2_weight.get_unchecked(i + self.config.back * j) }; + } + } + for be in back_vec.iter_mut() { + if *be < 0. { + *be = 0.; + } + } + let mut result = self.layer3_bias[0]; + for (be, w) in back_vec.iter().zip(self.layer3_weight.iter()) { + result += *be * *w; + } + (Self::score_scale() as f32 * result.clamp(-64., 64.)).round() as i16 + } + + fn score_scale(&self) -> i16 { + Self::score_scale() + } + + fn score_min(&self) -> i16 { + Self::score_min() + } + + fn score_max(&self) -> i16 { + Self::score_max() + } +} diff --git a/src/engine/pattern_eval.rs b/src/engine/pattern_eval.rs index 91131b7..fba1c25 100644 --- a/src/engine/pattern_eval.rs +++ b/src/engine/pattern_eval.rs @@ -171,32 +171,8 @@ impl FoldedEvaluator { } } -// interprete base-2 number as base-3 number -// base_2_to_3(x) := radix_parse(radix_fmt(x, 2), 3) -const fn base_2_to_3(mut x: usize) -> usize { - let mut base3 = 0; - let mut pow3 = 1; - while x > 0 { - base3 += (x % 2) * pow3; - pow3 *= 3; - x /= 2; - } - base3 -} - const NON_PATTERN_SCORES: usize = 4; -const BASE_2_TO_3_TABLE_BITS: usize = 13; -const BASE_2_TO_3: [usize; 1 << BASE_2_TO_3_TABLE_BITS] = { - let mut table = [0usize; 1 << BASE_2_TO_3_TABLE_BITS]; - let mut i = 0; - while i < table.len() { - table[i] = base_2_to_3(i); - i += 1; - } - table -}; - // x = R^rot M ^ mirror struct SquareGroup { rot: u8, @@ -438,14 +414,6 @@ pub struct PatternLinearEvaluator { vectorizer: IndicesVectorizer, } -fn pow3(x: i8) -> usize { - if x == 0 { - 1 - } else { - 3 * pow3(x - 1) - } -} - const SCALE: i16 = 256; const EVAL_SCORE_MAX: i16 = BOARD_SIZE as i16 * SCALE; const EVAL_SCORE_MIN: i16 = -EVAL_SCORE_MAX; diff --git a/src/main.rs b/src/main.rs index 1750971..1e53ff8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #![feature(const_option)] #![feature(portable_simd)] #![feature(iterator_try_collect)] +#![feature(slice_as_chunks)] #![feature(test)] mod book; mod compression; diff --git a/src/setup.rs b/src/setup.rs index 4ad7338..5d0d845 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -1,13 +1,13 @@ -use crate::engine::pattern_eval::*; +use crate::engine::nnue_eval::*; use crate::engine::search::*; use crate::engine::table::*; use std::path::Path; use std::sync::Arc; -pub fn setup_default() -> SolveObj { +pub fn setup_default() -> SolveObj { let res_cache = Arc::new(ResCacheTable::new(2048, 16384)); let eval_cache = Arc::new(EvalCacheTable::new(2048, 16384)); - let evaluator = Arc::new(PatternLinearEvaluator::load(Path::new("table-220710")).unwrap()); + let evaluator = Arc::new(NNUEEvaluator::load(Path::new("nnue_32x64x32_240429")).unwrap()); let search_params = SearchParams { reduce: false, parallel_depth_limit: 16,