Skip to content

Commit

Permalink
optimize encoder interpolation latency (Layr-Labs#309)
Browse files Browse the repository at this point in the history
Co-authored-by: Bowen Xue <[email protected]>
Co-authored-by: Daniel Mancia <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2024
1 parent 1d17604 commit 992e0f2
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 40 deletions.
1 change: 1 addition & 0 deletions disperser/cmd/encoder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"

"github.com/Layr-Labs/eigenda/common"

"github.com/Layr-Labs/eigenda/disperser/cmd/encoder/flags"
"github.com/urfave/cli"
)
Expand Down
8 changes: 4 additions & 4 deletions encoding/kzg/prover/parametrized_prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ func (p *ParametrizedProver) proofWorker(
points: nil,
err: err,
}
}

for i := 0; i < len(coeffs); i++ {
coeffStore[i][j] = coeffs[i]
} else {
for i := 0; i < len(coeffs); i++ {
coeffStore[i][j] = coeffs[i]
}
}
}

Expand Down
87 changes: 67 additions & 20 deletions encoding/rs/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rs

import (
"errors"
"fmt"
"log"
"time"

Expand Down Expand Up @@ -56,8 +57,8 @@ func (g *Encoder) Encode(inputFr []fr.Element) (*GlobalPoly, []Frame, []uint32,
return nil, nil, nil, err
}

log.Printf(" SUMMARY: Encode %v byte among %v numNode takes %v\n",
len(inputFr)*encoding.BYTES_PER_COEFFICIENT, g.NumChunks, time.Since(start))
log.Printf(" SUMMARY: RSEncode %v byte among %v numChunks with chunkLength %v takes %v\n",
len(inputFr)*encoding.BYTES_PER_COEFFICIENT, g.NumChunks, g.ChunkLength, time.Since(start))

return poly, frames, indices, nil
}
Expand All @@ -72,34 +73,47 @@ func (g *Encoder) MakeFrames(
if err != nil {
return nil, nil, err
}
k := uint64(0)

indices := make([]uint32, 0)
frames := make([]Frame, g.NumChunks)

for i := uint64(0); i < uint64(g.NumChunks); i++ {
numWorker := uint64(g.NumRSWorker)

// finds out which coset leader i-th node is having
j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i))
if numWorker > g.NumChunks {
numWorker = g.NumChunks
}

// mutltiprover return proof in butterfly order
frame := Frame{}
indices = append(indices, j)
jobChan := make(chan JobRequest, numWorker)
results := make(chan error, numWorker)

ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)]
err := rb.ReverseBitOrderFr(ys)
if err != nil {
return nil, nil, err
}
coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j))
if err != nil {
return nil, nil, err
for w := uint64(0); w < numWorker; w++ {
go g.interpolyWorker(
polyEvals,
jobChan,
results,
frames,
)
}

for i := uint64(0); i < g.NumChunks; i++ {
j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i))
jr := JobRequest{
Index: i,
}
jobChan <- jr
indices = append(indices, j)
}
close(jobChan)

frame.Coeffs = coeffs
for w := uint64(0); w < numWorker; w++ {
interPolyErr := <-results
if interPolyErr != nil {
err = interPolyErr
}
}

frames[k] = frame
k++
if err != nil {
return nil, nil, fmt.Errorf("proof worker error: %v", err)
}

return frames, indices, nil
Expand Down Expand Up @@ -127,3 +141,36 @@ func (g *Encoder) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, []fr.Elemen

return evals, pdCoeffs, nil
}

type JobRequest struct {
Index uint64
}

func (g *Encoder) interpolyWorker(
polyEvals []fr.Element,
jobChan <-chan JobRequest,
results chan<- error,
frames []Frame,
) {

for jr := range jobChan {
i := jr.Index
j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i))
ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)]
err := rb.ReverseBitOrderFr(ys)
if err != nil {
results <- err
continue
}
coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j))
if err != nil {
results <- err
continue
}

frames[i].Coeffs = coeffs
}

results <- nil

}
4 changes: 4 additions & 0 deletions encoding/rs/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rs

import (
"math"
"runtime"

"github.com/Layr-Labs/eigenda/encoding"
"github.com/Layr-Labs/eigenda/encoding/fft"
Expand All @@ -13,6 +14,8 @@ type Encoder struct {
Fs *fft.FFTSettings

verbose bool

NumRSWorker int
}

// The function creates a high level struct that determines the encoding the a data of a
Expand All @@ -37,6 +40,7 @@ func NewEncoder(params encoding.EncodingParams, verbose bool) (*Encoder, error)
EncodingParams: params,
Fs: fs,
verbose: verbose,
NumRSWorker: runtime.GOMAXPROCS(0),
}, nil

}
23 changes: 7 additions & 16 deletions encoding/rs/interpolation.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ func (g *Encoder) GetInterpolationPolyEval(
//var tmp, tmp2 fr.Element
for i := 0; i < len(interpolationPoly); i++ {
shiftedInterpolationPoly[i].Mul(&interpolationPoly[i], &wPow)

wPow.Mul(&wPow, &w)

}

err := g.Fs.InplaceFFT(shiftedInterpolationPoly, evals, false)
Expand All @@ -66,26 +64,19 @@ func (g *Encoder) GetInterpolationPolyEval(
// Since both F W are invertible, c = W^-1 F^-1 d, convert it back. F W W^-1 F^-1 d = c
func (g *Encoder) GetInterpolationPolyCoeff(chunk []fr.Element, k uint32) ([]fr.Element, error) {
coeffs := make([]fr.Element, g.ChunkLength)
w := g.Fs.ExpandedRootsOfUnity[uint64(k)]
shiftedInterpolationPoly := make([]fr.Element, len(chunk))
err := g.Fs.InplaceFFT(chunk, shiftedInterpolationPoly, true)
if err != nil {
return coeffs, err
}
var wPow fr.Element
wPow.SetOne()

var tmp, tmp2 fr.Element

mod := int32(len(g.Fs.ExpandedRootsOfUnity) - 1)

for i := 0; i < len(chunk); i++ {
tmp.Inverse(&wPow)

tmp2.Mul(&shiftedInterpolationPoly[i], &tmp)

coeffs[i].Set(&tmp2)

tmp.Mul(&wPow, &w)

wPow.Set(&tmp)
// We can lookup the inverse power by counting RootOfUnity backward
j := (-int32(k)*int32(i))%mod + mod
coeffs[i].Mul(&shiftedInterpolationPoly[i], &g.Fs.ExpandedRootsOfUnity[j])
}

return coeffs, nil
}

0 comments on commit 992e0f2

Please sign in to comment.