From 7fe13c82eb3ee2ce64dd08ee38eae39eabc4d12d Mon Sep 17 00:00:00 2001 From: NathanBSC Date: Thu, 6 Jun 2024 11:22:52 +0800 Subject: [PATCH] consensu/parlia: optimize prepareTurnTerm and verifyTurnTerm --- consensus/parlia/bohrFork.go | 35 +++++++++++++++++++++++-- consensus/parlia/parlia.go | 50 ++++++++---------------------------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/consensus/parlia/bohrFork.go b/consensus/parlia/bohrFork.go index 2b712b3216..714b57afd4 100644 --- a/consensus/parlia/bohrFork.go +++ b/consensus/parlia/bohrFork.go @@ -2,12 +2,14 @@ package parlia import ( "context" + "errors" "math/big" mrand "math/rand" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core/systemcontracts" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/internal/ethapi" @@ -16,9 +18,38 @@ import ( "github.com/ethereum/go-ethereum/rpc" ) -func (p *Parlia) getTurnTerm(header *types.Header) (turnTerm *big.Int, err error) { +func (p *Parlia) getTurnTerm(chain consensus.ChainHeaderReader, header *types.Header) (*uint8, error) { + if header.Number.Uint64()%p.config.Epoch != 0 || + !p.chainConfig.IsBohr(header.Number, header.Time) { + return nil, nil + } + + parent := chain.GetHeaderByHash(header.ParentHash) + if parent == nil { + return nil, errors.New("parent not found") + } + + var turnTerm uint8 + if p.chainConfig.IsBohr(parent.Number, parent.Time) { + turnTermFromContract, err := p.getTurnTermFromContract(parent) + if err != nil { + return nil, err + } + if turnTermFromContract == nil { + return nil, errors.New("unexpected error when getTurnTermFromContract") + } + turnTerm = uint8(turnTermFromContract.Int64()) + } else { + turnTerm = uint8(defaultTurnTerm) + } + log.Debug("getTurnTerm", "turnTerm", turnTerm) + + return &turnTerm, nil +} + +func (p *Parlia) getTurnTermFromContract(header *types.Header) (turnTerm *big.Int, err error) { if params.UseRandTurnTerm { - return p.getRandTurnTerm(header) + return p.getRandTurnTerm(header) // used as a mock to get turnTerm from the contract } ctx, cancel := context.WithCancel(context.Background()) diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index 68e56c3d50..10d614e97f 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -903,29 +903,13 @@ func (p *Parlia) prepareValidators(header *types.Header) error { } func (p *Parlia) prepareTurnTerm(chain consensus.ChainHeaderReader, header *types.Header) error { - if header.Number.Uint64()%p.config.Epoch != 0 || - !p.chainConfig.IsBohr(header.Number, header.Time) { - return nil + turnTerm, err := p.getTurnTerm(chain, header) + if err != nil { + return err } - parent := chain.GetHeaderByHash(header.ParentHash) - if parent == nil { - return errors.New("parent not found") - } - - if p.chainConfig.IsBohr(parent.Number, parent.Time) { - turnTerm, err := p.getTurnTerm(parent) - if err != nil { - return err - } - if turnTerm == nil { - return errors.New("unexpected error when getTurnTerm") - } - header.Extra = append(header.Extra, byte(int(turnTerm.Int64()))) - log.Debug("prepareTurnTerm", "turnTerm", turnTerm.Int64()) - } else { - log.Debug("prepareTurnTerm", "turnTerm", "defaultTurnTerm") - header.Extra = append(header.Extra, byte(int(defaultTurnTerm))) + if turnTerm != nil { + header.Extra = append(header.Extra, *turnTerm) } return nil @@ -1126,25 +1110,13 @@ func (p *Parlia) verifyTurnTerm(chain consensus.ChainHeaderReader, header *types return err } if turnTermFromHeader != nil { - parent := chain.GetHeaderByHash(header.ParentHash) - if parent == nil { - return errors.New("parent not found") + turnTerm, err := p.getTurnTerm(chain, header) + if err != nil { + return err } - - if p.chainConfig.IsBohr(parent.Number, parent.Time) { - turnTermFromContract, err := p.getTurnTerm(parent) - if err != nil { - return err - } - if turnTermFromContract != nil && uint8(turnTermFromContract.Int64()) == *turnTermFromHeader { - log.Debug("verifyTurnTerm", "turnTerm", turnTermFromContract.Int64()) - return nil - } - } else { - if uint8(defaultTurnTerm) == *turnTermFromHeader { - log.Debug("verifyTurnTerm", "turnTerm", "defaultTurnTerm") - return nil - } + if turnTerm != nil && *turnTerm == *turnTermFromHeader { + log.Debug("verifyTurnTerm", "turnTerm", *turnTerm) + return nil } }