diff --git a/cl/beacon/synced_data/interface.go b/cl/beacon/synced_data/interface.go index 356f51405f7..ee35df33aad 100644 --- a/cl/beacon/synced_data/interface.go +++ b/cl/beacon/synced_data/interface.go @@ -6,6 +6,7 @@ import "github.com/ledgerwatch/erigon/cl/phase1/core/state" type SyncedData interface { OnHeadState(newState *state.CachingBeaconState) (err error) HeadState() *state.CachingBeaconState + HeadStateReader() state.BeaconStateReader Syncing() bool HeadSlot() uint64 } diff --git a/cl/beacon/synced_data/mock_services/synced_data_mock.go b/cl/beacon/synced_data/mock_services/synced_data_mock.go index 672add792eb..33f8e1b7cc7 100644 --- a/cl/beacon/synced_data/mock_services/synced_data_mock.go +++ b/cl/beacon/synced_data/mock_services/synced_data_mock.go @@ -67,6 +67,20 @@ func (mr *MockSyncedDataMockRecorder) HeadState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadState", reflect.TypeOf((*MockSyncedData)(nil).HeadState)) } +// HeadStateReader mocks base method. +func (m *MockSyncedData) HeadStateReader() state.BeaconStateReader { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HeadStateReader") + ret0, _ := ret[0].(state.BeaconStateReader) + return ret0 +} + +// HeadStateReader indicates an expected call of HeadStateReader. +func (mr *MockSyncedDataMockRecorder) HeadStateReader() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadStateReader", reflect.TypeOf((*MockSyncedData)(nil).HeadStateReader)) +} + // OnHeadState mocks base method. func (m *MockSyncedData) OnHeadState(arg0 *state.CachingBeaconState) error { m.ctrl.T.Helper() diff --git a/cl/beacon/synced_data/synced_data.go b/cl/beacon/synced_data/synced_data.go index 7b633af1436..6d982b4df3e 100644 --- a/cl/beacon/synced_data/synced_data.go +++ b/cl/beacon/synced_data/synced_data.go @@ -43,6 +43,10 @@ func (s *SyncedDataManager) HeadState() *state.CachingBeaconState { return nil } +func (s *SyncedDataManager) HeadStateReader() state.BeaconStateReader { + return s.HeadState() +} + func (s *SyncedDataManager) Syncing() bool { if !s.enabled { return false diff --git a/cl/phase1/core/state/interface.go b/cl/phase1/core/state/interface.go new file mode 100644 index 00000000000..8da7dee7897 --- /dev/null +++ b/cl/phase1/core/state/interface.go @@ -0,0 +1,12 @@ +package state + +import libcommon "github.com/ledgerwatch/erigon-lib/common" + +// BeaconStateReader is an interface for reading the beacon state. +// +//go:generate mockgen -destination=./mock_services/beacon_state_reader.go -package=mock_services . BeaconStateReader +type BeaconStateReader interface { + ValidatorPublicKey(index int) (libcommon.Bytes48, error) + GetDomain(domainType [4]byte, epoch uint64) ([]byte, error) + CommitteeCount(epoch uint64) uint64 +} diff --git a/cl/phase1/core/state/mock_services/beacon_state_reader.go b/cl/phase1/core/state/mock_services/beacon_state_reader.go new file mode 100644 index 00000000000..4efc825d91d --- /dev/null +++ b/cl/phase1/core/state/mock_services/beacon_state_reader.go @@ -0,0 +1,84 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ledgerwatch/erigon/cl/phase1/core/state (interfaces: BeaconStateReader) +// +// Generated by this command: +// +// mockgen -destination=./mock_services/beacon_state_reader.go -package=mock_services . BeaconStateReader +// + +// Package mock_services is a generated GoMock package. +package mock_services + +import ( + reflect "reflect" + + common "github.com/ledgerwatch/erigon-lib/common" + gomock "go.uber.org/mock/gomock" +) + +// MockBeaconStateReader is a mock of BeaconStateReader interface. +type MockBeaconStateReader struct { + ctrl *gomock.Controller + recorder *MockBeaconStateReaderMockRecorder +} + +// MockBeaconStateReaderMockRecorder is the mock recorder for MockBeaconStateReader. +type MockBeaconStateReaderMockRecorder struct { + mock *MockBeaconStateReader +} + +// NewMockBeaconStateReader creates a new mock instance. +func NewMockBeaconStateReader(ctrl *gomock.Controller) *MockBeaconStateReader { + mock := &MockBeaconStateReader{ctrl: ctrl} + mock.recorder = &MockBeaconStateReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBeaconStateReader) EXPECT() *MockBeaconStateReaderMockRecorder { + return m.recorder +} + +// CommitteeCount mocks base method. +func (m *MockBeaconStateReader) CommitteeCount(arg0 uint64) uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CommitteeCount", arg0) + ret0, _ := ret[0].(uint64) + return ret0 +} + +// CommitteeCount indicates an expected call of CommitteeCount. +func (mr *MockBeaconStateReaderMockRecorder) CommitteeCount(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitteeCount", reflect.TypeOf((*MockBeaconStateReader)(nil).CommitteeCount), arg0) +} + +// GetDomain mocks base method. +func (m *MockBeaconStateReader) GetDomain(arg0 [4]byte, arg1 uint64) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDomain", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDomain indicates an expected call of GetDomain. +func (mr *MockBeaconStateReaderMockRecorder) GetDomain(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomain", reflect.TypeOf((*MockBeaconStateReader)(nil).GetDomain), arg0, arg1) +} + +// ValidatorPublicKey mocks base method. +func (m *MockBeaconStateReader) ValidatorPublicKey(arg0 int) (common.Bytes48, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidatorPublicKey", arg0) + ret0, _ := ret[0].(common.Bytes48) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidatorPublicKey indicates an expected call of ValidatorPublicKey. +func (mr *MockBeaconStateReaderMockRecorder) ValidatorPublicKey(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidatorPublicKey", reflect.TypeOf((*MockBeaconStateReader)(nil).ValidatorPublicKey), arg0) +} diff --git a/cl/phase1/network/services/attestation_service.go b/cl/phase1/network/services/attestation_service.go index f13439ae926..3e8cc3ce101 100644 --- a/cl/phase1/network/services/attestation_service.go +++ b/cl/phase1/network/services/attestation_service.go @@ -20,6 +20,8 @@ import ( var ( computeSubnetForAttestation = subnets.ComputeSubnetForAttestation computeCommitteeCountPerSlot = subnets.ComputeCommitteeCountPerSlot + computeSigningRoot = fork.ComputeSigningRoot + blsVerify = bls.Verify ) type attestationService struct { @@ -60,7 +62,7 @@ func (s *attestationService) ProcessMessage(ctx context.Context, subnet *uint64, committeeIndex = att.AttestantionData().CommitteeIndex() targetEpoch = att.AttestantionData().Target().Epoch() ) - headState := s.syncedDataManager.HeadState() + headState := s.syncedDataManager.HeadStateReader() if headState == nil { return ErrIgnore } @@ -148,11 +150,11 @@ func (s *attestationService) ProcessMessage(ctx context.Context, subnet *uint64, if err != nil { return fmt.Errorf("unable to get the domain: %v", err) } - signingRoot, err := fork.ComputeSigningRoot(att.AttestantionData(), domain) + signingRoot, err := computeSigningRoot(att.AttestantionData(), domain) if err != nil { return fmt.Errorf("unable to get signing root: %v", err) } - if valid, err := bls.Verify(signature[:], signingRoot[:], pubKey[:]); err != nil { + if valid, err := blsVerify(signature[:], signingRoot[:], pubKey[:]); err != nil { return err } else if !valid { return fmt.Errorf("invalid signature") diff --git a/cl/phase1/network/services/attestation_service_test.go b/cl/phase1/network/services/attestation_service_test.go index 58425d72ea0..ff15fd444b3 100644 --- a/cl/phase1/network/services/attestation_service_test.go +++ b/cl/phase1/network/services/attestation_service_test.go @@ -6,11 +6,13 @@ import ( "testing" "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/types/ssz" mockSync "github.com/ledgerwatch/erigon/cl/beacon/synced_data/mock_services" "github.com/ledgerwatch/erigon/cl/clparams" "github.com/ledgerwatch/erigon/cl/cltypes" "github.com/ledgerwatch/erigon/cl/cltypes/solid" "github.com/ledgerwatch/erigon/cl/phase1/core/state" + mockState "github.com/ledgerwatch/erigon/cl/phase1/core/state/mock_services" "github.com/ledgerwatch/erigon/cl/phase1/forkchoice" "github.com/ledgerwatch/erigon/cl/utils/eth_clock" mockCommittee "github.com/ledgerwatch/erigon/cl/validator/committee_subscription/mock_services" @@ -38,20 +40,25 @@ type attestationTestSuite struct { gomockCtrl *gomock.Controller mockForkChoice *forkchoice.ForkChoiceStorageMock syncedData *mockSync.MockSyncedData + beaconStateReader *mockState.MockBeaconStateReader committeeSubscibe *mockCommittee.MockCommitteeSubscribe ethClock *eth_clock.MockEthereumClock attService AttestationService + beaconConfig *clparams.BeaconChainConfig } func (t *attestationTestSuite) SetupTest() { t.gomockCtrl = gomock.NewController(t.T()) t.mockForkChoice = &forkchoice.ForkChoiceStorageMock{} t.syncedData = mockSync.NewMockSyncedData(t.gomockCtrl) + t.beaconStateReader = mockState.NewMockBeaconStateReader(t.gomockCtrl) t.committeeSubscibe = mockCommittee.NewMockCommitteeSubscribe(t.gomockCtrl) t.ethClock = eth_clock.NewMockEthereumClock(t.gomockCtrl) - beaconConfig := &clparams.BeaconChainConfig{SlotsPerEpoch: mockSlotsPerEpoch} + t.beaconConfig = &clparams.BeaconChainConfig{SlotsPerEpoch: mockSlotsPerEpoch} netConfig := &clparams.NetworkConfig{} - t.attService = NewAttestationService(t.mockForkChoice, t.committeeSubscibe, t.ethClock, t.syncedData, beaconConfig, netConfig) + computeSigningRoot = func(obj ssz.HashableSSZ, domain []byte) ([32]byte, error) { return [32]byte{}, nil } + blsVerify = func(sig []byte, msg []byte, pubKeys []byte) (bool, error) { return true, nil } + t.attService = NewAttestationService(t.mockForkChoice, t.committeeSubscibe, t.ethClock, t.syncedData, t.beaconConfig, netConfig) } func (t *attestationTestSuite) TearDownTest() { @@ -73,8 +80,8 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "Test attestation with committee index out of range", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 1 } }, @@ -88,8 +95,8 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "Test attestation with wrong subnet", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 5 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { @@ -106,8 +113,8 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "Test attestation with wrong slot (current_slot < slot)", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 5 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { @@ -125,8 +132,8 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "Attestation is aggregated", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 5 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { @@ -148,8 +155,8 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "Attestation is empty", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 5 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { @@ -169,16 +176,51 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { wantErr: true, }, { - name: "block header not found", + name: "invalid signature", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 5 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { return 1 } t.ethClock.EXPECT().GetCurrentSlot().Return(mockSlot).Times(1) + t.beaconStateReader.EXPECT().ValidatorPublicKey(gomock.Any()).Return(common.Bytes48{}, nil).Times(1) + t.beaconStateReader.EXPECT().GetDomain(t.beaconConfig.DomainBeaconAttester, att.AttestantionData().Target().Epoch()).Return([]byte{}, nil).Times(1) + computeSigningRoot = func(obj ssz.HashableSSZ, domain []byte) ([32]byte, error) { + return [32]byte{}, nil + } + blsVerify = func(sig []byte, msg []byte, pubKeys []byte) (bool, error) { + return false, nil + } + }, + args: args{ + ctx: context.Background(), + subnet: uint64Ptr(1), + msg: att, + }, + wantErr: true, + }, + { + name: "block header not found", + mock: func() { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { + return 8 + } + computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { + return 1 + } + t.ethClock.EXPECT().GetCurrentSlot().Return(mockSlot).Times(1) + t.beaconStateReader.EXPECT().ValidatorPublicKey(gomock.Any()).Return(common.Bytes48{}, nil).Times(1) + t.beaconStateReader.EXPECT().GetDomain(t.beaconConfig.DomainBeaconAttester, att.AttestantionData().Target().Epoch()).Return([]byte{}, nil).Times(1) + computeSigningRoot = func(obj ssz.HashableSSZ, domain []byte) ([32]byte, error) { + return [32]byte{}, nil + } + blsVerify = func(sig []byte, msg []byte, pubKeys []byte) (bool, error) { + return true, nil + } }, args: args{ ctx: context.Background(), @@ -190,14 +232,22 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "invalid target block", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 8 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { return 1 } t.ethClock.EXPECT().GetCurrentSlot().Return(mockSlot).Times(1) + t.beaconStateReader.EXPECT().ValidatorPublicKey(gomock.Any()).Return(common.Bytes48{}, nil).Times(1) + t.beaconStateReader.EXPECT().GetDomain(t.beaconConfig.DomainBeaconAttester, att.AttestantionData().Target().Epoch()).Return([]byte{}, nil).Times(1) + computeSigningRoot = func(obj ssz.HashableSSZ, domain []byte) ([32]byte, error) { + return [32]byte{}, nil + } + blsVerify = func(sig []byte, msg []byte, pubKeys []byte) (bool, error) { + return true, nil + } t.mockForkChoice.Headers = map[common.Hash]*cltypes.BeaconBlockHeader{ att.AttestantionData().BeaconBlockRoot(): {}, // wrong block root } @@ -212,14 +262,22 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "invalid finality checkpoint", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 8 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { return 1 } t.ethClock.EXPECT().GetCurrentSlot().Return(mockSlot).Times(1) + t.beaconStateReader.EXPECT().ValidatorPublicKey(gomock.Any()).Return(common.Bytes48{}, nil).Times(1) + t.beaconStateReader.EXPECT().GetDomain(t.beaconConfig.DomainBeaconAttester, att.AttestantionData().Target().Epoch()).Return([]byte{}, nil).Times(1) + computeSigningRoot = func(obj ssz.HashableSSZ, domain []byte) ([32]byte, error) { + return [32]byte{}, nil + } + blsVerify = func(sig []byte, msg []byte, pubKeys []byte) (bool, error) { + return true, nil + } t.mockForkChoice.Headers = map[common.Hash]*cltypes.BeaconBlockHeader{ att.AttestantionData().BeaconBlockRoot(): {}, } @@ -242,14 +300,22 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { { name: "success", mock: func() { - t.syncedData.EXPECT().HeadState().Return(&state.CachingBeaconState{}).Times(1) - computeCommitteeCountPerSlot = func(_ *state.CachingBeaconState, _, _ uint64) uint64 { + t.syncedData.EXPECT().HeadStateReader().Return(t.beaconStateReader).Times(1) + computeCommitteeCountPerSlot = func(_ state.BeaconStateReader, _, _ uint64) uint64 { return 8 } computeSubnetForAttestation = func(_, _, _, _, _ uint64) uint64 { return 1 } t.ethClock.EXPECT().GetCurrentSlot().Return(mockSlot).Times(1) + t.beaconStateReader.EXPECT().ValidatorPublicKey(gomock.Any()).Return(common.Bytes48{}, nil).Times(1) + t.beaconStateReader.EXPECT().GetDomain(t.beaconConfig.DomainBeaconAttester, att.AttestantionData().Target().Epoch()).Return([]byte{}, nil).Times(1) + computeSigningRoot = func(obj ssz.HashableSSZ, domain []byte) ([32]byte, error) { + return [32]byte{}, nil + } + blsVerify = func(sig []byte, msg []byte, pubKeys []byte) (bool, error) { + return true, nil + } t.mockForkChoice.Headers = map[common.Hash]*cltypes.BeaconBlockHeader{ att.AttestantionData().BeaconBlockRoot(): {}, } @@ -278,8 +344,8 @@ func (t *attestationTestSuite) TestAttestationProcessMessage() { tt.mock() err := t.attService.ProcessMessage(tt.args.ctx, tt.args.subnet, tt.args.msg) if tt.wantErr { - log.Printf("%v", err) - t.Require().Error(err) + log.Printf("err msg: %v", err) + t.Require().Error(err, err.Error()) } else { t.Require().NoError(err) } diff --git a/cl/phase1/network/subnets/subnets.go b/cl/phase1/network/subnets/subnets.go index 17a9a515724..1fba86c09ee 100644 --- a/cl/phase1/network/subnets/subnets.go +++ b/cl/phase1/network/subnets/subnets.go @@ -64,7 +64,7 @@ func ComputeSubnetForAttestation(committeePerSlot, slot, committeeIndex, slotsPe return (committeesSinceEpochStart + committeeIndex) % attSubnetCount } -func ComputeCommitteeCountPerSlot(s *state.CachingBeaconState, slot uint64, slotsPerEpoch uint64) uint64 { +func ComputeCommitteeCountPerSlot(s state.BeaconStateReader, slot uint64, slotsPerEpoch uint64) uint64 { epoch := slot / slotsPerEpoch return s.CommitteeCount(epoch) }