Skip to content

Commit

Permalink
refactor: remove storage of chain/connection in context for ibc callb…
Browse files Browse the repository at this point in the history
…acks
  • Loading branch information
Duong Minh Ngoc committed Feb 29, 2024
1 parent e3184a0 commit a1289ba
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 46 deletions.
6 changes: 1 addition & 5 deletions x/interchainstaking/ibc_module.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package interchainstaking

import (
"context"
"errors"
"fmt"

Expand All @@ -13,7 +12,6 @@ import (
host "github.com/cosmos/ibc-go/v5/modules/core/24-host"
ibcexported "github.com/cosmos/ibc-go/v5/modules/core/exported"

"github.com/quicksilver-zone/quicksilver/utils"
"github.com/quicksilver-zone/quicksilver/x/interchainstaking/keeper"
)

Expand Down Expand Up @@ -126,9 +124,7 @@ func (im IBCModule) OnAcknowledgementPacket(
ctx.Logger().Error(err.Error())
return err
}
ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), connectionID))

err = im.keeper.HandleAcknowledgement(ctx, packet, acknowledgement)
err = im.keeper.HandleAcknowledgement(ctx, packet, acknowledgement, connectionID)
if err != nil {
im.keeper.Logger(ctx).Error("CALLBACK ERROR:", "error", err.Error())
}
Expand Down
30 changes: 15 additions & 15 deletions x/interchainstaking/keeper/ibc_packet_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func DeserializeCosmosTxTyped(cdc codec.BinaryCodec, data []byte) ([]TypedMsg, e
return msgs, nil
}

