diff --git a/dot/state/epoch.go b/dot/state/epoch.go index d8f02ed2c7..0b5ab5db83 100644 --- a/dot/state/epoch.go +++ b/dot/state/epoch.go @@ -12,14 +12,16 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" ) var ( + ErrConfigNotFound = errors.New("config data not found") ErrEpochNotInMemory = errors.New("epoch not found in memory map") errHashNotInMemory = errors.New("hash not found in memory map") - errEpochDataNotFound = errors.New("epoch data not found in the database") + errEpochNotInDatabase = errors.New("epoch data not found in the database") errHashNotPersisted = errors.New("hash with next epoch not found in database") errNoPreRuntimeDigest = errors.New("header does not contain pre-runtime digest") ) @@ -58,11 +60,11 @@ type EpochState struct { nextEpochDataLock sync.RWMutex // nextEpochData follows the format map[epoch]map[block hash]next epoch data - nextEpochData map[uint64]map[common.Hash]types.NextEpochData + nextEpochData nextEpochMap[types.NextEpochData] nextConfigDataLock sync.RWMutex // nextConfigData follows the format map[epoch]map[block hash]next config data - nextConfigData map[uint64]map[common.Hash]types.NextConfigData + nextConfigData nextEpochMap[types.NextConfigData] } // NewEpochStateFromGenesis returns a new EpochState given information for the first epoch, fetched from the runtime @@ -90,8 +92,8 @@ func NewEpochStateFromGenesis(db chaindb.Database, blockState *BlockState, blockState: blockState, db: epochDB, epochLength: genesisConfig.EpochLength, - nextEpochData: make(map[uint64]map[common.Hash]types.NextEpochData), - nextConfigData: make(map[uint64]map[common.Hash]types.NextConfigData), + nextEpochData: make(nextEpochMap[types.NextEpochData]), + nextConfigData: make(nextEpochMap[types.NextConfigData]), } auths, err := types.BABEAuthorityRawToAuthority(genesisConfig.GenesisAuthorities) @@ -151,8 +153,8 @@ func NewEpochState(db chaindb.Database, blockState *BlockState) (*EpochState, er db: chaindb.NewTable(db, epochPrefix), epochLength: epochLength, skipToEpoch: skipToEpoch, - nextEpochData: make(map[uint64]map[common.Hash]types.NextEpochData), - nextConfigData: make(map[uint64]map[common.Hash]types.NextConfigData), + nextEpochData: make(nextEpochMap[types.NextEpochData]), + nextConfigData: make(nextEpochMap[types.NextConfigData]), }, nil } @@ -247,25 +249,29 @@ func (s *EpochState) SetEpochData(epoch uint64, info *types.EpochData) error { // if the header params is nil then it will search only in database func (s *EpochState) GetEpochData(epoch uint64, header *types.Header) (*types.EpochData, error) { epochData, err := s.getEpochDataFromDatabase(epoch) - if err == nil && epochData != nil { + if err != nil && !errors.Is(err, chaindb.ErrKeyNotFound) { + return nil, fmt.Errorf("failed to retrieve epoch data from database: %w", err) + } + + if epochData != nil { return epochData, nil } - if err != nil && !errors.Is(err, chaindb.ErrKeyNotFound) { - return nil, fmt.Errorf("failed to get epoch data from database: %w", err) + if header == nil { + return nil, errEpochNotInDatabase } - // lookup in-memory only if header is given - if header != nil && errors.Is(err, chaindb.ErrKeyNotFound) { - epochData, err = s.getEpochDataFromMemory(epoch, header) - if err != nil { - return nil, fmt.Errorf("failed to get epoch data from memory: %w", err) - } + s.nextEpochDataLock.RLock() + defer s.nextEpochDataLock.RUnlock() + + inMemoryEpochData, err := s.nextEpochData.Retrieve(s.blockState, epoch, header) + if err != nil { + return nil, fmt.Errorf("failed to get epoch data from memory: %w", err) } - if epochData == nil { - return nil, fmt.Errorf("%w: for epoch %d and header with hash %s", - errEpochDataNotFound, epoch, header.Hash()) + epochData, err = inMemoryEpochData.ToEpochData() + if err != nil { + return nil, fmt.Errorf("cannot transform into epoch data: %w", err) } return epochData, nil @@ -287,32 +293,6 @@ func (s *EpochState) getEpochDataFromDatabase(epoch uint64) (*types.EpochData, e return raw.ToEpochData() } -// getEpochDataFromMemory retrieves the right epoch data that belongs to the header parameter -func (s *EpochState) getEpochDataFromMemory(epoch uint64, header *types.Header) (*types.EpochData, error) { - s.nextEpochDataLock.RLock() - defer s.nextEpochDataLock.RUnlock() - - atEpoch, has := s.nextEpochData[epoch] - if !has { - return nil, fmt.Errorf("%w: %d", ErrEpochNotInMemory, epoch) - } - - headerHash := header.Hash() - - for hash, value := range atEpoch { - isDescendant, err := s.blockState.IsDescendantOf(hash, headerHash) - if err != nil { - return nil, fmt.Errorf("cannot verify the ancestry: %w", err) - } - - if isDescendant { - return value.ToEpochData() - } - } - - return nil, fmt.Errorf("%w: %s", errHashNotInMemory, headerHash) -} - // GetLatestEpochData returns the EpochData for the current epoch func (s *EpochState) GetLatestEpochData() (*types.EpochData, error) { curr, err := s.GetCurrentEpoch() @@ -323,26 +303,6 @@ func (s *EpochState) GetLatestEpochData() (*types.EpochData, error) { return s.GetEpochData(curr, nil) } -// HasEpochData returns whether epoch data exists for a given epoch -func (s *EpochState) HasEpochData(epoch uint64) (bool, error) { - has, err := s.db.Has(epochDataKey(epoch)) - if err == nil && has { - return has, nil - } - - // we can have `has == false` and `err == nil` - // so ensure the error is not nil in the condition below. - if err != nil && !errors.Is(chaindb.ErrKeyNotFound, err) { - return false, fmt.Errorf("cannot check database for epoch key %d: %w", epoch, err) - } - - s.nextEpochDataLock.Lock() - defer s.nextEpochDataLock.Unlock() - - _, has = s.nextEpochData[epoch] - return has, nil -} - // SetConfigData sets the BABE config data for a given epoch func (s *EpochState) SetConfigData(epoch uint64, info *types.ConfigData) error { enc, err := scale.Marshal(*info) @@ -364,28 +324,44 @@ func (s *EpochState) setLatestConfigData(epoch uint64) error { return s.db.Put(latestConfigDataKey, buf) } -// GetConfigData returns the config data for a given epoch persisted in database -// otherwise tries to get the data from the in-memory map using the header. -// If the header params is nil then it will search only in the database -func (s *EpochState) GetConfigData(epoch uint64, header *types.Header) (*types.ConfigData, error) { - configData, err := s.getConfigDataFromDatabase(epoch) - if err == nil && configData != nil { - return configData, nil - } +// GetConfigData returns the newest config data for a given epoch persisted in database +// otherwise tries to get the data from the in-memory map using the header. If we don't +// find any config data for the current epoch we lookup in the previous epochs, as the spec says: +// - The supplied configuration data are intended to be used from the next epoch onwards. +// If the header params is nil then it will search only in the database. +func (s *EpochState) GetConfigData(epoch uint64, header *types.Header) (configData *types.ConfigData, err error) { + for tryEpoch := int(epoch); tryEpoch >= 0; tryEpoch-- { + configData, err = s.getConfigDataFromDatabase(uint64(tryEpoch)) + if err != nil && !errors.Is(err, chaindb.ErrKeyNotFound) { + return nil, fmt.Errorf("failed to retrieve config epoch from database: %w", err) + } - if err != nil && !errors.Is(err, chaindb.ErrKeyNotFound) { - return nil, fmt.Errorf("failed to get config data from database: %w", err) - } else if header == nil { - // if no header is given then skip the lookup in-memory - return configData, nil - } + if configData != nil { + return configData, nil + } - configData, err = s.getConfigDataFromMemory(epoch, header) - if err != nil { - return nil, fmt.Errorf("failed to get config data from memory: %w", err) + // there is no config data for the `tryEpoch` on database and we don't have a + // header to lookup in the memory map, so let's go retrieve the previous epoch + if header == nil { + continue + } + + // we will check in the memory map and if we don't find the data + // then we continue searching through the previous epoch + s.nextConfigDataLock.RLock() + inMemoryConfigData, err := s.nextConfigData.Retrieve(s.blockState, uint64(tryEpoch), header) + s.nextConfigDataLock.RUnlock() + + if errors.Is(err, ErrEpochNotInMemory) { + continue + } else if err != nil { + return nil, fmt.Errorf("failed to get config data from memory: %w", err) + } + + return inMemoryConfigData.ToConfigData(), err } - return configData, nil + return nil, fmt.Errorf("%w: epoch %d", ErrConfigNotFound, epoch) } // getConfigDataFromDatabase returns the BABE config data for a given epoch persisted in database @@ -404,26 +380,36 @@ func (s *EpochState) getConfigDataFromDatabase(epoch uint64) (*types.ConfigData, return info, nil } -// getConfigDataFromMemory retrieves the BABE config data for a given epoch that belongs to the header parameter -func (s *EpochState) getConfigDataFromMemory(epoch uint64, header *types.Header) (*types.ConfigData, error) { - s.nextConfigDataLock.RLock() - defer s.nextConfigDataLock.RUnlock() +type nextEpochMap[T types.NextEpochData | types.NextConfigData] map[uint64]map[common.Hash]T - atEpoch, has := s.nextConfigData[epoch] +func (nem nextEpochMap[T]) Retrieve(blockState *BlockState, epoch uint64, header *types.Header) (*T, error) { + atEpoch, has := nem[epoch] if !has { return nil, fmt.Errorf("%w: %d", ErrEpochNotInMemory, epoch) } headerHash := header.Hash() - for hash, value := range atEpoch { - isDescendant, err := s.blockState.IsDescendantOf(hash, headerHash) + isDescendant, err := blockState.IsDescendantOf(hash, headerHash) + + // sometimes while moving to the next epoch is possible the header + // is not fully imported by the blocktree, in this case we will use + // its parent header which migth be already imported. + if errors.Is(err, blocktree.ErrEndNodeNotFound) { + parentHeader, err := blockState.GetHeader(header.ParentHash) + if err != nil { + return nil, fmt.Errorf("cannot get parent header: %w", err) + } + + return nem.Retrieve(blockState, epoch, parentHeader) + } + if err != nil { return nil, fmt.Errorf("cannot verify the ancestry: %w", err) } if isDescendant { - return value.ToConfigData(), nil + return &value, nil } } @@ -441,24 +427,6 @@ func (s *EpochState) GetLatestConfigData() (*types.ConfigData, error) { return s.GetConfigData(epoch, nil) } -// HasConfigData returns whether config data exists for a given epoch -func (s *EpochState) HasConfigData(epoch uint64) (bool, error) { - has, err := s.db.Has(configDataKey(epoch)) - if err == nil && has { - return has, nil - } - - if err != nil && !errors.Is(chaindb.ErrKeyNotFound, err) { - return false, fmt.Errorf("cannot check database for epoch key %d: %w", epoch, err) - } - - s.nextConfigDataLock.Lock() - defer s.nextConfigDataLock.Unlock() - - _, has = s.nextConfigData[epoch] - return has, nil -} - // GetStartSlotForEpoch returns the first slot in the given epoch. // If 0 is passed as the epoch, it returns the start slot for the current epoch. func (s *EpochState) GetStartSlotForEpoch(epoch uint64) (uint64, error) { diff --git a/dot/state/epoch_test.go b/dot/state/epoch_test.go index 234e22ca55..75b4736560 100644 --- a/dot/state/epoch_test.go +++ b/dot/state/epoch_test.go @@ -54,9 +54,6 @@ func TestEpochState_CurrentEpoch(t *testing.T) { func TestEpochState_EpochData(t *testing.T) { s := newEpochStateFromGenesis(t) - has, err := s.HasEpochData(0) - require.NoError(t, err) - require.True(t, has) keyring, err := keystore.NewSr25519Keyring() require.NoError(t, err) diff --git a/lib/babe/epoch.go b/lib/babe/epoch.go index 1e524094fb..7cf51873be 100644 --- a/lib/babe/epoch.go +++ b/lib/babe/epoch.go @@ -24,15 +24,25 @@ func (b *Service) initiateEpoch(epoch uint64) (*epochData, error) { } } - epochData, startSlot, err := b.getEpochDataAndStartSlot(epoch) + bestBlockHeader, err := b.blockState.BestBlockHeader() + if err != nil { + return nil, fmt.Errorf("cannot get the best block header: %w", err) + } + + epochData, err := b.getEpochData(epoch, bestBlockHeader) if err != nil { return nil, fmt.Errorf("cannot get epoch data and start slot: %w", err) } + startSlot, err := b.epochState.GetStartSlotForEpoch(epoch) + if err != nil { + return nil, fmt.Errorf("cannot get start slot for epoch %d: %w", epoch, err) + } + // if we're at genesis, we need to determine when the first slot of the network will be // by checking when we will be able to produce block 1. // note that this assumes there will only be one producer of block 1 - if b.blockState.BestBlockHash() == b.blockState.GenesisHash() { + if bestBlockHeader.Hash() == b.blockState.GenesisHash() { startSlot, err = b.getFirstAuthoringSlot(epoch, epochData) if err != nil { return nil, fmt.Errorf("cannot get first authoring slot: %w", err) @@ -75,78 +85,43 @@ func (b *Service) checkAndSetFirstSlot() error { return nil } -func (b *Service) getEpochDataAndStartSlot(epoch uint64) (*epochData, uint64, error) { +func (b *Service) getEpochData(epoch uint64, bestBlock *types.Header) (*epochData, error) { if epoch == 0 { - startSlot, err := b.epochState.GetStartSlotForEpoch(epoch) - if err != nil { - return nil, 0, fmt.Errorf("cannot get start slot for epoch %d: %w", epoch, err) - } - epochData, err := b.getLatestEpochData() if err != nil { - return nil, 0, fmt.Errorf("cannot get latest epoch data: %w", err) + return nil, fmt.Errorf("cannot get latest epoch data: %w", err) } - return epochData, startSlot, nil - } - - has, err := b.epochState.HasEpochData(epoch) - if err != nil { - return nil, 0, fmt.Errorf("cannot check epoch state: %w", err) - } - - if !has { - logger.Criticalf("%s number=%d", errNoEpochData, epoch) - return nil, 0, fmt.Errorf("%w: for epoch %d", errNoEpochData, epoch) + return epochData, nil } - data, err := b.epochState.GetEpochData(epoch, nil) + currEpochData, err := b.epochState.GetEpochData(epoch, bestBlock) if err != nil { - return nil, 0, fmt.Errorf("cannot get epoch data for epoch %d: %w", epoch, err) + return nil, fmt.Errorf("cannot get epoch data for epoch %d: %w", epoch, err) } - idx, err := b.getAuthorityIndex(data.Authorities) + currentConfigData, err := b.epochState.GetConfigData(epoch, bestBlock) if err != nil { - return nil, 0, fmt.Errorf("cannot get authority index: %w", err) + return nil, fmt.Errorf("cannot get config data for epoch %d: %w", epoch, err) } - has, err = b.epochState.HasConfigData(epoch) + threshold, err := CalculateThreshold(currentConfigData.C1, currentConfigData.C2, len(currEpochData.Authorities)) if err != nil { - return nil, 0, fmt.Errorf("cannot check for config data for epoch %d: %w", epoch, err) - } - - var cfgData *types.ConfigData - if has { - cfgData, err = b.epochState.GetConfigData(epoch, nil) - if err != nil { - return nil, 0, fmt.Errorf("cannot get config data for epoch %d: %w", epoch, err) - } - } else { - cfgData, err = b.epochState.GetLatestConfigData() - if err != nil { - return nil, 0, fmt.Errorf("cannot get latest config data from epoch state: %w", err) - } + return nil, fmt.Errorf("cannot calculate threshold: %w", err) } - threshold, err := CalculateThreshold(cfgData.C1, cfgData.C2, len(data.Authorities)) + idx, err := b.getAuthorityIndex(currEpochData.Authorities) if err != nil { - return nil, 0, fmt.Errorf("cannot calculate threshold: %w", err) + return nil, fmt.Errorf("cannot get authority index: %w", err) } - ed := &epochData{ - randomness: data.Randomness, - authorities: data.Authorities, + return &epochData{ + randomness: currEpochData.Randomness, + authorities: currEpochData.Authorities, authorityIndex: idx, threshold: threshold, - allowedSlots: types.AllowedSlots(cfgData.SecondarySlots), - } - - startSlot, err := b.epochState.GetStartSlotForEpoch(epoch) - if err != nil { - return nil, 0, fmt.Errorf("cannot get start slot for epoch %d: %w", epoch, err) - } - - return ed, startSlot, nil + allowedSlots: types.AllowedSlots(currentConfigData.SecondarySlots), + }, nil } func (b *Service) getLatestEpochData() (resEpochData *epochData, error error) { diff --git a/lib/babe/epoch_test.go b/lib/babe/epoch_test.go index 0e4d62ccfe..1580dea7eb 100644 --- a/lib/babe/epoch_test.go +++ b/lib/babe/epoch_test.go @@ -71,19 +71,6 @@ func TestBabeService_checkAndSetFirstSlot(t *testing.T) { } func TestBabeService_getEpochDataAndStartSlot(t *testing.T) { - ctrl := gomock.NewController(t) - mockBlockState := NewMockBlockState(ctrl) - mockEpochState0 := NewMockEpochState(ctrl) - mockEpochState1 := NewMockEpochState(ctrl) - mockEpochState2 := NewMockEpochState(ctrl) - - mockEpochState0.EXPECT().GetStartSlotForEpoch(uint64(0)).Return(uint64(1), nil) - mockEpochState1.EXPECT().GetStartSlotForEpoch(uint64(1)).Return(uint64(201), nil) - mockEpochState2.EXPECT().GetStartSlotForEpoch(uint64(1)).Return(uint64(201), nil) - - mockEpochState1.EXPECT().HasEpochData(uint64(1)).Return(true, nil) - mockEpochState2.EXPECT().HasEpochData(uint64(1)).Return(true, nil) - kp := keyring.Alice().(*sr25519.Keypair) authority := types.NewAuthority(kp.Public(), uint64(1)) testEpochData := &types.EpochData{ @@ -91,55 +78,21 @@ func TestBabeService_getEpochDataAndStartSlot(t *testing.T) { Authorities: []types.Authority{*authority}, } - mockEpochState1.EXPECT().GetEpochData(uint64(1), nil).Return(testEpochData, nil) - mockEpochState2.EXPECT().GetEpochData(uint64(1), nil).Return(testEpochData, nil) - - mockEpochState1.EXPECT().HasConfigData(uint64(1)).Return(true, nil) - mockEpochState2.EXPECT().HasConfigData(uint64(1)).Return(false, nil) - testConfigData := &types.ConfigData{ C1: 1, C2: 1, } - mockEpochState1.EXPECT().GetConfigData(uint64(1), nil).Return(testConfigData, nil) - testLatestConfigData := &types.ConfigData{ C1: 1, C2: 2, } - mockEpochState2.EXPECT().GetLatestConfigData().Return(testLatestConfigData, nil) - testEpochDataEpoch0 := &types.EpochData{ Randomness: [32]byte{9}, Authorities: []types.Authority{*authority}, } - mockEpochState0.EXPECT().GetLatestEpochData().Return(testEpochDataEpoch0, nil) - mockEpochState0.EXPECT().GetLatestConfigData().Return(testConfigData, nil) - - bs0 := &Service{ - authority: true, - keypair: kp, - epochState: mockEpochState0, - blockState: mockBlockState, - } - - bs1 := &Service{ - authority: true, - keypair: kp, - epochState: mockEpochState1, - blockState: mockBlockState, - } - - bs2 := &Service{ - authority: true, - keypair: kp, - epochState: mockEpochState2, - blockState: mockBlockState, - } - threshold0, err := CalculateThreshold(testConfigData.C1, testConfigData.C2, 1) require.NoError(t, err) @@ -147,16 +100,27 @@ func TestBabeService_getEpochDataAndStartSlot(t *testing.T) { require.NoError(t, err) cases := []struct { + service func(*gomock.Controller) *Service name string - service *Service epoch uint64 expected *epochData expectedStartSlot uint64 }{ { - name: "should get epoch data for epoch 0", - service: bs0, - epoch: 0, + name: "should get epoch data for epoch 0", + service: func(ctrl *gomock.Controller) *Service { + mockEpochState := NewMockEpochState(ctrl) + + mockEpochState.EXPECT().GetLatestEpochData().Return(testEpochDataEpoch0, nil) + mockEpochState.EXPECT().GetLatestConfigData().Return(testConfigData, nil) + + return &Service{ + authority: true, + keypair: kp, + epochState: mockEpochState, + } + }, + epoch: 0, expected: &epochData{ randomness: testEpochDataEpoch0.Randomness, authorities: testEpochDataEpoch0.Authorities, @@ -166,9 +130,20 @@ func TestBabeService_getEpochDataAndStartSlot(t *testing.T) { expectedStartSlot: 1, }, { - name: "should get epoch data for epoch 1 with config data from epoch 1", - service: bs1, - epoch: 1, + name: "should get epoch data for epoch 1 with config data from epoch 1", + service: func(ctrl *gomock.Controller) *Service { + mockEpochState := NewMockEpochState(ctrl) + + mockEpochState.EXPECT().GetEpochData(uint64(1), nil).Return(testEpochData, nil) + mockEpochState.EXPECT().GetConfigData(uint64(1), nil).Return(testConfigData, nil) + + return &Service{ + authority: true, + keypair: kp, + epochState: mockEpochState, + } + }, + epoch: 1, expected: &epochData{ randomness: testEpochData.Randomness, authorities: testEpochData.Authorities, @@ -178,9 +153,20 @@ func TestBabeService_getEpochDataAndStartSlot(t *testing.T) { expectedStartSlot: 201, }, { - name: "should get epoch data for epoch 1 and config data for epoch 0", - service: bs2, - epoch: 1, + name: "should get epoch data for epoch 1 and config data for epoch 0", + service: func(ctrl *gomock.Controller) *Service { + mockEpochState := NewMockEpochState(ctrl) + + mockEpochState.EXPECT().GetEpochData(uint64(1), nil).Return(testEpochData, nil) + mockEpochState.EXPECT().GetConfigData(uint64(1), nil).Return(testLatestConfigData, nil) + + return &Service{ + authority: true, + keypair: kp, + epochState: mockEpochState, + } + }, + epoch: 1, expected: &epochData{ randomness: testEpochData.Randomness, authorities: testEpochData.Authorities, @@ -191,10 +177,15 @@ func TestBabeService_getEpochDataAndStartSlot(t *testing.T) { }, } - for _, tc := range cases { - res, startSlot, err := tc.service.getEpochDataAndStartSlot(tc.epoch) - require.NoError(t, err) - require.Equal(t, tc.expected, res) - require.Equal(t, tc.expectedStartSlot, startSlot) + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + service := tt.service(ctrl) + + res, err := service.getEpochData(tt.epoch, nil) + require.NoError(t, err) + require.Equal(t, tt.expected, res) + }) } } diff --git a/lib/babe/errors.go b/lib/babe/errors.go index 527a0a8915..b851f2c05a 100644 --- a/lib/babe/errors.go +++ b/lib/babe/errors.go @@ -69,12 +69,10 @@ var ( errNilStorageState = errors.New("storage state is nil") errNilParentHeader = errors.New("parent header is nil") errInvalidResult = errors.New("invalid error value") - errNoEpochData = errors.New("no epoch data found for upcoming epoch") errFirstBlockTimeout = errors.New("timed out waiting for first block") errChannelClosed = errors.New("block notifier channel was closed") errOverPrimarySlotThreshold = errors.New("cannot claim slot, over primary threshold") errNotOurTurnToPropose = errors.New("cannot claim slot, not our turn to propose a block") - errNoConfigData = errors.New("cannot find ConfigData for epoch") errGetEpochData = errors.New("get epochData error") errFailedFinalisation = errors.New("failed to check finalisation") errMissingDigest = errors.New("chain head missing digest") diff --git a/lib/babe/mock_state_test.go b/lib/babe/mock_state_test.go index f5bdfd2d91..38815094a9 100644 --- a/lib/babe/mock_state_test.go +++ b/lib/babe/mock_state_test.go @@ -678,36 +678,6 @@ func (mr *MockEpochStateMockRecorder) GetStartSlotForEpoch(arg0 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStartSlotForEpoch", reflect.TypeOf((*MockEpochState)(nil).GetStartSlotForEpoch), arg0) } -// HasConfigData mocks base method. -func (m *MockEpochState) HasConfigData(arg0 uint64) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasConfigData", arg0) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// HasConfigData indicates an expected call of HasConfigData. -func (mr *MockEpochStateMockRecorder) HasConfigData(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasConfigData", reflect.TypeOf((*MockEpochState)(nil).HasConfigData), arg0) -} - -// HasEpochData mocks base method. -func (m *MockEpochState) HasEpochData(arg0 uint64) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasEpochData", arg0) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// HasEpochData indicates an expected call of HasEpochData. -func (mr *MockEpochStateMockRecorder) HasEpochData(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasEpochData", reflect.TypeOf((*MockEpochState)(nil).HasEpochData), arg0) -} - // SetCurrentEpoch mocks base method. func (m *MockEpochState) SetCurrentEpoch(arg0 uint64) error { m.ctrl.T.Helper() diff --git a/lib/babe/state.go b/lib/babe/state.go index 5b168e850a..897f06bd35 100644 --- a/lib/babe/state.go +++ b/lib/babe/state.go @@ -66,12 +66,9 @@ type EpochState interface { GetCurrentEpoch() (uint64, error) SetEpochData(uint64, *types.EpochData) error - HasEpochData(epoch uint64) (bool, error) - GetEpochData(epoch uint64, header *types.Header) (*types.EpochData, error) GetConfigData(epoch uint64, header *types.Header) (*types.ConfigData, error) - HasConfigData(epoch uint64) (bool, error) GetLatestConfigData() (*types.ConfigData, error) GetStartSlotForEpoch(epoch uint64) (uint64, error) GetEpochForBlock(header *types.Header) (uint64, error) diff --git a/lib/babe/verify.go b/lib/babe/verify.go index 671114ef2d..50b59c2211 100644 --- a/lib/babe/verify.go +++ b/lib/babe/verify.go @@ -202,7 +202,7 @@ func (v *VerificationManager) getVerifierInfo(epoch uint64, header *types.Header return nil, fmt.Errorf("failed to get epoch data for epoch %d: %w", epoch, err) } - configData, err := v.getConfigData(epoch, header) + configData, err := v.epochState.GetConfigData(epoch, header) if err != nil { return nil, fmt.Errorf("failed to get config data: %w", err) } @@ -220,21 +220,6 @@ func (v *VerificationManager) getVerifierInfo(epoch uint64, header *types.Header }, nil } -func (v *VerificationManager) getConfigData(epoch uint64, header *types.Header) (*types.ConfigData, error) { - for i := int(epoch); i >= 0; i-- { - has, err := v.epochState.HasConfigData(uint64(i)) - if err != nil { - return nil, err - } else if !has { - continue - } - - return v.epochState.GetConfigData(uint64(i), header) - } - - return nil, errNoConfigData -} - // verifier is a BABE verifier for a specific authority set, randomness, and threshold type verifier struct { blockState BlockState diff --git a/lib/babe/verify_test.go b/lib/babe/verify_test.go index f6c106ca7c..d8c8432ebf 100644 --- a/lib/babe/verify_test.go +++ b/lib/babe/verify_test.go @@ -8,6 +8,7 @@ import ( "fmt" "testing" + "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" @@ -730,63 +731,6 @@ func Test_verifier_verifyAuthorshipRight(t *testing.T) { } } -func TestVerificationManager_getConfigData(t *testing.T) { - ctrl := gomock.NewController(t) - mockBlockState := NewMockBlockState(ctrl) - mockEpochStateEmpty := NewMockEpochState(ctrl) - mockEpochStateHasErr := NewMockEpochState(ctrl) - mockEpochStateGetErr := NewMockEpochState(ctrl) - - testHeader := types.NewEmptyHeader() - - mockEpochStateEmpty.EXPECT().HasConfigData(uint64(0)).Return(false, nil) - mockEpochStateHasErr.EXPECT().HasConfigData(uint64(0)).Return(false, errNoConfigData) - mockEpochStateGetErr.EXPECT().HasConfigData(uint64(0)).Return(true, nil) - mockEpochStateGetErr.EXPECT().GetConfigData(uint64(0), testHeader).Return(nil, errNoConfigData) - - vm0, err := NewVerificationManager(mockBlockState, mockEpochStateEmpty) - assert.NoError(t, err) - vm1, err := NewVerificationManager(mockBlockState, mockEpochStateHasErr) - assert.NoError(t, err) - vm2, err := NewVerificationManager(mockBlockState, mockEpochStateGetErr) - assert.NoError(t, err) - tests := []struct { - name string - vm *VerificationManager - epoch uint64 - exp *types.ConfigData - expErr error - }{ - { - name: "cant find ConfigData", - vm: vm0, - expErr: errNoConfigData, - }, - { - name: "hasConfigData error", - vm: vm1, - expErr: errNoConfigData, - }, - { - name: "getConfigData error", - vm: vm2, - expErr: errNoConfigData, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - v := tt.vm - res, err := v.getConfigData(tt.epoch, testHeader) - if tt.expErr != nil { - assert.EqualError(t, err, tt.expErr.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tt.exp, res) - }) - } -} - func TestVerificationManager_getVerifierInfo(t *testing.T) { ctrl := gomock.NewController(t) mockBlockState := NewMockBlockState(ctrl) @@ -797,13 +741,12 @@ func TestVerificationManager_getVerifierInfo(t *testing.T) { testHeader := types.NewEmptyHeader() - mockEpochStateGetErr.EXPECT().GetEpochData(uint64(0), testHeader).Return(nil, errNoConfigData) + mockEpochStateGetErr.EXPECT().GetEpochData(uint64(0), testHeader).Return(nil, state.ErrEpochNotInMemory) mockEpochStateHasErr.EXPECT().GetEpochData(uint64(0), testHeader).Return(&types.EpochData{}, nil) - mockEpochStateHasErr.EXPECT().HasConfigData(uint64(0)).Return(false, errNoConfigData) + mockEpochStateHasErr.EXPECT().GetConfigData(uint64(0), testHeader).Return(&types.ConfigData{}, state.ErrConfigNotFound) mockEpochStateThresholdErr.EXPECT().GetEpochData(uint64(0), testHeader).Return(&types.EpochData{}, nil) - mockEpochStateThresholdErr.EXPECT().HasConfigData(uint64(0)).Return(true, nil) mockEpochStateThresholdErr.EXPECT().GetConfigData(uint64(0), testHeader). Return(&types.ConfigData{ C1: 3, @@ -811,7 +754,6 @@ func TestVerificationManager_getVerifierInfo(t *testing.T) { }, nil) mockEpochStateOk.EXPECT().GetEpochData(uint64(0), testHeader).Return(&types.EpochData{}, nil) - mockEpochStateOk.EXPECT().HasConfigData(uint64(0)).Return(true, nil) mockEpochStateOk.EXPECT().GetConfigData(uint64(0), testHeader). Return(&types.ConfigData{ C1: 1, @@ -837,12 +779,12 @@ func TestVerificationManager_getVerifierInfo(t *testing.T) { { name: "getEpochData error", vm: vm0, - expErr: fmt.Errorf("failed to get epoch data for epoch %d: %w", 0, errNoConfigData), + expErr: fmt.Errorf("failed to get epoch data for epoch %d: %w", 0, state.ErrEpochNotInMemory), }, { name: "getConfigData error", vm: vm1, - expErr: fmt.Errorf("failed to get config data: %w", errNoConfigData), + expErr: fmt.Errorf("failed to get config data: %w", state.ErrConfigNotFound), }, { name: "calculate threshold error",