Skip to content

Commit

Permalink
Implement an iterative FFT algorithm
Browse files Browse the repository at this point in the history
This change adds an alternative algorithm for computing the discrete
Fourier transform. It also adds a new module, benchmarked, for
components of the crate that we want to benchmark, but don't want to
expose in the public API. Finally, it adds a benchmark for comparing the
speed of iterative FFT and recursive FFT on various input lengths.
  • Loading branch information
cjpatton committed Apr 5, 2021
1 parent 17cde85 commit af2026b
Show file tree
Hide file tree
Showing 9 changed files with 452 additions and 83 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ thiserror = "1.0"

[dev-dependencies]
assert_matches = "1.5.0"
criterion = "0.3"
modinverse = "0.1.0"
num-bigint = "0.4.0"

[[bench]]
name = "fft"
harness = false

[[example]]
name = "sum"
33 changes: 33 additions & 0 deletions benches/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MPL-2.0

use criterion::{criterion_group, criterion_main, Criterion};

use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft};
use prio::finite_field::{Field, FieldElement};

pub fn fft(c: &mut Criterion) {
let test_sizes = [16, 256, 1024, 4096];
for size in test_sizes.iter() {
let mut rng = rand::thread_rng();
let mut inp = vec![Field::zero(); *size];
let mut outp = vec![Field::zero(); *size];
for i in 0..*size {
inp[i] = Field::rand(&mut rng);
}

c.bench_function(&format!("iterative/{}", *size), |b| {
b.iter(|| {
benchmarked_iterative_fft(&mut outp, &inp);
})
});

c.bench_function(&format!("recursive/{}", *size), |b| {
b.iter(|| {
benchmarked_recursive_fft(&mut outp, &inp);
})
});
}
}

criterion_group!(benches, fft);
criterion_main!(benches);
26 changes: 26 additions & 0 deletions src/benchmarked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MPL-2.0

//! This package provides wrappers around internal components of this crate that we want to
//! benchmark, but which we don't want to expose in the public API.
use crate::fft::discrete_fourier_transform;
use crate::finite_field::{Field, FieldElement};
use crate::polynomial::{poly_fft, PolyAuxMemory};

/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
discrete_fourier_transform(outp, inp).expect("encountered unexpected error");
}