func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Packet, acknowledgement []byte) error {
func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Packet, acknowledgement []byte, connectionID string) error {
var (
ack channeltypes.Acknowledgement
success bool
Expand Down Expand Up @@ -134,7 +134,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack
return nil
}
k.Logger(ctx).Info("Rewards withdrawn")
if err := k.HandleWithdrawRewards(ctx, msg.Msg); err != nil {
if err := k.HandleWithdrawRewards(ctx, msg.Msg, connectionID); err != nil {
return err
}
continue
Expand All @@ -155,7 +155,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack

k.Logger(ctx).Info("Tokens redeemed for shares", "response", response)
// we should update delegation records here.
if err := k.HandleRedeemTokens(ctx, msg.Msg, response.Amount, packetData.Memo); err != nil {
if err := k.HandleRedeemTokens(ctx, msg.Msg, response.Amount, packetData.Memo, connectionID); err != nil {
return err
}
continue
Expand Down Expand Up @@ -239,7 +239,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack

case "/cosmos.bank.v1beta1.MsgSend":
if !success {
if err := k.HandleFailedBankSend(ctx, msg.Msg, packetData.Memo); err != nil {
if err := k.HandleFailedBankSend(ctx, msg.Msg, packetData.Memo, connectionID); err != nil {
k.Logger(ctx).Error("unable to handle failed MsgSend", "error", err)
return err
}
Expand All @@ -254,7 +254,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack

k.Logger(ctx).Info("Funds Transferred", "response", response)
// check tokenTransfers - if end user unescrow and burn txs
if err := k.HandleCompleteSend(ctx, msg.Msg, packetData.Memo); err != nil {
if err := k.HandleCompleteSend(ctx, msg.Msg, packetData.Memo, connectionID); err != nil {
return err
}
case "/cosmos.distribution.v1beta1.MsgSetWithdrawAddress":
Expand Down Expand Up @@ -332,7 +332,7 @@ func (k *Keeper) HandleMsgTransfer(ctx sdk.Context, msg ibctransfertypes.Fungibl
return k.BankKeeper.SendCoinsFromModuleToModule(ctx, types.ModuleName, authtypes.FeeCollectorName, balance)
}

func (k *Keeper) HandleCompleteSend(ctx sdk.Context, msg sdk.Msg, memo string) error {
func (k *Keeper) HandleCompleteSend(ctx sdk.Context, msg sdk.Msg, memo string, connectionID string) error {
k.Logger(ctx).Info("Received MsgSend acknowledgement")
// first, type assertion. we should have banktypes.MsgSend
sMsg, ok := msg.(*banktypes.MsgSend)
Expand All @@ -343,7 +343,7 @@ func (k *Keeper) HandleCompleteSend(ctx sdk.Context, msg sdk.Msg, memo string) e
}

// get zone
zone, err := k.GetZoneFromContext(ctx)
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
err = fmt.Errorf("2: %w", err)
k.Logger(ctx).Error(err.Error())
Expand Down Expand Up @@ -768,7 +768,7 @@ func (k *Keeper) HandleUndelegate(ctx sdk.Context, msg sdk.Msg, completion time.
return nil
}

func (k *Keeper) HandleFailedBankSend(ctx sdk.Context, msg sdk.Msg, memo string) error {
func (k *Keeper) HandleFailedBankSend(ctx sdk.Context, msg sdk.Msg, memo string, connectionID string) error {
sMsg, ok := msg.(*banktypes.MsgSend)
if !ok {
err := errors.New("unable to cast source message to MsgSend")
Expand All @@ -777,7 +777,7 @@ func (k *Keeper) HandleFailedBankSend(ctx sdk.Context, msg sdk.Msg, memo string)
}

// get zone
zone, err := k.GetZoneFromContext(ctx)
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
k.Logger(ctx).Error(err.Error())
return err
Expand Down Expand Up @@ -907,15 +907,15 @@ func (k *Keeper) HandleFailedUndelegate(ctx sdk.Context, msg sdk.Msg, memo strin
return nil
}

func (k *Keeper) HandleRedeemTokens(ctx sdk.Context, msg sdk.Msg, amount sdk.Coin, memo string) error {
func (k *Keeper) HandleRedeemTokens(ctx sdk.Context, msg sdk.Msg, amount sdk.Coin, memo string, connectionID string) error {
k.Logger(ctx).Info("Received MsgRedeemTokensforShares acknowledgement")
// first, type assertion. we should have stakingtypes.MsgRedeemTokensforShares
redeemMsg, ok := msg.(*lsmstakingtypes.MsgRedeemTokensForShares)
if !ok {
k.Logger(ctx).Error("unable to cast source message to MsgRedeemTokensforShares")
return errors.New("unable to cast source message to MsgRedeemTokensforShares")
}
validatorAddress, err := k.GetValidatorForToken(ctx, redeemMsg.Amount)
validatorAddress, err := k.GetValidatorForToken(ctx, redeemMsg.Amount, connectionID)
if err != nil {
return err
}
Expand Down Expand Up @@ -1119,8 +1119,8 @@ func (k *Keeper) HandleUpdatedWithdrawAddress(ctx sdk.Context, msg sdk.Msg) erro
return nil
}

func (k *Keeper) GetValidatorForToken(ctx sdk.Context, amount sdk.Coin) (string, error) {
zone, err := k.GetZoneFromContext(ctx)
func (k *Keeper) GetValidatorForToken(ctx sdk.Context, amount sdk.Coin, connectionID string) (string, error) {
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
err = fmt.Errorf("3: %w", err)
k.Logger(ctx).Error(err.Error())
Expand Down Expand Up @@ -1291,14 +1291,14 @@ func (k *Keeper) UpdateDelegationRecordForAddress(
return nil
}

func (k *Keeper) HandleWithdrawRewards(ctx sdk.Context, msg sdk.Msg) error {
func (k *Keeper) HandleWithdrawRewards(ctx sdk.Context, msg sdk.Msg, connectionID string) error {
withdrawalMsg, ok := msg.(*distrtypes.MsgWithdrawDelegatorReward)
if !ok {
k.Logger(ctx).Error("unable to cast source message to MsgWithdrawDelegatorReward")
return errors.New("unable to cast source message to MsgWithdrawDelegatorReward")
}

zone, err := k.GetZoneFromContext(ctx)
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
err = fmt.Errorf("4: %w", err)
k.Logger(ctx).Error(err.Error())
Expand Down
47 changes: 21 additions & 26 deletions x/interchainstaking/keeper/ibc_packet_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,7 @@ func (suite *KeeperTestSuite) TestHandleWithdrawRewards() {
}
}

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), zone.ConnectionId))
err := quicksilver.InterchainstakingKeeper.HandleWithdrawRewards(ctx, test.msg(&zone))
err := quicksilver.InterchainstakingKeeper.HandleWithdrawRewards(ctx, test.msg(&zone), zone.ConnectionId)
if test.err {
suite.Error(err)
} else {
Expand Down Expand Up @@ -1122,7 +1121,7 @@ func (suite *KeeperTestSuite) TestReceiveAckErrForBeginRedelegate() {
_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
suite.True(found)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
Expand Down Expand Up @@ -1493,7 +1492,7 @@ func (suite *KeeperTestSuite) TestReceiveAckErrForBeginUndelegate() {
suite.True(found)
}

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

for _, ubr := range test.unbondingRecords(ctx, quicksilver, zone) {
Expand Down Expand Up @@ -1892,7 +1891,9 @@ func (suite *KeeperTestSuite) Test_v045Callback() {
Data: packetBytes,
}
ctx = suite.chainA.GetContext()
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement)))
zone, found := quicksilver.InterchainstakingKeeper.GetZone(ctx, suite.chainB.ChainID)
suite.True(found)
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement), zone.ConnectionId))

suite.True(test.assertStatements(ctx, quicksilver))
})
Expand Down Expand Up @@ -2000,8 +2001,7 @@ func (suite *KeeperTestSuite) Test_v046Callback() {
Data: packetBytes,
}

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), "connection-0"))
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement)))
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement), "connection-0"))

