Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu encoder no blob batching #642

Open
wants to merge 6 commits into
base: gpu-encode
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions encoding/kzg/prover/cpu/multiframe_proof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package cpu

import (
"fmt"
"math"
"time"

"github.com/Layr-Labs/eigenda/encoding/fft"
"github.com/Layr-Labs/eigenda/encoding/kzg"
"github.com/Layr-Labs/eigenda/encoding/utils/toeplitz"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bn254"
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
)

type WorkerResult struct {
points []bn254.G1Affine
err error
}

type CpuComputer struct {
*kzg.KzgConfig
Fs *fft.FFTSettings
FFTPointsT [][]bn254.G1Affine // transpose of FFTPoints
SFs *fft.FFTSettings
Srs *kzg.SRS
G2Trailing []bn254.G2Affine
}

func (p *CpuComputer) ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine, error) {
inputLength := uint64(len(coeffs))
shiftedSecret := p.G2Trailing[p.KzgConfig.SRSNumberToLoad-inputLength:]
config := ecc.MultiExpConfig{}
//The proof of low degree is commitment of the polynomial shifted to the largest srs degree
var lengthProof bn254.G2Affine
_, err := lengthProof.MultiExp(shiftedSecret, coeffs, config)
if err != nil {
return nil, err
}
return &lengthProof, nil
}

func (p *CpuComputer) ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, error) {
// compute commit for the full poly
config := ecc.MultiExpConfig{}
var commitment bn254.G1Affine
_, err := commitment.MultiExp(p.Srs.G1[:len(coeffs)], coeffs, config)
if err != nil {
return nil, err
}
return &commitment, nil
}

func (p *CpuComputer) ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Affine, error) {
config := ecc.MultiExpConfig{}

var lengthCommitment bn254.G2Affine
_, err := lengthCommitment.MultiExp(p.Srs.G2[:len(coeffs)], coeffs, config)
if err != nil {
return nil, err
}
return &lengthCommitment, nil
}

func (p *CpuComputer) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) {
begin := time.Now()
// Robert: Standardizing this to use the same math used in precomputeSRS
dimE := numChunks
l := chunkLen

sumVec := make([]bn254.G1Affine, dimE*2)

jobChan := make(chan uint64, numWorker)
results := make(chan WorkerResult, numWorker)

// create storage for intermediate fft outputs
coeffStore := make([][]fr.Element, dimE*2)
for i := range coeffStore {
coeffStore[i] = make([]fr.Element, l)
}

for w := uint64(0); w < numWorker; w++ {
go p.proofWorker(polyFr, jobChan, l, dimE, coeffStore, results)
}

for j := uint64(0); j < l; j++ {
jobChan <- j
}
close(jobChan)

// return last error
var err error
for w := uint64(0); w < numWorker; w++ {
wr := <-results
if wr.err != nil {
err = wr.err
}
}

if err != nil {
return nil, fmt.Errorf("proof worker error: %v", err)
}

t0 := time.Now()

// compute proof by multi scaler multiplication
msmErrors := make(chan error, dimE*2)
for i := uint64(0); i < dimE*2; i++ {

go func(k uint64) {
_, err := sumVec[k].MultiExp(p.FFTPointsT[k], coeffStore[k], ecc.MultiExpConfig{})
// handle error
msmErrors <- err
}(i)
}

for i := uint64(0); i < dimE*2; i++ {
err := <-msmErrors
if err != nil {
fmt.Println("Error. MSM while adding points", err)
return nil, err
}
}

t1 := time.Now()

// only 1 ifft is needed
sumVecInv, err := p.Fs.FFTG1(sumVec, true)
if err != nil {
return nil, fmt.Errorf("fft error: %v", err)
}

t2 := time.Now()

// outputs is out of order - buttefly
proofs, err := p.Fs.FFTG1(sumVecInv[:dimE], false)
if err != nil {
return nil, err
}

t3 := time.Now()

fmt.Printf("mult-th %v, msm %v,fft1 %v, fft2 %v,\n", t0.Sub(begin), t1.Sub(t0), t2.Sub(t1), t3.Sub(t2))

return proofs, nil
}

