Skip to content

Commit

Permalink
Merge pull request #21 from cjpatton/iter-fft
Browse files Browse the repository at this point in the history
Implement an iterative FFT algorithm
  • Loading branch information
tgeoghegan authored Apr 5, 2021
2 parents 55c2ebe + af2026b commit d4f97df
Show file tree
Hide file tree
Showing 9 changed files with 452 additions and 83 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ thiserror = "1.0"

[dev-dependencies]
assert_matches = "1.5.0"
criterion = "0.3"
modinverse = "0.1.0"
num-bigint = "0.4.0"

[[bench]]
name = "fft"
harness = false

[[example]]
name = "sum"
33 changes: 33 additions & 0 deletions benches/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MPL-2.0

use criterion::{criterion_group, criterion_main, Criterion};

use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft};
use prio::finite_field::{Field, FieldElement};

pub fn fft(c: &mut Criterion) {
let test_sizes = [16, 256, 1024, 4096];
for size in test_sizes.iter() {
let mut rng = rand::thread_rng();
let mut inp = vec![Field::zero(); *size];
let mut outp = vec![Field::zero(); *size];
for i in 0..*size {
inp[i] = Field::rand(&mut rng);
}

c.bench_function(&format!("iterative/{}", *size), |b| {
b.iter(|| {
benchmarked_iterative_fft(&mut outp, &inp);
})
});

c.bench_function(&format!("recursive/{}", *size), |b| {
b.iter(|| {
benchmarked_recursive_fft(&mut outp, &inp);
})
});
}
}

criterion_group!(benches, fft);
criterion_main!(benches);
26 changes: 26 additions & 0 deletions src/benchmarked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MPL-2.0

//! This package provides wrappers around internal components of this crate that we want to
//! benchmark, but which we don't want to expose in the public API.
use crate::fft::discrete_fourier_transform;
use crate::finite_field::{Field, FieldElement};
use crate::polynomial::{poly_fft, PolyAuxMemory};

/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
discrete_fourier_transform(outp, inp).expect("encountered unexpected error");
}

