From 7dba93a346e3b5e37cf7b7535aa8a2d6859b6ddb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 5 Feb 2024 19:08:51 +0000 Subject: [PATCH] fix based on feedbacks --- core/encoding/cli.go | 2 +- .../kzgEncoder/batchCommitEquivalence_test.go | 4 +- pkg/encoding/kzgEncoder/degree_test.go | 4 +- pkg/encoding/utils/pointsIO.go | 70 ++++++++++++------- 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/core/encoding/cli.go b/core/encoding/cli.go index ac9bf18760..3df33493fd 100644 --- a/core/encoding/cli.go +++ b/core/encoding/cli.go @@ -48,7 +48,7 @@ func CLIFlags(envPrefix string) []cli.Flag { }, cli.Uint64Flag{ Name: SRSLoadingNumberFlagName, - Usage: "Number of the SRS to load into memory", + Usage: "Number of SRS points to load into memory", Required: true, EnvVar: common.PrefixEnvVar(envPrefix, "SRS_LOAD"), }, diff --git a/pkg/encoding/kzgEncoder/batchCommitEquivalence_test.go b/pkg/encoding/kzgEncoder/batchCommitEquivalence_test.go index 18e6962ef2..bbfebaa63f 100644 --- a/pkg/encoding/kzgEncoder/batchCommitEquivalence_test.go +++ b/pkg/encoding/kzgEncoder/batchCommitEquivalence_test.go @@ -32,7 +32,7 @@ func TestBatchEquivalence(t *testing.T) { } } - assert.True(t, group.BatchVerifyCommitEquivalence(commitPairs) == nil, "batch equivalence test failed\n") + assert.Error(t, group.BatchVerifyCommitEquivalence(commitPairs), "batch equivalence negative test failed\n") var modifiedCommit bn254.G1Point bn254.AddG1(&modifiedCommit, commit, commit) @@ -54,5 +54,5 @@ func TestBatchEquivalence(t *testing.T) { bn254.AddG1(&commitPairs[numBlob/2].Commitment, &commitPairs[numBlob/2].Commitment, &commitPairs[numBlob/2].Commitment) - assert.False(t, group.BatchVerifyCommitEquivalence(commitPairs) == nil, "batch equivalence negative test failed in outer loop\n") + assert.Error(t, group.BatchVerifyCommitEquivalence(commitPairs), "batch equivalence negative test failed in outer loo\n") } diff --git a/pkg/encoding/kzgEncoder/degree_test.go b/pkg/encoding/kzgEncoder/degree_test.go index de42e80fde..6454a821d1 100644 --- a/pkg/encoding/kzgEncoder/degree_test.go +++ b/pkg/encoding/kzgEncoder/degree_test.go @@ -28,9 +28,9 @@ func TestLengthProof(t *testing.T) { require.Nil(t, err) length := len(inputFr) - assert.True(t, group.VerifyCommit(lowDegreeCommitment, lowDegreeProof, uint64(length)) == nil, "low degree verification failed\n") + assert.NoError(t, group.VerifyCommit(lowDegreeCommitment, lowDegreeProof, uint64(length)), "low degree verification failed\n") length = len(inputFr) - 10 - assert.False(t, group.VerifyCommit(lowDegreeCommitment, lowDegreeProof, uint64(length)) == nil, "low degree verification failed\n") + assert.Error(t, group.VerifyCommit(lowDegreeCommitment, lowDegreeProof, uint64(length)), "low degree verification failed\n") } } diff --git a/pkg/encoding/utils/pointsIO.go b/pkg/encoding/utils/pointsIO.go index 412c563770..fc90cb2f3e 100644 --- a/pkg/encoding/utils/pointsIO.go +++ b/pkg/encoding/utils/pointsIO.go @@ -6,7 +6,6 @@ import ( "io" "log" "os" - "sync" "time" bls "github.com/Layr-Labs/eigenda/pkg/kzg/bn254" @@ -71,13 +70,12 @@ func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bls.G1Point, e s1Outs := make([]bls.G1Point, n) - var wg sync.WaitGroup - wg.Add(int(numWorker)) - start := uint64(0) end := uint64(0) size := n / numWorker + results := make(chan error, numWorker) + for i := uint64(0); i < numWorker; i++ { start = i * size @@ -88,9 +86,15 @@ func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bls.G1Point, e } //fmt.Printf("worker %v start %v end %v. size %v\n", i, start, end, end - start) //todo: handle error? - go readG1Worker(buf, s1Outs, start, end, G1PointBytes, &wg) + go readG1Worker(buf, s1Outs, start, end, G1PointBytes, results) + } + + for w := uint64(0); w < numWorker; w++ { + err := <-results + if err != nil { + return nil, err + } } - wg.Wait() // measure parsing time t = time.Now() @@ -144,13 +148,12 @@ func ReadG1PointSection(filepath string, from, to uint64, numWorker uint64) ([]b s1Outs := make([]bls.G1Point, n) - var wg sync.WaitGroup - wg.Add(int(numWorker)) - start := uint64(0) end := uint64(0) size := n / numWorker + results := make(chan error, numWorker) + for i := uint64(0); i < numWorker; i++ { start = i * size @@ -159,10 +162,16 @@ func ReadG1PointSection(filepath string, from, to uint64, numWorker uint64) ([]b } else { end = (i + 1) * size } - //todo: handle error? - go readG1Worker(buf, s1Outs, start, end, G1PointBytes, &wg) + + go readG1Worker(buf, s1Outs, start, end, G1PointBytes, results) + } + + for w := uint64(0); w < numWorker; w++ { + err := <-results + if err != nil { + return nil, err + } } - wg.Wait() // measure parsing time t = time.Now() @@ -177,16 +186,17 @@ func readG1Worker( start uint64, // in element, not in byte end uint64, step uint64, - wg *sync.WaitGroup, + results chan<- error, ) { for i := start; i < end; i++ { g1 := buf[i*step : (i+1)*step] err := outs[i].UnmarshalText(g1[:]) if err != nil { + results <- err panic(err) } } - wg.Done() + results <- nil } func readG2Worker( @@ -195,16 +205,18 @@ func readG2Worker( start uint64, // in element, not in byte end uint64, step uint64, - wg *sync.WaitGroup, + results chan<- error, ) { for i := start; i < end; i++ { g1 := buf[i*step : (i+1)*step] err := outs[i].UnmarshalText(g1[:]) if err != nil { + results <- err log.Println("Unmarshalling error:", err) + panic(err) } } - wg.Done() + results <- nil } func ReadG2Points(filepath string, n uint64, numWorker uint64) ([]bls.G2Point, error) { @@ -242,8 +254,7 @@ func ReadG2Points(filepath string, n uint64, numWorker uint64) ([]bls.G2Point, e s2Outs := make([]bls.G2Point, n) - var wg sync.WaitGroup - wg.Add(int(numWorker)) + results := make(chan error, numWorker) start := uint64(0) end := uint64(0) @@ -256,11 +267,16 @@ func ReadG2Points(filepath string, n uint64, numWorker uint64) ([]bls.G2Point, e } else { end = (i + 1) * size } - //todo: handle error? - go readG2Worker(buf, s2Outs, start, end, G2PointBytes, &wg) + go readG2Worker(buf, s2Outs, start, end, G2PointBytes, results) + } + + for w := uint64(0); w < numWorker; w++ { + err := <-results + if err != nil { + return nil, err + } } - wg.Wait() // measure parsing time t = time.Now() @@ -314,8 +330,7 @@ func ReadG2PointSection(filepath string, from, to uint64, numWorker uint64) ([]b s2Outs := make([]bls.G2Point, n) - var wg sync.WaitGroup - wg.Add(int(numWorker)) + results := make(chan error, numWorker) start := uint64(0) end := uint64(0) @@ -330,9 +345,14 @@ func ReadG2PointSection(filepath string, from, to uint64, numWorker uint64) ([]b end = (i + 1) * size } //todo: handle error? - go readG2Worker(buf, s2Outs, start, end, G2PointBytes, &wg) + go readG2Worker(buf, s2Outs, start, end, G2PointBytes, results) + } + for w := uint64(0); w < numWorker; w++ { + err := <-results + if err != nil { + return nil, err + } } - wg.Wait() // measure parsing time t = time.Now()