Skip to content

Commit

Permalink
refactor: remove storage of chain/connection in context for ibc callb…
Browse files Browse the repository at this point in the history
…acks (#1209)

* refactor: remove storage of chain/connection in context for ibc callbacks

* fix error message

* refactor: remove unused code

---------

Co-authored-by: Joe Bowman <joe@ingenuity.build>
minhngoc274 and Joe Bowman authored Mar 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 2c92595 commit 7966553
Showing 6 changed files with 42 additions and 113 deletions.
6 changes: 1 addition & 5 deletions x/interchainstaking/ibc_module.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package interchainstaking

import (
"context"
"errors"
"fmt"

@@ -13,7 +12,6 @@ import (
host "github.com/cosmos/ibc-go/v5/modules/core/24-host"
ibcexported "github.com/cosmos/ibc-go/v5/modules/core/exported"

"github.com/quicksilver-zone/quicksilver/utils"
"github.com/quicksilver-zone/quicksilver/x/interchainstaking/keeper"
)

@@ -126,9 +124,7 @@ func (im IBCModule) OnAcknowledgementPacket(
ctx.Logger().Error(err.Error())
return err
}
ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), connectionID))

err = im.keeper.HandleAcknowledgement(ctx, packet, acknowledgement)
err = im.keeper.HandleAcknowledgement(ctx, packet, acknowledgement, connectionID)
if err != nil {
im.keeper.Logger(ctx).Error("CALLBACK ERROR:", "error", err.Error())
}
30 changes: 15 additions & 15 deletions x/interchainstaking/keeper/ibc_packet_handlers.go
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ func DeserializeCosmosTxTyped(cdc codec.BinaryCodec, data []byte) ([]TypedMsg, e
return msgs, nil
}

