From 154f046a2f48a11393fa8dfad88e3c2cee4947b4 Mon Sep 17 00:00:00 2001 From: "rust.dev" <102041955+RustNinja@users.noreply.github.com> Date: Wed, 22 May 2024 12:44:17 +0100 Subject: [PATCH] clone the logic from original version to custom pfm to introduce fees. (#518) all test passed. --- app/keepers/keepers.go | 1 + custom/custompfm/keeper/keeper.go | 222 +++++++++++++++++++++++++++++- 2 files changed, 222 insertions(+), 1 deletion(-) diff --git a/app/keepers/keepers.go b/app/keepers/keepers.go index 613ecf54..211b340e 100644 --- a/app/keepers/keepers.go +++ b/app/keepers/keepers.go @@ -374,6 +374,7 @@ func (appKeepers *AppKeepers) InitNormalKeepers( 0, routerkeeper.DefaultForwardTransferPacketTimeoutTimestamp, routerkeeper.DefaultRefundTransferPacketTimeoutTimestamp, + appKeepers.TransferMiddlewareKeeper, ) ratelimitMiddlewareStack := ratelimitmodule.NewIBCMiddleware(appKeepers.RatelimitKeeper, ibcMiddlewareStack) hooksTransferMiddleware := ibc_hooks.NewIBCMiddleware(ratelimitMiddlewareStack, &appKeepers.HooksICS4Wrapper) diff --git a/custom/custompfm/keeper/keeper.go b/custom/custompfm/keeper/keeper.go index e38e19b1..b9b7476f 100644 --- a/custom/custompfm/keeper/keeper.go +++ b/custom/custompfm/keeper/keeper.go @@ -1,14 +1,22 @@ package keeper import ( + "encoding/json" + "fmt" + "strings" "time" + "github.com/armon/go-metrics" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/address" router "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v7/packetforward" "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v7/packetforward/keeper" + "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v7/packetforward/types" + transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" + ibctransfermiddlewarekeeper "github.com/notional-labs/composable/v6/x/transfermiddleware/keeper" ) var _ porttypes.Middleware = &IBCMiddleware{} @@ -17,6 +25,14 @@ var _ porttypes.Middleware = &IBCMiddleware{} // forward keeper and the underlying application. type IBCMiddleware struct { router.IBCMiddleware + ibcfeekeeper ibctransfermiddlewarekeeper.Keeper + + app1 porttypes.IBCModule + keeper1 *keeper.Keeper + + retriesOnTimeout1 uint8 + forwardTimeout1 time.Duration + refundTimeout1 time.Duration } func NewIBCMiddleware( @@ -25,9 +41,18 @@ func NewIBCMiddleware( retriesOnTimeout uint8, forwardTimeout time.Duration, refundTimeout time.Duration, + ibcfeekeeper ibctransfermiddlewarekeeper.Keeper, ) IBCMiddleware { return IBCMiddleware{ IBCMiddleware: router.NewIBCMiddleware(app, k, retriesOnTimeout, forwardTimeout, refundTimeout), + ibcfeekeeper: ibcfeekeeper, + + //we need this bz this field is not exported in the parent struct + app1: app, + keeper1: k, + retriesOnTimeout1: retriesOnTimeout, + forwardTimeout1: forwardTimeout, + refundTimeout1: refundTimeout, } } @@ -36,5 +61,200 @@ func (im IBCMiddleware) OnRecvPacket( packet channeltypes.Packet, relayer sdk.AccAddress, ) ibcexported.Acknowledgement { - return im.IBCMiddleware.OnRecvPacket(ctx, packet, relayer) + logger := im.keeper1.Logger(ctx) + + im.ibcfeekeeper.GetParams(ctx) + + var data transfertypes.FungibleTokenPacketData + if err := transfertypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil { + logger.Debug(fmt.Sprintf("packetForwardMiddleware OnRecvPacket payload is not a FungibleTokenPacketData: %s", err.Error())) + return im.IBCMiddleware.OnRecvPacket(ctx, packet, relayer) + } + + logger.Debug("packetForwardMiddleware OnRecvPacket", + "sequence", packet.Sequence, + "src-channel", packet.SourceChannel, "src-port", packet.SourcePort, + "dst-channel", packet.DestinationChannel, "dst-port", packet.DestinationPort, + "amount", data.Amount, "denom", data.Denom, "memo", data.Memo, + ) + + d := make(map[string]interface{}) + err := json.Unmarshal([]byte(data.Memo), &d) + if err != nil || d["forward"] == nil { + // not a packet that should be forwarded + logger.Debug("packetForwardMiddleware OnRecvPacket forward metadata does not exist") + return im.app1.OnRecvPacket(ctx, packet, relayer) + } + m := &types.PacketMetadata{} + err = json.Unmarshal([]byte(data.Memo), m) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error parsing forward metadata", "error", err) + return newErrorAcknowledgement(fmt.Errorf("error parsing forward metadata: %w", err)) + } + + metadata := m.Forward + + goCtx := ctx.Context() + processed := getBoolFromAny(goCtx.Value(types.ProcessedKey{})) + nonrefundable := getBoolFromAny(goCtx.Value(types.NonrefundableKey{})) + disableDenomComposition := getBoolFromAny(goCtx.Value(types.DisableDenomCompositionKey{})) + + if err := metadata.Validate(); err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket forward metadata is invalid", "error", err) + return newErrorAcknowledgement(err) + } + + // override the receiver so that senders cannot move funds through arbitrary addresses. + overrideReceiver, err := getReceiver(packet.DestinationChannel, data.Sender) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket failed to construct override receiver", "error", err) + return newErrorAcknowledgement(fmt.Errorf("failed to construct override receiver: %w", err)) + } + + // if this packet has been handled by another middleware in the stack there may be no need to call into the + // underlying app, otherwise the transfer module's OnRecvPacket callback could be invoked more than once + // which would mint/burn vouchers more than once + if !processed { + if err := im.receiveFunds(ctx, packet, data, overrideReceiver, relayer); err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error receiving packet", "error", err) + return newErrorAcknowledgement(fmt.Errorf("error receiving packet: %w", err)) + } + } + + // if this packet's token denom is already the base denom for some native token on this chain, + // we do not need to do any further composition of the denom before forwarding the packet + denomOnThisChain := data.Denom + // Check if the packet was sent from Picasso + paraChainIBCTokenInfo, found := im.keeper1.GetParachainTokenInfoByAssetID(ctx, data.Denom) + if found && (paraChainIBCTokenInfo.ChannelID == packet.DestinationChannel) { + disableDenomComposition = true + denomOnThisChain = paraChainIBCTokenInfo.NativeDenom + } + + if !disableDenomComposition { + denomOnThisChain = getDenomForThisChain( + packet.DestinationPort, packet.DestinationChannel, + packet.SourcePort, packet.SourceChannel, + data.Denom, + ) + } + + amountInt, ok := sdk.NewIntFromString(data.Amount) + if !ok { + logger.Error("packetForwardMiddleware OnRecvPacket error parsing amount for forward", "amount", data.Amount) + return newErrorAcknowledgement(fmt.Errorf("error parsing amount for forward: %s", data.Amount)) + } + + token := sdk.NewCoin(denomOnThisChain, amountInt) + + timeout := time.Duration(metadata.Timeout) + + if timeout.Nanoseconds() <= 0 { + timeout = im.forwardTimeout1 + } + + var retries uint8 + if metadata.Retries != nil { + retries = *metadata.Retries + } else { + retries = im.retriesOnTimeout1 + } + + err = im.keeper1.ForwardTransferPacket(ctx, nil, packet, data.Sender, overrideReceiver, metadata, token, retries, timeout, []metrics.Label{}, nonrefundable) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error forwarding packet", "error", err) + return newErrorAcknowledgement(err) + } + + // returning nil ack will prevent WriteAcknowledgement from occurring for forwarded packet. + // This is intentional so that the acknowledgement will be written later based on the ack/timeout of the forwarded packet. + return nil + + // charge_coin := sdk.NewCoin(packet.Token.Denom, sdk.ZeroInt()) + // return channeltypes.NewErrorAcknowledgement(fmt.Errorf("error parsing forward metadata")) + // return im.IBCMiddleware.OnRecvPacket(ctx, packet, relayer) +} + +func newErrorAcknowledgement(err error) channeltypes.Acknowledgement { + return channeltypes.Acknowledgement{ + Response: &channeltypes.Acknowledgement_Error{ + Error: fmt.Sprintf("packet-forward-middleware error: %s", err.Error()), + }, + } +} + +func getBoolFromAny(value any) bool { + if value == nil { + return false + } + boolVal, ok := value.(bool) + if !ok { + return false + } + return boolVal +} + +func getReceiver(channel string, originalSender string) (string, error) { + senderStr := fmt.Sprintf("%s/%s", channel, originalSender) + senderHash32 := address.Hash(types.ModuleName, []byte(senderStr)) + sender := sdk.AccAddress(senderHash32[:20]) + bech32Prefix := sdk.GetConfig().GetBech32AccountAddrPrefix() + return sdk.Bech32ifyAddressBytes(bech32Prefix, sender) +} + +func (im IBCMiddleware) receiveFunds( + ctx sdk.Context, + packet channeltypes.Packet, + data transfertypes.FungibleTokenPacketData, + overrideReceiver string, + relayer sdk.AccAddress, +) error { + overrideData := transfertypes.FungibleTokenPacketData{ + Denom: data.Denom, + Amount: data.Amount, + Sender: data.Sender, + Receiver: overrideReceiver, // override receiver + // Memo explicitly zeroed + } + overrideDataBz := transfertypes.ModuleCdc.MustMarshalJSON(&overrideData) + overridePacket := channeltypes.Packet{ + Sequence: packet.Sequence, + SourcePort: packet.SourcePort, + SourceChannel: packet.SourceChannel, + DestinationPort: packet.DestinationPort, + DestinationChannel: packet.DestinationChannel, + Data: overrideDataBz, // override data + TimeoutHeight: packet.TimeoutHeight, + TimeoutTimestamp: packet.TimeoutTimestamp, + } + + ack := im.app1.OnRecvPacket(ctx, overridePacket, relayer) + + if ack == nil { + return fmt.Errorf("ack is nil") + } + + if !ack.Success() { + return fmt.Errorf("ack error: %s", string(ack.Acknowledgement())) + } + + return nil +} + +func getDenomForThisChain(port, channel, counterpartyPort, counterpartyChannel, denom string) string { + counterpartyPrefix := transfertypes.GetDenomPrefix(counterpartyPort, counterpartyChannel) + if strings.HasPrefix(denom, counterpartyPrefix) { + // unwind denom + unwoundDenom := denom[len(counterpartyPrefix):] + denomTrace := transfertypes.ParseDenomTrace(unwoundDenom) + if denomTrace.Path == "" { + // denom is now unwound back to native denom + return unwoundDenom + } + // denom is still IBC denom + return denomTrace.IBCDenom() + } + // append port and channel from this chain to denom + prefixedDenom := transfertypes.GetDenomPrefix(port, channel) + denom + return transfertypes.ParseDenomTrace(prefixedDenom).IBCDenom() }