Skip to content

Commit

Permalink
Refactor ChanUpgradeInit to use new upgrade type (#3456)
Browse files Browse the repository at this point in the history
  • Loading branch information
chatton authored Apr 19, 2023
1 parent e62634d commit 81a709b
Show file tree
Hide file tree
Showing 21 changed files with 341 additions and 514 deletions.
11 changes: 7 additions & 4 deletions e2e/tests/core/04-channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ func (s *ChannelUpgradeTestSuite) TestChannelUpgrade() {

rlyWallet := s.CreateUserOnChainA(ctx, testvalues.StartingTokenAmount)

counterParty := channeltypes.NewCounterparty(channelA.Counterparty.PortID, channelA.Counterparty.ChannelID)
proposedUpgradeChannel := channeltypes.NewChannel(channeltypes.INITUPGRADE, channeltypes.UNORDERED, counterParty, channelA.ConnectionHops, `{"fee_version":"ics29-1","app_version":"ics20-1"}`)

t.Run("channel upgrade init", func(t *testing.T) {
upgradeTimeout := channeltypes.NewUpgradeTimeout(clienttypes.NewHeight(0, 10000), 0)
upgradeFields := channeltypes.NewUpgradeFields(channeltypes.UNORDERED, channelA.ConnectionHops, `{"fee_version":"ics29-1","app_version":"ics20-1"}`)
msgChanUpgradeInit := channeltypes.NewMsgChannelUpgradeInit(
channelA.PortID, channelA.ChannelID, proposedUpgradeChannel, clienttypes.NewHeight(0, 10000), 0, rlyWallet.FormattedAddress(),
channelA.PortID, channelA.ChannelID, upgradeFields, upgradeTimeout, rlyWallet.FormattedAddress(),
)

s.Require().NoError(msgChanUpgradeInit.ValidateBasic())

txResp, err := s.BroadcastMessages(ctx, chainA, rlyWallet, msgChanUpgradeInit)
s.Require().NoError(err)
s.AssertValidTxResponse(txResp)

channel, err := s.QueryChannel(ctx, chainA, channelA.PortID, channelA.ChannelID)
s.Require().NoError(err)
s.Require().Equal(channeltypes.INITUPGRADE, channel.State)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (im IBCMiddleware) OnTimeoutPacket(
}

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, counterparty channeltypes.Counterparty, version, previousVersion string) (string, error) {
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, version, previousVersion string) (string, error) {
return icatypes.Version, nil
}

Expand Down
2 changes: 1 addition & 1 deletion modules/apps/27-interchain-accounts/host/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (im IBCModule) OnTimeoutPacket(
}

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCModule) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, counterparty channeltypes.Counterparty, version, previousVersion string) (string, error) {
func (im IBCModule) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, version, previousVersion string) (string, error) {
return "", errorsmod.Wrap(icatypes.ErrInvalidChannelFlow, "channel handshake must be initiated by controller chain")
}

Expand Down
4 changes: 2 additions & 2 deletions modules/apps/29-fee/ibc_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ func (im IBCMiddleware) OnTimeoutPacket(
}

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, counterparty channeltypes.Counterparty, version, previousVersion string) (string, error) {
return im.app.OnChanUpgradeInit(ctx, order, connectionHops, portID, channelID, sequence, counterparty, version, previousVersion)
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, version, previousVersion string) (string, error) {
return im.app.OnChanUpgradeInit(ctx, order, connectionHops, portID, channelID, sequence, version, previousVersion)
}