func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Packet, acknowledgement []byte) error {
func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Packet, acknowledgement []byte, connectionID string) error {
var (
ack channeltypes.Acknowledgement
success bool
@@ -134,7 +134,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack
return nil
}
k.Logger(ctx).Info("Rewards withdrawn")
if err := k.HandleWithdrawRewards(ctx, msg.Msg); err != nil {
if err := k.HandleWithdrawRewards(ctx, msg.Msg, connectionID); err != nil {
return err
}
continue
@@ -155,7 +155,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack

k.Logger(ctx).Info("Tokens redeemed for shares", "response", response)
// we should update delegation records here.
if err := k.HandleRedeemTokens(ctx, msg.Msg, response.Amount, packetData.Memo); err != nil {
if err := k.HandleRedeemTokens(ctx, msg.Msg, response.Amount, packetData.Memo, connectionID); err != nil {
return err
}
continue
@@ -239,7 +239,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack

case "/cosmos.bank.v1beta1.MsgSend":
if !success {
if err := k.HandleFailedBankSend(ctx, msg.Msg, packetData.Memo); err != nil {
if err := k.HandleFailedBankSend(ctx, msg.Msg, packetData.Memo, connectionID); err != nil {
k.Logger(ctx).Error("unable to handle failed MsgSend", "error", err)
return err
}
@@ -254,7 +254,7 @@ func (k *Keeper) HandleAcknowledgement(ctx sdk.Context, packet channeltypes.Pack

k.Logger(ctx).Info("Funds Transferred", "response", response)
// check tokenTransfers - if end user unescrow and burn txs
if err := k.HandleCompleteSend(ctx, msg.Msg, packetData.Memo); err != nil {
if err := k.HandleCompleteSend(ctx, msg.Msg, packetData.Memo, connectionID); err != nil {
return err
}
case "/cosmos.distribution.v1beta1.MsgSetWithdrawAddress":
@@ -332,7 +332,7 @@ func (k *Keeper) HandleMsgTransfer(ctx sdk.Context, msg ibctransfertypes.Fungibl
return k.BankKeeper.SendCoinsFromModuleToModule(ctx, types.ModuleName, authtypes.FeeCollectorName, balance)
}

func (k *Keeper) HandleCompleteSend(ctx sdk.Context, msg sdk.Msg, memo string) error {
func (k *Keeper) HandleCompleteSend(ctx sdk.Context, msg sdk.Msg, memo string, connectionID string) error {
k.Logger(ctx).Info("Received MsgSend acknowledgement")
// first, type assertion. we should have banktypes.MsgSend
sMsg, ok := msg.(*banktypes.MsgSend)
@@ -343,7 +343,7 @@ func (k *Keeper) HandleCompleteSend(ctx sdk.Context, msg sdk.Msg, memo string) e
}

// get zone
zone, err := k.GetZoneFromContext(ctx)
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
err = fmt.Errorf("2: %w", err)
k.Logger(ctx).Error(err.Error())
@@ -768,7 +768,7 @@ func (k *Keeper) HandleUndelegate(ctx sdk.Context, msg sdk.Msg, completion time.
return nil
}

func (k *Keeper) HandleFailedBankSend(ctx sdk.Context, msg sdk.Msg, memo string) error {
func (k *Keeper) HandleFailedBankSend(ctx sdk.Context, msg sdk.Msg, memo string, connectionID string) error {
sMsg, ok := msg.(*banktypes.MsgSend)
if !ok {
err := errors.New("unable to cast source message to MsgSend")
@@ -777,7 +777,7 @@ func (k *Keeper) HandleFailedBankSend(ctx sdk.Context, msg sdk.Msg, memo string)
}

// get zone
zone, err := k.GetZoneFromContext(ctx)
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
k.Logger(ctx).Error(err.Error())
return err
@@ -907,15 +907,15 @@ func (k *Keeper) HandleFailedUndelegate(ctx sdk.Context, msg sdk.Msg, memo strin
return nil
}

func (k *Keeper) HandleRedeemTokens(ctx sdk.Context, msg sdk.Msg, amount sdk.Coin, memo string) error {
func (k *Keeper) HandleRedeemTokens(ctx sdk.Context, msg sdk.Msg, amount sdk.Coin, memo string, connectionID string) error {
k.Logger(ctx).Info("Received MsgRedeemTokensforShares acknowledgement")
// first, type assertion. we should have stakingtypes.MsgRedeemTokensforShares
redeemMsg, ok := msg.(*lsmstakingtypes.MsgRedeemTokensForShares)
if !ok {
k.Logger(ctx).Error("unable to cast source message to MsgRedeemTokensforShares")
return errors.New("unable to cast source message to MsgRedeemTokensforShares")
}
validatorAddress, err := k.GetValidatorForToken(ctx, redeemMsg.Amount)
validatorAddress, err := k.GetValidatorForToken(ctx, redeemMsg.Amount, connectionID)
if err != nil {
return err
}
@@ -1119,8 +1119,8 @@ func (k *Keeper) HandleUpdatedWithdrawAddress(ctx sdk.Context, msg sdk.Msg) erro
return nil
}

func (k *Keeper) GetValidatorForToken(ctx sdk.Context, amount sdk.Coin) (string, error) {
zone, err := k.GetZoneFromContext(ctx)
func (k *Keeper) GetValidatorForToken(ctx sdk.Context, amount sdk.Coin, connectionID string) (string, error) {
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
err = fmt.Errorf("3: %w", err)
k.Logger(ctx).Error(err.Error())
@@ -1291,14 +1291,14 @@ func (k *Keeper) UpdateDelegationRecordForAddress(
return nil
}

func (k *Keeper) HandleWithdrawRewards(ctx sdk.Context, msg sdk.Msg) error {
func (k *Keeper) HandleWithdrawRewards(ctx sdk.Context, msg sdk.Msg, connectionID string) error {
withdrawalMsg, ok := msg.(*distrtypes.MsgWithdrawDelegatorReward)
if !ok {
k.Logger(ctx).Error("unable to cast source message to MsgWithdrawDelegatorReward")
return errors.New("unable to cast source message to MsgWithdrawDelegatorReward")
}

zone, err := k.GetZoneFromContext(ctx)
zone, err := k.GetZoneFromConnectionID(ctx, connectionID)
if err != nil {
err = fmt.Errorf("4: %w", err)
k.Logger(ctx).Error(err.Error())
47 changes: 21 additions & 26 deletions x/interchainstaking/keeper/ibc_packet_handlers_test.go
Original file line number Diff line number Diff line change
@@ -909,8 +909,7 @@ func (suite *KeeperTestSuite) TestHandleWithdrawRewards() {
}
}

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), zone.ConnectionId))
err := quicksilver.InterchainstakingKeeper.HandleWithdrawRewards(ctx, test.msg(&zone))
err := quicksilver.InterchainstakingKeeper.HandleWithdrawRewards(ctx, test.msg(&zone), zone.ConnectionId)
if test.err {
suite.Error(err)
} else {
@@ -1122,7 +1121,7 @@ func (suite *KeeperTestSuite) TestReceiveAckErrForBeginRedelegate() {
_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
suite.True(found)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
@@ -1493,7 +1492,7 @@ func (suite *KeeperTestSuite) TestReceiveAckErrForBeginUndelegate() {
suite.True(found)
}

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

for _, ubr := range test.unbondingRecords(ctx, quicksilver, zone) {
@@ -1892,7 +1891,9 @@ func (suite *KeeperTestSuite) Test_v045Callback() {
Data: packetBytes,
}
ctx = suite.chainA.GetContext()
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement)))
zone, found := quicksilver.InterchainstakingKeeper.GetZone(ctx, suite.chainB.ChainID)
suite.True(found)
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement), zone.ConnectionId))

suite.True(test.assertStatements(ctx, quicksilver))
})
@@ -2000,8 +2001,7 @@ func (suite *KeeperTestSuite) Test_v046Callback() {
Data: packetBytes,
}

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), "connection-0"))
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement)))
suite.NoError(quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, icatypes.ModuleCdc.MustMarshalJSON(&acknowledgement), "connection-0"))

