Skip to content

Commit

Permalink
Use const array instead of lazy_static
Browse files Browse the repository at this point in the history
  • Loading branch information
primenumber committed Feb 10, 2024
1 parent d41c290 commit 4b50030
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 56 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.2.0"
authors = ["prime <[email protected]>"]

[dependencies]
lazy_static = "1.4"
rand = { version = "0.8", features = ["small_rng"] }
rand_xoshiro = "0.6"
futures = { version = "0.3", features = ["std"] }
Expand Down
33 changes: 17 additions & 16 deletions src/engine/bits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use core::arch::x86_64::{_pdep_u64, _pext_u64};
use lazy_static::lazy_static;

pub fn popcnt(x: u64) -> i8 {
x.count_ones() as i8
Expand Down Expand Up @@ -53,23 +52,25 @@ pub fn pdep(x: u64, mask: u64) -> u64 {
unsafe { _pdep_u64(x, mask) }
}

lazy_static! {
pub static ref BASE3: [usize; 256] = {
let mut res = [0usize; 256];
for x in 0..256 {
let mut pow3 = 1;
let mut sum = 0;
for i in 0..8 {
if ((x >> i) & 1) == 1 {
sum += pow3;
}
pow3 *= 3;
pub const BASE3: [usize; 256] = {
let mut res = [0usize; 256];
let mut x = 0;
while x < 256 {
let mut pow3 = 1;
let mut sum = 0;
let mut i = 0;
while i < 8 {
if ((x >> i) & 1) == 1 {
sum += pow3;
}
res[x] = sum;
pow3 *= 3;
i += 1;
}
res
};
}
res[x] = sum;
x += 1;
}
res
};

#[cfg(test)]
mod tests {
Expand Down
135 changes: 100 additions & 35 deletions src/engine/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::engine::hand::*;
use anyhow::Result;
use clap::ArgMatches;
use core::arch::x86_64::*;
use lazy_static::lazy_static;
use std::cmp::min;
use std::fmt;
use std::io::{BufWriter, Write};
Expand Down Expand Up @@ -33,7 +32,7 @@ pub struct PlayIterator {

pub const BOARD_SIZE: usize = 64;

#[cfg(all(target_feature = "avx512cd", target_feature="avx512vl"))]
#[cfg(all(target_feature = "avx512cd", target_feature = "avx512vl"))]
unsafe fn smart_upper_bit(x: __m256i) -> __m256i {
let y = _mm256_lzcnt_epi64(x);
_mm256_srlv_epi64(_mm256_set1_epi64x(0x8000_0000_0000_0000u64 as i64), y)
Expand All @@ -48,6 +47,14 @@ unsafe fn smart_upper_bit(mut x: __m256i) -> __m256i {
_mm256_andnot_si256(lowers, x)
}

const fn smart_upper_bit_scalar(mut x: u64, lane: usize) -> u64 {
x |= x >> [8, 1, 7, 9][lane];
x |= x >> [16, 2, 14, 18][lane];
x |= x >> [32, 4, 28, 36][lane];
let lowers = x >> [8, 1, 7, 9][lane];
!lowers & x
}

#[allow(dead_code)]
unsafe fn upper_bit(mut x: __m256i) -> __m256i {
x = _mm256_or_si256(x, _mm256_srli_epi64(x, 1));
Expand Down Expand Up @@ -125,10 +132,51 @@ impl Board {
reduce_or(flipped)
}

#[cfg(all(target_feature = "avx2"))]
pub fn flip_unchecked(&self, pos: usize) -> u64 {
unsafe { self.flip_simd(pos) }
}

pub const fn flip_naive(&self, pos: usize) -> u64 {
let o_mask = 0x7E7E_7E7E_7E7E_7E7Eu64;
let om = [
self.opponent,
self.opponent & o_mask,
self.opponent & o_mask,
self.opponent & o_mask,
];
let mask1 = [
0x0080808080808080u64,
0x7f00000000000000u64,
0x0102040810204000u64,
0x0040201008040201u64,
];
let mask2 = [
0x0101010101010100u64,
0x00000000000000feu64,
0x0002040810204080u64,
0x8040201008040200u64,
];
let mut flipped = 0;
let mut i = 0;
while i < 4 {
let mask = mask1[i] >> (63 - pos);
let outflank = smart_upper_bit_scalar(!om[i] & mask, i) & self.player;
flipped |= (outflank.wrapping_neg() << 1) & mask;
let mask = mask2[i] << pos;
let outflank = !((!om[i] & mask).wrapping_sub(1)) & mask & self.player;
flipped |= !((if outflank == 0 {
0xFFFF_FFFF_FFFF_FFFFu64
} else {
0
})
.wrapping_sub(outflank))
& mask;
i += 1;
}
flipped
}

pub fn flip(&self, pos: usize) -> u64 {
if ((self.empty() >> pos) & 1) == 0 {
0
Expand All @@ -137,6 +185,14 @@ impl Board {
}
}

pub const fn flip_const(&self, pos: usize) -> u64 {
if ((self.empty() >> pos) & 1) == 0 {
0
} else {
self.flip_naive(pos)
}
}

pub fn is_movable(&self, pos: usize) -> bool {
if pos >= BOARD_SIZE {
return false;
Expand Down Expand Up @@ -171,7 +227,7 @@ impl Board {
}
}

pub fn pass_unchecked(&self) -> Board {
pub const fn pass_unchecked(&self) -> Board {
Board {
player: self.opponent,
opponent: self.player,
Expand All @@ -189,7 +245,7 @@ impl Board {
}
}

pub fn empty(&self) -> u64 {
pub const fn empty(&self) -> u64 {
!(self.player | self.opponent)
}

Expand Down Expand Up @@ -605,17 +661,19 @@ pub fn weighted_mobility(board: &Board) -> i8 {
popcnt(b) + popcnt(b & corner)
}

fn stable_bits_8(board: Board, passed: bool, memo: &mut [Option<u64>]) -> u64 {
const fn stable_bits_8(board: Board, passed: bool, memo: &[Option<u64>]) -> u64 {
let index = BASE3[board.player as usize] + 2 * BASE3[board.opponent as usize];
if let Some(res) = memo[index] {
return res;
}
let mut res = 0xFF;
for pos in 0..8 {
let mut pos = 0;
while pos < 8 {
if ((board.empty() >> pos) & 1) != 1 {
pos += 1;
continue;
}
let flip = board.flip(pos);
let flip = board.flip_const(pos);
let pos_bit = 1 << pos;
let next = Board {
player: board.opponent ^ flip,
Expand All @@ -624,11 +682,12 @@ fn stable_bits_8(board: Board, passed: bool, memo: &mut [Option<u64>]) -> u64 {
res &= !flip;
res &= !pos_bit;
res &= stable_bits_8(next, false, memo);
pos += 1;
}
if !passed {
let next = board.pass_unchecked();
res &= stable_bits_8(next, true, memo);
memo[index] = Some(res);
//memo[index] = Some(res);
}
res
}
Expand All @@ -643,32 +702,38 @@ pub fn parse_board(matches: &ArgMatches) {
println!("{}", board);
}

lazy_static! {
static ref STABLE: [u64; 6561] = {
let mut memo = [None; 6561];
for i in 0..6561 {
let mut me = 0;
let mut op = 0;
let mut tmp = i;
for j in 0..8 {
let state = tmp % 3;
match state {
1 => me |= 1 << j,
2 => op |= 1 << j,
_ => (),
}
tmp /= 3;
const STABLE: [u64; 6561] = {
let mut memo = [None; 6561];
let mut ri = 0;
while ri < 6561 {
let i = 6561 - ri - 1;
let mut me = 0;
let mut op = 0;
let mut tmp = i;
let mut j = 0;
while j < 8 {
let state = tmp % 3;
match state {
1 => me |= 1 << j,
2 => op |= 1 << j,
_ => (),
}
let board = Board {
player: me,
opponent: op,
};
stable_bits_8(board, false, &mut memo);
}
let mut res = [0; 6561];
for i in 0..6561 {
res[i] = memo[i].unwrap() & 0xFF;
tmp /= 3;
j += 1;
}
res
};
}
let board = Board {
player: me,
opponent: op,
};
let res = stable_bits_8(board, false, &memo);
memo[i] = Some(res);
ri += 1;
}
let mut res = [0; 6561];
let mut i = 0;
while i < 6561 {
res[i] = memo[i].unwrap() & 0xFF;
i += 1;
}
res
};
6 changes: 2 additions & 4 deletions src/engine/board/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ impl From<NaiveBoard> for Board {
_ => (),
}
}
Board {
player,
opponent,
}
Board { player, opponent }
}
}

Expand Down Expand Up @@ -202,6 +199,7 @@ fn test_ops() {
assert_eq!(board, Board::from(naive_board.clone()));
for i in 0..BOARD_SIZE {
assert_eq!(board.flip(i), naive_board.flip(i));
assert_eq!(board.flip_const(i), naive_board.flip(i));
assert_eq!(board.is_movable(i), naive_board.is_movable(i));
if board.is_movable(i) {
assert_eq!(
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(const_option)]
#![feature(test)]
mod book;
mod engine;
Expand Down

0 comments on commit 4b50030

Please sign in to comment.