diff --git a/src/backend/vector/neon/field.rs b/src/backend/vector/neon/field.rs index 8c93afef0..f87ede8bf 100644 --- a/src/backend/vector/neon/field.rs +++ b/src/backend/vector/neon/field.rs @@ -10,9 +10,9 @@ // - Henry de Valence // - Robrecht Blancquaert -//! More details on the algorithms can be found in the `avx2` -//! module. Here comments are mostly added only when needed -//! to explain differenes between the 'base' avx2 version and +//! More details on the algorithms can be found in the `avx2` +//! module. Here comments are mostly added only when needed +//! to explain differenes between the 'base' avx2 version and //! this re-implementation for arm neon. //! The most major difference is the split of one vector of 8 @@ -61,10 +61,10 @@ fn repack_pair(x: (u32x4, u32x4), y: (u32x4, u32x4)) -> (u32x4, u32x4) { use core::arch::aarch64::vgetq_lane_u32; (vcombine_u32( - vset_lane_u32(vgetq_lane_u32(x.0.into_bits(), 2) , vget_low_u32(x.0.into_bits()), 1), - vset_lane_u32(vgetq_lane_u32(y.0.into_bits(), 2) , vget_low_u32(y.0.into_bits()), 1)).into_bits(), + vset_lane_u32(vgetq_lane_u32(x.0.into_bits(), 2) , vget_low_u32(x.0.into_bits()), 1), + vset_lane_u32(vgetq_lane_u32(y.0.into_bits(), 2) , vget_low_u32(y.0.into_bits()), 1)).into_bits(), vcombine_u32( - vset_lane_u32(vgetq_lane_u32(x.1.into_bits(), 2) , vget_low_u32(x.1.into_bits()), 1), + vset_lane_u32(vgetq_lane_u32(x.1.into_bits(), 2) , vget_low_u32(x.1.into_bits()), 1), vset_lane_u32(vgetq_lane_u32(y.1.into_bits(), 2) , vget_low_u32(y.1.into_bits()), 1)).into_bits()) } } @@ -100,16 +100,16 @@ macro_rules! lane_shuffle { unsafe { use core::arch::aarch64::vgetq_lane_u32; const c: [i32; 8] = [$l0, $l1, $l2, $l3, $l4, $l5, $l6, $l7]; - (u32x4::new(if c[0] < 4 { vgetq_lane_u32($x.0.into_bits(), c[0]) } else { vgetq_lane_u32($x.1.into_bits(), c[0] - 4) }, - if c[1] < 4 { vgetq_lane_u32($x.0.into_bits(), c[1]) } else { vgetq_lane_u32($x.1.into_bits(), c[1] - 4) }, - if c[2] < 4 { vgetq_lane_u32($x.0.into_bits(), c[2]) } else { vgetq_lane_u32($x.1.into_bits(), c[2] - 4) }, + (u32x4::new(if c[0] < 4 { vgetq_lane_u32($x.0.into_bits(), c[0]) } else { vgetq_lane_u32($x.1.into_bits(), c[0] - 4) }, + if c[1] < 4 { vgetq_lane_u32($x.0.into_bits(), c[1]) } else { vgetq_lane_u32($x.1.into_bits(), c[1] - 4) }, + if c[2] < 4 { vgetq_lane_u32($x.0.into_bits(), c[2]) } else { vgetq_lane_u32($x.1.into_bits(), c[2] - 4) }, if c[3] < 4 { vgetq_lane_u32($x.0.into_bits(), c[3]) } else { vgetq_lane_u32($x.1.into_bits(), c[3] - 4) }), - u32x4::new(if c[4] < 4 { vgetq_lane_u32($x.0.into_bits(), c[4]) } else { vgetq_lane_u32($x.1.into_bits(), c[4] - 4) }, - if c[5] < 4 { vgetq_lane_u32($x.0.into_bits(), c[5]) } else { vgetq_lane_u32($x.1.into_bits(), c[5] - 4) }, - if c[6] < 4 { vgetq_lane_u32($x.0.into_bits(), c[6]) } else { vgetq_lane_u32($x.1.into_bits(), c[6] - 4) }, + u32x4::new(if c[4] < 4 { vgetq_lane_u32($x.0.into_bits(), c[4]) } else { vgetq_lane_u32($x.1.into_bits(), c[4] - 4) }, + if c[5] < 4 { vgetq_lane_u32($x.0.into_bits(), c[5]) } else { vgetq_lane_u32($x.1.into_bits(), c[5] - 4) }, + if c[6] < 4 { vgetq_lane_u32($x.0.into_bits(), c[6]) } else { vgetq_lane_u32($x.1.into_bits(), c[6] - 4) }, if c[7] < 4 { vgetq_lane_u32($x.0.into_bits(), c[7]) } else { vgetq_lane_u32($x.1.into_bits(), c[7] - 4) })) } - + } } @@ -161,14 +161,14 @@ impl FieldElement2625x4 { pub fn split(&self) -> [FieldElement51; 4] { let mut out = [FieldElement51::zero(); 4]; for i in 0..5 { - let a_2i = self.0[i].0.extract(0) as u64; - let b_2i = self.0[i].0.extract(1) as u64; - let a_2i_1 = self.0[i].0.extract(2) as u64; + let a_2i = self.0[i].0.extract(0) as u64; + let b_2i = self.0[i].0.extract(1) as u64; + let a_2i_1 = self.0[i].0.extract(2) as u64; let b_2i_1 = self.0[i].0.extract(3) as u64; let c_2i = self.0[i].1.extract(0) as u64; - let d_2i = self.0[i].1.extract(1) as u64; - let c_2i_1 = self.0[i].1.extract(2) as u64; - let d_2i_1 = self.0[i].1.extract(3) as u64; + let d_2i = self.0[i].1.extract(1) as u64; + let c_2i_1 = self.0[i].1.extract(2) as u64; + let d_2i_1 = self.0[i].1.extract(3) as u64; out[0].0[i] = a_2i + (a_2i_1 << 26); out[1].0[i] = b_2i + (b_2i_1 << 26); @@ -211,40 +211,33 @@ impl FieldElement2625x4 { pub fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 { #[inline(always)] fn blend_lanes(x: (u32x4, u32x4), y: (u32x4, u32x4), control: Lanes) -> (u32x4, u32x4) { - unsafe { - use core::arch::aarch64::vqtbx1q_u8; - match control { - Lanes::C => { - (x.0, - vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits()) - } - Lanes::D => { - (x.0, - vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits()).into_bits()) - } - Lanes::AD => { - (vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits() ).into_bits(), - vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits() ).into_bits()) - } - Lanes::AB => { - (y.0, x.1) - } - Lanes::AC => { - (vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits(), - vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits()) - } - Lanes::CD => { - (x.0, y.1) - } - Lanes::BC => { - (vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits() ).into_bits(), - vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits() ).into_bits()) - } - Lanes::ABCD => { - y - } - + use packed_simd::shuffle; + match control { + Lanes::C => { + (x.0, shuffle!(y.1, x.1, [0, 5, 2, 7])) + } + Lanes::D => { + (x.0, shuffle!(y.1, x.1, [4, 1, 6, 3])) + } + Lanes::AD => { + (shuffle!(y.0, x.0, [0, 5, 2, 7]), shuffle!(y.1, x.1, [4, 1, 6, 3])) + } + Lanes::AB => { + (y.0, x.1) + } + Lanes::AC => { + (shuffle!(y.0, x.0, [0, 5, 2, 7]), shuffle!(y.1, x.1, [0, 5, 2, 7])) } + Lanes::CD => { + (x.0, y.1) + } + Lanes::BC => { + (shuffle!(y.0, x.0, [4, 1, 6, 3]), shuffle!(y.1, x.1, [0, 5, 2, 7])) + } + Lanes::ABCD => { + y + } + } } @@ -333,7 +326,7 @@ impl FieldElement2625x4 { use core::arch::aarch64::vget_high_u32; use core::arch::aarch64::vcombine_u32; - let c: (u32x4, u32x4) = (vqshlq_u32(v.0.into_bits(), shifts.0.into_bits()).into_bits(), + let c: (u32x4, u32x4) = (vqshlq_u32(v.0.into_bits(), shifts.0.into_bits()).into_bits(), vqshlq_u32(v.1.into_bits(), shifts.1.into_bits()).into_bits()); (vcombine_u32(vget_high_u32(c.0.into_bits()), vget_low_u32(c.0.into_bits())).into_bits(), vcombine_u32(vget_high_u32(c.1.into_bits()), vget_low_u32(c.1.into_bits())).into_bits()) @@ -377,7 +370,7 @@ impl FieldElement2625x4 { use core::arch::aarch64::vmulq_n_u32; use core::arch::aarch64::vget_low_u32; use core::arch::aarch64::vcombine_u32; - + let c9_19_spread: (u32x4, u32x4) = (vmulq_n_u32(c98.0.into_bits(), 19).into_bits(), vmulq_n_u32(c98.1.into_bits(), 19).into_bits()); (vcombine_u32(vget_low_u32(c9_19_spread.0.into_bits()), u32x2::splat(0).into_bits()).into_bits(), @@ -423,9 +416,9 @@ impl FieldElement2625x4 { unsafe { use core::arch::aarch64::vmulq_n_u32; - c0 = (vmulq_n_u32(c0.0.into_bits(), 19).into_bits(), + c0 = (vmulq_n_u32(c0.0.into_bits(), 19).into_bits(), vmulq_n_u32(c0.1.into_bits(), 19).into_bits()); - c1 = (vmulq_n_u32(c1.0.into_bits(), 19).into_bits(), + c1 = (vmulq_n_u32(c1.0.into_bits(), 19).into_bits(), vmulq_n_u32(c1.1.into_bits(), 19).into_bits()); } @@ -457,8 +450,8 @@ impl FieldElement2625x4 { #[inline(always)] fn m_lo(x: (u32x2, u32x2), y: (u32x2, u32x2)) -> (u32x2, u32x2) { use core::arch::aarch64::vmull_u32; - unsafe { - let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(), + unsafe { + let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(), vmull_u32(x.1.into_bits(), y.1.into_bits()).into_bits()); (u32x2::new(x.0.extract(0), x.0.extract(2)), u32x2::new(x.1.extract(0), x.1.extract(2))) } @@ -497,7 +490,7 @@ impl FieldElement2625x4 { let mut z7 = m(x0_2,x7) + m(x1_2,x6) + m(x2_2,x5) + m(x3_2,x4) + ((m(x8,x9_19)) << 1); let mut z8 = m(x0_2,x8) + m(x1_2,x7_2) + m(x2_2,x6) + m(x3_2,x5_2) + m(x4,x4) + ((m(x9,x9_19)) << 1); let mut z9 = m(x0_2,x9) + m(x1_2,x8) + m(x2_2,x7) + m(x3_2,x6) + m(x4_2,x5); - + let low__p37 = u64x4::splat(0x3ffffed << 37); let even_p37 = u64x4::splat(0x3ffffff << 37); @@ -609,8 +602,8 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { #[inline(always)] fn m_lo(x: (u32x2, u32x2), y: (u32x2, u32x2)) -> (u32x2, u32x2) { use core::arch::aarch64::vmull_u32; - unsafe { - let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(), + unsafe { + let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(), vmull_u32(x.1.into_bits(), y.1.into_bits()).into_bits()); (u32x2::new(x.0.extract(0), x.0.extract(2)), u32x2::new(x.1.extract(0), x.1.extract(2))) }