From 27e55b4f5cb2a087ce002d42dc8d99cf1154f0c5 Mon Sep 17 00:00:00 2001 From: Shawn Marshall-Spitzbart <44221603+smarshall-spitzbart@users.noreply.github.com> Date: Tue, 18 Oct 2022 04:48:26 -0700 Subject: [PATCH] Consistent params validation (#394) * changes * Update shared_params.go * Create params_test.go * corrections * remove println * Update x/ccv/consumer/types/params_test.go Co-authored-by: Simon Noetzlin * Update x/ccv/consumer/types/params_test.go Co-authored-by: Simon Noetzlin * fixing TestValidateParams Co-authored-by: Marius Poke Co-authored-by: Simon Noetzlin --- x/ccv/consumer/keeper/params_test.go | 13 ++++--- x/ccv/consumer/keeper/relay.go | 6 +-- x/ccv/consumer/types/genesis.go | 3 ++ x/ccv/consumer/types/genesis_test.go | 24 ++++++++++++ x/ccv/consumer/types/params.go | 53 ++++++++++++++++----------- x/ccv/consumer/types/params_test.go | 35 ++++++++++++++++++ x/ccv/provider/types/params.go | 4 +- x/ccv/types/shared_params.go | 55 +++++++++++++++++++++++++++- 8 files changed, 159 insertions(+), 34 deletions(-) create mode 100644 x/ccv/consumer/types/params_test.go diff --git a/x/ccv/consumer/keeper/params_test.go b/x/ccv/consumer/keeper/params_test.go index 7a9b0dcd94..e7f653a533 100644 --- a/x/ccv/consumer/keeper/params_test.go +++ b/x/ccv/consumer/keeper/params_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -// TestParams tests the default params set for a consumer chain, and related getters/setters +// TestParams tests getters/setters for consumer params func TestParams(t *testing.T) { consumerKeeper, ctx, ctrl, _ := testkeeper.GetConsumerKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t)) defer ctrl.Finish() @@ -21,7 +21,8 @@ func TestParams(t *testing.T) { params := consumerKeeper.GetParams(ctx) require.Equal(t, expParams, params) - newParams := types.NewParams(false, 1000, "abc", "def", 7*24*time.Hour) + newParams := types.NewParams(false, 1000, + "channel-2", "cosmos19pe9pg5dv9k5fzgzmsrgnw9rl9asf7ddwhu7lm", 7*24*time.Hour) consumerKeeper.SetParams(ctx, newParams) params = consumerKeeper.GetParams(ctx) require.Equal(t, newParams, params) @@ -30,12 +31,12 @@ func TestParams(t *testing.T) { gotBPDT := consumerKeeper.GetBlocksPerDistributionTransmission(ctx) require.Equal(t, gotBPDT, int64(10)) - consumerKeeper.SetDistributionTransmissionChannel(ctx, "foobarbaz") + consumerKeeper.SetDistributionTransmissionChannel(ctx, "channel-7") gotChan := consumerKeeper.GetDistributionTransmissionChannel(ctx) - require.Equal(t, gotChan, "foobarbaz") + require.Equal(t, gotChan, "channel-7") - consumerKeeper.SetProviderFeePoolAddrStr(ctx, "foobar") + consumerKeeper.SetProviderFeePoolAddrStr(ctx, "cosmos1dkas8mu4kyhl5jrh4nzvm65qz588hy9qcz08la") gotAddr := consumerKeeper. GetProviderFeePoolAddrStr(ctx) - require.Equal(t, gotAddr, "foobar") + require.Equal(t, gotAddr, "cosmos1dkas8mu4kyhl5jrh4nzvm65qz588hy9qcz08la") } diff --git a/x/ccv/consumer/keeper/relay.go b/x/ccv/consumer/keeper/relay.go index 4d5b085937..e15343b1bb 100644 --- a/x/ccv/consumer/keeper/relay.go +++ b/x/ccv/consumer/keeper/relay.go @@ -108,7 +108,7 @@ func (k Keeper) SendVSCMaturedPackets(ctx sdk.Context) error { channelID, // source channel id ccv.ConsumerPortID, // source port id packetData.GetBytes(), - k.GetParams(ctx).CcvTimeoutPeriod, + k.GetCCVTimeoutPeriod(ctx), ) if err != nil { return err @@ -154,7 +154,7 @@ func (k Keeper) SendSlashPacket(ctx sdk.Context, validator abci.Validator, valse channelID, // source channel id ccv.ConsumerPortID, // source port id packetData.GetBytes(), - k.GetParams(ctx).CcvTimeoutPeriod, + k.GetCCVTimeoutPeriod(ctx), ) if err != nil { panic(err) @@ -191,7 +191,7 @@ func (k Keeper) SendPendingSlashRequests(ctx sdk.Context) { channelID, // source channel id ccv.ConsumerPortID, // source port id slashReq.Packet.GetBytes(), - k.GetParams(ctx).CcvTimeoutPeriod, + k.GetCCVTimeoutPeriod(ctx), ) if err != nil { panic(err) diff --git a/x/ccv/consumer/types/genesis.go b/x/ccv/consumer/types/genesis.go index 0e2a0b2557..a2fdbe38a9 100644 --- a/x/ccv/consumer/types/genesis.go +++ b/x/ccv/consumer/types/genesis.go @@ -62,6 +62,9 @@ func (gs GenesisState) Validate() error { if len(gs.InitialValSet) == 0 { return sdkerrors.Wrap(ccv.ErrInvalidGenesis, "initial validator set is empty") } + if err := gs.Params.Validate(); err != nil { + return err + } if gs.NewChain { if gs.ProviderClientState == nil { diff --git a/x/ccv/consumer/types/genesis_test.go b/x/ccv/consumer/types/genesis_test.go index 3408c81ccb..9dfcf9146d 100644 --- a/x/ccv/consumer/types/genesis_test.go +++ b/x/ccv/consumer/types/genesis_test.go @@ -144,6 +144,18 @@ func TestValidateInitialGenesisState(t *testing.T) { valUpdates, types.SlashRequests{}, params), true, }, + { + "invalid new consumer genesis state: invalid params", + types.NewInitialGenesisState(cs, consensusState, valUpdates, types.SlashRequests{}, + types.NewParams( + true, + types.DefaultBlocksPerDistributionTransmission, + "", + "", + 0, // CCV timeout period cannot be 0 + )), + true, + }, } for _, c := range cases { @@ -257,6 +269,18 @@ func TestValidateRestartGenesisState(t *testing.T) { types.NewRestartGenesisState("ccvclient", "ccvchannel", nil, nil, nil, nil, params), true, }, + { + "invalid restart consumer genesis state: invalid params", + types.NewRestartGenesisState("ccvclient", "ccvchannel", nil, valUpdates, nil, nil, + types.NewParams( + true, + types.DefaultBlocksPerDistributionTransmission, + "", + "", + 0, // CCV timeout period cannot be 0 + )), + true, + }, } for _, c := range cases { diff --git a/x/ccv/consumer/types/params.go b/x/ccv/consumer/types/params.go index dc5d7d5432..d714eaef25 100644 --- a/x/ccv/consumer/types/params.go +++ b/x/ccv/consumer/types/params.go @@ -1,7 +1,6 @@ package types import ( - "fmt" "time" paramtypes "github.com/cosmos/cosmos-sdk/x/params/types" @@ -52,41 +51,53 @@ func DefaultParams() Params { // Validate all ccv-consumer module parameters func (p Params) Validate() error { + if err := ccvtypes.ValidateBool(p.Enabled); err != nil { + return err + } + if err := ccvtypes.ValidatePositiveInt64(p.BlocksPerDistributionTransmission); err != nil { + return err + } + if err := validateDistributionTransmissionChannel(p.DistributionTransmissionChannel); err != nil { + return err + } + if err := validateProviderFeePoolAddrStr(p.ProviderFeePoolAddrStr); err != nil { + return err + } + if err := ccvtypes.ValidateDuration(p.CcvTimeoutPeriod); err != nil { + return err + } return nil } // ParamSetPairs implements params.ParamSet func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs { return paramtypes.ParamSetPairs{ - paramtypes.NewParamSetPair(KeyEnabled, p.Enabled, validateBool), + paramtypes.NewParamSetPair(KeyEnabled, p.Enabled, ccvtypes.ValidateBool), paramtypes.NewParamSetPair(KeyBlocksPerDistributionTransmission, - p.BlocksPerDistributionTransmission, validateInt64), + p.BlocksPerDistributionTransmission, ccvtypes.ValidatePositiveInt64), paramtypes.NewParamSetPair(KeyDistributionTransmissionChannel, - p.DistributionTransmissionChannel, validateString), + p.DistributionTransmissionChannel, validateDistributionTransmissionChannel), paramtypes.NewParamSetPair(KeyProviderFeePoolAddrStr, - p.ProviderFeePoolAddrStr, validateString), + p.ProviderFeePoolAddrStr, validateProviderFeePoolAddrStr), paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, - p.CcvTimeoutPeriod, ccvtypes.ValidateCCVTimeoutPeriod), + p.CcvTimeoutPeriod, ccvtypes.ValidateDuration), } } -func validateBool(i interface{}) error { - if _, ok := i.(bool); !ok { - return fmt.Errorf("invalid parameter type: %T", i) +func validateDistributionTransmissionChannel(i interface{}) error { + // Accept empty string as valid, since this will be the default value on genesis + if i == "" { + return nil } - return nil + // Otherwise validate as usual for a channelID + return ccvtypes.ValidateChannelIdentifier(i) } -func validateInt64(i interface{}) error { - if _, ok := i.(int64); !ok { - return fmt.Errorf("invalid parameter type: %T", i) +func validateProviderFeePoolAddrStr(i interface{}) error { + // Accept empty string as valid, since this will be the default value on genesis + if i == "" { + return nil } - return nil -} - -func validateString(i interface{}) error { - if _, ok := i.(string); !ok { - return fmt.Errorf("invalid parameter type: %T", i) - } - return nil + // Otherwise validate as usual for a bech32 address + return ccvtypes.ValidateBech32(i) } diff --git a/x/ccv/consumer/types/params_test.go b/x/ccv/consumer/types/params_test.go new file mode 100644 index 0000000000..b7e8cbf937 --- /dev/null +++ b/x/ccv/consumer/types/params_test.go @@ -0,0 +1,35 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + consumertypes "github.com/cosmos/interchain-security/x/ccv/consumer/types" +) + +// Tests the validation of consumer params that happens at genesis +func TestValidateParams(t *testing.T) { + + testCases := []struct { + name string + params consumertypes.Params + expPass bool + }{ + {"default params", consumertypes.DefaultParams(), true}, + {"custom valid params", consumertypes.NewParams(true, 5, "", "", 5), true}, + {"custom invalid params, block per dist transmission", consumertypes.NewParams(true, -5, "", "", 5), false}, + {"custom invalid params, dist transmission channel", consumertypes.NewParams(true, 5, "badchannel/", "", 5), false}, + {"custom invalid params, provider fee pool addr string", consumertypes.NewParams(true, 5, "", "imabadaddress", 5), false}, + {"custom invalid params, ccv timeout", consumertypes.NewParams(true, 5, "", "", -5), false}, + } + + for _, tc := range testCases { + err := tc.params.Validate() + if tc.expPass { + require.Nil(t, err, "expected error to be nil for test case: %s", tc.name) + } else { + require.NotNil(t, err, "expected error but got nil for test case: %s", tc.name) + } + } +} diff --git a/x/ccv/provider/types/params.go b/x/ccv/provider/types/params.go index 0058a2cbe4..5a51a4eb51 100644 --- a/x/ccv/provider/types/params.go +++ b/x/ccv/provider/types/params.go @@ -57,7 +57,7 @@ func (p Params) Validate() error { if p.TemplateClient == nil { return fmt.Errorf("template client is nil") } - if ccvtypes.ValidateCCVTimeoutPeriod(p.CcvTimeoutPeriod) != nil { + if ccvtypes.ValidateDuration(p.CcvTimeoutPeriod) != nil { return fmt.Errorf("ccv timeout period is invalid") } return validateTemplateClient(*p.TemplateClient) @@ -67,7 +67,7 @@ func (p Params) Validate() error { func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs { return paramtypes.ParamSetPairs{ paramtypes.NewParamSetPair(KeyTemplateClient, p.TemplateClient, validateTemplateClient), - paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, p.CcvTimeoutPeriod, ccvtypes.ValidateCCVTimeoutPeriod), + paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, p.CcvTimeoutPeriod, ccvtypes.ValidateDuration), } } diff --git a/x/ccv/types/shared_params.go b/x/ccv/types/shared_params.go index 5cd02e0c25..15e2dd1acb 100644 --- a/x/ccv/types/shared_params.go +++ b/x/ccv/types/shared_params.go @@ -3,6 +3,9 @@ package types import ( fmt "fmt" "time" + + sdktypes "github.com/cosmos/cosmos-sdk/types" + ibchost "github.com/cosmos/ibc-go/v3/modules/core/24-host" ) const ( @@ -14,13 +17,61 @@ var ( KeyCCVTimeoutPeriod = []byte("CcvTimeoutPeriod") ) -func ValidateCCVTimeoutPeriod(i interface{}) error { +func ValidateDuration(i interface{}) error { period, ok := i.(time.Duration) if !ok { return fmt.Errorf("invalid parameter type: %T", i) } if period <= time.Duration(0) { - return fmt.Errorf("ibc timeout period is not positive") + return fmt.Errorf("duration must be positive") + } + return nil +} + +func ValidateBool(i interface{}) error { + if _, ok := i.(bool); !ok { + return fmt.Errorf("invalid parameter type: %T", i) + } + return nil +} + +func ValidateInt64(i interface{}) error { + if _, ok := i.(int64); !ok { + return fmt.Errorf("invalid parameter type: %T", i) + } + return nil +} + +func ValidatePositiveInt64(i interface{}) error { + if err := ValidateInt64(i); err != nil { + return err + } + if i.(int64) <= int64(0) { + return fmt.Errorf("int must be positive") } return nil } + +func ValidateString(i interface{}) error { + if _, ok := i.(string); !ok { + return fmt.Errorf("invalid parameter type: %T", i) + } + return nil +} + +func ValidateChannelIdentifier(i interface{}) error { + value, ok := i.(string) + if !ok { + return fmt.Errorf("invalid parameter type: %T", i) + } + return ibchost.ChannelIdentifierValidator(value) +} + +func ValidateBech32(i interface{}) error { + value, ok := i.(string) + if !ok { + return fmt.Errorf("invalid parameter type: %T", i) + } + _, err := sdktypes.AccAddressFromBech32(value) + return err +}