diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index be98f9751..45aa87b0f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -55,19 +55,19 @@ jobs: - run: cargo build --target thumbv7em-none-eabi --release - run: cargo build --target thumbv7em-none-eabi --release --features serde - build-simd-nightly: - name: Build simd backend (nightly) + test-simd-native: + name: Test simd backend (native) runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@nightly - # Build with AVX2 features, then with AVX512 features - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo build --target x86_64-unknown-linux-gnu - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx512ifma' - run: cargo build --target x86_64-unknown-linux-gnu + # This will: + # 1) build all of the x86_64 SIMD code, + # 2) run all of the SIMD-specific tests that the test runner supports, + # 3) run all of the normal tests using the best available SIMD backend. + RUSTFLAGS: '-C target_cpu=native' + run: cargo test --features simd --target x86_64-unknown-linux-gnu test-simd-avx2: name: Test simd backend (avx2) @@ -76,8 +76,10 @@ jobs: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo test --target x86_64-unknown-linux-gnu + # This will run AVX2-specific tests and run all of the normal tests + # with the AVX2 backend, even if the runner supports AVX512. + RUSTFLAGS: '-C target_feature=+avx2' + run: cargo test --no-default-features --features alloc,precomputed-tables,zeroize,simd_avx2 --target x86_64-unknown-linux-gnu build-docs: name: Build docs @@ -131,12 +133,7 @@ jobs: - uses: dtolnay/rust-toolchain@nightly with: components: clippy - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo clippy --target x86_64-unknown-linux-gnu - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx512ifma' - run: cargo clippy --target x86_64-unknown-linux-gnu + - run: cargo clippy --target x86_64-unknown-linux-gnu rustfmt: name: Check formatting @@ -162,9 +159,7 @@ jobs: - uses: dtolnay/rust-toolchain@1.60.0 - run: cargo build --no-default-features --features serde # Also make sure the AVX2 build works - - env: - RUSTFLAGS: '--cfg curve25519_dalek_backend="simd" -C target_feature=+avx2' - run: cargo build --target x86_64-unknown-linux-gnu + - run: cargo build --target x86_64-unknown-linux-gnu bench: name: Check that benchmarks compile diff --git a/Cargo.toml b/Cargo.toml index a1dafcb24..33bbb29a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ rustdoc-args = [ "--html-in-header", "docs/assets/rustdoc-include-katex-header.html", "--cfg", "docsrs", ] -rustc-args = ["--cfg", "curve25519_dalek_backend=\"simd\""] features = ["serde", "rand_core", "digest", "legacy_compatibility"] [dev-dependencies] @@ -54,15 +53,29 @@ digest = { version = "0.10", default-features = false, optional = true } subtle = { version = "2.3.0", default-features = false } serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] } zeroize = { version = "1", default-features = false, optional = true } +unsafe_target_feature = { version = "= 0.1.1", optional = true } + +[target.'cfg(target_arch = "x86_64")'.dependencies] +cpufeatures = "0.2.6" [target.'cfg(curve25519_dalek_backend = "fiat")'.dependencies] fiat-crypto = "0.1.19" [features] -default = ["alloc", "precomputed-tables", "zeroize"] +default = ["alloc", "precomputed-tables", "zeroize", "simd"] alloc = ["zeroize?/alloc"] precomputed-tables = [] legacy_compatibility = [] +# Whether to allow the use of the AVX2 SIMD backend. +simd_avx2 = ["unsafe_target_feature"] + +# Whether to allow the use of the AVX512 SIMD backend. +# (Note: This requires Rust nightly; on Rust stable this feature will be ignored.) +simd_avx512 = ["unsafe_target_feature"] + +# A meta-feature to allow all SIMD backends to be used. +simd = ["simd_avx2", "simd_avx512"] + [profile.dev] opt-level = 2 diff --git a/Makefile b/Makefile index 3b41b1756..bb61cc844 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,5 @@ FEATURES := serde rand_core digest legacy_compatibility -export RUSTFLAGS := --cfg=curve25519_dalek_backend="simd" export RUSTDOCFLAGS := \ --cfg docsrs \ --html-in-header docs/assets/rustdoc-include-katex-header.html diff --git a/README.md b/README.md index 2735dbdf4..12100691d 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,9 @@ curve25519-dalek = "4.0.0-rc.2" | `alloc` | ✓ | Enables Edwards and Ristretto multiscalar multiplication, batch scalar inversion, and batch Ristretto double-and-compress. Also enables `zeroize`. | | `zeroize` | ✓ | Enables [`Zeroize`][zeroize-trait] for all scalar and curve point types. | | `precomputed-tables` | ✓ | Includes precomputed basepoint multiplication tables. This speeds up `EdwardsPoint::mul_base` and `RistrettoPoint::mul_base` by ~4x, at the cost of ~30KB added to the code size. | +| `simd_avx2` | ✓ | Allows the AVX2 SIMD backend to be used, if available. | +| `simd_avx512` | ✓ | Allows the AVX512 SIMD backend to be used, if available. | +| `simd` | ✓ | Allows every SIMD backend to be used, if available. | | `rand_core` | | Enables `Scalar::random` and `RistrettoPoint::random`. This is an optional dependency whose version is not subject to SemVer. See [below](#public-api-semver-exemptions) for more details. | | `digest` | | Enables `RistrettoPoint::{from_hash, hash_from_bytes}` and `Scalar::{from_hash, hash_from_bytes}`. This is an optional dependency whose version is not subject to SemVer. See [below](#public-api-semver-exemptions) for more details. | | `serde` | | Enables `serde` serialization/deserialization for all the point and scalar types. | @@ -95,18 +98,17 @@ See tracking issue: [curve25519-dalek/issues/521](https://github.com/dalek-crypt Curve arithmetic is implemented and used by selecting one of the following backends: -| Backend | Implementation | Target backends | -| :--- | :--- | :--- | -| `[default]` | Serial formulas | `u32`
`u64` | -| `simd` | [Parallel][parallel_doc], using Advanced Vector Extensions | `avx2`
`avx512ifma` | -| `fiat` | Formally verified field arithmetic from [fiat-crypto] | `fiat_u32`
`fiat_u64` | +| Backend | Implementation | Target backends | +| :--- | :--- | :--- | +| `[default]` | Automatic runtime backend selection (either serial or SIMD) | `u32`
`u64`
`avx2`
`avx512` | +| `fiat` | Formally verified field arithmetic from [fiat-crypto] | `fiat_u32`
`fiat_u64` | -To choose a backend other than the `[default]` serial backend, set the +To choose a backend other than the `[default]` backend, set the environment variable: ```sh RUSTFLAGS='--cfg curve25519_dalek_backend="BACKEND"' ``` -where `BACKEND` is `simd` or `fiat`. Equivalently, you can write to +where `BACKEND` is `fiat`. Equivalently, you can write to `~/.cargo/config`: ```toml [build] @@ -114,11 +116,8 @@ rustflags = ['--cfg=curve25519_dalek_backend="BACKEND"'] ``` More info [here](https://doc.rust-lang.org/cargo/reference/config.html#buildrustflags). -The `simd` backend requires extra configuration. See [the SIMD -section](#simd-target-backends). - Note for contributors: The target backends are not entirely independent of each -other. The `simd` backend directly depends on parts of the the `u64` backend to +other. The SIMD backend directly depends on parts of the the `u64` backend to function. ## Word size for serial backends @@ -137,7 +136,7 @@ RUSTFLAGS='--cfg curve25519_dalek_bits="SIZE"' where `SIZE` is `32` or `64`. As in the above section, this can also be placed in `~/.cargo/config`. -**NOTE:** The `simd` backend CANNOT be used with word size 32. +**NOTE:** Using a word size of 32 will automatically disable SIMD support. ### Cross-compilation @@ -152,18 +151,19 @@ $ cargo build --target i686-unknown-linux-gnu ## SIMD target backends -Target backend selection within `simd` must be done manually by setting the -`RUSTFLAGS` environment variable to one of the below options: +The SIMD target backend selection is done automatically at runtime depending +on the available CPU features, provided the appropriate feature flag is enabled. -| CPU feature | `RUSTFLAGS` | Requires nightly? | -| :--- | :--- | :--- | -| avx2 | `-C target_feature=+avx2` | no | -| avx512ifma | `-C target_feature=+avx512ifma` | yes | +You can also specify an appropriate `-C target_feature` to build a binary +which assumes the required SIMD instructions are always available. -Or you can use `-C target_cpu=native` if you don't know what to set. +| Backend | Feature flag | `RUSTFLAGS` | Requires nightly? | +| :--- | :--- | :--- | :--- | +| avx2 | `simd_avx2` | `-C target_feature=+avx2` | no | +| avx512 | `simd_avx512` | `-C target_feature=+avx512ifma,+avx512vl` | yes | -The AVX512 backend requires Rust nightly. If enabled and when compiled on a non-nightly -compiler it will fall back to using the AVX2 backend. +The AVX512 backend requires Rust nightly. When compiled on a non-nightly +compiler it will always be disabled. # Documentation @@ -243,7 +243,8 @@ The implementation is memory-safe, and contains no significant `unsafe` code. The SIMD backend uses `unsafe` internally to call SIMD intrinsics. These are marked `unsafe` only because invoking them on an inappropriate CPU would cause `SIGILL`, but the entire backend is only -compiled with appropriate `target_feature`s, so this cannot occur. +invoked when the appropriate CPU features are detected at runtime, or +when the whole program is compiled with the appropriate `target_feature`s. # Performance @@ -251,8 +252,7 @@ Benchmarks are run using [`criterion.rs`][criterion]: ```sh cargo bench --features "rand_core" -# Uses avx2 or ifma only if compiled for an appropriate target. -export RUSTFLAGS='--cfg curve25519_dalek_backend="simd" -C target_cpu=native' +export RUSTFLAGS='-C target_cpu=native' cargo +nightly bench --features "rand_core" ``` @@ -294,7 +294,7 @@ universe's beauty, but also his deep hatred of the Daleks. Rusty destroys the other Daleks and departs the ship, determined to track down and bring an end to the Dalek race.* -`curve25519-dalek` is authored by Isis Agora Lovecruft and Henry de Valence. +`curve25519-dalek` is authored by Isis Agora Lovecruft and Henry de Valence. Portions of this library were originally a port of [Adam Langley's Golang ed25519 library](https://github.com/agl/ed25519), which was in diff --git a/build.rs b/build.rs index 80c0eb1fb..04f4d9ca3 100644 --- a/build.rs +++ b/build.rs @@ -27,6 +27,13 @@ fn main() { { println!("cargo:rustc-cfg=nightly"); } + + let rustc_version = rustc_version::version().expect("failed to detect rustc version"); + if rustc_version.major == 1 && rustc_version.minor <= 64 { + // Old versions of Rust complain when you have an `unsafe fn` and you use `unsafe {}` inside, + // so for those we want to apply the `#[allow(unused_unsafe)]` attribute to get rid of that warning. + println!("cargo:rustc-cfg=allow_unused_unsafe"); + } } // Deterministic cfg(curve25519_dalek_bits) when this is not explicitly set. diff --git a/src/backend/mod.rs b/src/backend/mod.rs index b6cea7ebf..18c8c2251 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -34,7 +34,308 @@ //! The [`vector`] backend is selected by the `simd_backend` cargo //! feature; it uses the [`serial`] backend for non-vectorized operations. +use crate::EdwardsPoint; +use crate::Scalar; + pub mod serial; -#[cfg(any(curve25519_dalek_backend = "simd", docsrs))] +#[cfg(all( + target_arch = "x86_64", + any(feature = "simd_avx2", all(feature = "simd_avx512", nightly)), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") +))] pub mod vector; + +#[derive(Copy, Clone)] +enum BackendKind { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx2, + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx512, + Serial, +} + +#[inline] +fn get_selected_backend() -> BackendKind { + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + { + cpufeatures::new!(cpuid_avx512, "avx512ifma", "avx512vl"); + let token_avx512: cpuid_avx512::InitToken = cpuid_avx512::init(); + if token_avx512.get() { + return BackendKind::Avx512; + } + } + + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + { + cpufeatures::new!(cpuid_avx2, "avx2"); + let token_avx2: cpuid_avx2::InitToken = cpuid_avx2::init(); + if token_avx2.get() { + return BackendKind::Avx2; + } + } + + BackendKind::Serial +} + +#[allow(missing_docs)] +#[cfg(feature = "alloc")] +pub fn pippenger_optional_multiscalar_mul(scalars: I, points: J) -> Option +where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator>, +{ + use crate::traits::VartimeMultiscalarMul; + + match get_selected_backend() { + #[cfg(all(target_arch = "x86_64", feature = "simd_avx2", curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx2 => + self::vector::scalar_mul::pippenger::spec_avx2::Pippenger::optional_multiscalar_mul::(scalars, points), + #[cfg(all(target_arch = "x86_64", all(feature = "simd_avx512", nightly), curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx512 => + self::vector::scalar_mul::pippenger::spec_avx512ifma_avx512vl::Pippenger::optional_multiscalar_mul::(scalars, points), + BackendKind::Serial => + self::serial::scalar_mul::pippenger::Pippenger::optional_multiscalar_mul::(scalars, points), + } +} + +#[cfg(feature = "alloc")] +pub(crate) enum VartimePrecomputedStraus { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx2(self::vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + Avx512ifma( + self::vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus, + ), + Scalar(self::serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus), +} + +#[cfg(feature = "alloc")] +impl VartimePrecomputedStraus { + pub fn new(static_points: I) -> Self + where + I: IntoIterator, + I::Item: core::borrow::Borrow, + { + use crate::traits::VartimePrecomputedMultiscalarMul; + + match get_selected_backend() { + #[cfg(all(target_arch = "x86_64", feature = "simd_avx2", curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx2 => + VartimePrecomputedStraus::Avx2(self::vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus::new(static_points)), + #[cfg(all(target_arch = "x86_64", all(feature = "simd_avx512", nightly), curve25519_dalek_bits = "64", not(curve25519_dalek_backend = "fiat")))] + BackendKind::Avx512 => + VartimePrecomputedStraus::Avx512ifma(self::vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus::new(static_points)), + BackendKind::Serial => + VartimePrecomputedStraus::Scalar(self::serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points)) + } + } + + pub fn optional_mixed_multiscalar_mul( + &self, + static_scalars: I, + dynamic_scalars: J, + dynamic_points: K, + ) -> Option + where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator, + J::Item: core::borrow::Borrow, + K: IntoIterator>, + { + use crate::traits::VartimePrecomputedMultiscalarMul; + + match self { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + VartimePrecomputedStraus::Avx2(inner) => inner.optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points, + ), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + VartimePrecomputedStraus::Avx512ifma(inner) => inner.optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points, + ), + VartimePrecomputedStraus::Scalar(inner) => inner.optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points, + ), + } + } +} + +#[allow(missing_docs)] +#[cfg(feature = "alloc")] +pub fn straus_multiscalar_mul(scalars: I, points: J) -> EdwardsPoint +where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator, + J::Item: core::borrow::Borrow, +{ + use crate::traits::MultiscalarMul; + + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => { + self::vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::( + scalars, points, + ) + } + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::< + I, + J, + >(scalars, points) + } + BackendKind::Serial => { + self::serial::scalar_mul::straus::Straus::multiscalar_mul::(scalars, points) + } + } +} + +#[allow(missing_docs)] +#[cfg(feature = "alloc")] +pub fn straus_optional_multiscalar_mul(scalars: I, points: J) -> Option +where + I: IntoIterator, + I::Item: core::borrow::Borrow, + J: IntoIterator>, +{ + use crate::traits::VartimeMultiscalarMul; + + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => { + self::vector::scalar_mul::straus::spec_avx2::Straus::optional_multiscalar_mul::( + scalars, points, + ) + } + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::optional_multiscalar_mul::< + I, + J, + >(scalars, points) + } + BackendKind::Serial => { + self::serial::scalar_mul::straus::Straus::optional_multiscalar_mul::( + scalars, points, + ) + } + } +} + +/// Perform constant-time, variable-base scalar multiplication. +pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => self::vector::scalar_mul::variable_base::spec_avx2::mul(point, scalar), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar) + } + BackendKind::Serial => self::serial::scalar_mul::variable_base::mul(point, scalar), + } +} + +/// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the Ed25519 basepoint. +#[allow(non_snake_case)] +pub fn vartime_double_base_mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { + match get_selected_backend() { + #[cfg(all( + target_arch = "x86_64", + feature = "simd_avx2", + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx2 => self::vector::scalar_mul::vartime_double_base::spec_avx2::mul(a, A, b), + #[cfg(all( + target_arch = "x86_64", + all(feature = "simd_avx512", nightly), + curve25519_dalek_bits = "64", + not(curve25519_dalek_backend = "fiat") + ))] + BackendKind::Avx512 => { + self::vector::scalar_mul::vartime_double_base::spec_avx512ifma_avx512vl::mul(a, A, b) + } + BackendKind::Serial => self::serial::scalar_mul::vartime_double_base::mul(a, A, b), + } +} diff --git a/src/backend/serial/mod.rs b/src/backend/serial/mod.rs index 933bb88de..13fef5c63 100644 --- a/src/backend/serial/mod.rs +++ b/src/backend/serial/mod.rs @@ -42,8 +42,4 @@ cfg_if! { pub mod curve_models; -#[cfg(not(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") -)))] pub mod scalar_mul; diff --git a/src/backend/serial/u64/constants.rs b/src/backend/serial/u64/constants.rs index 1aaed3109..67d51492d 100644 --- a/src/backend/serial/u64/constants.rs +++ b/src/backend/serial/u64/constants.rs @@ -327,7 +327,7 @@ pub const EIGHT_TORSION_INNER_DOC_HIDDEN: [EdwardsPoint; 8] = [ /// Table containing precomputed multiples of the Ed25519 basepoint \\(B = (x, 4/5)\\). #[cfg(feature = "precomputed-tables")] -pub static ED25519_BASEPOINT_TABLE: &'static EdwardsBasepointTable = +pub static ED25519_BASEPOINT_TABLE: &EdwardsBasepointTable = &ED25519_BASEPOINT_TABLE_INNER_DOC_HIDDEN; /// Inner constant, used to avoid filling the docs with precomputed points. diff --git a/src/backend/vector/avx2/edwards.rs b/src/backend/vector/avx2/edwards.rs index 032265069..7bb58b1ee 100644 --- a/src/backend/vector/avx2/edwards.rs +++ b/src/backend/vector/avx2/edwards.rs @@ -41,8 +41,13 @@ use core::ops::{Add, Neg, Sub}; use subtle::Choice; use subtle::ConditionallySelectable; +use unsafe_target_feature::unsafe_target_feature; + use crate::edwards; -use crate::window::{LookupTable, NafLookupTable5, NafLookupTable8}; +use crate::window::{LookupTable, NafLookupTable5}; + +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +use crate::window::NafLookupTable8; use crate::traits::Identity; @@ -59,12 +64,14 @@ use super::field::{FieldElement2625x4, Lanes, Shuffle}; #[derive(Copy, Clone, Debug)] pub struct ExtendedPoint(pub(super) FieldElement2625x4); +#[unsafe_target_feature("avx2")] impl From for ExtendedPoint { fn from(P: edwards::EdwardsPoint) -> ExtendedPoint { ExtendedPoint(FieldElement2625x4::new(&P.X, &P.Y, &P.Z, &P.T)) } } +#[unsafe_target_feature("avx2")] impl From for edwards::EdwardsPoint { fn from(P: ExtendedPoint) -> edwards::EdwardsPoint { let tmp = P.0.split(); @@ -77,6 +84,7 @@ impl From for edwards::EdwardsPoint { } } +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for ExtendedPoint { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { ExtendedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice)) @@ -87,18 +95,21 @@ impl ConditionallySelectable for ExtendedPoint { } } +#[unsafe_target_feature("avx2")] impl Default for ExtendedPoint { fn default() -> ExtendedPoint { ExtendedPoint::identity() } } +#[unsafe_target_feature("avx2")] impl Identity for ExtendedPoint { fn identity() -> ExtendedPoint { constants::EXTENDEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx2")] impl ExtendedPoint { /// Compute the double of this point. pub fn double(&self) -> ExtendedPoint { @@ -184,6 +195,7 @@ impl ExtendedPoint { #[derive(Copy, Clone, Debug)] pub struct CachedPoint(pub(super) FieldElement2625x4); +#[unsafe_target_feature("avx2")] impl From for CachedPoint { fn from(P: ExtendedPoint) -> CachedPoint { let mut x = P.0; @@ -202,18 +214,21 @@ impl From for CachedPoint { } } +#[unsafe_target_feature("avx2")] impl Default for CachedPoint { fn default() -> CachedPoint { CachedPoint::identity() } } +#[unsafe_target_feature("avx2")] impl Identity for CachedPoint { fn identity() -> CachedPoint { constants::CACHEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for CachedPoint { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { CachedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice)) @@ -224,6 +239,7 @@ impl ConditionallySelectable for CachedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a> Neg for &'a CachedPoint { type Output = CachedPoint; /// Lazily negate the point. @@ -238,6 +254,7 @@ impl<'a> Neg for &'a CachedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -275,6 +292,7 @@ impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -288,6 +306,7 @@ impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx2")] impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { fn from(point: &'a edwards::EdwardsPoint) -> Self { let P = ExtendedPoint::from(*point); @@ -299,6 +318,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { } } +#[unsafe_target_feature("avx2")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -312,6 +332,8 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { } } +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +#[unsafe_target_feature("avx2")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -325,6 +347,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { } } +#[cfg(target_feature = "avx2")] #[cfg(test)] mod test { use super::*; @@ -524,6 +547,7 @@ mod test { doubling_test_helper(P); } + #[cfg(any(feature = "precomputed-tables", feature = "alloc"))] #[test] fn basepoint_odd_lookup_table_verify() { use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; diff --git a/src/backend/vector/avx2/field.rs b/src/backend/vector/avx2/field.rs index 9f278723d..bdb55efa5 100644 --- a/src/backend/vector/avx2/field.rs +++ b/src/backend/vector/avx2/field.rs @@ -48,6 +48,8 @@ use crate::backend::vector::avx2::constants::{ P_TIMES_16_HI, P_TIMES_16_LO, P_TIMES_2_HI, P_TIMES_2_LO, }; +use unsafe_target_feature::unsafe_target_feature; + /// Unpack 32-bit lanes into 64-bit lanes: /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) @@ -57,6 +59,7 @@ use crate::backend::vector::avx2::constants::{ /// (a0, 0, b0, 0, c0, 0, d0, 0) /// (a1, 0, b1, 0, c1, 0, d1, 0) /// ``` +#[unsafe_target_feature("avx2")] #[inline(always)] fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { let a: u32x8; @@ -80,6 +83,7 @@ fn unpack_pair(src: u32x8) -> (u32x8, u32x8) { /// ```ascii,no_run /// (a0, b0, a1, b1, c0, d0, c1, d1) /// ``` +#[unsafe_target_feature("avx2")] #[inline(always)] fn repack_pair(x: u32x8, y: u32x8) -> u32x8 { unsafe { @@ -151,6 +155,7 @@ pub struct FieldElement2625x4(pub(crate) [u32x8; 5]); use subtle::Choice; use subtle::ConditionallySelectable; +#[unsafe_target_feature("avx2")] impl ConditionallySelectable for FieldElement2625x4 { fn conditional_select( a: &FieldElement2625x4, @@ -179,6 +184,7 @@ impl ConditionallySelectable for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl FieldElement2625x4 { pub const ZERO: FieldElement2625x4 = FieldElement2625x4([u32x8::splat_const::<0>(); 5]); @@ -675,6 +681,7 @@ impl FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Neg for FieldElement2625x4 { type Output = FieldElement2625x4; @@ -703,6 +710,7 @@ impl Neg for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Add for FieldElement2625x4 { type Output = FieldElement2625x4; /// Add two `FieldElement2625x4`s, without performing a reduction. @@ -718,6 +726,7 @@ impl Add for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { type Output = FieldElement2625x4; /// Perform a multiplication by a vector of small constants. @@ -750,6 +759,7 @@ impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 { } } +#[unsafe_target_feature("avx2")] impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { type Output = FieldElement2625x4; /// Multiply `self` by `rhs`. @@ -765,6 +775,7 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { /// The coefficients of the result are bounded with \\( b < 0.007 \\). /// #[rustfmt::skip] // keep alignment of z* calculations + #[inline] fn mul(self, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 { #[inline(always)] fn m(x: u32x8, y: u32x8) -> u64x4 { @@ -859,6 +870,7 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 { } } +#[cfg(target_feature = "avx2")] #[cfg(test)] mod test { use super::*; diff --git a/src/backend/vector/avx2/mod.rs b/src/backend/vector/avx2/mod.rs index b3e2d14ea..fba39f05c 100644 --- a/src/backend/vector/avx2/mod.rs +++ b/src/backend/vector/avx2/mod.rs @@ -16,3 +16,5 @@ pub(crate) mod field; pub(crate) mod edwards; pub(crate) mod constants; + +pub(crate) use self::edwards::{CachedPoint, ExtendedPoint}; diff --git a/src/backend/vector/ifma/edwards.rs b/src/backend/vector/ifma/edwards.rs index 5bdc3ce07..ccfe092c8 100644 --- a/src/backend/vector/ifma/edwards.rs +++ b/src/backend/vector/ifma/edwards.rs @@ -16,8 +16,13 @@ use core::ops::{Add, Neg, Sub}; use subtle::Choice; use subtle::ConditionallySelectable; +use unsafe_target_feature::unsafe_target_feature; + use crate::edwards; -use crate::window::{LookupTable, NafLookupTable5, NafLookupTable8}; +use crate::window::{LookupTable, NafLookupTable5}; + +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +use crate::window::NafLookupTable8; use super::constants; use super::field::{F51x4Reduced, F51x4Unreduced, Lanes, Shuffle}; @@ -28,12 +33,14 @@ pub struct ExtendedPoint(pub(super) F51x4Unreduced); #[derive(Copy, Clone, Debug)] pub struct CachedPoint(pub(super) F51x4Reduced); +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for ExtendedPoint { fn from(P: edwards::EdwardsPoint) -> ExtendedPoint { ExtendedPoint(F51x4Unreduced::new(&P.X, &P.Y, &P.Z, &P.T)) } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for edwards::EdwardsPoint { fn from(P: ExtendedPoint) -> edwards::EdwardsPoint { let reduced = F51x4Reduced::from(P.0); @@ -47,6 +54,7 @@ impl From for edwards::EdwardsPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for CachedPoint { fn from(P: ExtendedPoint) -> CachedPoint { let mut x = P.0; @@ -59,18 +67,21 @@ impl From for CachedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Default for ExtendedPoint { fn default() -> ExtendedPoint { ExtendedPoint::identity() } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Identity for ExtendedPoint { fn identity() -> ExtendedPoint { constants::EXTENDEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl ExtendedPoint { pub fn double(&self) -> ExtendedPoint { // (Y1 X1 T1 Z1) -- uses vpshufd (1c latency @ 1/c) @@ -122,6 +133,7 @@ impl ExtendedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -151,18 +163,21 @@ impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Default for CachedPoint { fn default() -> CachedPoint { CachedPoint::identity() } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Identity for CachedPoint { fn identity() -> CachedPoint { constants::CACHEDPOINT_IDENTITY } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl ConditionallySelectable for CachedPoint { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { CachedPoint(F51x4Reduced::conditional_select(&a.0, &b.0, choice)) @@ -173,6 +188,7 @@ impl ConditionallySelectable for CachedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> Neg for &'a CachedPoint { type Output = CachedPoint; @@ -182,6 +198,7 @@ impl<'a> Neg for &'a CachedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { type Output = ExtendedPoint; @@ -191,6 +208,7 @@ impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { fn from(point: &'a edwards::EdwardsPoint) -> Self { let P = ExtendedPoint::from(*point); @@ -202,6 +220,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -215,6 +234,8 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5 { } } +#[cfg(any(feature = "precomputed-tables", feature = "alloc"))] +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { fn from(point: &'a edwards::EdwardsPoint) -> Self { let A = ExtendedPoint::from(*point); @@ -228,6 +249,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8 { } } +#[cfg(target_feature = "avx512ifma,avx512vl")] #[cfg(test)] mod test { use super::*; diff --git a/src/backend/vector/ifma/field.rs b/src/backend/vector/ifma/field.rs index fd1955315..5928e14a2 100644 --- a/src/backend/vector/ifma/field.rs +++ b/src/backend/vector/ifma/field.rs @@ -16,15 +16,19 @@ use core::ops::{Add, Mul, Neg}; use crate::backend::serial::u64::field::FieldElement51; +use unsafe_target_feature::unsafe_target_feature; + /// A wrapper around `vpmadd52luq` that works on `u64x4`. -#[inline(always)] +#[unsafe_target_feature("avx512ifma,avx512vl")] +#[inline] unsafe fn madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x4 { use core::arch::x86_64::_mm256_madd52lo_epu64; _mm256_madd52lo_epu64(z.into(), x.into(), y.into()).into() } /// A wrapper around `vpmadd52huq` that works on `u64x4`. -#[inline(always)] +#[unsafe_target_feature("avx512ifma,avx512vl")] +#[inline] unsafe fn madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x4 { use core::arch::x86_64::_mm256_madd52hi_epu64; _mm256_madd52hi_epu64(z.into(), x.into(), y.into()).into() @@ -53,6 +57,7 @@ pub enum Shuffle { CACA, } +#[unsafe_target_feature("avx512ifma,avx512vl")] #[inline(always)] fn shuffle_lanes(x: u64x4, control: Shuffle) -> u64x4 { unsafe { @@ -84,6 +89,7 @@ pub enum Lanes { BCD, } +#[unsafe_target_feature("avx512ifma,avx512vl")] #[inline] fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 { unsafe { @@ -100,6 +106,7 @@ fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl F51x4Unreduced { pub const ZERO: F51x4Unreduced = F51x4Unreduced([u64x4::splat_const::<0>(); 5]); @@ -198,6 +205,7 @@ impl F51x4Unreduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Neg for F51x4Reduced { type Output = F51x4Reduced; @@ -209,6 +217,7 @@ impl Neg for F51x4Reduced { use subtle::Choice; use subtle::ConditionallySelectable; +#[unsafe_target_feature("avx512ifma,avx512vl")] impl ConditionallySelectable for F51x4Reduced { #[inline] fn conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced { @@ -235,6 +244,7 @@ impl ConditionallySelectable for F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl F51x4Reduced { #[inline] pub fn shuffle(&self, control: Shuffle) -> F51x4Reduced { @@ -373,6 +383,7 @@ impl F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for F51x4Unreduced { #[inline] fn from(x: F51x4Reduced) -> F51x4Unreduced { @@ -380,6 +391,7 @@ impl From for F51x4Unreduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl From for F51x4Reduced { #[inline] fn from(x: F51x4Unreduced) -> F51x4Reduced { @@ -405,6 +417,7 @@ impl From for F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl Add for F51x4Unreduced { type Output = F51x4Unreduced; #[inline] @@ -419,6 +432,7 @@ impl Add for F51x4Unreduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced { type Output = F51x4Unreduced; #[inline] @@ -470,6 +484,7 @@ impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced { } } +#[unsafe_target_feature("avx512ifma,avx512vl")] impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced { type Output = F51x4Unreduced; #[inline] @@ -614,6 +629,7 @@ impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced { } } +#[cfg(target_feature = "avx512ifma,avx512vl")] #[cfg(test)] mod test { use super::*; diff --git a/src/backend/vector/ifma/mod.rs b/src/backend/vector/ifma/mod.rs index 79a61ff3b..f48748d21 100644 --- a/src/backend/vector/ifma/mod.rs +++ b/src/backend/vector/ifma/mod.rs @@ -16,3 +16,5 @@ pub mod field; pub mod edwards; pub mod constants; + +pub(crate) use self::edwards::{CachedPoint, ExtendedPoint}; diff --git a/src/backend/vector/mod.rs b/src/backend/vector/mod.rs index 51c9e81e3..d720f4acb 100644 --- a/src/backend/vector/mod.rs +++ b/src/backend/vector/mod.rs @@ -11,60 +11,13 @@ #![doc = include_str!("../../../docs/parallel-formulas.md")] -#[cfg(not(any( - target_feature = "avx2", - all(target_feature = "avx512ifma", nightly), - docsrs -)))] -compile_error!("'simd' backend selected without target_feature=+avx2 or +avx512ifma"); - #[allow(missing_docs)] pub mod packed_simd; -#[cfg(any( - all( - target_feature = "avx2", - not(all(target_feature = "avx512ifma", nightly)) - ), - all(docsrs, target_arch = "x86_64") -))] +#[cfg(feature = "simd_avx2")] pub mod avx2; -#[cfg(any( - all( - target_feature = "avx2", - not(all(target_feature = "avx512ifma", nightly)) - ), - all(docsrs, target_arch = "x86_64") -))] -pub(crate) use self::avx2::{edwards::CachedPoint, edwards::ExtendedPoint}; -#[cfg(any( - all(target_feature = "avx512ifma", nightly), - all(docsrs, target_arch = "x86_64") -))] +#[cfg(all(feature = "simd_avx512", nightly))] pub mod ifma; -#[cfg(all(target_feature = "avx512ifma", nightly))] -pub(crate) use self::ifma::{edwards::CachedPoint, edwards::ExtendedPoint}; -#[cfg(any( - target_feature = "avx2", - all(target_feature = "avx512ifma", nightly), - all(docsrs, target_arch = "x86_64") -))] -#[allow(missing_docs)] pub mod scalar_mul; - -// Precomputed table re-exports - -#[cfg(any( - all( - target_feature = "avx2", - not(all(target_feature = "avx512ifma", nightly)), - feature = "precomputed-tables" - ), - all(docsrs, target_arch = "x86_64") -))] -pub(crate) use self::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; - -#[cfg(all(target_feature = "avx512ifma", nightly, feature = "precomputed-tables"))] -pub(crate) use self::ifma::constants::BASEPOINT_ODD_LOOKUP_TABLE; diff --git a/src/backend/vector/packed_simd.rs b/src/backend/vector/packed_simd.rs index 6a3484d72..6ab5dcc9c 100644 --- a/src/backend/vector/packed_simd.rs +++ b/src/backend/vector/packed_simd.rs @@ -3,14 +3,16 @@ // This file is part of curve25519-dalek. // See LICENSE for licensing information. -///! This module defines wrappers over platform-specific SIMD types to make them -///! more convenient to use. -///! -///! UNSAFETY: Everything in this module assumes that we're running on hardware -///! which supports at least AVX2. This invariant *must* be enforced -///! by the callers of this code. +//! This module defines wrappers over platform-specific SIMD types to make them +//! more convenient to use. +//! +//! UNSAFETY: Everything in this module assumes that we're running on hardware +//! which supports at least AVX2. This invariant *must* be enforced +//! by the callers of this code. use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub}; +use unsafe_target_feature::unsafe_target_feature; + macro_rules! impl_shared { ( $ty:ident, @@ -26,6 +28,7 @@ macro_rules! impl_shared { #[repr(transparent)] pub struct $ty(core::arch::x86_64::__m256i); + #[unsafe_target_feature("avx2")] impl From<$ty> for core::arch::x86_64::__m256i { #[inline] fn from(value: $ty) -> core::arch::x86_64::__m256i { @@ -33,6 +36,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl From for $ty { #[inline] fn from(value: core::arch::x86_64::__m256i) -> $ty { @@ -40,6 +44,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl PartialEq for $ty { #[inline] fn eq(&self, rhs: &$ty) -> bool { @@ -72,6 +77,7 @@ macro_rules! impl_shared { impl Eq for $ty {} + #[unsafe_target_feature("avx2")] impl Add for $ty { type Output = Self; @@ -81,6 +87,8 @@ macro_rules! impl_shared { } } + #[allow(clippy::assign_op_pattern)] + #[unsafe_target_feature("avx2")] impl AddAssign for $ty { #[inline] fn add_assign(&mut self, rhs: $ty) { @@ -88,6 +96,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl Sub for $ty { type Output = Self; @@ -97,6 +106,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl BitAnd for $ty { type Output = Self; @@ -106,6 +116,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] impl BitXor for $ty { type Output = Self; @@ -115,6 +126,8 @@ macro_rules! impl_shared { } } + #[allow(clippy::assign_op_pattern)] + #[unsafe_target_feature("avx2")] impl BitAndAssign for $ty { #[inline] fn bitand_assign(&mut self, rhs: $ty) { @@ -122,6 +135,8 @@ macro_rules! impl_shared { } } + #[allow(clippy::assign_op_pattern)] + #[unsafe_target_feature("avx2")] impl BitXorAssign for $ty { #[inline] fn bitxor_assign(&mut self, rhs: $ty) { @@ -129,6 +144,7 @@ macro_rules! impl_shared { } } + #[unsafe_target_feature("avx2")] #[allow(dead_code)] impl $ty { #[inline] @@ -152,6 +168,7 @@ macro_rules! impl_shared { macro_rules! impl_conv { ($src:ident => $($dst:ident),+) => { $( + #[unsafe_target_feature("avx2")] impl From<$src> for $dst { #[inline] fn from(value: $src) -> $dst { @@ -235,20 +252,22 @@ impl u64x4 { } /// Constructs a new instance. + #[unsafe_target_feature("avx2")] #[inline] - pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> Self { + pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 { unsafe { // _mm256_set_epi64 sets the underlying vector in reverse order of the args - Self(core::arch::x86_64::_mm256_set_epi64x( + u64x4(core::arch::x86_64::_mm256_set_epi64x( x3 as i64, x2 as i64, x1 as i64, x0 as i64, )) } } /// Constructs a new instance with all of the elements initialized to the given value. + #[unsafe_target_feature("avx2")] #[inline] - pub fn splat(x: u64) -> Self { - unsafe { Self(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } + pub fn splat(x: u64) -> u64x4 { + unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) } } } @@ -257,6 +276,7 @@ impl u32x8 { /// A constified variant of `new`. /// /// Should only be called from `const` contexts. At runtime `new` is going to be faster. + #[allow(clippy::too_many_arguments)] #[inline] pub const fn new_const( x0: u32, @@ -282,11 +302,13 @@ impl u32x8 { } /// Constructs a new instance. + #[allow(clippy::too_many_arguments)] + #[unsafe_target_feature("avx2")] #[inline] - pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> Self { + pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 { unsafe { // _mm256_set_epi32 sets the underlying vector in reverse order of the args - Self(core::arch::x86_64::_mm256_set_epi32( + u32x8(core::arch::x86_64::_mm256_set_epi32( x7 as i32, x6 as i32, x5 as i32, x4 as i32, x3 as i32, x2 as i32, x1 as i32, x0 as i32, )) @@ -294,11 +316,15 @@ impl u32x8 { } /// Constructs a new instance with all of the elements initialized to the given value. + #[unsafe_target_feature("avx2")] #[inline] - pub fn splat(x: u32) -> Self { - unsafe { Self(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } + pub fn splat(x: u32) -> u32x8 { + unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) } } +} +#[unsafe_target_feature("avx2")] +impl u32x8 { /// Multiplies the low unsigned 32-bits from each packed 64-bit element /// and returns the unsigned 64-bit results. /// diff --git a/src/backend/vector/scalar_mul/mod.rs b/src/backend/vector/scalar_mul/mod.rs index 36a7047a2..fed3470e7 100644 --- a/src/backend/vector/scalar_mul/mod.rs +++ b/src/backend/vector/scalar_mul/mod.rs @@ -9,15 +9,22 @@ // - isis agora lovecruft // - Henry de Valence +//! Implementations of various multiplication algorithms for the SIMD backends. + +#[allow(missing_docs)] pub mod variable_base; +#[allow(missing_docs)] pub mod vartime_double_base; +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub mod straus; +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub mod precomputed_straus; +#[allow(missing_docs)] #[cfg(feature = "alloc")] pub mod pippenger; diff --git a/src/backend/vector/scalar_mul/pippenger.rs b/src/backend/vector/scalar_mul/pippenger.rs index f7c161620..b00cb87c5 100644 --- a/src/backend/vector/scalar_mul/pippenger.rs +++ b/src/backend/vector/scalar_mul/pippenger.rs @@ -9,157 +9,169 @@ #![allow(non_snake_case)] -use alloc::vec::Vec; - -use core::borrow::Borrow; -use core::cmp::Ordering; - -use crate::backend::vector::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::{Identity, VartimeMultiscalarMul}; - -/// Implements a version of Pippenger's algorithm. -/// -/// See the documentation in the serial `scalar_mul::pippenger` module for details. -pub struct Pippenger; - -impl VartimeMultiscalarMul for Pippenger { - type Point = EdwardsPoint; - - fn optional_multiscalar_mul(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - let mut scalars = scalars.into_iter(); - let size = scalars.by_ref().size_hint().0; - let w = if size < 500 { - 6 - } else if size < 800 { - 7 - } else { - 8 - }; - - let max_digit: usize = 1 << w; - let digits_count: usize = Scalar::to_radix_2w_size_hint(w); - let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket - - // Collect optimized scalars and points in a buffer for repeated access - // (scanning the whole collection per each digit position). - let scalars = scalars.into_iter().map(|s| s.borrow().as_radix_2w(w)); - - let points = points - .into_iter() - .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); - - let scalars_points = scalars - .zip(points) - .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) - .collect::>>()?; - - // Prepare 2^w/2 buckets. - // buckets[i] corresponds to a multiplication factor (i+1). - let mut buckets: Vec = (0..buckets_count) - .map(|_| ExtendedPoint::identity()) - .collect(); - - let mut columns = (0..digits_count).rev().map(|digit_index| { - // Clear the buckets when processing another digit. - for bucket in &mut buckets { - *bucket = ExtendedPoint::identity(); - } +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { - // Iterate over pairs of (point, scalar) - // and add/sub the point to the corresponding bucket. - // Note: if we add support for precomputed lookup tables, - // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. - for (digits, pt) in scalars_points.iter() { - // Widen digit so that we don't run into edge cases when w=8. - let digit = digits[digit_index] as i16; - match digit.cmp(&0) { - Ordering::Greater => { - let b = (digit - 1) as usize; - buckets[b] = &buckets[b] + pt; - } - Ordering::Less => { - let b = (-digit - 1) as usize; - buckets[b] = &buckets[b] - pt; + use alloc::vec::Vec; + + use core::borrow::Borrow; + use core::cmp::Ordering; + + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::{Identity, VartimeMultiscalarMul}; + + /// Implements a version of Pippenger's algorithm. + /// + /// See the documentation in the serial `scalar_mul::pippenger` module for details. + pub struct Pippenger; + + impl VartimeMultiscalarMul for Pippenger { + type Point = EdwardsPoint; + + fn optional_multiscalar_mul(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + let mut scalars = scalars.into_iter(); + let size = scalars.by_ref().size_hint().0; + let w = if size < 500 { + 6 + } else if size < 800 { + 7 + } else { + 8 + }; + + let max_digit: usize = 1 << w; + let digits_count: usize = Scalar::to_radix_2w_size_hint(w); + let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket + + // Collect optimized scalars and points in a buffer for repeated access + // (scanning the whole collection per each digit position). + let scalars = scalars.map(|s| s.borrow().as_radix_2w(w)); + + let points = points + .into_iter() + .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P)))); + + let scalars_points = scalars + .zip(points) + .map(|(s, maybe_p)| maybe_p.map(|p| (s, p))) + .collect::>>()?; + + // Prepare 2^w/2 buckets. + // buckets[i] corresponds to a multiplication factor (i+1). + let mut buckets: Vec = (0..buckets_count) + .map(|_| ExtendedPoint::identity()) + .collect(); + + let mut columns = (0..digits_count).rev().map(|digit_index| { + // Clear the buckets when processing another digit. + for bucket in &mut buckets { + *bucket = ExtendedPoint::identity(); + } + + // Iterate over pairs of (point, scalar) + // and add/sub the point to the corresponding bucket. + // Note: if we add support for precomputed lookup tables, + // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0]. + for (digits, pt) in scalars_points.iter() { + // Widen digit so that we don't run into edge cases when w=8. + let digit = digits[digit_index] as i16; + match digit.cmp(&0) { + Ordering::Greater => { + let b = (digit - 1) as usize; + buckets[b] = &buckets[b] + pt; + } + Ordering::Less => { + let b = (-digit - 1) as usize; + buckets[b] = &buckets[b] - pt; + } + Ordering::Equal => {} } - Ordering::Equal => {} } - } - // Add the buckets applying the multiplication factor to each bucket. - // The most efficient way to do that is to have a single sum with two running sums: - // an intermediate sum from last bucket to the first, and a sum of intermediate sums. - // - // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: - // C - // C B - // C B A Sum = C + (C+B) + (C+B+A) - let mut buckets_intermediate_sum = buckets[buckets_count - 1]; - let mut buckets_sum = buckets[buckets_count - 1]; - for i in (0..(buckets_count - 1)).rev() { - buckets_intermediate_sum = - &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); - buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); - } + // Add the buckets applying the multiplication factor to each bucket. + // The most efficient way to do that is to have a single sum with two running sums: + // an intermediate sum from last bucket to the first, and a sum of intermediate sums. + // + // For example, to add buckets 1*A, 2*B, 3*C we need to add these points: + // C + // C B + // C B A Sum = C + (C+B) + (C+B+A) + let mut buckets_intermediate_sum = buckets[buckets_count - 1]; + let mut buckets_sum = buckets[buckets_count - 1]; + for i in (0..(buckets_count - 1)).rev() { + buckets_intermediate_sum = + &buckets_intermediate_sum + &CachedPoint::from(buckets[i]); + buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum); + } - buckets_sum - }); + buckets_sum + }); - // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. - // `unwrap()` always succeeds because we know we have more than zero digits. - let hi_column = columns.next().unwrap(); + // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`. + // `unwrap()` always succeeds because we know we have more than zero digits. + let hi_column = columns.next().unwrap(); - Some( - columns - .fold(hi_column, |total, p| { - &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) - }) - .into(), - ) + Some( + columns + .fold(hi_column, |total, p| { + &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p) + }) + .into(), + ) + } } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::constants; - use crate::scalar::Scalar; - #[test] - fn test_vartime_pippenger() { - // Reuse points across different tests - let mut n = 512; - let x = Scalar::from(2128506u64).invert(); - let y = Scalar::from(4443282u64).invert(); - let points: Vec<_> = (0..n) - .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) - .collect(); - let scalars: Vec<_> = (0..n) - .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars - .collect(); - - let premultiplied: Vec = scalars - .iter() - .zip(points.iter()) - .map(|(sc, pt)| sc * pt) - .collect(); - - while n > 0 { - let scalars = &scalars[0..n].to_vec(); - let points = &points[0..n].to_vec(); - let control: EdwardsPoint = premultiplied[0..n].iter().sum(); - - let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); - - assert_eq!(subject.compress(), control.compress()); - - n = n / 2; + #[cfg(test)] + mod test { + #[test] + fn test_vartime_pippenger() { + use super::*; + use crate::constants; + use crate::scalar::Scalar; + + // Reuse points across different tests + let mut n = 512; + let x = Scalar::from(2128506u64).invert(); + let y = Scalar::from(4443282u64).invert(); + let points: Vec<_> = (0..n) + .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64)) + .collect(); + let scalars: Vec<_> = (0..n) + .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars + .collect(); + + let premultiplied: Vec = scalars + .iter() + .zip(points.iter()) + .map(|(sc, pt)| sc * pt) + .collect(); + + while n > 0 { + let scalars = &scalars[0..n].to_vec(); + let points = &points[0..n].to_vec(); + let control: EdwardsPoint = premultiplied[0..n].iter().sum(); + + let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone()); + + assert_eq!(subject.compress(), control.compress()); + + n = n / 2; + } } } } diff --git a/src/backend/vector/scalar_mul/precomputed_straus.rs b/src/backend/vector/scalar_mul/precomputed_straus.rs index 359846173..8c45c29cf 100644 --- a/src/backend/vector/scalar_mul/precomputed_straus.rs +++ b/src/backend/vector/scalar_mul/precomputed_straus.rs @@ -11,105 +11,117 @@ #![allow(non_snake_case)] -use alloc::vec::Vec; +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { -use core::borrow::Borrow; -use core::cmp::Ordering; + use alloc::vec::Vec; -use crate::backend::vector::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::Identity; -use crate::traits::VartimePrecomputedMultiscalarMul; -use crate::window::{NafLookupTable5, NafLookupTable8}; + use core::borrow::Borrow; + use core::cmp::Ordering; -pub struct VartimePrecomputedStraus { - static_lookup_tables: Vec>, -} + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; -impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { - type Point = EdwardsPoint; + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; - fn new(static_points: I) -> Self - where - I: IntoIterator, - I::Item: Borrow, - { - Self { - static_lookup_tables: static_points - .into_iter() - .map(|P| NafLookupTable8::::from(P.borrow())) - .collect(), - } + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::Identity; + use crate::traits::VartimePrecomputedMultiscalarMul; + use crate::window::{NafLookupTable5, NafLookupTable8}; + + pub struct VartimePrecomputedStraus { + static_lookup_tables: Vec>, } - fn optional_mixed_multiscalar_mul( - &self, - static_scalars: I, - dynamic_scalars: J, - dynamic_points: K, - ) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - K: IntoIterator>, - { - let static_nafs = static_scalars - .into_iter() - .map(|c| c.borrow().non_adjacent_form(5)) - .collect::>(); - let dynamic_nafs: Vec<_> = dynamic_scalars - .into_iter() - .map(|c| c.borrow().non_adjacent_form(5)) - .collect::>(); - - let dynamic_lookup_tables = dynamic_points - .into_iter() - .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) - .collect::>>()?; - - let sp = self.static_lookup_tables.len(); - let dp = dynamic_lookup_tables.len(); - assert_eq!(sp, static_nafs.len()); - assert_eq!(dp, dynamic_nafs.len()); - - // We could save some doublings by looking for the highest - // nonzero NAF coefficient, but since we might have a lot of - // them to search, it's not clear it's worthwhile to check. - let mut R = ExtendedPoint::identity(); - for j in (0..256).rev() { - R = R.double(); - - for i in 0..dp { - let t_ij = dynamic_nafs[i][j]; - match t_ij.cmp(&0) { - Ordering::Greater => { - R = &R + &dynamic_lookup_tables[i].select(t_ij as usize); - } - Ordering::Less => { - R = &R - &dynamic_lookup_tables[i].select(-t_ij as usize); - } - Ordering::Equal => {} - } + impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { + type Point = EdwardsPoint; + + fn new(static_points: I) -> Self + where + I: IntoIterator, + I::Item: Borrow, + { + Self { + static_lookup_tables: static_points + .into_iter() + .map(|P| NafLookupTable8::::from(P.borrow())) + .collect(), } + } + + fn optional_mixed_multiscalar_mul( + &self, + static_scalars: I, + dynamic_scalars: J, + dynamic_points: K, + ) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator, + J::Item: Borrow, + K: IntoIterator>, + { + let static_nafs = static_scalars + .into_iter() + .map(|c| c.borrow().non_adjacent_form(5)) + .collect::>(); + let dynamic_nafs: Vec<_> = dynamic_scalars + .into_iter() + .map(|c| c.borrow().non_adjacent_form(5)) + .collect::>(); + + let dynamic_lookup_tables = dynamic_points + .into_iter() + .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) + .collect::>>()?; + + let sp = self.static_lookup_tables.len(); + let dp = dynamic_lookup_tables.len(); + assert_eq!(sp, static_nafs.len()); + assert_eq!(dp, dynamic_nafs.len()); - #[allow(clippy::needless_range_loop)] - for i in 0..sp { - let t_ij = static_nafs[i][j]; - match t_ij.cmp(&0) { - Ordering::Greater => { - R = &R + &self.static_lookup_tables[i].select(t_ij as usize); + // We could save some doublings by looking for the highest + // nonzero NAF coefficient, but since we might have a lot of + // them to search, it's not clear it's worthwhile to check. + let mut R = ExtendedPoint::identity(); + for j in (0..256).rev() { + R = R.double(); + + for i in 0..dp { + let t_ij = dynamic_nafs[i][j]; + match t_ij.cmp(&0) { + Ordering::Greater => { + R = &R + &dynamic_lookup_tables[i].select(t_ij as usize); + } + Ordering::Less => { + R = &R - &dynamic_lookup_tables[i].select(-t_ij as usize); + } + Ordering::Equal => {} } - Ordering::Less => { - R = &R - &self.static_lookup_tables[i].select(-t_ij as usize); + } + + #[allow(clippy::needless_range_loop)] + for i in 0..sp { + let t_ij = static_nafs[i][j]; + match t_ij.cmp(&0) { + Ordering::Greater => { + R = &R + &self.static_lookup_tables[i].select(t_ij as usize); + } + Ordering::Less => { + R = &R - &self.static_lookup_tables[i].select(-t_ij as usize); + } + Ordering::Equal => {} } - Ordering::Equal => {} } } - } - Some(R.into()) + Some(R.into()) + } } } diff --git a/src/backend/vector/scalar_mul/straus.rs b/src/backend/vector/scalar_mul/straus.rs index 693415361..046bcd14c 100644 --- a/src/backend/vector/scalar_mul/straus.rs +++ b/src/backend/vector/scalar_mul/straus.rs @@ -11,102 +11,114 @@ #![allow(non_snake_case)] -use alloc::vec::Vec; - -use core::borrow::Borrow; -use core::cmp::Ordering; - -use zeroize::Zeroizing; - -use crate::backend::vector::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul}; -use crate::window::{LookupTable, NafLookupTable5}; - -/// Multiscalar multiplication using interleaved window / Straus' -/// method. See the `Straus` struct in the serial backend for more -/// details. -/// -/// This exists as a seperate implementation from that one because the -/// AVX2 code uses different curve models (it does not pass between -/// multiple models during scalar mul), and it has to convert the -/// point representation on the fly. -pub struct Straus {} - -impl MultiscalarMul for Straus { - type Point = EdwardsPoint; - - fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - { - // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] - // for each input point P - let lookup_tables: Vec<_> = points - .into_iter() - .map(|point| LookupTable::::from(point.borrow())) - .collect(); - - let scalar_digits_vec: Vec<_> = scalars - .into_iter() - .map(|s| s.borrow().as_radix_16()) - .collect(); - // Pass ownership to a `Zeroizing` wrapper - let scalar_digits = Zeroizing::new(scalar_digits_vec); - - let mut Q = ExtendedPoint::identity(); - for j in (0..64).rev() { - Q = Q.mul_by_pow_2(4); - let it = scalar_digits.iter().zip(lookup_tables.iter()); - for (s_i, lookup_table_i) in it { - // Q = Q + s_{i,j} * P_i - Q = &Q + &lookup_table_i.select(s_i[j]); +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { + + use alloc::vec::Vec; + + use core::borrow::Borrow; + use core::cmp::Ordering; + + use zeroize::Zeroizing; + + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul}; + use crate::window::{LookupTable, NafLookupTable5}; + + /// Multiscalar multiplication using interleaved window / Straus' + /// method. See the `Straus` struct in the serial backend for more + /// details. + /// + /// This exists as a seperate implementation from that one because the + /// AVX2 code uses different curve models (it does not pass between + /// multiple models during scalar mul), and it has to convert the + /// point representation on the fly. + pub struct Straus {} + + impl MultiscalarMul for Straus { + type Point = EdwardsPoint; + + fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator, + J::Item: Borrow, + { + // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] + // for each input point P + let lookup_tables: Vec<_> = points + .into_iter() + .map(|point| LookupTable::::from(point.borrow())) + .collect(); + + let scalar_digits_vec: Vec<_> = scalars + .into_iter() + .map(|s| s.borrow().as_radix_16()) + .collect(); + // Pass ownership to a `Zeroizing` wrapper + let scalar_digits = Zeroizing::new(scalar_digits_vec); + + let mut Q = ExtendedPoint::identity(); + for j in (0..64).rev() { + Q = Q.mul_by_pow_2(4); + let it = scalar_digits.iter().zip(lookup_tables.iter()); + for (s_i, lookup_table_i) in it { + // Q = Q + s_{i,j} * P_i + Q = &Q + &lookup_table_i.select(s_i[j]); + } } + Q.into() } - Q.into() } -} -impl VartimeMultiscalarMul for Straus { - type Point = EdwardsPoint; - - fn optional_multiscalar_mul(scalars: I, points: J) -> Option - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator>, - { - let nafs: Vec<_> = scalars - .into_iter() - .map(|c| c.borrow().non_adjacent_form(5)) - .collect(); - let lookup_tables: Vec<_> = points - .into_iter() - .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) - .collect::>>()?; - - let mut Q = ExtendedPoint::identity(); - - for i in (0..256).rev() { - Q = Q.double(); - - for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) { - match naf[i].cmp(&0) { - Ordering::Greater => { - Q = &Q + &lookup_table.select(naf[i] as usize); - } - Ordering::Less => { - Q = &Q - &lookup_table.select(-naf[i] as usize); + impl VartimeMultiscalarMul for Straus { + type Point = EdwardsPoint; + + fn optional_multiscalar_mul(scalars: I, points: J) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator>, + { + let nafs: Vec<_> = scalars + .into_iter() + .map(|c| c.borrow().non_adjacent_form(5)) + .collect(); + let lookup_tables: Vec<_> = points + .into_iter() + .map(|P_opt| P_opt.map(|P| NafLookupTable5::::from(&P))) + .collect::>>()?; + + let mut Q = ExtendedPoint::identity(); + + for i in (0..256).rev() { + Q = Q.double(); + + for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) { + match naf[i].cmp(&0) { + Ordering::Greater => { + Q = &Q + &lookup_table.select(naf[i] as usize); + } + Ordering::Less => { + Q = &Q - &lookup_table.select(-naf[i] as usize); + } + Ordering::Equal => {} } - Ordering::Equal => {} } } - } - Some(Q.into()) + Some(Q.into()) + } } } diff --git a/src/backend/vector/scalar_mul/variable_base.rs b/src/backend/vector/scalar_mul/variable_base.rs index 52e855dd1..2da479926 100644 --- a/src/backend/vector/scalar_mul/variable_base.rs +++ b/src/backend/vector/scalar_mul/variable_base.rs @@ -1,32 +1,44 @@ #![allow(non_snake_case)] -use crate::backend::vector::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::Identity; -use crate::window::LookupTable; +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { -/// Perform constant-time, variable-base scalar multiplication. -pub fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { - // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] - let lookup_table = LookupTable::::from(point); - // Setting s = scalar, compute - // - // s = s_0 + s_1*16^1 + ... + s_63*16^63, - // - // with `-8 ≤ s_i < 8` for `0 ≤ i < 63` and `-8 ≤ s_63 ≤ 8`. - let scalar_digits = scalar.as_radix_16(); - // Compute s*P as - // - // s*P = P*(s_0 + s_1*16^1 + s_2*16^2 + ... + s_63*16^63) - // s*P = P*s_0 + P*s_1*16^1 + P*s_2*16^2 + ... + P*s_63*16^63 - // s*P = P*s_0 + 16*(P*s_1 + 16*(P*s_2 + 16*( ... + P*s_63)...)) - // - // We sum right-to-left. - let mut Q = ExtendedPoint::identity(); - for i in (0..64).rev() { - Q = Q.mul_by_pow_2(4); - Q = &Q + &lookup_table.select(scalar_digits[i]); + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; + + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; + + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::Identity; + use crate::window::LookupTable; + + /// Perform constant-time, variable-base scalar multiplication. + pub fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint { + // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] + let lookup_table = LookupTable::::from(point); + // Setting s = scalar, compute + // + // s = s_0 + s_1*16^1 + ... + s_63*16^63, + // + // with `-8 ≤ s_i < 8` for `0 ≤ i < 63` and `-8 ≤ s_63 ≤ 8`. + let scalar_digits = scalar.as_radix_16(); + // Compute s*P as + // + // s*P = P*(s_0 + s_1*16^1 + s_2*16^2 + ... + s_63*16^63) + // s*P = P*s_0 + P*s_1*16^1 + P*s_2*16^2 + ... + P*s_63*16^63 + // s*P = P*s_0 + 16*(P*s_1 + 16*(P*s_2 + 16*( ... + P*s_63)...)) + // + // We sum right-to-left. + let mut Q = ExtendedPoint::identity(); + for i in (0..64).rev() { + Q = Q.mul_by_pow_2(4); + Q = &Q + &lookup_table.select(scalar_digits[i]); + } + Q.into() } - Q.into() } diff --git a/src/backend/vector/scalar_mul/vartime_double_base.rs b/src/backend/vector/scalar_mul/vartime_double_base.rs index 5ec69ed52..191572bb1 100644 --- a/src/backend/vector/scalar_mul/vartime_double_base.rs +++ b/src/backend/vector/scalar_mul/vartime_double_base.rs @@ -11,69 +11,91 @@ #![allow(non_snake_case)] -use core::cmp::Ordering; +#[unsafe_target_feature::unsafe_target_feature_specialize( + conditional("avx2", feature = "simd_avx2"), + conditional("avx512ifma,avx512vl", all(feature = "simd_avx512", nightly)) +)] +pub mod spec { -use crate::backend::vector::{CachedPoint, ExtendedPoint}; -use crate::edwards::EdwardsPoint; -use crate::scalar::Scalar; -use crate::traits::Identity; -use crate::window::NafLookupTable5; + use core::cmp::Ordering; -/// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the Ed25519 basepoint. -pub fn mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { - let a_naf = a.non_adjacent_form(5); + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; - #[cfg(feature = "precomputed-tables")] - let b_naf = b.non_adjacent_form(8); - #[cfg(not(feature = "precomputed-tables"))] - let b_naf = b.non_adjacent_form(5); - - // Find starting index - let mut i: usize = 255; - for j in (0..256).rev() { - i = j; - if a_naf[i] != 0 || b_naf[i] != 0 { - break; - } - } + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint}; - let table_A = NafLookupTable5::::from(A); + #[cfg(feature = "precomputed-tables")] + #[for_target_feature("avx2")] + use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE; #[cfg(feature = "precomputed-tables")] - let table_B = &crate::backend::vector::BASEPOINT_ODD_LOOKUP_TABLE; - #[cfg(not(feature = "precomputed-tables"))] - let table_B = &NafLookupTable5::::from(&crate::constants::ED25519_BASEPOINT_POINT); + #[for_target_feature("avx512ifma")] + use crate::backend::vector::ifma::constants::BASEPOINT_ODD_LOOKUP_TABLE; - let mut Q = ExtendedPoint::identity(); + use crate::edwards::EdwardsPoint; + use crate::scalar::Scalar; + use crate::traits::Identity; + use crate::window::NafLookupTable5; - loop { - Q = Q.double(); + /// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the Ed25519 basepoint. + pub fn mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint { + let a_naf = a.non_adjacent_form(5); - match a_naf[i].cmp(&0) { - Ordering::Greater => { - Q = &Q + &table_A.select(a_naf[i] as usize); - } - Ordering::Less => { - Q = &Q - &table_A.select(-a_naf[i] as usize); + #[cfg(feature = "precomputed-tables")] + let b_naf = b.non_adjacent_form(8); + #[cfg(not(feature = "precomputed-tables"))] + let b_naf = b.non_adjacent_form(5); + + // Find starting index + let mut i: usize = 255; + for j in (0..256).rev() { + i = j; + if a_naf[i] != 0 || b_naf[i] != 0 { + break; } - Ordering::Equal => {} } - match b_naf[i].cmp(&0) { - Ordering::Greater => { - Q = &Q + &table_B.select(b_naf[i] as usize); + let table_A = NafLookupTable5::::from(A); + + #[cfg(feature = "precomputed-tables")] + let table_B = &BASEPOINT_ODD_LOOKUP_TABLE; + + #[cfg(not(feature = "precomputed-tables"))] + let table_B = + &NafLookupTable5::::from(&crate::constants::ED25519_BASEPOINT_POINT); + + let mut Q = ExtendedPoint::identity(); + + loop { + Q = Q.double(); + + match a_naf[i].cmp(&0) { + Ordering::Greater => { + Q = &Q + &table_A.select(a_naf[i] as usize); + } + Ordering::Less => { + Q = &Q - &table_A.select(-a_naf[i] as usize); + } + Ordering::Equal => {} + } + + match b_naf[i].cmp(&0) { + Ordering::Greater => { + Q = &Q + &table_B.select(b_naf[i] as usize); + } + Ordering::Less => { + Q = &Q - &table_B.select(-b_naf[i] as usize); + } + Ordering::Equal => {} } - Ordering::Less => { - Q = &Q - &table_B.select(-b_naf[i] as usize); + + if i == 0 { + break; } - Ordering::Equal => {} + i -= 1; } - if i == 0 { - break; - } - i -= 1; + Q.into() } - - Q.into() } diff --git a/src/constants.rs b/src/constants.rs index 344d608f1..caef33a98 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -99,7 +99,7 @@ use crate::ristretto::RistrettoBasepointTable; /// The Ristretto basepoint, as a `RistrettoBasepointTable` for scalar multiplication. #[cfg(feature = "precomputed-tables")] -pub static RISTRETTO_BASEPOINT_TABLE: &'static RistrettoBasepointTable = unsafe { +pub static RISTRETTO_BASEPOINT_TABLE: &RistrettoBasepointTable = unsafe { // SAFETY: `RistrettoBasepointTable` is a `#[repr(transparent)]` newtype of // `EdwardsBasepointTable` &*(ED25519_BASEPOINT_TABLE as *const EdwardsBasepointTable as *const RistrettoBasepointTable) diff --git a/src/edwards.rs b/src/edwards.rs index 6db96341e..b52a58862 100644 --- a/src/edwards.rs +++ b/src/edwards.rs @@ -144,17 +144,6 @@ use crate::traits::MultiscalarMul; #[cfg(feature = "alloc")] use crate::traits::{VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; -#[cfg(not(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") -)))] -use crate::backend::serial::scalar_mul; -#[cfg(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") -))] -use crate::backend::vector::scalar_mul; - // ------------------------------------------------------------------------ // Compressed points // ------------------------------------------------------------------------ @@ -698,7 +687,7 @@ impl<'a, 'b> Mul<&'b Scalar> for &'a EdwardsPoint { /// For scalar multiplication of a basepoint, /// `EdwardsBasepointTable` is approximately 4x faster. fn mul(self, scalar: &'b Scalar) -> EdwardsPoint { - scalar_mul::variable_base::mul(self, scalar) + crate::backend::variable_base_mul(self, scalar) } } @@ -795,7 +784,7 @@ impl MultiscalarMul for EdwardsPoint { // size-dependent algorithm dispatch, use this as the hint. let _size = s_lo; - scalar_mul::straus::Straus::multiscalar_mul(scalars, points) + crate::backend::straus_multiscalar_mul(scalars, points) } } @@ -827,9 +816,9 @@ impl VartimeMultiscalarMul for EdwardsPoint { let size = s_lo; if size < 190 { - scalar_mul::straus::Straus::optional_multiscalar_mul(scalars, points) + crate::backend::straus_optional_multiscalar_mul(scalars, points) } else { - scalar_mul::pippenger::Pippenger::optional_multiscalar_mul(scalars, points) + crate::backend::pippenger_optional_multiscalar_mul(scalars, points) } } } @@ -839,7 +828,7 @@ impl VartimeMultiscalarMul for EdwardsPoint { // decouple stability of the inner type from the stability of the // outer type. #[cfg(feature = "alloc")] -pub struct VartimeEdwardsPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus); +pub struct VartimeEdwardsPrecomputation(crate::backend::VartimePrecomputedStraus); #[cfg(feature = "alloc")] impl VartimePrecomputedMultiscalarMul for VartimeEdwardsPrecomputation { @@ -850,7 +839,7 @@ impl VartimePrecomputedMultiscalarMul for VartimeEdwardsPrecomputation { I: IntoIterator, I::Item: Borrow, { - Self(scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points)) + Self(crate::backend::VartimePrecomputedStraus::new(static_points)) } fn optional_mixed_multiscalar_mul( @@ -878,7 +867,7 @@ impl EdwardsPoint { A: &EdwardsPoint, b: &Scalar, ) -> EdwardsPoint { - scalar_mul::vartime_double_base::mul(a, A, b) + crate::backend::vartime_double_base_mul(a, A, b) } } diff --git a/src/lib.rs b/src/lib.rs index 83ccdadd4..f4d1d8223 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,15 +11,16 @@ #![no_std] #![cfg_attr( - all( - curve25519_dalek_backend = "simd", - target_feature = "avx512ifma", - nightly - ), + all(target_arch = "x86_64", feature = "simd_avx512", nightly), feature(stdsimd) )] +#![cfg_attr( + all(target_arch = "x86_64", feature = "simd_avx512", nightly), + feature(avx512_target_feature) +)] #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg, doc_cfg_hide))] #![cfg_attr(docsrs, doc(cfg_hide(docsrs)))] +#![cfg_attr(allow_unused_unsafe, allow(unused_unsafe))] //------------------------------------------------------------------------ // Documentation: //------------------------------------------------------------------------ diff --git a/src/ristretto.rs b/src/ristretto.rs index 6f3a69c20..03fa343f9 100644 --- a/src/ristretto.rs +++ b/src/ristretto.rs @@ -180,9 +180,6 @@ use digest::Digest; use crate::constants; use crate::field::FieldElement; -#[cfg(feature = "alloc")] -use cfg_if::cfg_if; - use subtle::Choice; use subtle::ConditionallyNegatable; use subtle::ConditionallySelectable; @@ -203,18 +200,6 @@ use crate::traits::Identity; #[cfg(feature = "alloc")] use crate::traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; -#[cfg(feature = "alloc")] -cfg_if! { - if #[cfg(all( - curve25519_dalek_backend = "simd", - any(target_feature = "avx2", target_feature = "avx512ifma") - ))] { - use crate::backend::vector::scalar_mul; - } else { - use crate::backend::serial::scalar_mul; - } -} - // ------------------------------------------------------------------------ // Compressed points // ------------------------------------------------------------------------ @@ -996,7 +981,7 @@ impl VartimeMultiscalarMul for RistrettoPoint { // decouple stability of the inner type from the stability of the // outer type. #[cfg(feature = "alloc")] -pub struct VartimeRistrettoPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus); +pub struct VartimeRistrettoPrecomputation(crate::backend::VartimePrecomputedStraus); #[cfg(feature = "alloc")] impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation { @@ -1007,11 +992,9 @@ impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation { I: IntoIterator, I::Item: Borrow, { - Self( - scalar_mul::precomputed_straus::VartimePrecomputedStraus::new( - static_points.into_iter().map(|P| P.borrow().0), - ), - ) + Self(crate::backend::VartimePrecomputedStraus::new( + static_points.into_iter().map(|P| P.borrow().0), + )) } fn optional_mixed_multiscalar_mul( diff --git a/src/scalar.rs b/src/scalar.rs index 025e8cbed..6634c88fc 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -184,7 +184,7 @@ cfg_if! { } /// The `Scalar` struct holds an element of \\(\mathbb Z / \ell\mathbb Z \\). -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Copy, Clone, Hash)] pub struct Scalar { /// `bytes` is a little-endian byte encoding of an integer representing a scalar modulo the