/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) {
let mut mem = PolyAuxMemory::new(inp.len() / 2);
poly_fft(
outp,
inp,
&mem.roots_2n,
inp.len(),
false,
&mut mem.fft_memory,
)
}
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl Client {
pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option<Self> {
let n = (dimension + 1).next_power_of_two();

if 2 * n > Field::num_roots() as usize {
if 2 * n > Field::generator_order() as usize {
// too many elements for this field, not enough roots of unity
return None;
}
Expand Down
172 changes: 172 additions & 0 deletions src/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// SPDX-License-Identifier: MPL-2.0

//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
//! Transform (DFT) over a slice of field elements.
use crate::finite_field::FieldElement;
use crate::fp::{log2, MAX_ROOTS};

use std::convert::TryFrom;

/// An error returned by DFT or DFT inverse computation.
#[derive(Debug, thiserror::Error)]
pub enum FftError {
/// The output is too small.
#[error("output slice is smaller than the input")]
OutputTooSmall,
/// The input is too large.
#[error("input slice is larger than than maximum permitted")]
InputTooLarge,
/// The input length is not a power of 2.
#[error("input size is not a power of 2")]
InputSizeInvalid,
}

/// Sets `outp` to the DFT of `inp`.
pub fn discrete_fourier_transform<F: FieldElement>(
outp: &mut [F],
inp: &[F],
) -> Result<(), FftError> {
let n = inp.len();
let d = usize::try_from(log2(n as u128)).unwrap();

if n > outp.len() {
return Err(FftError::OutputTooSmall);
}

if n > 1 << MAX_ROOTS {
return Err(FftError::InputTooLarge);
}

if n != 1 << d {
return Err(FftError::InputSizeInvalid);
}

for i in 0..n {
outp[i] = inp[bitrev(d, i)];
}

let mut w: F;
for l in 1..d + 1 {
w = F::root(0).unwrap(); // one
let r = F::root(l).unwrap();
let y = 1 << (l - 1);
for i in 0..y {
for j in 0..(n / y) >> 1 {
let x = (1 << l) * j + i;
let u = outp[x];
let v = w * outp[x + y];
outp[x] = u + v;
outp[x + y] = u - v;
}
w *= r;
}
}

Ok(())
}

/// Sets `outp` to the inverse of the DFT of `inp`.
#[allow(dead_code)]
pub fn discrete_fourier_transform_inv<F: FieldElement>(
outp: &mut [F],
inp: &[F],
) -> Result<(), FftError> {
discrete_fourier_transform(outp, inp)?;
let n = inp.len();
let m = F::from(F::Integer::try_from(n).unwrap()).inv();
let mut tmp: F;

outp[0] *= m;
outp[n >> 1] *= m;
for i in 1..n >> 1 {
tmp = outp[i] * m;
outp[i] = outp[n - i] * m;
outp[n - i] = tmp;
}

Ok(())
}

// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
fn bitrev(d: usize, x: usize) -> usize {
let mut y = 0;
for i in 0..d {
y += ((x >> i) & 1) << (d - i);
}
y >> 1
}

#[cfg(test)]
mod tests {
use super::*;
use crate::finite_field::{Field, Field126, Field64, Field80};
use crate::polynomial::{poly_fft, PolyAuxMemory};

fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
let mut rng = rand::thread_rng();
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];

for size in test_sizes.iter() {
let mut want = vec![F::zero(); *size];
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
for i in 0..*size {
want[i] = F::rand(&mut rng);
}

discrete_fourier_transform(&mut tmp, &want)?;
discrete_fourier_transform_inv(&mut got, &tmp)?;
assert_eq!(got, want);
}

Ok(())
}

#[test]
fn test_field32() {
discrete_fourier_transform_then_inv_test::<Field>().expect("unexpected error");
}

#[test]
fn test_field64() {
discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
}

#[test]
fn test_field80() {
discrete_fourier_transform_then_inv_test::<Field80>().expect("unexpected error");
}

#[test]
fn test_field126() {
discrete_fourier_transform_then_inv_test::<Field126>().expect("unexpected error");
}

#[test]
fn test_recursive_fft() {
let size = 128;
let mut rng = rand::thread_rng();
let mut mem = PolyAuxMemory::new(size / 2);

let mut inp = vec![Field::zero(); size];
let mut want = vec![Field::zero(); size];
let mut got = vec![Field::zero(); size];
for i in 0..size {
inp[i] = Field::rand(&mut rng);
}

discrete_fourier_transform::<Field>(&mut want, &inp).expect("unexpected error");

poly_fft(
&mut got,
&inp,
&mem.roots_2n,
size,
false,
&mut mem.fft_memory,
);

assert_eq!(got, want);
}
}
67 changes: 53 additions & 14 deletions src/finite_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
use crate::fp::{FP126, FP32, FP64, FP80};
use std::{
cmp::min,
convert::TryFrom,
fmt::{Display, Formatter},
fmt::{Debug, Display, Formatter},
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};

use rand::Rng;

/// Possible errors from finite field operations.
#[derive(Debug, thiserror::Error)]
pub enum FiniteFieldError {
Expand All @@ -21,36 +24,55 @@ pub enum FiniteFieldError {
/// Objects with this trait represent an element of `GF(p)` for some prime `p`.
pub trait FieldElement:
Sized
+ Debug
+ Copy
+ PartialEq
+ Eq
+ Add
+ Add<Output = Self>
+ AddAssign
+ Sub
+ Sub<Output = Self>
+ SubAssign
+ Mul
+ Mul<Output = Self>
+ MulAssign
+ Div
+ Div<Output = Self>
+ DivAssign
+ Neg
+ Neg<Output = Self>
+ Display
+ From<<Self as FieldElement>::Integer>
{
/// The error returned if converting `usize` to an `Int` fails.
type IntegerTryFromError: std::fmt::Debug;

/// The integer representation of the field element.
type Integer;
type Integer: Copy
+ Debug
+ Sub<Output = <Self as FieldElement>::Integer>
+ TryFrom<usize, Error = Self::IntegerTryFromError>;

/// Modular exponentation, i.e., `self^exp (mod p)`.
fn pow(&self, exp: Self) -> Self;
fn pow(&self, exp: Self) -> Self; // TODO(cjpatton) exp should have type Self::Integer

/// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
fn inv(&self) -> Self;

/// Returns the prime modulus `p`.
fn modulus() -> Self::Integer;

/// Returns a generator of the multiplicative subgroup of size `FieldElement::num_roots()`.
/// Returns the size of the multiplicative subgroup generated by `generator()`.
fn generator_order() -> Self::Integer;

/// Returns the generator of the multiplicative subgroup of size `generator_order()`.
fn generator() -> Self;

/// Returns the size of the multiplicative subgroup generated by `FieldElement::generator()`.
fn num_roots() -> Self::Integer;
/// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
/// prinicpal root of unity is 1 by definition.
fn root(l: usize) -> Option<Self>;

/// Returns a random field element distributed uniformly over all field elements.
fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self;

/// Returns the additive identity.
fn zero() -> Self;
}

macro_rules! make_field {
Expand Down Expand Up @@ -190,6 +212,7 @@ macro_rules! make_field {

impl FieldElement for $elem {
type Integer = $int;
type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;

fn pow(&self, exp: Self) -> Self {
Self($fp.pow(self.0, $fp.from_elem(exp.0)))
Expand All @@ -207,8 +230,24 @@ macro_rules! make_field {
Self($fp.g)
}

fn num_roots() -> Self::Integer {
$fp.num_roots as Self::Integer
fn generator_order() -> Self::Integer {
1 << (Self::Integer::try_from($fp.num_roots).unwrap())
}

fn root(l: usize) -> Option<Self> {
if l < min($fp.roots.len(), $fp.num_roots+1) {
Some(Self($fp.roots[l]))
} else {
None
}
}

fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self {
Self($fp.rand_elem(rng))
}

fn zero() -> Self {
Self(0)
}
}
};
Expand Down Expand Up @@ -245,7 +284,7 @@ make_field!(

#[test]
fn test_arithmetic() {
// TODO(cjpatton) Add tests for Field64, Field80, and Field126.
// TODO(cjpatton) Add tests for the other fields.
use rand::prelude::*;

let modulus = Field::modulus();
Expand Down
Loading

0 comments on commit d4f97df

Please sign in to comment.