Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom channel bug fixes #936

Merged
merged 8 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ func ParseUniversePublicAccessStatus(
case "w":
return UniversePublicAccessStatusWrite, nil

case "":
return UniversePublicAccessStatusNone, nil

default:
// This default case returns an error. It will capture the case
// where the CLI argument is present but unset (empty value).
Expand Down
234 changes: 205 additions & 29 deletions rfq/order.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ type Policy interface {
// Expiry returns the policy's expiry time as a unix timestamp.
Expiry() uint64

// HasExpired returns true if the policy has expired.
HasExpired() bool

// Scid returns the serialised short channel ID (SCID) of the channel to
// which the policy applies.
Scid() uint64
Expand Down Expand Up @@ -141,6 +144,13 @@ func (c *AssetSalePolicy) Expiry() uint64 {
return c.expiry
}

// HasExpired returns true if the policy has expired.
func (c *AssetSalePolicy) HasExpired() bool {
expireTime := time.Unix(int64(c.expiry), 0).UTC()

return time.Now().UTC().After(expireTime)
}

// Scid returns the serialised short channel ID (SCID) of the channel to which
// the policy applies.
func (c *AssetSalePolicy) Scid() uint64 {
Expand Down Expand Up @@ -260,6 +270,13 @@ func (c *AssetPurchasePolicy) Expiry() uint64 {
return c.expiry
}

// HasExpired returns true if the policy has expired.
func (c *AssetPurchasePolicy) HasExpired() bool {
expireTime := time.Unix(int64(c.expiry), 0).UTC()

return time.Now().UTC().After(expireTime)
}

// Scid returns the serialised short channel ID (SCID) of the channel to which
// the policy applies.
func (c *AssetPurchasePolicy) Scid() uint64 {
Expand All @@ -268,10 +285,24 @@ func (c *AssetPurchasePolicy) Scid() uint64 {

// GenerateInterceptorResponse generates an interceptor response for the policy.
func (c *AssetPurchasePolicy) GenerateInterceptorResponse(
_ lndclient.InterceptedHtlc) (*lndclient.InterceptedHtlcResponse,
htlc lndclient.InterceptedHtlc) (*lndclient.InterceptedHtlcResponse,
error) {

incomingValue := lnwire.MilliSatoshi(c.AssetAmount) * c.BidPrice
htlcRecord, err := parseHtlcCustomRecords(htlc.WireCustomRecords)
if err != nil {
return nil, fmt.Errorf("parsing HTLC custom records failed: %w",
err)
}

// The incoming amount is just to signal to the fee logic in lnd that
// we have received enough to pay for the routing fees and the asset
// amount. Due to rounding errors, we may slightly underreport the
// incoming value of the asset. So we increase it by exactly one asset
// unit to ensure that the fee logic in lnd does not reject the HTLC.
const roundingCorrection = 1
htlcAssetAmount := htlcRecord.Amounts.Val.Sum() + roundingCorrection
incomingValue := lnwire.MilliSatoshi(htlcAssetAmount) * c.BidPrice

return &lndclient.InterceptedHtlcResponse{
Action: lndclient.InterceptorActionResumeModified,
IncomingAmount: incomingValue,
Expand All @@ -281,6 +312,118 @@ func (c *AssetPurchasePolicy) GenerateInterceptorResponse(
// Ensure that AssetPurchasePolicy implements the Policy interface.
var _ Policy = (*AssetPurchasePolicy)(nil)

// AssetForwardPolicy is a struct that holds the terms which determine whether a
// channel HTLC for an asset-to-asset forward is accepted or rejected.
type AssetForwardPolicy struct {
incomingPolicy *AssetPurchasePolicy
outgoingPolicy *AssetSalePolicy
}

// NewAssetForwardPolicy creates a new asset forward policy.
func NewAssetForwardPolicy(incoming, outgoing Policy) (*AssetForwardPolicy,
error) {

incomingPolicy, ok := incoming.(*AssetPurchasePolicy)
if !ok {
return nil, fmt.Errorf("incoming policy is not an asset "+
"purchase policy, but %T", incoming)
}

outgoingPolicy, ok := outgoing.(*AssetSalePolicy)
if !ok {
return nil, fmt.Errorf("outgoing policy is not an asset "+
"sale policy, but %T", outgoing)
}

return &AssetForwardPolicy{
incomingPolicy: incomingPolicy,
outgoingPolicy: outgoingPolicy,
}, nil
}

// CheckHtlcCompliance returns an error if the given HTLC intercept descriptor
// does not satisfy the subject policy.
func (a *AssetForwardPolicy) CheckHtlcCompliance(
htlc lndclient.InterceptedHtlc) error {

if err := a.incomingPolicy.CheckHtlcCompliance(htlc); err != nil {
return fmt.Errorf("error checking forward policy, inbound "+
"HTLC does not comply with policy: %w", err)
}

if err := a.outgoingPolicy.CheckHtlcCompliance(htlc); err != nil {
return fmt.Errorf("error checking forward policy, outbound "+
"HTLC does not comply with policy: %w", err)
}

return nil
}

// Expiry returns the policy's expiry time as a unix timestamp in seconds. The
// returned expiry time is the earliest expiry time of the incoming and outgoing
// policies.
func (a *AssetForwardPolicy) Expiry() uint64 {
if a.incomingPolicy.Expiry() < a.outgoingPolicy.Expiry() {
return a.incomingPolicy.Expiry()
}

return a.outgoingPolicy.Expiry()
}

// HasExpired returns true if the policy has expired.
func (a *AssetForwardPolicy) HasExpired() bool {
expireTime := time.Unix(int64(a.Expiry()), 0).UTC()

return time.Now().UTC().After(expireTime)
}

// Scid returns the serialised short channel ID (SCID) of the channel to which
// the policy applies. This is the SCID of the incoming policy.
func (a *AssetForwardPolicy) Scid() uint64 {
return a.incomingPolicy.Scid()
}

// GenerateInterceptorResponse generates an interceptor response for the policy.
func (a *AssetForwardPolicy) GenerateInterceptorResponse(
htlc lndclient.InterceptedHtlc) (*lndclient.InterceptedHtlcResponse,
error) {

incomingResponse, err := a.incomingPolicy.GenerateInterceptorResponse(
htlc,
)
if err != nil {
return nil, fmt.Errorf("error generating incoming interceptor "+
"response: %w", err)
}

outgoingResponse, err := a.outgoingPolicy.GenerateInterceptorResponse(
htlc,
)
if err != nil {
return nil, fmt.Errorf("error generating outgoing interceptor "+
"response: %w", err)
}

return &lndclient.InterceptedHtlcResponse{
// Both incoming and outgoing policies will resume with
// modifications.
Action: lndclient.InterceptorActionResumeModified,

// The incoming policy will modify the incoming amount in order
// to satisfy the fee check in `lnd`.
IncomingAmount: incomingResponse.IncomingAmount,

// The outgoing policy will modify the outgoing amount and add
// custom records in order to satisfy the terms of the receiving
// node.
OutgoingAmount: outgoingResponse.OutgoingAmount,
CustomRecords: outgoingResponse.CustomRecords,
}, nil
}

// Ensure that AssetForwardPolicy implements the Policy interface.
var _ Policy = (*AssetForwardPolicy)(nil)

// OrderHandlerCfg is a struct that holds the configuration parameters for the
// order handler service.
type OrderHandlerCfg struct {
Expand Down Expand Up @@ -335,7 +478,11 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context,
htlc lndclient.InterceptedHtlc) (*lndclient.InterceptedHtlcResponse,
error) {

log.Debug("Handling incoming HTLC")
log.Debugf("Handling incoming HTLC, incoming channel ID: %v, "+
"outgoing channel ID: %v (incoming amount: %v, outgoing "+
"amount: %v)", htlc.IncomingCircuitKey.ChanID.ToUint64(),
htlc.OutgoingChannelID.ToUint64(), htlc.AmountInMsat,
htlc.AmountOutMsat)

// Look up a policy for the HTLC. If a policy does not exist, we resume
// the HTLC. This is because the HTLC may be relevant to another
Expand All @@ -353,7 +500,8 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context,
}, nil
}

log.Debugf("Fetched policy with SCID %v", policy.Scid())
log.Debugf("Fetched policy with SCID %v of type %T", policy.Scid(),
policy)

// At this point, we know that a policy exists and has not expired
// whilst sitting in the local cache. We can now check that the HTLC
Expand All @@ -369,8 +517,7 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context,
}

log.Debug("HTLC complies with policy. Broadcasting accept event.")
acceptHtlcEvent := NewAcceptHtlcEvent(htlc, policy)
h.cfg.AcceptHtlcEvents <- acceptHtlcEvent
h.cfg.AcceptHtlcEvents <- NewAcceptHtlcEvent(htlc, policy)

return policy.GenerateInterceptorResponse(htlc)
}
Expand Down Expand Up @@ -474,6 +621,17 @@ func (h *OrderHandler) RegisterAssetPurchasePolicy(
func (h *OrderHandler) fetchPolicy(htlc lndclient.InterceptedHtlc) (Policy,
bool, error) {

outScid := SerialisedScid(htlc.OutgoingChannelID.ToUint64())
outPolicy, haveOutPolicy := h.policies.Load(outScid)

inScid := SerialisedScid(htlc.IncomingCircuitKey.ChanID.ToUint64())
inPolicy, haveInPolicy := h.policies.Load(inScid)

log.Tracef("Have inbound policy: %v: %v", haveInPolicy,
spew.Sdump(inPolicy))
log.Tracef("Have outbound policy: %v: %v", haveOutPolicy,
spew.Sdump(outPolicy))

var (
foundPolicy *Policy
foundScid *SerialisedScid
Expand Down Expand Up @@ -506,26 +664,51 @@ func (h *OrderHandler) fetchPolicy(htlc lndclient.InterceptedHtlc) (Policy,
})
}

// Here we handle a special case where we both have an incoming and
// outgoing policy. In this case, we need to create a forward policy.
if foundPolicy != nil && haveOutPolicy {
incomingPolicy := *foundPolicy
outgoingPolicy := outPolicy

if incomingPolicy.HasExpired() {
scid := incomingPolicy.Scid()
h.policies.Delete(SerialisedScid(scid))
}
if outgoingPolicy.HasExpired() {
scid := outgoingPolicy.Scid()
h.policies.Delete(SerialisedScid(scid))
}

// If either the incoming or outgoing policy has expired, we
// return false, as if we didn't find a policy.
if incomingPolicy.HasExpired() || outgoingPolicy.HasExpired() {
return nil, false, nil
}

forwardPolicy, err := NewAssetForwardPolicy(
incomingPolicy, outgoingPolicy,
)
if err != nil {
return nil, false, fmt.Errorf("error creating forward "+
"policy: %w", err)
}

return forwardPolicy, true, nil

}

// If no policy has been found so far, we attempt to look up a policy by
// the outgoing channel SCID.
if foundPolicy == nil {
scid := SerialisedScid(htlc.OutgoingChannelID.ToUint64())
policy, ok := h.policies.Load(scid)
if ok {
foundPolicy = &policy
foundScid = &scid
}
if foundPolicy == nil && haveOutPolicy {
foundPolicy = &outPolicy
foundScid = &outScid
}

// If no policy has been found so far, we attempt to look up a policy by
// the incoming channel SCID.
if foundPolicy == nil {
scid := SerialisedScid(htlc.IncomingCircuitKey.ChanID.ToUint64())
policy, ok := h.policies.Load(scid)
if ok {
foundPolicy = &policy
foundScid = &scid
}
if foundPolicy == nil && haveInPolicy {
foundPolicy = &inPolicy
foundScid = &inScid
}

// If no policy has been found, we return false.
Expand All @@ -536,11 +719,7 @@ func (h *OrderHandler) fetchPolicy(htlc lndclient.InterceptedHtlc) (Policy,
policy := *foundPolicy
scid := *foundScid

// If the policy has expired, return false and clear it from the cache.
expireTime := time.Unix(int64(policy.Expiry()), 0).UTC()
currentTime := time.Now().UTC()

if currentTime.After(expireTime) {
if policy.HasExpired() {
h.policies.Delete(scid)
return nil, false, nil
}
Expand All @@ -555,10 +734,7 @@ func (h *OrderHandler) cleanupStalePolicies() {

h.policies.ForEach(
func(scid SerialisedScid, policy Policy) error {
expireTime := time.Unix(int64(policy.Expiry()), 0).UTC()
currentTime := time.Now().UTC()

if currentTime.After(expireTime) {
if policy.HasExpired() {
staleCounter++
h.policies.Delete(scid)
}
Expand Down
4 changes: 4 additions & 0 deletions rfqmsg/accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ type acceptWireMsgData struct {
// Sig is a signature over the serialized contents of the message.
Sig tlv.RecordT[tlv.TlvType3, [64]byte]

// InOutRateTick is the tick rate for the accept, defined in
// in_asset/out_asset. This is only set in a buy accept message.
InOutRateTick acceptInOutRateTick

// OutInRateTick is the tick rate for the accept, defined in
// out_asset/in_asset. This is only set in a sell accept message.
OutInRateTick acceptOutInRateTick
}

Expand Down
9 changes: 7 additions & 2 deletions tapchannel/aux_closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ func (a *AuxChanCloser) AuxCloseOutputs(
// anchor amt, then we'll just drop this allocation, and modify
// our asset allocation to match this value.
if amtAfterAnchor <= o.DustLimit {
localAlloc.BtcAmount = btcAmt
if localAlloc != nil {
localAlloc.BtcAmount = btcAmt
}
ffranr marked this conversation as resolved.
Show resolved Hide resolved
return
}

Expand All @@ -285,7 +287,10 @@ func (a *AuxChanCloser) AuxCloseOutputs(
// anchor amt, then we'll just drop this allocation, and modify
// our asset allocation to match this value.
if amtAfterAnchor <= o.DustLimit {
remoteAlloc.BtcAmount = btcAmt
if remoteAlloc != nil {
remoteAlloc.BtcAmount = btcAmt
}

return
}

Expand Down
Loading
Loading