Skip to content

Commit

Permalink
refactor(auth,vesting): move ValidateBasic logic to msgServer (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
julienrbrt authored Apr 12, 2023
1 parent 1641bb9 commit e4dbf1b
Show file tree
Hide file tree
Showing 17 changed files with 229 additions and 329 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Ref: https://keepachangelog.com/en/1.0.0/

### API Breaking Changes

* (x/*all*) [#15648](https://github.com/cosmos/cosmos-sdk/issues/15648) Make `SetParams` consistent across all modules and validate the params at the message handling instead of `SetParams` method.
* (x/genutil) [#15679](https://github.com/cosmos/cosmos-sdk/pull/15679) `MigrateGenesisCmd` now takes a `MigrationMap` instead of having the SDK genesis migration hardcoded.
* (client) [#15673](https://github.com/cosmos/cosmos-sdk/pull/15673) Move `client/keys.OutputFormatJSON` and `client/keys.OutputFormatText` to `client/flags` package.
* (x/nft) [#15588](https://github.com/cosmos/cosmos-sdk/pull/15588) `NewKeeper` now takes a `KVStoreService` instead of a `StoreKey` and methods in the `Keeper` now take a `context.Context` instead of a `sdk.Context`.
Expand Down
4 changes: 1 addition & 3 deletions x/auth/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,8 @@ func (ak AccountKeeper) getBech32Prefix() (string, error) {
}

// SetParams sets the auth module's parameters.
// CONTRACT: This method performs no validation of the parameters.
func (ak AccountKeeper) SetParams(ctx context.Context, params types.Params) error {
if err := params.Validate(); err != nil {
return err
}
return ak.ParamsState.Set(ctx, params)
}

Expand Down
12 changes: 8 additions & 4 deletions x/auth/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ func NewMsgServerImpl(ak AccountKeeper) types.MsgServer {
}
}

func (ms msgServer) UpdateParams(goCtx context.Context, req *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) {
if ms.authority != req.Authority {
return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", ms.authority, req.Authority)
func (ms msgServer) UpdateParams(goCtx context.Context, msg *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) {
if ms.authority != msg.Authority {
return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", ms.authority, msg.Authority)
}

if err := msg.Params.Validate(); err != nil {
return nil, err
}

ctx := sdk.UnwrapSDKContext(goCtx)
if err := ms.SetParams(ctx, req.Params); err != nil {
if err := ms.SetParams(ctx, msg.Params); err != nil {
return nil, err
}

Expand Down
15 changes: 0 additions & 15 deletions x/auth/types/msgs.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package types

import (
"cosmossdk.io/errors"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
)
Expand All @@ -22,16 +20,3 @@ func (msg MsgUpdateParams) GetSigners() []sdk.AccAddress {
addr, _ := sdk.AccAddressFromBech32(msg.Authority)
return []sdk.AccAddress{addr}
}

// ValidateBasic does a sanity check on the provided data.
func (msg MsgUpdateParams) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil {
return errors.Wrap(err, "invalid authority address")
}

if err := msg.Params.Validate(); err != nil {
return err
}

return nil
}
7 changes: 1 addition & 6 deletions x/auth/vesting/client/cli/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ timestamp.`,
delayed, _ := cmd.Flags().GetBool(FlagDelayed)

msg := types.NewMsgCreateVestingAccount(clientCtx.GetFromAddress(), toAddr, amount, endTime, delayed)

return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg)
},
}
Expand Down Expand Up @@ -111,7 +110,6 @@ tokens.`,
}

msg := types.NewMsgCreatePermanentLockedAccount(clientCtx.GetFromAddress(), toAddr, amount)

return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg)
},
}
Expand Down Expand Up @@ -190,15 +188,12 @@ func NewMsgCreatePeriodicVestingAccountCmd() *cobra.Command {
if p.Length < 0 {
return fmt.Errorf("invalid period length of %d in period %d, length must be greater than 0", p.Length, i)
}

period := types.Period{Length: p.Length, Amount: amount}
periods = append(periods, period)
}

msg := types.NewMsgCreatePeriodicVestingAccount(clientCtx.GetFromAddress(), toAddr, vestingData.StartTime, periods)
if err := msg.ValidateBasic(); err != nil {
return err
}

return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg)
},
}
Expand Down
109 changes: 67 additions & 42 deletions x/auth/vesting/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,39 @@ func NewMsgServerImpl(k keeper.AccountKeeper, bk types.BankKeeper) types.MsgServ
var _ types.MsgServer = msgServer{}

func (s msgServer) CreateVestingAccount(goCtx context.Context, msg *types.MsgCreateVestingAccount) (*types.MsgCreateVestingAccountResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
ak := s.AccountKeeper
bk := s.BankKeeper

if err := bk.IsSendEnabledCoins(ctx, msg.Amount...); err != nil {
return nil, err
}

from, err := sdk.AccAddressFromBech32(msg.FromAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid 'from' address: %s", err)
}

to, err := sdk.AccAddressFromBech32(msg.ToAddress)
if err != nil {
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid 'to' address: %s", err)
}

if err := validateAmount(msg.Amount); err != nil {
return nil, err
}

if msg.EndTime <= 0 {
return nil, errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "invalid end time")
}

ctx := sdk.UnwrapSDKContext(goCtx)
if err := s.BankKeeper.IsSendEnabledCoins(ctx, msg.Amount...); err != nil {
return nil, err
}

if bk.BlockedAddr(to) {
if s.BankKeeper.BlockedAddr(to) {
return nil, errorsmod.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", msg.ToAddress)
}

if acc := ak.GetAccount(ctx, to); acc != nil {
if acc := s.AccountKeeper.GetAccount(ctx, to); acc != nil {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "account %s already exists", msg.ToAddress)
}

baseAccount := authtypes.NewBaseAccountWithAddress(to)
baseAccount = ak.NewAccount(ctx, baseAccount).(*authtypes.BaseAccount)
baseAccount = s.AccountKeeper.NewAccount(ctx, baseAccount).(*authtypes.BaseAccount)
baseVestingAccount := types.NewBaseVestingAccount(baseAccount, msg.Amount.Sort(), msg.EndTime)

var vestingAccount sdk.AccountI
Expand All @@ -65,7 +71,7 @@ func (s msgServer) CreateVestingAccount(goCtx context.Context, msg *types.MsgCre
vestingAccount = types.NewContinuousVestingAccountRaw(baseVestingAccount, ctx.BlockTime().Unix())
}

ak.SetAccount(ctx, vestingAccount)
s.AccountKeeper.SetAccount(ctx, vestingAccount)

defer func() {
telemetry.IncrCounter(1, "new", "account")
Expand All @@ -81,44 +87,46 @@ func (s msgServer) CreateVestingAccount(goCtx context.Context, msg *types.MsgCre
}
}()

if err = bk.SendCoins(ctx, from, to, msg.Amount); err != nil {
if err = s.BankKeeper.SendCoins(ctx, from, to, msg.Amount); err != nil {
return nil, err
}

return &types.MsgCreateVestingAccountResponse{}, nil
}

func (s msgServer) CreatePermanentLockedAccount(goCtx context.Context, msg *types.MsgCreatePermanentLockedAccount) (*types.MsgCreatePermanentLockedAccountResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
ak := s.AccountKeeper
bk := s.BankKeeper

if err := bk.IsSendEnabledCoins(ctx, msg.Amount...); err != nil {
return nil, err
}

from, err := sdk.AccAddressFromBech32(msg.FromAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid 'from' address: %s", err)
}

to, err := sdk.AccAddressFromBech32(msg.ToAddress)
if err != nil {
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid 'to' address: %s", err)
}

if err := validateAmount(msg.Amount); err != nil {
return nil, err
}

if bk.BlockedAddr(to) {
ctx := sdk.UnwrapSDKContext(goCtx)
if err := s.BankKeeper.IsSendEnabledCoins(ctx, msg.Amount...); err != nil {
return nil, err
}

if s.BankKeeper.BlockedAddr(to) {
return nil, errorsmod.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", msg.ToAddress)
}

if acc := ak.GetAccount(ctx, to); acc != nil {
if acc := s.AccountKeeper.GetAccount(ctx, to); acc != nil {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "account %s already exists", msg.ToAddress)
}

baseAccount := authtypes.NewBaseAccountWithAddress(to)
baseAccount = ak.NewAccount(ctx, baseAccount).(*authtypes.BaseAccount)
baseAccount = s.AccountKeeper.NewAccount(ctx, baseAccount).(*authtypes.BaseAccount)
vestingAccount := types.NewPermanentLockedAccount(baseAccount, msg.Amount)

ak.SetAccount(ctx, vestingAccount)
s.AccountKeeper.SetAccount(ctx, vestingAccount)

defer func() {
telemetry.IncrCounter(1, "new", "account")
Expand All @@ -134,46 +142,51 @@ func (s msgServer) CreatePermanentLockedAccount(goCtx context.Context, msg *type
}
}()

if err = bk.SendCoins(ctx, from, to, msg.Amount); err != nil {
if err = s.BankKeeper.SendCoins(ctx, from, to, msg.Amount); err != nil {
return nil, err
}

return &types.MsgCreatePermanentLockedAccountResponse{}, nil
}

func (s msgServer) CreatePeriodicVestingAccount(goCtx context.Context, msg *types.MsgCreatePeriodicVestingAccount) (*types.MsgCreatePeriodicVestingAccountResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

ak := s.AccountKeeper
bk := s.BankKeeper

from, err := sdk.AccAddressFromBech32(msg.FromAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid 'from' address: %s", err)
}

to, err := sdk.AccAddressFromBech32(msg.ToAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid 'to' address: %s", err)
}

if acc := ak.GetAccount(ctx, to); acc != nil {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "account %s already exists", msg.ToAddress)
if msg.StartTime < 1 {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "invalid start time of %d, length must be greater than 0", msg.StartTime)
}

var totalCoins sdk.Coins
for _, period := range msg.VestingPeriods {
for i, period := range msg.VestingPeriods {
if period.Length < 1 {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "invalid period length of %d in period %d, length must be greater than 0", period.Length, i)
}

totalCoins = totalCoins.Add(period.Amount...)
}

if err := bk.IsSendEnabledCoins(ctx, totalCoins...); err != nil {
ctx := sdk.UnwrapSDKContext(goCtx)
if acc := s.AccountKeeper.GetAccount(ctx, to); acc != nil {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "account %s already exists", msg.ToAddress)
}

if err := s.BankKeeper.IsSendEnabledCoins(ctx, totalCoins...); err != nil {
return nil, err
}

baseAccount := authtypes.NewBaseAccountWithAddress(to)
baseAccount = ak.NewAccount(ctx, baseAccount).(*authtypes.BaseAccount)
baseAccount = s.AccountKeeper.NewAccount(ctx, baseAccount).(*authtypes.BaseAccount)
vestingAccount := types.NewPeriodicVestingAccount(baseAccount, totalCoins.Sort(), msg.StartTime, msg.VestingPeriods)

ak.SetAccount(ctx, vestingAccount)
s.AccountKeeper.SetAccount(ctx, vestingAccount)

defer func() {
telemetry.IncrCounter(1, "new", "account")
Expand All @@ -189,9 +202,21 @@ func (s msgServer) CreatePeriodicVestingAccount(goCtx context.Context, msg *type
}
}()

if err = bk.SendCoins(ctx, from, to, totalCoins); err != nil {
if err = s.BankKeeper.SendCoins(ctx, from, to, totalCoins); err != nil {
return nil, err
}

return &types.MsgCreatePeriodicVestingAccountResponse{}, nil
}

func validateAmount(amount sdk.Coins) error {
if !amount.IsValid() {
return sdkerrors.ErrInvalidCoins.Wrap(amount.String())
}

if !amount.IsAllPositive() {
return sdkerrors.ErrInvalidCoins.Wrap(amount.String())
}

return nil
}
Loading

0 comments on commit e4dbf1b

Please sign in to comment.