Skip to content

Commit

Permalink
Mixed encoder test passes
Browse files Browse the repository at this point in the history
  • Loading branch information
mooselumph committed Nov 28, 2023
1 parent 11c85b9 commit 02c202f
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 49 deletions.
23 changes: 17 additions & 6 deletions pkg/encoding/encoder/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ import (
// maxInputSize is the upper bound of the original data size. This is needed because
// the frames and indices don't encode the length of the original data. If maxInputSize
// is smaller than the original input size, decoded data will be trimmed to fit the maxInputSize.
func (g *Encoder) Decode(frames []Frame, indices []uint64, maxInputSize uint64) ([]byte, error) {
numSys := GetNumSys(maxInputSize, g.ChunkLen)

if uint64(len(frames)) < numSys {
return nil, errors.New("number of frame must be sufficient")
}
func (g *Encoder) Decode(frames []Frame, indices []uint64) ([]bls.Fr, error) {

samples := make([]*bls.Fr, g.NumEvaluations())
// copy evals based on frame coeffs into samples
Expand Down Expand Up @@ -70,6 +65,22 @@ func (g *Encoder) Decode(frames []Frame, indices []uint64, maxInputSize uint64)
return nil, err
}

return reconstructedPoly, nil
}

func (g *Encoder) DecodeBytes(frames []Frame, indices []uint64, maxInputSize uint64) ([]byte, error) {

numSys := GetNumSys(maxInputSize, g.ChunkLen)

if uint64(len(frames)) < numSys {
return nil, errors.New("number of frame must be sufficient")
}

reconstructedPoly, err := g.Decode(frames, indices)
if err != nil {
return nil, err
}

data := ToByteArray(reconstructedPoly, maxInputSize)

return data, nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/encoding/encoder/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestEncodeDecode_InvertsWhenSamplingAllFrames(t *testing.T) {

// sample some frames
samples, indices := sampleFrames(frames, uint64(len(frames)))
data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)))
data, err := enc.DecodeBytes(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)))

require.Nil(t, err)
require.NotNil(t, data)
Expand All @@ -47,7 +47,7 @@ func TestEncodeDecode_InvertsWhenSamplingMissingFrame(t *testing.T) {

// sample some frames
samples, indices := sampleFrames(frames, uint64(len(frames)-1))
data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)))
data, err := enc.DecodeBytes(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)))

require.Nil(t, err)
require.NotNil(t, data)
Expand All @@ -71,7 +71,7 @@ func TestEncodeDecode_ErrorsWhenNotEnoughSampledFrames(t *testing.T) {

// sample some frames
samples, indices := sampleFrames(frames, uint64(len(frames)-2))
data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)))
data, err := enc.DecodeBytes(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)))

