diff --git a/Makefile b/Makefile index c62408d8..131a3e72 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,8 @@ test-bitcoin-lnd: test-bins 'Test_LndLnd_Bitcoin_SwapOut|'\ 'Test_LndLnd_Bitcoin_SwapIn|'\ 'Test_LndCln_Bitcoin_SwapOut|'\ - 'Test_LndCln_Bitcoin_SwapIn)'\ + 'Test_LndCln_Bitcoin_SwapIn|'\ + 'Test_LndLnd_ExcessiveAmount)'\ ./test ${INTEGRATION_TEST_ENV} go test $(INTEGRATION_TEST_OPTS) ./lnd .PHONY: test-bitcoin-lnd diff --git a/lnd/client.go b/lnd/client.go index 45bc979a..2a9bbc87 100644 --- a/lnd/client.go +++ b/lnd/client.go @@ -77,6 +77,32 @@ func (l *Client) CanSpend(amtMsat uint64) error { return nil } +// getMaxHtlcAmtMsat returns the maximum htlc amount in msat for a channel. +// If for some reason it cannot be retrieved, return 0. +func (l *Client) getMaxHtlcAmtMsat(chanId uint64, pubkey string) (uint64, error) { + var maxHtlcAmtMsat uint64 = 0 + r, err := l.lndClient.GetChanInfo(context.Background(), &lnrpc.ChanInfoRequest{ + ChanId: chanId, + }) + if err != nil { + // Ignore err because channel graph information is not always set. + return maxHtlcAmtMsat, nil + } + if r.Node1Pub == pubkey { + maxHtlcAmtMsat = r.GetNode1Policy().GetMaxHtlcMsat() + } else if r.Node2Pub == pubkey { + maxHtlcAmtMsat = r.GetNode2Policy().GetMaxHtlcMsat() + } + return maxHtlcAmtMsat, nil +} + +func min(x, y uint64) uint64 { + if x < y { + return x + } + return y +} + // SpendableMsat returns an estimate of the total we could send through the // channel with given scid. func (l *Client) SpendableMsat(scid string) (uint64, error) { @@ -96,7 +122,18 @@ func (l *Client) SpendableMsat(scid string) (uint64, error) { if err = l.checkChannel(ch); err != nil { return 0, err } - return uint64(ch.LocalBalance * 1000), nil + maxHtlcAmtMsat, err := l.getMaxHtlcAmtMsat(ch.ChanId, l.pubkey) + if err != nil { + return 0, err + } + spendable := uint64(ch.LocalBalance * 1000) + // since the max htlc limit is not always set reliably, + // the check is skipped if it is not set. + if maxHtlcAmtMsat == 0 { + return spendable, nil + } + return min(maxHtlcAmtMsat, spendable), nil + } } return 0, fmt.Errorf("could not find a channel with scid: %s", scid) diff --git a/swap/actions.go b/swap/actions.go index d8b5e463..6d40b46c 100644 --- a/swap/actions.go +++ b/swap/actions.go @@ -579,6 +579,15 @@ func (r *PayFeeInvoiceAction) Execute(services *SwapServices, swap *SwapData) Ev return swap.HandleError(err) } + sp, err := ll.SpendableMsat(swap.SwapOutRequest.Scid) + if err != nil { + return swap.HandleError(err) + } + + if sp <= swap.SwapOutRequest.Amount*1000 { + return swap.HandleError(err) + } + swap.OpeningTxFee = msatAmt / 1000 expectedFee, err := wallet.GetFlatSwapOutFee() diff --git a/test/bitcoin_lnd_test.go b/test/bitcoin_lnd_test.go index 93f49e90..b7dabbb2 100644 --- a/test/bitcoin_lnd_test.go +++ b/test/bitcoin_lnd_test.go @@ -1099,3 +1099,89 @@ func Test_LndCln_Bitcoin_SwapOut(t *testing.T) { csvClaimTest(t, params) }) } + +func Test_LndLnd_ExcessiveAmount(t *testing.T) { + IsIntegrationTest(t) + t.Parallel() + t.Run("exceed_maxhtlc", func(t *testing.T) { + t.Parallel() + require := require.New(t) + + bitcoind, lightningds, peerswapds, scid := lndlndSetup(t, uint64(math.Pow10(9))) + defer func() { + if t.Failed() { + pprintFail( + tailableProcess{ + p: bitcoind.DaemonProcess, + lines: defaultLines, + }, + tailableProcess{ + p: lightningds[0].DaemonProcess, + lines: defaultLines, + }, + tailableProcess{ + p: lightningds[1].DaemonProcess, + lines: defaultLines, + }, + tailableProcess{ + p: peerswapds[0].DaemonProcess, + lines: defaultLines, + }, + tailableProcess{ + p: peerswapds[1].DaemonProcess, + lines: defaultLines, + }, + ) + } + }() + + var channelBalances []uint64 + var walletBalances []uint64 + for _, lightningd := range lightningds { + b, err := lightningd.GetBtcBalanceSat() + require.NoError(err) + walletBalances = append(walletBalances, b) + + b, err = lightningd.GetChannelBalanceSat(scid) + require.NoError(err) + channelBalances = append(channelBalances, b) + } + + lcid, err := lightningds[0].ChanIdFromScid(scid) + if err != nil { + t.Fatalf("lightingds[0].ChanIdFromScid() %v", err) + } + + params := &testParams{ + swapAmt: channelBalances[0] / 2, + scid: scid, + origTakerWallet: walletBalances[0], + origMakerWallet: walletBalances[1], + origTakerBalance: channelBalances[0], + origMakerBalance: channelBalances[1], + takerNode: lightningds[0], + makerNode: lightningds[1], + takerPeerswap: peerswapds[0].DaemonProcess, + makerPeerswap: peerswapds[1].DaemonProcess, + chainRpc: bitcoind.RpcProxy, + chaind: bitcoind, + confirms: BitcoinConfirms, + csv: BitcoinCsv, + swapType: swap.SWAPTYPE_OUT, + } + asset := "btc" + + _, err = lightningds[0].SetHtlcMaximumMilliSatoshis(scid, channelBalances[0]*1000/2-1) + assert.NoError(t, err) + // Swap out should fail as the swap_amt is to high. + // Do swap. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err = peerswapds[0].PeerswapClient.SwapOut(ctx, &peerswaprpc.SwapOutRequest{ + ChannelId: lcid, + SwapAmount: params.swapAmt, + Asset: asset, + }) + assert.Error(t, err) + }) +} diff --git a/testframework/clightning.go b/testframework/clightning.go index 7493999c..a11c2f24 100644 --- a/testframework/clightning.go +++ b/testframework/clightning.go @@ -11,6 +11,7 @@ import ( "time" "github.com/elementsproject/glightning/glightning" + "github.com/elementsproject/peerswap/lightning" ) type CLightningNode struct { @@ -519,3 +520,38 @@ func (n *CLightningNode) GetFeeInvoiceAmtSat() (sat uint64, err error) { } return feeInvoiceAmt, nil } + +type SetChannel struct { + Id string `json:"id"` + HtlcMaximumMilliSatoshis string `json:"htlcmax,omitempty"` +} + +type ChannelInfo struct { + PeerID string `json:"peer_id"` + ChannelID string `json:"channel_id"` + ShortChannelID string `json:"short_channel_id"` + FeeBaseMsat glightning.Amount `json:"fee_base_msat"` + FeeProportionalMillionths glightning.Amount `json:"fee_proportional_millionths"` + MinimumHtlcOutMsat glightning.Amount `json:"minimum_htlc_out_msat"` + MaximumHtlcOutMsat glightning.Amount `json:"maximum_htlc_out_msat"` +} + +type SetChannelResponse struct { + Channels []ChannelInfo `json:"channels"` +} + +func (r *SetChannel) Name() string { + return "setchannel" +} + +func (n *CLightningNode) SetHtlcMaximumMilliSatoshis(scid string, maxHtlcMsat uint64) (msat uint64, err error) { + var res SetChannelResponse + err = n.Rpc.Request(&SetChannel{ + Id: lightning.Scid(scid).ClnStyle(), + HtlcMaximumMilliSatoshis: fmt.Sprint(maxHtlcMsat), + }, &res) + if err != nil { + return 0, err + } + return maxHtlcMsat, err +} diff --git a/testframework/lightning.go b/testframework/lightning.go index 1ab15e52..444fd4ce 100644 --- a/testframework/lightning.go +++ b/testframework/lightning.go @@ -33,5 +33,6 @@ type LightningNode interface { GetFeeInvoiceAmtSat() (sat uint64, err error) Run(waitForReady, swaitForBitcoinSynced bool) error + SetHtlcMaximumMilliSatoshis(scid string, maxHtlcMsat uint64) (msat uint64, err error) Stop() error } diff --git a/testframework/lnd.go b/testframework/lnd.go index 7be94494..d87ad0bb 100644 --- a/testframework/lnd.go +++ b/testframework/lnd.go @@ -9,7 +9,10 @@ import ( "os" "path/filepath" "regexp" + "strconv" + "strings" + "github.com/elementsproject/peerswap/lightning" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" ) @@ -529,3 +532,50 @@ func (n *LndNode) GetFeeInvoiceAmtSat() (sat uint64, err error) { } return feeInvoiceAmt, nil } + +func (n *LndNode) SetHtlcMaximumMilliSatoshis(scid string, maxHtlcMsat uint64) (msat uint64, err error) { + s := lightning.Scid(scid) + res, err := n.Rpc.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{}) + if err != nil { + return 0, fmt.Errorf("ListChannels() %w", err) + } + for _, ch := range res.GetChannels() { + channelShortId := lnwire.NewShortChanIDFromInt(ch.ChanId) + if channelShortId.String() == s.LndStyle() { + r, err := n.Rpc.GetChanInfo(context.Background(), &lnrpc.ChanInfoRequest{ + ChanId: ch.ChanId, + }) + if err != nil { + return 0, err + } + parts := strings.Split(r.ChanPoint, ":") + if len(parts) != 2 { + return 0, fmt.Errorf("expected scid to be composed of 3 blocks") + } + txPosition, err := strconv.Atoi(parts[1]) + if err != nil { + return 0, err + } + _, err = n.Rpc.UpdateChannelPolicy(context.Background(), &lnrpc.PolicyUpdateRequest{ + Scope: &lnrpc.PolicyUpdateRequest_ChanPoint{ChanPoint: &lnrpc.ChannelPoint{ + FundingTxid: &lnrpc.ChannelPoint_FundingTxidStr{ + FundingTxidStr: parts[0], + }, + OutputIndex: uint32(txPosition), + }}, + BaseFeeMsat: 1000, + FeeRate: 1, + FeeRatePpm: 0, + TimeLockDelta: 40, + MaxHtlcMsat: maxHtlcMsat, + MinHtlcMsat: msat, + MinHtlcMsatSpecified: false, + }) + if err != nil { + return 0, err + } + return maxHtlcMsat, err + } + } + return 0, fmt.Errorf("could not find a channel with scid: %s", scid) +}