diff --git a/.changelog/2665.internal.md b/.changelog/2665.internal.md new file mode 100644 index 00000000000..dd590723cf4 --- /dev/null +++ b/.changelog/2665.internal.md @@ -0,0 +1 @@ +Add sanity checks for stake accumulator state integrity diff --git a/go/consensus/tendermint/apps/registry/registry.go b/go/consensus/tendermint/apps/registry/registry.go index 9679d28cfb7..7635694eaa7 100644 --- a/go/consensus/tendermint/apps/registry/registry.go +++ b/go/consensus/tendermint/apps/registry/registry.go @@ -189,7 +189,7 @@ func (app *registryApplication) onRegistryEpochChanged(ctx *abci.Context, regist // Remove the stake claim for the given node. if !params.DebugBypassStake { - if err = stakeAcc.RemoveStakeClaim(node.EntityID, stakeClaimForNode(node.ID)); err != nil { + if err = stakeAcc.RemoveStakeClaim(node.EntityID, registry.StakeClaimForNode(node.ID)); err != nil { return fmt.Errorf("registry: onRegistryEpochChanged: couldn't remove stake claim: %w", err) } } diff --git a/go/consensus/tendermint/apps/registry/transactions.go b/go/consensus/tendermint/apps/registry/transactions.go index cfe08634d90..c03bf813845 100644 --- a/go/consensus/tendermint/apps/registry/transactions.go +++ b/go/consensus/tendermint/apps/registry/transactions.go @@ -3,9 +3,7 @@ package registry import ( "fmt" - "github.com/oasislabs/oasis-core/go/common" "github.com/oasislabs/oasis-core/go/common/cbor" - "github.com/oasislabs/oasis-core/go/common/crypto/signature" "github.com/oasislabs/oasis-core/go/common/entity" "github.com/oasislabs/oasis-core/go/common/node" "github.com/oasislabs/oasis-core/go/consensus/tendermint/abci" @@ -16,20 +14,6 @@ import ( staking "github.com/oasislabs/oasis-core/go/staking/api" ) -const ( - claimRegisterEntity = "registry.RegisterEntity" - claimRegisterNode = "registry.RegisterNode.%s" - claimRegisterRuntime = "registry.RegisterRuntime.%s" -) - -func stakeClaimForNode(id signature.PublicKey) staking.StakeClaim { - return staking.StakeClaim(fmt.Sprintf(claimRegisterNode, id)) -} - -func stakeClaimForRuntime(id common.Namespace) staking.StakeClaim { - return staking.StakeClaim(fmt.Sprintf(claimRegisterRuntime, id)) -} - func (app *registryApplication) registerEntity( ctx *abci.Context, state *registryState.MutableState, @@ -67,7 +51,7 @@ func (app *registryApplication) registerEntity( } if !params.DebugBypassStake { - if err = stakingState.AddStakeClaim(ctx, ent.ID, claimRegisterEntity, []staking.ThresholdKind{staking.KindEntity}); err != nil { + if err = stakingState.AddStakeClaim(ctx, ent.ID, registry.StakeClaimRegisterEntity, []staking.ThresholdKind{staking.KindEntity}); err != nil { ctx.Logger().Error("RegisterEntity: Insufficent stake", "err", err, "id", ent.ID, @@ -141,7 +125,7 @@ func (app *registryApplication) deregisterEntity(ctx *abci.Context, state *regis } if !params.DebugBypassStake { - if err = stakingState.RemoveStakeClaim(ctx, id, claimRegisterEntity); err != nil { + if err = stakingState.RemoveStakeClaim(ctx, id, registry.StakeClaimRegisterEntity); err != nil { panic(fmt.Errorf("DeregisterEntity: failed to remove stake claim: %w", err)) } } @@ -320,20 +304,8 @@ func (app *registryApplication) registerNode( // nolint: gocyclo return fmt.Errorf("failed to create stake accumulator cache: %w", err) } - claim := stakeClaimForNode(newNode.ID) - var thresholds []staking.ThresholdKind - if newNode.HasRoles(node.RoleKeyManager) { - thresholds = append(thresholds, staking.KindNodeKeyManager) - } - if newNode.HasRoles(node.RoleComputeWorker) { - thresholds = append(thresholds, staking.KindNodeCompute) - } - if newNode.HasRoles(node.RoleStorageWorker) { - thresholds = append(thresholds, staking.KindNodeStorage) - } - if newNode.HasRoles(node.RoleValidator) { - thresholds = append(thresholds, staking.KindNodeValidator) - } + claim := registry.StakeClaimForNode(newNode.ID) + thresholds := registry.StakeThresholdsForNode(newNode) if err = stakeAcc.AddStakeClaim(newNode.EntityID, claim, thresholds); err != nil { ctx.Logger().Error("RegisterNode: insufficient stake for new node", @@ -609,20 +581,8 @@ func (app *registryApplication) registerRuntime( // nolint: gocyclo // Make sure that the entity has enough stake. if !params.DebugBypassStake { - claim := stakeClaimForRuntime(rt.ID) - var thresholds []staking.ThresholdKind - switch rt.Kind { - case registry.KindCompute: - thresholds = append(thresholds, staking.KindRuntimeCompute) - case registry.KindKeyManager: - thresholds = append(thresholds, staking.KindRuntimeKeyManager) - default: - ctx.Logger().Error("RegisterRuntime: unknown runtime kind", - "runtime_id", rt.ID, - "kind", rt.Kind, - ) - return fmt.Errorf("registry: unknown runtime kind (%d)", rt.Kind) - } + claim := registry.StakeClaimForRuntime(rt.ID) + thresholds := registry.StakeThresholdsForRuntime(rt) if err = stakingState.AddStakeClaim(ctx, rt.EntityID, claim, thresholds); err != nil { ctx.Logger().Error("RegisterRuntime: Insufficent stake", diff --git a/go/consensus/tendermint/apps/staking/genesis.go b/go/consensus/tendermint/apps/staking/genesis.go index f1a091bf970..d3a087872b9 100644 --- a/go/consensus/tendermint/apps/staking/genesis.go +++ b/go/consensus/tendermint/apps/staking/genesis.go @@ -72,6 +72,15 @@ func (app *stakingApplication) initLedger(ctx *abci.Context, state *stakingState return errors.New("staking/tendermint: invalid genesis debonding escrow balance") } + // Make sure that the stake accumulator is empty as otherwise it could be inconsistent with + // what is registered in the genesis block. + if len(v.Escrow.StakeAccumulator.Claims) > 0 { + ctx.Logger().Error("InitChain: non-empty stake accumulator", + "id", id, + ) + return errors.New("staking/tendermint: non-empty stake accumulator in genesis") + } + ups = append(ups, ledgerUpdate{id, v}) if err := totalSupply.Add(&v.General.Balance); err != nil { ctx.Logger().Error("InitChain: failed to add general balance", @@ -261,6 +270,9 @@ func (sq *stakingQuerier) Genesis(ctx context.Context) (*staking.Genesis, error) ledger := make(map[signature.PublicKey]*staking.Account) for _, acctID := range accounts { acct := sq.state.Account(acctID) + // Make sure that export resets the stake accumulator state as that should be re-initialized + // during genesis (a genesis document with non-empty stake accumulator is invalid). + acct.Escrow.StakeAccumulator = staking.StakeAccumulator{} ledger[acctID] = acct } diff --git a/go/consensus/tendermint/apps/supplementarysanity/checks.go b/go/consensus/tendermint/apps/supplementarysanity/checks.go index 3900c8dafd3..cfb26e5c962 100644 --- a/go/consensus/tendermint/apps/supplementarysanity/checks.go +++ b/go/consensus/tendermint/apps/supplementarysanity/checks.go @@ -6,6 +6,7 @@ import ( "github.com/tendermint/iavl" "github.com/oasislabs/oasis-core/go/common" + "github.com/oasislabs/oasis-core/go/common/crypto/signature" "github.com/oasislabs/oasis-core/go/common/quantity" keymanagerState "github.com/oasislabs/oasis-core/go/consensus/tendermint/apps/keymanager/state" registryState "github.com/oasislabs/oasis-core/go/consensus/tendermint/apps/registry/state" @@ -212,3 +213,89 @@ func checkHalt(*iavl.MutableTree, epochtime.EpochTime) error { // nothing to check yet return nil } + +func checkStakeClaims(state *iavl.MutableTree, now epochtime.EpochTime) error { + regSt := registryState.NewMutableState(state) + stakeSt := stakingState.NewMutableState(state) + + params, err := regSt.ConsensusParameters() + if err != nil { + return fmt.Errorf("failed to get consensus parameters: %w", err) + } + + // Skip checks if stake is being bypassed. + if params.DebugBypassStake { + return nil + } + + // Claims in the stake accumulators should be consistent with general state. + claims := make(map[signature.PublicKey]map[staking.StakeClaim][]staking.ThresholdKind) + // Entity registrations. + entities, err := regSt.Entities() + if err != nil { + return fmt.Errorf("failed to get entities: %w", err) + } + for _, entity := range entities { + claims[entity.ID] = map[staking.StakeClaim][]staking.ThresholdKind{ + registry.StakeClaimRegisterEntity: []staking.ThresholdKind{staking.KindEntity}, + } + } + // Node registrations. + nodes, err := regSt.Nodes() + if err != nil { + return fmt.Errorf("failed to get node registrations: %w", err) + } + for _, node := range nodes { + claims[node.EntityID][registry.StakeClaimForNode(node.ID)] = registry.StakeThresholdsForNode(node) + } + // Runtime registrations. + runtimes, err := regSt.AllRuntimes() + if err != nil { + return fmt.Errorf("failed to get runtime registrations: %w", err) + } + for _, rt := range runtimes { + claims[rt.EntityID][registry.StakeClaimForRuntime(rt.ID)] = registry.StakeThresholdsForRuntime(rt) + } + + // Compare with actual accumulator state. + for _, entity := range entities { + acct := stakeSt.Account(entity.ID) + expectedClaims := claims[entity.ID] + actualClaims := acct.Escrow.StakeAccumulator.Claims + if len(expectedClaims) != len(actualClaims) { + return fmt.Errorf("incorrect number of stake claims for account %s (expected: %d got: %d)", + entity.ID, + len(expectedClaims), + len(actualClaims), + ) + } + for claim, expectedThresholds := range expectedClaims { + thresholds, ok := actualClaims[claim] + if !ok { + return fmt.Errorf("missing claim %s for account %s", claim, entity.ID) + } + if len(thresholds) != len(expectedThresholds) { + return fmt.Errorf("incorrect number of thresholds for claim %s for account %s (expected: %d got: %d)", + claim, + entity.ID, + len(expectedThresholds), + len(thresholds), + ) + } + for i, expectedThreshold := range expectedThresholds { + threshold := thresholds[i] + if threshold != expectedThreshold { + return fmt.Errorf("incorrect threshold in position %d for claim %s for account %s (expected: %s got: %s)", + i, + claim, + entity.ID, + expectedThreshold, + threshold, + ) + } + } + } + } + + return nil +} diff --git a/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go b/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go index 5f1d4e4437c..131a5718eff 100644 --- a/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go +++ b/go/consensus/tendermint/apps/supplementarysanity/supplementarysanity.go @@ -125,6 +125,7 @@ func (app *supplementarySanityApplication) endBlockImpl(ctx *abci.Context, reque {"checkBeacon", checkBeacon}, {"checkConsensus", checkConsensus}, {"checkHalt", checkHalt}, + {"checkStakeClaims", checkStakeClaims}, } { if err := tt.checker(state, now); err != nil { return errors.Wrap(err, tt.name) diff --git a/go/registry/api/api.go b/go/registry/api/api.go index f7c2a7f60e5..35b41203b9e 100644 --- a/go/registry/api/api.go +++ b/go/registry/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/oasislabs/oasis-core/go/common/sgx/ias" "github.com/oasislabs/oasis-core/go/consensus/api/transaction" epochtime "github.com/oasislabs/oasis-core/go/epochtime/api" + staking "github.com/oasislabs/oasis-core/go/staking/api" ) // ModuleName is a unique module name for the registry module. @@ -1285,3 +1286,52 @@ var DefaultGasCosts = transaction.Costs{ GasOpRuntimeEpochMaintenance: 1000, GasOpUpdateKeyManager: 1000, } + +const ( + // StakeClaimRegisterEntity is the stake claim identifier used for registering an entity. + StakeClaimRegisterEntity = "registry.RegisterEntity" + // StakeClaimRegisterNode is the stake claim template used for registering nodes. + StakeClaimRegisterNode = "registry.RegisterNode.%s" + // StakeClaimRegisterRuntime is the stake claim template used for registering runtimes. + StakeClaimRegisterRuntime = "registry.RegisterRuntime.%s" +) + +// StakeClaimForNode generates a new stake claim identifier for a specific node registration. +func StakeClaimForNode(id signature.PublicKey) staking.StakeClaim { + return staking.StakeClaim(fmt.Sprintf(StakeClaimRegisterNode, id)) +} + +// StakeClaimForRuntime generates a new stake claim for a specific runtime registration. +func StakeClaimForRuntime(id common.Namespace) staking.StakeClaim { + return staking.StakeClaim(fmt.Sprintf(StakeClaimRegisterRuntime, id)) +} + +// StakeThresholdsForNode returns the staking thresholds for the given node. +func StakeThresholdsForNode(n *node.Node) (thresholds []staking.ThresholdKind) { + if n.HasRoles(node.RoleKeyManager) { + thresholds = append(thresholds, staking.KindNodeKeyManager) + } + if n.HasRoles(node.RoleComputeWorker) { + thresholds = append(thresholds, staking.KindNodeCompute) + } + if n.HasRoles(node.RoleStorageWorker) { + thresholds = append(thresholds, staking.KindNodeStorage) + } + if n.HasRoles(node.RoleValidator) { + thresholds = append(thresholds, staking.KindNodeValidator) + } + return +} + +// StakeThresholdsForRuntime returns the staking thresholds for the given runtime. +func StakeThresholdsForRuntime(rt *Runtime) (thresholds []staking.ThresholdKind) { + switch rt.Kind { + case KindCompute: + thresholds = append(thresholds, staking.KindRuntimeCompute) + case KindKeyManager: + thresholds = append(thresholds, staking.KindRuntimeKeyManager) + default: + panic(fmt.Errorf("registry: unknown runtime kind: %s", rt.Kind)) + } + return +} diff --git a/go/staking/api/sanity_check.go b/go/staking/api/sanity_check.go index 2db16a55cb4..8f71e1d9da8 100644 --- a/go/staking/api/sanity_check.go +++ b/go/staking/api/sanity_check.go @@ -176,6 +176,12 @@ func (g *Genesis) SanityCheck(now epochtime.EpochTime) error { // nolint: gocycl if err != nil { return err } + + // Make sure that the stake accumulator is empty as otherwise it could be inconsistent with + // what is registered in the genesis block. + if len(acct.Escrow.StakeAccumulator.Claims) > 0 { + return fmt.Errorf("staking: non-empty stake accumulator in genesis") + } } _ = total.Add(&g.CommonPool) if total.Cmp(&g.TotalSupply) != 0 {