diff --git a/consensus/polybft/blockchain_wrapper.go b/consensus/polybft/blockchain_wrapper.go index f73157edaf..e6d86578d4 100644 --- a/consensus/polybft/blockchain_wrapper.go +++ b/consensus/polybft/blockchain_wrapper.go @@ -38,8 +38,7 @@ type blockchainBackend interface { txPool txPoolInterface, blockTime time.Duration, logger hclog.Logger) (blockBuilder, error) // ProcessBlock builds a final block from given 'block' on top of 'parent'. - ProcessBlock(parent *types.Header, block *types.Block, - callback func(*state.Transition) error) (*types.FullBlock, error) + ProcessBlock(parent *types.Header, block *types.Block) (*types.FullBlock, error) // GetStateProviderForBlock returns a reference to make queries to the state at 'block'. GetStateProviderForBlock(block *types.Header) (contract.Provider, error) @@ -83,8 +82,7 @@ func (p *blockchainWrapper) CommitBlock(block *types.FullBlock) error { } // ProcessBlock builds a final block from given 'block' on top of 'parent' -func (p *blockchainWrapper) ProcessBlock(parent *types.Header, block *types.Block, - callback func(*state.Transition) error) (*types.FullBlock, error) { +func (p *blockchainWrapper) ProcessBlock(parent *types.Header, block *types.Block) (*types.FullBlock, error) { header := block.Header.Copy() start := time.Now().UTC() @@ -100,12 +98,6 @@ func (p *blockchainWrapper) ProcessBlock(parent *types.Header, block *types.Bloc } } - if callback != nil { - if err := callback(transition); err != nil { - return nil, err - } - } - _, root, err := transition.Commit() if err != nil { return nil, fmt.Errorf("failed to commit the state changes: %w", err) diff --git a/consensus/polybft/extra.go b/consensus/polybft/extra.go index 2f6a6b5f0e..42ba463738 100644 --- a/consensus/polybft/extra.go +++ b/consensus/polybft/extra.go @@ -426,7 +426,8 @@ func (c *CheckpointData) ValidateBasic(parentCheckpoint *CheckpointData) error { // Validate encapsulates validation logic for checkpoint data // (with regards to current and next epoch validators) func (c *CheckpointData) Validate(parentCheckpoint *CheckpointData, - currentValidators validator.AccountSet, nextValidators validator.AccountSet) error { + currentValidators validator.AccountSet, nextValidators validator.AccountSet, + exitRootHash types.Hash) error { if err := c.ValidateBasic(parentCheckpoint); err != nil { return err } @@ -459,6 +460,12 @@ func (c *CheckpointData) Validate(parentCheckpoint *CheckpointData, return fmt.Errorf("epoch number should not change for epoch-ending block") } + // exit root hash of proposer and + // validator that validates proposal have to match + if exitRootHash != c.EventRoot { + return fmt.Errorf("exit root hash not as expected") + } + return nil } diff --git a/consensus/polybft/extra_test.go b/consensus/polybft/extra_test.go index 4f33798291..53bac6ca83 100644 --- a/consensus/polybft/extra_test.go +++ b/consensus/polybft/extra_test.go @@ -646,6 +646,7 @@ func TestCheckpointData_Validate(t *testing.T) { nextValidators validator.AccountSet currentValidatorsHash types.Hash nextValidatorsHash types.Hash + exitRootHash types.Hash errString string }{ { @@ -713,6 +714,17 @@ func TestCheckpointData_Validate(t *testing.T) { nextValidatorsHash: nextValidatorsHash, errString: "epoch number should not change for epoch-ending block", }, + { + name: "Invalid exit root hash", + parentEpochNumber: 2, + epochNumber: 2, + currentValidators: currentValidators, + nextValidators: currentValidators, + currentValidatorsHash: currentValidatorsHash, + nextValidatorsHash: currentValidatorsHash, + exitRootHash: types.BytesToHash([]byte{0, 1, 2, 3, 4, 5, 6, 7}), + errString: "exit root hash not as expected", + }, } for _, c := range cases { @@ -723,9 +735,10 @@ func TestCheckpointData_Validate(t *testing.T) { EpochNumber: c.epochNumber, CurrentValidatorsHash: c.currentValidatorsHash, NextValidatorsHash: c.nextValidatorsHash, + EventRoot: c.exitRootHash, } parentCheckpoint := &CheckpointData{EpochNumber: c.parentEpochNumber} - err := checkpoint.Validate(parentCheckpoint, c.currentValidators, c.nextValidators) + err := checkpoint.Validate(parentCheckpoint, c.currentValidators, c.nextValidators, types.ZeroHash) if c.errString != "" { require.ErrorContains(t, err, c.errString) diff --git a/consensus/polybft/fsm.go b/consensus/polybft/fsm.go index 6c3f79a295..90fe2eb5b8 100644 --- a/consensus/polybft/fsm.go +++ b/consensus/polybft/fsm.go @@ -334,24 +334,26 @@ func (f *fsm) Validate(proposal []byte) error { } currentValidators := f.validators.Accounts() - nextValidators := f.validators.Accounts() - validateExtraData := func(transition *state.Transition) error { - if f.isEndOfEpoch { - if !extra.Validators.Equals(f.newValidatorsDelta) { - return errValidatorSetDeltaMismatch - } - } else if !extra.Validators.IsEmpty() { - // delta should be empty in non epoch ending blocks - return errValidatorsUpdateInNonEpochEnding + // validate validators delta + if f.isEndOfEpoch { + if !extra.Validators.Equals(f.newValidatorsDelta) { + return errValidatorSetDeltaMismatch } + } else if !extra.Validators.IsEmpty() { + // delta should be empty in non epoch ending blocks + return errValidatorsUpdateInNonEpochEnding + } - nextValidators, err = f.getValidatorsTransition(extra.Validators) - if err != nil { - return err - } + nextValidators, err := f.getValidatorsTransition(extra.Validators) + if err != nil { + return err + } - return extra.Checkpoint.Validate(parentExtra.Checkpoint, currentValidators, nextValidators) + // validate checkpoint data + if err := extra.Checkpoint.Validate(parentExtra.Checkpoint, + currentValidators, nextValidators, f.exitEventRootHash); err != nil { + return err } if f.logger.IsTrace() && block.Number() > 1 { @@ -363,7 +365,7 @@ func (f *fsm) Validate(proposal []byte) error { f.logger.Trace("[FSM Validate]", "Block", block.Number(), "parent validators", validators) } - stateBlock, err := f.backend.ProcessBlock(f.parent, &block, validateExtraData) + stateBlock, err := f.backend.ProcessBlock(f.parent, &block) if err != nil { return err } diff --git a/consensus/polybft/fsm_test.go b/consensus/polybft/fsm_test.go index 0e83ba27ba..0d71d15413 100644 --- a/consensus/polybft/fsm_test.go +++ b/consensus/polybft/fsm_test.go @@ -710,6 +710,69 @@ func TestFSM_ValidateCommit_Good(t *testing.T) { require.NoError(t, err) } +func TestFSM_Validate_ExitEventRootNotExpected(t *testing.T) { + t.Parallel() + + const ( + accountsCount = 5 + parentBlockNumber = 25 + signaturesCount = 3 + ) + + validators := validator.NewTestValidators(t, accountsCount) + parentExtra := createTestExtraObject(validators.GetPublicIdentities(), validator.AccountSet{}, 4, signaturesCount, signaturesCount) + parentExtra.Validators = nil + + parent := &types.Header{ + Number: parentBlockNumber, + ExtraData: parentExtra.MarshalRLPTo(nil), + } + parent.ComputeHash() + + polybftBackendMock := new(polybftBackendMock) + polybftBackendMock.On("GetValidators", mock.Anything, mock.Anything).Return(validators.GetPublicIdentities(), nil).Once() + + extra := createTestExtraObject(validators.GetPublicIdentities(), validator.AccountSet{}, 4, signaturesCount, signaturesCount) + extra.Validators = nil + parentCheckpointHash, err := extra.Checkpoint.Hash(0, parentBlockNumber, parent.Hash) + require.NoError(t, err) + + currentValSetHash, err := validators.GetPublicIdentities().Hash() + require.NoError(t, err) + + extra.Parent = createSignature(t, validators.GetPrivateIdentities(), parentCheckpointHash, bls.DomainCheckpointManager) + extra.Checkpoint.EpochNumber = 1 + extra.Checkpoint.CurrentValidatorsHash = currentValSetHash + extra.Checkpoint.NextValidatorsHash = currentValSetHash + + stateBlock := createDummyStateBlock(parent.Number+1, types.Hash{100, 15}, extra.MarshalRLPTo(nil)) + + proposalHash, err := extra.Checkpoint.Hash(0, stateBlock.Block.Number(), stateBlock.Block.Hash()) + require.NoError(t, err) + + stateBlock.Block.Header.Hash = proposalHash + stateBlock.Block.Header.ParentHash = parent.Hash + stateBlock.Block.Header.Timestamp = uint64(time.Now().UTC().Unix()) + stateBlock.Block.Transactions = []*types.Transaction{} + + proposal := stateBlock.Block.MarshalRLP() + + fsm := &fsm{ + parent: parent, + backend: new(blockchainMock), + validators: validators.ToValidatorSet(), + logger: hclog.NewNullLogger(), + polybftBackend: polybftBackendMock, + config: &PolyBFTConfig{BlockTimeDrift: 1}, + exitEventRootHash: types.BytesToHash([]byte{0, 1, 2, 3, 4}), // expect this to be in proposal extra + } + + err = fsm.Validate(proposal) + require.ErrorContains(t, err, "exit root hash not as expected") + + polybftBackendMock.AssertExpectations(t) +} + func TestFSM_Validate_EpochEndingBlock_MismatchInDeltas(t *testing.T) { t.Parallel() @@ -763,7 +826,7 @@ func TestFSM_Validate_EpochEndingBlock_MismatchInDeltas(t *testing.T) { proposal := stateBlock.Block.MarshalRLP() blockchainMock := new(blockchainMock) - blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything, mock.Anything). + blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything). Return(stateBlock, error(nil)). Maybe() @@ -854,7 +917,7 @@ func TestFSM_Validate_EpochEndingBlock_UpdatingValidatorSetInNonEpochEndingBlock proposal := stateBlock.Block.MarshalRLP() blockchainMock := new(blockchainMock) - blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything, mock.Anything). + blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything). Return(stateBlock, error(nil)). Maybe() @@ -1062,7 +1125,7 @@ func TestFSM_Insert_Good(t *testing.T) { builderMock := newBlockBuilderMock(builtBlock) chainMock := &blockchainMock{} chainMock.On("CommitBlock", mock.Anything).Return(error(nil)).Once() - chainMock.On("ProcessBlock", mock.Anything, mock.Anything, mock.Anything). + chainMock.On("ProcessBlock", mock.Anything, mock.Anything). Return(builtBlock, error(nil)). Maybe() diff --git a/consensus/polybft/mocks_test.go b/consensus/polybft/mocks_test.go index 85305b5735..750067ba92 100644 --- a/consensus/polybft/mocks_test.go +++ b/consensus/polybft/mocks_test.go @@ -40,14 +40,8 @@ func (m *blockchainMock) NewBlockBuilder(parent *types.Header, coinbase types.Ad return args.Get(0).(blockBuilder), args.Error(1) //nolint:forcetypeassert } -func (m *blockchainMock) ProcessBlock(parent *types.Header, block *types.Block, callback func(*state.Transition) error) (*types.FullBlock, error) { - args := m.Called(parent, block, callback) - - if callback != nil { - if err := callback(nil); err != nil { - return nil, err - } - } +func (m *blockchainMock) ProcessBlock(parent *types.Header, block *types.Block) (*types.FullBlock, error) { + args := m.Called(parent, block) return args.Get(0).(*types.FullBlock), args.Error(1) //nolint:forcetypeassert }