Skip to content

Commit

Permalink
Merge pull request #570 from CosmWasm/561-better-ibc-contract-interface
Browse files Browse the repository at this point in the history
Better ibc contract interface
  • Loading branch information
ethanfrey authored Jul 29, 2021
2 parents 3d1ff09 + 0fe0b62 commit 0f6f437
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 84 deletions.
71 changes: 51 additions & 20 deletions x/wasm/ibc.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@ func (i IBCHandler) OnChanOpenInit(
return sdkerrors.Wrapf(err, "contract port id")
}

err = i.keeper.OnOpenChannel(ctx, contractAddr, wasmvmtypes.IBCChannel{
Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID},
CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId},
Order: order.String(),
Version: version,
ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported.
}, "")
msg := wasmvmtypes.IBCChannelOpenMsg{
OpenInit: &wasmvmtypes.IBCOpenInit{
Channel: wasmvmtypes.IBCChannel{
Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID},
CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId},
Order: order.String(),
Version: version,
ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported.
},
},
}
err = i.keeper.OnOpenChannel(ctx, contractAddr, msg)
if err != nil {
return err
}
Expand Down Expand Up @@ -80,13 +85,20 @@ func (i IBCHandler) OnChanOpenTry(
return sdkerrors.Wrapf(err, "contract port id")
}

err = i.keeper.OnOpenChannel(ctx, contractAddr, wasmvmtypes.IBCChannel{
Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID},
CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId},
Order: order.String(),
Version: version,
ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported.
}, counterpartyVersion)
msg := wasmvmtypes.IBCChannelOpenMsg{
OpenTry: &wasmvmtypes.IBCOpenTry{
Channel: wasmvmtypes.IBCChannel{
Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID},
CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId},
Order: order.String(),
Version: version,
ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported.
},
CounterpartyVersion: counterpartyVersion,
},
}

err = i.keeper.OnOpenChannel(ctx, contractAddr, msg)
if err != nil {
return err
}
Expand Down Expand Up @@ -117,7 +129,13 @@ func (i IBCHandler) OnChanOpenAck(
if !ok {
return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}
return i.keeper.OnConnectChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), counterpartyVersion)
msg := wasmvmtypes.IBCChannelConnectMsg{
OpenAck: &wasmvmtypes.IBCOpenAck{
Channel: toWasmVMChannel(portID, channelID, channelInfo),
CounterpartyVersion: counterpartyVersion,
},
}
return i.keeper.OnConnectChannel(ctx, contractAddr, msg)
}

// OnChanOpenConfirm implements the IBCModule interface
Expand All @@ -130,7 +148,12 @@ func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string)
if !ok {
return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}
return i.keeper.OnConnectChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), "")
msg := wasmvmtypes.IBCChannelConnectMsg{
OpenConfirm: &wasmvmtypes.IBCOpenConfirm{
Channel: toWasmVMChannel(portID, channelID, channelInfo),
},
}
return i.keeper.OnConnectChannel(ctx, contractAddr, msg)
}

// OnChanCloseInit implements the IBCModule interface
Expand All @@ -144,7 +167,10 @@ func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) e
return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

err = i.keeper.OnCloseChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), false)
msg := wasmvmtypes.IBCChannelCloseMsg{
CloseInit: &wasmvmtypes.IBCCloseInit{Channel: toWasmVMChannel(portID, channelID, channelInfo)},
}
err = i.keeper.OnCloseChannel(ctx, contractAddr, msg)
if err != nil {
return err
}
Expand All @@ -165,7 +191,10 @@ func (i IBCHandler) OnChanCloseConfirm(ctx sdk.Context, portID, channelID string
return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

err = i.keeper.OnCloseChannel(ctx, contractAddr, toWasmVMChannel(portID, channelID, channelInfo), true)
msg := wasmvmtypes.IBCChannelCloseMsg{
CloseConfirm: &wasmvmtypes.IBCCloseConfirm{Channel: toWasmVMChannel(portID, channelID, channelInfo)},
}
err = i.keeper.OnCloseChannel(ctx, contractAddr, msg)
if err != nil {
return err
}
Expand Down Expand Up @@ -193,7 +222,8 @@ func (i IBCHandler) OnRecvPacket(
if err != nil {
return nil, nil, sdkerrors.Wrapf(err, "contract port id")
}
ack, err := i.keeper.OnRecvPacket(ctx, contractAddr, newIBCPacket(packet))
msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: newIBCPacket(packet)}
ack, err := i.keeper.OnRecvPacket(ctx, contractAddr, msg)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -230,7 +260,8 @@ func (i IBCHandler) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet)
if err != nil {
return nil, sdkerrors.Wrapf(err, "contract port id")
}
err = i.keeper.OnTimeoutPacket(ctx, contractAddr, newIBCPacket(packet))
msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: newIBCPacket(packet)}
err = i.keeper.OnTimeoutPacket(ctx, contractAddr, msg)
if err != nil {
return nil, err
}
Expand Down
55 changes: 7 additions & 48 deletions x/wasm/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"time"
)

