From 20fab1501e8418f2bea3eb0446ea37181f935e90 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 3 Sep 2024 14:25:17 -0500 Subject: [PATCH] fix: fixes #522 with bound check before computing twiddles when domain has no precompute set (#523) --- ecc/bls12-377/fr/fft/fft.go | 20 +- ecc/bls12-377/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bls12-378/fr/fft/fft.go | 20 +- ecc/bls12-378/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bls12-381/fr/fft/fft.go | 20 +- ecc/bls12-381/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bls24-315/fr/fft/fft.go | 20 +- ecc/bls24-315/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bls24-317/fr/fft/fft.go | 20 +- ecc/bls24-317/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bn254/fr/fft/fft.go | 20 +- ecc/bn254/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bw6-633/fr/fft/fft.go | 20 +- ecc/bw6-633/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bw6-756/fr/fft/fft.go | 20 +- ecc/bw6-756/fr/fft/fft_test.go | 302 ++++++++--------- ecc/bw6-761/fr/fft/fft.go | 20 +- ecc/bw6-761/fr/fft/fft_test.go | 302 ++++++++--------- internal/generator/fft/template/fft.go.tmpl | 20 +- .../generator/fft/template/tests/fft.go.tmpl | 306 +++++++++--------- 20 files changed, 1643 insertions(+), 1581 deletions(-) diff --git a/ecc/bls12-377/fr/fft/fft.go b/ecc/bls12-377/fr/fft/fft.go index 0e181cfc43..634dc309bc 100644 --- a/ecc/bls12-377/fr/fft/fft.go +++ b/ecc/bls12-377/fr/fft/fft.go @@ -95,10 +95,12 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) if !domain.withPrecompute { twiddlesStartStage = 3 nbStages := int(bits.TrailingZeros64(domain.Cardinality)) - twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) - w := domain.Generator - w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1< 0 { + twiddles = make([][]fr.Element, nbStages - twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) + buildTwiddles(twiddles, w, uint64(nbStages - twiddlesStartStage)) + } // else, we don't need twiddles } switch decimation { @@ -118,10 +120,12 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... if !domain.withPrecompute { twiddlesStartStage = 3 nbStages := int(bits.TrailingZeros64(domain.Cardinality)) - twiddlesInv = make([][]fr.Element, nbStages - twiddlesStartStage) - w := domain.GeneratorInv - w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) - buildTwiddles(twiddlesInv, w, uint64(nbStages - twiddlesStartStage)) + if nbStages - twiddlesStartStage > 0 { + twiddlesInv = make([][]fr.Element, nbStages - twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) + buildTwiddles(twiddlesInv, w, uint64(nbStages - twiddlesStartStage)) + } // else, we don't need twiddles } switch decimation { diff --git a/internal/generator/fft/template/tests/fft.go.tmpl b/internal/generator/fft/template/tests/fft.go.tmpl index 6fb2885570..73b22993b2 100644 --- a/internal/generator/fft/template/tests/fft.go.tmpl +++ b/internal/generator/fft/template/tests/fft.go.tmpl @@ -9,224 +9,228 @@ import ( "github.com/leanovate/gopter/prop" "github.com/leanovate/gopter/gen" + "fmt" + ) func TestFFT(t *testing.T) { - const maxSize = 1 << 10 - - nbCosets := 3 - domainWithPrecompute := NewDomain(maxSize) - domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) - parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 - properties := gopter.NewProperties(parameters) - for domainName, domain := range map[string]*Domain{ - "with precompute": domainWithPrecompute, - "without precompute": domainWithoutPrecompute, - } { - domainName := domainName - domain := domain - t.Logf("domain: %s", domainName) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { - - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) - - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) - - domain.FFT(pol, DIF) - BitReverse(pol) - - sample := domain.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + for maxSize := 2; maxSize <= 1 << 10; maxSize <<= 1 { - eval := evaluatePolynomial(backupPol, sample) + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) - return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + domain.FFT(pol, DIF) + BitReverse(pol) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - domain.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + eval := evaluatePolynomial(backupPol, sample) - sample := domain.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domain.FrMultiplicativeGen) + return eval.Equal(&pol[ithpower]) - eval := evaluatePolynomial(backupPol, sample) + }, + gen.IntRange(0, maxSize-1), + )) - return eval.Equal(&pol[ithpower]) + - }, - gen.IntRange(0, maxSize-1), - )) + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - BitReverse(pol) - domain.FFT(pol, DIT) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - sample := domain.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + eval := evaluatePolynomial(backupPol, sample) - eval := evaluatePolynomial(backupPol, sample) + return eval.Equal(&pol[ithpower]) - return eval.Equal(&pol[ithpower]) + }, + gen.IntRange(0, maxSize-1), + )) - }, - gen.IntRange(0, maxSize-1), - )) + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - func() bool { + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + BitReverse(pol) + domain.FFT(pol, DIT) - BitReverse(pol) - domain.FFT(pol, DIT) - domain.FFTInverse(pol, DIF) - BitReverse(pol) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + eval := evaluatePolynomial(backupPol, sample) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + return eval.Equal(&pol[ithpower]) - func() bool { + }, + gen.IntRange(0, maxSize-1), + )) - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + func() bool { - check := true + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 1; i <= nbCosets; i++ { + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) BitReverse(pol) - domain.FFT(pol, DIT, OnCoset()) - domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) BitReverse(pol) + check := true for i := 0; i < len(pol); i++ { check = check && pol[i].Equal(&backupPol[i]) } - } + return check + }, + )) - return check - }, - )) + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + func() bool { - func() bool { + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + check := true - domain.FFTInverse(pol, DIF) - domain.FFT(pol, DIT) + for i := 1; i <= nbCosets; i++ { - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } - func() bool { + return check + }, + )) + } - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + func() bool { - domain.FFTInverse(pol, DIF, OnCoset()) - domain.FFT(pol, DIT, OnCoset()) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() } - } + copy(backupPol, pol) - // compute with nbTasks == 1 - domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) } - } + return check + }, + )) + + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { - return true - }, - )) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) + + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + return true + }, + )) + } properties.TestingRun(t, gopter.ConsoleReporter(false)) }