From 194e377a6a41f564f9708cc921c377dadfc153ee Mon Sep 17 00:00:00 2001 From: Marius Poke Date: Fri, 7 Jun 2024 16:13:15 +0200 Subject: [PATCH] fix!: Replace GetAllConsumerChains with lightweight version (#1946) * add GetAllConsumerChainIDs * replace GetAllConsumerChains with GetAllRegisteredConsumerChainIDs * add changelog entry * move HasToValidate to grpc_query.go as it's used only there * apply review suggestions --- .../provider/1946-get-consumer-chains.md | 3 + tests/mbt/driver/core.go | 34 +++--- tests/mbt/driver/mbt_test.go | 72 +++++------ x/ccv/provider/keeper/distribution.go | 16 +-- x/ccv/provider/keeper/genesis.go | 36 +++--- x/ccv/provider/keeper/grpc_query.go | 115 ++++++++++++++++-- x/ccv/provider/keeper/grpc_query_test.go | 98 ++++++++++++++- x/ccv/provider/keeper/hooks.go | 7 +- x/ccv/provider/keeper/keeper.go | 105 ++-------------- x/ccv/provider/keeper/keeper_test.go | 94 ++------------ x/ccv/provider/keeper/partial_set_security.go | 6 +- x/ccv/provider/keeper/relay.go | 31 ++--- x/ccv/provider/migrations/v3/migrations.go | 8 +- x/ccv/provider/migrations/v5/migrations.go | 7 +- 14 files changed, 327 insertions(+), 305 deletions(-) create mode 100644 .changelog/unreleased/bug-fixes/provider/1946-get-consumer-chains.md diff --git a/.changelog/unreleased/bug-fixes/provider/1946-get-consumer-chains.md b/.changelog/unreleased/bug-fixes/provider/1946-get-consumer-chains.md new file mode 100644 index 0000000000..eae373c390 --- /dev/null +++ b/.changelog/unreleased/bug-fixes/provider/1946-get-consumer-chains.md @@ -0,0 +1,3 @@ +- Replace `GetAllConsumerChains` with lightweight version + (`GetAllRegisteredConsumerChainIDs`) that doesn't call into the staking module + ([\#1946](https://github.com/cosmos/interchain-security/pull/1946)) \ No newline at end of file diff --git a/tests/mbt/driver/core.go b/tests/mbt/driver/core.go index 62ed11671c..2cd618096f 100644 --- a/tests/mbt/driver/core.go +++ b/tests/mbt/driver/core.go @@ -3,12 +3,11 @@ package main import ( "fmt" "log" + gomath "math" "strings" "testing" "time" - gomath "math" - "cosmossdk.io/math" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" @@ -33,7 +32,6 @@ import ( consumerkeeper "github.com/cosmos/interchain-security/v5/x/ccv/consumer/keeper" consumertypes "github.com/cosmos/interchain-security/v5/x/ccv/consumer/types" providerkeeper "github.com/cosmos/interchain-security/v5/x/ccv/provider/keeper" - providertypes "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" "github.com/cosmos/interchain-security/v5/x/ccv/types" ) @@ -222,11 +220,7 @@ func (s *Driver) getStateString() string { state.WriteString("\n") state.WriteString("Consumers Chains:\n") - consumerChains := s.providerKeeper().GetAllConsumerChains(s.providerCtx()) - chainIds := make([]string, len(consumerChains)) - for i, consumerChain := range consumerChains { - chainIds[i] = consumerChain.ChainId - } + chainIds := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx()) state.WriteString(strings.Join(chainIds, ", ")) state.WriteString("\n\n") @@ -264,11 +258,11 @@ func (s *Driver) getChainStateString(chain ChainId) string { if !s.isProviderChain(chain) { // Check whether the chain is in the consumer chains on the provider - consumerChains := s.providerKeeper().GetAllConsumerChains(s.providerCtx()) + consumerChainIDs := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx()) found := false - for _, consumerChain := range consumerChains { - if consumerChain.ChainId == string(chain) { + for _, consumerChainID := range consumerChainIDs { + if consumerChainID == string(chain) { found = true } } @@ -371,16 +365,16 @@ func (s *Driver) endAndBeginBlock(chain ChainId, timeAdvancement time.Duration) return header } -func (s *Driver) runningConsumers() []providertypes.Chain { - consumersOnProvider := s.providerKeeper().GetAllConsumerChains(s.providerCtx()) +func (s *Driver) runningConsumerChainIDs() []ChainId { + consumerIDsOnProvider := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx()) - consumersWithIntactChannel := make([]providertypes.Chain, 0) - for _, consumer := range consumersOnProvider { - if s.path(ChainId(consumer.ChainId)).Path.EndpointA.GetChannel().State == channeltypes.CLOSED || - s.path(ChainId(consumer.ChainId)).Path.EndpointB.GetChannel().State == channeltypes.CLOSED { + consumersWithIntactChannel := make([]ChainId, 0) + for _, consumerChainID := range consumerIDsOnProvider { + if s.path(ChainId(consumerChainID)).Path.EndpointA.GetChannel().State == channeltypes.CLOSED || + s.path(ChainId(consumerChainID)).Path.EndpointB.GetChannel().State == channeltypes.CLOSED { continue } - consumersWithIntactChannel = append(consumersWithIntactChannel, consumer) + consumersWithIntactChannel = append(consumersWithIntactChannel, ChainId(consumerChainID)) } return consumersWithIntactChannel } @@ -448,8 +442,8 @@ func (s *Driver) RequestSlash( // DeliverAcks delivers, for each path, // all possible acks (up to math.MaxInt many per path). func (s *Driver) DeliverAcks() { - for _, chain := range s.runningConsumers() { - path := s.path(ChainId(chain.ChainId)) + for _, chainID := range s.runningConsumerChainIDs() { + path := s.path(chainID) path.DeliverAcks(path.Path.EndpointA.Chain.ChainID, gomath.MaxInt) path.DeliverAcks(path.Path.EndpointB.Chain.ChainID, gomath.MaxInt) } diff --git a/tests/mbt/driver/mbt_test.go b/tests/mbt/driver/mbt_test.go index 3e2cfe7901..70fded614a 100644 --- a/tests/mbt/driver/mbt_test.go +++ b/tests/mbt/driver/mbt_test.go @@ -306,21 +306,21 @@ func RunItfTrace(t *testing.T, path string) { // needs a header of height H+1 to accept the packet // so, we do two blocks, one with a very small increment, // and then another to increment the rest of the time - runningConsumersBefore := driver.runningConsumers() + runningConsumerChainIDsBefore := driver.runningConsumerChainIDs() driver.endAndBeginBlock("provider", 1*time.Nanosecond) - for _, consumer := range driver.runningConsumers() { - UpdateProviderClientOnConsumer(t, driver, consumer.ChainId) + for _, consumerChainID := range driver.runningConsumerChainIDs() { + UpdateProviderClientOnConsumer(t, driver, string(consumerChainID)) } driver.endAndBeginBlock("provider", time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond) - runningConsumersAfter := driver.runningConsumers() + runningConsumerChainIDsAfter := driver.runningConsumerChainIDs() // the consumers that were running before but not after must have timed out - for _, consumer := range runningConsumersBefore { + for _, consumerChainID := range runningConsumerChainIDsBefore { found := false - for _, consumerAfter := range runningConsumersAfter { - if consumerAfter.ChainId == consumer.ChainId { + for _, consumerChainIDAfter := range runningConsumerChainIDsAfter { + if consumerChainIDAfter == consumerChainID { found = true break } @@ -334,8 +334,8 @@ func RunItfTrace(t *testing.T, path string) { // because setting up chains will modify timestamps // when the coordinator is starting chains lastTimestamps := make(map[ChainId]time.Time, len(consumers)) - for _, consumer := range driver.runningConsumers() { - lastTimestamps[ChainId(consumer.ChainId)] = driver.runningTime(ChainId(consumer.ChainId)) + for _, consumerChainID := range driver.runningConsumerChainIDs() { + lastTimestamps[consumerChainID] = driver.runningTime(consumerChainID) } driver.coordinator.CurrentTime = driver.runningTime("provider") @@ -366,12 +366,12 @@ func RunItfTrace(t *testing.T, path string) { // for all connected consumers, update the clients... // unless it was the last consumer to be started, in which case it already has the header // as we called driver.setupConsumer - for _, consumer := range driver.runningConsumers() { - if len(consumersToStart) > 0 && consumer.ChainId == consumersToStart[len(consumersToStart)-1].Value.(string) { + for _, consumerChainID := range driver.runningConsumerChainIDs() { + if len(consumersToStart) > 0 && string(consumerChainID) == consumersToStart[len(consumersToStart)-1].Value.(string) { continue } - UpdateProviderClientOnConsumer(t, driver, consumer.ChainId) + UpdateProviderClientOnConsumer(t, driver, string(consumerChainID)) } case "EndAndBeginBlockForConsumer": @@ -492,33 +492,33 @@ func RunItfTrace(t *testing.T, path string) { t.Logf("Comparing model state to actual state...") // compare the running consumers - modelRunningConsumers := RunningConsumers(currentModelState) + modelRunningConsumerChainIDs := RunningConsumers(currentModelState) - systemRunningConsumers := driver.runningConsumers() - actualRunningConsumers := make([]string, len(systemRunningConsumers)) - for i, chain := range systemRunningConsumers { - actualRunningConsumers[i] = chain.ChainId + systemRunningConsumerChainIDs := driver.runningConsumerChainIDs() + actualRunningConsumerChainIDs := make([]string, len(systemRunningConsumerChainIDs)) + for i, chainID := range systemRunningConsumerChainIDs { + actualRunningConsumerChainIDs[i] = string(chainID) } // sort the slices so that we can compare them - sort.Slice(modelRunningConsumers, func(i, j int) bool { - return modelRunningConsumers[i] < modelRunningConsumers[j] + sort.Slice(modelRunningConsumerChainIDs, func(i, j int) bool { + return modelRunningConsumerChainIDs[i] < modelRunningConsumerChainIDs[j] }) - sort.Slice(actualRunningConsumers, func(i, j int) bool { - return actualRunningConsumers[i] < actualRunningConsumers[j] + sort.Slice(actualRunningConsumerChainIDs, func(i, j int) bool { + return actualRunningConsumerChainIDs[i] < actualRunningConsumerChainIDs[j] }) - require.Equal(t, modelRunningConsumers, actualRunningConsumers, "Running consumers do not match") + require.Equal(t, modelRunningConsumerChainIDs, actualRunningConsumerChainIDs, "Running consumers do not match") // check validator sets - provider current validator set should be the one from the staking keeper - CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers, realAddrsToModelConsAddrs) + CompareValidatorSets(t, driver, currentModelState, actualRunningConsumerChainIDs, realAddrsToModelConsAddrs) // check times - sanity check that the block times match the ones from the model - CompareTimes(driver, actualRunningConsumers, currentModelState, timeOffset) + CompareTimes(driver, actualRunningConsumerChainIDs, currentModelState, timeOffset) // check sent packets: we check that the package queues in the model and the system have the same length. - for _, consumer := range actualRunningConsumers { - ComparePacketQueues(t, driver, currentModelState, consumer, timeOffset) + for _, consumerChainID := range actualRunningConsumerChainIDs { + ComparePacketQueues(t, driver, currentModelState, consumerChainID, timeOffset) } // compare that the sent packets on the proider match the model CompareSentPacketsOnProvider(driver, currentModelState, timeOffset) @@ -528,8 +528,8 @@ func RunItfTrace(t *testing.T, path string) { CompareJailedValidators(driver, currentModelState, timeOffset, addressMap) // for all newly sent vsc packets, figure out which vsc id in the model they correspond to - for _, consumer := range actualRunningConsumers { - actualPackets := driver.packetQueue(PROVIDER, ChainId(consumer)) + for _, consumerChainID := range actualRunningConsumerChainIDs { + actualPackets := driver.packetQueue(PROVIDER, ChainId(consumerChainID)) actualNewPackets := make([]types.ValidatorSetChangePacketData, 0) for _, packet := range actualPackets { @@ -545,7 +545,7 @@ func RunItfTrace(t *testing.T, path string) { actualNewPackets = append(actualNewPackets, packetData) } - modelPackets := PacketQueue(currentModelState, PROVIDER, consumer) + modelPackets := PacketQueue(currentModelState, PROVIDER, consumerChainID) newModelVscIds := make([]uint64, 0) for _, packet := range modelPackets { modelVscId := uint64(packet.Value.(itf.MapExprType)["value"].Value.(itf.MapExprType)["id"].Value.(int64)) @@ -785,15 +785,15 @@ func CompareValSet(modelValSet map[string]itf.Expr, systemValSet map[string]int6 } func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]itf.Expr, timeOffset time.Time) { - for _, consumer := range driver.runningConsumers() { - vscSendTimestamps := driver.providerKeeper().GetAllVscSendTimestamps(driver.providerCtx(), consumer.ChainId) + for _, consumerChainID := range driver.runningConsumerChainIDs() { + vscSendTimestamps := driver.providerKeeper().GetAllVscSendTimestamps(driver.providerCtx(), string(consumerChainID)) actualVscSendTimestamps := make([]time.Time, 0) for _, vscSendTimestamp := range vscSendTimestamps { actualVscSendTimestamps = append(actualVscSendTimestamps, vscSendTimestamp.Timestamp) } - modelVscSendTimestamps := VscSendTimestamps(currentModelState, consumer.ChainId) + modelVscSendTimestamps := VscSendTimestamps(currentModelState, string(consumerChainID)) for i, modelVscSendTimestamp := range modelVscSendTimestamps { actualTimeWithOffset := actualVscSendTimestamps[i].Unix() - timeOffset.Unix() @@ -802,7 +802,7 @@ func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]i modelVscSendTimestamp, actualTimeWithOffset, "Vsc send timestamps do not match for consumer %v", - consumer.ChainId, + consumerChainID, ) } } @@ -860,9 +860,9 @@ func (s *Stats) EnterStats(driver *Driver) { // max number of in-flight packets inFlightPackets := 0 - for _, consumer := range driver.runningConsumers() { - inFlightPackets += len(driver.packetQueue(PROVIDER, ChainId(consumer.ChainId))) - inFlightPackets += len(driver.packetQueue(ChainId(consumer.ChainId), PROVIDER)) + for _, consumerChainID := range driver.runningConsumerChainIDs() { + inFlightPackets += len(driver.packetQueue(PROVIDER, consumerChainID)) + inFlightPackets += len(driver.packetQueue(consumerChainID, PROVIDER)) } if inFlightPackets > s.maxNumInFlightPackets { s.maxNumInFlightPackets = inFlightPackets diff --git a/x/ccv/provider/keeper/distribution.go b/x/ccv/provider/keeper/distribution.go index 52d5440344..37571be6f0 100644 --- a/x/ccv/provider/keeper/distribution.go +++ b/x/ccv/provider/keeper/distribution.go @@ -75,14 +75,14 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) { } // Iterate over all registered consumer chains - for _, consumer := range k.GetAllConsumerChains(ctx) { + for _, consumerChainID := range k.GetAllRegisteredConsumerChainIDs(ctx) { // transfer the consumer rewards to the distribution module account // note that the rewards transferred are only consumer whitelisted denoms - rewardsCollected, err := k.TransferConsumerRewardsToDistributionModule(ctx, consumer.ChainId) + rewardsCollected, err := k.TransferConsumerRewardsToDistributionModule(ctx, consumerChainID) if err != nil { k.Logger(ctx).Error( "fail to transfer rewards to distribution module for chain %s: %s", - consumer.ChainId, + consumerChainID, err, ) continue @@ -97,12 +97,12 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) { // temporary workaround to keep CanWithdrawInvariant happy // general discussions here: https://github.com/cosmos/cosmos-sdk/issues/2906#issuecomment-441867634 - if k.ComputeConsumerTotalVotingPower(ctx, consumer.ChainId) == 0 { + if k.ComputeConsumerTotalVotingPower(ctx, consumerChainID) == 0 { err := k.distributionKeeper.FundCommunityPool(context.Context(ctx), rewardsCollected, k.accountKeeper.GetModuleAccount(ctx, types.ConsumerRewardsPool).GetAddress()) if err != nil { k.Logger(ctx).Error( "fail to allocate rewards from consumer chain %s to community pool: %s", - consumer.ChainId, + consumerChainID, err, ) } @@ -116,7 +116,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) { if err != nil { k.Logger(ctx).Error( "cannot get community tax while allocating rewards from consumer chain %s: %s", - consumer.ChainId, + consumerChainID, err, ) continue @@ -127,7 +127,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) { // allocate tokens to consumer validators feeAllocated := k.AllocateTokensToConsumerValidators( ctx, - consumer.ChainId, + consumerChainID, feeMultiplier, ) remaining = remaining.Sub(feeAllocated) @@ -138,7 +138,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) { if err != nil { k.Logger(ctx).Error( "fail to allocate rewards from consumer chain %s to community pool: %s", - consumer.ChainId, + consumerChainID, err, ) continue diff --git a/x/ccv/provider/keeper/genesis.go b/x/ccv/provider/keeper/genesis.go index 66895233a7..b8d6d179fc 100644 --- a/x/ccv/provider/keeper/genesis.go +++ b/x/ccv/provider/keeper/genesis.go @@ -108,47 +108,51 @@ func (k Keeper) InitGenesis(ctx sdk.Context, genState *types.GenesisState) { // ExportGenesis returns the CCV provider module's exported genesis func (k Keeper) ExportGenesis(ctx sdk.Context) *types.GenesisState { // get a list of all registered consumer chains - registeredChains := k.GetAllConsumerChains(ctx) + registeredChainIDs := k.GetAllRegisteredConsumerChainIDs(ctx) var exportedVscSendTimestamps []types.ExportedVscSendTimestamp // export states for each consumer chains var consumerStates []types.ConsumerState - for _, chain := range registeredChains { - gen, found := k.GetConsumerGenesis(ctx, chain.ChainId) + for _, chainID := range registeredChainIDs { + // no need for the second return value of GetConsumerClientId + // as GetAllRegisteredConsumerChainIDs already iterated through + // the entire prefix range + clientID, _ := k.GetConsumerClientId(ctx, chainID) + gen, found := k.GetConsumerGenesis(ctx, chainID) if !found { - panic(fmt.Errorf("cannot find genesis for consumer chain %s with client %s", chain.ChainId, chain.ClientId)) + panic(fmt.Errorf("cannot find genesis for consumer chain %s with client %s", chainID, clientID)) } // initial consumer chain states cs := types.ConsumerState{ - ChainId: chain.ChainId, - ClientId: chain.ClientId, + ChainId: chainID, + ClientId: clientID, ConsumerGenesis: gen, - UnbondingOpsIndex: k.GetAllUnbondingOpIndexes(ctx, chain.ChainId), + UnbondingOpsIndex: k.GetAllUnbondingOpIndexes(ctx, chainID), } // try to find channel id for the current consumer chain - channelId, found := k.GetChainToChannel(ctx, chain.ChainId) + channelId, found := k.GetChainToChannel(ctx, chainID) if found { cs.ChannelId = channelId - cs.InitialHeight, found = k.GetInitChainHeight(ctx, chain.ChainId) + cs.InitialHeight, found = k.GetInitChainHeight(ctx, chainID) if !found { - panic(fmt.Errorf("cannot find init height for consumer chain %s", chain.ChainId)) + panic(fmt.Errorf("cannot find init height for consumer chain %s", chainID)) } - cs.SlashDowntimeAck = k.GetSlashAcks(ctx, chain.ChainId) + cs.SlashDowntimeAck = k.GetSlashAcks(ctx, chainID) } - cs.PendingValsetChanges = k.GetPendingVSCPackets(ctx, chain.ChainId) + cs.PendingValsetChanges = k.GetPendingVSCPackets(ctx, chainID) consumerStates = append(consumerStates, cs) - vscSendTimestamps := k.GetAllVscSendTimestamps(ctx, chain.ChainId) - exportedVscSendTimestamps = append(exportedVscSendTimestamps, types.ExportedVscSendTimestamp{ChainId: chain.ChainId, VscSendTimestamps: vscSendTimestamps}) + vscSendTimestamps := k.GetAllVscSendTimestamps(ctx, chainID) + exportedVscSendTimestamps = append(exportedVscSendTimestamps, types.ExportedVscSendTimestamp{ChainId: chainID, VscSendTimestamps: vscSendTimestamps}) } // ConsumerAddrsToPrune are added only for registered consumer chains consumerAddrsToPrune := []types.ConsumerAddrsToPrune{} - for _, chain := range registeredChains { - consumerAddrsToPrune = append(consumerAddrsToPrune, k.GetAllConsumerAddrsToPrune(ctx, chain.ChainId)...) + for _, chainID := range registeredChainIDs { + consumerAddrsToPrune = append(consumerAddrsToPrune, k.GetAllConsumerAddrsToPrune(ctx, chainID)...) } params := k.GetParams(ctx) diff --git a/x/ccv/provider/keeper/grpc_query.go b/x/ccv/provider/keeper/grpc_query.go index 4ff0a2c74c..32042da5c7 100644 --- a/x/ccv/provider/keeper/grpc_query.go +++ b/x/ccv/provider/keeper/grpc_query.go @@ -10,6 +10,7 @@ import ( errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" ccvtypes "github.com/cosmos/interchain-security/v5/x/ccv/types" @@ -47,15 +48,70 @@ func (k Keeper) QueryConsumerChains(goCtx context.Context, req *types.QueryConsu ctx := sdk.UnwrapSDKContext(goCtx) chains := []*types.Chain{} - for _, chain := range k.GetAllConsumerChains(ctx) { - // prevent implicit memory aliasing - c := chain + for _, chainID := range k.GetAllRegisteredConsumerChainIDs(ctx) { + c, err := k.GetConsumerChain(ctx, chainID) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } chains = append(chains, &c) } return &types.QueryConsumerChainsResponse{Chains: chains}, nil } +// GetConsumerChain returns a Chain data structure with all the necessary fields +func (k Keeper) GetConsumerChain(ctx sdk.Context, chainID string) (types.Chain, error) { + clientID, found := k.GetConsumerClientId(ctx, chainID) + if !found { + return types.Chain{}, fmt.Errorf("cannot find clientID for consumer (%s)", chainID) + } + + topN, found := k.GetTopN(ctx, chainID) + + // Get MinPowerInTop_N + var minPowerInTopN int64 + if found && topN > 0 { + bondedValidators, err := k.stakingKeeper.GetLastValidators(ctx) + if err != nil { + return types.Chain{}, err + } + res, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN) + if err != nil { + return types.Chain{}, fmt.Errorf("failed to compute min power to opt in for chain (%s): %w", chainID, err) + } + minPowerInTopN = res + } else { + minPowerInTopN = -1 + } + + validatorSetCap, _ := k.GetValidatorSetCap(ctx, chainID) + + validatorsPowerCap, _ := k.GetValidatorsPowerCap(ctx, chainID) + + allowlist := k.GetAllowList(ctx, chainID) + strAllowlist := make([]string, len(allowlist)) + for i, addr := range allowlist { + strAllowlist[i] = addr.String() + } + + denylist := k.GetDenyList(ctx, chainID) + strDenylist := make([]string, len(denylist)) + for i, addr := range denylist { + strDenylist[i] = addr.String() + } + + return types.Chain{ + ChainId: chainID, + ClientId: clientID, + Top_N: topN, + MinPowerInTop_N: minPowerInTopN, + ValidatorSetCap: validatorSetCap, + ValidatorsPowerCap: validatorsPowerCap, + Allowlist: strAllowlist, + Denylist: strDenylist, + }, nil +} + func (k Keeper) QueryConsumerChainStarts(goCtx context.Context, req *types.QueryConsumerChainStartProposalsRequest) (*types.QueryConsumerChainStartProposalsResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "empty request") @@ -311,11 +367,9 @@ func (k Keeper) QueryConsumerChainsValidatorHasToValidate(goCtx context.Context, // get all the consumer chains for which the validator is either already // opted-in, currently a consumer validator or if its voting power is within the TopN validators consumersToValidate := []string{} - for _, consumer := range k.GetAllConsumerChains(ctx) { - chainID := consumer.ChainId - - if hasToValidate, err := k.HasToValidate(ctx, provAddr, chainID); err == nil && hasToValidate { - consumersToValidate = append(consumersToValidate, chainID) + for _, consumerChainID := range k.GetAllRegisteredConsumerChainIDs(ctx) { + if hasToValidate, err := k.hasToValidate(ctx, provAddr, consumerChainID); err == nil && hasToValidate { + consumersToValidate = append(consumersToValidate, consumerChainID) } } @@ -324,6 +378,51 @@ func (k Keeper) QueryConsumerChainsValidatorHasToValidate(goCtx context.Context, }, nil } +// hasToValidate checks if a validator needs to validate on a consumer chain +func (k Keeper) hasToValidate( + ctx sdk.Context, + provAddr types.ProviderConsAddress, + chainID string, +) (bool, error) { + // if the validator was sent as part of the packet in the last epoch, it has to validate + if k.IsConsumerValidator(ctx, chainID, provAddr) { + return true, nil + } + + // if the validator was not part of the last epoch, check if the validator is going to be part of te next epoch + bondedValidators, err := k.stakingKeeper.GetLastValidators(ctx) + if err != nil { + return false, errorsmod.Wrapf(stakingtypes.ErrNoValidatorFound, "error getting last bonded validators: %s", err) + } + if topN, found := k.GetTopN(ctx, chainID); found && topN > 0 { + // in a Top-N chain, we automatically opt in all validators that belong to the top N + minPower, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN) + if err == nil { + k.OptInTopNValidators(ctx, chainID, bondedValidators, minPower) + } else { + k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chainID, "error", err) + } + } + + // if the validator is opted in and belongs to the validators of the next epoch, then if nothing changes + // the validator would have to validate in the next epoch + if k.IsOptedIn(ctx, chainID, provAddr) { + lastVals, err := k.stakingKeeper.GetLastValidators(ctx) + if err != nil { + return false, err + } + nextValidators := k.ComputeNextValidators(ctx, chainID, lastVals) + for _, v := range nextValidators { + consAddr := sdk.ConsAddress(v.ProviderConsAddr) + if provAddr.ToSdkConsAddr().Equals(consAddr) { + return true, nil + } + } + } + + return false, nil +} + // QueryValidatorConsumerCommissionRate returns the commission rate a given // validator charges on a given consumer chain func (k Keeper) QueryValidatorConsumerCommissionRate(goCtx context.Context, req *types.QueryValidatorConsumerCommissionRateRequest) (*types.QueryValidatorConsumerCommissionRateResponse, error) { diff --git a/x/ccv/provider/keeper/grpc_query_test.go b/x/ccv/provider/keeper/grpc_query_test.go index c5d9520c6b..2bc9304882 100644 --- a/x/ccv/provider/keeper/grpc_query_test.go +++ b/x/ccv/provider/keeper/grpc_query_test.go @@ -1,20 +1,22 @@ package keeper_test import ( + "fmt" "testing" "time" "cosmossdk.io/math" "github.com/cometbft/cometbft/proto/tendermint/crypto" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - + sdk "github.com/cosmos/cosmos-sdk/types" sdktypes "github.com/cosmos/cosmos-sdk/types" cryptotestutil "github.com/cosmos/interchain-security/v5/testutil/crypto" testkeeper "github.com/cosmos/interchain-security/v5/testutil/keeper" "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" ccvtypes "github.com/cosmos/interchain-security/v5/x/ccv/types" + "github.com/stretchr/testify/require" ) func TestQueryAllPairsValConAddrByConsumerChainID(t *testing.T) { @@ -250,3 +252,95 @@ func TestQueryValidatorConsumerCommissionRate(t *testing.T) { res, _ = pk.QueryValidatorConsumerCommissionRate(ctx, &req) require.Equal(t, expectedCommissionRate, res.Rate) } + +// TestGetConsumerChain tests GetConsumerChain behaviour correctness +func TestGetConsumerChain(t *testing.T) { + pk, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) + defer ctrl.Finish() + + chainIDs := []string{"chain-1", "chain-2", "chain-3", "chain-4"} + + // mock the validator set + vals := []stakingtypes.Validator{ + {OperatorAddress: "cosmosvaloper1c4k24jzduc365kywrsvf5ujz4ya6mwympnc4en"}, // 50 power + {OperatorAddress: "cosmosvaloper196ax4vc0lwpxndu9dyhvca7jhxp70rmcvrj90c"}, // 150 power + {OperatorAddress: "cosmosvaloper1clpqr4nrk4khgkxj78fcwwh6dl3uw4epsluffn"}, // 300 power + {OperatorAddress: "cosmosvaloper1tflk30mq5vgqjdly92kkhhq3raev2hnz6eete3"}, // 500 power + } + powers := []int64{50, 150, 300, 500} // sum = 1000 + mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Return(vals, nil).AnyTimes() + + for i, val := range vals { + valAddr, err := sdk.ValAddressFromBech32(val.GetOperator()) + require.NoError(t, err) + mocks.MockStakingKeeper.EXPECT().GetLastValidatorPower(gomock.Any(), valAddr).Return(powers[i], nil).AnyTimes() + } + + // set Top N parameters, client ids and expected result + topNs := []uint32{0, 70, 90, 100} + expectedMinPowerInTopNs := []int64{ + -1, // Top N is 0, so not a Top N chain + 300, // 500 and 300 are in Top 70% + 150, // 150 is also in the top 90%, + 50, // everyone is in the top 100% + } + + validatorSetCaps := []uint32{0, 5, 10, 20} + validatorPowerCaps := []uint32{0, 5, 10, 33} + allowlists := [][]types.ProviderConsAddress{ + {}, + {types.NewProviderConsAddress([]byte("providerAddr1")), types.NewProviderConsAddress([]byte("providerAddr2"))}, + {types.NewProviderConsAddress([]byte("providerAddr3"))}, + {}, + } + + denylists := [][]types.ProviderConsAddress{ + {types.NewProviderConsAddress([]byte("providerAddr4")), types.NewProviderConsAddress([]byte("providerAddr5"))}, + {}, + {types.NewProviderConsAddress([]byte("providerAddr6"))}, + {}, + } + + expectedGetAllOrder := []types.Chain{} + for i, chainID := range chainIDs { + clientID := fmt.Sprintf("client-%d", len(chainIDs)-i) + topN := topNs[i] + pk.SetConsumerClientId(ctx, chainID, clientID) + pk.SetTopN(ctx, chainID, topN) + pk.SetValidatorSetCap(ctx, chainID, validatorSetCaps[i]) + pk.SetValidatorsPowerCap(ctx, chainID, validatorPowerCaps[i]) + for _, addr := range allowlists[i] { + pk.SetAllowlist(ctx, chainID, addr) + } + for _, addr := range denylists[i] { + pk.SetDenylist(ctx, chainID, addr) + } + strAllowlist := make([]string, len(allowlists[i])) + for j, addr := range allowlists[i] { + strAllowlist[j] = addr.String() + } + + strDenylist := make([]string, len(denylists[i])) + for j, addr := range denylists[i] { + strDenylist[j] = addr.String() + } + + expectedGetAllOrder = append(expectedGetAllOrder, + types.Chain{ + ChainId: chainID, + ClientId: clientID, + Top_N: topN, + MinPowerInTop_N: expectedMinPowerInTopNs[i], + ValidatorSetCap: validatorSetCaps[i], + ValidatorsPowerCap: validatorPowerCaps[i], + Allowlist: strAllowlist, + Denylist: strDenylist, + }) + } + + for i, chainID := range pk.GetAllRegisteredAndProposedChainIDs(ctx) { + c, err := pk.GetConsumerChain(ctx, chainID) + require.NoError(t, err) + require.Equal(t, expectedGetAllOrder[i], c) + } +} diff --git a/x/ccv/provider/keeper/hooks.go b/x/ccv/provider/keeper/hooks.go index 2131a974b0..4c4672ce4c 100644 --- a/x/ccv/provider/keeper/hooks.go +++ b/x/ccv/provider/keeper/hooks.go @@ -10,6 +10,7 @@ import ( v1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" providertypes "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" ccvtypes "github.com/cosmos/interchain-security/v5/x/ccv/types" ) @@ -93,9 +94,9 @@ func (h Hooks) AfterUnbondingInitiated(goCtx context.Context, id uint64) error { } // get all consumers where the validator is in the validator set - for _, chain := range h.k.GetAllConsumerChains(ctx) { - if h.k.IsConsumerValidator(ctx, chain.ChainId, providertypes.NewProviderConsAddress(consAddr)) { - consumerChainIDS = append(consumerChainIDS, chain.ChainId) + for _, chainID := range h.k.GetAllRegisteredConsumerChainIDs(ctx) { + if h.k.IsConsumerValidator(ctx, chainID, types.NewProviderConsAddress(consAddr)) { + consumerChainIDS = append(consumerChainIDS, chainID) } } diff --git a/x/ccv/provider/keeper/keeper.go b/x/ccv/provider/keeper/keeper.go index 042f05a285..cc057ba123 100644 --- a/x/ccv/provider/keeper/keeper.go +++ b/x/ccv/provider/keeper/keeper.go @@ -266,13 +266,15 @@ func (k Keeper) GetAllPendingConsumerChainIDs(ctx sdk.Context) []string { return chainIDs } -// GetAllConsumerChains gets all of the consumer chains, for which the provider module +// GetAllRegisteredConsumerChainIDs gets all of the consumer chain IDs, for which the provider module // created IBC clients. Consumer chains with created clients are also referred to as registered. // // Note that the registered consumer chains are stored under keys with the following format: // ChainToClientBytePrefix | chainID // Thus, the returned array is in ascending order of chainIDs. -func (k Keeper) GetAllConsumerChains(ctx sdk.Context) (chains []types.Chain) { +func (k Keeper) GetAllRegisteredConsumerChainIDs(ctx sdk.Context) []string { + chainIDs := []string{} + store := ctx.KVStore(k.storeKey) iterator := storetypes.KVStorePrefixIterator(store, []byte{types.ChainToClientBytePrefix}) defer iterator.Close() @@ -280,56 +282,10 @@ func (k Keeper) GetAllConsumerChains(ctx sdk.Context) (chains []types.Chain) { for ; iterator.Valid(); iterator.Next() { // remove 1 byte prefix from key to retrieve chainID chainID := string(iterator.Key()[1:]) - clientID := string(iterator.Value()) - - topN, found := k.GetTopN(ctx, chainID) - - var minPowerInTopN int64 - if found && topN > 0 { - lastVals, err := k.stakingKeeper.GetLastValidators(ctx) - if err != nil { - k.Logger(ctx).Error("failed to get last validators", "chain", chainID, "error", err) - minPowerInTopN = -1 - } else { - res, err := k.ComputeMinPowerToOptIn(ctx, lastVals, topN) - if err != nil { - k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chainID, "error", err) - minPowerInTopN = -1 - } else { - minPowerInTopN = res - } - } - } else { - minPowerInTopN = -1 - } - - validatorSetCap, _ := k.GetValidatorSetCap(ctx, chainID) - validatorsPowerCap, _ := k.GetValidatorsPowerCap(ctx, chainID) - allowlist := k.GetAllowList(ctx, chainID) - strAllowlist := make([]string, len(allowlist)) - for i, addr := range allowlist { - strAllowlist[i] = addr.String() - } - - denylist := k.GetDenyList(ctx, chainID) - strDenylist := make([]string, len(denylist)) - for i, addr := range denylist { - strDenylist[i] = addr.String() - } - - chains = append(chains, types.Chain{ - ChainId: chainID, - ClientId: clientID, - Top_N: topN, - MinPowerInTop_N: minPowerInTopN, - ValidatorSetCap: validatorSetCap, - ValidatorsPowerCap: validatorsPowerCap, - Allowlist: strAllowlist, - Denylist: strDenylist, - }) + chainIDs = append(chainIDs, chainID) } - return chains + return chainIDs } // SetChannelToChain sets the mapping from the CCV channel ID to the consumer chainID. @@ -1195,10 +1151,7 @@ func (k Keeper) BondDenom(ctx sdk.Context) (string, error) { func (k Keeper) GetAllRegisteredAndProposedChainIDs(ctx sdk.Context) []string { allConsumerChains := []string{} - consumerChains := k.GetAllConsumerChains(ctx) - for _, consumerChain := range consumerChains { - allConsumerChains = append(allConsumerChains, consumerChain.ChainId) - } + allConsumerChains = append(allConsumerChains, k.GetAllRegisteredConsumerChainIDs(ctx)...) proposedChains := k.GetAllProposedConsumerChainIDs(ctx) for _, proposedChain := range proposedChains { allConsumerChains = append(allConsumerChains, proposedChain.ChainID) @@ -1320,50 +1273,6 @@ func (k Keeper) DeleteAllOptedIn( } } -func (k Keeper) HasToValidate( - ctx sdk.Context, - provAddr types.ProviderConsAddress, - chainID string, -) (bool, error) { - // if the validator was sent as part of the packet in the last epoch, it has to validate - if k.IsConsumerValidator(ctx, chainID, provAddr) { - return true, nil - } - - // if the validator was not part of the last epoch, check if the validator is going to be part of te next epoch - bondedValidators, err := k.stakingKeeper.GetLastValidators(ctx) - if err != nil { - return false, err - } - if topN, found := k.GetTopN(ctx, chainID); found && topN > 0 { - // in a Top-N chain, we automatically opt in all validators that belong to the top N - minPower, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN) - if err == nil { - k.OptInTopNValidators(ctx, chainID, bondedValidators, minPower) - } else { - k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chainID, "error", err) - } - } - - // if the validator is opted in and belongs to the validators of the next epoch, then if nothing changes - // the validator would have to validate in the next epoch - if k.IsOptedIn(ctx, chainID, provAddr) { - lastVals, err := k.stakingKeeper.GetLastValidators(ctx) - if err != nil { - return false, err - } - nextValidators := k.ComputeNextValidators(ctx, chainID, lastVals) - for _, v := range nextValidators { - consAddr := sdk.ConsAddress(v.ProviderConsAddr) - if provAddr.ToSdkConsAddr().Equals(consAddr) { - return true, nil - } - } - } - - return false, nil -} - // SetConsumerCommissionRate sets a per-consumer chain commission rate // for the given validator address func (k Keeper) SetConsumerCommissionRate( diff --git a/x/ccv/provider/keeper/keeper_test.go b/x/ccv/provider/keeper/keeper_test.go index 23f3cc00ef..fd850c6e13 100644 --- a/x/ccv/provider/keeper/keeper_test.go +++ b/x/ccv/provider/keeper/keeper_test.go @@ -9,11 +9,9 @@ import ( "cosmossdk.io/math" ibctesting "github.com/cosmos/ibc-go/v8/testing" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" - sdk "github.com/cosmos/cosmos-sdk/types" abci "github.com/cometbft/cometbft/abci/types" tmprotocrypto "github.com/cometbft/cometbft/proto/tendermint/crypto" @@ -22,8 +20,6 @@ import ( testkeeper "github.com/cosmos/interchain-security/v5/testutil/keeper" "github.com/cosmos/interchain-security/v5/x/ccv/provider/types" ccv "github.com/cosmos/interchain-security/v5/x/ccv/types" - - stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" ) const consumer = "consumer" @@ -397,98 +393,22 @@ func TestVscSendTimestamp(t *testing.T) { require.Empty(t, providerKeeper.GetAllVscSendTimestamps(ctx, chainID)) } -// TestGetAllConsumerChains tests GetAllConsumerChains behaviour correctness -func TestGetAllConsumerChains(t *testing.T) { - pk, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) +func TestGetAllRegisteredConsumerChainIDs(t *testing.T) { + pk, ctx, ctrl, _ := testkeeper.GetProviderKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) defer ctrl.Finish() chainIDs := []string{"chain-2", "chain-1", "chain-4", "chain-3"} + // GetAllRegisteredConsumerChainIDs iterates over chainID in lexicographical order + expectedChainIDs := []string{"chain-1", "chain-2", "chain-3", "chain-4"} - // mock the validator set - vals := []stakingtypes.Validator{ - {OperatorAddress: "cosmosvaloper1c4k24jzduc365kywrsvf5ujz4ya6mwympnc4en"}, // 50 power - {OperatorAddress: "cosmosvaloper196ax4vc0lwpxndu9dyhvca7jhxp70rmcvrj90c"}, // 150 power - {OperatorAddress: "cosmosvaloper1clpqr4nrk4khgkxj78fcwwh6dl3uw4epsluffn"}, // 300 power - {OperatorAddress: "cosmosvaloper1tflk30mq5vgqjdly92kkhhq3raev2hnz6eete3"}, // 500 power - } - powers := []int64{50, 150, 300, 500} // sum = 1000 - mocks.MockStakingKeeper.EXPECT().GetLastValidators(gomock.Any()).Return(vals, nil).AnyTimes() - - for i, val := range vals { - valAddr, err := sdk.ValAddressFromBech32(val.GetOperator()) - require.NoError(t, err) - mocks.MockStakingKeeper.EXPECT().GetLastValidatorPower(gomock.Any(), valAddr).Return(powers[i], nil).AnyTimes() - } - - // set Top N parameters, client ids and expected result - topNs := []uint32{0, 70, 90, 100} - expectedMinPowerInTopNs := []int64{ - -1, // Top N is 0, so not a Top N chain - 300, // 500 and 300 are in Top 70% - 150, // 150 is also in the top 90%, - 50, // everyone is in the top 100% - } - - validatorSetCaps := []uint32{0, 5, 10, 20} - validatorPowerCaps := []uint32{0, 5, 10, 33} - allowlists := [][]types.ProviderConsAddress{ - {}, - {types.NewProviderConsAddress([]byte("providerAddr1")), types.NewProviderConsAddress([]byte("providerAddr2"))}, - {types.NewProviderConsAddress([]byte("providerAddr3"))}, - {}, - } - - denylists := [][]types.ProviderConsAddress{ - {types.NewProviderConsAddress([]byte("providerAddr4")), types.NewProviderConsAddress([]byte("providerAddr5"))}, - {}, - {types.NewProviderConsAddress([]byte("providerAddr6"))}, - {}, - } - - expectedGetAllOrder := []types.Chain{} for i, chainID := range chainIDs { clientID := fmt.Sprintf("client-%d", len(chainIDs)-i) - topN := topNs[i] pk.SetConsumerClientId(ctx, chainID, clientID) - pk.SetTopN(ctx, chainID, topN) - pk.SetValidatorSetCap(ctx, chainID, validatorSetCaps[i]) - pk.SetValidatorsPowerCap(ctx, chainID, validatorPowerCaps[i]) - for _, addr := range allowlists[i] { - pk.SetAllowlist(ctx, chainID, addr) - } - for _, addr := range denylists[i] { - pk.SetDenylist(ctx, chainID, addr) - } - strAllowlist := make([]string, len(allowlists[i])) - for j, addr := range allowlists[i] { - strAllowlist[j] = addr.String() - } - - strDenylist := make([]string, len(denylists[i])) - for j, addr := range denylists[i] { - strDenylist[j] = addr.String() - } - - expectedGetAllOrder = append(expectedGetAllOrder, - types.Chain{ - ChainId: chainID, - ClientId: clientID, - Top_N: topN, - MinPowerInTop_N: expectedMinPowerInTopNs[i], - ValidatorSetCap: validatorSetCaps[i], - ValidatorsPowerCap: validatorPowerCaps[i], - Allowlist: strAllowlist, - Denylist: strDenylist, - }) - } - // sorting by chainID - sort.Slice(expectedGetAllOrder, func(i, j int) bool { - return expectedGetAllOrder[i].ChainId < expectedGetAllOrder[j].ChainId - }) + } - result := pk.GetAllConsumerChains(ctx) + result := pk.GetAllRegisteredConsumerChainIDs(ctx) require.Len(t, result, len(chainIDs)) - require.Equal(t, expectedGetAllOrder, result) + require.Equal(t, expectedChainIDs, result) } // TestGetAllChannelToChains tests GetAllChannelToChains behaviour correctness diff --git a/x/ccv/provider/keeper/partial_set_security.go b/x/ccv/provider/keeper/partial_set_security.go index 1edff9b56b..7a5c1ec33d 100644 --- a/x/ccv/provider/keeper/partial_set_security.go +++ b/x/ccv/provider/keeper/partial_set_security.go @@ -69,11 +69,11 @@ func (k Keeper) HandleOptOut(ctx sdk.Context, chainID string, providerAddr types if err != nil { return err } - lastVals, err := k.stakingKeeper.GetLastValidators(ctx) + bondedValidators, err := k.stakingKeeper.GetLastValidators(ctx) if err != nil { - return err + return errorsmod.Wrapf(stakingtypes.ErrNoValidatorFound, "error getting last bonded validators: %s", err) } - minPowerToOptIn, err := k.ComputeMinPowerToOptIn(ctx, lastVals, topN) + minPowerToOptIn, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN) if err != nil { k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chainID, "error", err) return errorsmod.Wrapf( diff --git a/x/ccv/provider/keeper/relay.go b/x/ccv/provider/keeper/relay.go index 5c4217e6ae..ff7a671559 100644 --- a/x/ccv/provider/keeper/relay.go +++ b/x/ccv/provider/keeper/relay.go @@ -167,10 +167,10 @@ func (k Keeper) EndBlockVSU(ctx sdk.Context) { // If the CCV channel is not established for a consumer chain, // the updates will remain queued until the channel is established func (k Keeper) SendVSCPackets(ctx sdk.Context) { - for _, chain := range k.GetAllConsumerChains(ctx) { + for _, chainID := range k.GetAllRegisteredConsumerChainIDs(ctx) { // check if CCV channel is established and send - if channelID, found := k.GetChainToChannel(ctx, chain.ChainId); found { - k.SendVSCPacketsToChain(ctx, chain.ChainId, channelID) + if channelID, found := k.GetChainToChannel(ctx, chainID); found { + k.SendVSCPacketsToChain(ctx, chainID, channelID) } } } @@ -225,35 +225,36 @@ func (k Keeper) QueueVSCPackets(ctx sdk.Context) { panic(fmt.Errorf("failed to get last validators: %w", err)) } - for _, chain := range k.GetAllConsumerChains(ctx) { - currentValidators := k.GetConsumerValSet(ctx, chain.ChainId) + for _, chainID := range k.GetAllRegisteredConsumerChainIDs(ctx) { + currentValidators := k.GetConsumerValSet(ctx, chainID) + topN, _ := k.GetTopN(ctx, chainID) - if chain.Top_N > 0 { + if topN > 0 { // in a Top-N chain, we automatically opt in all validators that belong to the top N - minPower, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, chain.Top_N) + minPower, err := k.ComputeMinPowerToOptIn(ctx, bondedValidators, topN) if err == nil { - k.OptInTopNValidators(ctx, chain.ChainId, bondedValidators, minPower) + k.OptInTopNValidators(ctx, chainID, bondedValidators, minPower) } else { // we just log here and do not panic because panic-ing would halt the provider chain - k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chain.ChainId, "error", err) + k.Logger(ctx).Error("failed to compute min power to opt in for chain", "chain", chainID, "error", err) } } - nextValidators := k.ComputeNextValidators(ctx, chain.ChainId, bondedValidators) + nextValidators := k.ComputeNextValidators(ctx, chainID, bondedValidators) valUpdates := DiffValidators(currentValidators, nextValidators) - k.SetConsumerValSet(ctx, chain.ChainId, nextValidators) + k.SetConsumerValSet(ctx, chainID, nextValidators) // check whether there are changes in the validator set; // note that this also entails unbonding operations // w/o changes in the voting power of the validators in the validator set - unbondingOps := k.GetUnbondingOpsFromIndex(ctx, chain.ChainId, valUpdateID) + unbondingOps := k.GetUnbondingOpsFromIndex(ctx, chainID, valUpdateID) if len(valUpdates) != 0 || len(unbondingOps) != 0 { // construct validator set change packet data - packet := ccv.NewValidatorSetChangePacketData(valUpdates, valUpdateID, k.ConsumeSlashAcks(ctx, chain.ChainId)) - k.AppendPendingVSCPackets(ctx, chain.ChainId, packet) + packet := ccv.NewValidatorSetChangePacketData(valUpdates, valUpdateID, k.ConsumeSlashAcks(ctx, chainID)) + k.AppendPendingVSCPackets(ctx, chainID, packet) k.Logger(ctx).Info("VSCPacket enqueued:", - "chainID", chain.ChainId, + "chainID", chainID, "vscID", valUpdateID, "len updates", len(valUpdates), "len unbonding ops", len(unbondingOps), diff --git a/x/ccv/provider/migrations/v3/migrations.go b/x/ccv/provider/migrations/v3/migrations.go index 2ffd1e6f25..8c17000b0b 100644 --- a/x/ccv/provider/migrations/v3/migrations.go +++ b/x/ccv/provider/migrations/v3/migrations.go @@ -11,15 +11,15 @@ import ( // MigrateQueuedPackets processes all queued packet data for all consumer chains that were stored // on the provider in the v2 consensus version (jail throttling v1). func MigrateQueuedPackets(ctx sdk.Context, k providerkeeper.Keeper) error { - for _, consumer := range k.GetAllConsumerChains(ctx) { - slashData, vscmData := k.LegacyGetAllThrottledPacketData(ctx, consumer.ChainId) + for _, consumerChainID := range k.GetAllRegisteredConsumerChainIDs(ctx) { + slashData, vscmData := k.LegacyGetAllThrottledPacketData(ctx, consumerChainID) if len(slashData) > 0 { k.Logger(ctx).Error(fmt.Sprintf("slash data being dropped: %v", slashData)) } for _, data := range vscmData { - k.HandleVSCMaturedPacket(ctx, consumer.ChainId, data) + k.HandleVSCMaturedPacket(ctx, consumerChainID, data) } - k.LegacyDeleteThrottledPacketDataForConsumer(ctx, consumer.ChainId) + k.LegacyDeleteThrottledPacketDataForConsumer(ctx, consumerChainID) } return nil } diff --git a/x/ccv/provider/migrations/v5/migrations.go b/x/ccv/provider/migrations/v5/migrations.go index aa228b6a09..411efd49e1 100644 --- a/x/ccv/provider/migrations/v5/migrations.go +++ b/x/ccv/provider/migrations/v5/migrations.go @@ -10,12 +10,9 @@ import ( // If a chain is in voting while the upgrade happens, this is not sufficient, // and a migration to rewrite the proposal is needed. func MigrateTopNForRegisteredChains(ctx sdk.Context, providerKeeper providerkeeper.Keeper) { - // get all consumer chains - registeredConsumerChains := providerKeeper.GetAllConsumerChains(ctx) - // Set the topN of each chain to 95 - for _, chain := range registeredConsumerChains { - providerKeeper.SetTopN(ctx, chain.ChainId, 95) + for _, chainID := range providerKeeper.GetAllRegisteredConsumerChainIDs(ctx) { + providerKeeper.SetTopN(ctx, chainID, 95) } }