Skip to content

Commit

Permalink
Use packed_simd::shuffle instead of vqtbx1q_u8
Browse files Browse the repository at this point in the history
  • Loading branch information
rubdos committed Dec 8, 2022
1 parent 3ab954c commit 26c494f
Showing 1 changed file with 35 additions and 40 deletions.
75 changes: 35 additions & 40 deletions src/backend/vector/neon/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
// - Henry de Valence <[email protected]>
// - Robrecht Blancquaert <[email protected]>

//! More details on the algorithms can be found in the `avx2`
//! module. Here comments are mostly added only when needed
//! to explain differenes between the 'base' avx2 version and
//! More details on the algorithms can be found in the `avx2`
//! module. Here comments are mostly added only when needed
//! to explain differenes between the 'base' avx2 version and
//! this re-implementation for arm neon.
//! The most major difference is the split of one vector of 8
Expand Down Expand Up @@ -61,10 +61,10 @@ fn repack_pair(x: (u32x4, u32x4), y: (u32x4, u32x4)) -> (u32x4, u32x4) {
use core::arch::aarch64::vgetq_lane_u32;

(vcombine_u32(
vset_lane_u32(vgetq_lane_u32(x.0.into_bits(), 2) , vget_low_u32(x.0.into_bits()), 1),
vset_lane_u32(vgetq_lane_u32(y.0.into_bits(), 2) , vget_low_u32(y.0.into_bits()), 1)).into_bits(),
vset_lane_u32(vgetq_lane_u32(x.0.into_bits(), 2) , vget_low_u32(x.0.into_bits()), 1),
vset_lane_u32(vgetq_lane_u32(y.0.into_bits(), 2) , vget_low_u32(y.0.into_bits()), 1)).into_bits(),
vcombine_u32(
vset_lane_u32(vgetq_lane_u32(x.1.into_bits(), 2) , vget_low_u32(x.1.into_bits()), 1),
vset_lane_u32(vgetq_lane_u32(x.1.into_bits(), 2) , vget_low_u32(x.1.into_bits()), 1),
vset_lane_u32(vgetq_lane_u32(y.1.into_bits(), 2) , vget_low_u32(y.1.into_bits()), 1)).into_bits())
}
}
Expand Down Expand Up @@ -100,16 +100,16 @@ macro_rules! lane_shuffle {
unsafe {
use core::arch::aarch64::vgetq_lane_u32;
const c: [i32; 8] = [$l0, $l1, $l2, $l3, $l4, $l5, $l6, $l7];
(u32x4::new(if c[0] < 4 { vgetq_lane_u32($x.0.into_bits(), c[0]) } else { vgetq_lane_u32($x.1.into_bits(), c[0] - 4) },
if c[1] < 4 { vgetq_lane_u32($x.0.into_bits(), c[1]) } else { vgetq_lane_u32($x.1.into_bits(), c[1] - 4) },
if c[2] < 4 { vgetq_lane_u32($x.0.into_bits(), c[2]) } else { vgetq_lane_u32($x.1.into_bits(), c[2] - 4) },
(u32x4::new(if c[0] < 4 { vgetq_lane_u32($x.0.into_bits(), c[0]) } else { vgetq_lane_u32($x.1.into_bits(), c[0] - 4) },
if c[1] < 4 { vgetq_lane_u32($x.0.into_bits(), c[1]) } else { vgetq_lane_u32($x.1.into_bits(), c[1] - 4) },
if c[2] < 4 { vgetq_lane_u32($x.0.into_bits(), c[2]) } else { vgetq_lane_u32($x.1.into_bits(), c[2] - 4) },
if c[3] < 4 { vgetq_lane_u32($x.0.into_bits(), c[3]) } else { vgetq_lane_u32($x.1.into_bits(), c[3] - 4) }),
u32x4::new(if c[4] < 4 { vgetq_lane_u32($x.0.into_bits(), c[4]) } else { vgetq_lane_u32($x.1.into_bits(), c[4] - 4) },
if c[5] < 4 { vgetq_lane_u32($x.0.into_bits(), c[5]) } else { vgetq_lane_u32($x.1.into_bits(), c[5] - 4) },
if c[6] < 4 { vgetq_lane_u32($x.0.into_bits(), c[6]) } else { vgetq_lane_u32($x.1.into_bits(), c[6] - 4) },
u32x4::new(if c[4] < 4 { vgetq_lane_u32($x.0.into_bits(), c[4]) } else { vgetq_lane_u32($x.1.into_bits(), c[4] - 4) },
if c[5] < 4 { vgetq_lane_u32($x.0.into_bits(), c[5]) } else { vgetq_lane_u32($x.1.into_bits(), c[5] - 4) },
if c[6] < 4 { vgetq_lane_u32($x.0.into_bits(), c[6]) } else { vgetq_lane_u32($x.1.into_bits(), c[6] - 4) },
if c[7] < 4 { vgetq_lane_u32($x.0.into_bits(), c[7]) } else { vgetq_lane_u32($x.1.into_bits(), c[7] - 4) }))
}

}
}

Expand Down Expand Up @@ -161,14 +161,14 @@ impl FieldElement2625x4 {
pub fn split(&self) -> [FieldElement51; 4] {
let mut out = [FieldElement51::zero(); 4];
for i in 0..5 {
let a_2i = self.0[i].0.extract(0) as u64;
let b_2i = self.0[i].0.extract(1) as u64;
let a_2i_1 = self.0[i].0.extract(2) as u64;
let a_2i = self.0[i].0.extract(0) as u64;
let b_2i = self.0[i].0.extract(1) as u64;
let a_2i_1 = self.0[i].0.extract(2) as u64;
let b_2i_1 = self.0[i].0.extract(3) as u64;
let c_2i = self.0[i].1.extract(0) as u64;
let d_2i = self.0[i].1.extract(1) as u64;
let c_2i_1 = self.0[i].1.extract(2) as u64;
let d_2i_1 = self.0[i].1.extract(3) as u64;
let d_2i = self.0[i].1.extract(1) as u64;
let c_2i_1 = self.0[i].1.extract(2) as u64;
let d_2i_1 = self.0[i].1.extract(3) as u64;

out[0].0[i] = a_2i + (a_2i_1 << 26);
out[1].0[i] = b_2i + (b_2i_1 << 26);
Expand Down Expand Up @@ -212,33 +212,28 @@ impl FieldElement2625x4 {
#[inline(always)]
fn blend_lanes(x: (u32x4, u32x4), y: (u32x4, u32x4), control: Lanes) -> (u32x4, u32x4) {
unsafe {
use core::arch::aarch64::vqtbx1q_u8;
use packed_simd::shuffle;
match control {
Lanes::C => {
(x.0,
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits())
(x.0, shuffle!(y.1, x.1, [0, 5, 2, 7]))
}
Lanes::D => {
(x.0,
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits()).into_bits())
(x.0, shuffle!(y.1, x.1, [4, 1, 6, 3]))
}
Lanes::AD => {
(vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits() ).into_bits(),
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits() ).into_bits())
(shuffle!(y.0, x.0, [0, 5, 2, 7]), shuffle!(y.1, x.1, [4, 1, 6, 3]))
}
Lanes::AB => {
(y.0, x.1)
}
Lanes::AC => {
(vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits(),
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits())
(shuffle!(y.0, x.0, [0, 5, 2, 7]), shuffle!(y.1, x.1, [0, 5, 2, 7]))
}
Lanes::CD => {
(x.0, y.1)
(x.0, y.1)
}
Lanes::BC => {
(vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits() ).into_bits(),
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits() ).into_bits())
(shuffle!(y.0, x.0, [4, 1, 6, 3]), shuffle!(y.1, x.1, [0, 5, 2, 7]))
}
Lanes::ABCD => {
y
Expand Down Expand Up @@ -333,7 +328,7 @@ impl FieldElement2625x4 {
use core::arch::aarch64::vget_high_u32;
use core::arch::aarch64::vcombine_u32;

let c: (u32x4, u32x4) = (vqshlq_u32(v.0.into_bits(), shifts.0.into_bits()).into_bits(),
let c: (u32x4, u32x4) = (vqshlq_u32(v.0.into_bits(), shifts.0.into_bits()).into_bits(),
vqshlq_u32(v.1.into_bits(), shifts.1.into_bits()).into_bits());
(vcombine_u32(vget_high_u32(c.0.into_bits()), vget_low_u32(c.0.into_bits())).into_bits(),
vcombine_u32(vget_high_u32(c.1.into_bits()), vget_low_u32(c.1.into_bits())).into_bits())
Expand Down Expand Up @@ -377,7 +372,7 @@ impl FieldElement2625x4 {
use core::arch::aarch64::vmulq_n_u32;
use core::arch::aarch64::vget_low_u32;
use core::arch::aarch64::vcombine_u32;

let c9_19_spread: (u32x4, u32x4) = (vmulq_n_u32(c98.0.into_bits(), 19).into_bits(), vmulq_n_u32(c98.1.into_bits(), 19).into_bits());

(vcombine_u32(vget_low_u32(c9_19_spread.0.into_bits()), u32x2::splat(0).into_bits()).into_bits(),
Expand Down Expand Up @@ -423,9 +418,9 @@ impl FieldElement2625x4 {
unsafe {
use core::arch::aarch64::vmulq_n_u32;

c0 = (vmulq_n_u32(c0.0.into_bits(), 19).into_bits(),
c0 = (vmulq_n_u32(c0.0.into_bits(), 19).into_bits(),
vmulq_n_u32(c0.1.into_bits(), 19).into_bits());
c1 = (vmulq_n_u32(c1.0.into_bits(), 19).into_bits(),
c1 = (vmulq_n_u32(c1.0.into_bits(), 19).into_bits(),
vmulq_n_u32(c1.1.into_bits(), 19).into_bits());
}

Expand Down Expand Up @@ -457,8 +452,8 @@ impl FieldElement2625x4 {
#[inline(always)]
fn m_lo(x: (u32x2, u32x2), y: (u32x2, u32x2)) -> (u32x2, u32x2) {
use core::arch::aarch64::vmull_u32;
unsafe {
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
unsafe {
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
vmull_u32(x.1.into_bits(), y.1.into_bits()).into_bits());
(u32x2::new(x.0.extract(0), x.0.extract(2)), u32x2::new(x.1.extract(0), x.1.extract(2)))
}
Expand Down Expand Up @@ -497,7 +492,7 @@ impl FieldElement2625x4 {
let mut z7 = m(x0_2,x7) + m(x1_2,x6) + m(x2_2,x5) + m(x3_2,x4) + ((m(x8,x9_19)) << 1);
let mut z8 = m(x0_2,x8) + m(x1_2,x7_2) + m(x2_2,x6) + m(x3_2,x5_2) + m(x4,x4) + ((m(x9,x9_19)) << 1);
let mut z9 = m(x0_2,x9) + m(x1_2,x8) + m(x2_2,x7) + m(x3_2,x6) + m(x4_2,x5);


let low__p37 = u64x4::splat(0x3ffffed << 37);
let even_p37 = u64x4::splat(0x3ffffff << 37);
Expand Down Expand Up @@ -609,8 +604,8 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 {
#[inline(always)]
fn m_lo(x: (u32x2, u32x2), y: (u32x2, u32x2)) -> (u32x2, u32x2) {
use core::arch::aarch64::vmull_u32;
unsafe {
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
unsafe {
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
vmull_u32(x.1.into_bits(), y.1.into_bits()).into_bits());
(u32x2::new(x.0.extract(0), x.0.extract(2)), u32x2::new(x.1.extract(0), x.1.extract(2)))
}
Expand Down

0 comments on commit 26c494f

Please sign in to comment.