Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

imp: rm app upgrade interface from IBCModule and use type assertions for app callback routing #5375

Merged
merged 8 commits into from
Dec 12, 2023
4 changes: 2 additions & 2 deletions e2e/tests/wasm/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
"testing"
"time"

testifysuite "github.com/stretchr/testify/suite"

"github.com/strangelove-ventures/interchaintest/v8/chain/cosmos"
"github.com/strangelove-ventures/interchaintest/v8/ibc"
"github.com/strangelove-ventures/interchaintest/v8/testutil"
testifysuite "github.com/stretchr/testify/suite"

upgradetypes "cosmossdk.io/x/upgrade/types"

authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"

"github.com/cosmos/ibc-go/e2e/testsuite"
"github.com/cosmos/ibc-go/e2e/testvalues"
wasmtypes "github.com/cosmos/ibc-go/modules/light-clients/08-wasm/types"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ func (im IBCMiddleware) OnTimeoutPacket(

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) (string, error) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return "", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
damiannolan marked this conversation as resolved.
Show resolved Hide resolved
}

if !im.keeper.GetParams(ctx).ControllerEnabled {
return "", types.ErrControllerSubModuleDisabled
}
Expand All @@ -248,7 +253,7 @@ func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID str
}

if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, connectionID) {
return im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, version)
return cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, version)
}

return version, nil
Expand Down
43 changes: 34 additions & 9 deletions modules/apps/29-fee/ibc_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,23 @@ func (im IBCMiddleware) OnChanUpgradeInit(
connectionHops []string,
upgradeVersion string,
) (string, error) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return "", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

versionMetadata, err := types.MetadataFromVersion(upgradeVersion)
if err != nil {
// since it is valid for fee version to not be specified, the upgrade version may be for a middleware
// or application further down in the stack. Thus, passthrough to next middleware or application in callstack.
return im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, upgradeVersion)
return cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, upgradeVersion)
}

if versionMetadata.FeeVersion != types.Version {
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, versionMetadata.FeeVersion)
}

appVersion, err := im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
appVersion, err := cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
if err != nil {
return "", err
}
Expand All @@ -362,18 +367,23 @@ func (im IBCMiddleware) OnChanUpgradeInit(

// OnChanUpgradeTry implement s the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string) (string, error) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return "", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

versionMetadata, err := types.MetadataFromVersion(counterpartyVersion)
if err != nil {
// since it is valid for fee version to not be specified, the counterparty upgrade version may be for a middleware
// or application further down in the stack. Thus, passthrough to next middleware or application in callstack.
return im.app.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, counterpartyVersion)
return cbs.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, counterpartyVersion)
}

if versionMetadata.FeeVersion != types.Version {
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, versionMetadata.FeeVersion)
}

appVersion, err := im.app.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
appVersion, err := cbs.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
if err != nil {
return "", err
}
Expand All @@ -389,40 +399,55 @@ func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID stri

// OnChanUpgradeAck implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, counterpartyVersion string) error {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

versionMetadata, err := types.MetadataFromVersion(counterpartyVersion)
if err != nil {
// since it is valid for fee version to not be specified, the counterparty upgrade version may be for a middleware
// or application further down in the stack. Thus, passthrough to next middleware or application in callstack.
return im.app.OnChanUpgradeAck(ctx, portID, channelID, counterpartyVersion)
return cbs.OnChanUpgradeAck(ctx, portID, channelID, counterpartyVersion)
}

if versionMetadata.FeeVersion != types.Version {
return errorsmod.Wrapf(types.ErrInvalidVersion, "expected counterparty fee version: %s, got: %s", types.Version, versionMetadata.FeeVersion)
}

// call underlying app's OnChanUpgradeAck callback with the counterparty app version.
return im.app.OnChanUpgradeAck(ctx, portID, channelID, versionMetadata.AppVersion)
return cbs.OnChanUpgradeAck(ctx, portID, channelID, versionMetadata.AppVersion)
}

// OnChanUpgradeOpen implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
panic(errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack"))
}

// discard the version metadata returned as upgrade fields have already been validated in previous handshake steps.
_, err := types.MetadataFromVersion(version)
if err != nil {
// set fee disabled and passthrough to the next middleware or application in callstack.
im.keeper.DeleteFeeEnabled(ctx, portID, channelID)
im.app.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
cbs.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
return
}

// set fee enabled and passthrough to the next middleware of application in callstack.
im.keeper.SetFeeEnabled(ctx, portID, channelID)
im.app.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
cbs.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
}

