-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement an iterative FFT algorithm
This change adds an alternative algorithm for computing the discrete Fourier transform. It also adds a new module, benchmarked, for components of the crate that we want to benchmark, but don't want to expose in the public API. Finally, it adds a benchmark for comparing the speed of iterative FFT and recursive FFT on various input lengths.
- Loading branch information
Showing
9 changed files
with
452 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
use criterion::{criterion_group, criterion_main, Criterion}; | ||
|
||
use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft}; | ||
use prio::finite_field::{Field, FieldElement}; | ||
|
||
pub fn fft(c: &mut Criterion) { | ||
let test_sizes = [16, 256, 1024, 4096]; | ||
for size in test_sizes.iter() { | ||
let mut rng = rand::thread_rng(); | ||
let mut inp = vec![Field::zero(); *size]; | ||
let mut outp = vec![Field::zero(); *size]; | ||
for i in 0..*size { | ||
inp[i] = Field::rand(&mut rng); | ||
} | ||
|
||
c.bench_function(&format!("iterative/{}", *size), |b| { | ||
b.iter(|| { | ||
benchmarked_iterative_fft(&mut outp, &inp); | ||
}) | ||
}); | ||
|
||
c.bench_function(&format!("recursive/{}", *size), |b| { | ||
b.iter(|| { | ||
benchmarked_recursive_fft(&mut outp, &inp); | ||
}) | ||
}); | ||
} | ||
} | ||
|
||
criterion_group!(benches, fft); | ||
criterion_main!(benches); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
//! This package provides wrappers around internal components of this crate that we want to | ||
//! benchmark, but which we don't want to expose in the public API. | ||
use crate::fft::discrete_fourier_transform; | ||
use crate::finite_field::{Field, FieldElement}; | ||
use crate::polynomial::{poly_fft, PolyAuxMemory}; | ||
|
||
/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. | ||
pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) { | ||
discrete_fourier_transform(outp, inp).expect("encountered unexpected error"); | ||
} | ||
|
||
/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. | ||
pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) { | ||
let mut mem = PolyAuxMemory::new(inp.len() / 2); | ||
poly_fft( | ||
outp, | ||
inp, | ||
&mem.roots_2n, | ||
inp.len(), | ||
false, | ||
&mut mem.fft_memory, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier | ||
//! Transform (DFT) over a slice of field elements. | ||
use crate::finite_field::FieldElement; | ||
use crate::fp::{log2, MAX_ROOTS}; | ||
|
||
use std::convert::TryFrom; | ||
|
||
/// An error returned by DFT or DFT inverse computation. | ||
#[derive(Debug, thiserror::Error)] | ||
pub enum FftError { | ||
/// The output is too small. | ||
#[error("output slice is smaller than the input")] | ||
OutputTooSmall, | ||
/// The input is too large. | ||
#[error("input slice is larger than than maximum permitted")] | ||
InputTooLarge, | ||
/// The input length is not a power of 2. | ||
#[error("input size is not a power of 2")] | ||
InputSizeInvalid, | ||
} | ||
|
||
/// Sets `outp` to the DFT of `inp`. | ||
pub fn discrete_fourier_transform<F: FieldElement>( | ||
outp: &mut [F], | ||
inp: &[F], | ||
) -> Result<(), FftError> { | ||
let n = inp.len(); | ||
let d = usize::try_from(log2(n as u128)).unwrap(); | ||
|
||
if n > outp.len() { | ||
return Err(FftError::OutputTooSmall); | ||
} | ||
|
||
if n > 1 << MAX_ROOTS { | ||
return Err(FftError::InputTooLarge); | ||
} | ||
|
||
if n != 1 << d { | ||
return Err(FftError::InputSizeInvalid); | ||
} | ||
|
||
for i in 0..n { | ||
outp[i] = inp[bitrev(d, i)]; | ||
} | ||
|
||
let mut w: F; | ||
for l in 1..d + 1 { | ||
w = F::root(0).unwrap(); // one | ||
let r = F::root(l).unwrap(); | ||
let y = 1 << (l - 1); | ||
for i in 0..y { | ||
for j in 0..(n / y) >> 1 { | ||
let x = (1 << l) * j + i; | ||
let u = outp[x]; | ||
let v = w * outp[x + y]; | ||
outp[x] = u + v; | ||
outp[x + y] = u - v; | ||
} | ||
w *= r; | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Sets `outp` to the inverse of the DFT of `inp`. | ||
#[allow(dead_code)] | ||
pub fn discrete_fourier_transform_inv<F: FieldElement>( | ||
outp: &mut [F], | ||
inp: &[F], | ||
) -> Result<(), FftError> { | ||
discrete_fourier_transform(outp, inp)?; | ||
let n = inp.len(); | ||
let m = F::from(F::Integer::try_from(n).unwrap()).inv(); | ||
let mut tmp: F; | ||
|
||
outp[0] *= m; | ||
outp[n >> 1] *= m; | ||
for i in 1..n >> 1 { | ||
tmp = outp[i] * m; | ||
outp[i] = outp[n - i] * m; | ||
outp[n - i] = tmp; | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109) | ||
fn bitrev(d: usize, x: usize) -> usize { | ||
let mut y = 0; | ||
for i in 0..d { | ||
y += ((x >> i) & 1) << (d - i); | ||
} | ||
y >> 1 | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::finite_field::{Field, Field126, Field64, Field80}; | ||
use crate::polynomial::{poly_fft, PolyAuxMemory}; | ||
|
||
fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> { | ||
let mut rng = rand::thread_rng(); | ||
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; | ||
|
||
for size in test_sizes.iter() { | ||
let mut want = vec![F::zero(); *size]; | ||
let mut tmp = vec![F::zero(); *size]; | ||
let mut got = vec![F::zero(); *size]; | ||
for i in 0..*size { | ||
want[i] = F::rand(&mut rng); | ||
} | ||
|
||
discrete_fourier_transform(&mut tmp, &want)?; | ||
discrete_fourier_transform_inv(&mut got, &tmp)?; | ||
assert_eq!(got, want); | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_field32() { | ||
discrete_fourier_transform_then_inv_test::<Field>().expect("unexpected error"); | ||
} | ||
|
||
#[test] | ||
fn test_field64() { | ||
discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error"); | ||
} | ||
|
||
#[test] | ||
fn test_field80() { | ||
discrete_fourier_transform_then_inv_test::<Field80>().expect("unexpected error"); | ||
} | ||
|
||
#[test] | ||
fn test_field126() { | ||
discrete_fourier_transform_then_inv_test::<Field126>().expect("unexpected error"); | ||
} | ||
|
||
#[test] | ||
fn test_recursive_fft() { | ||
let size = 128; | ||
let mut rng = rand::thread_rng(); | ||
let mut mem = PolyAuxMemory::new(size / 2); | ||
|
||
let mut inp = vec![Field::zero(); size]; | ||
let mut want = vec![Field::zero(); size]; | ||
let mut got = vec![Field::zero(); size]; | ||
for i in 0..size { | ||
inp[i] = Field::rand(&mut rng); | ||
} | ||
|
||
discrete_fourier_transform::<Field>(&mut want, &inp).expect("unexpected error"); | ||
|
||
poly_fft( | ||
&mut got, | ||
&inp, | ||
&mem.roots_2n, | ||
size, | ||
false, | ||
&mut mem.fft_memory, | ||
); | ||
|
||
assert_eq!(got, want); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.