diff --git a/x/wasm/ibc.go b/x/wasm/ibc.go index 0d7599e3cd..2d00e7ee06 100644 --- a/x/wasm/ibc.go +++ b/x/wasm/ibc.go @@ -43,13 +43,18 @@ func (i IBCHandler) OnChanOpenInit( return sdkerrors.Wrapf(err, "contract port id") } - err = i.keeper.OnOpenChannel(ctx, contractAddr, wasmvmtypes.IBCChannel{ - Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, - CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId}, - Order: order.String(), - Version: version, - ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. - }, "") + msg := wasmvmtypes.IBCChannelOpenMsg{ + OpenInit: &wasmvmtypes.IBCOpenInit{ + Channel: wasmvmtypes.IBCChannel{ + Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, + CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId}, + Order: order.String(), + Version: version, + ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. + }, + }, + } + err = i.keeper.OnOpenChannel(ctx, contractAddr, msg) if err != nil { return err } @@ -80,13 +85,20 @@ func (i IBCHandler) OnChanOpenTry( return sdkerrors.Wrapf(err, "contract port id") } - err = i.keeper.OnOpenChannel(ctx, contractAddr, wasmvmtypes.IBCChannel{ - Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, - CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId}, - Order: order.String(), - Version: version, - ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. - }, counterpartyVersion) + msg := wasmvmtypes.IBCChannelOpenMsg{ + OpenTry: &wasmvmtypes.IBCOpenTry{ + Channel: wasmvmtypes.IBCChannel{ + Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, + CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId}, + Order: order.String(), + Version: version, + ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. + }, + CounterpartyVersion: counterpartyVersion, + }, + } + + err = i.keeper.OnOpenChannel(ctx, contractAddr, msg) if err != nil { return err } @@ -117,7 +129,13 @@ func (i IBCHandler) OnChanOpenAck( if !ok { return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - return i.keeper.OnConnectChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), counterpartyVersion) + msg := wasmvmtypes.IBCChannelConnectMsg{ + OpenAck: &wasmvmtypes.IBCOpenAck{ + Channel: toWasmVMChannel(portID, channelID, channelInfo), + CounterpartyVersion: counterpartyVersion, + }, + } + return i.keeper.OnConnectChannel(ctx, contractAddr, msg) } // OnChanOpenConfirm implements the IBCModule interface @@ -130,7 +148,12 @@ func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string) if !ok { return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - return i.keeper.OnConnectChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), "") + msg := wasmvmtypes.IBCChannelConnectMsg{ + OpenConfirm: &wasmvmtypes.IBCOpenConfirm{ + Channel: toWasmVMChannel(portID, channelID, channelInfo), + }, + } + return i.keeper.OnConnectChannel(ctx, contractAddr, msg) } // OnChanCloseInit implements the IBCModule interface @@ -144,7 +167,10 @@ func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) e return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - err = i.keeper.OnCloseChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), false) + msg := wasmvmtypes.IBCChannelCloseMsg{ + CloseInit: &wasmvmtypes.IBCCloseInit{Channel: toWasmVMChannel(portID, channelID, channelInfo)}, + } + err = i.keeper.OnCloseChannel(ctx, contractAddr, msg) if err != nil { return err } @@ -165,7 +191,10 @@ func (i IBCHandler) OnChanCloseConfirm(ctx sdk.Context, portID, channelID string return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - err = i.keeper.OnCloseChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), true) + msg := wasmvmtypes.IBCChannelCloseMsg{ + CloseConfirm: &wasmvmtypes.IBCCloseConfirm{Channel: toWasmVMChannel(portID, channelID, channelInfo)}, + } + err = i.keeper.OnCloseChannel(ctx, contractAddr, msg) if err != nil { return err } @@ -193,7 +222,8 @@ func (i IBCHandler) OnRecvPacket( if err != nil { return nil, nil, sdkerrors.Wrapf(err, "contract port id") } - ack, err := i.keeper.OnRecvPacket(ctx, contractAddr, newIBCPacket(packet)) + msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: newIBCPacket(packet)} + ack, err := i.keeper.OnRecvPacket(ctx, contractAddr, msg) if err != nil { return nil, nil, err } @@ -230,7 +260,8 @@ func (i IBCHandler) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet) if err != nil { return nil, sdkerrors.Wrapf(err, "contract port id") } - err = i.keeper.OnTimeoutPacket(ctx, contractAddr, newIBCPacket(packet)) + msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: newIBCPacket(packet)} + err = i.keeper.OnTimeoutPacket(ctx, contractAddr, msg) if err != nil { return nil, err } diff --git a/x/wasm/keeper/relay.go b/x/wasm/keeper/relay.go index 19a0e6fab5..8d7f72c9f3 100644 --- a/x/wasm/keeper/relay.go +++ b/x/wasm/keeper/relay.go @@ -9,6 +9,8 @@ import ( "time" ) +var _ types.IBCContractKeeper = (*Keeper)(nil) + // OnOpenChannel calls the contract to participate in the IBC channel handshake step. // In the IBC protocol this is either the `Channel Open Init` event on the initiating chain or // `Channel Open Try` on the counterparty chain. @@ -17,9 +19,7 @@ import ( func (k Keeper) OnOpenChannel( ctx sdk.Context, contractAddr sdk.AccAddress, - channel wasmvmtypes.IBCChannel, - // this is unset on init, set on try - counterpartyVersion string, + msg wasmvmtypes.IBCChannelOpenMsg, ) error { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-open-channel") @@ -31,18 +31,6 @@ func (k Keeper) OnOpenChannel( env := types.NewEnv(ctx, contractAddr) querier := k.newQueryHandler(ctx, contractAddr) - msg := wasmvmtypes.IBCChannelOpenMsg{} - if counterpartyVersion == "" { - msg.OpenInit = &wasmvmtypes.IBCOpenInit{ - Channel: channel, - } - } else { - msg.OpenTry = &wasmvmtypes.IBCOpenTry{ - Channel: channel, - CounterpartyVersion: counterpartyVersion, - } - } - gas := k.runtimeGasForContract(ctx) gasUsed, execErr := k.wasmVM.IBCChannelOpen(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization) k.consumeRuntimeGas(ctx, gasUsed) @@ -63,9 +51,7 @@ func (k Keeper) OnOpenChannel( func (k Keeper) OnConnectChannel( ctx sdk.Context, contractAddr sdk.AccAddress, - channel wasmvmtypes.IBCChannel, - // this is set on ack, unset on confirm - counterpartyVersion string, + msg wasmvmtypes.IBCChannelConnectMsg, ) error { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-connect-channel") contractInfo, codeInfo, prefixStore, err := k.contractInstance(ctx, contractAddr) @@ -76,18 +62,6 @@ func (k Keeper) OnConnectChannel( env := types.NewEnv(ctx, contractAddr) querier := k.newQueryHandler(ctx, contractAddr) - msg := wasmvmtypes.IBCChannelConnectMsg{} - if counterpartyVersion == "" { - msg.OpenConfirm = &wasmvmtypes.IBCOpenConfirm{ - Channel: channel, - } - } else { - msg.OpenAck = &wasmvmtypes.IBCOpenAck{ - Channel: channel, - CounterpartyVersion: counterpartyVersion, - } - } - gas := k.runtimeGasForContract(ctx) res, gasUsed, execErr := k.wasmVM.IBCChannelConnect(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization) k.consumeRuntimeGas(ctx, gasUsed) @@ -107,9 +81,7 @@ func (k Keeper) OnConnectChannel( func (k Keeper) OnCloseChannel( ctx sdk.Context, contractAddr sdk.AccAddress, - channel wasmvmtypes.IBCChannel, - // false for init, true for confirm - confirm bool, + msg wasmvmtypes.IBCChannelCloseMsg, ) error { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-close-channel") @@ -121,17 +93,6 @@ func (k Keeper) OnCloseChannel( params := types.NewEnv(ctx, contractAddr) querier := k.newQueryHandler(ctx, contractAddr) - msg := wasmvmtypes.IBCChannelCloseMsg{} - if confirm { - msg.CloseConfirm = &wasmvmtypes.IBCCloseConfirm{ - Channel: channel, - } - } else { - msg.CloseInit = &wasmvmtypes.IBCCloseInit{ - Channel: channel, - } - } - gas := k.runtimeGasForContract(ctx) res, gasUsed, execErr := k.wasmVM.IBCChannelClose(codeInfo.CodeHash, params, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization) k.consumeRuntimeGas(ctx, gasUsed) @@ -151,7 +112,7 @@ func (k Keeper) OnCloseChannel( func (k Keeper) OnRecvPacket( ctx sdk.Context, contractAddr sdk.AccAddress, - packet wasmvmtypes.IBCPacket, + msg wasmvmtypes.IBCPacketReceiveMsg, ) ([]byte, error) { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-recv-packet") contractInfo, codeInfo, prefixStore, err := k.contractInstance(ctx, contractAddr) @@ -161,7 +122,6 @@ func (k Keeper) OnRecvPacket( env := types.NewEnv(ctx, contractAddr) querier := k.newQueryHandler(ctx, contractAddr) - msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: packet} gas := k.runtimeGasForContract(ctx) res, gasUsed, execErr := k.wasmVM.IBCPacketReceive(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization) @@ -209,7 +169,7 @@ func (k Keeper) OnAckPacket( func (k Keeper) OnTimeoutPacket( ctx sdk.Context, contractAddr sdk.AccAddress, - packet wasmvmtypes.IBCPacket, + msg wasmvmtypes.IBCPacketTimeoutMsg, ) error { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-timeout-packet") @@ -220,7 +180,6 @@ func (k Keeper) OnTimeoutPacket( env := types.NewEnv(ctx, contractAddr) querier := k.newQueryHandler(ctx, contractAddr) - msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: packet} gas := k.runtimeGasForContract(ctx) res, gasUsed, execErr := k.wasmVM.IBCPacketTimeout(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization) diff --git a/x/wasm/keeper/relay_test.go b/x/wasm/keeper/relay_test.go index 8f00067e14..ef313c2e3a 100644 --- a/x/wasm/keeper/relay_test.go +++ b/x/wasm/keeper/relay_test.go @@ -62,7 +62,13 @@ func TestOnOpenChannel(t *testing.T) { before := ctx.GasMeter().GasConsumed() // when - err := keepers.WasmKeeper.OnOpenChannel(ctx, spec.contractAddr, myChannel, "foo") + msg := wasmvmtypes.IBCChannelOpenMsg{ + OpenTry: &wasmvmtypes.IBCOpenTry{ + Channel: myChannel, + CounterpartyVersion: "foo", + }, + } + err := keepers.WasmKeeper.OnOpenChannel(ctx, spec.contractAddr, msg) // then if spec.expErr { @@ -163,7 +169,12 @@ func TestOnConnectChannel(t *testing.T) { } // when - err := keepers.WasmKeeper.OnConnectChannel(ctx, spec.contractAddr, myChannel, "") + msg := wasmvmtypes.IBCChannelConnectMsg{ + OpenConfirm: &wasmvmtypes.IBCOpenConfirm{ + Channel: myChannel, + }, + } + err := keepers.WasmKeeper.OnConnectChannel(ctx, spec.contractAddr, msg) // then events := ctx.EventManager().Events() @@ -279,7 +290,12 @@ func TestOnCloseChannel(t *testing.T) { } // when - err := keepers.WasmKeeper.OnCloseChannel(ctx, spec.contractAddr, myChannel, false) + msg := wasmvmtypes.IBCChannelCloseMsg{ + CloseInit: &wasmvmtypes.IBCCloseInit{ + Channel: myChannel, + }, + } + err := keepers.WasmKeeper.OnCloseChannel(ctx, spec.contractAddr, msg) // then events := ctx.EventManager().Events() @@ -453,7 +469,8 @@ func TestOnRecvPacket(t *testing.T) { } // when - gotAck, err := keepers.WasmKeeper.OnRecvPacket(ctx, spec.contractAddr, myPacket) + msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: myPacket} + gotAck, err := keepers.WasmKeeper.OnRecvPacket(ctx, spec.contractAddr, msg) // then events := ctx.EventManager().Events() @@ -704,7 +721,8 @@ func TestOnTimeoutPacket(t *testing.T) { } // when - err := keepers.WasmKeeper.OnTimeoutPacket(ctx, spec.contractAddr, myPacket) + msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: myPacket} + err := keepers.WasmKeeper.OnTimeoutPacket(ctx, spec.contractAddr, msg) // then events := ctx.EventManager().Events() diff --git a/x/wasm/types/exported_keepers.go b/x/wasm/types/exported_keepers.go index 3a5ef173ff..345cbf09df 100644 --- a/x/wasm/types/exported_keepers.go +++ b/x/wasm/types/exported_keepers.go @@ -58,28 +58,22 @@ type IBCContractKeeper interface { OnOpenChannel( ctx sdk.Context, contractAddr sdk.AccAddress, - channel wasmvmtypes.IBCChannel, - // this is unset on init, set on try - counterpartyVersion string, + msg wasmvmtypes.IBCChannelOpenMsg, ) error OnConnectChannel( ctx sdk.Context, contractAddr sdk.AccAddress, - channel wasmvmtypes.IBCChannel, - // this is set on ack, unset on confirm - counterpartyVersion string, + msg wasmvmtypes.IBCChannelConnectMsg, ) error OnCloseChannel( ctx sdk.Context, contractAddr sdk.AccAddress, - channel wasmvmtypes.IBCChannel, - // false for init, true for confirm - confirm bool, + msg wasmvmtypes.IBCChannelCloseMsg, ) error OnRecvPacket( ctx sdk.Context, contractAddr sdk.AccAddress, - packet wasmvmtypes.IBCPacket, + msg wasmvmtypes.IBCPacketReceiveMsg, ) ([]byte, error) OnAckPacket( ctx sdk.Context, @@ -89,7 +83,7 @@ type IBCContractKeeper interface { OnTimeoutPacket( ctx sdk.Context, contractAddr sdk.AccAddress, - packet wasmvmtypes.IBCPacket, + msg wasmvmtypes.IBCPacketTimeoutMsg, ) error // ClaimCapability allows the transfer module to claim a capability //that IBC module passes to it