From 92c0379280a3fde2c7f8dce877eb18c52266b75c Mon Sep 17 00:00:00 2001 From: Cosmo Bobak <56003038+cosmobobak@users.noreply.github.com> Date: Fri, 16 Aug 2024 00:04:44 +0100 Subject: [PATCH] Significantly improve SIMD code quality (#185) Bench: 13292315 --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/nnue/network/layers.rs | 85 ++++++++++++++++++++++++++------------ 3 files changed, 61 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 609dacaf..d6e9fce9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -726,7 +726,7 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "viridithas" -version = "14.0.1" +version = "15.0.0" dependencies = [ "anyhow", "arrayvec", diff --git a/Cargo.toml b/Cargo.toml index 83511910..d2ac880b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "viridithas" -version = "14.0.1" +version = "15.0.0" edition = "2021" description = "A superhuman chess engine." license = "MIT" diff --git a/src/nnue/network/layers.rs b/src/nnue/network/layers.rs index 9c05a34a..1052d84f 100644 --- a/src/nnue/network/layers.rs +++ b/src/nnue/network/layers.rs @@ -117,7 +117,7 @@ mod x86simd { use super::super::{Align64, L1_SIZE, L2_SIZE, L3_SIZE, QA, QB}; use crate::nnue::{ network::L1_CHUNK_PER_32, - simd::{self, VecF32, VecI32, VecI8, F32_CHUNK_SIZE, I16_CHUNK_SIZE, I32_CHUNK_SIZE, S}, + simd::{self, VecI32, F32_CHUNK_SIZE, I16_CHUNK_SIZE, I32_CHUNK_SIZE, S, U8_CHUNK_SIZE}, }; use std::mem::MaybeUninit; @@ -152,10 +152,20 @@ mod x86simd { NNZTable { table } }; - unsafe fn find_nnz( + // used in only one place, separate function for clarity. + unsafe fn reinterpret_as_i32s(ptr: &Align64<[MaybeUninit; L1_SIZE]>) -> &Align64<[i32; L1_SIZE / 4]> { + let ptr = std::ptr::from_ref(ptr); + // check that the reference is aligned to the register alignment + debug_assert!((ptr as usize) % std::mem::align_of::() == 0); + debug_assert!((ptr as usize) % std::mem::align_of::>() == 0); + // cast: + &*ptr.cast::>() + } + + unsafe fn find_nnz<'a>( input: &Align64<[i32; L1_SIZE / L1_CHUNK_PER_32]>, - out: &mut Align64<[MaybeUninit; L1_SIZE / L1_CHUNK_PER_32]>, - ) -> usize { + out: &'a mut Align64<[MaybeUninit; L1_SIZE / L1_CHUNK_PER_32]>, + ) -> &'a [u16] { use std::arch::x86_64::_mm_add_epi16 as vec128_add; use std::arch::x86_64::_mm_load_si128 as vec128_load; use std::arch::x86_64::_mm_set1_epi16 as vec128_set_16; @@ -187,7 +197,8 @@ mod x86simd { } } - count + // SAFETY: we have initialised this region of the array. + std::slice::from_raw_parts(out.get_unchecked(0).as_ptr().cast(), count) } #[allow( @@ -251,24 +262,31 @@ mod x86simd { } } - let input32 = &*std::ptr::from_ref(&ft_outputs.0).cast::>(); + // &Align64<[MaybeUninit; L1_SIZE]>) -> &Align64<[i32; L1_SIZE / 4]> + let input32 = reinterpret_as_i32s(&ft_outputs); + // Compute the non-zero indices. let mut nnz: Align64<[MaybeUninit; L1_SIZE / L1_CHUNK_PER_32]> = MaybeUninit::uninit().assume_init(); + let nnz_slice = find_nnz(input32, &mut nnz); - let nnz_count = find_nnz(input32, &mut nnz); + let mut sums = [0; L2_SIZE]; - let mut sums = [simd::zero_i32(); L2_SIZE / F32_CHUNK_SIZE]; - - for &i in nnz.get_unchecked(..nnz_count) { - let i = i.assume_init(); + for &i in nnz_slice { + // load the non-zero activation, and splat it into a SIMD register. let input = simd::splat_i32(*input32.get_unchecked(i as usize)); - let i_col = i as usize * L2_SIZE * L1_CHUNK_PER_32; - let col = std::ptr::from_ref(weights.get_unchecked(i_col)).cast::(); + // compute the index into the weights matrix. + let w_offset = i as usize * L2_SIZE * L1_CHUNK_PER_32; + // for each SIMD-block in the row, compute the product + // of the non-zero activation with the corresponding + // weight, and add it to the accumulator. for k in 0..L2_SIZE / F32_CHUNK_SIZE { - *sums.get_unchecked_mut(k) = simd::mul_add_u8_to_i32( - *sums.get_unchecked(k), - simd::reinterpret_i32s_as_i8s(input), - *col.add(k), + simd::store_i32( + sums.get_unchecked_mut(k * F32_CHUNK_SIZE), + simd::mul_add_u8_to_i32( + simd::load_i32(sums.get_unchecked(k * F32_CHUNK_SIZE)), + simd::reinterpret_i32s_as_i8s(input), + simd::load_i8(weights.get_unchecked(w_offset + k * U8_CHUNK_SIZE)), + ), ); } } @@ -278,9 +296,13 @@ mod x86simd { let sum_mul = simd::splat_f32(L1_MUL); for i in 0..L2_SIZE / F32_CHUNK_SIZE { // Convert into floats, and activate L1 - let bias_vec = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE)); - let sum_ps = simd::mul_add_f32(simd::i32_to_f32(*sums.get_unchecked(i)), sum_mul, bias_vec); - let clipped = simd::min_f32(simd::max_f32(sum_ps, zero), one); + let bias = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE)); + let sum = simd::mul_add_f32( + simd::i32_to_f32(simd::load_i32(sums.get_unchecked(i * F32_CHUNK_SIZE))), + sum_mul, + bias, + ); + let clipped = simd::min_f32(simd::max_f32(sum, zero), one); let squared = simd::mul_f32(clipped, clipped); simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared); } @@ -296,25 +318,36 @@ mod x86simd { ) { // SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc. unsafe { - let mut sum_vecs = [simd::zero_f32(); L3_SIZE / F32_CHUNK_SIZE]; + let mut sums = [0.0; L3_SIZE]; for i in 0..L3_SIZE / F32_CHUNK_SIZE { - *sum_vecs.get_unchecked_mut(i) = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE)); + simd::store_f32( + sums.get_unchecked_mut(i * F32_CHUNK_SIZE), + simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE)), + ); } for i in 0..L2_SIZE { let input_vec = simd::splat_f32(*inputs.get_unchecked(i)); - let weight = std::ptr::from_ref(weights.get_unchecked(i * L3_SIZE)).cast::(); for j in 0..L3_SIZE / F32_CHUNK_SIZE { - *sum_vecs.get_unchecked_mut(j) = - simd::mul_add_f32(input_vec, *weight.add(j), *sum_vecs.get_unchecked(j)); + simd::store_f32( + sums.get_unchecked_mut(j * F32_CHUNK_SIZE), + simd::mul_add_f32( + input_vec, + simd::load_f32(weights.get_unchecked(i * L3_SIZE + j * F32_CHUNK_SIZE)), + simd::load_f32(sums.get_unchecked(j * F32_CHUNK_SIZE)), + ), + ); } } // Activate L2 let one = simd::splat_f32(1.0); for i in 0..L3_SIZE / F32_CHUNK_SIZE { - let clipped = simd::min_f32(simd::max_f32(*sum_vecs.get_unchecked(i), simd::zero_f32()), one); + let clipped = simd::min_f32( + simd::max_f32(simd::load_f32(sums.get_unchecked(i * F32_CHUNK_SIZE)), simd::zero_f32()), + one, + ); let squared = simd::mul_f32(clipped, clipped); simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared); }