// OnChanUpgradeTry implement s the IBCModule interface
Expand Down
2 changes: 1 addition & 1 deletion modules/apps/transfer/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func (im IBCModule) OnTimeoutPacket(
}

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCModule) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, counterparty channeltypes.Counterparty, version, previousVersion string) (string, error) {
func (im IBCModule) OnChanUpgradeInit(ctx sdk.Context, order channeltypes.Order, connectionHops []string, portID, channelID string, sequence uint64, version, previousVersion string) (string, error) {
return types.Version, nil
}

Expand Down
13 changes: 7 additions & 6 deletions modules/core/04-channel/keeper/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,18 @@ func emitChannelClosedEvent(ctx sdk.Context, packet exported.PacketI, channel ty
}

// emitChannelUpgradeInitEvent emits a channel upgrade init event
func emitChannelUpgradeInitEvent(ctx sdk.Context, portID string, channelID string, upgradeSequence uint64, channel types.Channel) {
func emitChannelUpgradeInitEvent(ctx sdk.Context, portID string, channelID string, currentChannel types.Channel, upgrade types.Upgrade) {
ctx.EventManager().EmitEvents(sdk.Events{
sdk.NewEvent(
types.EventTypeChannelUpgradeInit,
sdk.NewAttribute(types.AttributeKeyPortID, portID),
sdk.NewAttribute(types.AttributeKeyChannelID, channelID),
sdk.NewAttribute(types.AttributeCounterpartyPortID, channel.Counterparty.PortId),
sdk.NewAttribute(types.AttributeCounterpartyChannelID, channel.Counterparty.ChannelId),
sdk.NewAttribute(types.AttributeKeyConnectionID, channel.ConnectionHops[0]),
sdk.NewAttribute(types.AttributeVersion, channel.Version),
sdk.NewAttribute(types.AttributeKeyUpgradeSequence, fmt.Sprintf("%d", upgradeSequence)),
sdk.NewAttribute(types.AttributeCounterpartyPortID, currentChannel.Counterparty.PortId),
sdk.NewAttribute(types.AttributeCounterpartyChannelID, currentChannel.Counterparty.ChannelId),
sdk.NewAttribute(types.AttributeKeyUpgradeConnectionHops, upgrade.Fields.ConnectionHops[0]),
sdk.NewAttribute(types.AttributeKeyUpgradeVersion, upgrade.Fields.Version),
sdk.NewAttribute(types.AttributeKeyUpgradeSequence, fmt.Sprintf("%d", currentChannel.UpgradeSequence)),
sdk.NewAttribute(types.AttributeKeyUpgradeOrdering, upgrade.Fields.Ordering.String()),
),
sdk.NewEvent(
sdk.EventTypeMessage,
Expand Down
2 changes: 1 addition & 1 deletion modules/core/04-channel/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ func (suite *KeeperTestSuite) TestUpgradeTimeoutAccessors() {
suite.coordinator.SetupConnections(path)
suite.coordinator.CreateChannels(path)

expUpgradeTimeout := types.UpgradeTimeout{TimeoutHeight: clienttypes.NewHeight(1, 10), TimeoutTimestamp: uint64(suite.coordinator.CurrentTime.UnixNano())}
expUpgradeTimeout := types.UpgradeTimeout{Height: clienttypes.NewHeight(1, 10), Timestamp: uint64(suite.coordinator.CurrentTime.UnixNano())}

suite.Run("set upgrade timeout", func() {
upgradeTimeout, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetUpgradeTimeout(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
Expand Down
122 changes: 28 additions & 94 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
package keeper

import (
"reflect"

errorsmod "cosmossdk.io/errors"
"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"

clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
connectiontypes "github.com/cosmos/ibc-go/v7/modules/core/03-connection/types"
"github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
portkeeper "github.com/cosmos/ibc-go/v7/modules/core/05-port/keeper"
porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types"
host "github.com/cosmos/ibc-go/v7/modules/core/24-host"
)

// ChanUpgradeInit is called by a module to initiate a channel upgrade handshake with
Expand All @@ -22,115 +14,57 @@ func (k Keeper) ChanUpgradeInit(
ctx sdk.Context,
portID string,
channelID string,
chanCap *capabilitytypes.Capability,
proposedUpgradeChannel types.Channel,
counterpartyTimeoutHeight clienttypes.Height,
counterpartyTimeoutTimestamp uint64,
) (upgradeSequence uint64, previousVersion string, err error) {
upgradeFields types.UpgradeFields,
upgradeTimeout types.UpgradeTimeout,
) (types.Upgrade, error) {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return 0, "", errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
return types.Upgrade{}, errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

if channel.State != types.OPEN {
return 0, "", errorsmod.Wrapf(types.ErrInvalidChannelState, "expected %s, got %s", types.OPEN, channel.State)
}

if !k.scopedKeeper.AuthenticateCapability(ctx, chanCap, host.ChannelCapabilityPath(portID, channelID)) {
return 0, "", errorsmod.Wrapf(types.ErrChannelCapabilityNotFound, "caller does not own capability for channel, port ID (%s) channel ID (%s)", portID, channelID)
return types.Upgrade{}, errorsmod.Wrapf(types.ErrInvalidChannelState, "expected %s, got %s", types.OPEN, channel.State)
}

// set the restore channel to the current channel and reassign channel state to INITUPGRADE,
// if the channel == proposedUpgradeChannel then fail fast as no upgradable fields have been modified.
restoreChannel := channel
channel.State = types.INITUPGRADE
if reflect.DeepEqual(channel, proposedUpgradeChannel) {
return 0, "", errorsmod.Wrap(types.ErrChannelExists, "existing channel end is identical to proposed upgrade channel end")
if err := k.ValidateUpgradeFields(ctx, upgradeFields, channel); err != nil {
return types.Upgrade{}, err
}

connectionEnd, err := k.GetConnection(ctx, proposedUpgradeChannel.ConnectionHops[0])
proposedUpgrade, err := k.constructProposedUpgrade(ctx, portID, channelID, upgradeFields, upgradeTimeout)
if err != nil {
return 0, "", err
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return 0, "", errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
}

if proposedUpgradeChannel.Counterparty.PortId != channel.Counterparty.PortId ||
proposedUpgradeChannel.Counterparty.ChannelId != channel.Counterparty.ChannelId {
return 0, "", errorsmod.Wrap(types.ErrInvalidCounterparty, "counterparty port ID and channel ID cannot be upgraded")
return types.Upgrade{}, errorsmod.Wrap(err, "failed to construct proposed upgrade")
}

if !proposedUpgradeChannel.Ordering.SubsetOf(channel.Ordering) {
return 0, "", errorsmod.Wrap(types.ErrInvalidChannelOrdering, "channel ordering must be a subset of the new ordering")
}

upgradeSequence = uint64(1)
if seq, found := k.GetUpgradeSequence(ctx, portID, channelID); found {
upgradeSequence = seq + 1
}

upgradeTimeout := types.UpgradeTimeout{
TimeoutHeight: counterpartyTimeoutHeight,
TimeoutTimestamp: counterpartyTimeoutTimestamp,
}

k.SetUpgradeRestoreChannel(ctx, portID, channelID, restoreChannel)
k.SetUpgradeSequence(ctx, portID, channelID, upgradeSequence)
k.SetUpgradeTimeout(ctx, portID, channelID, upgradeTimeout)
channel.UpgradeSequence++
k.SetChannel(ctx, portID, channelID, channel)

return upgradeSequence, channel.Version, nil
return proposedUpgrade, nil
}

// WriteUpgradeInitChannel writes a channel which has successfully passed the UpgradeInit handshake step.
// An event is emitted for the handshake step.
func (k Keeper) WriteUpgradeInitChannel(
ctx sdk.Context,
portID,
channelID string,
upgradeSequence uint64,
channelUpgrade types.Channel,
) {
func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID string, currentChannel types.Channel, upgrade types.Upgrade) {
defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-init")

k.SetChannel(ctx, portID, channelID, channelUpgrade)
k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", types.OPEN.String(), "new-state", types.INITUPGRADE.String())

emitChannelUpgradeInitEvent(ctx, portID, channelID, upgradeSequence, channelUpgrade)
}

// RestoreChannel restores the given channel to the state prior to upgrade.
func (k Keeper) RestoreChannel(ctx sdk.Context, portID, channelID string, upgradeSequence uint64, err error) error {
errorReceipt := types.NewErrorReceipt(upgradeSequence, err)
k.SetUpgradeErrorReceipt(ctx, portID, channelID, errorReceipt)
currentChannel.State = types.INITUPGRADE

channel, found := k.GetUpgradeRestoreChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "channel-id: %s", channelID)
}
k.SetChannel(ctx, portID, channelID, currentChannel)
k.SetUpgrade(ctx, portID, channelID, upgrade)

k.SetChannel(ctx, portID, channelID, channel)
k.DeleteUpgradeRestoreChannel(ctx, portID, channelID)
k.DeleteUpgradeTimeout(ctx, portID, channelID)

module, _, err := k.LookupModuleByChannel(ctx, portID, channelID)
if err != nil {
return errorsmod.Wrap(err, "could not retrieve module from port-id")
}
k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", types.OPEN.String(), "new-state", types.INITUPGRADE.String())

portKeeper, ok := k.portKeeper.(*portkeeper.Keeper)
if !ok {
panic("todo: handle this situation")
}
emitChannelUpgradeInitEvent(ctx, portID, channelID, currentChannel, upgrade)
}

cbs, found := portKeeper.Router.GetRoute(module)
// constructProposedUpgrade returns the proposed upgrade from the provided arguments.
func (k Keeper) constructProposedUpgrade(ctx sdk.Context, portID, channelID string, fields types.UpgradeFields, timeout types.UpgradeTimeout) (types.Upgrade, error) {
seq, found := k.GetNextSequenceSend(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
return types.Upgrade{}, types.ErrSequenceSendNotFound
}

return cbs.OnChanUpgradeRestore(ctx, portID, channelID)
return types.Upgrade{
Fields: fields,
Timeout: timeout,
LatestSequenceSend: seq - 1,
}, nil
}
Loading

0 comments on commit 81a709b

Please sign in to comment.