diff --git a/rfq/manager.go b/rfq/manager.go index 629c2c0fc..e88448f5d 100644 --- a/rfq/manager.go +++ b/rfq/manager.go @@ -359,8 +359,8 @@ func (m *Manager) handleIncomingMessage(incomingMsg rfqmsg.IncomingMsg) error { // payment by the SCID alias through which it comes in // and compare it to the one in the invoice. err := m.addScidAlias( - uint64(msg.ShortChannelId()), *msg.AssetID, - msg.Peer, + uint64(msg.ShortChannelId()), + *msg.Request.AssetID, msg.Peer, ) if err != nil { m.cfg.ErrChan <- fmt.Errorf("error adding "+ @@ -450,7 +450,8 @@ func (m *Manager) handleOutgoingMessage(outgoingMsg rfqmsg.OutgoingMsg) error { // make sure we can identify the forwarded asset payment by the // outgoing SCID alias within the onion packet. err := m.addScidAlias( - uint64(msg.ShortChannelId()), *msg.AssetID, msg.Peer, + uint64(msg.ShortChannelId()), *msg.Request.AssetID, + msg.Peer, ) if err != nil { return fmt.Errorf("error adding local alias: %w", err) diff --git a/rfq/negotiator.go b/rfq/negotiator.go index 4e4b9fb16..ee52468d3 100644 --- a/rfq/negotiator.go +++ b/rfq/negotiator.go @@ -609,7 +609,7 @@ func (n *Negotiator) HandleIncomingBuyAccept(msg rfqmsg.BuyAccept, // for an ask price. We will then compare the ask price returned // by the price oracle with the ask price provided by the peer. oraclePrice, _, err := n.queryAskFromPriceOracle( - &msg.Peer, msg.AssetID, nil, + &msg.Peer, msg.Request.AssetID, nil, msg.AssetAmount, nil, ) if err != nil { @@ -730,7 +730,7 @@ func (n *Negotiator) HandleIncomingSellAccept(msg rfqmsg.SellAccept, // for a bid price. We will then compare the bid price returned // by the price oracle with the bid price provided by the peer. oraclePrice, _, err := n.queryBidFromPriceOracle( - msg.Peer, msg.AssetID, nil, msg.AssetAmount, + msg.Peer, msg.Request.AssetID, nil, msg.AssetAmount, ) if err != nil { // The price oracle returned an error. We will return diff --git a/rfq/order.go b/rfq/order.go index 6a83ec266..c55b02568 100644 --- a/rfq/order.go +++ b/rfq/order.go @@ -101,7 +101,7 @@ func NewAssetSalePolicy(quote rfqmsg.BuyAccept) *AssetSalePolicy { MaxAssetAmount: quote.AssetAmount, AskPrice: quote.AskPrice, expiry: quote.Expiry, - assetID: quote.AssetID, + assetID: quote.Request.AssetID, } } diff --git a/rfq/stream.go b/rfq/stream.go index 6d3dd95a1..3a23332aa 100644 --- a/rfq/stream.go +++ b/rfq/stream.go @@ -9,6 +9,7 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/taproot-assets/fn" "github.com/lightninglabs/taproot-assets/rfqmsg" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" ) @@ -47,6 +48,14 @@ type StreamHandler struct { // the peer raw messages subscription. errRecvRawMessages <-chan error + // outgoingRequests is a map of request IDs to outgoing requests. + // This map is used to match incoming accept messages to outgoing + // requests. + // + // TODO(ffranr): Periodically remove expired outgoing requests from + // this map. + outgoingRequests lnutils.SyncMap[rfqmsg.ID, rfqmsg.OutgoingMsg] + // ContextGuard provides a wait group and main quit channel that can be // used to create guarded contexts. *fn.ContextGuard @@ -100,6 +109,68 @@ func (h *StreamHandler) handleIncomingWireMessage( log.Debugf("Stream handling incoming message: %s", msg) + // If the incoming message is an accept message, lookup the + // corresponding outgoing request message. Assign the outgoing request + // to a field on the accept message. This step allows us to easily + // access the request that the accept message is responding to. Some of + // the request fields are not present in the accept message. + // + // If the incoming message is a reject message, remove the corresponding + // outgoing request from the store. + switch typedMsg := msg.(type) { + case *rfqmsg.Reject: + // Delete the corresponding outgoing request from the store. + h.outgoingRequests.Delete(typedMsg.ID) + + case *rfqmsg.BuyAccept: + // Load and delete the corresponding outgoing request from the + // store. + outgoingRequest, found := h.outgoingRequests.LoadAndDelete( + typedMsg.ID, + ) + + // Ensure that we have an outgoing request to match the incoming + // accept message. + if !found { + return fmt.Errorf("no outgoing request found for "+ + "incoming accept message: %s", typedMsg.ID) + } + + // Type cast the outgoing message to a BuyRequest (the request + // type that corresponds to a buy accept message). + buyReq, ok := outgoingRequest.(*rfqmsg.BuyRequest) + if !ok { + return fmt.Errorf("expected BuyRequest, got %T", + outgoingRequest) + } + + typedMsg.Request = *buyReq + + case *rfqmsg.SellAccept: + // Load and delete the corresponding outgoing request from the + // store. + outgoingRequest, found := h.outgoingRequests.LoadAndDelete( + typedMsg.ID, + ) + + // Ensure that we have an outgoing request to match the incoming + // accept message. + if !found { + return fmt.Errorf("no outgoing request found for "+ + "incoming accept message: %s", typedMsg.ID) + } + + // Type cast the outgoing message to a SellRequest (the request + // type that corresponds to a sell accept message). + req, ok := outgoingRequest.(*rfqmsg.SellRequest) + if !ok { + return fmt.Errorf("expected SellRequest, got %T", + outgoingRequest) + } + + typedMsg.Request = *req + } + // Send the incoming message to the RFQ manager. sendSuccess := fn.SendOrQuit(h.cfg.IncomingMessages, msg, h.Quit) if !sendSuccess { @@ -137,6 +208,15 @@ func (h *StreamHandler) HandleOutgoingMessage( err) } + // Store outgoing requests. + switch msg := outgoingMsg.(type) { + case *rfqmsg.BuyRequest: + h.outgoingRequests.Store(msg.ID, msg) + + case *rfqmsg.SellRequest: + h.outgoingRequests.Store(msg.ID, msg) + } + return nil } diff --git a/rfqmsg/buy_accept.go b/rfqmsg/buy_accept.go index 30c87e65b..c4122676b 100644 --- a/rfqmsg/buy_accept.go +++ b/rfqmsg/buy_accept.go @@ -2,11 +2,9 @@ package rfqmsg import ( "bytes" - "crypto/sha256" "fmt" "io" - "github.com/lightninglabs/taproot-assets/asset" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -20,7 +18,6 @@ const ( TypeBuyAcceptAskPrice tlv.Type = 4 TypeBuyAcceptExpiry tlv.Type = 6 TypeBuyAcceptSignature tlv.Type = 8 - TypeBuyAcceptAssetID tlv.Type = 10 ) func TypeRecordBuyAcceptVersion(version *WireMsgDataVersion) tlv.Record { @@ -81,15 +78,6 @@ func TypeRecordBuyAcceptSig(sig *[64]byte) tlv.Record { return tlv.MakePrimitiveRecord(TypeBuyAcceptSignature, sig) } -func TypeRecordBuyAcceptAssetID(assetID **asset.ID) tlv.Record { - const recordSize = sha256.Size - - return tlv.MakeStaticRecord( - TypeBuyAcceptAssetID, assetID, recordSize, - AssetIdEncoder, AssetIdDecoder, - ) -} - const ( // latestBuyAcceptVersion is the latest supported buy accept wire // message data field version. @@ -115,28 +103,17 @@ type buyAcceptMsgData struct { // sig is a signature over the serialized contents of the message. sig [64]byte - - // AssetID is the asset ID of the asset that the accept message is for. - AssetID *asset.ID } // encodeRecords provides all TLV records for encoding. func (q *buyAcceptMsgData) encodeRecords() []tlv.Record { - records := []tlv.Record{ + return []tlv.Record{ TypeRecordBuyAcceptVersion(&q.Version), TypeRecordBuyAcceptID(&q.ID), TypeRecordBuyAcceptAskPrice(&q.AskPrice), TypeRecordBuyAcceptExpiry(&q.Expiry), TypeRecordBuyAcceptSig(&q.sig), } - - if q.AssetID != nil { - records = append( - records, TypeRecordBuyAcceptAssetID(&q.AssetID), - ) - } - - return records } // decodeRecords provides all TLV records for decoding. @@ -147,7 +124,6 @@ func (q *buyAcceptMsgData) decodeRecords() []tlv.Record { TypeRecordBuyAcceptAskPrice(&q.AskPrice), TypeRecordBuyAcceptExpiry(&q.Expiry), TypeRecordBuyAcceptSig(&q.sig), - TypeRecordBuyAcceptAssetID(&q.AssetID), } } @@ -185,6 +161,10 @@ type BuyAccept struct { // Peer is the peer that sent the quote request. Peer route.Vertex + // Request is the quote request message that this message responds to. + // This field is not included in the wire message. + Request BuyRequest + // AssetAmount is the amount of the asset that the accept message // is for. AssetAmount uint64 @@ -201,12 +181,12 @@ func NewBuyAcceptFromRequest(request BuyRequest, askPrice lnwire.MilliSatoshi, return &BuyAccept{ Peer: request.Peer, AssetAmount: request.AssetAmount, + Request: request, buyAcceptMsgData: buyAcceptMsgData{ Version: latestBuyAcceptVersion, ID: request.ID, AskPrice: askPrice, Expiry: expiry, - AssetID: request.AssetID, }, } } diff --git a/rfqmsg/sell_accept.go b/rfqmsg/sell_accept.go index de71278f5..b1d1d382f 100644 --- a/rfqmsg/sell_accept.go +++ b/rfqmsg/sell_accept.go @@ -2,12 +2,10 @@ package rfqmsg import ( "bytes" - "crypto/sha256" "encoding/binary" "fmt" "io" - "github.com/lightninglabs/taproot-assets/asset" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -21,7 +19,6 @@ const ( TypeSellAcceptBidPrice tlv.Type = 4 TypeSellAcceptExpiry tlv.Type = 6 TypeSellAcceptSignature tlv.Type = 8 - TypeSellAcceptAssetID tlv.Type = 10 ) func TypeRecordSellAcceptVersion(version *WireMsgDataVersion) tlv.Record { @@ -56,15 +53,6 @@ func TypeRecordSellAcceptSig(sig *[64]byte) tlv.Record { return tlv.MakePrimitiveRecord(TypeSellAcceptSignature, sig) } -func TypeRecordSellAcceptAssetID(assetID **asset.ID) tlv.Record { - const recordSize = sha256.Size - - return tlv.MakeStaticRecord( - TypeSellAcceptAssetID, assetID, recordSize, - AssetIdEncoder, AssetIdDecoder, - ) -} - const ( // latestSellAcceptVersion is the latest supported sell accept wire // message data field version. @@ -90,28 +78,17 @@ type sellAcceptMsgData struct { // sig is a signature over the serialized contents of the message. sig [64]byte - - // AssetID is the asset ID of the asset that the accept message is for. - AssetID *asset.ID } // encodeRecords provides all TLV records for encoding. func (q *sellAcceptMsgData) encodeRecords() []tlv.Record { - records := []tlv.Record{ + return []tlv.Record{ TypeRecordSellAcceptVersion(&q.Version), TypeRecordSellAcceptID(&q.ID), TypeRecordSellAcceptBidPrice(&q.BidPrice), TypeRecordSellAcceptExpiry(&q.Expiry), TypeRecordSellAcceptSig(&q.sig), } - - if q.AssetID != nil { - records = append( - records, TypeRecordSellAcceptAssetID(&q.AssetID), - ) - } - - return records } // decodeRecords provides all TLV records for decoding. @@ -122,7 +99,6 @@ func (q *sellAcceptMsgData) decodeRecords() []tlv.Record { TypeRecordSellAcceptBidPrice(&q.BidPrice), TypeRecordSellAcceptExpiry(&q.Expiry), TypeRecordSellAcceptSig(&q.sig), - TypeRecordSellAcceptAssetID(&q.AssetID), } } @@ -160,6 +136,10 @@ type SellAccept struct { // Peer is the peer that sent the quote request. Peer route.Vertex + // Request is the quote request message that this message responds to. + // This field is not included in the wire message. + Request SellRequest + // AssetAmount is the amount of the asset that the accept message // is for. AssetAmount uint64 @@ -176,12 +156,12 @@ func NewSellAcceptFromRequest(request SellRequest, bidPrice lnwire.MilliSatoshi, return &SellAccept{ Peer: request.Peer, AssetAmount: request.AssetAmount, + Request: request, sellAcceptMsgData: sellAcceptMsgData{ Version: latestSellAcceptVersion, ID: request.ID, BidPrice: bidPrice, Expiry: expiry, - AssetID: request.AssetID, }, } } diff --git a/tapchannel/aux_traffic_shaper.go b/tapchannel/aux_traffic_shaper.go index 40c4365c4..0827b63f2 100644 --- a/tapchannel/aux_traffic_shaper.go +++ b/tapchannel/aux_traffic_shaper.go @@ -244,16 +244,16 @@ func (s *AuxTrafficShaper) ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, // We now know how many units we need. We take the asset ID from the // RFQ so the recipient can match it back to the quote. - if quote.AssetID == nil { + if quote.Request.AssetID == nil { return 0, nil, fmt.Errorf("quote has no asset ID") } log.Debugf("Producing HTLC extra data for RFQ ID %x (SCID %d): "+ "asset ID %x, asset amount %d", rfqID[:], rfqID.Scid(), - quote.AssetID[:], numAssetUnits) + quote.Request.AssetID[:], numAssetUnits) htlc.Amounts.Val.Balances = []*rfqmsg.AssetBalance{ - rfqmsg.NewAssetBalance(*quote.AssetID, numAssetUnits), + rfqmsg.NewAssetBalance(*quote.Request.AssetID, numAssetUnits), } // Encode the updated HTLC TLV back into a blob and return it with the