diff --git a/.changelog/2665.internal.md b/.changelog/2665.internal.md new file mode 100644 index 00000000000..dd590723cf4 --- /dev/null +++ b/.changelog/2665.internal.md @@ -0,0 +1 @@ +Add sanity checks for stake accumulator state integrity diff --git a/go/consensus/tendermint/apps/staking/genesis.go b/go/consensus/tendermint/apps/staking/genesis.go index f1a091bf970..d3a087872b9 100644 --- a/go/consensus/tendermint/apps/staking/genesis.go +++ b/go/consensus/tendermint/apps/staking/genesis.go @@ -72,6 +72,15 @@ func (app *stakingApplication) initLedger(ctx *abci.Context, state *stakingState return errors.New("staking/tendermint: invalid genesis debonding escrow balance") } + // Make sure that the stake accumulator is empty as otherwise it could be inconsistent with + // what is registered in the genesis block. + if len(v.Escrow.StakeAccumulator.Claims) > 0 { + ctx.Logger().Error("InitChain: non-empty stake accumulator", + "id", id, + ) + return errors.New("staking/tendermint: non-empty stake accumulator in genesis") + } + ups = append(ups, ledgerUpdate{id, v}) if err := totalSupply.Add(&v.General.Balance); err != nil { ctx.Logger().Error("InitChain: failed to add general balance", @@ -261,6 +270,9 @@ func (sq *stakingQuerier) Genesis(ctx context.Context) (*staking.Genesis, error) ledger := make(map[signature.PublicKey]*staking.Account) for _, acctID := range accounts { acct := sq.state.Account(acctID) + // Make sure that export resets the stake accumulator state as that should be re-initialized + // during genesis (a genesis document with non-empty stake accumulator is invalid). + acct.Escrow.StakeAccumulator = staking.StakeAccumulator{} ledger[acctID] = acct } diff --git a/go/consensus/tendermint/apps/supplementarysanity/checks.go b/go/consensus/tendermint/apps/supplementarysanity/checks.go index 3900c8dafd3..a3b537576c9 100644 --- a/go/consensus/tendermint/apps/supplementarysanity/checks.go +++ b/go/consensus/tendermint/apps/supplementarysanity/checks.go @@ -6,6 +6,7 @@ import ( "github.com/tendermint/iavl" "github.com/oasislabs/oasis-core/go/common" + "github.com/oasislabs/oasis-core/go/common/crypto/signature" "github.com/oasislabs/oasis-core/go/common/quantity" keymanagerState "github.com/oasislabs/oasis-core/go/consensus/tendermint/apps/keymanager/state" registryState "github.com/oasislabs/oasis-core/go/consensus/tendermint/apps/registry/state" @@ -212,3 +213,79 @@ func checkHalt(*iavl.MutableTree, epochtime.EpochTime) error { // nothing to check yet return nil } + +func checkStakeClaims(state *iavl.MutableTree, now epochtime.EpochTime) error { + regSt := registryState.NewMutableState(state) + stakeSt := stakingState.NewMutableState(state) + + // Claims in the stake accumulators should be consistent with general state. + claims := make(map[signature.PublicKey]map[staking.StakeClaim][]staking.ThresholdKind) + // Entity registrations. + entities, err := regSt.Entities() + if err != nil { + return fmt.Errorf("failed to get entities: %w", err) + } + for _, entity := range entities { + claims[entity.ID] = map[staking.StakeClaim][]staking.ThresholdKind{ + registry.StakeClaimRegisterEntity: []staking.ThresholdKind{staking.KindEntity}, + } + } + // Node registrations. + nodes, err := regSt.Nodes() + if err != nil { + return fmt.Errorf("failed to get node registrations: %w", err) + } + for _, node := range nodes { + claims[node.EntityID][registry.StakeClaimForNode(node.ID)] = registry.StakeThresholdsForNode(node) + } + // Runtime registrations. + runtimes, err := regSt.AllRuntimes() + if err != nil { + return fmt.Errorf("failed to get runtime registrations: %w", err) + } + for _, rt := range runtimes { + claims[rt.EntityID][registry.StakeClaimForRuntime(rt.ID)] = registry.StakeThresholdsForRuntime(rt) + } + + // Compare with actual accumulator state. + for _, entity := range entities { + acct := stakeSt.Account(entity.ID) + expectedClaims := claims[entity.ID] + actualClaims := acct.Escrow.StakeAccumulator.Claims + if len(expectedClaims) != len(actualClaims) { + return fmt.Errorf("incorrect number of stake claims for account %s (expected: %d got: %d)", + entity.ID, + len(expectedClaims), + len(actualClaims), + ) + } + for claim, expectedThresholds := range expectedClaims { + thresholds, ok := actualClaims[claim] + if !ok { + return fmt.Errorf("missing claim %s for account %s", claim, entity.ID) + } + if len(thresholds) != len(expectedThresholds) { + return fmt.Errorf("incorrect number of thresholds for claim %s for account %s (expected: %d got: %d)", + claim, + entity.ID, + len(expectedThresholds), + len(thresholds), + ) + } + for i, expectedThreshold := range expectedThresholds { + threshold := thresholds[i] + if threshold != expectedThreshold { + return fmt.Errorf("incorrect threshold in position %d for claim %s for account %s (expected: %s got: %s)", + i, + claim, + entity.ID, + expectedThreshold, + threshold, + ) + } + } + } + } + + return nil +} diff --git a/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go b/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go index 5f1d4e4437c..131a5718eff 100644 --- a/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go +++ b/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go @@ -125,6 +125,7 @@ func (app *supplementarySanityApplication) endBlockImpl(ctx *abci.Context, reque {"checkBeacon", checkBeacon}, {"checkConsensus", checkConsensus}, {"checkHalt", checkHalt}, + {"checkStakeClaims", checkStakeClaims}, } { if err := tt.checker(state, now); err != nil { return errors.Wrap(err, tt.name) diff --git a/go/staking/api/sanity_check.go b/go/staking/api/sanity_check.go index 2db16a55cb4..8f71e1d9da8 100644 --- a/go/staking/api/sanity_check.go +++ b/go/staking/api/sanity_check.go @@ -176,6 +176,12 @@ func (g *Genesis) SanityCheck(now epochtime.EpochTime) error { // nolint: gocycl if err != nil { return err } + + // Make sure that the stake accumulator is empty as otherwise it could be inconsistent with + // what is registered in the genesis block. + if len(acct.Escrow.StakeAccumulator.Claims) > 0 { + return fmt.Errorf("staking: non-empty stake accumulator in genesis") + } } _ = total.Add(&g.CommonPool) if total.Cmp(&g.TotalSupply) != 0 {