func (p *CpuComputer) proofWorker(
polyFr []fr.Element,
jobChan <-chan uint64,
l uint64,
dimE uint64,
coeffStore [][]fr.Element,
results chan<- WorkerResult,
) {

for j := range jobChan {
coeffs, err := p.GetSlicesCoeff(polyFr, dimE, j, l)
if err != nil {
results <- WorkerResult{
points: nil,
err: err,
}
} else {
for i := 0; i < len(coeffs); i++ {
coeffStore[i][j] = coeffs[i]
}
}
}

results <- WorkerResult{
err: nil,
}
}

// output is in the form see primeField toeplitz
//
// phi ^ (coset size ) = 1
//
// implicitly pad slices to power of 2
func (p *CpuComputer) GetSlicesCoeff(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) {
// there is a constant term
m := uint64(len(polyFr)) - 1
dim := (m - j) / l

toeV := make([]fr.Element, 2*dimE-1)
for i := uint64(0); i < dim; i++ {

toeV[i].Set(&polyFr[m-(j+i*l)])
}

// use precompute table
tm, err := toeplitz.NewToeplitz(toeV, p.SFs)
if err != nil {
return nil, err
}
return tm.GetFFTCoeff()
}

/*
returns the power of 2 which is immediately bigger than the input
*/
func CeilIntPowerOf2Num(d uint64) uint64 {
nextPower := math.Ceil(math.Log2(float64(d)))
return uint64(math.Pow(2.0, nextPower))
}
42 changes: 42 additions & 0 deletions encoding/kzg/prover/gpu/ecntt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package gpu

import (
"fmt"

"github.com/Layr-Labs/eigenda/encoding/utils/gpu_utils"
"github.com/consensys/gnark-crypto/ecc/bn254"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
icicle_bn254 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254"
ecntt "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/ecntt"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
)

func (c *GpuComputeDevice) ECNttToGnark(batchPoints core.HostOrDeviceSlice, isInverse bool, totalSize int) ([]bn254.G1Affine, error) {
output, err := c.ECNtt(batchPoints, isInverse, totalSize)
if err != nil {
return nil, err
}

// convert icicle projective to gnark affine
gpuFFTBatch := gpu_utils.HostSliceIcicleProjectiveToGnarkAffine(output, int(c.NumWorker))

return gpuFFTBatch, nil
}

func (c *GpuComputeDevice) ECNtt(batchPoints core.HostOrDeviceSlice, isInverse bool, totalSize int) (core.HostSlice[icicle_bn254.Projective], error) {
output := make(core.HostSlice[icicle_bn254.Projective], totalSize)

if isInverse {
err := ecntt.ECNtt(batchPoints, core.KInverse, &c.NttCfg, output)
if err.CudaErrorCode != cr.CudaSuccess || err.IcicleErrorCode != core.IcicleSuccess {
return nil, fmt.Errorf("inverse ecntt failed")
}
} else {
err := ecntt.ECNtt(batchPoints, core.KForward, &c.NttCfg, output)
if err.CudaErrorCode != cr.CudaSuccess || err.IcicleErrorCode != core.IcicleSuccess {
return nil, fmt.Errorf("forward ecntt failed")
}
}
return output, nil
}
32 changes: 32 additions & 0 deletions encoding/kzg/prover/gpu/msm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package gpu

import (
"fmt"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
icicle_bn254 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254"
icicle_bn254_msm "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/msm"
)

// MsmBatch function supports batch across blobs.
// totalSize is the number of output points, which equals to numPoly * 2 * dimE , dimE is number of chunks
func (c *GpuComputeDevice) MsmBatch(rowsFrIcicleCopy core.HostOrDeviceSlice, rowsG1Icicle []icicle_bn254.Affine, totalSize int) (core.DeviceSlice, error) {
msmCfg := icicle_bn254_msm.GetDefaultMSMConfig()

rowsG1IcicleCopy := core.HostSliceFromElements[icicle_bn254.Affine](rowsG1Icicle)

var p icicle_bn254.Projective
var out core.DeviceSlice

_, err := out.Malloc(totalSize*p.Size(), p.Size())
if err != cr.CudaSuccess {
return out, fmt.Errorf("%v", "Allocating bytes on device for Projective results failed")
}

err = icicle_bn254_msm.Msm(rowsFrIcicleCopy, rowsG1IcicleCopy, &msmCfg, out)
if err != cr.CudaSuccess {
return out, fmt.Errorf("%v", "Msm failed")
}
return out, nil
}
Loading
Loading