Skip to content

Commit

Permalink
build hotshot payload and add unitests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
ImJeremyHe committed Dec 17, 2024
1 parent 9000e9c commit e0af17e
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 80 deletions.
5 changes: 5 additions & 0 deletions arbnode/batch_poster.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ type BatchPosterConfig struct {
UseEscapeHatch bool `koanf:"use-escape-hatch"`
EspressoTxnsPollingInterval time.Duration `koanf:"espresso-txns-polling-interval"`
EspressoSwitchDelayThreshold uint64 `koanf:"espresso-switch-delay-threshold"`
EspressoMaxTransactioSize uint64 `koanf:"espresso-max-transaction-size"`
}

func (c *BatchPosterConfig) Validate() error {
Expand Down Expand Up @@ -244,6 +245,7 @@ func BatchPosterConfigAddOptions(prefix string, f *pflag.FlagSet) {
redislock.AddConfigOptions(prefix+".redis-lock", f)
dataposter.DataPosterConfigAddOptions(prefix+".data-poster", f, dataposter.DefaultDataPosterConfig)
genericconf.WalletConfigAddOptions(prefix+".parent-chain-wallet", f, DefaultBatchPosterConfig.ParentChainWallet.Pathname)
f.Uint64(prefix+".espresso-max-transaction-size", DefaultBatchPosterConfig.EspressoSwitchDelayThreshold, "specifies the max size of a espresso transasction")
}

var DefaultBatchPosterConfig = BatchPosterConfig{
Expand Down Expand Up @@ -277,6 +279,7 @@ var DefaultBatchPosterConfig = BatchPosterConfig{
EspressoSwitchDelayThreshold: 350,
LightClientAddress: "",
HotShotUrl: "",
EspressoMaxTransactioSize: 900 * 1024,
}

var DefaultBatchPosterL1WalletConfig = genericconf.WalletConfig{
Expand Down Expand Up @@ -313,6 +316,7 @@ var TestBatchPosterConfig = BatchPosterConfig{
EspressoSwitchDelayThreshold: 10,
LightClientAddress: "",
HotShotUrl: "",
EspressoMaxTransactioSize: 10 * 1024,
}

type BatchPosterOpts struct {
Expand Down Expand Up @@ -377,6 +381,7 @@ func NewBatchPoster(ctx context.Context, opts *BatchPosterOpts) (*BatchPoster, e
opts.Streamer.UseEscapeHatch = opts.Config().UseEscapeHatch
opts.Streamer.espressoTxnsPollingInterval = opts.Config().EspressoTxnsPollingInterval
opts.Streamer.espressoSwitchDelayThreshold = opts.Config().EspressoSwitchDelayThreshold
opts.Streamer.espressoMaxTransactionSize = opts.Config().EspressoMaxTransactioSize
}

b := &BatchPoster{
Expand Down
126 changes: 126 additions & 0 deletions arbnode/espresso_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package arbnode

import (
"bytes"
"encoding/binary"
"errors"

espressoTypes "github.com/EspressoSystems/espresso-sequencer-go/types"
"github.com/ethereum/go-ethereum/log"
"github.com/offchainlabs/nitro/arbutil"
)

const MAX_ATTESTATION_QUOTE_SIZE int = 4 * 1024
const LEN_SIZE int = 8
const INDEX_SIZE int = 8

func buildRawHotShotPayload(
msgPositions []arbutil.MessageIndex,
msgFetcher func(arbutil.MessageIndex) ([]byte, error),
maxSize uint64,
) ([]byte, int) {

payload := []byte{}
msgCnt := 0

for _, p := range msgPositions {
sizeBuf := make([]byte, LEN_SIZE)
positionBuf := make([]byte, INDEX_SIZE)
msg, err := msgFetcher(p)
if err != nil {
log.Warn("failed to fetch the message", "pos", p)
break
}
binary.BigEndian.PutUint64(sizeBuf, uint64(len(msg)))
binary.BigEndian.PutUint64(positionBuf, uint64(p))

if len(payload)+len(sizeBuf)+len(msg)+len(positionBuf)+MAX_ATTESTATION_QUOTE_SIZE > int(maxSize) {
break
}
// Add the submitted txn position and the size of the message along with the message
payload = append(payload, positionBuf...)
payload = append(payload, sizeBuf...)
payload = append(payload, msg...)
msgCnt += 1
}
return payload, msgCnt
}

func signHotShotPayload(
unsigned []byte,
signer func([]byte) ([]byte, error),
) ([]byte, error) {
quote, err := signer(unsigned)
if err != nil {
return nil, err
}

quoteSizeBuf := make([]byte, LEN_SIZE)
binary.BigEndian.PutUint64(quoteSizeBuf, uint64(len(quote)))
// Put the signature first. That would help easier parsing.
result := quoteSizeBuf
result = append(result, quote...)
result = append(result, unsigned...)

return result, nil
}

func validateIfPayloadIsInBlock(p []byte, payloads []espressoTypes.Bytes) bool {
validated := false
for _, payload := range payloads {
if bytes.Equal(p, payload) {
validated = true
break
}
}
return validated
}

func ParsePayload(rawPayload []byte) (signature []byte, indices []uint64, messages [][]byte, err error) {
if len(rawPayload) < LEN_SIZE {
return nil, nil, nil, errors.New("payload too short to parse signature size")
}

// Extract the signature size
signatureSize := binary.BigEndian.Uint64(rawPayload[:LEN_SIZE])
currentPos := LEN_SIZE

if len(rawPayload[currentPos:]) < int(signatureSize) {
return nil, nil, nil, errors.New("payload too short for signature")
}

// Extract the signature
signature = rawPayload[currentPos : currentPos+int(signatureSize)]
currentPos += int(signatureSize)

indices = []uint64{}
messages = [][]byte{}

// Parse messages
for {
if len(rawPayload[currentPos:]) < LEN_SIZE+INDEX_SIZE {
break // No more messages to parse
}

// Extract the index
index := binary.BigEndian.Uint64(rawPayload[currentPos : currentPos+INDEX_SIZE])
currentPos += INDEX_SIZE

// Extract the message size
messageSize := binary.BigEndian.Uint64(rawPayload[currentPos : currentPos+LEN_SIZE])
currentPos += LEN_SIZE

if len(rawPayload[currentPos:]) < int(messageSize) {
return nil, nil, nil, errors.New("message size mismatch")
}

// Extract the message
message := rawPayload[currentPos : currentPos+int(messageSize)]
currentPos += int(messageSize)

indices = append(indices, index)
messages = append(messages, message)
}

return signature, indices, messages, nil
}
128 changes: 128 additions & 0 deletions arbnode/espresso_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package arbnode

import (
"bytes"
"encoding/binary"
"fmt"
"testing"

espressoTypes "github.com/EspressoSystems/espresso-sequencer-go/types"
"github.com/offchainlabs/nitro/arbutil"
)

func mockMsgFetcher(index arbutil.MessageIndex) ([]byte, error) {
return []byte("message" + fmt.Sprint(index)), nil
}

func TestParsePayload(t *testing.T) {
msgPositions := []arbutil.MessageIndex{1, 2, 10, 24, 100}

rawPayload, cnt := buildRawHotShotPayload(msgPositions, mockMsgFetcher, 200*1024)
if cnt != len(msgPositions) {
t.Fatal("exceed transactions")
}

mockSignature := []byte("fake_signature")
fakeSigner := func(payload []byte) ([]byte, error) {
return mockSignature, nil
}
signedPayload, err := signHotShotPayload(rawPayload, fakeSigner)
if err != nil {
t.Fatalf("failed to sign payload: %v", err)
}

// Parse the signed payload
signature, indices, messages, err := ParsePayload(signedPayload)
if err != nil {
t.Fatalf("failed to parse payload: %v", err)
}

// Validate parsed data
if !bytes.Equal(signature, mockSignature) {
t.Errorf("expected signature 'fake_signature', got %v", mockSignature)
}

for i, index := range indices {
if arbutil.MessageIndex(index) != msgPositions[i] {
t.Errorf("expected index %d, got %d", msgPositions[i], index)
}
}

expectedMessages := [][]byte{
[]byte("message1"),
[]byte("message2"),
[]byte("message10"),
[]byte("message24"),
[]byte("message100"),
}
for i, message := range messages {
if !bytes.Equal(message, expectedMessages[i]) {
t.Errorf("expected message %s, got %s", expectedMessages[i], message)
}
}
}

func TestValidateIfPayloadIsInBlock(t *testing.T) {
msgPositions := []arbutil.MessageIndex{1, 2}

rawPayload, _ := buildRawHotShotPayload(msgPositions, mockMsgFetcher, 200*1024)
fakeSigner := func(payload []byte) ([]byte, error) {
return []byte("fake_signature"), nil
}
signedPayload, err := signHotShotPayload(rawPayload, fakeSigner)
if err != nil {
t.Fatalf("failed to sign payload: %v", err)
}

// Validate payload in a block
blockPayloads := []espressoTypes.Bytes{
signedPayload,
[]byte("other_payload"),
}

if !validateIfPayloadIsInBlock(signedPayload, blockPayloads) {
t.Error("expected payload to be validated in block")
}

if validateIfPayloadIsInBlock([]byte("invalid_payload"), blockPayloads) {
t.Error("did not expect invalid payload to be validated in block")
}
}

func TestParsePayloadInvalidCases(t *testing.T) {
invalidPayloads := []struct {
description string
payload []byte
}{
{
description: "Empty payload",
payload: []byte{},
},
{
description: "Signature size exceeds payload",
payload: append(make([]byte, 8), []byte("short")...),
},
{
description: "Message size exceeds remaining payload",
payload: func() []byte {
var payload []byte
sigSizeBuf := make([]byte, 8)
binary.BigEndian.PutUint64(sigSizeBuf, 0)
payload = append(payload, sigSizeBuf...)
msgSizeBuf := make([]byte, 8)
binary.BigEndian.PutUint64(msgSizeBuf, 100)
payload = append(payload, msgSizeBuf...)
return payload
}(),
},
}

for _, tc := range invalidPayloads {
t.Run(tc.description, func(t *testing.T) {
_, _, _, err := ParsePayload(tc.payload)
if err == nil {
t.Errorf("expected error for case '%s', but got none", tc.description)
}
})
}
}
1 change: 1 addition & 0 deletions arbnode/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var (
dbSchemaVersion []byte = []byte("_schemaVersion") // contains a uint64 representing the database schema version
espressoSubmittedPos []byte = []byte("_espressoSubmittedPos") // contains the current message indices of the last submitted txns
espressoSubmittedHash []byte = []byte("_espressoSubmittedHash") // contains the hash of the last submitted txn
espressoSubmittedPayload []byte = []byte("_espressoSubmittedPayload") // contains the payload of the last submitted espresso txn
espressoPendingTxnsPositions []byte = []byte("_espressoPendingTxnsPositions") // contains the index of the pending txns that need to be submitted to espresso
espressoLastConfirmedPos []byte = []byte("_espressoLastConfirmedPos") // contains the position of the last confirmed message
espressoSkipVerificationPos []byte = []byte("_espressoSkipVerificationPos") // contains the position of the latest message that should skip the validation due to hotshot liveness failure
Expand Down
Loading

0 comments on commit e0af17e

Please sign in to comment.