Skip to content

Commit

Permalink
Significantly improve SIMD code quality (#185)
Browse files Browse the repository at this point in the history
Bench: 13292315
  • Loading branch information
cosmobobak authored Aug 15, 2024
1 parent e066da5 commit 92c0379
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "viridithas"
version = "14.0.1"
version = "15.0.0"
edition = "2021"
description = "A superhuman chess engine."
license = "MIT"
Expand Down
85 changes: 59 additions & 26 deletions src/nnue/network/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ mod x86simd {
use super::super::{Align64, L1_SIZE, L2_SIZE, L3_SIZE, QA, QB};
use crate::nnue::{
network::L1_CHUNK_PER_32,
simd::{self, VecF32, VecI32, VecI8, F32_CHUNK_SIZE, I16_CHUNK_SIZE, I32_CHUNK_SIZE, S},
simd::{self, VecI32, F32_CHUNK_SIZE, I16_CHUNK_SIZE, I32_CHUNK_SIZE, S, U8_CHUNK_SIZE},
};
use std::mem::MaybeUninit;

Expand Down Expand Up @@ -152,10 +152,20 @@ mod x86simd {
NNZTable { table }
};

unsafe fn find_nnz(
// used in only one place, separate function for clarity.
unsafe fn reinterpret_as_i32s(ptr: &Align64<[MaybeUninit<u8>; L1_SIZE]>) -> &Align64<[i32; L1_SIZE / 4]> {
let ptr = std::ptr::from_ref(ptr);
// check that the reference is aligned to the register alignment
debug_assert!((ptr as usize) % std::mem::align_of::<i32>() == 0);
debug_assert!((ptr as usize) % std::mem::align_of::<Align64<[i32; L1_SIZE / 4]>>() == 0);
// cast:
&*ptr.cast::<Align64<[i32; L1_SIZE / 4]>>()
}

unsafe fn find_nnz<'a>(
input: &Align64<[i32; L1_SIZE / L1_CHUNK_PER_32]>,
out: &mut Align64<[MaybeUninit<u16>; L1_SIZE / L1_CHUNK_PER_32]>,
) -> usize {
out: &'a mut Align64<[MaybeUninit<u16>; L1_SIZE / L1_CHUNK_PER_32]>,
) -> &'a [u16] {
use std::arch::x86_64::_mm_add_epi16 as vec128_add;
use std::arch::x86_64::_mm_load_si128 as vec128_load;
use std::arch::x86_64::_mm_set1_epi16 as vec128_set_16;
Expand Down Expand Up @@ -187,7 +197,8 @@ mod x86simd {
}
}

count
// SAFETY: we have initialised this region of the array.
std::slice::from_raw_parts(out.get_unchecked(0).as_ptr().cast(), count)
}

#[allow(
Expand Down Expand Up @@ -251,24 +262,31 @@ mod x86simd {
}
}

let input32 = &*std::ptr::from_ref(&ft_outputs.0).cast::<Align64<[i32; L1_SIZE / L1_CHUNK_PER_32]>>();
// &Align64<[MaybeUninit<u8>; L1_SIZE]>) -> &Align64<[i32; L1_SIZE / 4]>
let input32 = reinterpret_as_i32s(&ft_outputs);

// Compute the non-zero indices.
let mut nnz: Align64<[MaybeUninit<u16>; L1_SIZE / L1_CHUNK_PER_32]> = MaybeUninit::uninit().assume_init();
let nnz_slice = find_nnz(input32, &mut nnz);

let nnz_count = find_nnz(input32, &mut nnz);
let mut sums = [0; L2_SIZE];

let mut sums = [simd::zero_i32(); L2_SIZE / F32_CHUNK_SIZE];

