From 09e0036436d85fd4b558aa473792e6a9e58256dc Mon Sep 17 00:00:00 2001 From: Adam Tucker Date: Sun, 29 Oct 2023 21:25:35 -0600 Subject: [PATCH] feat: add clawback vesting account (#178) --- tests/e2e/auth/vesting/suite.go | 230 +++++++++++++ x/auth/vesting/client/cli/tx.go | 170 ++++++++++ .../vesting/client/testutil/testdata/badjson | 1 + .../client/testutil/testdata/badperiod.json | 9 + .../client/testutil/testdata/periods1.json | 13 + x/auth/vesting/exported/exported.go | 33 ++ x/auth/vesting/msg_server.go | 167 ++++++++++ .../testutil/expected_keepers_mocks.go | 78 +++++ x/auth/vesting/types/codec.go | 5 + x/auth/vesting/types/expected_keepers.go | 10 + x/auth/vesting/types/period.go | 265 +++++++++++++++ x/auth/vesting/types/vesting_account.go | 303 ++++++++++++++++++ .../types/vesting_account_internal_test.go | 65 ++++ x/bank/keeper/keeper.go | 7 + 14 files changed, 1356 insertions(+) create mode 100644 x/auth/vesting/client/testutil/testdata/badjson create mode 100644 x/auth/vesting/client/testutil/testdata/badperiod.json create mode 100644 x/auth/vesting/client/testutil/testdata/periods1.json create mode 100644 x/auth/vesting/types/vesting_account_internal_test.go diff --git a/tests/e2e/auth/vesting/suite.go b/tests/e2e/auth/vesting/suite.go index e2bbd722ea8a..b9d91fa626c8 100644 --- a/tests/e2e/auth/vesting/suite.go +++ b/tests/e2e/auth/vesting/suite.go @@ -219,3 +219,233 @@ func (s *E2ETestSuite) TestNewMsgCreatePermanentLockedAccountCmd() { s.T().Logf("Height now: %d", height) } } + +func (s *E2ETestSuite) TestNewMsgCreateClawbackVestingAccountCmd() { + val := s.network.Validators[0] + for _, tc := range []struct { + name string + args []string + expectErr bool + expectedCode uint32 + respType proto.Message + }{ + { + name: "basic", + args: []string{ + sdk.AccAddress("addr10______________").String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/periods1.json"), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/periods1.json"), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }, + expectErr: false, + expectedCode: 0, + respType: &sdk.TxResponse{}, + }, + { + name: "defaultLockup", + args: []string{ + sdk.AccAddress("addr11______________").String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/periods1.json"), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }, + expectErr: false, + expectedCode: 0, + respType: &sdk.TxResponse{}, + }, + { + name: "defaultVesting", + args: []string{ + sdk.AccAddress("addr12______________").String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/periods1.json"), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }, + expectErr: false, + expectedCode: 0, + respType: &sdk.TxResponse{}, + }, + { + name: "merge", + args: []string{ + sdk.AccAddress("addr10______________").String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/periods1.json"), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/periods1.json"), + fmt.Sprintf("--%s=%s", cli.FlagMerge, "true"), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }, + expectErr: false, + expectedCode: 0, + respType: &sdk.TxResponse{}, + }, + { + name: "bad vesting addr", + args: []string{ + "foo", + }, + expectErr: true, + }, + { + name: "no files", + args: []string{ + sdk.AccAddress("addr13______________").String(), + }, + expectErr: true, + }, + { + name: "bad lockup filename", + args: []string{ + sdk.AccAddress("addr13______________").String(), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/noexist"), + }, + expectErr: true, + }, + { + name: "bad lockup json", + args: []string{ + sdk.AccAddress("addr13______________").String(), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/badjson"), + }, + expectErr: true, + }, + { + name: "bad lockup periods", + args: []string{ + sdk.AccAddress("addr13______________").String(), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/badperiod.json"), + }, + expectErr: true, + }, + { + name: "bad vesting filename", + args: []string{ + sdk.AccAddress("addr13______________").String(), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/noexist"), + }, + expectErr: true, + }, + { + name: "bad vesting json", + args: []string{ + sdk.AccAddress("addr13______________").String(), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/badjson"), + }, + expectErr: true, + }, + { + name: "bad vesting periods", + args: []string{ + sdk.AccAddress("addr13______________").String(), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/badperiod.json"), + }, + expectErr: true, + }, + } { + s.Run(tc.name, func() { + clientCtx := val.ClientCtx + + bw, err := clitestutil.ExecTestCLICmd(clientCtx, cli.NewMsgCreateClawbackVestingAccountCmd(), tc.args) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + s.Require().NoError(clientCtx.Codec.UnmarshalJSON(bw.Bytes(), tc.respType), bw.String()) + + txResp := tc.respType.(*sdk.TxResponse) + s.Require().Equal(tc.expectedCode, txResp.Code) + } + }) + } +} + +func (s *E2ETestSuite) TestNewMsgClawbackCmd() { + val := s.network.Validators[0] + addr := sdk.AccAddress("addr30______________") + + _, err := clitestutil.ExecTestCLICmd(val.ClientCtx, cli.NewMsgCreateClawbackVestingAccountCmd(), []string{ + addr.String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=%s", cli.FlagLockup, "testdata/periods1.json"), + fmt.Sprintf("--%s=%s", cli.FlagVesting, "testdata/periods1.json"), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }) + s.Require().NoError(err) + + for _, tc := range []struct { + name string + args []string + expectErr bool + expectedCode uint32 + respType proto.Message + }{ + { + name: "basic", + args: []string{ + addr.String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=%s", cli.FlagDest, sdk.AccAddress("addr32______________").String()), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }, + expectErr: false, + expectedCode: 0, + respType: &sdk.TxResponse{}, + }, + { + name: "bad vesting addr", + args: []string{ + "foo", + }, + expectErr: true, + }, + { + name: "bad dest addr", + args: []string{ + addr.String(), + fmt.Sprintf("--%s=%s", cli.FlagDest, "bar"), + }, + expectErr: true, + }, + { + name: "default dest", + args: []string{ + addr.String(), + fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(10))).String()), + }, + expectErr: false, + expectedCode: 0, + respType: &sdk.TxResponse{}, + }, + } { + s.Run(tc.name, func() { + clientCtx := val.ClientCtx + + bw, err := clitestutil.ExecTestCLICmd(clientCtx, cli.NewMsgClawbackCmd(), tc.args) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + s.Require().NoError(clientCtx.Codec.UnmarshalJSON(bw.Bytes(), tc.respType), bw.String()) + + txResp := tc.respType.(*sdk.TxResponse) + s.Require().Equal(tc.expectedCode, txResp.Code) + } + }) + } +} diff --git a/x/auth/vesting/client/cli/tx.go b/x/auth/vesting/client/cli/tx.go index f2ef3ba67fe7..c0f5189f7562 100644 --- a/x/auth/vesting/client/cli/tx.go +++ b/x/auth/vesting/client/cli/tx.go @@ -3,6 +3,7 @@ package cli import ( "encoding/json" "fmt" + "io/ioutil" "os" "strconv" @@ -18,6 +19,10 @@ import ( // Transaction command flags const ( FlagDelayed = "delayed" + FlagDest = "dest" + FlagLockup = "lockup" + FlagMerge = "merge" + FlagVesting = "vesting" ) // GetTxCmd returns vesting module's transaction commands. @@ -34,6 +39,8 @@ func GetTxCmd() *cobra.Command { NewMsgCreateVestingAccountCmd(), NewMsgCreatePermanentLockedAccountCmd(), NewMsgCreatePeriodicVestingAccountCmd(), + NewMsgCreateClawbackVestingAccountCmd(), + NewMsgClawbackCmd(), ) return txCmd @@ -207,3 +214,166 @@ func NewMsgCreatePeriodicVestingAccountCmd() *cobra.Command { return cmd } + +// readScheduleFile reads the file at path and unmarshals it to get the schedule. +// Returns start time, periods, and error. +func readScheduleFile(path string) (int64, []types.Period, error) { + contents, err := ioutil.ReadFile(path) + if err != nil { + return 0, nil, err + } + + var data VestingData + if err := json.Unmarshal(contents, &data); err != nil { + return 0, nil, err + } + + startTime := data.StartTime + periods := make([]types.Period, len(data.Periods)) + + for i, p := range data.Periods { + amount, err := sdk.ParseCoinsNormalized(p.Coins) + if err != nil { + return 0, nil, err + } + if p.Length < 1 { + return 0, nil, fmt.Errorf("invalid period length of %d in period %d, length must be greater than 0", p.Length, i) + } + + periods[i] = types.Period{Length: p.Length, Amount: amount} + } + + return startTime, periods, nil +} + +// NewMsgCreateClawbackVestingAccountCmd returns a CLI command handler for creating a +// MsgCreateClawbackVestingAccount transaction. +func NewMsgCreateClawbackVestingAccountCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "create-clawback-vesting-account [to_address]", + Short: "Create a new vesting account funded with an allocation of tokens, subject to clawback.", + Long: `Must provide a lockup periods file (--lockup), a vesting periods file (--vesting), or both. +If both files are given, they must describe schedules for the same total amount. +If one file is omitted, it will default to a schedule that immediately unlocks or vests the entire amount. +The described amount of coins will be transferred from the --from address to the vesting account. +Unvested coins may be "clawed back" by the funder with the clawback command. +Coins may not be transferred out of the account if they are locked or unvested, but may be staked. +Staking rewards are subject to a proportional vesting encumbrance. + +A periods file is a JSON object describing a sequence of unlocking or vesting events, +with a start time and an array of coins strings and durations relative to the start or previous event.`, + Example: `Sample period file contents: +{ + "start_time": 1625204910, + "periods": [ + { + "coins": "10test", + "length": 2592000 // 30 days + }, + { + "coins": "10test", + "length": 2592000 // 30 days + } + ] +} +`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientTxContext(cmd) + if err != nil { + return err + } + + toAddr, err := sdk.AccAddressFromBech32(args[0]) + if err != nil { + return err + } + + lockupFile, _ := cmd.Flags().GetString(FlagLockup) + vestingFile, _ := cmd.Flags().GetString(FlagVesting) + if lockupFile == "" && vestingFile == "" { + return fmt.Errorf("must specify at least one of %s or %s", FlagLockup, FlagVesting) + } + + var ( + lockupStart, vestingStart int64 + lockupPeriods, vestingPeriods []types.Period + ) + if lockupFile != "" { + lockupStart, lockupPeriods, err = readScheduleFile(lockupFile) + if err != nil { + return err + } + } + if vestingFile != "" { + vestingStart, vestingPeriods, err = readScheduleFile(vestingFile) + if err != nil { + return err + } + } + + commonStart, _ := types.AlignSchedules(lockupStart, vestingStart, lockupPeriods, vestingPeriods) + + merge, _ := cmd.Flags().GetBool(FlagMerge) + + msg := types.NewMsgCreateClawbackVestingAccount(clientCtx.GetFromAddress(), toAddr, commonStart, lockupPeriods, vestingPeriods, merge) + if err := msg.ValidateBasic(); err != nil { + return err + } + + return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg) + }, + } + + cmd.Flags().Bool(FlagMerge, false, "Merge new amount and schedule with existing ClawbackVestingAccount, if any") + cmd.Flags().String(FlagLockup, "", "Path to file containing unlocking periods") + cmd.Flags().String(FlagVesting, "", "Path to file containing vesting periods") + flags.AddTxFlagsToCmd(cmd) + return cmd +} + +// NewMsgClawbackCmd returns a CLI command handler for creating a +// MsgClawback transaction. +func NewMsgClawbackCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "clawback [address]", + Short: "Transfer unvested amount out of a ClawbackVestingAccount.", + Long: `Must be requested by the original funder address (--from). + May provide a destination address (--dest), otherwise the coins return to the funder. + Delegated or undelegating staking tokens will be transferred in the delegated (undelegating) state. + The recipient is vulnerable to slashing, and must act to unbond the tokens if desired. + `, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientTxContext(cmd) + if err != nil { + return err + } + + addr, err := sdk.AccAddressFromBech32(args[0]) + if err != nil { + return err + } + + var dest sdk.AccAddress + destString, _ := cmd.Flags().GetString(FlagDest) + if destString != "" { + dest, err = sdk.AccAddressFromBech32(destString) + if err != nil { + return fmt.Errorf("invalid destination address: %w", err) + } + } + + msg := types.NewMsgClawback(clientCtx.GetFromAddress(), addr, dest) + if err := msg.ValidateBasic(); err != nil { + return err + } + + return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg) + }, + } + + cmd.Flags().String(FlagDest, "", "Address of destination (defaults to funder)") + flags.AddTxFlagsToCmd(cmd) + return cmd +} diff --git a/x/auth/vesting/client/testutil/testdata/badjson b/x/auth/vesting/client/testutil/testdata/badjson new file mode 100644 index 000000000000..aad60f4a0e85 --- /dev/null +++ b/x/auth/vesting/client/testutil/testdata/badjson @@ -0,0 +1 @@ +Not JSON data. diff --git a/x/auth/vesting/client/testutil/testdata/badperiod.json b/x/auth/vesting/client/testutil/testdata/badperiod.json new file mode 100644 index 000000000000..e580e831438b --- /dev/null +++ b/x/auth/vesting/client/testutil/testdata/badperiod.json @@ -0,0 +1,9 @@ +{ + "start_time": 1625204910, + "periods": [ + { + "coins": "10test", + "length_seconds": -500 + } + ] +} diff --git a/x/auth/vesting/client/testutil/testdata/periods1.json b/x/auth/vesting/client/testutil/testdata/periods1.json new file mode 100644 index 000000000000..82255f074c99 --- /dev/null +++ b/x/auth/vesting/client/testutil/testdata/periods1.json @@ -0,0 +1,13 @@ +{ + "start_time": 1625204910, + "periods": [ + { + "coins": "10stake", + "length": 2592000 + }, + { + "coins": "10stake", + "length":2592000 + } + ] +} diff --git a/x/auth/vesting/exported/exported.go b/x/auth/vesting/exported/exported.go index 858e53ed4f98..04369110438c 100644 --- a/x/auth/vesting/exported/exported.go +++ b/x/auth/vesting/exported/exported.go @@ -40,3 +40,36 @@ type VestingAccount interface { GetDelegatedFree() sdk.Coins GetDelegatedVesting() sdk.Coins } + +// AddGrantAction encapsulates the data needed to add a grant to an account. +type AddGrantAction interface { + // AddToAccount adds the grant to the specified account. + // The rawAccount should bypass any account wrappers. + AddToAccount(ctx sdk.Context, rawAccount VestingAccount) error +} + +// ClawbackAction encapsulates the data needed to perform clawback. +type ClawbackAction interface { + // TakeFromAccount removes unvested tokens from the specified account. + // The rawAccount should bypass any account wrappers. + TakeFromAccount(ctx sdk.Context, rawAccount VestingAccount) error +} + +// ClawbackVestingAccountI is an interface for the methods of a clawback account. +type ClawbackVestingAccountI interface { + VestingAccount + + // GetUnlockedOnly returns the sum of all unlocking events up to and including + // the blockTime. + GetUnlockedOnly(blockTime time.Time) sdk.Coins + + // GetVestedOnly returns the sum of all vesting events up to and including + // the blockTime. + GetVestedOnly(blockTime time.Time) sdk.Coins + + // AddGrant adds the specified grant to the account. + AddGrant(ctx sdk.Context, action AddGrantAction) error + + // Clawback performs the clawback described by action. + Clawback(ctx sdk.Context, action ClawbackAction) error +} diff --git a/x/auth/vesting/msg_server.go b/x/auth/vesting/msg_server.go index a823916e23bb..a4461980f93b 100644 --- a/x/auth/vesting/msg_server.go +++ b/x/auth/vesting/msg_server.go @@ -194,10 +194,177 @@ func (s msgServer) CreatePeriodicVestingAccount(goCtx context.Context, msg *type return &types.MsgCreatePeriodicVestingAccountResponse{}, nil } +// CreateClawbackVestingAccount creates a new ClawbackVestingAccount, or merges a grant into an existing one. func (s msgServer) CreateClawbackVestingAccount(goCtx context.Context, msg *types.MsgCreateClawbackVestingAccount) (*types.MsgCreateClawbackVestingAccountResponse, error) { + ctx := sdk.UnwrapSDKContext(goCtx) + ak := s.AccountKeeper + bk := s.BankKeeper + + from, err := sdk.AccAddressFromBech32(msg.FromAddress) + if err != nil { + return nil, err + } + to, err := sdk.AccAddressFromBech32(msg.ToAddress) + if err != nil { + return nil, err + } + + if bk.BlockedAddr(to) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", msg.ToAddress) + } + + vestingCoins := sdk.NewCoins() + for _, period := range msg.VestingPeriods { + vestingCoins = vestingCoins.Add(period.Amount...) + } + + lockupCoins := sdk.NewCoins() + for _, period := range msg.LockupPeriods { + lockupCoins = lockupCoins.Add(period.Amount...) + } + + // if lockup absent, default to an instant unlock schedule + lockupPeriods := msg.LockupPeriods + vestingPeriods := msg.VestingPeriods + if !vestingCoins.IsZero() && len(msg.LockupPeriods) == 0 { + lockupPeriods = []types.Period{ + {Length: 0, Amount: vestingCoins}, + } + lockupCoins = vestingCoins + } + + if !lockupCoins.IsZero() && len(msg.VestingPeriods) == 0 { + // If vesting absent, default to an instant vesting schedule + vestingPeriods = []types.Period{ + {Length: 0, Amount: lockupCoins}, + } + vestingCoins = lockupCoins + } + + if !types.CoinEq(lockupCoins, vestingCoins) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidRequest, "lockup and vesting amounts must be equal") + } + + var ( + madeNewAcc bool + vestingAcc *types.ClawbackVestingAccount + ) + + acc := ak.GetAccount(ctx, to) + + // a grant can be added only if toAddress exists, msg.Merge && isClawback && to.FunderAddress == msg.FromAddress + if acc != nil { + var isClawback bool + vestingAcc, isClawback = acc.(*types.ClawbackVestingAccount) + switch { + case !isClawback: + return nil, sdkerrors.Wrapf(sdkerrors.ErrNotSupported, "account %s must be a clawback vesting account", msg.ToAddress) + case !msg.Merge && isClawback: + return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidRequest, "account %s already exists; consider setting 'merge' to 'true'", msg.ToAddress) + case msg.FromAddress != vestingAcc.FunderAddress: + return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidRequest, "account %s can only accept grants from account %s", msg.ToAddress, vestingAcc.FunderAddress) + } + grantAction := types.NewClawbackGrantAction(msg.FromAddress, msg.StartTime, msg.GetLockupPeriods(), msg.GetVestingPeriods(), vestingCoins) + err := vestingAcc.AddGrant(ctx, grantAction) + if err != nil { + return nil, err + } + ak.SetAccount(ctx, vestingAcc) + } else { + baseAcc := authtypes.NewBaseAccountWithAddress(to) + vestingAcc = types.NewClawbackVestingAccount( + baseAcc, + from, + vestingCoins, + msg.StartTime, + lockupPeriods, + vestingPeriods, + ) + acc := ak.NewAccount(ctx, vestingAcc) + madeNewAcc = true + ak.SetAccount(ctx, acc) + } + + if madeNewAcc { + defer func() { + telemetry.IncrCounter(1, "new", "account") + + for _, a := range vestingCoins { + if a.Amount.IsInt64() { + telemetry.SetGaugeWithLabels( + []string{"tx", "msg", "create_clawback_vesting_account"}, + float32(a.Amount.Int64()), + []metrics.Label{telemetry.NewLabel("denom", a.Denom)}, + ) + } + } + }() + } + + err = bk.SendCoins(ctx, from, to, vestingCoins) + if err != nil { + return nil, err + } + return &types.MsgCreateClawbackVestingAccountResponse{}, nil } +// Clawback removes the unvested amount from a ClawbackVestingAccount. +// The destination defaults to the funder address, but can be overridden. func (s msgServer) Clawback(goCtx context.Context, msg *types.MsgClawback) (*types.MsgClawbackResponse, error) { + ctx := sdk.UnwrapSDKContext(goCtx) + accountKeeper := s.AccountKeeper + bankKeeper := s.BankKeeper + + funder, err := sdk.AccAddressFromBech32(msg.GetFunderAddress()) + if err != nil { + return nil, err + } + + addr, err := sdk.AccAddressFromBech32(msg.GetAddress()) + if err != nil { + return nil, err + } + + dest := funder + if msg.GetDestAddress() != "" { + dest, err = sdk.AccAddressFromBech32(msg.GetDestAddress()) + if err != nil { + return nil, err + } + } + + if bankKeeper.BlockedAddr(dest) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, + "%s is not allowed to receive funds", msg.DestAddress, + ) + } + + // Check if account exists + account := accountKeeper.GetAccount(ctx, addr) + if account == nil { + return nil, sdkerrors.Wrapf(sdkerrors.ErrNotFound, "account %s does not exist", msg.Address) + } + + // Check if account has a clawback account + vestingAccount, ok := account.(*types.ClawbackVestingAccount) + if !ok { + return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidRequest, "account not subject to clawback: %s", msg.Address) + } + + // Check if account funder is same as in msg + if vestingAccount.FunderAddress != msg.FunderAddress { + return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidRequest, "clawback can only be requested by original funder %s", vestingAccount.FunderAddress) + } + + clawbackAction := types.NewClawbackAction(funder, dest, accountKeeper, bankKeeper) + + // Perform clawback transfer, + // this updates state for both the vesting account and the destination account. + err = vestingAccount.Clawback(ctx, clawbackAction) + if err != nil { + return nil, err + } + return &types.MsgClawbackResponse{}, nil } diff --git a/x/auth/vesting/testutil/expected_keepers_mocks.go b/x/auth/vesting/testutil/expected_keepers_mocks.go index fd1fe6140d9c..9261ee255bb6 100644 --- a/x/auth/vesting/testutil/expected_keepers_mocks.go +++ b/x/auth/vesting/testutil/expected_keepers_mocks.go @@ -8,9 +8,59 @@ import ( reflect "reflect" types "github.com/cosmos/cosmos-sdk/types" + types0 "github.com/cosmos/cosmos-sdk/x/auth/types" gomock "github.com/golang/mock/gomock" ) +// MockAccountKeeper is a mock of AccountKeeper interface. +type MockAccountKeeper struct { + ctrl *gomock.Controller + recorder *MockAccountKeeperMockRecorder +} + +// MockAccountKeeperMockRecorder is the mock recorder for MockAccountKeeper. +type MockAccountKeeperMockRecorder struct { + mock *MockAccountKeeper +} + +// NewMockAccountKeeper creates a new mock instance. +func NewMockAccountKeeper(ctrl *gomock.Controller) *MockAccountKeeper { + mock := &MockAccountKeeper{ctrl: ctrl} + mock.recorder = &MockAccountKeeperMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAccountKeeper) EXPECT() *MockAccountKeeperMockRecorder { + return m.recorder +} + +// GetAccount mocks base method. +func (m *MockAccountKeeper) GetAccount(arg0 types.Context, arg1 types.AccAddress) types0.AccountI { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccount", arg0, arg1) + ret0, _ := ret[0].(types0.AccountI) + return ret0 +} + +// GetAccount indicates an expected call of GetAccount. +func (mr *MockAccountKeeperMockRecorder) GetAccount(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccount", reflect.TypeOf((*MockAccountKeeper)(nil).GetAccount), arg0, arg1) +} + +// SetAccount mocks base method. +func (m *MockAccountKeeper) SetAccount(arg0 types.Context, arg1 types0.AccountI) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccount", arg0, arg1) +} + +// SetAccount indicates an expected call of SetAccount. +func (mr *MockAccountKeeperMockRecorder) SetAccount(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccount", reflect.TypeOf((*MockAccountKeeper)(nil).SetAccount), arg0, arg1) +} + // MockBankKeeper is a mock of BankKeeper interface. type MockBankKeeper struct { ctrl *gomock.Controller @@ -48,6 +98,20 @@ func (mr *MockBankKeeperMockRecorder) BlockedAddr(addr interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BlockedAddr", reflect.TypeOf((*MockBankKeeper)(nil).BlockedAddr), addr) } +// GetAllBalances mocks base method. +func (m *MockBankKeeper) GetAllBalances(ctx types.Context, addr types.AccAddress) types.Coins { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllBalances", ctx, addr) + ret0, _ := ret[0].(types.Coins) + return ret0 +} + +// GetAllBalances indicates an expected call of GetAllBalances. +func (mr *MockBankKeeperMockRecorder) GetAllBalances(ctx, addr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllBalances", reflect.TypeOf((*MockBankKeeper)(nil).GetAllBalances), ctx, addr) +} + // IsSendEnabledCoins mocks base method. func (m *MockBankKeeper) IsSendEnabledCoins(ctx types.Context, coins ...types.Coin) error { m.ctrl.T.Helper() @@ -80,3 +144,17 @@ func (mr *MockBankKeeperMockRecorder) SendCoins(ctx, fromAddr, toAddr, amt inter mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendCoins", reflect.TypeOf((*MockBankKeeper)(nil).SendCoins), ctx, fromAddr, toAddr, amt) } + +// SpendableCoins mocks base method. +func (m *MockBankKeeper) SpendableCoins(ctx types.Context, addr types.AccAddress) types.Coins { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SpendableCoins", ctx, addr) + ret0, _ := ret[0].(types.Coins) + return ret0 +} + +// SpendableCoins indicates an expected call of SpendableCoins. +func (mr *MockBankKeeperMockRecorder) SpendableCoins(ctx, addr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpendableCoins", reflect.TypeOf((*MockBankKeeper)(nil).SpendableCoins), ctx, addr) +} diff --git a/x/auth/vesting/types/codec.go b/x/auth/vesting/types/codec.go index 2dd2c5a8196f..89dccf621e0f 100644 --- a/x/auth/vesting/types/codec.go +++ b/x/auth/vesting/types/codec.go @@ -26,6 +26,7 @@ func RegisterLegacyAminoCodec(cdc *codec.LegacyAmino) { legacy.RegisterAminoMsg(cdc, &MsgCreateVestingAccount{}, "cosmos-sdk/MsgCreateVestingAccount") legacy.RegisterAminoMsg(cdc, &MsgCreatePermanentLockedAccount{}, "cosmos-sdk/MsgCreatePermLockedAccount") legacy.RegisterAminoMsg(cdc, &MsgCreatePeriodicVestingAccount{}, "cosmos-sdk/MsgCreatePeriodVestAccount") + cdc.RegisterConcrete(&ClawbackVestingAccount{}, "cosmos-sdk/ClawbackVestingAccount", nil) } // RegisterInterface associates protoName with AccountI and VestingAccount @@ -38,6 +39,7 @@ func RegisterInterfaces(registry types.InterfaceRegistry) { &DelayedVestingAccount{}, &PeriodicVestingAccount{}, &PermanentLockedAccount{}, + &ClawbackVestingAccount{}, ) registry.RegisterImplementations( @@ -47,6 +49,7 @@ func RegisterInterfaces(registry types.InterfaceRegistry) { &ContinuousVestingAccount{}, &PeriodicVestingAccount{}, &PermanentLockedAccount{}, + &ClawbackVestingAccount{}, ) registry.RegisterImplementations( @@ -56,12 +59,14 @@ func RegisterInterfaces(registry types.InterfaceRegistry) { &ContinuousVestingAccount{}, &PeriodicVestingAccount{}, &PermanentLockedAccount{}, + &ClawbackVestingAccount{}, ) registry.RegisterImplementations( (*sdk.Msg)(nil), &MsgCreateVestingAccount{}, &MsgCreatePermanentLockedAccount{}, + &MsgCreateClawbackVestingAccount{}, ) msgservice.RegisterMsgServiceDesc(registry, &_Msg_serviceDesc) diff --git a/x/auth/vesting/types/expected_keepers.go b/x/auth/vesting/types/expected_keepers.go index 5705eea30baf..821881aa37cc 100644 --- a/x/auth/vesting/types/expected_keepers.go +++ b/x/auth/vesting/types/expected_keepers.go @@ -2,12 +2,22 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" ) +// AccountKeeper defines the expected interface contract that is required by the +// vesting module for storing accounts. +type AccountKeeper interface { + GetAccount(sdk.Context, sdk.AccAddress) authtypes.AccountI + SetAccount(sdk.Context, authtypes.AccountI) +} + // BankKeeper defines the expected interface contract the vesting module requires // for creating vesting accounts with funds. type BankKeeper interface { + GetAllBalances(ctx sdk.Context, addr sdk.AccAddress) sdk.Coins IsSendEnabledCoins(ctx sdk.Context, coins ...sdk.Coin) error SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error + SpendableCoins(ctx sdk.Context, addr sdk.AccAddress) sdk.Coins BlockedAddr(addr sdk.AccAddress) bool } diff --git a/x/auth/vesting/types/period.go b/x/auth/vesting/types/period.go index 2d64fd6199db..85193037b6fb 100644 --- a/x/auth/vesting/types/period.go +++ b/x/auth/vesting/types/period.go @@ -58,3 +58,268 @@ func (p Periods) String() string { return strings.TrimSpace(fmt.Sprintf(`Vesting Periods: %s`, strings.Join(periodsListString, ", "))) } + +// A "schedule" is an increasing step function of Coins over time. +// It's specified with an absolute start time and a sequence of relative +// periods, with each step at the end of a period. A schedule may also +// give the time and total value at the last step, which can speed +// evaluation of the step function after the last step. + +// ReadSchedule returns the value of a schedule at the current provided time. +func ReadSchedule(startTime, endTime int64, periods []Period, totalCoins sdk.Coins, currTime int64) sdk.Coins { + if currTime <= startTime { + return sdk.NewCoins() + } + if currTime >= endTime { + return totalCoins + } + + coins := sdk.NewCoins() // sum of amounts for events before currTime + time := startTime + + for _, period := range periods { + if currTime < time+period.Length { + // we're reading before the next event + break + } + coins = coins.Add(period.Amount...) + time += period.Length + } + + return coins +} + +// max64 returns the maximum of its inputs. +func max64(i, j int64) int64 { + if i > j { + return i + } + return j +} + +// min64 returns the minimum of its inputs. +func min64(i, j int64) int64 { + if i < j { + return i + } + return j +} + +// coinsMin returns the minimum of its inputs for all denominations. +func coinsMin(a, b sdk.Coins) sdk.Coins { + min := sdk.NewCoins() + for _, coinA := range a { + denom := coinA.Denom + bAmt := b.AmountOfNoDenomValidation(denom) + minAmt := coinA.Amount + if minAmt.GT(bAmt) { + minAmt = bAmt + } + if minAmt.IsPositive() { + min = min.Add(sdk.NewCoin(denom, minAmt)) + } + } + return min +} + +// DisjunctPeriods returns the union of two vesting period schedules. +// The returned schedule is the union of the vesting events, with simultaneous +// events combined into a single event. +// Input schedules P and Q are defined by their start times and periods. +// Returns new start time, new end time, and merged vesting events, relative to +// the new start time. +func DisjunctPeriods(startP, startQ int64, periodsP, periodsQ []Period) (int64, int64, []Period) { + timeP := startP // time of last merged p event, next p event is relative to this time + timeQ := startQ // time of last merged q event, next q event is relative to this time + iP := 0 // p indexes before this have been merged + iQ := 0 // q indexes before this have been merged + lenP := len(periodsP) + lenQ := len(periodsQ) + startTime := min64(startP, startQ) // we pick the earlier time + time := startTime // time of last merged event, or the start time + merged := []Period{} + + // emit adds an output period and updates the last event time + emit := func(nextTime int64, amount sdk.Coins) { + period := Period{ + Length: nextTime - time, + Amount: amount, + } + merged = append(merged, period) + time = nextTime + } + + // consumeP emits the next period from p, updating indexes + consumeP := func(nextP int64) { + emit(nextP, periodsP[iP].Amount) + timeP = nextP + iP++ + } + + // consumeQ emits the next period from q, updating indexes + consumeQ := func(nextQ int64) { + emit(nextQ, periodsQ[iQ].Amount) + timeQ = nextQ + iQ++ + } + + // consumeBoth emits a merge of the next periods from p and q, updating indexes + consumeBoth := func(nextTime int64) { + emit(nextTime, periodsP[iP].Amount.Add(periodsQ[iQ].Amount...)) + timeP = nextTime + timeQ = nextTime + iP++ + iQ++ + } + + // while there are more events in both schedules, handle the next one, merge if concurrent + for iP < lenP && iQ < lenQ { + nextP := timeP + periodsP[iP].Length // next p event in absolute time + nextQ := timeQ + periodsQ[iQ].Length // next q event in absolute time + if nextP < nextQ { + consumeP(nextP) + } else if nextP > nextQ { + consumeQ(nextQ) + } else { + consumeBoth(nextP) + } + } + // consume remaining events in schedule P + for iP < lenP { + nextP := timeP + periodsP[iP].Length + consumeP(nextP) + } + // consume remaining events in schedule Q + for iQ < lenQ { + nextQ := timeQ + periodsQ[iQ].Length + consumeQ(nextQ) + } + return startTime, time, merged +} + +// ConjunctPeriods returns the combination of two period schedules where the result is the minimum of the two schedules. +func ConjunctPeriods(startP, startQ int64, periodsP, periodsQ []Period) (startTime int64, endTime int64, merged []Period) { + timeP := startP + timeQ := startQ + iP := 0 + iQ := 0 + lenP := len(periodsP) + lenQ := len(periodsQ) + startTime = min64(startP, startQ) + time := startTime + merged = []Period{} + amount := sdk.NewCoins() + amountP := amount + amountQ := amount + + // emit adds an output period and updates the last event time + emit := func(nextTime int64, coins sdk.Coins) { + period := Period{ + Length: nextTime - time, + Amount: coins, + } + merged = append(merged, period) + time = nextTime + amount = amount.Add(coins...) + } + + // consumeP processes the next event in P and emits an event + // if the minimum of P and Q changes + consumeP := func(nextTime int64) { + amountP = amountP.Add(periodsP[iP].Amount...) + min := coinsMin(amountP, amountQ) + if amount.IsAllLTE(min) { + diff := min.Sub(amount...) + if !diff.IsZero() { + emit(nextTime, diff) + } + } + timeP = nextTime + iP++ + } + + // consumeQ processes the next event in Q and emits an event + // if the minimum of P and Q changes + consumeQ := func(nextTime int64) { + amountQ = amountQ.Add(periodsQ[iQ].Amount...) + min := coinsMin(amountP, amountQ) + if amount.IsAllLTE(min) { + diff := min.Sub(amount...) + if !diff.IsZero() { + emit(nextTime, diff) + } + } + timeQ = nextTime + iQ++ + } + + // consumeBoth processes simultaneous events in P and Q and emits an + // event if the minumum of P and Q changes + consumeBoth := func(nextTime int64) { + amountP = amountP.Add(periodsP[iP].Amount...) + amountQ = amountQ.Add(periodsQ[iQ].Amount...) + min := coinsMin(amountP, amountQ) + if amount.IsAllLTE(min) { + diff := min.Sub(amount...) + if !diff.IsZero() { + emit(nextTime, diff) + } + } + timeP = nextTime + timeQ = nextTime + iP++ + iQ++ + } + + // while there are events left in both schedules, process the next one + for iP < lenP && iQ < lenQ { + nextP := timeP + periodsP[iP].Length // next p event in absolute time + nextQ := timeQ + periodsQ[iQ].Length // next q event in absolute time + if nextP < nextQ { + consumeP(nextP) + } else if nextP > nextQ { + consumeQ(nextQ) + } else { + consumeBoth(nextP) + } + } + + // consume remaining events in schedule P + for iP < lenP { + nextP := timeP + periodsP[iP].Length + consumeP(nextP) + } + + // consume remaining events in schedule Q + for iQ < lenQ { + nextQ := timeQ + periodsQ[iQ].Length + consumeQ(nextQ) + } + + endTime = time + return +} + +// AlignSchedules rewrites the first period length to align the two arguments to the same start time, +// returning the earliest start time and the latest end time +func AlignSchedules(startP, startQ int64, p, q []Period) (startTime, endTime int64) { + startTime = min64(startP, startQ) + + if len(p) > 0 { + p[0].Length += startP - startTime + } + if len(q) > 0 { + q[0].Length += startQ - startTime + } + + endP := startTime + for _, period := range p { + endP += period.Length + } + endQ := startTime + for _, period := range q { + endQ += period.Length + } + endTime = max64(endP, endQ) + return +} diff --git a/x/auth/vesting/types/vesting_account.go b/x/auth/vesting/types/vesting_account.go index 892794f669cc..ac12810155e0 100644 --- a/x/auth/vesting/types/vesting_account.go +++ b/x/auth/vesting/types/vesting_account.go @@ -2,6 +2,7 @@ package types import ( "errors" + "fmt" "time" "cosmossdk.io/math" @@ -9,7 +10,9 @@ import ( cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + "github.com/cosmos/cosmos-sdk/x/auth/vesting/exported" vestexported "github.com/cosmos/cosmos-sdk/x/auth/vesting/exported" ) @@ -19,6 +22,7 @@ var ( _ vestexported.VestingAccount = (*ContinuousVestingAccount)(nil) _ vestexported.VestingAccount = (*PeriodicVestingAccount)(nil) _ vestexported.VestingAccount = (*DelayedVestingAccount)(nil) + _ vestexported.VestingAccount = (*ClawbackVestingAccount)(nil) ) // Base Vesting Account @@ -610,3 +614,302 @@ func marshalYaml(i interface{}) (interface{}, error) { } return string(bz), nil } + +// Clawback Vesting Account + +var _ vestexported.VestingAccount = (*ClawbackVestingAccount)(nil) +var _ authtypes.GenesisAccount = (*ClawbackVestingAccount)(nil) + +// NewClawbackVestingAccount returns a new ClawbackVestingAccount +func NewClawbackVestingAccount( + baseAcc *authtypes.BaseAccount, + funder sdk.AccAddress, + originalVesting sdk.Coins, + startTime int64, + lockupPeriods, + vestingPeriods Periods, +) *ClawbackVestingAccount { + // copy and align schedules to avoid mutating inputs + lockupPeriod := make(Periods, len(lockupPeriods)) + copy(lockupPeriod, lockupPeriods) + vp := make(Periods, len(vestingPeriods)) + copy(vp, vestingPeriods) + _, endTime := AlignSchedules(startTime, startTime, lockupPeriod, vp) + baseVestingAcc := &BaseVestingAccount{ + BaseAccount: baseAcc, + OriginalVesting: originalVesting, + EndTime: endTime, + } + + return &ClawbackVestingAccount{ + BaseVestingAccount: baseVestingAcc, + FunderAddress: funder.String(), + StartTime: startTime, + LockupPeriods: lockupPeriod, + VestingPeriods: vp, + } +} + +// GetVestedCoins returns the total number of vested coins. If no coins are vested, +// nil is returned. +func (va ClawbackVestingAccount) GetVestedCoins(blockTime time.Time) sdk.Coins { + // It's likely that one or the other schedule will be nearly trivial, + // so there should be little overhead in recomputing the conjunction each time. + coins := coinsMin(va.GetUnlockedOnly(blockTime), va.GetVestedOnly(blockTime)) + if coins.IsZero() { + return nil + } + return coins +} + +// GetVestingCoins returns the total number of vesting coins. If no coins are +// vesting, nil is returned. +func (va ClawbackVestingAccount) GetVestingCoins(blockTime time.Time) sdk.Coins { + return va.OriginalVesting.Sub(va.GetVestedCoins(blockTime)...) +} + +// LockedCoins returns the set of coins that are not spendable (i.e. locked), +// defined as the vesting coins that are not delegated. +func (va ClawbackVestingAccount) LockedCoins(blockTime time.Time) sdk.Coins { + return va.BaseVestingAccount.LockedCoinsFromVesting(va.GetVestingCoins(blockTime)) +} + +// TrackDelegation tracks a desired delegation amount by setting the appropriate +// values for the amount of delegated vesting, delegated free, and reducing the +// overall amount of base coins. +func (va *ClawbackVestingAccount) TrackDelegation(blockTime time.Time, balance, amount sdk.Coins) { + va.BaseVestingAccount.TrackDelegation(balance, va.GetVestingCoins(blockTime), amount) +} + +// GetStartTime returns the time when vesting starts for a periodic vesting +// account. +func (va ClawbackVestingAccount) GetStartTime() int64 { + return va.StartTime +} + +// GetVestingPeriods returns vesting periods associated with periodic vesting account. +func (va ClawbackVestingAccount) GetVestingPeriods() Periods { + return va.VestingPeriods +} + +// coinEq returns whether two Coins are equal. +// The IsEqual() method can panic. +func CoinEq(a, b sdk.Coins) bool { + return a.IsAllLTE(b) && b.IsAllLTE(a) +} + +// Validate checks for errors on the account fields +func (va ClawbackVestingAccount) Validate() error { + if va.GetStartTime() >= va.GetEndTime() { + return errors.New("vesting start-time must be before end-time") + + } + + lockupEnd := va.GetStartTime() + lockupCoins := sdk.NewCoins() + + for _, p := range va.LockupPeriods { + lockupEnd += p.Length + lockupCoins = lockupCoins.Add(p.Amount...) + } + + if lockupEnd > va.EndTime { + return errors.New("lockup schedule extends beyond account end time") + } + + // use coinEq to prevent panic + if !CoinEq(lockupCoins, va.OriginalVesting) { + return errors.New("original vesting coins does not match the sum of all coins in lockup periods") + } + + vestingEnd := va.GetStartTime() + vestingCoins := sdk.NewCoins() + + for _, p := range va.VestingPeriods { + vestingEnd += p.Length + vestingCoins = vestingCoins.Add(p.Amount...) + } + + if vestingEnd > va.EndTime { + return errors.New("vesting schedule exteds beyond account end time") + } + + if !CoinEq(vestingCoins, va.OriginalVesting) { + return errors.New("original vesting coins does not match the sum of all coins in vesting periods") + } + + return va.BaseVestingAccount.Validate() +} + +type clawbackGrantAction struct { + funderAddress string + grantStartTime int64 + grantLockupPeriods []Period + grantVestingPeriods []Period + grantCoins sdk.Coins +} + +func NewClawbackGrantAction( + funderAddress string, + grantStartTime int64, + grantLockupPeriods, grantVestingPeriods []Period, + grantCoins sdk.Coins, +) exported.AddGrantAction { + return clawbackGrantAction{ + funderAddress: funderAddress, + grantStartTime: grantStartTime, + grantLockupPeriods: grantLockupPeriods, + grantVestingPeriods: grantVestingPeriods, + grantCoins: grantCoins, + } +} + +func (cga clawbackGrantAction) AddToAccount(ctx sdk.Context, rawAccount exported.VestingAccount) error { + cva, ok := rawAccount.(*ClawbackVestingAccount) + if !ok { + return fmt.Errorf("expected *ClawbackVestingAccount, got %T", rawAccount) + } + if cga.funderAddress != cva.FunderAddress { + return sdkerrors.Wrapf( + sdkerrors.ErrInvalidRequest, + "account %s can only accept grants from account %s", + rawAccount.GetAddress(), cva.FunderAddress, + ) + } + cva.addGrant(ctx, cga.grantStartTime, cga.grantLockupPeriods, cga.grantVestingPeriods, cga.grantCoins) + return nil + +} + +func (va *ClawbackVestingAccount) AddGrant(ctx sdk.Context, action exported.AddGrantAction) error { + return action.AddToAccount(ctx, va) +} + +func (va *ClawbackVestingAccount) addGrant(ctx sdk.Context, grantStartTime int64, grantLockupPeriods, grantVestingPeriods []Period, grantCoins sdk.Coins) { + // modify schedules for the new grant + newLockupStart, newLockupEnd, newLockupPeriods := DisjunctPeriods(va.GetStartTime(), grantStartTime, va.LockupPeriods, grantLockupPeriods) + newVestingStart, newVestingEnd, newVestingPeriods := DisjunctPeriods(va.GetStartTime(), grantStartTime, + va.GetVestingPeriods(), grantVestingPeriods) + if newLockupStart != newVestingStart { + panic("bad start time calculation") + } + va.StartTime = newLockupStart + va.EndTime = max64(newLockupEnd, newVestingEnd) + va.LockupPeriods = newLockupPeriods + va.VestingPeriods = newVestingPeriods + va.OriginalVesting = va.OriginalVesting.Add(grantCoins...) +} + +// GetUnlockedOnly returns the unlocking schedule at blockTIme. +func (va ClawbackVestingAccount) GetUnlockedOnly(blockTime time.Time) sdk.Coins { + return ReadSchedule(va.GetStartTime(), va.EndTime, va.LockupPeriods, va.OriginalVesting, blockTime.Unix()) +} + +// GetVestedOnly returns the vesting schedule at blockTime. +func (va ClawbackVestingAccount) GetVestedOnly(blockTime time.Time) sdk.Coins { + return ReadSchedule(va.GetStartTime(), va.EndTime, va.VestingPeriods, va.OriginalVesting, blockTime.Unix()) +} + +// computeClawback removes all future vesting events from the account, +// returns the total sum of these events. When removing the future vesting events, +// the lockup schedule will also have to be capped to keep the total sums the same. +// (But future unlocking events might be preserved if they unlock currently vested coins.) +// If the amount returned is zero, then the returned account should be unchanged. +// Note that this method althers the struct itself +// Does not adjust DelegatedVesting +func (va *ClawbackVestingAccount) computeClawback(clawbackTime int64) sdk.Coins { + // Compute the truncated vesting schedule and amounts. + // Work with the schedule as the primary data and recompute derived fields, e.g. OriginalVesting. + vestTime := va.GetStartTime() + totalVested := sdk.NewCoins() + totalUnvested := sdk.NewCoins() + unvestedIdx := 0 + for i, period := range va.VestingPeriods { + // this period vests at time t, if this occurred before clawback time, + // then its already vested. + vestTime += period.Length + // tie in time gets clawed back + if vestTime < clawbackTime { + totalVested = totalVested.Add(period.Amount...) + unvestedIdx = i + 1 + } else { + totalUnvested = totalUnvested.Add(period.Amount...) + } + } + lastVestTime := vestTime + + newVestingPeriods := va.VestingPeriods[:unvestedIdx] + + // To cap the unlocking schedule to the new total vested, conjunct with a limiting schedule + capPeriods := []Period{ + { + Length: 0, + Amount: totalVested, + }, + } + _, lastLockTime, newLockupPeriods := ConjunctPeriods(va.StartTime, va.StartTime, va.LockupPeriods, capPeriods) + + // Now construct the new account state + va.OriginalVesting = totalVested + va.EndTime = max64(lastVestTime, lastLockTime) + va.LockupPeriods = newLockupPeriods + va.VestingPeriods = newVestingPeriods + + return totalUnvested +} + +type clawbackAction struct { + requestor sdk.AccAddress + dest sdk.AccAddress + ak AccountKeeper + bk BankKeeper +} + +func NewClawbackAction(requestor, dest sdk.AccAddress, ak AccountKeeper, bk BankKeeper) exported.ClawbackAction { + return clawbackAction{ + requestor: requestor, + dest: dest, + ak: ak, + bk: bk, + } +} + +func (ca clawbackAction) TakeFromAccount(ctx sdk.Context, rawAccount exported.VestingAccount) error { + cva, ok := rawAccount.(*ClawbackVestingAccount) + if !ok { + return fmt.Errorf("clawback expects *ClawbackVestingAccount, got %T", rawAccount) + } + if ca.requestor.String() != cva.FunderAddress { + return sdkerrors.Wrapf(sdkerrors.ErrInvalidRequest, "clawback can only be requested by original funder %s", cva.FunderAddress) + } + return cva.clawback(ctx, ca.dest, ca.ak, ca.bk) +} + +func (va *ClawbackVestingAccount) Clawback(ctx sdk.Context, action exported.ClawbackAction) error { + return action.TakeFromAccount(ctx, va) +} + +// Clawback transfers unvested tokens in a ClawbackVestingAccount to dest. +// Future vesting events are removed. Unstaked tokens are simply sent. +// Unbonding and staked tokens are transferred with their staking state +// intact. Account state is updated to reflect the removals. +func (va *ClawbackVestingAccount) clawback(ctx sdk.Context, dest sdk.AccAddress, ak AccountKeeper, bk BankKeeper) error { + // Compute the clawback based on the account state only, and update account + toClawBack := va.computeClawback(ctx.BlockTime().Unix()) + if toClawBack.IsZero() { + return nil + } + addr := va.GetAddress() + + // update the account's vesting settings + ak.SetAccount(ctx, va) + + // Now that future vesting events (and associated lockup) are removed, + // the balance of the account is unlocked and can be freely transferred. + err := bk.SendCoins(ctx, addr, dest, toClawBack) + if err != nil { + // shouldn't happen, we have a correctness issue in toClawBack in this case + return err + } + return nil +} diff --git a/x/auth/vesting/types/vesting_account_internal_test.go b/x/auth/vesting/types/vesting_account_internal_test.go new file mode 100644 index 000000000000..520bfef50bcf --- /dev/null +++ b/x/auth/vesting/types/vesting_account_internal_test.go @@ -0,0 +1,65 @@ +package types + +import ( + "testing" + "time" + + tmtime "github.com/cometbft/cometbft/types/time" + "github.com/cosmos/cosmos-sdk/testutil/testdata" + sdk "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + "github.com/stretchr/testify/require" +) + +var ( + stakeDenom = "stake" + feeDenom = "fee" +) + +func initBaseAccount() (*authtypes.BaseAccount, sdk.Coins) { + _, _, addr := testdata.KeyTestPubAddr() + origCoins := sdk.Coins{sdk.NewInt64Coin(feeDenom, 1000), sdk.NewInt64Coin(stakeDenom, 100)} + bacc := authtypes.NewBaseAccountWithAddress(addr) + + return bacc, origCoins +} + +func TestComputeClawback(t *testing.T) { + c := sdk.NewCoins + fee := func(x int64) sdk.Coin { return sdk.NewInt64Coin(feeDenom, x) } + stake := func(x int64) sdk.Coin { return sdk.NewInt64Coin(stakeDenom, x) } + now := tmtime.Now() + lockupPeriods := Periods{ + {Length: int64(12 * 3600), Amount: c(fee(1000), stake(100))}, // noon + } + vestingPeriods := Periods{ + {Length: int64(8 * 3600), Amount: c(fee(200))}, // 8am + {Length: int64(1 * 3600), Amount: c(fee(200), stake(50))}, // 9am + {Length: int64(6 * 3600), Amount: c(fee(200), stake(50))}, // 3pm + {Length: int64(2 * 3600), Amount: c(fee(200))}, // 5pm + {Length: int64(1 * 3600), Amount: c(fee(200))}, // 6pm + } + + bacc, origCoins := initBaseAccount() + va := NewClawbackVestingAccount(bacc, sdk.AccAddress([]byte("funder")), origCoins, now.Unix(), lockupPeriods, vestingPeriods) + + amt := va.computeClawback(now.Unix()) + require.Equal(t, c(fee(1000), stake(100)), amt) + require.Equal(t, c(), va.OriginalVesting) + require.Equal(t, 0, len(va.LockupPeriods)) + require.Equal(t, 0, len(va.VestingPeriods)) + + va2 := NewClawbackVestingAccount(bacc, sdk.AccAddress([]byte("funder")), origCoins, now.Unix(), lockupPeriods, vestingPeriods) + amt = va2.computeClawback(now.Add(11 * time.Hour).Unix()) + require.Equal(t, c(fee(600), stake(50)), amt) + require.Equal(t, c(fee(400), stake(50)), va2.OriginalVesting) + require.Equal(t, []Period{{Length: int64(12 * 3600), Amount: c(fee(400), stake(50))}}, va2.LockupPeriods) + require.Equal(t, []Period{ + {Length: int64(8 * 3600), Amount: c(fee(200))}, // 8am + {Length: int64(1 * 3600), Amount: c(fee(200), stake(50))}, // 9am + }, va2.VestingPeriods) + + va3 := NewClawbackVestingAccount(bacc, sdk.AccAddress([]byte("funder")), origCoins, now.Unix(), lockupPeriods, vestingPeriods) + amt = va3.computeClawback(now.Add(23 * time.Hour).Unix()) + require.Equal(t, c(), amt) +} diff --git a/x/bank/keeper/keeper.go b/x/bank/keeper/keeper.go index 31251b344c37..79449189c023 100644 --- a/x/bank/keeper/keeper.go +++ b/x/bank/keeper/keeper.go @@ -12,6 +12,7 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/query" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + vestexported "github.com/cosmos/cosmos-sdk/x/auth/vesting/exported" "github.com/cosmos/cosmos-sdk/x/bank/types" ) @@ -150,6 +151,12 @@ func (k BaseKeeper) DelegateCoins(ctx sdk.Context, delegatorAddr, moduleAccAddr return sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", moduleAccAddr) } + // do not allow delegation if clawback vesting account + acc := k.ak.GetAccount(ctx, delegatorAddr) + if _, ok := acc.(vestexported.ClawbackVestingAccountI); ok { + return fmt.Errorf("clawback vesting account (%s) is restricted for delegation", delegatorAddr) + } + if !amt.IsValid() { return sdkerrors.Wrap(sdkerrors.ErrInvalidCoins, amt.String()) }