diff --git a/cmd/tx.go b/cmd/tx.go index 59400686..7046e08b 100644 --- a/cmd/tx.go +++ b/cmd/tx.go @@ -254,16 +254,9 @@ func channelUpgradeInitCmd(ctx *config.Context) *cobra.Command { } // check cp state - if unsafe, err := cmd.Flags().GetBool(flagUnsafe); err != nil { + permitUnsafe, err := cmd.Flags().GetBool(flagUnsafe) + if err != nil { return err - } else if !unsafe { - if height, err := cp.LatestHeight(); err != nil { - return err - } else if chann, err := cp.QueryChannel(core.NewQueryContext(cmd.Context(), height)); err != nil { - return err - } else if state := chann.Channel.State; state >= chantypes.FLUSHING && state <= chantypes.FLUSHCOMPLETE { - return fmt.Errorf("stop channel upgrade initialization because the counterparty is in %v state", state) - } } // get ordering from flags @@ -286,11 +279,16 @@ func channelUpgradeInitCmd(ctx *config.Context) *cobra.Command { return err } - return core.InitChannelUpgrade(chain, chantypes.UpgradeFields{ - Ordering: ordering, - ConnectionHops: connHops, - Version: version, - }) + return core.InitChannelUpgrade( + chain, + cp, + chantypes.UpgradeFields{ + Ordering: ordering, + ConnectionHops: connHops, + Version: version, + }, + permitUnsafe, + ) }, } diff --git a/core/channel-upgrade.go b/core/channel-upgrade.go index 7930a393..eb61b859 100644 --- a/core/channel-upgrade.go +++ b/core/channel-upgrade.go @@ -75,10 +75,40 @@ func (action UpgradeAction) String() string { } // InitChannelUpgrade builds `MsgChannelUpgradeInit` based on the specified UpgradeFields and sends it to the specified chain. -func InitChannelUpgrade(chain *ProvableChain, upgradeFields chantypes.UpgradeFields) error { +func InitChannelUpgrade(chain, cp *ProvableChain, upgradeFields chantypes.UpgradeFields, permitUnsafe bool) error { logger := GetChannelLogger(chain.Chain) defer logger.TimeTrack(time.Now(), "InitChannelUpgrade") + if h, err := chain.LatestHeight(); err != nil { + logger.Error("failed to get the latest height", err) + return err + } else if cpH, err := cp.LatestHeight(); err != nil { + logger.Error("failed to get the latest height of the counterparty chain", err) + return err + } else if chann, cpChann, err := QueryChannelPair( + NewQueryContext(context.TODO(), h), + NewQueryContext(context.TODO(), cpH), + chain, + cp, + false, + ); err != nil { + logger.Error("failed to query for the channel pair", err) + return err + } else if chann.Channel.State != chantypes.OPEN || cpChann.Channel.State != chantypes.OPEN { + logger = &log.RelayLogger{Logger: logger.With( + "channel_state", chann.Channel.State, + "cp_channel_state", cpChann.Channel.State, + )} + + if permitUnsafe { + logger.Info("unsafe channel upgrade is permitted") + } else { + err := errors.New("unsafe channel upgrade initialization") + logger.Error("unsafe channel upgrade is not permitted", err) + return err + } + } + addr, err := chain.GetAddress() if err != nil { logger.Error("failed to get address", err)