var _ types.IBCContractKeeper = (*Keeper)(nil)

// OnOpenChannel calls the contract to participate in the IBC channel handshake step.
// In the IBC protocol this is either the `Channel Open Init` event on the initiating chain or
// `Channel Open Try` on the counterparty chain.
Expand All @@ -17,9 +19,7 @@ import (
func (k Keeper) OnOpenChannel(
ctx sdk.Context,
contractAddr sdk.AccAddress,
channel wasmvmtypes.IBCChannel,
// this is unset on init, set on try
counterpartyVersion string,
msg wasmvmtypes.IBCChannelOpenMsg,
) error {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-open-channel")

Expand All @@ -31,18 +31,6 @@ func (k Keeper) OnOpenChannel(
env := types.NewEnv(ctx, contractAddr)
querier := k.newQueryHandler(ctx, contractAddr)

msg := wasmvmtypes.IBCChannelOpenMsg{}
if counterpartyVersion == "" {
msg.OpenInit = &wasmvmtypes.IBCOpenInit{
Channel: channel,
}
} else {
msg.OpenTry = &wasmvmtypes.IBCOpenTry{
Channel: channel,
CounterpartyVersion: counterpartyVersion,
}
}

gas := k.runtimeGasForContract(ctx)
gasUsed, execErr := k.wasmVM.IBCChannelOpen(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization)
k.consumeRuntimeGas(ctx, gasUsed)
Expand All @@ -63,9 +51,7 @@ func (k Keeper) OnOpenChannel(
func (k Keeper) OnConnectChannel(
ctx sdk.Context,
contractAddr sdk.AccAddress,
channel wasmvmtypes.IBCChannel,
// this is set on ack, unset on confirm
counterpartyVersion string,
msg wasmvmtypes.IBCChannelConnectMsg,
) error {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-connect-channel")
contractInfo, codeInfo, prefixStore, err := k.contractInstance(ctx, contractAddr)
Expand All @@ -76,18 +62,6 @@ func (k Keeper) OnConnectChannel(
env := types.NewEnv(ctx, contractAddr)
querier := k.newQueryHandler(ctx, contractAddr)

msg := wasmvmtypes.IBCChannelConnectMsg{}
if counterpartyVersion == "" {
msg.OpenConfirm = &wasmvmtypes.IBCOpenConfirm{
Channel: channel,
}
} else {
msg.OpenAck = &wasmvmtypes.IBCOpenAck{
Channel: channel,
CounterpartyVersion: counterpartyVersion,
}
}

gas := k.runtimeGasForContract(ctx)
res, gasUsed, execErr := k.wasmVM.IBCChannelConnect(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization)
k.consumeRuntimeGas(ctx, gasUsed)
Expand All @@ -107,9 +81,7 @@ func (k Keeper) OnConnectChannel(
func (k Keeper) OnCloseChannel(
ctx sdk.Context,
contractAddr sdk.AccAddress,
channel wasmvmtypes.IBCChannel,
// false for init, true for confirm
confirm bool,
msg wasmvmtypes.IBCChannelCloseMsg,
) error {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-close-channel")

Expand All @@ -121,17 +93,6 @@ func (k Keeper) OnCloseChannel(
params := types.NewEnv(ctx, contractAddr)
querier := k.newQueryHandler(ctx, contractAddr)

msg := wasmvmtypes.IBCChannelCloseMsg{}
if confirm {
msg.CloseConfirm = &wasmvmtypes.IBCCloseConfirm{
Channel: channel,
}
} else {
msg.CloseInit = &wasmvmtypes.IBCCloseInit{
Channel: channel,
}
}

gas := k.runtimeGasForContract(ctx)
res, gasUsed, execErr := k.wasmVM.IBCChannelClose(codeInfo.CodeHash, params, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization)
k.consumeRuntimeGas(ctx, gasUsed)
Expand All @@ -151,7 +112,7 @@ func (k Keeper) OnCloseChannel(
func (k Keeper) OnRecvPacket(
ctx sdk.Context,
contractAddr sdk.AccAddress,
packet wasmvmtypes.IBCPacket,
msg wasmvmtypes.IBCPacketReceiveMsg,
) ([]byte, error) {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-recv-packet")
contractInfo, codeInfo, prefixStore, err := k.contractInstance(ctx, contractAddr)
Expand All @@ -161,7 +122,6 @@ func (k Keeper) OnRecvPacket(

env := types.NewEnv(ctx, contractAddr)
querier := k.newQueryHandler(ctx, contractAddr)
msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: packet}

gas := k.runtimeGasForContract(ctx)
res, gasUsed, execErr := k.wasmVM.IBCPacketReceive(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization)
Expand Down Expand Up @@ -209,7 +169,7 @@ func (k Keeper) OnAckPacket(
func (k Keeper) OnTimeoutPacket(
ctx sdk.Context,
contractAddr sdk.AccAddress,
packet wasmvmtypes.IBCPacket,
msg wasmvmtypes.IBCPacketTimeoutMsg,
) error {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "ibc-timeout-packet")

Expand All @@ -220,7 +180,6 @@ func (k Keeper) OnTimeoutPacket(

env := types.NewEnv(ctx, contractAddr)
querier := k.newQueryHandler(ctx, contractAddr)
msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: packet}

gas := k.runtimeGasForContract(ctx)
res, gasUsed, execErr := k.wasmVM.IBCPacketTimeout(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas, costJsonDeserialization)
Expand Down
28 changes: 23 additions & 5 deletions x/wasm/keeper/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ func TestOnOpenChannel(t *testing.T) {
before := ctx.GasMeter().GasConsumed()

// when
err := keepers.WasmKeeper.OnOpenChannel(ctx, spec.contractAddr, myChannel, "foo")
msg := wasmvmtypes.IBCChannelOpenMsg{
OpenTry: &wasmvmtypes.IBCOpenTry{
Channel: myChannel,
CounterpartyVersion: "foo",
},
}
err := keepers.WasmKeeper.OnOpenChannel(ctx, spec.contractAddr, msg)

// then
if spec.expErr {
Expand Down Expand Up @@ -163,7 +169,12 @@ func TestOnConnectChannel(t *testing.T) {
}

// when
err := keepers.WasmKeeper.OnConnectChannel(ctx, spec.contractAddr, myChannel, "")
msg := wasmvmtypes.IBCChannelConnectMsg{
OpenConfirm: &wasmvmtypes.IBCOpenConfirm{
Channel: myChannel,
},
}
err := keepers.WasmKeeper.OnConnectChannel(ctx, spec.contractAddr, msg)

// then
events := ctx.EventManager().Events()
Expand Down Expand Up @@ -279,7 +290,12 @@ func TestOnCloseChannel(t *testing.T) {
}

// when
err := keepers.WasmKeeper.OnCloseChannel(ctx, spec.contractAddr, myChannel, false)
msg := wasmvmtypes.IBCChannelCloseMsg{
CloseInit: &wasmvmtypes.IBCCloseInit{
Channel: myChannel,
},
}
err := keepers.WasmKeeper.OnCloseChannel(ctx, spec.contractAddr, msg)

// then
events := ctx.EventManager().Events()
Expand Down Expand Up @@ -453,7 +469,8 @@ func TestOnRecvPacket(t *testing.T) {
}

// when
gotAck, err := keepers.WasmKeeper.OnRecvPacket(ctx, spec.contractAddr, myPacket)
msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: myPacket}
gotAck, err := keepers.WasmKeeper.OnRecvPacket(ctx, spec.contractAddr, msg)

// then
events := ctx.EventManager().Events()
Expand Down Expand Up @@ -704,7 +721,8 @@ func TestOnTimeoutPacket(t *testing.T) {
}

// when
err := keepers.WasmKeeper.OnTimeoutPacket(ctx, spec.contractAddr, myPacket)
msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: myPacket}
err := keepers.WasmKeeper.OnTimeoutPacket(ctx, spec.contractAddr, msg)

// then
events := ctx.EventManager().Events()
Expand Down
16 changes: 5 additions & 11 deletions x/wasm/types/exported_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,28 +58,22 @@ type IBCContractKeeper interface {
OnOpenChannel(
ctx sdk.Context,
contractAddr sdk.AccAddress,
channel wasmvmtypes.IBCChannel,
// this is unset on init, set on try
counterpartyVersion string,
msg wasmvmtypes.IBCChannelOpenMsg,
) error
OnConnectChannel(
ctx sdk.Context,
contractAddr sdk.AccAddress,
channel wasmvmtypes.IBCChannel,
// this is set on ack, unset on confirm
counterpartyVersion string,
msg wasmvmtypes.IBCChannelConnectMsg,
) error
OnCloseChannel(
ctx sdk.Context,
contractAddr sdk.AccAddress,
channel wasmvmtypes.IBCChannel,
// false for init, true for confirm
confirm bool,
msg wasmvmtypes.IBCChannelCloseMsg,
) error
OnRecvPacket(
ctx sdk.Context,
contractAddr sdk.AccAddress,
packet wasmvmtypes.IBCPacket,
msg wasmvmtypes.IBCPacketReceiveMsg,
) ([]byte, error)
OnAckPacket(
ctx sdk.Context,
Expand All @@ -89,7 +83,7 @@ type IBCContractKeeper interface {
OnTimeoutPacket(
ctx sdk.Context,
contractAddr sdk.AccAddress,
packet wasmvmtypes.IBCPacket,
msg wasmvmtypes.IBCPacketTimeoutMsg,
) error
// ClaimCapability allows the transfer module to claim a capability
//that IBC module passes to it
Expand Down

0 comments on commit 0f6f437

Please sign in to comment.