From 4bf08d7373da4645db3b774a4ed2edb5a24fef1d Mon Sep 17 00:00:00 2001 From: ImJeremyHe Date: Tue, 17 Dec 2024 14:22:01 +0800 Subject: [PATCH] build hotshot payload and add unitests for it --- arbnode/batch_poster.go | 5 + arbnode/espresso_utils.go | 126 +++++++++++++++++++++++ arbnode/espresso_utils_test.go | 128 ++++++++++++++++++++++++ arbnode/schema.go | 1 + arbnode/transaction_streamer.go | 171 +++++++++++++++++--------------- 5 files changed, 351 insertions(+), 80 deletions(-) create mode 100644 arbnode/espresso_utils.go create mode 100644 arbnode/espresso_utils_test.go diff --git a/arbnode/batch_poster.go b/arbnode/batch_poster.go index 65e41fdeb0..d1d6776e74 100644 --- a/arbnode/batch_poster.go +++ b/arbnode/batch_poster.go @@ -182,6 +182,7 @@ type BatchPosterConfig struct { UseEscapeHatch bool `koanf:"use-escape-hatch"` EspressoTxnsPollingInterval time.Duration `koanf:"espresso-txns-polling-interval"` EspressoSwitchDelayThreshold uint64 `koanf:"espresso-switch-delay-threshold"` + EspressoMaxTransactioSize uint64 `koanf:"espresso-max-transaction-size"` } func (c *BatchPosterConfig) Validate() error { @@ -244,6 +245,7 @@ func BatchPosterConfigAddOptions(prefix string, f *pflag.FlagSet) { redislock.AddConfigOptions(prefix+".redis-lock", f) dataposter.DataPosterConfigAddOptions(prefix+".data-poster", f, dataposter.DefaultDataPosterConfig) genericconf.WalletConfigAddOptions(prefix+".parent-chain-wallet", f, DefaultBatchPosterConfig.ParentChainWallet.Pathname) + f.Uint64(prefix+".espresso-max-transaction-size", DefaultBatchPosterConfig.EspressoSwitchDelayThreshold, "specifies the max size of a espresso transasction") } var DefaultBatchPosterConfig = BatchPosterConfig{ @@ -277,6 +279,7 @@ var DefaultBatchPosterConfig = BatchPosterConfig{ EspressoSwitchDelayThreshold: 350, LightClientAddress: "", HotShotUrl: "", + EspressoMaxTransactioSize: 900 * 1024, } var DefaultBatchPosterL1WalletConfig = genericconf.WalletConfig{ @@ -313,6 +316,7 @@ var TestBatchPosterConfig = BatchPosterConfig{ EspressoSwitchDelayThreshold: 10, LightClientAddress: "", HotShotUrl: "", + EspressoMaxTransactioSize: 25 * 1024, } type BatchPosterOpts struct { @@ -377,6 +381,7 @@ func NewBatchPoster(ctx context.Context, opts *BatchPosterOpts) (*BatchPoster, e opts.Streamer.UseEscapeHatch = opts.Config().UseEscapeHatch opts.Streamer.espressoTxnsPollingInterval = opts.Config().EspressoTxnsPollingInterval opts.Streamer.espressoSwitchDelayThreshold = opts.Config().EspressoSwitchDelayThreshold + opts.Streamer.espressoMaxTransactionSize = opts.Config().EspressoMaxTransactioSize } b := &BatchPoster{ diff --git a/arbnode/espresso_utils.go b/arbnode/espresso_utils.go new file mode 100644 index 0000000000..829665b375 --- /dev/null +++ b/arbnode/espresso_utils.go @@ -0,0 +1,126 @@ +package arbnode + +import ( + "bytes" + "encoding/binary" + "errors" + + espressoTypes "github.com/EspressoSystems/espresso-sequencer-go/types" + "github.com/ethereum/go-ethereum/log" + "github.com/offchainlabs/nitro/arbutil" +) + +const MAX_ATTESTATION_QUOTE_SIZE int = 4 * 1024 +const LEN_SIZE int = 8 +const INDEX_SIZE int = 8 + +func buildRawHotShotPayload( + msgPositions []arbutil.MessageIndex, + msgFetcher func(arbutil.MessageIndex) ([]byte, error), + maxSize uint64, +) ([]byte, int) { + + payload := []byte{} + msgCnt := 0 + + for _, p := range msgPositions { + sizeBuf := make([]byte, LEN_SIZE) + positionBuf := make([]byte, INDEX_SIZE) + msg, err := msgFetcher(p) + if err != nil { + log.Warn("failed to fetch the message", "pos", p) + break + } + binary.BigEndian.PutUint64(sizeBuf, uint64(len(msg))) + binary.BigEndian.PutUint64(positionBuf, uint64(p)) + + if len(payload)+len(sizeBuf)+len(msg)+len(positionBuf)+MAX_ATTESTATION_QUOTE_SIZE > int(maxSize) { + break + } + // Add the submitted txn position and the size of the message along with the message + payload = append(payload, positionBuf...) + payload = append(payload, sizeBuf...) + payload = append(payload, msg...) + msgCnt += 1 + } + return payload, msgCnt +} + +func signHotShotPayload( + unsigned []byte, + signer func([]byte) ([]byte, error), +) ([]byte, error) { + quote, err := signer(unsigned) + if err != nil { + return nil, err + } + + quoteSizeBuf := make([]byte, LEN_SIZE) + binary.BigEndian.PutUint64(quoteSizeBuf, uint64(len(quote))) + // Put the signature first. That would help easier parsing. + result := quoteSizeBuf + result = append(result, quote...) + result = append(result, unsigned...) + + return result, nil +} + +func validateIfPayloadIsInBlock(p []byte, payloads []espressoTypes.Bytes) bool { + validated := false + for _, payload := range payloads { + if bytes.Equal(p, payload) { + validated = true + break + } + } + return validated +} + +func ParsePayload(rawPayload []byte) (signature []byte, indices []uint64, messages [][]byte, err error) { + if len(rawPayload) < LEN_SIZE { + return nil, nil, nil, errors.New("payload too short to parse signature size") + } + + // Extract the signature size + signatureSize := binary.BigEndian.Uint64(rawPayload[:LEN_SIZE]) + currentPos := LEN_SIZE + + if len(rawPayload[currentPos:]) < int(signatureSize) { + return nil, nil, nil, errors.New("payload too short for signature") + } + + // Extract the signature + signature = rawPayload[currentPos : currentPos+int(signatureSize)] + currentPos += int(signatureSize) + + indices = []uint64{} + messages = [][]byte{} + + // Parse messages + for { + if len(rawPayload[currentPos:]) < LEN_SIZE+INDEX_SIZE { + break // No more messages to parse + } + + // Extract the index + index := binary.BigEndian.Uint64(rawPayload[currentPos : currentPos+INDEX_SIZE]) + currentPos += INDEX_SIZE + + // Extract the message size + messageSize := binary.BigEndian.Uint64(rawPayload[currentPos : currentPos+LEN_SIZE]) + currentPos += LEN_SIZE + + if len(rawPayload[currentPos:]) < int(messageSize) { + return nil, nil, nil, errors.New("message size mismatch") + } + + // Extract the message + message := rawPayload[currentPos : currentPos+int(messageSize)] + currentPos += int(messageSize) + + indices = append(indices, index) + messages = append(messages, message) + } + + return signature, indices, messages, nil +} diff --git a/arbnode/espresso_utils_test.go b/arbnode/espresso_utils_test.go new file mode 100644 index 0000000000..1bd476ab4f --- /dev/null +++ b/arbnode/espresso_utils_test.go @@ -0,0 +1,128 @@ +package arbnode + +import ( + "bytes" + "encoding/binary" + "fmt" + "testing" + + espressoTypes "github.com/EspressoSystems/espresso-sequencer-go/types" + "github.com/offchainlabs/nitro/arbutil" +) + +func mockMsgFetcher(index arbutil.MessageIndex) ([]byte, error) { + return []byte("message" + fmt.Sprint(index)), nil +} + +func TestParsePayload(t *testing.T) { + msgPositions := []arbutil.MessageIndex{1, 2, 10, 24, 100} + + rawPayload, cnt := buildRawHotShotPayload(msgPositions, mockMsgFetcher, 200*1024) + if cnt != len(msgPositions) { + t.Fatal("exceed transactions") + } + + mockSignature := []byte("fake_signature") + fakeSigner := func(payload []byte) ([]byte, error) { + return mockSignature, nil + } + signedPayload, err := signHotShotPayload(rawPayload, fakeSigner) + if err != nil { + t.Fatalf("failed to sign payload: %v", err) + } + + // Parse the signed payload + signature, indices, messages, err := ParsePayload(signedPayload) + if err != nil { + t.Fatalf("failed to parse payload: %v", err) + } + + // Validate parsed data + if !bytes.Equal(signature, mockSignature) { + t.Errorf("expected signature 'fake_signature', got %v", mockSignature) + } + + for i, index := range indices { + if arbutil.MessageIndex(index) != msgPositions[i] { + t.Errorf("expected index %d, got %d", msgPositions[i], index) + } + } + + expectedMessages := [][]byte{ + []byte("message1"), + []byte("message2"), + []byte("message10"), + []byte("message24"), + []byte("message100"), + } + for i, message := range messages { + if !bytes.Equal(message, expectedMessages[i]) { + t.Errorf("expected message %s, got %s", expectedMessages[i], message) + } + } +} + +func TestValidateIfPayloadIsInBlock(t *testing.T) { + msgPositions := []arbutil.MessageIndex{1, 2} + + rawPayload, _ := buildRawHotShotPayload(msgPositions, mockMsgFetcher, 200*1024) + fakeSigner := func(payload []byte) ([]byte, error) { + return []byte("fake_signature"), nil + } + signedPayload, err := signHotShotPayload(rawPayload, fakeSigner) + if err != nil { + t.Fatalf("failed to sign payload: %v", err) + } + + // Validate payload in a block + blockPayloads := []espressoTypes.Bytes{ + signedPayload, + []byte("other_payload"), + } + + if !validateIfPayloadIsInBlock(signedPayload, blockPayloads) { + t.Error("expected payload to be validated in block") + } + + if validateIfPayloadIsInBlock([]byte("invalid_payload"), blockPayloads) { + t.Error("did not expect invalid payload to be validated in block") + } +} + +func TestParsePayloadInvalidCases(t *testing.T) { + invalidPayloads := []struct { + description string + payload []byte + }{ + { + description: "Empty payload", + payload: []byte{}, + }, + { + description: "Signature size exceeds payload", + payload: append(make([]byte, 8), []byte("short")...), + }, + { + description: "Message size exceeds remaining payload", + payload: func() []byte { + var payload []byte + sigSizeBuf := make([]byte, 8) + binary.BigEndian.PutUint64(sigSizeBuf, 0) + payload = append(payload, sigSizeBuf...) + msgSizeBuf := make([]byte, 8) + binary.BigEndian.PutUint64(msgSizeBuf, 100) + payload = append(payload, msgSizeBuf...) + return payload + }(), + }, + } + + for _, tc := range invalidPayloads { + t.Run(tc.description, func(t *testing.T) { + _, _, _, err := ParsePayload(tc.payload) + if err == nil { + t.Errorf("expected error for case '%s', but got none", tc.description) + } + }) + } +} diff --git a/arbnode/schema.go b/arbnode/schema.go index 3ce1a3e871..fc5308ee4e 100644 --- a/arbnode/schema.go +++ b/arbnode/schema.go @@ -19,6 +19,7 @@ var ( dbSchemaVersion []byte = []byte("_schemaVersion") // contains a uint64 representing the database schema version espressoSubmittedPos []byte = []byte("_espressoSubmittedPos") // contains the current message indices of the last submitted txns espressoSubmittedHash []byte = []byte("_espressoSubmittedHash") // contains the hash of the last submitted txn + espressoSubmittedPayload []byte = []byte("_espressoSubmittedPayload") // contains the payload of the last submitted espresso txn espressoPendingTxnsPositions []byte = []byte("_espressoPendingTxnsPositions") // contains the index of the pending txns that need to be submitted to espresso espressoLastConfirmedPos []byte = []byte("_espressoLastConfirmedPos") // contains the position of the last confirmed message espressoSkipVerificationPos []byte = []byte("_espressoSkipVerificationPos") // contains the position of the latest message that should skip the validation due to hotshot liveness failure diff --git a/arbnode/transaction_streamer.go b/arbnode/transaction_streamer.go index 61af1ca4b6..10e1b277b6 100644 --- a/arbnode/transaction_streamer.go +++ b/arbnode/transaction_streamer.go @@ -85,6 +85,7 @@ type TransactionStreamer struct { lightClientReader lightclient.LightClientReaderInterface espressoTxnsPollingInterval time.Duration espressoSwitchDelayThreshold uint64 + espressoMaxTransactionSize uint64 // Public these fields for testing HotshotDown bool UseEscapeHatch bool @@ -1311,6 +1312,16 @@ func (s *TransactionStreamer) pollSubmittedTransactionForFinality(ctx context.Co return fmt.Errorf("error validating namespace proof (height: %d)", height) } + submittedPayload, err := s.getEspressoSubmittedPayload() + if err != nil { + return fmt.Errorf("submitted payload not found: %w", err) + } + + validated := validateIfPayloadIsInBlock(submittedPayload, resp.Transactions) + if !validated { + return fmt.Errorf("transactions fetched from HotShot doesn't contain the submitted payload") + } + snapshot, err := s.lightClientReader.FetchMerkleRoot(height, nil) if err != nil { return fmt.Errorf("%w (height: %d): %w", EspressoFetchMerkleRootErr, height, err) @@ -1346,11 +1357,8 @@ func (s *TransactionStreamer) pollSubmittedTransactionForFinality(ctx context.Co defer s.espressoTxnsStateInsertionMutex.Unlock() batch := s.db.NewBatch() - if err := s.setEspressoSubmittedPos(batch, nil); err != nil { - return fmt.Errorf("failed to set the submitted pos to nil: %w", err) - } - if err := s.setEspressoSubmittedHash(batch, nil); err != nil { - return fmt.Errorf("failed to set the submitted hash to nil: %w", err) + if err := s.cleanEspressoSubmittedData(batch); err != nil { + return nil } lastConfirmedPos := submittedTxnPos[len(submittedTxnPos)-1] if err := s.setEspressoLastConfirmedPos(batch, &lastConfirmedPos); err != nil { @@ -1403,6 +1411,17 @@ func (s *TransactionStreamer) getEspressoSubmittedHash() (*espressoTypes.TaggedB return hashParsed, nil } +func (s *TransactionStreamer) getEspressoSubmittedPayload() ([]byte, error) { + bytes, err := s.db.Get(espressoSubmittedHash) + if err != nil { + if dbutil.IsErrNotFound(err) { + return nil, nil + } + return nil, err + } + return bytes, nil +} + func (s *TransactionStreamer) getLastConfirmedPos() (*arbutil.MessageIndex, error) { lastConfirmedBytes, err := s.db.Get(espressoLastConfirmedPos) if err != nil { @@ -1484,6 +1503,33 @@ func (s *TransactionStreamer) setEspressoLastConfirmedPos(batch ethdb.KeyValueWr return nil } +func (s *TransactionStreamer) cleanEspressoSubmittedData(batch ethdb.Batch) error { + if err := s.setEspressoSubmittedPos(batch, nil); err != nil { + return fmt.Errorf("failed to set the submitted pos to nil: %w", err) + } + if err := s.setEspressoSubmittedPayload(batch, nil); err != nil { + return fmt.Errorf("failed to set the submitted pos to nil: %w", err) + } + if err := s.setEspressoSubmittedHash(batch, nil); err != nil { + return fmt.Errorf("failed to set the submitted hash to nil: %w", err) + } + return nil + +} + +func (s *TransactionStreamer) setEspressoSubmittedPayload(batch ethdb.KeyValueWriter, payload []byte) error { + if payload == nil { + err := batch.Delete(espressoSubmittedHash) + return err + } + err := batch.Put(espressoSubmittedPayload, payload) + if err != nil { + return err + + } + return nil +} + func (s *TransactionStreamer) setSkipVerificationPos(batch ethdb.KeyValueWriter, pos *arbutil.MessageIndex) error { posBytes, err := rlp.EncodeToBytes(pos) if err != nil { @@ -1600,24 +1646,29 @@ func (s *TransactionStreamer) submitEspressoTransactions(ctx context.Context) ti } if len(pendingTxnsPos) > 0 { - // get the message at the pending txn position - msgs := []arbostypes.MessageWithMetadata{} pendingTxnsPosToBeSubmitted := []arbutil.MessageIndex{} - for _, pos := range pendingTxnsPos { + + fetcher := func(pos arbutil.MessageIndex) ([]byte, error) { msg, err := s.GetMessage(pos) if err != nil { - log.Error("failed to get espresso submitted pos", "err", err) - return s.espressoTxnsPollingInterval + return nil, err } - if msg != nil { - msgs = append(msgs, *msg) - pendingTxnsPosToBeSubmitted = append(pendingTxnsPosToBeSubmitted, pos) + b, err := rlp.EncodeToBytes(msg) + if err != nil { + return nil, err } + return b, nil } - payload, msgCnt := s.buildHotShotPayload(&msgs, pendingTxnsPosToBeSubmitted) + payload, msgCnt := buildRawHotShotPayload(pendingTxnsPosToBeSubmitted, fetcher, s.espressoMaxTransactionSize) if msgCnt == 0 { - log.Error("failed to build the hotshot transaction: a large message has exceeded the size limit") + log.Error("failed to build the hotshot transaction: a large message has exceeded the size limit or failed to get a message from storage") + return s.espressoTxnsPollingInterval + } + + payload, err = signHotShotPayload(payload, s.getAttestationQuote) + if err != nil { + log.Error("failed to sign the hotshot payload", "err", err) return s.espressoTxnsPollingInterval } @@ -1655,6 +1706,11 @@ func (s *TransactionStreamer) submitEspressoTransactions(ctx context.Context) ti log.Error("failed to set the submitted hash", "err", err) return s.espressoTxnsPollingInterval } + err = s.setEspressoSubmittedPayload(batch, payload) + if err != nil { + log.Error("failed to set the espresso payload", "err", err) + return s.espressoTxnsPollingInterval + } err = batch.Write() if err != nil { @@ -1666,7 +1722,7 @@ func (s *TransactionStreamer) submitEspressoTransactions(ctx context.Context) ti return s.espressoTxnsPollingInterval } -func (s *TransactionStreamer) toggleEscapeHatch(ctx context.Context) error { +func (s *TransactionStreamer) checkEspressoLiveness(ctx context.Context) error { live, err := s.lightClientReader.IsHotShotLive(s.espressoSwitchDelayThreshold) if err != nil { return err @@ -1680,19 +1736,22 @@ func (s *TransactionStreamer) toggleEscapeHatch(ctx context.Context) error { return nil } - // If hotshot is up, escape hatch is disabled - // - check if escape hatch should be activated - // - check if the submitted transaction should be skipped from espresso verification + // If hotshot was previously up, now it is down if !live { log.Warn("enabling the escape hatch, hotshot is down") s.HotshotDown = true } + if !s.UseEscapeHatch { + return nil + } + submittedHash, err := s.getEspressoSubmittedHash() if err != nil { return err } + // No transaction is waiting for espresso finalization if submittedHash == nil { return nil } @@ -1715,6 +1774,7 @@ func (s *TransactionStreamer) toggleEscapeHatch(ctx context.Context) error { return err } if hotshotLive { + // This transaction will be still finalized return nil } submitted, err := s.getEspressoSubmittedPos() @@ -1731,27 +1791,17 @@ func (s *TransactionStreamer) toggleEscapeHatch(ctx context.Context) error { defer s.espressoTxnsStateInsertionMutex.Unlock() batch := s.db.NewBatch() - if s.UseEscapeHatch { - // If escape hatch is used, write down the allowed skip position - // to the database. Batch poster will read this and circumvent the espresso validation - // for certain messages - err = s.setEspressoSubmittedHash(batch, nil) - if err != nil { - return err - } - err = s.setEspressoSubmittedPos(batch, nil) - if err != nil { - return err - } - err = s.setEspressoPendingTxnsPos(batch, nil) - if err != nil { - return err - } - log.Warn("setting last skip verification position", "pos", last) - err = s.setSkipVerificationPos(batch, &last) - if err != nil { - return err - } + // If escape hatch is used, write down the allowed skip position + // to the database. Batch poster will read this and circumvent the espresso validation + // for certain messages + err = s.cleanEspressoSubmittedData(batch) + if err != nil { + return err + } + log.Warn("setting last skip verification position", "pos", last) + err = s.setSkipVerificationPos(batch, &last) + if err != nil { + return err } err = batch.Write() if err != nil { @@ -1778,7 +1828,7 @@ func (s *TransactionStreamer) espressoSwitch(ctx context.Context, ignored struct return retryRate } if enabledEspresso { - err := s.toggleEscapeHatch(ctx) + err := s.checkEspressoLiveness(ctx) if err != nil { if ctx.Err() != nil { return 0 @@ -1829,45 +1879,6 @@ func (s *TransactionStreamer) Start(ctxIn context.Context) error { return stopwaiter.CallIterativelyWith[struct{}](&s.StopWaiterSafe, s.executeMessages, s.newMessageNotifier) } -const ESPRESSO_TRANSACTION_SIZE_LIMIT int = 900 * 1024 -const MAX_ATTESTATION_QUOTE_SIZE int = 4 * 1024 - -func (t *TransactionStreamer) buildHotShotPayload(msgs *[]arbostypes.MessageWithMetadata, submittedTxnPos []arbutil.MessageIndex) (espressoTypes.Bytes, int) { - payload := []byte{} - msgCnt := 0 - - for i, msg := range *msgs { - sizeBuf := make([]byte, 8) - positionBuf := make([]byte, 8) - msgBytes, err := rlp.EncodeToBytes(msg) - if err != nil { - return nil, 0 - } - binary.BigEndian.PutUint64(sizeBuf, uint64(len(msgBytes))) - binary.BigEndian.PutUint64(positionBuf, uint64(submittedTxnPos[i])) - - if len(payload)+len(sizeBuf)+len(msgBytes)+len(positionBuf)+MAX_ATTESTATION_QUOTE_SIZE > ESPRESSO_TRANSACTION_SIZE_LIMIT { - break - } - // Add the submitted txn position and the size of the message along with the message - payload = append(payload, positionBuf...) - payload = append(payload, sizeBuf...) - payload = append(payload, msgBytes...) - msgCnt += 1 - } - - // Also add the attestation quote - quote, err := t.getAttestationQuote(payload) - if err != nil { - return nil, 0 - } - - // append the quote to the payload, this is important for making sure that only payload sent by TEE is accepted as valid - payload = append(payload, quote...) - - return payload, msgCnt -} - /** * This function generates the attestation quote for the user data. * The user data is hashed using keccak256 and then 32 bytes of padding is added to the hash.