/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) {
let mut mem = PolyAuxMemory::new(inp.len() / 2);
poly_fft(
outp,
inp,
&mem.roots_2n,
inp.len(),
false,
&mut mem.fft_memory,
)
}
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl Client {
pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option<Self> {
let n = (dimension + 1).next_power_of_two();

if 2 * n > Field::num_roots() as usize {
if 2 * n > Field::generator_order() as usize {
// too many elements for this field, not enough roots of unity
return None;
}
Expand Down
172 changes: 172 additions & 0 deletions src/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// SPDX-License-Identifier: MPL-2.0

//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
//! Transform (DFT) over a slice of field elements.
use crate::finite_field::FieldElement;
use crate::fp::{log2, MAX_ROOTS};

use std::convert::TryFrom;

/// An error returned by DFT or DFT inverse computation.
#[derive(Debug, thiserror::Error)]
pub enum FftError {
/// The output is too small.
#[error("output slice is smaller than the input")]
OutputTooSmall,
/// The input is too large.
#[error("input slice is larger than than maximum permitted")]
InputTooLarge,
/// The input length is not a power of 2.
#[error("input size is not a power of 2")]
InputSizeInvalid,
}

/// Sets `outp` to the DFT of `inp`.
pub fn discrete_fourier_transform<F: FieldElement>(
outp: &mut [F],
inp: &[F],
) -> Result<(), FftError> {
let n = inp.len();
let d = usize::try_from(log2(n as u128)).unwrap();

if n > outp.len() {
return Err(FftError::OutputTooSmall);
}

if n > 1 << MAX_ROOTS {
return Err(FftError::InputTooLarge);
}

if n != 1 << d {
return Err(FftError::InputSizeInvalid);
}

for i in 0..n {
outp[i] = inp[bitrev(d, i)];
}

let mut w: F;
for l in 1..d + 1 {
w = F::root(0).unwrap(); // one
let r = F::root(l).unwrap();
let y = 1 << (l - 1);
for i in 0..y {
for j in 0..(n / y) >> 1 {
let x = (1 << l) * j + i;
let u = outp[x];
let v = w * outp[x + y];
outp[x] = u + v;
outp[x + y] = u - v;
}
w *= r;
}
}

Ok(())
}

/// Sets `outp` to the inverse of the DFT of `inp`.
#[allow(dead_code)]
pub fn discrete_fourier_transform_inv<F: FieldElement>(
outp: &mut [F],
inp: &[F],
) -> Result<(), FftError> {
discrete_fourier_transform(outp, inp)?;
let n = inp.len();
let m = F::from(F::Integer::try_from(n).unwrap()).inv();
let mut tmp: F;

outp[0] *= m;
outp[n >> 1] *= m;
for i in 1..n >> 1 {
tmp = outp[i] * m;
outp[i] = outp[n - i] * m;
outp[n - i] = tmp;
}

Ok(())
}

// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
fn bitrev(d: usize, x: usize) -> usize {
let mut y = 0;
for i in 0..d {
y += ((x >> i) & 1) << (d - i);
}
y >> 1
}

#[cfg(test)]
mod tests {
use super::*;
use crate::finite_field::{Field, Field126, Field64, Field80};
use crate::polynomial::{poly_fft, PolyAuxMemory};

fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
let mut rng = rand::thread_rng();
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];

for size in test_sizes.iter() {
let mut want = vec![F::zero(); *size];
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
for i in 0..*size {
want[i] = F::rand(&mut rng);
}

discrete_fourier_transform(&mut tmp, &want)?;
discrete_fourier_transform_inv(&mut got, &tmp)?;
assert_eq!(got, want);
}

Ok(())
}

#[test]
fn test_field32() {
discrete_fourier_transform_then_inv_test::<Field>().expect("unexpected error");
}

#[test]
fn test_field64() {
discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
}

#[test]
fn test_field80() {
discrete_fourier_transform_then_inv_test::<Field80>().expect("unexpected error");
}

#[test]
fn test_field126() {
discrete_fourier_transform_then_inv_test::<Field126>().expect("unexpected error");
}

#[test]
fn test_recursive_fft() {
let size = 128;
let mut rng = rand::thread_rng();
let mut mem = PolyAuxMemory::new(size / 2);

let mut inp = vec![Field::zero(); size];
let mut want = vec![Field::zero(); size];
let mut got = vec![Field::zero(); size];
for i in 0..size {
inp[i] = Field::rand(&mut rng);
}

discrete_fourier_transform::<Field>(&mut want, &inp).expect("unexpected error");

poly_fft(
&mut got,
&inp,
&mem.roots_2n,
size,
false,
&mut mem.fft_memory,
);

assert_eq!(got, want);
}
}
67 changes: 53 additions & 14 deletions src/finite_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
use crate::fp::{FP126, FP32, FP64, FP80};
use std::{
cmp::min,
convert::TryFrom,
fmt::{Display, Formatter},
fmt::{Debug, Display, Formatter},
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};

use rand::Rng;

/// Possible errors from finite field operations.
#[derive(Debug, thiserror::Error)]
pub enum FiniteFieldError {
Expand All @@ -21,36 +24,55 @@ pub enum FiniteFieldError {
/// Objects with this trait represent an element of `GF(p)` for some prime `p`.
pub trait FieldElement:
Sized
+ Debug
+ Copy
+ PartialEq
+ Eq
+ Add
+ Add<Output = Self>
+ AddAssign
+ Sub
+ Sub<Output = Self>
+ SubAssign
+ Mul
+ Mul<Output = Self>
+ MulAssign
+ Div
+ Div<Output = Self>
+ DivAssign
+ Neg
+ Neg<Output = Self>
+ Display
+ From<<Self as FieldElement>::Integer>
{
/// The error returned if converting `usize` to an `Int` fails.
type IntegerTryFromError: std::fmt::Debug;

/// The integer representation of the field element.
type Integer;
type Integer: Copy
+ Debug
+ Sub<Output = <Self as FieldElement>::Integer>
+ TryFrom<usize, Error = Self::IntegerTryFromError>;

/// Modular exponentation, i.e., `self^exp (mod p)`.
fn pow(&self, exp: Self) -> Self;
fn pow(&self, exp: Self) -> Self; // TODO(cjpatton) exp should have type Self::Integer

/// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
fn inv(&self) -> Self;

/// Returns the prime modulus `p`.
fn modulus() -> Self::Integer;

/// Returns a generator of the multiplicative subgroup of size `FieldElement::num_roots()`.
/// Returns the size of the multiplicative subgroup generated by `generator()`.
fn generator_order() -> Self::Integer;

/// Returns the generator of the multiplicative subgroup of size `generator_order()`.
fn generator() -> Self;

/// Returns the size of the multiplicative subgroup generated by `FieldElement::generator()`.
fn num_roots() -> Self::Integer;
/// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
/// prinicpal root of unity is 1 by definition.
fn root(l: usize) -> Option<Self>;

/// Returns a random field element distributed uniformly over all field elements.
fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self;

/// Returns the additive identity.
fn zero() -> Self;
}

macro_rules! make_field {
Expand Down Expand Up @@ -190,6 +212,7 @@ macro_rules! make_field {

impl FieldElement for $elem {
type Integer = $int;
type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;

fn pow(&self, exp: Self) -> Self {
Self($fp.pow(self.0, $fp.from_elem(exp.0)))
Expand All @@ -207,8 +230,24 @@ macro_rules! make_field {
Self($fp.g)
}

fn num_roots() -> Self::Integer {
$fp.num_roots as Self::Integer
fn generator_order() -> Self::Integer {
1 << (Self::Integer::try_from($fp.num_roots).unwrap())
}

fn root(l: usize) -> Option<Self> {
if l < min($fp.roots.len(), $fp.num_roots+1) {
Some(Self($fp.roots[l]))
} else {
None
}
}

fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self {
Self($fp.rand_elem(rng))
}

fn zero() -> Self {
Self(0)
}
}
};
Expand Down Expand Up @@ -245,7 +284,7 @@ make_field!(

#[test]
fn test_arithmetic() {
// TODO(cjpatton) Add tests for Field64, Field80, and Field126.
// TODO(cjpatton) Add tests for the other fields.
use rand::prelude::*;

let modulus = Field::modulus();
Expand Down
Loading

0 comments on commit af2026b

Please sign in to comment.