diff --git a/x/interchainstaking/keeper/ibc_packet_handlers.go b/x/interchainstaking/keeper/ibc_packet_handlers.go index a23c987da..19ef0abe7 100644 --- a/x/interchainstaking/keeper/ibc_packet_handlers.go +++ b/x/interchainstaking/keeper/ibc_packet_handlers.go @@ -5,26 +5,23 @@ import ( "errors" "fmt" "reflect" - "strconv" "strings" "time" "cosmossdk.io/math" - "github.com/golang/protobuf/proto" //nolint:staticcheck - "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/telemetry" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/bech32" - icatypes "github.com/cosmos/ibc-go/v5/modules/apps/27-interchain-accounts/types" - ibctransfertypes "github.com/cosmos/ibc-go/v5/modules/apps/transfer/types" - clienttypes "github.com/cosmos/ibc-go/v5/modules/core/02-client/types" - channeltypes "github.com/cosmos/ibc-go/v5/modules/core/04-channel/types" - - "github.com/cosmos/cosmos-sdk/telemetry" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + icatypes "github.com/cosmos/ibc-go/v5/modules/apps/27-interchain-accounts/types" + ibctransfertypes "github.com/cosmos/ibc-go/v5/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v5/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v5/modules/core/04-channel/types" + "github.com/golang/protobuf/proto" //nolint:staticcheck lsmstakingtypes "github.com/iqlusioninc/liquidity-staking-module/x/staking/types" "github.com/ingenuity-build/quicksilver/utils" @@ -34,7 +31,6 @@ import ( const ( transferPort = "transfer" - withdrawal = "withdrawal" ) type TypedMsg struct { @@ -578,14 +574,10 @@ func (k *Keeper) HandleBeginRedelegate(ctx sdk.Context, msg sdk.Msg, completion if completion.IsZero() { return errors.New("invalid zero nil completion time") } - parts := strings.Split(memo, "/") - if len(parts) != 2 || parts[0] != "rebalance" { - return errors.New("unexpected epoch rebalance memo format") - } - epochNumber, err := strconv.ParseInt(parts[1], 10, 64) + epochNumber, err := types.ParseMsgMemo(memo, types.MsgTypeRebalance) if err != nil { - return errors.New("unexpected epoch rebalance memo format (2)") + return err } k.Logger(ctx).Info("Received MsgBeginRedelegate acknowledgement") @@ -615,14 +607,9 @@ func (k *Keeper) HandleBeginRedelegate(ctx sdk.Context, msg sdk.Msg, completion } func (k *Keeper) HandleFailedBeginRedelegate(ctx sdk.Context, msg sdk.Msg, memo string) error { - parts := strings.Split(memo, "/") - if len(parts) != 2 || parts[0] != "rebalance" { - return errors.New("unexpected epoch rebalance memo format") - } - - epochNumber, err := strconv.ParseInt(parts[1], 10, 64) + epochNumber, err := types.ParseMsgMemo(memo, types.MsgTypeRebalance) if err != nil { - return errors.New("unexpected epoch rebalance memo format (2)") + return err } k.Logger(ctx).Error("Received MsgBeginRedelegate acknowledgement error") @@ -649,15 +636,12 @@ func (k *Keeper) HandleUndelegate(ctx sdk.Context, msg sdk.Msg, completion time. k.Logger(ctx).Error("unable to cast source message to MsgUndelegate") return errors.New("unable to cast source message to MsgUndelegate") } - memoParts := strings.Split(memo, "/") - if len(memoParts) != 2 || memoParts[0] != withdrawal { - return errors.New("unexpected memo form") - } - epochNumber, err := strconv.ParseInt(memoParts[1], 10, 64) + epochNumber, err := types.ParseMsgMemo(memo, types.MsgTypeWithdrawal) if err != nil { return err } + zone := k.GetZoneForDelegateAccount(ctx, undelegateMsg.DelegatorAddress) ubr, found := k.GetUnbondingRecord(ctx, zone.ChainId, undelegateMsg.ValidatorAddress, epochNumber) @@ -707,14 +691,9 @@ func (k *Keeper) HandleUndelegate(ctx sdk.Context, msg sdk.Msg, completion time. } func (k *Keeper) HandleFailedUndelegate(ctx sdk.Context, msg sdk.Msg, memo string) error { - parts := strings.Split(memo, "/") - if len(parts) != 2 || parts[0] != withdrawal { - return errors.New("unexpected epoch undelegate memo format") - } - - epochNumber, err := strconv.ParseInt(parts[1], 10, 64) + epochNumber, err := types.ParseMsgMemo(memo, types.MsgTypeWithdrawal) if err != nil { - return errors.New("unexpected epoch undelegate memo format (2)") + return err } k.Logger(ctx).Error("Received MsgUndelegate acknowledgement error") diff --git a/x/interchainstaking/types/ibc_packet.go b/x/interchainstaking/types/ibc_packet.go new file mode 100644 index 000000000..9bc5a5369 --- /dev/null +++ b/x/interchainstaking/types/ibc_packet.go @@ -0,0 +1,26 @@ +package types + +import ( + "fmt" + "strconv" + "strings" +) + +const ( + MsgTypeWithdrawal = "withdrawal" + MsgTypeRebalance = "rebalance" +) + +func ParseMsgMemo(memo, msgType string) (epochNumber int64, err error) { + parts := strings.Split(memo, "/") + if len(parts) != 2 || parts[0] != msgType { + return 0, fmt.Errorf("unexpected epoch %s memo format", msgType) + } + + epochNumber, err = strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("unexpected epoch %s memo format: %w", msgType, err) + } + + return +} diff --git a/x/interchainstaking/types/ibc_packet_test.go b/x/interchainstaking/types/ibc_packet_test.go new file mode 100644 index 000000000..66ba073b4 --- /dev/null +++ b/x/interchainstaking/types/ibc_packet_test.go @@ -0,0 +1,58 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ingenuity-build/quicksilver/x/interchainstaking/types" +) + +func TestParseMsgMemo(t *testing.T) { + tests := []struct { + name string + memo string + msgType string + wantErr bool + expectedEpochNumber int64 + }{ + { + name: "valid rebalance", + memo: types.MsgTypeRebalance + "/" + "10", + msgType: types.MsgTypeRebalance, + wantErr: false, + expectedEpochNumber: 10, + }, + { + name: "valid withdrawal", + memo: types.MsgTypeWithdrawal + "/" + "10", + msgType: types.MsgTypeWithdrawal, + wantErr: false, + expectedEpochNumber: 10, + }, + { + name: "invalid msg type", + memo: "invalid" + "/" + "10", + msgType: types.MsgTypeWithdrawal, + wantErr: true, + }, + { + name: "invalid epoch number", + memo: types.MsgTypeWithdrawal + "/" + "A", + msgType: types.MsgTypeWithdrawal, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + epochNumber, err := types.ParseMsgMemo(tt.memo, tt.msgType) + if tt.wantErr { + t.Logf("Error:\n%v\n", err) + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.expectedEpochNumber, epochNumber) + }) + } +}