Skip to content

Commit

Permalink
Consistent params validation (#394)
Browse files Browse the repository at this point in the history
* changes

* Update shared_params.go

* Create params_test.go

* corrections

* remove println

* Update x/ccv/consumer/types/params_test.go

Co-authored-by: Simon Noetzlin <[email protected]>

* Update x/ccv/consumer/types/params_test.go

Co-authored-by: Simon Noetzlin <[email protected]>

* fixing TestValidateParams

Co-authored-by: Marius Poke <[email protected]>
Co-authored-by: Simon Noetzlin <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2022
1 parent e5f4ea5 commit 27e55b4
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 34 deletions.
13 changes: 7 additions & 6 deletions x/ccv/consumer/keeper/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)

// TestParams tests the default params set for a consumer chain, and related getters/setters
// TestParams tests getters/setters for consumer params
func TestParams(t *testing.T) {
consumerKeeper, ctx, ctrl, _ := testkeeper.GetConsumerKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()
Expand All @@ -21,7 +21,8 @@ func TestParams(t *testing.T) {
params := consumerKeeper.GetParams(ctx)
require.Equal(t, expParams, params)

newParams := types.NewParams(false, 1000, "abc", "def", 7*24*time.Hour)
newParams := types.NewParams(false, 1000,
"channel-2", "cosmos19pe9pg5dv9k5fzgzmsrgnw9rl9asf7ddwhu7lm", 7*24*time.Hour)
consumerKeeper.SetParams(ctx, newParams)
params = consumerKeeper.GetParams(ctx)
require.Equal(t, newParams, params)
Expand All @@ -30,12 +31,12 @@ func TestParams(t *testing.T) {
gotBPDT := consumerKeeper.GetBlocksPerDistributionTransmission(ctx)
require.Equal(t, gotBPDT, int64(10))

consumerKeeper.SetDistributionTransmissionChannel(ctx, "foobarbaz")
consumerKeeper.SetDistributionTransmissionChannel(ctx, "channel-7")
gotChan := consumerKeeper.GetDistributionTransmissionChannel(ctx)
require.Equal(t, gotChan, "foobarbaz")
require.Equal(t, gotChan, "channel-7")

consumerKeeper.SetProviderFeePoolAddrStr(ctx, "foobar")
consumerKeeper.SetProviderFeePoolAddrStr(ctx, "cosmos1dkas8mu4kyhl5jrh4nzvm65qz588hy9qcz08la")
gotAddr := consumerKeeper.
GetProviderFeePoolAddrStr(ctx)
require.Equal(t, gotAddr, "foobar")
require.Equal(t, gotAddr, "cosmos1dkas8mu4kyhl5jrh4nzvm65qz588hy9qcz08la")
}
6 changes: 3 additions & 3 deletions x/ccv/consumer/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (k Keeper) SendVSCMaturedPackets(ctx sdk.Context) error {
channelID, // source channel id
ccv.ConsumerPortID, // source port id
packetData.GetBytes(),
k.GetParams(ctx).CcvTimeoutPeriod,
k.GetCCVTimeoutPeriod(ctx),
)
if err != nil {
return err
Expand Down Expand Up @@ -154,7 +154,7 @@ func (k Keeper) SendSlashPacket(ctx sdk.Context, validator abci.Validator, valse
channelID, // source channel id
ccv.ConsumerPortID, // source port id
packetData.GetBytes(),
k.GetParams(ctx).CcvTimeoutPeriod,
k.GetCCVTimeoutPeriod(ctx),
)
if err != nil {
panic(err)
Expand Down Expand Up @@ -191,7 +191,7 @@ func (k Keeper) SendPendingSlashRequests(ctx sdk.Context) {
channelID, // source channel id
ccv.ConsumerPortID, // source port id
slashReq.Packet.GetBytes(),
k.GetParams(ctx).CcvTimeoutPeriod,
k.GetCCVTimeoutPeriod(ctx),
)
if err != nil {
panic(err)
Expand Down
3 changes: 3 additions & 0 deletions x/ccv/consumer/types/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ func (gs GenesisState) Validate() error {
if len(gs.InitialValSet) == 0 {
return sdkerrors.Wrap(ccv.ErrInvalidGenesis, "initial validator set is empty")
}
if err := gs.Params.Validate(); err != nil {
return err
}

if gs.NewChain {
if gs.ProviderClientState == nil {
Expand Down
24 changes: 24 additions & 0 deletions x/ccv/consumer/types/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ func TestValidateInitialGenesisState(t *testing.T) {
valUpdates, types.SlashRequests{}, params),
true,
},
{
"invalid new consumer genesis state: invalid params",
types.NewInitialGenesisState(cs, consensusState, valUpdates, types.SlashRequests{},
types.NewParams(
true,
types.DefaultBlocksPerDistributionTransmission,
"",
"",
0, // CCV timeout period cannot be 0
)),
true,
},
}

for _, c := range cases {
Expand Down Expand Up @@ -257,6 +269,18 @@ func TestValidateRestartGenesisState(t *testing.T) {
types.NewRestartGenesisState("ccvclient", "ccvchannel", nil, nil, nil, nil, params),
true,
},
{
"invalid restart consumer genesis state: invalid params",
types.NewRestartGenesisState("ccvclient", "ccvchannel", nil, valUpdates, nil, nil,
types.NewParams(
true,
types.DefaultBlocksPerDistributionTransmission,
"",
"",
0, // CCV timeout period cannot be 0
)),
true,
},
}

for _, c := range cases {
Expand Down
53 changes: 32 additions & 21 deletions x/ccv/consumer/types/params.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package types

import (
"fmt"
"time"

paramtypes "github.com/cosmos/cosmos-sdk/x/params/types"
Expand Down Expand Up @@ -52,41 +51,53 @@ func DefaultParams() Params {

// Validate all ccv-consumer module parameters
func (p Params) Validate() error {
if err := ccvtypes.ValidateBool(p.Enabled); err != nil {
return err
}
if err := ccvtypes.ValidatePositiveInt64(p.BlocksPerDistributionTransmission); err != nil {
return err
}
if err := validateDistributionTransmissionChannel(p.DistributionTransmissionChannel); err != nil {
return err
}
if err := validateProviderFeePoolAddrStr(p.ProviderFeePoolAddrStr); err != nil {
return err
}
if err := ccvtypes.ValidateDuration(p.CcvTimeoutPeriod); err != nil {
return err
}
return nil
}

// ParamSetPairs implements params.ParamSet
func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs {
return paramtypes.ParamSetPairs{
paramtypes.NewParamSetPair(KeyEnabled, p.Enabled, validateBool),
paramtypes.NewParamSetPair(KeyEnabled, p.Enabled, ccvtypes.ValidateBool),
paramtypes.NewParamSetPair(KeyBlocksPerDistributionTransmission,
p.BlocksPerDistributionTransmission, validateInt64),
p.BlocksPerDistributionTransmission, ccvtypes.ValidatePositiveInt64),
paramtypes.NewParamSetPair(KeyDistributionTransmissionChannel,
p.DistributionTransmissionChannel, validateString),
p.DistributionTransmissionChannel, validateDistributionTransmissionChannel),
paramtypes.NewParamSetPair(KeyProviderFeePoolAddrStr,
p.ProviderFeePoolAddrStr, validateString),
p.ProviderFeePoolAddrStr, validateProviderFeePoolAddrStr),
paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod,
p.CcvTimeoutPeriod, ccvtypes.ValidateCCVTimeoutPeriod),
p.CcvTimeoutPeriod, ccvtypes.ValidateDuration),
}
}

func validateBool(i interface{}) error {
if _, ok := i.(bool); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
func validateDistributionTransmissionChannel(i interface{}) error {
// Accept empty string as valid, since this will be the default value on genesis
if i == "" {
return nil
}
return nil
// Otherwise validate as usual for a channelID
return ccvtypes.ValidateChannelIdentifier(i)
}

func validateInt64(i interface{}) error {
if _, ok := i.(int64); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
func validateProviderFeePoolAddrStr(i interface{}) error {
// Accept empty string as valid, since this will be the default value on genesis
if i == "" {
return nil
}
return nil
}

func validateString(i interface{}) error {
if _, ok := i.(string); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
// Otherwise validate as usual for a bech32 address
return ccvtypes.ValidateBech32(i)
}
35 changes: 35 additions & 0 deletions x/ccv/consumer/types/params_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"

consumertypes "github.com/cosmos/interchain-security/x/ccv/consumer/types"
)

// Tests the validation of consumer params that happens at genesis
func TestValidateParams(t *testing.T) {

testCases := []struct {
name string
params consumertypes.Params
expPass bool
}{
{"default params", consumertypes.DefaultParams(), true},
{"custom valid params", consumertypes.NewParams(true, 5, "", "", 5), true},
{"custom invalid params, block per dist transmission", consumertypes.NewParams(true, -5, "", "", 5), false},
{"custom invalid params, dist transmission channel", consumertypes.NewParams(true, 5, "badchannel/", "", 5), false},
{"custom invalid params, provider fee pool addr string", consumertypes.NewParams(true, 5, "", "imabadaddress", 5), false},
{"custom invalid params, ccv timeout", consumertypes.NewParams(true, 5, "", "", -5), false},
}

for _, tc := range testCases {
err := tc.params.Validate()
if tc.expPass {
require.Nil(t, err, "expected error to be nil for test case: %s", tc.name)
} else {
require.NotNil(t, err, "expected error but got nil for test case: %s", tc.name)
}
}
}
4 changes: 2 additions & 2 deletions x/ccv/provider/types/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (p Params) Validate() error {
if p.TemplateClient == nil {
return fmt.Errorf("template client is nil")
}
if ccvtypes.ValidateCCVTimeoutPeriod(p.CcvTimeoutPeriod) != nil {
if ccvtypes.ValidateDuration(p.CcvTimeoutPeriod) != nil {
return fmt.Errorf("ccv timeout period is invalid")
}
return validateTemplateClient(*p.TemplateClient)
Expand All @@ -67,7 +67,7 @@ func (p Params) Validate() error {
func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs {
return paramtypes.ParamSetPairs{
paramtypes.NewParamSetPair(KeyTemplateClient, p.TemplateClient, validateTemplateClient),
paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, p.CcvTimeoutPeriod, ccvtypes.ValidateCCVTimeoutPeriod),
paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, p.CcvTimeoutPeriod, ccvtypes.ValidateDuration),
}
}

Expand Down
55 changes: 53 additions & 2 deletions x/ccv/types/shared_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package types
import (
fmt "fmt"
"time"

sdktypes "github.com/cosmos/cosmos-sdk/types"
ibchost "github.com/cosmos/ibc-go/v3/modules/core/24-host"
)

const (
Expand All @@ -14,13 +17,61 @@ var (
KeyCCVTimeoutPeriod = []byte("CcvTimeoutPeriod")
)

func ValidateCCVTimeoutPeriod(i interface{}) error {
func ValidateDuration(i interface{}) error {
period, ok := i.(time.Duration)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
if period <= time.Duration(0) {
return fmt.Errorf("ibc timeout period is not positive")
return fmt.Errorf("duration must be positive")
}
return nil
}

func ValidateBool(i interface{}) error {
if _, ok := i.(bool); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
}

func ValidateInt64(i interface{}) error {
if _, ok := i.(int64); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
}

func ValidatePositiveInt64(i interface{}) error {
if err := ValidateInt64(i); err != nil {
return err
}
if i.(int64) <= int64(0) {
return fmt.Errorf("int must be positive")
}
return nil
}

func ValidateString(i interface{}) error {
if _, ok := i.(string); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
}

func ValidateChannelIdentifier(i interface{}) error {
value, ok := i.(string)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return ibchost.ChannelIdentifierValidator(value)
}

func ValidateBech32(i interface{}) error {
value, ok := i.(string)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
_, err := sdktypes.AccAddressFromBech32(value)
return err
}

0 comments on commit 27e55b4

Please sign in to comment.