suite.True(test.assertStatements(ctx, quicksilver))
})
@@ -2462,7 +2462,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginUndelegate() {
suite.True(found)
}

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

for idx, ewdr := range test.expectedWithdrawalRecords(ctx, quicksilver, zone) {
@@ -2547,7 +2547,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginRedelegateNonNilCompletion()
_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
suite.True(found)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

afterRedelegation, found := quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, 1)
@@ -2641,7 +2641,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginRedelegateNilCompletion() {
_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, epoch)
suite.True(found)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

_, found = quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, epoch)
@@ -2720,7 +2720,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBeginRedelegateNoExistingRecord()

// call handler

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

createdRecord, found := quicksilver.InterchainstakingKeeper.GetRedelegationRecord(ctx, zone.ChainId, validators[0].ValoperAddress, validators[1].ValoperAddress, epoch)
@@ -2797,8 +2797,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForWithdrawReward() {
}
}

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), zone.ConnectionId))
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

allBalancesQueryCnt := 0
@@ -2874,7 +2873,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForRedeemTokens() {
suite.NoError(err)

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

delegationRecord, found = quicksilver.InterchainstakingKeeper.GetDelegation(ctx, zone.ChainId, zone.DelegationAddress.Address, vals[0])
@@ -2953,7 +2952,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForTokenizedShares() {
ackBytes, err := icatypes.ModuleCdc.MarshalJSON(&acknowledgement)
suite.NoError(err)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

_, found = quicksilver.InterchainstakingKeeper.GetWithdrawalRecord(ctx, zone.ChainId, txHash, types.WithdrawStatusTokenize)
@@ -3025,7 +3024,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForDelegate() {
ackBytes, err := icatypes.ModuleCdc.MarshalJSON(&acknowledgement)
suite.NoError(err)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

newCompleted := ctx.BlockTime()
@@ -3089,7 +3088,7 @@ func (suite *KeeperTestSuite) TestReceiveAckForBankSend() {
ackBytes, err := icatypes.ModuleCdc.MarshalJSON(&acknowledgement)
suite.NoError(err)

err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)
}

@@ -3144,8 +3143,7 @@ func (suite *KeeperTestSuite) TestReceiveAckErrForBankSend() {

ackBytes := []byte("{\"error\":\"ABCI code: 32: error handling packet on host chain: see events for details\"}")

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes)
err = quicksilver.InterchainstakingKeeper.HandleAcknowledgement(ctx, packet, ackBytes, zone.ConnectionId)
suite.NoError(err)

newRecord, found := quicksilver.InterchainstakingKeeper.GetWithdrawalRecord(ctx, zone.ChainId, "7C8B95EEE82CB63771E02EBEB05E6A80076D70B2E0A1C457F1FD1A0EF2EA961D", types.WithdrawStatusUnbond)
@@ -4060,17 +4058,14 @@ func (suite *KeeperTestSuite) TestGetValidatorForToken() {

quicksilver := suite.GetQuicksilverApp(suite.chainA)
ctx := suite.chainA.GetContext()
if test.setupConnection {
ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
}

zone, found := quicksilver.InterchainstakingKeeper.GetZone(ctx, suite.chainB.ChainID)

if !found {
suite.Fail("unable to retrieve zone for test")
}
amount := test.amount(ctx, quicksilver, zone)
resVal, err := quicksilver.InterchainstakingKeeper.GetValidatorForToken(ctx, amount)
resVal, err := quicksilver.InterchainstakingKeeper.GetValidatorForToken(ctx, amount, suite.path.EndpointA.ConnectionID)

if test.err {
suite.Error(err)
@@ -4178,7 +4173,7 @@ func (suite *KeeperTestSuite) TestHandleCompleteSend() {

msg := tc.message(&zone)

err := quicksilver.InterchainstakingKeeper.HandleCompleteSend(ctx, msg, tc.memo)
err := quicksilver.InterchainstakingKeeper.HandleCompleteSend(ctx, msg, tc.memo, zone.ConnectionId)
if tc.expectedError != nil {
suite.Equal(tc.expectedError, err)
} else {
@@ -4374,7 +4369,7 @@ func (suite *KeeperTestSuite) TestHandleFailedBankSend() {
// set address for zone mapping
quicksilver.InterchainstakingKeeper.SetAddressZoneMapping(ctx, user, zone.ChainId)
msg := test.message(&zone)
err := quicksilver.InterchainstakingKeeper.HandleFailedBankSend(ctx, msg, test.memo)
err := quicksilver.InterchainstakingKeeper.HandleFailedBankSend(ctx, msg, test.memo, zone.ConnectionId)

if test.err {
suite.Error(err)
@@ -4562,7 +4557,7 @@ func (suite *KeeperTestSuite) TestHandleRedeemTokens() {
FirstSeen: &t,
})

err := quicksilver.InterchainstakingKeeper.HandleRedeemTokens(ctx, msg, sdk.NewCoin(zone.BaseDenom, lsmMsg.Amount.Amount), txHash)
err := quicksilver.InterchainstakingKeeper.HandleRedeemTokens(ctx, msg, sdk.NewCoin(zone.BaseDenom, lsmMsg.Amount.Amount), txHash, zone.ConnectionId)
if test.errs[idx] {
suite.Error(err)
} else {
10 changes: 0 additions & 10 deletions x/interchainstaking/keeper/keeper.go
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@ import (
ibckeeper "github.com/cosmos/ibc-go/v5/modules/core/keeper"
ibctmtypes "github.com/cosmos/ibc-go/v5/modules/light-clients/07-tendermint/types"

"github.com/quicksilver-zone/quicksilver/utils"
"github.com/quicksilver-zone/quicksilver/utils/addressutils"
epochskeeper "github.com/quicksilver-zone/quicksilver/x/epochs/keeper"
interchainquerykeeper "github.com/quicksilver-zone/quicksilver/x/interchainquery/keeper"
@@ -541,15 +540,6 @@ func (k *Keeper) GetChainID(ctx sdk.Context, connectionID string) (string, error
return client.ChainId, nil
}

func (k *Keeper) GetChainIDFromContext(ctx sdk.Context) (string, error) {
connectionID := ctx.Context().Value(utils.ContextKey("connectionID"))
if connectionID == nil {
return "", errors.New("connectionID not in context")
}

return k.GetChainID(ctx, connectionID.(string))
}

func (k *Keeper) EmitPerformanceBalanceQuery(ctx sdk.Context, zone *types.Zone) error {
_, addr, err := bech32.DecodeAndConvert(zone.PerformanceAddress.Address)
if err != nil {
52 changes: 0 additions & 52 deletions x/interchainstaking/keeper/keeper_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package keeper_test

import (
"context"
"errors"
"testing"
"time"

@@ -22,11 +20,9 @@ import (
ibctesting "github.com/cosmos/ibc-go/v5/testing"

"github.com/quicksilver-zone/quicksilver/app"
"github.com/quicksilver-zone/quicksilver/utils"
"github.com/quicksilver-zone/quicksilver/utils/addressutils"
"github.com/quicksilver-zone/quicksilver/utils/randomutils"
ics "github.com/quicksilver-zone/quicksilver/x/interchainstaking"
interchainstakingkeeper "github.com/quicksilver-zone/quicksilver/x/interchainstaking/keeper"
icstypes "github.com/quicksilver-zone/quicksilver/x/interchainstaking/types"
)

@@ -669,54 +665,6 @@ func (suite *KeeperTestSuite) TestOverrideRedemptionRateNoCap() {
suite.Equal(sdk.NewDecWithPrec(676666666666666667, 18), zone.RedemptionRate)
}

func (suite *KeeperTestSuite) TestGetChainIDFromContext() {
testCase := []struct {
name string
setup func() (*interchainstakingkeeper.Keeper, sdk.Context)
wantErr bool
expectedErr error
expectedChainID string
}{
{
name: "connectionID not in context",
setup: func() (*interchainstakingkeeper.Keeper, sdk.Context) {
suite.SetupTest()
suite.setupTestZones()
return suite.GetQuicksilverApp(suite.chainA).InterchainstakingKeeper, suite.chainA.GetContext()
},
wantErr: true,
expectedErr: errors.New("connectionID not in context"),
},
{
name: "get chainID success",
setup: func() (*interchainstakingkeeper.Keeper, sdk.Context) {
suite.SetupTest()
suite.setupTestZones()
ctx := suite.chainA.GetContext()

ctx = ctx.WithContext(context.WithValue(ctx.Context(), utils.ContextKey("connectionID"), suite.path.EndpointA.ConnectionID))
return suite.GetQuicksilverApp(suite.chainA).InterchainstakingKeeper, ctx
},
wantErr: false,
expectedErr: nil,
expectedChainID: "testchain2",
},
}
for _, tc := range testCase {
suite.Run(tc.name, func() {
keeper, ctx := tc.setup()

chainID, err := keeper.GetChainIDFromContext(ctx)
if tc.wantErr {
suite.Equal(tc.expectedErr, err)
return
}
suite.NoError(err)
suite.Equal(tc.expectedChainID, chainID)
})
}
}

func (suite *KeeperTestSuite) TestIteratePortConnection() {
suite.SetupTest()
suite.setupTestZones()
10 changes: 5 additions & 5 deletions x/interchainstaking/keeper/zones.go
Original file line number Diff line number Diff line change
@@ -117,15 +117,15 @@ func (k *Keeper) AllZones(ctx sdk.Context) []types.Zone {
return zones
}

// GetZoneFromContext determines the zone from the current context.
func (k *Keeper) GetZoneFromContext(ctx sdk.Context) (*types.Zone, error) {
chainID, err := k.GetChainIDFromContext(ctx)
// GetZoneFromConnectionID determines the zone from the connection ID
func (k *Keeper) GetZoneFromConnectionID(ctx sdk.Context, connectionID string) (*types.Zone, error) {
chainID, err := k.GetChainID(ctx, connectionID)
if err != nil {
return nil, fmt.Errorf("unable to fetch zone from context: %w", err)
return nil, fmt.Errorf("unable to fetch zone from connection id: %w", err)
}
zone, found := k.GetZone(ctx, chainID)
if !found {
err := fmt.Errorf("unable to fetch zone from context: not found for chainID %s", chainID)
err := fmt.Errorf("unable to fetch zone from connection id: not found for chainID %s", chainID)
k.Logger(ctx).Error(err.Error())
return nil, err
}

0 comments on commit 7966553

Please sign in to comment.