diff --git a/portalnetwork/history/accumulator.go b/portalnetwork/history/accumulator.go index 9dafee13c231..56f7028f9956 100644 --- a/portalnetwork/history/accumulator.go +++ b/portalnetwork/history/accumulator.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/portalnetwork/utils" "github.com/ethereum/go-ethereum/rlp" ssz "github.com/ferranbt/fastssz" "github.com/holiman/uint256" @@ -43,11 +44,15 @@ func newEpoch() *epoch { func (e *epoch) add(header types.Header) error { blockHash := header.Hash().Bytes() - difficulty := uint256.MustFromBig(header.Number) + difficulty := uint256.MustFromBig(header.Difficulty) e.difficulty = uint256.NewInt(0).Add(e.difficulty, difficulty) + // big-endian + difficultyBytes := e.difficulty.Bytes32() + utils.ReverseBytesInPlace(difficultyBytes[:]) + record := HeaderRecord{ BlockHash: blockHash, - TotalDifficulty: e.difficulty.Bytes(), + TotalDifficulty: difficultyBytes[:], } sszBytes, err := record.MarshalSSZ() if err != nil { diff --git a/portalnetwork/history/accumulator_test.go b/portalnetwork/history/accumulator_test.go index 0df94fe78ccb..9d66a06b3a1e 100644 --- a/portalnetwork/history/accumulator_test.go +++ b/portalnetwork/history/accumulator_test.go @@ -4,13 +4,16 @@ import ( "bytes" "encoding/json" "fmt" + "math/big" "os" "strconv" "testing" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/portalnetwork/utils" "github.com/ethereum/go-ethereum/rlp" + "github.com/holiman/uint256" "github.com/stretchr/testify/assert" ) @@ -51,6 +54,35 @@ func TestBuildAndVerifyProof(t *testing.T) { } } +func TestUpdate(t *testing.T) { + epochAcc, err := getEpochAccu("0xcddbda3fd6f764602c06803ff083dbfc73f2bb396df17a31e5457329b9a0f38d") + assert.NoError(t, err) + + startNumber := 1000000 + epochRecordIndex := GetHeaderRecordIndex(uint64(startNumber)) + + newEpochAcc := NewAccumulator() + + for i := 0; i <= int(epochRecordIndex); i++ { + tmp := make([]byte, 64) + copy(tmp, epochAcc.HeaderRecords[i]) + newEpochAcc.currentEpoch.records = append(newEpochAcc.currentEpoch.records, tmp) + } + + startDifficulty := bytesToUint256(epochAcc.HeaderRecords[epochRecordIndex][32:]) + + newEpochAcc.currentEpoch.difficulty = startDifficulty + + for i := startNumber + 1; i <= 1000010; i++ { + header, err := getHeader(uint64(i)) + assert.NoError(t, err) + err = newEpochAcc.Update(*header) + assert.NoError(t, err) + currIndex := GetHeaderRecordIndex(uint64(i)) + assert.True(t, bytes.Equal(newEpochAcc.currentEpoch.records[currIndex], epochAcc.HeaderRecords[currIndex])) + } +} + // all test blocks are in the same epoch func parseHeaderWithProof() ([]BlockHeaderWithProof, error) { headWithProofBytes, err := os.ReadFile("./testdata/header_with_proofs.json") @@ -112,3 +144,8 @@ func getHeader(number uint64) (*types.Header, error) { err = rlp.Decode(reader, head) return head, err } + +func bytesToUint256(input []byte) *uint256.Int { + res := utils.ReverseBytes(input) + return uint256.MustFromBig(big.NewInt(0).SetBytes(res)) +} diff --git a/portalnetwork/storage/content_storage.go b/portalnetwork/storage/content_storage.go index 8b9b9a389bb4..9f24fb08dfb9 100644 --- a/portalnetwork/storage/content_storage.go +++ b/portalnetwork/storage/content_storage.go @@ -11,6 +11,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/utils" "github.com/holiman/uint256" sqlite3 "github.com/mattn/go-sqlite3" ) @@ -285,7 +286,7 @@ func (p *ContentStorage) GetLargestDistance() (*uint256.Int, error) { return nil, err } // reverse the distance, because big.SetBytes is big-endian - reverseBytes(distance) + utils.ReverseBytesInPlace(distance) bigNum := new(big.Int).SetBytes(distance) res := uint256.MustFromBig(bigNum) return res, nil @@ -385,9 +386,3 @@ func (p *ContentStorage) deleteContentOutOfRadius(radius *uint256.Int) error { func (p *ContentStorage) ForcePrune(radius *uint256.Int) error { return p.deleteContentOutOfRadius(radius) } - -func reverseBytes(src []byte) { - for i := 0; i < len(src)/2; i++ { - src[i], src[len(src)-i-1] = src[len(src)-i-1], src[i] - } -} diff --git a/portalnetwork/utils/bytes.go b/portalnetwork/utils/bytes.go new file mode 100644 index 000000000000..0205d31ca725 --- /dev/null +++ b/portalnetwork/utils/bytes.go @@ -0,0 +1,16 @@ +package utils + +func ReverseBytesInPlace(src []byte) { + for i := 0; i < len(src)/2; i++ { + src[i], src[len(src)-i-1] = src[len(src)-i-1], src[i] + } +} + +func ReverseBytes(src []byte) []byte { + lenth := len(src) + dst := make([]byte, lenth) + for i := 0; i < len(src); i++ { + dst[lenth-1-i] = src[i] + } + return dst +}