diff --git a/chains/manager.go b/chains/manager.go index b8b051c0a05..f92ef51d3b1 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -762,7 +762,7 @@ func (m *manager) createAvalancheChain( if err != nil { return nil, fmt.Errorf("error creating peer tracker: %w", err) } - vdrs.RegisterCallbackListener(ctx.SubnetID, connectedValidators) + vdrs.RegisterSetCallbackListener(ctx.SubnetID, connectedValidators) peerTracker, err := p2p.NewPeerTracker( ctx.Log, @@ -794,7 +794,7 @@ func (m *manager) createAvalancheChain( connectedBeacons := tracker.NewPeers() startupTracker := tracker.NewStartup(connectedBeacons, (3*bootstrapWeight+3)/4) - vdrs.RegisterCallbackListener(ctx.SubnetID, startupTracker) + vdrs.RegisterSetCallbackListener(ctx.SubnetID, startupTracker) snowGetHandler, err := snowgetter.New( vmWrappingProposerVM, @@ -1107,7 +1107,7 @@ func (m *manager) createSnowmanChain( if err != nil { return nil, fmt.Errorf("error creating peer tracker: %w", err) } - vdrs.RegisterCallbackListener(ctx.SubnetID, connectedValidators) + vdrs.RegisterSetCallbackListener(ctx.SubnetID, connectedValidators) peerTracker, err := p2p.NewPeerTracker( ctx.Log, @@ -1139,7 +1139,7 @@ func (m *manager) createSnowmanChain( connectedBeacons := tracker.NewPeers() startupTracker := tracker.NewStartup(connectedBeacons, (3*bootstrapWeight+3)/4) - beacons.RegisterCallbackListener(ctx.SubnetID, startupTracker) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startupTracker) snowGetHandler, err := snowgetter.New( vm, diff --git a/network/network.go b/network/network.go index 4e462e933a4..506e5190858 100644 --- a/network/network.go +++ b/network/network.go @@ -239,7 +239,7 @@ func NewNetwork( if err != nil { return nil, fmt.Errorf("initializing ip tracker failed with: %w", err) } - config.Validators.RegisterCallbackListener(constants.PrimaryNetworkID, ipTracker) + config.Validators.RegisterSetCallbackListener(constants.PrimaryNetworkID, ipTracker) // Track all default bootstrappers to ensure their current IPs are gossiped // like validator IPs. diff --git a/node/overridden_manager.go b/node/overridden_manager.go index 4dd49b65eab..484fe05da75 100644 --- a/node/overridden_manager.go +++ b/node/overridden_manager.go @@ -72,8 +72,12 @@ func (o *overriddenManager) GetMap(ids.ID) map[ids.NodeID]*validators.GetValidat return o.manager.GetMap(o.subnetID) } -func (o *overriddenManager) RegisterCallbackListener(_ ids.ID, listener validators.SetCallbackListener) { - o.manager.RegisterCallbackListener(o.subnetID, listener) +func (o *overriddenManager) RegisterCallbackListener(listener validators.ManagerCallbackListener) { + o.manager.RegisterCallbackListener(listener) +} + +func (o *overriddenManager) RegisterSetCallbackListener(_ ids.ID, listener validators.SetCallbackListener) { + o.manager.RegisterSetCallbackListener(o.subnetID, listener) } func (o *overriddenManager) String() string { diff --git a/snow/engine/avalanche/bootstrap/bootstrapper_test.go b/snow/engine/avalanche/bootstrap/bootstrapper_test.go index 8d8c107383e..812a343dd17 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper_test.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper_test.go @@ -81,7 +81,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *vertex.Te totalWeight, err := vdrs.TotalWeight(constants.PrimaryNetworkID) require.NoError(err) startupTracker := tracker.NewStartup(peerTracker, totalWeight/2+1) - vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, startupTracker) + vdrs.RegisterSetCallbackListener(constants.PrimaryNetworkID, startupTracker) avaGetHandler, err := getter.New(manager, sender, ctx.Log, time.Second, 2000, ctx.AvalancheRegisterer) require.NoError(err) diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 66b71d69bbb..fd7fc0f1283 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -73,7 +73,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *block.Tes totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) require.NoError(err) startupTracker := tracker.NewStartup(tracker.NewPeers(), totalWeight/2+1) - vdrs.RegisterCallbackListener(ctx.SubnetID, startupTracker) + vdrs.RegisterSetCallbackListener(ctx.SubnetID, startupTracker) require.NoError(startupTracker.Connected(context.Background(), peer, version.CurrentApp)) @@ -126,7 +126,7 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { startupAlpha := alpha startupTracker := tracker.NewStartup(tracker.NewPeers(), startupAlpha) - peers.RegisterCallbackListener(ctx.SubnetID, startupTracker) + peers.RegisterSetCallbackListener(ctx.SubnetID, startupTracker) snowGetHandler, err := getter.New(vm, sender, ctx.Log, time.Second, 2000, ctx.Registerer) require.NoError(err) @@ -650,7 +650,7 @@ func TestBootstrapNoParseOnNew(t *testing.T) { totalWeight, err := peers.TotalWeight(ctx.SubnetID) require.NoError(err) startupTracker := tracker.NewStartup(tracker.NewPeers(), totalWeight/2+1) - peers.RegisterCallbackListener(ctx.SubnetID, startupTracker) + peers.RegisterSetCallbackListener(ctx.SubnetID, startupTracker) require.NoError(startupTracker.Connected(context.Background(), peer, version.CurrentApp)) snowGetHandler, err := getter.New(vm, sender, ctx.Log, time.Second, 2000, ctx.Registerer) diff --git a/snow/engine/snowman/syncer/state_syncer_test.go b/snow/engine/snowman/syncer/state_syncer_test.go index 1ec1e67021b..2ee745bb5c6 100644 --- a/snow/engine/snowman/syncer/state_syncer_test.go +++ b/snow/engine/snowman/syncer/state_syncer_test.go @@ -116,7 +116,7 @@ func TestStateSyncingStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, _, sender := buildTestsObjects(t, ctx, startup, beacons, alpha) @@ -159,7 +159,7 @@ func TestStateSyncLocalSummaryIsIncludedAmongFrontiersIfAvailable(t *testing.T) peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, _ := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -197,7 +197,7 @@ func TestStateSyncNotFoundOngoingSummaryIsNotIncludedAmongFrontiers(t *testing.T peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, _ := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -228,7 +228,7 @@ func TestBeaconsAreReachedForFrontiersUponStartup(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, _, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -267,7 +267,7 @@ func TestUnRequestedStateSummaryFrontiersAreDropped(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -357,7 +357,7 @@ func TestMalformedStateSummaryFrontiersAreDropped(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -426,7 +426,7 @@ func TestLateResponsesFromUnresponsiveFrontiersAreNotRecorded(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -509,7 +509,7 @@ func TestStateSyncIsRestartedIfTooManyFrontierSeedersTimeout(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -598,7 +598,7 @@ func TestVoteRequestsAreSentAsAllFrontierBeaconsResponded(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -669,7 +669,7 @@ func TestUnRequestedVotesAreDropped(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -786,7 +786,7 @@ func TestVotesForUnknownSummariesAreDropped(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) @@ -890,7 +890,7 @@ func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, alpha) @@ -1035,7 +1035,7 @@ func TestVotingIsRestartedIfMajorityIsNotReachedDueToTimeouts(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, alpha) @@ -1141,7 +1141,7 @@ func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing. peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, fullVM, sender := buildTestsObjects(t, ctx, startup, beacons, alpha) @@ -1286,7 +1286,7 @@ func TestStateSyncIsDoneOnceVMNotifies(t *testing.T) { peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) syncer, _, _ := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index 43a00dbcb9c..07cffdca565 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -109,7 +109,7 @@ func New(config Config) (*Transitive, error) { } acceptedFrontiers := tracker.NewAccepted() - config.Validators.RegisterCallbackListener(config.Ctx.SubnetID, acceptedFrontiers) + config.Validators.RegisterSetCallbackListener(config.Ctx.SubnetID, acceptedFrontiers) factory := poll.NewEarlyTermNoTraversalFactory( config.Params.AlphaPreference, diff --git a/snow/engine/snowman/transitive_test.go b/snow/engine/snowman/transitive_test.go index bd5966f797b..f7e74412075 100644 --- a/snow/engine/snowman/transitive_test.go +++ b/snow/engine/snowman/transitive_test.go @@ -43,7 +43,7 @@ func setup(t *testing.T, engCfg Config) (ids.NodeID, validators.Manager, *common require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) require.NoError(engCfg.ConnectedValidators.Connected(context.Background(), vdr, version.CurrentApp)) - vals.RegisterCallbackListener(engCfg.Ctx.SubnetID, engCfg.ConnectedValidators) + vals.RegisterSetCallbackListener(engCfg.Ctx.SubnetID, engCfg.ConnectedValidators) sender := &common.SenderTest{T: t} engCfg.Sender = sender diff --git a/snow/networking/handler/health_test.go b/snow/networking/handler/health_test.go index 256613827cf..163332735ea 100644 --- a/snow/networking/handler/health_test.go +++ b/snow/networking/handler/health_test.go @@ -64,7 +64,7 @@ func TestHealthCheckSubnet(t *testing.T) { require.NoError(err) peerTracker := commontracker.NewPeers() - vdrs.RegisterCallbackListener(ctx.SubnetID, peerTracker) + vdrs.RegisterSetCallbackListener(ctx.SubnetID, peerTracker) sb := subnets.New( ctx.NodeID, diff --git a/snow/validators/manager.go b/snow/validators/manager.go index 5844c1e7f18..fcf37934112 100644 --- a/snow/validators/manager.go +++ b/snow/validators/manager.go @@ -24,6 +24,12 @@ var ( ErrMissingValidators = errors.New("missing validators") ) +type ManagerCallbackListener interface { + OnValidatorAdded(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) + OnValidatorRemoved(subnetID ids.ID, nodeID ids.NodeID, weight uint64) + OnValidatorWeightChanged(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) +} + type SetCallbackListener interface { OnValidatorAdded(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) OnValidatorRemoved(nodeID ids.NodeID, weight uint64) @@ -88,9 +94,13 @@ type Manager interface { // Map of the validators in this subnet GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput - // When a validator's weight changes, or a validator is added/removed, - // this listener is called. - RegisterCallbackListener(subnetID ids.ID, listener SetCallbackListener) + // When a validator is added, removed, or its weight changes, the listener + // will be notified of the event. + RegisterCallbackListener(listener ManagerCallbackListener) + + // When a validator is added, removed, or its weight changes on [subnetID], + // the listener will be notified of the event. + RegisterSetCallbackListener(subnetID ids.ID, listener SetCallbackListener) } // NewManager returns a new, empty manager @@ -105,7 +115,8 @@ type manager struct { // Key: Subnet ID // Value: The validators that validate the subnet - subnetToVdrs map[ids.ID]*vdrSet + subnetToVdrs map[ids.ID]*vdrSet + callbackListeners []ManagerCallbackListener } func (m *manager) AddStaker(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { @@ -118,7 +129,7 @@ func (m *manager) AddStaker(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKe set, exists := m.subnetToVdrs[subnetID] if !exists { - set = newSet() + set = newSet(subnetID, m.callbackListeners) m.subnetToVdrs[subnetID] = set } @@ -264,13 +275,23 @@ func (m *manager) GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput { return set.Map() } -func (m *manager) RegisterCallbackListener(subnetID ids.ID, listener SetCallbackListener) { +func (m *manager) RegisterCallbackListener(listener ManagerCallbackListener) { + m.lock.Lock() + defer m.lock.Unlock() + + m.callbackListeners = append(m.callbackListeners, listener) + for _, set := range m.subnetToVdrs { + set.RegisterManagerCallbackListener(listener) + } +} + +func (m *manager) RegisterSetCallbackListener(subnetID ids.ID, listener SetCallbackListener) { m.lock.Lock() defer m.lock.Unlock() set, exists := m.subnetToVdrs[subnetID] if !exists { - set = newSet() + set = newSet(subnetID, m.callbackListeners) m.subnetToVdrs[subnetID] = set } diff --git a/snow/validators/manager_test.go b/snow/validators/manager_test.go index 781d2e784e1..cf23d49d39b 100644 --- a/snow/validators/manager_test.go +++ b/snow/validators/manager_test.go @@ -17,6 +17,39 @@ import ( safemath "github.com/ava-labs/avalanchego/utils/math" ) +var _ ManagerCallbackListener = (*managerCallbackListener)(nil) + +type managerCallbackListener struct { + t *testing.T + onAdd func(ids.ID, ids.NodeID, *bls.PublicKey, ids.ID, uint64) + onWeight func(ids.ID, ids.NodeID, uint64, uint64) + onRemoved func(ids.ID, ids.NodeID, uint64) +} + +func (c *managerCallbackListener) OnValidatorAdded(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + if c.onAdd != nil { + c.onAdd(subnetID, nodeID, pk, txID, weight) + } else { + c.t.Fail() + } +} + +func (c *managerCallbackListener) OnValidatorRemoved(subnetID ids.ID, nodeID ids.NodeID, weight uint64) { + if c.onRemoved != nil { + c.onRemoved(subnetID, nodeID, weight) + } else { + c.t.Fail() + } +} + +func (c *managerCallbackListener) OnValidatorWeightChanged(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) { + if c.onWeight != nil { + c.onWeight(subnetID, nodeID, oldWeight, newWeight) + } else { + c.t.Fail() + } +} + func TestAddZeroWeight(t *testing.T) { require := require.New(t) @@ -411,142 +444,292 @@ func TestString(t *testing.T) { func TestAddCallback(t *testing.T) { require := require.New(t) - nodeID0 := ids.BuildTestNodeID([]byte{1}) - sk0, err := bls.NewSecretKey() + expectedSK, err := bls.NewSecretKey() require.NoError(err) - pk0 := bls.PublicFromSecretKey(sk0) - txID0 := ids.GenerateTestID() - weight0 := uint64(1) - m := NewManager() - subnetID := ids.GenerateTestID() - callCount := 0 - m.RegisterCallbackListener(subnetID, &callbackListener{ + var ( + expectedNodeID = ids.GenerateTestNodeID() + expectedPK = bls.PublicFromSecretKey(expectedSK) + expectedTxID = ids.GenerateTestID() + expectedWeight uint64 = 1 + expectedSubnetID0 = ids.GenerateTestID() + expectedSubnetID1 = ids.GenerateTestID() + + m = NewManager() + managerCallCount = 0 + setCallCount = 0 + ) + m.RegisterCallbackListener(&managerCallbackListener{ + t: t, + onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedWeight, weight) + managerCallCount++ + }, + }) + m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { - require.Equal(nodeID0, nodeID) - require.Equal(pk0, pk) - require.Equal(txID0, txID) - require.Equal(weight0, weight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedWeight, weight) + setCallCount++ }, }) - require.NoError(m.AddStaker(subnetID, nodeID0, pk0, txID0, weight0)) - // setup another subnetID - subnetID2 := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) - // should not be called for subnetID2 - require.Equal(1, callCount) + require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedWeight)) + require.Equal(1, managerCallCount) // should be called for expectedSubnetID0 + require.Equal(1, setCallCount) // should be called for expectedSubnetID0 + + require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedWeight)) + require.Equal(2, managerCallCount) // should be called for expectedSubnetID1 + require.Equal(1, setCallCount) // should not be called for expectedSubnetID1 } func TestAddWeightCallback(t *testing.T) { require := require.New(t) - nodeID0 := ids.BuildTestNodeID([]byte{1}) - txID0 := ids.GenerateTestID() - weight0 := uint64(1) - weight1 := uint64(93) - - m := NewManager() - subnetID := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID, nodeID0, nil, txID0, weight0)) + expectedSK, err := bls.NewSecretKey() + require.NoError(err) - callCount := 0 - m.RegisterCallbackListener(subnetID, &callbackListener{ + var ( + expectedNodeID = ids.GenerateTestNodeID() + expectedPK = bls.PublicFromSecretKey(expectedSK) + expectedTxID = ids.GenerateTestID() + expectedOldWeight uint64 = 1 + expectedAddedWeight uint64 = 10 + expectedNewWeight = expectedOldWeight + expectedAddedWeight + expectedSubnetID0 = ids.GenerateTestID() + expectedSubnetID1 = ids.GenerateTestID() + + m = NewManager() + managerAddCallCount = 0 + managerChangeCallCount = 0 + setAddCallCount = 0 + setChangeCallCount = 0 + ) + + require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight)) + + m.RegisterCallbackListener(&managerCallbackListener{ + t: t, + onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedOldWeight, weight) + managerAddCallCount++ + }, + onWeight: func(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedOldWeight, oldWeight) + require.Equal(expectedNewWeight, newWeight) + managerChangeCallCount++ + }, + }) + m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { - require.Equal(nodeID0, nodeID) - require.Nil(pk) - require.Equal(txID0, txID) - require.Equal(weight0, weight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedOldWeight, weight) + setAddCallCount++ }, onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) { - require.Equal(nodeID0, nodeID) - require.Equal(weight0, oldWeight) - require.Equal(weight0+weight1, newWeight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedOldWeight, oldWeight) + require.Equal(expectedNewWeight, newWeight) + setChangeCallCount++ }, }) - require.NoError(m.AddWeight(subnetID, nodeID0, weight1)) - // setup another subnetID - subnetID2 := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) - require.NoError(m.AddWeight(subnetID2, nodeID0, weight1)) - // should not be called for subnetID2 - require.Equal(2, callCount) + require.Equal(1, managerAddCallCount) + require.Zero(managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Zero(setChangeCallCount) + + require.NoError(m.AddWeight(expectedSubnetID0, expectedNodeID, expectedAddedWeight)) + require.Equal(1, managerAddCallCount) + require.Equal(1, managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setChangeCallCount) + + require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight)) + require.Equal(2, managerAddCallCount) + require.Equal(1, managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setChangeCallCount) + + require.NoError(m.AddWeight(expectedSubnetID1, expectedNodeID, expectedAddedWeight)) + require.Equal(2, managerAddCallCount) + require.Equal(2, managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setChangeCallCount) } func TestRemoveWeightCallback(t *testing.T) { require := require.New(t) - nodeID0 := ids.BuildTestNodeID([]byte{1}) - txID0 := ids.GenerateTestID() - weight0 := uint64(93) - weight1 := uint64(92) - - m := NewManager() - subnetID := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID, nodeID0, nil, txID0, weight0)) + expectedSK, err := bls.NewSecretKey() + require.NoError(err) - callCount := 0 - m.RegisterCallbackListener(subnetID, &callbackListener{ + var ( + expectedNodeID = ids.GenerateTestNodeID() + expectedPK = bls.PublicFromSecretKey(expectedSK) + expectedTxID = ids.GenerateTestID() + expectedNewWeight uint64 = 1 + expectedRemovedWeight uint64 = 10 + expectedOldWeight = expectedNewWeight + expectedRemovedWeight + expectedSubnetID0 = ids.GenerateTestID() + expectedSubnetID1 = ids.GenerateTestID() + + m = NewManager() + managerAddCallCount = 0 + managerChangeCallCount = 0 + setAddCallCount = 0 + setChangeCallCount = 0 + ) + + require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight)) + + m.RegisterCallbackListener(&managerCallbackListener{ + t: t, + onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedOldWeight, weight) + managerAddCallCount++ + }, + onWeight: func(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedOldWeight, oldWeight) + require.Equal(expectedNewWeight, newWeight) + managerChangeCallCount++ + }, + }) + m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { - require.Equal(nodeID0, nodeID) - require.Nil(pk) - require.Equal(txID0, txID) - require.Equal(weight0, weight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedOldWeight, weight) + setAddCallCount++ }, onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) { - require.Equal(nodeID0, nodeID) - require.Equal(weight0, oldWeight) - require.Equal(weight0-weight1, newWeight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedOldWeight, oldWeight) + require.Equal(expectedNewWeight, newWeight) + setChangeCallCount++ }, }) - require.NoError(m.RemoveWeight(subnetID, nodeID0, weight1)) - // setup another subnetID - subnetID2 := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) - require.NoError(m.RemoveWeight(subnetID2, nodeID0, weight1)) - // should not be called for subnetID2 - require.Equal(2, callCount) + require.Equal(1, managerAddCallCount) + require.Zero(managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Zero(setChangeCallCount) + + require.NoError(m.RemoveWeight(expectedSubnetID0, expectedNodeID, expectedRemovedWeight)) + require.Equal(1, managerAddCallCount) + require.Equal(1, managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setChangeCallCount) + + require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight)) + require.Equal(2, managerAddCallCount) + require.Equal(1, managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setChangeCallCount) + + require.NoError(m.RemoveWeight(expectedSubnetID1, expectedNodeID, expectedRemovedWeight)) + require.Equal(2, managerAddCallCount) + require.Equal(2, managerChangeCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setChangeCallCount) } -func TestValidatorRemovedCallback(t *testing.T) { +func TestRemoveCallback(t *testing.T) { require := require.New(t) - nodeID0 := ids.BuildTestNodeID([]byte{1}) - txID0 := ids.GenerateTestID() - weight0 := uint64(93) + expectedSK, err := bls.NewSecretKey() + require.NoError(err) - m := NewManager() - subnetID := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID, nodeID0, nil, txID0, weight0)) + var ( + expectedNodeID = ids.GenerateTestNodeID() + expectedPK = bls.PublicFromSecretKey(expectedSK) + expectedTxID = ids.GenerateTestID() + expectedWeight uint64 = 1 + expectedSubnetID0 = ids.GenerateTestID() + expectedSubnetID1 = ids.GenerateTestID() + + m = NewManager() + managerAddCallCount = 0 + managerRemoveCallCount = 0 + setAddCallCount = 0 + setRemoveCallCount = 0 + ) - callCount := 0 - m.RegisterCallbackListener(subnetID, &callbackListener{ + require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedWeight)) + + m.RegisterCallbackListener(&managerCallbackListener{ + t: t, + onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedWeight, weight) + managerAddCallCount++ + }, + onRemoved: func(subnetID ids.ID, nodeID ids.NodeID, weight uint64) { + require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID) + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedWeight, weight) + managerRemoveCallCount++ + }, + }) + m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { - require.Equal(nodeID0, nodeID) - require.Nil(pk) - require.Equal(txID0, txID) - require.Equal(weight0, weight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedPK, pk) + require.Equal(expectedTxID, txID) + require.Equal(expectedWeight, weight) + setAddCallCount++ }, onRemoved: func(nodeID ids.NodeID, weight uint64) { - require.Equal(nodeID0, nodeID) - require.Equal(weight0, weight) - callCount++ + require.Equal(expectedNodeID, nodeID) + require.Equal(expectedWeight, weight) + setRemoveCallCount++ }, }) - require.NoError(m.RemoveWeight(subnetID, nodeID0, weight0)) - // setup another subnetID - subnetID2 := ids.GenerateTestID() - require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) - require.NoError(m.AddWeight(subnetID2, nodeID0, weight0)) - // should not be called for subnetID2 - require.Equal(2, callCount) + require.Equal(1, managerAddCallCount) + require.Zero(managerRemoveCallCount) + require.Equal(1, setAddCallCount) + require.Zero(setRemoveCallCount) + + require.NoError(m.RemoveWeight(expectedSubnetID0, expectedNodeID, expectedWeight)) + require.Equal(1, managerAddCallCount) + require.Equal(1, managerRemoveCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setRemoveCallCount) + + require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedWeight)) + require.Equal(2, managerAddCallCount) + require.Equal(1, managerRemoveCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setRemoveCallCount) + + require.NoError(m.RemoveWeight(expectedSubnetID1, expectedNodeID, expectedWeight)) + require.Equal(2, managerAddCallCount) + require.Equal(2, managerRemoveCallCount) + require.Equal(1, setAddCallCount) + require.Equal(1, setRemoveCallCount) } diff --git a/snow/validators/set.go b/snow/validators/set.go index 5e7c81a2310..e9bb235f995 100644 --- a/snow/validators/set.go +++ b/snow/validators/set.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "math/big" + "slices" "strings" "sync" @@ -25,15 +26,19 @@ var ( ) // newSet returns a new, empty set of validators. -func newSet() *vdrSet { +func newSet(subnetID ids.ID, callbackListeners []ManagerCallbackListener) *vdrSet { return &vdrSet{ - vdrs: make(map[ids.NodeID]*Validator), - sampler: sampler.NewWeightedWithoutReplacement(), - totalWeight: new(big.Int), + subnetID: subnetID, + vdrs: make(map[ids.NodeID]*Validator), + totalWeight: new(big.Int), + sampler: sampler.NewWeightedWithoutReplacement(), + managerCallbackListeners: slices.Clone(callbackListeners), } } type vdrSet struct { + subnetID ids.ID + lock sync.RWMutex vdrs map[ids.NodeID]*Validator vdrSlice []*Validator @@ -43,7 +48,8 @@ type vdrSet struct { samplerInitialized bool sampler sampler.WeightedWithoutReplacement - callbackListeners []SetCallbackListener + managerCallbackListeners []ManagerCallbackListener + setCallbackListeners []SetCallbackListener } func (s *vdrSet) Add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { @@ -218,7 +224,7 @@ func (s *vdrSet) HasCallbackRegistered() bool { s.lock.RLock() defer s.lock.RUnlock() - return len(s.callbackListeners) > 0 + return len(s.setCallbackListeners) > 0 } func (s *vdrSet) Map() map[ids.NodeID]*GetValidatorOutput { @@ -305,11 +311,21 @@ func (s *vdrSet) prefixedString(prefix string) string { return sb.String() } +func (s *vdrSet) RegisterManagerCallbackListener(callbackListener ManagerCallbackListener) { + s.lock.Lock() + defer s.lock.Unlock() + + s.managerCallbackListeners = append(s.managerCallbackListeners, callbackListener) + for _, vdr := range s.vdrSlice { + callbackListener.OnValidatorAdded(s.subnetID, vdr.NodeID, vdr.PublicKey, vdr.TxID, vdr.Weight) + } +} + func (s *vdrSet) RegisterCallbackListener(callbackListener SetCallbackListener) { s.lock.Lock() defer s.lock.Unlock() - s.callbackListeners = append(s.callbackListeners, callbackListener) + s.setCallbackListeners = append(s.setCallbackListeners, callbackListener) for _, vdr := range s.vdrSlice { callbackListener.OnValidatorAdded(vdr.NodeID, vdr.PublicKey, vdr.TxID, vdr.Weight) } @@ -317,21 +333,30 @@ func (s *vdrSet) RegisterCallbackListener(callbackListener SetCallbackListener) // Assumes [s.lock] is held func (s *vdrSet) callWeightChangeCallbacks(node ids.NodeID, oldWeight, newWeight uint64) { - for _, callbackListener := range s.callbackListeners { + for _, callbackListener := range s.managerCallbackListeners { + callbackListener.OnValidatorWeightChanged(s.subnetID, node, oldWeight, newWeight) + } + for _, callbackListener := range s.setCallbackListeners { callbackListener.OnValidatorWeightChanged(node, oldWeight, newWeight) } } // Assumes [s.lock] is held func (s *vdrSet) callValidatorAddedCallbacks(node ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { - for _, callbackListener := range s.callbackListeners { + for _, callbackListener := range s.managerCallbackListeners { + callbackListener.OnValidatorAdded(s.subnetID, node, pk, txID, weight) + } + for _, callbackListener := range s.setCallbackListeners { callbackListener.OnValidatorAdded(node, pk, txID, weight) } } // Assumes [s.lock] is held func (s *vdrSet) callValidatorRemovedCallbacks(node ids.NodeID, weight uint64) { - for _, callbackListener := range s.callbackListeners { + for _, callbackListener := range s.managerCallbackListeners { + callbackListener.OnValidatorRemoved(s.subnetID, node, weight) + } + for _, callbackListener := range s.setCallbackListeners { callbackListener.OnValidatorRemoved(node, weight) } } diff --git a/snow/validators/set_test.go b/snow/validators/set_test.go index 4554f930fa3..480f9dba4f8 100644 --- a/snow/validators/set_test.go +++ b/snow/validators/set_test.go @@ -17,10 +17,43 @@ import ( safemath "github.com/ava-labs/avalanchego/utils/math" ) +var _ SetCallbackListener = (*setCallbackListener)(nil) + +type setCallbackListener struct { + t *testing.T + onAdd func(ids.NodeID, *bls.PublicKey, ids.ID, uint64) + onWeight func(ids.NodeID, uint64, uint64) + onRemoved func(ids.NodeID, uint64) +} + +func (c *setCallbackListener) OnValidatorAdded(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + if c.onAdd != nil { + c.onAdd(nodeID, pk, txID, weight) + } else { + c.t.Fail() + } +} + +func (c *setCallbackListener) OnValidatorRemoved(nodeID ids.NodeID, weight uint64) { + if c.onRemoved != nil { + c.onRemoved(nodeID, weight) + } else { + c.t.Fail() + } +} + +func (c *setCallbackListener) OnValidatorWeightChanged(nodeID ids.NodeID, oldWeight, newWeight uint64) { + if c.onWeight != nil { + c.onWeight(nodeID, oldWeight, newWeight) + } else { + c.t.Fail() + } +} + func TestSetAddDuplicate(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) nodeID := ids.GenerateTestNodeID() require.NoError(s.Add(nodeID, nil, ids.Empty, 1)) @@ -32,7 +65,7 @@ func TestSetAddDuplicate(t *testing.T) { func TestSetAddOverflow(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, math.MaxUint64)) @@ -44,7 +77,7 @@ func TestSetAddOverflow(t *testing.T) { func TestSetAddWeightOverflow(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) @@ -60,7 +93,7 @@ func TestSetAddWeightOverflow(t *testing.T) { func TestSetGetWeight(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) nodeID := ids.GenerateTestNodeID() require.Zero(s.GetWeight(nodeID)) @@ -83,7 +116,7 @@ func TestSetSubsetWeight(t *testing.T) { subset := set.Of(nodeID0, nodeID1) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(nodeID0, nil, ids.Empty, weight0)) require.NoError(s.Add(nodeID1, nil, ids.Empty, weight1)) @@ -98,7 +131,7 @@ func TestSetSubsetWeight(t *testing.T) { func TestSetRemoveWeightMissingValidator(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) @@ -109,7 +142,7 @@ func TestSetRemoveWeightMissingValidator(t *testing.T) { func TestSetRemoveWeightUnderflow(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) @@ -127,7 +160,7 @@ func TestSetRemoveWeightUnderflow(t *testing.T) { func TestSetGet(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) nodeID := ids.GenerateTestNodeID() _, ok := s.Get(nodeID) @@ -164,7 +197,7 @@ func TestSetGet(t *testing.T) { func TestSetLen(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) setLen := s.Len() require.Zero(setLen) @@ -195,7 +228,7 @@ func TestSetLen(t *testing.T) { func TestSetMap(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) m := s.Map() require.Empty(m) @@ -278,7 +311,7 @@ func TestSetWeight(t *testing.T) { vdr1 := ids.BuildTestNodeID([]byte{2}) weight1 := uint64(123) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(vdr0, nil, ids.Empty, weight0)) require.NoError(s.Add(vdr1, nil, ids.Empty, weight1)) @@ -292,7 +325,7 @@ func TestSetWeight(t *testing.T) { func TestSetSample(t *testing.T) { require := require.New(t) - s := newSet() + s := newSet(ids.Empty, nil) sampled, err := s.Sample(0) require.NoError(err) @@ -337,7 +370,7 @@ func TestSetString(t *testing.T) { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, }) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(nodeID0, nil, ids.Empty, 1)) require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1)) @@ -349,39 +382,6 @@ func TestSetString(t *testing.T) { require.Equal(expected, result) } -var _ SetCallbackListener = (*callbackListener)(nil) - -type callbackListener struct { - t *testing.T - onAdd func(ids.NodeID, *bls.PublicKey, ids.ID, uint64) - onWeight func(ids.NodeID, uint64, uint64) - onRemoved func(ids.NodeID, uint64) -} - -func (c *callbackListener) OnValidatorAdded(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { - if c.onAdd != nil { - c.onAdd(nodeID, pk, txID, weight) - } else { - c.t.Fail() - } -} - -func (c *callbackListener) OnValidatorRemoved(nodeID ids.NodeID, weight uint64) { - if c.onRemoved != nil { - c.onRemoved(nodeID, weight) - } else { - c.t.Fail() - } -} - -func (c *callbackListener) OnValidatorWeightChanged(nodeID ids.NodeID, oldWeight, newWeight uint64) { - if c.onWeight != nil { - c.onWeight(nodeID, oldWeight, newWeight) - } else { - c.t.Fail() - } -} - func TestSetAddCallback(t *testing.T) { require := require.New(t) @@ -392,10 +392,10 @@ func TestSetAddCallback(t *testing.T) { txID0 := ids.GenerateTestID() weight0 := uint64(1) - s := newSet() + s := newSet(ids.Empty, nil) callCount := 0 require.False(s.HasCallbackRegistered()) - s.RegisterCallbackListener(&callbackListener{ + s.RegisterCallbackListener(&setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { require.Equal(nodeID0, nodeID) @@ -418,12 +418,12 @@ func TestSetAddWeightCallback(t *testing.T) { weight0 := uint64(1) weight1 := uint64(93) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(nodeID0, nil, txID0, weight0)) callCount := 0 require.False(s.HasCallbackRegistered()) - s.RegisterCallbackListener(&callbackListener{ + s.RegisterCallbackListener(&setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { require.Equal(nodeID0, nodeID) @@ -452,12 +452,12 @@ func TestSetRemoveWeightCallback(t *testing.T) { weight0 := uint64(93) weight1 := uint64(92) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(nodeID0, nil, txID0, weight0)) callCount := 0 require.False(s.HasCallbackRegistered()) - s.RegisterCallbackListener(&callbackListener{ + s.RegisterCallbackListener(&setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { require.Equal(nodeID0, nodeID) @@ -485,12 +485,12 @@ func TestSetValidatorRemovedCallback(t *testing.T) { txID0 := ids.GenerateTestID() weight0 := uint64(93) - s := newSet() + s := newSet(ids.Empty, nil) require.NoError(s.Add(nodeID0, nil, txID0, weight0)) callCount := 0 require.False(s.HasCallbackRegistered()) - s.RegisterCallbackListener(&callbackListener{ + s.RegisterCallbackListener(&setCallbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { require.Equal(nodeID0, nodeID) diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index 214e7246ce3..4458259eace 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -353,7 +353,7 @@ func (vm *VM) onNormalOperationsStarted() error { } vl := validators.NewLogger(vm.ctx.Log, constants.PrimaryNetworkID, vm.ctx.NodeID) - vm.Validators.RegisterCallbackListener(constants.PrimaryNetworkID, vl) + vm.Validators.RegisterSetCallbackListener(constants.PrimaryNetworkID, vl) for subnetID := range vm.TrackedSubnets { vdrIDs := vm.Validators.GetValidatorIDs(subnetID) @@ -362,7 +362,7 @@ func (vm *VM) onNormalOperationsStarted() error { } vl := validators.NewLogger(vm.ctx.Log, subnetID, vm.ctx.NodeID) - vm.Validators.RegisterCallbackListener(subnetID, vl) + vm.Validators.RegisterSetCallbackListener(subnetID, vl) } if err := vm.state.Commit(); err != nil { diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index 66a939f0a86..479788b4709 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -1460,7 +1460,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { totalWeight, err := beacons.TotalWeight(ctx.SubnetID) require.NoError(err) startup := tracker.NewStartup(peers, (totalWeight+1)/2) - beacons.RegisterCallbackListener(ctx.SubnetID, startup) + beacons.RegisterSetCallbackListener(ctx.SubnetID, startup) // The engine handles consensus snowGetHandler, err := snowgetter.New(