// OnChanUpgradeRestore implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeRestore(ctx sdk.Context, portID, channelID string) {
im.app.OnChanUpgradeRestore(ctx, portID, channelID)
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
panic(errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack"))
}

cbs.OnChanUpgradeRestore(ctx, portID, channelID)
}

// SendPacket implements the ICS4 Wrapper interface
Expand Down
10 changes: 8 additions & 2 deletions modules/apps/29-fee/ibc_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,10 @@ func (suite *FeeTestSuite) TestOnChanUpgradeAck() {
module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

err = cbs.OnChanUpgradeAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, counterpartyUpgrade.Fields.Version)
Expand Down Expand Up @@ -1403,7 +1406,10 @@ func (suite *FeeTestSuite) TestOnChanUpgradeOpen() {
module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

upgrade := path.EndpointA.GetChannelUpgrade()
Expand Down
11 changes: 9 additions & 2 deletions modules/apps/transfer/ibc_module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types"
host "github.com/cosmos/ibc-go/v8/modules/core/24-host"
ibctesting "github.com/cosmos/ibc-go/v8/testing"
)
Expand Down Expand Up @@ -371,7 +372,10 @@ func (suite *TransferTestSuite) TestOnChanUpgradeTry() {
module, _, err := suite.chainB.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainB.GetContext(), types.PortID)
suite.Require().NoError(err)

cbs, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

version, err := cbs.OnChanUpgradeTry(
Expand Down Expand Up @@ -439,7 +443,10 @@ func (suite *TransferTestSuite) TestOnChanUpgradeAck() {
module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), types.PortID)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

err = cbs.OnChanUpgradeAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.Version)
Expand Down
1 change: 0 additions & 1 deletion modules/core/05-port/types/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
// IBCModule defines an interface that implements all the callbacks
// that modules must define as specified in ICS-26
type IBCModule interface {
UpgradableModule
// OnChanOpenInit will verify that the relayer-chosen parameters
// are valid and perform any custom INIT logic.
// It may return an error if the chosen parameters are invalid
Expand Down
56 changes: 49 additions & 7 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,18 @@ func (k Keeper) ChannelUpgradeInit(goCtx context.Context, msg *channeltypes.MsgC
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

upgrade, err := k.ChannelKeeper.ChanUpgradeInit(ctx, msg.PortId, msg.ChannelId, msg.Fields)
if err != nil {
ctx.Logger().Error("channel upgrade init failed", "error", errorsmod.Wrap(err, "channel upgrade init failed"))
Expand Down Expand Up @@ -794,12 +800,18 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
DimitrisJim marked this conversation as resolved.
Show resolved Hide resolved
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

channel, upgrade, err := k.ChannelKeeper.ChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, msg.ProposedUpgradeConnectionHops, msg.CounterpartyUpgradeFields, msg.CounterpartyUpgradeSequence, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed"))
Expand Down Expand Up @@ -844,12 +856,18 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade ack failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade ack failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

err = k.ChannelKeeper.ChanUpgradeAck(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed"))
Expand Down Expand Up @@ -896,12 +914,18 @@ func (k Keeper) ChannelUpgradeConfirm(goCtx context.Context, msg *channeltypes.M
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade confirm failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade confirm failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

err = k.ChannelKeeper.ChanUpgradeConfirm(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade confirm failed", "error", errorsmod.Wrap(err, "channel upgrade confirm failed"))
Expand Down Expand Up @@ -950,12 +974,18 @@ func (k Keeper) ChannelUpgradeOpen(goCtx context.Context, msg *channeltypes.MsgC
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

if err = k.ChannelKeeper.ChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.ProofChannel, msg.ProofHeight); err != nil {
ctx.Logger().Error("channel upgrade open failed", "error", errorsmod.Wrap(err, "channel upgrade open failed"))
return nil, errorsmod.Wrap(err, "channel upgrade open failed")
Expand Down Expand Up @@ -985,12 +1015,18 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade timeout failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade timeout failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

err = k.ChannelKeeper.ChanUpgradeTimeout(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannel, msg.ProofChannel, msg.ProofHeight)
if err != nil {
return nil, errorsmod.Wrapf(err, "could not timeout upgrade for channel: %s", msg.ChannelId)
Expand Down Expand Up @@ -1018,12 +1054,18 @@ func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.Ms
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

isAuthority := k.GetAuthority() == msg.Signer
if err := k.ChannelKeeper.ChanUpgradeCancel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt, msg.ProofErrorReceipt, msg.ProofHeight, isAuthority); err != nil {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", err.Error())
Expand Down
Loading