From 6c4af5a9f7160b43927a078e0fe048ab7859b700 Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Fri, 11 Aug 2023 13:50:30 -0700 Subject: [PATCH] Matrix transpose bitorder (#27) * add naive test * Change algorithm to work for LSB0 encoding instead of MSB0 * add test --------- Co-authored-by: th4s Co-authored-by: themighty1 --- matrix-transpose/Cargo.toml | 1 + matrix-transpose/src/lib.rs | 63 ++++++++++++++++++++++++++++++---- matrix-transpose/src/scalar.rs | 5 +-- matrix-transpose/src/simd.rs | 10 +++++- 4 files changed, 69 insertions(+), 10 deletions(-) diff --git a/matrix-transpose/Cargo.toml b/matrix-transpose/Cargo.toml index 75112ae6..c8768f80 100644 --- a/matrix-transpose/Cargo.toml +++ b/matrix-transpose/Cargo.toml @@ -11,6 +11,7 @@ thiserror.workspace = true [dev-dependencies] rand.workspace = true criterion.workspace = true +itybity.workspace = true [features] simd-transpose = [] diff --git a/matrix-transpose/src/lib.rs b/matrix-transpose/src/lib.rs index c08ac618..1479df66 100644 --- a/matrix-transpose/src/lib.rs +++ b/matrix-transpose/src/lib.rs @@ -32,6 +32,7 @@ use thiserror::Error; /// This function transposes a matrix on the bit-level. /// +/// Assumes an LSB0 bit encoding of the matrix. /// This implementation requires that the number of rows is a power of 2 /// and that the number of columns is a multiple of 8 pub fn transpose_bits(matrix: &mut [u8], rows: usize) -> Result<(), TransposeError> { @@ -87,17 +88,65 @@ mod tests { (0..elements).map(|_| rng.gen::()).collect() } + fn transpose_naive(data: &[u8], row_width: usize) -> Vec { + use itybity::*; + + let bits: Vec> = data.chunks(row_width).map(|x| x.to_lsb0_vec()).collect(); + let col_count = bits[0].len(); + let row_count = bits.len(); + + let mut bits_: Vec> = vec![vec![false; row_count]; col_count]; + for j in 0..row_count { + for i in 0..col_count { + bits_[i][j] = bits[j][i]; + } + } + + bits_.into_iter().flat_map(Vec::::from_lsb0).collect() + } + #[test] fn test_transpose_bits() { - let mut rows = 512; + let rows = 512; let columns = 256; let mut matrix: Vec = random_vec::(columns * rows); - let original = matrix.clone(); + let naive = transpose_naive(&matrix, columns); + transpose_bits(&mut matrix, rows).unwrap(); - rows = columns; - transpose_bits(&mut matrix, 8 * rows).unwrap(); - assert_eq!(original, matrix); + + assert_eq!(naive, matrix); + } + + #[test] + fn test_transpose_naive() { + let matrix = [ + // ------- bits in lsb0 + 3u8, // 1 1 0 0 0 0 0 0 + 76u8, // 0 0 1 1 0 0 1 0 + 120u8, // 0 0 0 1 1 1 1 0 + 9u8, // 1 0 0 1 0 0 0 0 + 17u8, // 1 0 0 0 1 0 0 0 + 102u8, // 0 1 1 0 0 1 1 0 + 53u8, // 1 0 1 0 1 1 0 0 + 125u8, // 1 0 1 1 1 1 1 0 + ]; + + let expected = [ + // ------- bits in lsb0 + 217u8, // 1 0 0 1 1 0 1 1 + 33u8, // 1 0 0 0 0 1 0 0 + 226u8, // 0 1 0 0 0 1 1 1 + 142u8, // 0 1 1 1 0 0 0 1 + 212u8, // 0 0 1 0 1 0 1 1 + 228u8, // 0 0 1 0 0 1 1 1 + 166u8, // 0 1 1 0 0 1 0 1 + 0u8, // 0 0 0 0 0 0 0 0 + ]; + + let naive = transpose_naive(&matrix, 1); + + assert_eq!(naive, expected); } #[test] @@ -141,12 +190,12 @@ mod tests { for k in 0..8 { for (l, chunk) in row.chunks(8).enumerate() { let expected: u8 = chunk.iter().enumerate().fold(0, |acc, (m, element)| { - acc + (element >> 7) * 2_u8.pow(7_u32 - m as u32) + acc + (element & 1) * 2_u8.pow(m as u32) }); let actual = matrix[row_index * columns + columns / 8 * k + l]; assert_eq!(expected, actual); } - let shifted_row = row.iter_mut().map(|el| *el << 1).collect::>(); + let shifted_row = row.iter_mut().map(|el| *el >> 1).collect::>(); row.copy_from_slice(&shifted_row); } } diff --git a/matrix-transpose/src/scalar.rs b/matrix-transpose/src/scalar.rs index 854431fa..75d2754a 100644 --- a/matrix-transpose/src/scalar.rs +++ b/matrix-transpose/src/scalar.rs @@ -37,6 +37,7 @@ where /// Single-row bit-mask shift /// +/// Assumes an LSB0 bit encoding of the matrix. /// This function is an implementation of the bit-level transpose in /// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html /// Caller has to make sure that columns is a multiple of 8 @@ -48,8 +49,8 @@ pub fn bitmask_shift(matrix: &mut [u8], columns: usize) { for bytes in row.chunks_mut(8) { let mut high_bits: u8 = 0b00000000; bytes.iter_mut().enumerate().for_each(|(k, b)| { - high_bits |= (0b10000000 & *b) >> k; - *b <<= 1; + high_bits |= (0b00000001 & *b) << k; + *b >>= 1; }); shifted_row.push(high_bits); } diff --git a/matrix-transpose/src/simd.rs b/matrix-transpose/src/simd.rs index 65c8549e..f6264211 100644 --- a/matrix-transpose/src/simd.rs +++ b/matrix-transpose/src/simd.rs @@ -6,6 +6,7 @@ use std::{ /// SIMD version for bit-level transposition /// +/// Assumes an LSB0 bit encoding of the matrix. /// This SIMD implementation additionally requires that the matrix has at least /// 16 (WASM) or 32 (x86_64) columns and rows #[cfg(any(target_arch = "x86_64", target_arch = "wasm32"))] @@ -73,6 +74,7 @@ where /// Unsafe single-row bit-mask shift /// +/// Assumes an LSB0 bit encoding of the matrix. /// This function is an implementation of the bit-level transpose in /// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html /// Caller has to make sure that columns is a multiple of 16 or 32 @@ -84,6 +86,7 @@ pub unsafe fn bitmask_shift_unchecked(matrix: &mut [u8], columns: usize) { #[cfg(target_arch = "x86_64")] use std::arch::x86_64::_mm256_movemask_epi8; + matrix.iter_mut().for_each(|b| *b = b.reverse_bits()); let simd_one = Simd::::splat(1); let mut s: Simd; for row in matrix.chunks_mut(columns) { @@ -95,7 +98,12 @@ pub unsafe fn bitmask_shift_unchecked(matrix: &mut [u8], columns: usize) { let high_bits = _mm256_movemask_epi8(s.reverse().into()); #[cfg(target_arch = "wasm32")] let high_bits = u8x16_bitmask(s.reverse().into()); - shifted_row.extend_from_slice(&high_bits.to_be_bytes()); + let high_bits: Vec = high_bits + .to_be_bytes() + .into_iter() + .map(|b| b.reverse_bits()) + .collect(); + shifted_row.extend_from_slice(&high_bits); s.shl_assign(simd_one); *chunk = s.to_array(); }