diff --git a/modules/core/04-channel/keeper/keeper.go b/modules/core/04-channel/keeper/keeper.go index e498e2a8d18..8474f42b706 100644 --- a/modules/core/04-channel/keeper/keeper.go +++ b/modules/core/04-channel/keeper/keeper.go @@ -547,6 +547,12 @@ func (k Keeper) deleteUpgrade(ctx sdk.Context, portID, channelID string) { store.Delete(host.ChannelUpgradeKey(portID, channelID)) } +// hasCounterpartyUpgrade returns true if a counterparty upgrade exists in store +func (k Keeper) hasCounterpartyUpgrade(ctx sdk.Context, portID, channelID string) bool { + store := ctx.KVStore(k.storeKey) + return store.Has(host.ChannelCounterpartyUpgradeKey(portID, channelID)) +} + // GetCounterpartyUpgrade gets the counterparty upgrade from the store. func (k Keeper) GetCounterpartyUpgrade(ctx sdk.Context, portID, channelID string) (types.Upgrade, bool) { store := ctx.KVStore(k.storeKey) diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index b6fbcf61b01..1dae7f7a862 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -55,6 +55,7 @@ func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID strin if k.hasUpgrade(ctx, portID, channelID) { // invalidating previous upgrade + k.deleteUpgradeInfo(ctx, portID, channelID) k.WriteErrorReceipt(ctx, portID, channelID, types.NewUpgradeError(channel.UpgradeSequence, types.ErrInvalidUpgrade)) } @@ -313,6 +314,9 @@ func (k Keeper) ChanUpgradeAck( return errorsmod.Wrap(err, "failed to verify counterparty upgrade") } + // if we have cancelled our upgrade after performing UpgradeInit + // or UpgradeTry, the lack of a stored upgrade will prevent us from + // continuing the upgrade handshake upgrade, found := k.GetUpgrade(ctx, portID, channelID) if !found { return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID) @@ -786,6 +790,14 @@ func (k Keeper) ChanUpgradeTimeout( return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "counterparty channel upgrade sequence (%d) must be greater than or equal to current upgrade sequence (%d)", counterpartyChannel.UpgradeSequence, channel.UpgradeSequence) } + // NOTE: The counterpartyChannel upgrade fields are not checked in the case + // the counterpartyChannel is in FLUSHING. This is not required because + // we prove that the upgrade timeout has elapsed on the counterparty, + // thus no historical proofs can be submitted. It is not possible for the + // counterparty to have upgraded if they were in FLUSHING and the upgrade + // timeout elapsed. Do not make use of the relayer provided fields without + // verifying them. + // verify the counterparty channel state if err := k.connectionKeeper.VerifyChannelState( ctx, @@ -964,7 +976,9 @@ func (k Keeper) MustAbortUpgrade(ctx sdk.Context, portID, channelID string, err } // abortUpgrade will restore the channel state to its pre-upgrade state so that upgrade is aborted. -// Any unnecessary state is delete and an error receipt is written. +// All upgrade information associated with the upgrade attempt is deleted and an upgrade error +// receipt is written for that upgrade attempt. This prevents the upgrade handshake from continuing +// on our side and provides proof for the counterparty to safely abort the upgrade. func (k Keeper) abortUpgrade(ctx sdk.Context, portID, channelID string, err error) error { if err == nil { return errorsmod.Wrap(types.ErrInvalidUpgradeError, "cannot abort upgrade handshake with nil error") @@ -991,6 +1005,8 @@ func (k Keeper) abortUpgrade(ctx sdk.Context, portID, channelID string, err erro } // restoreChannel will restore the channel state to its pre-upgrade state so that upgrade is aborted. +// When an upgrade attempt is aborted, the upgrade information must be deleted. This prevents us +// from continuing an upgrade handshake after we cancel an upgrade attempt. func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgradeSequence uint64, channel types.Channel) types.Channel { channel.State = types.OPEN channel.UpgradeSequence = upgradeSequence @@ -1017,6 +1033,17 @@ func (k Keeper) WriteErrorReceipt(ctx sdk.Context, portID, channelID string, upg panic(errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than existing error receipt sequence (%d)", errorReceiptToWrite.Sequence, existingErrorReceipt.Sequence)) } + // Ensure that no upgrade attempt exists for the same sequence we are + // writing an error receipt for. This could lead to divergent behaviour + // on the counterparty. + if channel.UpgradeSequence <= errorReceiptToWrite.Sequence { + upgradeFound := k.hasUpgrade(ctx, portID, channelID) + counterpartyUpgradeFound := k.hasCounterpartyUpgrade(ctx, portID, channelID) + if upgradeFound || counterpartyUpgradeFound { + panic(errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "attempting to write error receipt at sequence (%d) while upgrade information exists at the same sequence", errorReceiptToWrite.Sequence)) + } + } + k.setUpgradeErrorReceipt(ctx, portID, channelID, errorReceiptToWrite) EmitErrorReceiptEvent(ctx, portID, channelID, channel, upgradeError) } diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index 91982422373..a67ffa12682 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -2810,6 +2810,18 @@ func (suite *KeeperTestSuite) TestWriteErrorReceipt() { }, errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "error receipt sequence (10) must be greater than existing error receipt sequence (11)"), }, + { + "failure: upgrade exists for error receipt being written", + func() { + // attempt to write error receipt for existing upgrade without deleting upgrade info + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + ch := path.EndpointA.GetChannel() + upgradeError = types.NewUpgradeError(ch.UpgradeSequence, types.ErrInvalidUpgrade) + }, + errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "attempting to write error receipt at sequence (1) while upgrade information exists at the same sequence"), + }, { "failure: channel not found", func() {