diff --git a/test/testcases.go b/test/testcases.go index d0bc13e1..21a093c9 100644 --- a/test/testcases.go +++ b/test/testcases.go @@ -44,6 +44,9 @@ func coopClaimTest(t *testing.T, params *testParams) { // // Move local balance from taker to maker so that the taker does not // have enough balance to pay the invoice and cancels the swap coop. + feeInvoiceAmt, err := params.makerNode.GetFeeInvoiceAmtSat() + require.NoError(err) + moveAmt := (params.origTakerBalance - params.swapAmt) + 100 inv, err := params.makerNode.AddInvoice(moveAmt, "shift balance", "") require.NoError(err) @@ -56,8 +59,10 @@ func coopClaimTest(t *testing.T, params *testParams) { err = testframework.WaitFor(func() bool { setTakerFunds, err = params.takerNode.GetChannelBalanceSat(params.scid) require.NoError(err) - return params.origTakerBalance-moveAmt-10 < setTakerFunds && setTakerFunds < params.origTakerBalance-moveAmt+10 + return params.origTakerBalance-moveAmt-10-feeInvoiceAmt < setTakerFunds && + setTakerFunds < params.origTakerBalance-moveAmt+10-feeInvoiceAmt }, testframework.TIMEOUT) + require.NoError(err) // // STEP 3: Confirm opening tx diff --git a/testframework/clightning.go b/testframework/clightning.go index f5c91949..7b62578a 100644 --- a/testframework/clightning.go +++ b/testframework/clightning.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "regexp" + "strings" "time" "github.com/elementsproject/glightning/glightning" @@ -504,3 +505,17 @@ func (n *CLightningNode) GetMemoFromPayreq(bolt11 string) (string, error) { return r.Description, nil } + +func (n *CLightningNode) GetFeeInvoiceAmtSat() (sat uint64, err error) { + var feeInvoiceAmt uint64 + r, err := n.Rpc.ListInvoices() + if err != nil { + return 0, err + } + for _, i := range r { + if strings.Contains(i.Description, "fee") { + feeInvoiceAmt += i.AmountMilliSatoshi.MSat() / 1000 + } + } + return feeInvoiceAmt, nil +} diff --git a/testframework/lightning.go b/testframework/lightning.go index 341ec7ba..1ab15e52 100644 --- a/testframework/lightning.go +++ b/testframework/lightning.go @@ -30,6 +30,7 @@ type LightningNode interface { // invoices. GetLatestInvoice() (payreq string, err error) GetMemoFromPayreq(payreq string) (memo string, err error) + GetFeeInvoiceAmtSat() (sat uint64, err error) Run(waitForReady, swaitForBitcoinSynced bool) error Stop() error diff --git a/testframework/lnd.go b/testframework/lnd.go index c9ae7f52..1fc90812 100644 --- a/testframework/lnd.go +++ b/testframework/lnd.go @@ -8,6 +8,7 @@ import ( "math" "os" "path/filepath" + "strings" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" @@ -516,3 +517,18 @@ func ScidFromLndChanId(id uint64) string { lndScid := lnwire.NewShortChanIDFromInt(id) return fmt.Sprintf("%dx%dx%d", lndScid.BlockHeight, lndScid.TxIndex, lndScid.TxPosition) } + +func (n *LndNode) GetFeeInvoiceAmtSat() (sat uint64, err error) { + var feeInvoiceAmt uint64 + r, err := n.Rpc.ListInvoices(context.Background(), &lnrpc.ListInvoiceRequest{}) + if err != nil { + return 0, err + } + + for _, i := range r.Invoices { + if strings.Contains(i.GetMemo(), "fee") { + feeInvoiceAmt += uint64(i.GetValue()) + } + } + return feeInvoiceAmt, nil +}