diff --git a/x/staking/keeper/invariants.go b/x/staking/keeper/invariants.go index 3516316f6c8b..52a45504dbf8 100644 --- a/x/staking/keeper/invariants.go +++ b/x/staking/keeper/invariants.go @@ -166,20 +166,30 @@ func DelegatorSharesInvariant(k Keeper) sdk.Invariant { ) validators := k.GetAllValidators(ctx) + validatorsDelegationShares := map[string]sdk.Dec{} + + // initialize a map: validator -> its delegation shares for _, validator := range validators { - valTotalDelShares := validator.GetDelegatorShares() - totalDelShares := sdk.ZeroDec() + validatorsDelegationShares[validator.GetOperator().String()] = sdk.ZeroDec() + } - delegations := k.GetValidatorDelegations(ctx, validator.GetOperator()) - for _, delegation := range delegations { - totalDelShares = totalDelShares.Add(delegation.Shares) - } + // iterate through all the delegations to calculate the total delegation shares for each validator + delegations := k.GetAllDelegations(ctx) + for _, delegation := range delegations { + delegationValidatorAddr := delegation.GetValidatorAddr().String() + validatorDelegationShares := validatorsDelegationShares[delegationValidatorAddr] + validatorsDelegationShares[delegationValidatorAddr] = validatorDelegationShares.Add(delegation.Shares) + } - if !valTotalDelShares.Equal(totalDelShares) { + // for each validator, check if its total delegation shares calculated from the step above equals to its expected delegation shares + for _, validator := range validators { + expValTotalDelShares := validator.GetDelegatorShares() + calculatedValTotalDelShares := validatorsDelegationShares[validator.GetOperator().String()] + if !calculatedValTotalDelShares.Equal(expValTotalDelShares) { broken = true msg += fmt.Sprintf("broken delegator shares invariance:\n"+ "\tvalidator.DelegatorShares: %v\n"+ - "\tsum of Delegator.Shares: %v\n", valTotalDelShares, totalDelShares) + "\tsum of Delegator.Shares: %v\n", expValTotalDelShares, calculatedValTotalDelShares) } }