require.Nil(t, data)
require.NotNil(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/encoding/encoder/encoder_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func FuzzOnlySystematic(f *testing.F) {
//sample the correct systematic frames
samples, indices := sampleFrames(frames, uint64(len(frames)))

data, err := enc.Decode(samples, indices, uint64(len(input)))
data, err := enc.DecodeBytes(samples, indices, uint64(len(input)))
if err != nil {
t.Errorf("Error Decoding:\n Data:\n %q \n Err: %q", input, err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/encoding/kzgEncoder/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ func (g *KzgEncoder) Decode(frames []Frame, indices []uint64, maxInputSize uint6
rsFrames[ind] = rs.Frame{Coeffs: frame.Coeffs}
}

return g.Encoder.Decode(rsFrames, indices, maxInputSize)
return g.Encoder.DecodeBytes(rsFrames, indices, maxInputSize)
}
57 changes: 37 additions & 20 deletions pkg/encoding/kzgEncoder/shiftedencoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestShiftedEncoding(t *testing.T) {

group, _ := kzgRs.NewKzgEncoderGroup(kzgConfig)

blobSize := 256
blobSize := 32
blob := make([]byte, blobSize*31)
_, err := rand.Read(blob)
assert.NoError(t, err)
Expand All @@ -29,15 +29,13 @@ func TestShiftedEncoding(t *testing.T) {

params := rs.EncodingParams{
NumChunks: 4,
ChunkLen: 128,
ChunkLen: 16,
}
enc, err := group.NewKzgEncoder(params)
if err != nil {
t.Errorf("Error making rs: %q", err)
}

origCommit := enc.Commit(input)

n := uint8(math.Log2(float64(enc.NumEvaluations()))) + 1
fs := kzg.NewFFTSettings(n)

Expand All @@ -51,33 +49,52 @@ func TestShiftedEncoding(t *testing.T) {
}

//encode the data
commit, _, frames, indices, err := enc.Encode(input)
_, _, frames, indices_, err := enc.Encode(input)
if err != nil {
t.Errorf("Error Encoding:\n Data:\n %q \n Err: %q", input, err)
}

shiftedCommit := bls.G1Point{}
bls.MulG1(&shiftedCommit, origCommit, factor)
fmt.Println("indices_", indices_)

assert.True(t, bls.EqualG1(commit, &shiftedCommit), "commitment mismatch")
// for _, frame := range frames {
// assert.NotEqual(t, len(frame.Coeffs), 0)
// }

for _, frame := range frames {
assert.NotEqual(t, len(frame.Coeffs), 0)
}
// for i := 0; i < len(frames); i++ {
// f := frames[i]
// j := indices[i]

// q, err := rs.GetLeadingCosetIndex(uint64(i), params.NumChunks)
// assert.Nil(t, err)

// assert.Equal(t, j, q, "leading coset inconsistency")

for i := 0; i < len(frames); i++ {
f := frames[i]
j := indices[i]
// fmt.Printf("frame %v leading coset %v\n", i, j)
// lc := enc.Fs.ExpandedRootsOfUnity[uint64(q)]

q, err := rs.GetLeadingCosetIndex(uint64(i), params.NumChunks)
assert.Nil(t, err)
// assert.True(t, f.Verify(enc.Ks, &shiftedCommit, &lc), "Proof %v failed\n", i)
// }

assert.Equal(t, j, q, "leading coset inconsistency")
samples_, indices := sampleFrames(frames, uint64(len(frames)))

fmt.Printf("frame %v leading coset %v\n", i, j)
lc := enc.Fs.ExpandedRootsOfUnity[uint64(q)]
samples := make([]rs.Frame, len(frames))
for i, frame := range samples_ {
samples[i] = rs.Frame{
Coeffs: frame.Coeffs,
}
}

fmt.Println("len(samples)", len(samples), "len(frames)", len(frames), "len(indices)", len(indices))

recoveredCoeffs, err := enc.Encoder.Decode(samples, indices)
assert.NoError(t, err)

assert.True(t, f.Verify(enc.Ks, &shiftedCommit, &lc), "Proof %v failed\n", i)
notEqual := make([]int, 0)
for i := 0; i < len(input); i++ {
if !bls.EqualFr(&input[i], &recoveredCoeffs[i]) {
notEqual = append(notEqual, i)
}
}
assert.Equal(t, []int{}, notEqual)

}
2 changes: 1 addition & 1 deletion pkg/encoding/mixedencoder/allocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func AddOffsets(allocations []*Allocation) error {
// Sort allocations by number of evaluations
sorted := make([]*Allocation, len(allocations))
copy(sorted, allocations)
sort.Slice(sorted, func(i, j int) bool {
sort.SliceStable(sorted, func(i, j int) bool {
return sorted[i].NumEvaluations > sorted[j].NumEvaluations
})

Expand Down
1 change: 1 addition & 0 deletions pkg/encoding/mixedencoder/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func (e *MixedEncoder) Encode(input []byte, params []rs.EncodingParams) (*bls.G1

// Condition the input
shiftedPolyCoeffs := ShiftPoly(coeffs, allocations[ind].Offset)
fmt.Println("Offset", allocations[ind].Offset, "RootIndex", allocations[ind].RootIndex)

// Encode
shiftedCommit, _, frames, indices, err := encoder.Encode(shiftedPolyCoeffs)
Expand Down
74 changes: 57 additions & 17 deletions pkg/encoding/mixedencoder/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
rs "github.com/Layr-Labs/eigenda/pkg/encoding/encoder"
kzgrs "github.com/Layr-Labs/eigenda/pkg/encoding/kzgEncoder"
"github.com/Layr-Labs/eigenda/pkg/encoding/mixedencoder"
"github.com/Layr-Labs/eigenda/pkg/kzg/bn254"
"github.com/stretchr/testify/assert"

"github.com/Layr-Labs/eigenda/pkg/kzg/bn254"
)

var (
Expand Down Expand Up @@ -78,8 +79,30 @@ func TestMixedEncoding(t *testing.T) {
// }

// Decode
inputs := make([]*mixedencoder.MixedDecoderInput, len(outputs))
for i, output := range outputs {
inputs := sampleInputs(outputs)

// for _, input := range inputs {
// testDecode(t, blob, input)
// }

numEvaluations := 0
for _, input := range inputs {
numEvaluations += input.Allocation.NumEvaluations
}
numEvaluations = int(rs.NextPowerOf2(uint64(numEvaluations)))

decoded, err := encoder.Decode(numEvaluations, len(blob), inputs)

assert.NoError(t, err)
assert.Equal(t, string(blob), string(decoded))

}

func sampleInputs(outputs []*mixedencoder.MixedEncodingOutput) []*mixedencoder.MixedDecoderInput {

inputs := make([]*mixedencoder.MixedDecoderInput, 0)

for _, output := range outputs {

frames := make([]rs.Frame, len(output.Frames))
for j, frame := range output.Frames {
Expand All @@ -88,31 +111,48 @@ func TestMixedEncoding(t *testing.T) {
}
}

inputs[i] = &mixedencoder.MixedDecoderInput{
indices := make([]uint32, len(output.Indices))
for j := range output.Indices {
indices[j] = uint32(j)
}

inputs = append(inputs, &mixedencoder.MixedDecoderInput{
EncodingParams: output.Param,
Allocation: output.Allocation,
Frames: frames,
Indices: output.Indices,
}
Indices: indices,
})
}

numEvaluations := 0
for _, input := range inputs {
numEvaluations += input.Allocation.NumEvaluations
}
numEvaluations = int(rs.NextPowerOf2(uint64(numEvaluations)))
return inputs

tInputs := []*mixedencoder.MixedDecoderInput{
inputs[1], inputs[2],
}

func testDecode(t *testing.T, blob []byte, input *mixedencoder.MixedDecoderInput) {

enc, err := rs.NewEncoder(input.EncodingParams, false)
if err != nil {
t.Fatal(err)
}

decoded, err := encoder.Decode(numEvaluations, len(blob), tInputs)
indices := make([]uint64, len(input.Indices))
for i := range input.Indices {
indices[i] = uint64(i)
}

recoveredCoeffs, err := enc.Decode(input.Frames, indices)
assert.NoError(t, err)
assert.Equal(t, string(blob), string(decoded))

fmt.Println("Offset", tInputs[0].Allocation.RootIndex)
fmt.Println(string(decoded))
origCoeffs := rs.ToFrArray(blob)
shiftedCoeffs := mixedencoder.ShiftPoly(origCoeffs, input.Allocation.Offset)

notEqual := make([]int, 0)
for i := 0; i < len(shiftedCoeffs); i++ {
if !bn254.EqualFr(&shiftedCoeffs[i], &recoveredCoeffs[i]) {
notEqual = append(notEqual, i)
}
}
assert.Equal(t, []int{}, notEqual)

}

Expand Down

0 comments on commit 02c202f

Please sign in to comment.