From e5332a150a515b32bebfc48e073d1d5861afe5e5 Mon Sep 17 00:00:00 2001 From: Masanori Yoshida Date: Mon, 19 Aug 2024 15:26:52 +0900 Subject: [PATCH] WIP --- cmd/tx.go | 69 +++++- core/channel-upgrade.go | 519 ++++++++++++++++++++++++---------------- 2 files changed, 367 insertions(+), 221 deletions(-) diff --git a/cmd/tx.go b/cmd/tx.go index ef5d1157..310a2180 100644 --- a/cmd/tx.go +++ b/cmd/tx.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "errors" "fmt" "strings" "time" @@ -13,6 +14,7 @@ import ( "github.com/hyperledger-labs/yui-relayer/config" "github.com/hyperledger-labs/yui-relayer/core" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/spf13/viper" ) @@ -241,13 +243,11 @@ func channelUpgradeInitCmd(ctx *config.Context) *cobra.Command { } // get ordering from flags - var ordering chantypes.Order - if s, err := cmd.Flags().GetString(flagOrdering); err != nil { + ordering, err := getOrderFromFlags(cmd.Flags(), flagOrdering) + if err != nil { return err - } else if n, ok := chantypes.Order_value[s]; !ok || n == int32(chantypes.NONE) { - return fmt.Errorf("invalid ordering flag: %s", s) - } else { - ordering = chantypes.Order(n) + } else if ordering == chantypes.NONE { + return errors.New("NONE is unacceptable channel ordering") } // get connection hops from flags @@ -279,12 +279,14 @@ func channelUpgradeInitCmd(ctx *config.Context) *cobra.Command { func channelUpgradeExecuteCmd(ctx *config.Context) *cobra.Command { const ( - flagInterval = "interval" - flagUntilFlushing = "until-flushing" + flagInterval = "interval" + flagTargetSrcState = "target-src-state" + flagTargetDstState = "target-dst-state" ) const ( - defaultInterval = time.Second + defaultInterval = time.Second + defaultTargetState = "OPEN" ) cmd := cobra.Command{ @@ -313,17 +315,23 @@ func channelUpgradeExecuteCmd(ctx *config.Context) *cobra.Command { return err } - untilFlushing, err := cmd.Flags().GetBool(flagUntilFlushing) + targetSrcState, err := getUpgradeStateFromFlags(cmd.Flags(), flagTargetSrcState) if err != nil { return err } - return core.ExecuteChannelUpgrade(pathName, src, dst, interval, untilFlushing) + targetDstState, err := getUpgradeStateFromFlags(cmd.Flags(), flagTargetDstState) + if err != nil { + return err + } + + return core.ExecuteChannelUpgrade(pathName, src, dst, interval, targetSrcState, targetDstState) }, } - cmd.Flags().Duration(flagInterval, defaultInterval, "interval between attempts to proceed channel upgrade steps") - cmd.Flags().Bool(flagUntilFlushing, false, "the process exits when both chains have started flushing") + cmd.Flags().Duration(flagInterval, defaultInterval, "the interval between attempts to proceed the upgrade handshake") + cmd.Flags().String(flagTargetSrcState, defaultTargetState, "the source channel's upgrade state to be reached") + cmd.Flags().String(flagTargetDstState, defaultTargetState, "the destination channel's upgrade state to be reached") return &cmd } @@ -542,3 +550,38 @@ func getUint64Slice(key string) []uint64 { } return ret } + +func getUpgradeStateFromFlags(flags *pflag.FlagSet, flagName string) (core.UpgradeState, error) { + s, err := flags.GetString(flagName) + if err != nil { + return 0, err + } + + switch strings.ToUpper(s) { + case "UNINIT": + return core.UPGRADE_STATE_UNINIT, nil + case "INIT": + return core.UPGRADE_STATE_INIT, nil + case "FLUSHING": + return core.UPGRADE_STATE_FLUSHING, nil + case "FLUSHCOMPLETE": + return core.UPGRADE_STATE_FLUSHCOMPLETE, nil + default: + return 0, fmt.Errorf("invalid upgrade state specified: %s", s) + } +} + +func getOrderFromFlags(flags *pflag.FlagSet, flagName string) (chantypes.Order, error) { + s, err := flags.GetString(flagName) + if err != nil { + return 0, err + } + + s = "ORDER_" + strings.ToUpper(s) + value, ok := chantypes.Order_value[s] + if !ok { + return 0, fmt.Errorf("invalid channel order specified: %s", s) + } + + return chantypes.Order(value), nil +} diff --git a/core/channel-upgrade.go b/core/channel-upgrade.go index 7613f933..403ff808 100644 --- a/core/channel-upgrade.go +++ b/core/channel-upgrade.go @@ -3,7 +3,7 @@ package core import ( "errors" "fmt" - "math" + "log/slog" "time" retry "github.com/avast/retry-go" @@ -12,6 +12,66 @@ import ( chantypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" ) +type UpgradeState int + +const ( + UPGRADE_STATE_UNINIT UpgradeState = iota + UPGRADE_STATE_INIT + UPGRADE_STATE_FLUSHING + UPGRADE_STATE_FLUSHCOMPLETE +) + +func (state UpgradeState) String() string { + switch state { + case UPGRADE_STATE_UNINIT: + return "UNINIT" + case UPGRADE_STATE_INIT: + return "INIT" + case UPGRADE_STATE_FLUSHING: + return "FLUSHING" + case UPGRADE_STATE_FLUSHCOMPLETE: + return "FLUSHCOMPLETE" + default: + panic(fmt.Errorf("unexpected UpgradeState: %d", state)) + } +} + +type UpgradeAction int + +const ( + UPGRADE_ACTION_NONE UpgradeAction = iota + UPGRADE_ACTION_TRY + UPGRADE_ACTION_ACK + UPGRADE_ACTION_CONFIRM + UPGRADE_ACTION_OPEN + UPGRADE_ACTION_CANCEL + UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE + UPGRADE_ACTION_TIMEOUT +) + +func (action UpgradeAction) String() string { + switch action { + case UPGRADE_ACTION_NONE: + return "NONE" + case UPGRADE_ACTION_TRY: + return "TRY" + case UPGRADE_ACTION_ACK: + return "ACK" + case UPGRADE_ACTION_CONFIRM: + return "CONFIRM" + case UPGRADE_ACTION_OPEN: + return "OPEN" + case UPGRADE_ACTION_CANCEL: + return "CANCEL" + case UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE: + return "CANCEL_FLUSHCOMPLETE" + case UPGRADE_ACTION_TIMEOUT: + return "TIMEOUT" + default: + panic(fmt.Errorf("unexpected UpgradeAction: %d", action)) + } +} + // InitChannelUpgrade builds `MsgChannelUpgradeInit` based on the specified UpgradeFields and sends it to the specified chain. func InitChannelUpgrade(chain *ProvableChain, upgradeFields chantypes.UpgradeFields) error { logger := GetChannelLogger(chain.Chain) @@ -46,19 +106,34 @@ func InitChannelUpgrade(chain *ProvableChain, upgradeFields chantypes.UpgradeFie // ExecuteChannelUpgrade carries out channel upgrade handshake until both chains transition to the OPEN state. // This function repeatedly checks the states of both chains and decides the next action. -func ExecuteChannelUpgrade(pathName string, src, dst *ProvableChain, interval time.Duration, untilFlushing bool) error { +func ExecuteChannelUpgrade(pathName string, src, dst *ProvableChain, interval time.Duration, targetSrcState, targetDstState UpgradeState) error { logger := GetChannelPairLogger(src, dst) defer logger.TimeTrack(time.Now(), "ExecuteChannelUpgrade") + // UNINIT, INIT and FLUSHING are only supported for the target state + if targetSrcState == UPGRADE_STATE_FLUSHCOMPLETE { + return errors.New("FLUSHCOMPLETE as a target src state is not supported") + } else if targetDstState == UPGRADE_STATE_FLUSHCOMPLETE { + return errors.New("FLUSHCOMPLETE as a target dst state is not supported") + } + tick := time.Tick(interval) failures := 0 + firstCall := true for { <-tick - steps, err := upgradeChannelStep(src, dst, untilFlushing) + steps, err := upgradeChannelStep(src, dst, targetSrcState, targetDstState, firstCall) if err != nil { logger.Error("failed to create channel upgrade step", err) return err + } else { + firstCall = false + } + + if steps.Last { + logger.Info("Channel upgrade completed") + return nil } if !steps.Ready() { @@ -73,11 +148,6 @@ func ExecuteChannelUpgrade(pathName string, src, dst *ProvableChain, interval ti return err } - if steps.Last { - logger.Info("Channel upgrade completed") - return nil - } - failures = 0 } else { if failures++; failures > 2 { @@ -155,7 +225,24 @@ func CancelChannelUpgrade(chain, cp *ProvableChain) error { return nil } -func upgradeChannelStep(src, dst *ProvableChain, untilFlushing bool) (*RelayMsgs, error) { +func NewUpgradeState(chanState chantypes.State, upgradeExists bool) (UpgradeState, error) { + switch chanState { + case chantypes.OPEN: + if upgradeExists { + return UPGRADE_STATE_INIT, nil + } else { + return UPGRADE_STATE_UNINIT, nil + } + case chantypes.FLUSHING: + return UPGRADE_STATE_FLUSHING, nil + case chantypes.FLUSHCOMPLETE: + return UPGRADE_STATE_FLUSHCOMPLETE, nil + default: + return 0, fmt.Errorf("channel not opened yet: state=%s", chanState) + } +} + +func upgradeChannelStep(src, dst *ProvableChain, targetSrcState, targetDstState UpgradeState, firstCall bool) (*RelayMsgs, error) { logger := GetChannelPairLogger(src, dst) out := NewRelayMsgs() @@ -231,302 +318,272 @@ func upgradeChannelStep(src, dst *ProvableChain, untilFlushing bool) (*RelayMsgs return out, nil } - // translate channel state to channel upgrade state - type UpgradeState chantypes.State - const ( - UPGRADEUNINIT = UpgradeState(chantypes.OPEN) - UPGRADEINIT = UpgradeState(math.MaxInt32) - FLUSHING = UpgradeState(chantypes.FLUSHING) - FLUSHCOMPLETE = UpgradeState(chantypes.FLUSHCOMPLETE) - ) - srcState := UpgradeState(srcChan.Channel.State) - if srcState == UPGRADEUNINIT && srcChanUpg != nil { - srcState = UPGRADEINIT - } - dstState := UpgradeState(dstChan.Channel.State) - if dstState == UPGRADEUNINIT && dstChanUpg != nil { - dstState = UPGRADEINIT - } - - doTry := func(chain *ProvableChain, cpCtx QueryContext, cp *ProvableChain, cpHeaders []Header, cpChan *chantypes.QueryChannelResponse, cpChanUpg *chantypes.QueryUpgradeResponse) ([]sdk.Msg, error) { - proposedConnectionID, err := queryProposedConnectionID(cpCtx, cp, cpChanUpg) - if err != nil { - return nil, err - } - var msgs []sdk.Msg - addr := mustGetAddress(chain) - if len(cpHeaders) > 0 { - msgs = append(msgs, chain.Path().UpdateClients(cpHeaders, addr)...) - } - msgs = append(msgs, chain.Path().ChanUpgradeTry(proposedConnectionID, cpChan, cpChanUpg, addr)) - return msgs, nil - } - - doAck := func(chain *ProvableChain, cpHeaders []Header, cpChan *chantypes.QueryChannelResponse, cpChanUpg *chantypes.QueryUpgradeResponse) []sdk.Msg { - var msgs []sdk.Msg - addr := mustGetAddress(chain) - if len(cpHeaders) > 0 { - msgs = append(msgs, chain.Path().UpdateClients(cpHeaders, addr)...) - } - msgs = append(msgs, chain.Path().ChanUpgradeAck(cpChan, cpChanUpg, addr)) - return msgs - } - - doConfirm := func(chain *ProvableChain, cpHeaders []Header, cpChan *chantypes.QueryChannelResponse, cpChanUpg *chantypes.QueryUpgradeResponse) []sdk.Msg { - var msgs []sdk.Msg - addr := mustGetAddress(chain) - if len(cpHeaders) > 0 { - msgs = append(msgs, chain.Path().UpdateClients(cpHeaders, addr)...) - } - msgs = append(msgs, chain.Path().ChanUpgradeConfirm(cpChan, cpChanUpg, addr)) - return msgs - } - - doOpen := func(chain *ProvableChain, cpHeaders []Header, cpChan *chantypes.QueryChannelResponse) []sdk.Msg { - var msgs []sdk.Msg - addr := mustGetAddress(chain) - if len(cpHeaders) > 0 { - msgs = append(msgs, chain.Path().UpdateClients(cpHeaders, addr)...) - } - msgs = append(msgs, chain.Path().ChanUpgradeOpen(cpChan, addr)) - return msgs + // determine upgrade states + srcState, err := NewUpgradeState(srcChan.Channel.State, srcChanUpg != nil) + if err != nil { + return nil, err } - - doCancel := func(chain *ProvableChain, cpCtx QueryContext, cp *ProvableChain, cpHeaders []Header, upgradeSequence uint64) ([]sdk.Msg, error) { - cpChanUpgErr, err := QueryChannelUpgradeError(cpCtx, cp, upgradeSequence, true) - if err != nil { - return nil, err - } - - if cpChanUpgErr == nil { - logger.Warn("error receipt not found", "seq", upgradeSequence, "chain_id", cp.ChainID()) - } - - var msgs []sdk.Msg - addr := mustGetAddress(chain) - if len(cpHeaders) > 0 { - msgs = append(msgs, chain.Path().UpdateClients(cpHeaders, addr)...) - } - msgs = append(msgs, chain.Path().ChanUpgradeCancel(cpChanUpgErr, addr)) - return msgs, nil + dstState, err := NewUpgradeState(dstChan.Channel.State, dstChanUpg != nil) + if err != nil { + return nil, err } - doTimeout := func(chain *ProvableChain, cpHeaders []Header, cpChan *chantypes.QueryChannelResponse) []sdk.Msg { - var msgs []sdk.Msg - addr := mustGetAddress(chain) - if len(cpHeaders) > 0 { - msgs = append(msgs, chain.Path().UpdateClients(cpHeaders, addr)...) - } - msgs = append(msgs, chain.Path().ChanUpgradeTimeout(cpChan, addr)) - return msgs - } + // check if both chains have reached the target states + if firstCall && srcState == UPGRADE_STATE_INIT && dstState == UPGRADE_STATE_INIT { + return nil, errors.New("channel upgrade is not initialized") + } else if checkIfStateReached(srcState, targetSrcState) && checkIfStateReached(dstState, targetDstState) { + out.Last = true + return nil, nil + } + + // add info to logger + logger = logger.With( + slog.Group("src", + "state", srcState, + "seq", srcChan.Channel.UpgradeSequence, + ), + slog.Group("dst", + "state", dstState, + "seq", dstChan.Channel.UpgradeSequence, + ), + ) + // determine next actions for src/dst chains + srcAction := UPGRADE_ACTION_NONE + dstAction := UPGRADE_ACTION_NONE switch { - case srcState == UPGRADEUNINIT && dstState == UPGRADEUNINIT: - return nil, errors.New("channel upgrade is not initialized") - case srcState == UPGRADEINIT && dstState == UPGRADEUNINIT: + case srcState == UPGRADE_STATE_UNINIT && dstState == UPGRADE_STATE_UNINIT: + if firstCall { + return nil, errors.New("channel upgrade is not initialized") + } + case srcState == UPGRADE_STATE_INIT && dstState == UPGRADE_STATE_UNINIT: if dstChan.Channel.UpgradeSequence >= srcChan.Channel.UpgradeSequence { - logger.Warn("the initialized channel upgrade is outdated", - "src_seq", srcChan.Channel.UpgradeSequence, - "dst_seq", dstChan.Channel.UpgradeSequence, - ) - if out.Src, err = doCancel(src, dstCtxFinalized, dst, dstUpdateHeaders, 0); err != nil { - return nil, err - } - out.Last = true + srcAction = UPGRADE_ACTION_CANCEL } else { - if out.Dst, err = doTry(dst, srcCtxFinalized, src, srcUpdateHeaders, srcChan, srcChanUpg); err != nil { - return nil, err - } + dstAction = UPGRADE_ACTION_TRY } - case srcState == UPGRADEUNINIT && dstState == UPGRADEINIT: + case srcState == UPGRADE_STATE_UNINIT && dstState == UPGRADE_STATE_INIT: if srcChan.Channel.UpgradeSequence >= dstChan.Channel.UpgradeSequence { - logger.Warn("the initialized channel upgrade is outdated", - "src_seq", srcChan.Channel.UpgradeSequence, - "dst_seq", dstChan.Channel.UpgradeSequence, - ) - if out.Dst, err = doCancel(dst, srcCtxFinalized, src, srcUpdateHeaders, 0); err != nil { - return nil, err - } - out.Last = true + dstAction = UPGRADE_ACTION_CANCEL } else { - if out.Src, err = doTry(src, dstCtxFinalized, dst, dstUpdateHeaders, dstChan, dstChanUpg); err != nil { - return nil, err - } - } - case srcState == UPGRADEUNINIT && dstState == FLUSHING: - if out.Dst, err = doCancel(dst, srcCtxFinalized, src, srcUpdateHeaders, 0); err != nil { - return nil, err - } - out.Last = true - case srcState == FLUSHING && dstState == UPGRADEUNINIT: - if out.Src, err = doCancel(src, dstCtxFinalized, dst, dstUpdateHeaders, 0); err != nil { - return nil, err + srcAction = UPGRADE_ACTION_TRY } - out.Last = true - case srcState == UPGRADEUNINIT && dstState == FLUSHCOMPLETE: + case srcState == UPGRADE_STATE_UNINIT && dstState == UPGRADE_STATE_FLUSHING: + dstAction = UPGRADE_ACTION_CANCEL + case srcState == UPGRADE_STATE_FLUSHING && dstState == UPGRADE_STATE_UNINIT: + srcAction = UPGRADE_ACTION_CANCEL + case srcState == UPGRADE_STATE_UNINIT && dstState == UPGRADE_STATE_FLUSHCOMPLETE: if complete, err := upgradeAlreadyComplete(srcChan, dstCtxFinalized, dst, dstChanUpg); err != nil { return nil, err } else if complete { - out.Dst = doOpen(dst, srcUpdateHeaders, srcChan) + dstAction = UPGRADE_ACTION_OPEN } else if timedout, err := upgradeAlreadyTimedOut(srcCtxFinalized, src, dstChanUpg); err != nil { return nil, err } else if timedout { - out.Dst = doTimeout(dst, srcUpdateHeaders, srcChan) + dstAction = UPGRADE_ACTION_TIMEOUT } else { - if out.Dst, err = doCancel(dst, srcCtxFinalized, src, srcUpdateHeaders, dstChan.Channel.UpgradeSequence); err != nil { - return nil, err - } + dstAction = UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE } - out.Last = true - case srcState == FLUSHCOMPLETE && dstState == UPGRADEUNINIT: + case srcState == UPGRADE_STATE_FLUSHCOMPLETE && dstState == UPGRADE_STATE_UNINIT: if complete, err := upgradeAlreadyComplete(dstChan, srcCtxFinalized, src, srcChanUpg); err != nil { return nil, err } else if complete { - out.Src = doOpen(src, dstUpdateHeaders, dstChan) + srcAction = UPGRADE_ACTION_OPEN } else if timedout, err := upgradeAlreadyTimedOut(dstCtxFinalized, dst, srcChanUpg); err != nil { return nil, err } else if timedout { - out.Src = doTimeout(src, dstUpdateHeaders, dstChan) + srcAction = UPGRADE_ACTION_TIMEOUT } else { - if out.Src, err = doCancel(src, dstCtxFinalized, dst, dstUpdateHeaders, srcChan.Channel.UpgradeSequence); err != nil { - return nil, err - } + srcAction = UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE } - out.Last = true - case srcState == UPGRADEINIT && dstState == UPGRADEINIT: // crossing hellos - if srcChan.Channel.UpgradeSequence > dstChan.Channel.UpgradeSequence { - if out.Dst, err = doTry(dst, srcCtxFinalized, src, srcUpdateHeaders, srcChan, srcChanUpg); err != nil { - return nil, err - } - } else if srcChan.Channel.UpgradeSequence < dstChan.Channel.UpgradeSequence { - if out.Src, err = doTry(src, dstCtxFinalized, dst, dstUpdateHeaders, dstChan, dstChanUpg); err != nil { - return nil, err - } - } else { - // it is intentional to execute chanUpgradeTry on both sides if upgrade sequences - // are identical to each other. this is for testing purpose. - if out.Dst, err = doTry(dst, srcCtxFinalized, src, srcUpdateHeaders, srcChan, srcChanUpg); err != nil { - return nil, err - } - if out.Src, err = doTry(src, dstCtxFinalized, dst, dstUpdateHeaders, dstChan, dstChanUpg); err != nil { - return nil, err - } - out.Last = untilFlushing + case srcState == UPGRADE_STATE_INIT && dstState == UPGRADE_STATE_INIT: // crossing hellos + // it is intentional to execute chanUpgradeTry on both sides if upgrade sequences + // are identical to each other. this is for testing purpose. + if srcChan.Channel.UpgradeSequence >= dstChan.Channel.UpgradeSequence { + dstAction = UPGRADE_ACTION_TRY + } + if srcChan.Channel.UpgradeSequence <= dstChan.Channel.UpgradeSequence { + srcAction = UPGRADE_ACTION_TRY } - case srcState == UPGRADEINIT && dstState == FLUSHING: + case srcState == UPGRADE_STATE_INIT && dstState == UPGRADE_STATE_FLUSHING: if srcChan.Channel.UpgradeSequence != dstChan.Channel.UpgradeSequence { - if out.Dst, err = doCancel(dst, srcCtxFinalized, src, srcUpdateHeaders, 0); err != nil { - return nil, err - } + dstAction = UPGRADE_ACTION_CANCEL } else { // chanUpgradeAck checks if counterparty-specified timeout has exceeded. // if it has, chanUpgradeAck aborts the upgrade handshake. // Therefore the relayer need not check timeout by itself. - out.Src = doAck(src, dstUpdateHeaders, dstChan, dstChanUpg) - out.Last = untilFlushing + srcAction = UPGRADE_ACTION_ACK } - case srcState == FLUSHING && dstState == UPGRADEINIT: + case srcState == UPGRADE_STATE_FLUSHING && dstState == UPGRADE_STATE_INIT: if srcChan.Channel.UpgradeSequence != dstChan.Channel.UpgradeSequence { - if out.Src, err = doCancel(src, dstCtxFinalized, dst, dstUpdateHeaders, 0); err != nil { - return nil, err - } + srcAction = UPGRADE_ACTION_CANCEL } else { // chanUpgradeAck checks if counterparty-specified timeout has exceeded. // if it has, chanUpgradeAck aborts the upgrade handshake. // Therefore the relayer need not check timeout by itself. - out.Dst = doAck(dst, srcUpdateHeaders, srcChan, srcChanUpg) - out.Last = true + dstAction = UPGRADE_ACTION_ACK } - case srcState == UPGRADEINIT && dstState == FLUSHCOMPLETE: + case srcState == UPGRADE_STATE_INIT && dstState == UPGRADE_STATE_FLUSHCOMPLETE: if complete, err := upgradeAlreadyComplete(srcChan, dstCtxFinalized, dst, dstChanUpg); err != nil { return nil, err } else if complete { - out.Dst = doOpen(dst, srcUpdateHeaders, srcChan) + dstAction = UPGRADE_ACTION_OPEN } else if timedout, err := upgradeAlreadyTimedOut(srcCtxFinalized, src, dstChanUpg); err != nil { return nil, err } else if timedout { - out.Dst = doTimeout(dst, srcUpdateHeaders, srcChan) + dstAction = UPGRADE_ACTION_TIMEOUT } else { - if out.Dst, err = doCancel(dst, srcCtxFinalized, src, srcUpdateHeaders, dstChan.Channel.UpgradeSequence); err != nil { - return nil, err - } + dstAction = UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE } - case srcState == FLUSHCOMPLETE && dstState == UPGRADEINIT: + case srcState == UPGRADE_STATE_FLUSHCOMPLETE && dstState == UPGRADE_STATE_INIT: if complete, err := upgradeAlreadyComplete(dstChan, srcCtxFinalized, src, srcChanUpg); err != nil { return nil, err } else if complete { - out.Src = doOpen(src, dstUpdateHeaders, dstChan) + srcAction = UPGRADE_ACTION_OPEN } else if timedout, err := upgradeAlreadyTimedOut(dstCtxFinalized, dst, srcChanUpg); err != nil { return nil, err } else if timedout { - out.Src = doTimeout(src, dstUpdateHeaders, dstChan) + srcAction = UPGRADE_ACTION_TIMEOUT } else { - if out.Src, err = doCancel(src, dstCtxFinalized, dst, dstUpdateHeaders, srcChan.Channel.UpgradeSequence); err != nil { - return nil, err - } + srcAction = UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE } - case srcState == FLUSHING && dstState == FLUSHING: - nTimedout := 0 + case srcState == UPGRADE_STATE_FLUSHING && dstState == UPGRADE_STATE_FLUSHING: if timedout, err := upgradeAlreadyTimedOut(srcCtxFinalized, src, dstChanUpg); err != nil { return nil, err } else if timedout { - nTimedout += 1 - out.Dst = doTimeout(dst, srcUpdateHeaders, srcChan) + dstAction = UPGRADE_ACTION_TIMEOUT } if timedout, err := upgradeAlreadyTimedOut(dstCtxFinalized, dst, srcChanUpg); err != nil { return nil, err } else if timedout { - nTimedout += 1 - out.Src = doTimeout(src, dstUpdateHeaders, dstChan) + srcAction = UPGRADE_ACTION_TIMEOUT } - // if any chains have exceeded timeout, never execute chanUpgradeConfirm - if nTimedout > 0 { + // if either chain has already timed out, never execute chanUpgradeConfirm + if srcAction == UPGRADE_ACTION_TIMEOUT || dstAction == UPGRADE_ACTION_TIMEOUT { break } if completable, err := src.QueryCanTransitionToFlushComplete(srcCtxFinalized); err != nil { return nil, err } else if completable { - out.Src = doConfirm(src, dstUpdateHeaders, dstChan, dstChanUpg) + srcAction = UPGRADE_ACTION_CONFIRM } if completable, err := dst.QueryCanTransitionToFlushComplete(dstCtxFinalized); err != nil { return nil, err } else if completable { - out.Dst = doConfirm(dst, srcUpdateHeaders, srcChan, srcChanUpg) + dstAction = UPGRADE_ACTION_CONFIRM } - case srcState == FLUSHING && dstState == FLUSHCOMPLETE: + case srcState == UPGRADE_STATE_FLUSHING && dstState == UPGRADE_STATE_FLUSHCOMPLETE: if timedout, err := upgradeAlreadyTimedOut(srcCtxFinalized, src, dstChanUpg); err != nil { return nil, err } else if timedout { - out.Dst = doTimeout(dst, srcUpdateHeaders, srcChan) + dstAction = UPGRADE_ACTION_TIMEOUT } else if completable, err := src.QueryCanTransitionToFlushComplete(srcCtxFinalized); err != nil { return nil, err } else if completable { - out.Src = doConfirm(src, dstUpdateHeaders, dstChan, dstChanUpg) + srcAction = UPGRADE_ACTION_CONFIRM } - case srcState == FLUSHCOMPLETE && dstState == FLUSHING: + case srcState == UPGRADE_STATE_FLUSHCOMPLETE && dstState == UPGRADE_STATE_FLUSHING: if timedout, err := upgradeAlreadyTimedOut(dstCtxFinalized, dst, srcChanUpg); err != nil { return nil, err } else if timedout { - out.Src = doTimeout(src, dstUpdateHeaders, dstChan) + srcAction = UPGRADE_ACTION_TIMEOUT } else if completable, err := dst.QueryCanTransitionToFlushComplete(dstCtxFinalized); err != nil { return nil, err } else if completable { - out.Dst = doConfirm(dst, srcUpdateHeaders, srcChan, srcChanUpg) + dstAction = UPGRADE_ACTION_CONFIRM } - case srcState == FLUSHCOMPLETE && dstState == FLUSHCOMPLETE: - out.Src = doOpen(src, dstUpdateHeaders, dstChan) - out.Dst = doOpen(dst, srcUpdateHeaders, srcChan) - out.Last = true + case srcState == UPGRADE_STATE_FLUSHCOMPLETE && dstState == UPGRADE_STATE_FLUSHCOMPLETE: + srcAction = UPGRADE_ACTION_OPEN + dstAction = UPGRADE_ACTION_OPEN default: return nil, errors.New("unexpected state") } + if srcAction != UPGRADE_ACTION_NONE { + addr := mustGetAddress(src) + + if len(dstUpdateHeaders) > 0 { + out.Src = append(out.Src, src.Path().UpdateClients(dstUpdateHeaders, addr)...) + } + + msg, err := buildActionMsg( + src, + srcAction, + srcChan, + addr, + dstCtxFinalized, + dst, + dstChan, + dstChanUpg, + ) + if err != nil { + return nil, err + } + + out.Src = append(out.Src, msg) + } + + if dstAction != UPGRADE_ACTION_NONE { + addr := mustGetAddress(dst) + + if len(srcUpdateHeaders) > 0 { + out.Dst = append(out.Dst, dst.Path().UpdateClients(srcUpdateHeaders, addr)...) + } + + msg, err := buildActionMsg( + dst, + dstAction, + dstChan, + addr, + srcCtxFinalized, + src, + srcChan, + srcChanUpg, + ) + if err != nil { + return nil, err + } + + out.Dst = append(out.Dst, msg) + } + + // determine whether this turn is "Last" or not + out.Last = checkIfTargetStateReached(targetSrcState, srcState, srcAction) && + checkIfTargetStateReached(targetDstState, dstState, dstAction) + return out, nil } +// checkIfTargetStateReached returns true if: +// - The next upgrade state is `UNINIT` +// - The next upgrade state is greater than or equal to `target`. +func checkIfTargetStateReached(target, current UpgradeState, action UpgradeAction) bool { + var nextMinState UpgradeState + switch action { + case UPGRADE_ACTION_NONE: + nextMinState = current + case UPGRADE_ACTION_TRY: + nextMinState = UPGRADE_STATE_FLUSHING + case UPGRADE_ACTION_ACK: + nextMinState = UPGRADE_STATE_FLUSHING + case UPGRADE_ACTION_CONFIRM: + // CONFIRM is executed only when `QueryCanTransitionToFlushComplete` returns true + nextMinState = UPGRADE_STATE_FLUSHCOMPLETE + case UPGRADE_ACTION_OPEN: + nextMinState = UPGRADE_STATE_UNINIT + case UPGRADE_ACTION_CANCEL: + nextMinState = UPGRADE_STATE_UNINIT + case UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE: + nextMinState = UPGRADE_STATE_UNINIT + case UPGRADE_ACTION_TIMEOUT: + nextMinState = UPGRADE_STATE_UNINIT + } + + return nextMinState == UPGRADE_STATE_UNINIT || nextMinState >= target +} + func queryProposedConnectionID(cpCtx QueryContext, cp *ProvableChain, cpChanUpg *chantypes.QueryUpgradeResponse) (string, error) { if cpConn, err := cp.QueryConnection( cpCtx, @@ -624,3 +681,49 @@ func upgradeAlreadyTimedOut( } return cpChanUpg.Upgrade.Timeout.Elapsed(height, uint64(timestamp.UnixNano())), nil } + +// buildActionMsg builds and returns a MsgChannelUpgradeXXX message corresponding to `action`. +// This function also returns `UpgradeState` to which the channel will transition after the message is processed. +func buildActionMsg( + chain *ProvableChain, + action UpgradeAction, + selfChan *chantypes.QueryChannelResponse, + addr sdk.AccAddress, + cpCtx QueryContext, + cp *ProvableChain, + cpChan *chantypes.QueryChannelResponse, + cpUpg *chantypes.QueryUpgradeResponse, +) (sdk.Msg, error) { + pathEnd := chain.Path() + + switch action { + case UPGRADE_ACTION_TRY: + proposedConnectionID, err := queryProposedConnectionID(cpCtx, cp, cpUpg) + if err != nil { + return nil, err + } + return pathEnd.ChanUpgradeTry(proposedConnectionID, cpChan, cpUpg, addr), nil + case UPGRADE_ACTION_ACK: + return pathEnd.ChanUpgradeAck(cpChan, cpUpg, addr), nil + case UPGRADE_ACTION_CONFIRM: + return pathEnd.ChanUpgradeConfirm(cpChan, cpUpg, addr), nil + case UPGRADE_ACTION_OPEN: + return pathEnd.ChanUpgradeOpen(cpChan, addr), nil + case UPGRADE_ACTION_CANCEL: + upgErr, err := QueryChannelUpgradeError(cpCtx, cp, 0, true) + if err != nil { + return nil, err + } + return pathEnd.ChanUpgradeCancel(upgErr, addr), nil + case UPGRADE_ACTION_CANCEL_FLUSHCOMPLETE: + upgErr, err := QueryChannelUpgradeError(cpCtx, cp, selfChan.Channel.UpgradeSequence, true) + if err != nil { + return nil, err + } + return pathEnd.ChanUpgradeCancel(upgErr, addr), nil + case UPGRADE_ACTION_TIMEOUT: + return pathEnd.ChanUpgradeTimeout(cpChan, addr), nil + default: + panic(fmt.Errorf("unexpected action: %s", action)) + } +}