for &i in nnz.get_unchecked(..nnz_count) {
let i = i.assume_init();
for &i in nnz_slice {
// load the non-zero activation, and splat it into a SIMD register.
let input = simd::splat_i32(*input32.get_unchecked(i as usize));
let i_col = i as usize * L2_SIZE * L1_CHUNK_PER_32;
let col = std::ptr::from_ref(weights.get_unchecked(i_col)).cast::<VecI8>();
// compute the index into the weights matrix.
let w_offset = i as usize * L2_SIZE * L1_CHUNK_PER_32;
// for each SIMD-block in the row, compute the product
// of the non-zero activation with the corresponding
// weight, and add it to the accumulator.
for k in 0..L2_SIZE / F32_CHUNK_SIZE {
*sums.get_unchecked_mut(k) = simd::mul_add_u8_to_i32(
*sums.get_unchecked(k),
simd::reinterpret_i32s_as_i8s(input),
*col.add(k),
simd::store_i32(
sums.get_unchecked_mut(k * F32_CHUNK_SIZE),
simd::mul_add_u8_to_i32(
simd::load_i32(sums.get_unchecked(k * F32_CHUNK_SIZE)),
simd::reinterpret_i32s_as_i8s(input),
simd::load_i8(weights.get_unchecked(w_offset + k * U8_CHUNK_SIZE)),
),
);
}
}
Expand All @@ -278,9 +296,13 @@ mod x86simd {
let sum_mul = simd::splat_f32(L1_MUL);
for i in 0..L2_SIZE / F32_CHUNK_SIZE {
// Convert into floats, and activate L1
let bias_vec = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE));
let sum_ps = simd::mul_add_f32(simd::i32_to_f32(*sums.get_unchecked(i)), sum_mul, bias_vec);
let clipped = simd::min_f32(simd::max_f32(sum_ps, zero), one);
let bias = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE));
let sum = simd::mul_add_f32(
simd::i32_to_f32(simd::load_i32(sums.get_unchecked(i * F32_CHUNK_SIZE))),
sum_mul,
bias,
);
let clipped = simd::min_f32(simd::max_f32(sum, zero), one);
let squared = simd::mul_f32(clipped, clipped);
simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
}
Expand All @@ -296,25 +318,36 @@ mod x86simd {
) {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let mut sum_vecs = [simd::zero_f32(); L3_SIZE / F32_CHUNK_SIZE];
let mut sums = [0.0; L3_SIZE];

for i in 0..L3_SIZE / F32_CHUNK_SIZE {
*sum_vecs.get_unchecked_mut(i) = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE));
simd::store_f32(
sums.get_unchecked_mut(i * F32_CHUNK_SIZE),
simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE)),
);
}

for i in 0..L2_SIZE {
let input_vec = simd::splat_f32(*inputs.get_unchecked(i));
let weight = std::ptr::from_ref(weights.get_unchecked(i * L3_SIZE)).cast::<VecF32>();
for j in 0..L3_SIZE / F32_CHUNK_SIZE {
*sum_vecs.get_unchecked_mut(j) =
simd::mul_add_f32(input_vec, *weight.add(j), *sum_vecs.get_unchecked(j));
simd::store_f32(
sums.get_unchecked_mut(j * F32_CHUNK_SIZE),
simd::mul_add_f32(
input_vec,
simd::load_f32(weights.get_unchecked(i * L3_SIZE + j * F32_CHUNK_SIZE)),
simd::load_f32(sums.get_unchecked(j * F32_CHUNK_SIZE)),
),
);
}
}

// Activate L2
let one = simd::splat_f32(1.0);
for i in 0..L3_SIZE / F32_CHUNK_SIZE {
let clipped = simd::min_f32(simd::max_f32(*sum_vecs.get_unchecked(i), simd::zero_f32()), one);
let clipped = simd::min_f32(
simd::max_f32(simd::load_f32(sums.get_unchecked(i * F32_CHUNK_SIZE)), simd::zero_f32()),
one,
);
let squared = simd::mul_f32(clipped, clipped);
simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
}
Expand Down

0 comments on commit 92c0379

Please sign in to comment.