suite.True(test.assertStatements(ctx, quicksilver))
})
Expand Down Expand Up @@ -2462,7 +2462,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginUndelegate() {
suite.True(found)
}

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

for idx, ewdr := range test.expectedWithdrawalRecords(ctx, quicksilver, zone) {
Expand Down Expand Up @@ -2547,7 +2547,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginRedelegateNonNilCompletion()
_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
suite.True(found)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

afterRedelegation, found := quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
Expand Down Expand Up @@ -2641,7 +2641,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginRedelegateNilCompletion() {
_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, epoch)
suite.True(found)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, epoch)
Expand Down Expand Up @@ -2720,7 +2720,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginRedelegateNoExistingRecord()

// call handler

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

createdRecord, found := quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, epoch)
Expand Down Expand Up @@ -2797,8 +2797,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForWithdrawReward() {
}
}

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), zone.ConnectionId))
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

allBalancesQueryCnt := 0
Expand Down Expand Up @@ -2874,7 +2873,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForRedeemTokens() {
suite.NoError(err)

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

delegationRecord, found = quicksilver.InterchainstakingKeeper.GetDelegation(ctx, zone.ChainId, zone.DelegationAddress.Address, vals[0])
Expand Down Expand Up @@ -2953,7 +2952,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForTokenizedShares() {
ackBytes, err := icatypes.ModuleCdc.MarshalJSON(&acknowledgement)
suite.NoError(err)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

_, found = quicksilver.InterchainstakingKeeper.GetWithdrawalRecord(ctx, zone.ChainId, txHash, types.WithdrawStatusTokenize)
Expand Down Expand Up @@ -3025,7 +3024,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForDelegate() {
ackBytes, err := icatypes.ModuleCdc.MarshalJSON(&acknowledgement)
suite.NoError(err)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

newCompleted := ctx.BlockTime()
Expand Down Expand Up @@ -3089,7 +3088,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBankSend() {
ackBytes, err := icatypes.ModuleCdc.MarshalJSON(&acknowledgement)
suite.NoError(err)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)
}

Expand Down Expand Up @@ -3144,8 +3143,7 @@ func (suite *KeeperTestSuite) TestReceiveAckErrForBankSend() {

ackBytes := []byte("{\"error\":\"ABCI code: 32: error handling packet on host chain: see events for details\"}")

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

newRecord, found := quicksilver.InterchainstakingKeeper.GetWithdrawalRecord(ctx, zone.ChainId, "7C8B95EEE82CB63771E02EBEB05E6A80076D70B2E0A1C457F1FD1A0EF2EA961D", types.WithdrawStatusUnbond)
Expand Down Expand Up @@ -4060,17 +4058,14 @@ func (suite *KeeperTestSuite) TestGetValidatorForToken() {

quicksilver := suite.GetQuicksilverApp(suite.chainA)
ctx := suite.chainA.GetContext()
if test.setupConnection {
ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
}

zone, found := quicksilver.InterchainstakingKeeper.GetZone(ctx, suite.chainB.ChainID)

if !found {
suite.Fail("unable to retrieve zone for test")
}
amount := test.amount(ctx, quicksilver, zone)
resVal, err := quicksilver.InterchainstakingKeeper.GetValidatorForToken(ctx, amount)
resVal, err := quicksilver.InterchainstakingKeeper.GetValidatorForToken(ctx, amount, suite.path.EndpointA.ConnectionID)

if test.err {
suite.Error(err)
Expand Down Expand Up @@ -4178,7 +4173,7 @@ func (suite *KeeperTestSuite) TestHandleCompleteSend() {

msg := tc.message(&zone)

err := quicksilver.InterchainstakingKeeper.HandleCompleteSend(ctx, msg, tc.memo)
err := quicksilver.InterchainstakingKeeper.HandleCompleteSend(ctx, msg, tc.memo, zone.ConnectionId)
if tc.expectedError != nil {
suite.Equal(tc.expectedError, err)
} else {
Expand Down Expand Up @@ -4374,7 +4369,7 @@ func (suite *KeeperTestSuite) TestHandleFailedBankSend() {
// set address for zone mapping
quicksilver.InterchainstakingKeeper.SetAddressZoneMapping(ctx, user, zone.ChainId)
msg := test.message(&zone)
err := quicksilver.InterchainstakingKeeper.HandleFailedBankSend(ctx, msg, test.memo)
err := quicksilver.InterchainstakingKeeper.HandleFailedBankSend(ctx, msg, test.memo, zone.ConnectionId)

if test.err {
suite.Error(err)
Expand Down Expand Up @@ -4562,7 +4557,7 @@ func (suite *KeeperTestSuite) TestHandleRedeemTokens() {
FirstSeen: &t,
})

err := quicksilver.InterchainstakingKeeper.HandleRedeemTokens(ctx, msg, sdk.NewCoin(zone.BaseDenom, lsmMsg.Amount.Amount), txHash)
err := quicksilver.InterchainstakingKeeper.HandleRedeemTokens(ctx, msg, sdk.NewCoin(zone.BaseDenom, lsmMsg.Amount.Amount), txHash, zone.ConnectionId)
if test.errs[idx] {
suite.Error(err)
} else {
Expand Down
15 changes: 15 additions & 0 deletions x/interchainstaking/keeper/zones.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ func (k *Keeper) GetZoneFromContext(ctx sdk.Context) (*types.Zone, error) {
return &zone, nil
}

// GetZoneFromConnectionID determines the zone from the connection ID
func (k *Keeper) GetZoneFromConnectionID(ctx sdk.Context, connectionID string) (*types.Zone, error) {
chainID, err := k.GetChainID(ctx, connectionID)
if err != nil {
return nil, fmt.Errorf("unable to fetch zone from context: %w", err)
}
zone, found := k.GetZone(ctx, chainID)
if !found {
err := fmt.Errorf("unable to fetch zone from context: not found for chainID %s", chainID)
k.Logger(ctx).Error(err.Error())
return nil, err

Check warning on line 145 in x/interchainstaking/keeper/zones.go

View check run for this annotation

Codecov / codecov/patch

x/interchainstaking/keeper/zones.go#L143-L145

Added lines #L143 - L145 were not covered by tests
}
return &zone, nil
}

func (k *Keeper) GetZoneForAccount(ctx sdk.Context, address string) (*types.Zone, bool) {
chainID, found := k.GetAddressZoneMapping(ctx, address)
if !found {
Expand Down

0 comments on commit a1289ba

Please sign in to comment.