diff --git a/electrum/channel_db.py b/electrum/channel_db.py index 59de812b992a..398790566056 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -38,7 +38,8 @@ from . import constants from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .logging import Logger -from .lnutil import LN_FEATURES_IMPLEMENTED, LNPeerAddr, format_short_channel_id, ShortChannelID +from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID, + UnknownEvenFeatureBits, validate_features) from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnmsg import decode_msg @@ -47,15 +48,6 @@ from .lnchannel import Channel -class UnknownEvenFeatureBits(Exception): pass - -def validate_features(features: int) -> None: - enabled_features = list_enabled_bits(features) - for fbit in enabled_features: - if (1 << fbit) & LN_FEATURES_IMPLEMENTED == 0 and fbit % 2 == 0: - raise UnknownEvenFeatureBits() - - FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 diff --git a/electrum/lnaddr.py b/electrum/lnaddr.py index 0547af5af590..bb201e9e165e 100644 --- a/electrum/lnaddr.py +++ b/electrum/lnaddr.py @@ -141,6 +141,14 @@ def tagged(char, l): def tagged_bytes(char, l): return tagged(char, bitstring.BitArray(l)) +def trim_to_min_length(bits): + # Get minimal length by trimming leading 5 bits at a time. + while bits.startswith('0b00000'): + if len(bits) == 5: + break # v == 0 + bits = bits[5:] + return bits + # Discard trailing bits, convert to bytes. def trim_to_bytes(barr): # Adds a byte if necessary. @@ -155,7 +163,7 @@ def pull_tagged(stream): length = stream.read(5).uint * 32 + stream.read(5).uint return (CHARSET[tag], stream.read(length * 5), stream) -def lnencode(addr, privkey): +def lnencode(addr: 'LnAddr', privkey): if addr.amount: amount = Decimal(str(addr.amount)) # We can only send down to millisatoshi. @@ -196,23 +204,24 @@ def lnencode(addr, privkey): elif k == 'd': data += tagged_bytes('d', v.encode()) elif k == 'x': - # Get minimal length by trimming leading 5 bits at a time. expirybits = bitstring.pack('intbe:64', v)[4:64] - while expirybits.startswith('0b00000'): - if len(expirybits) == 5: - break # v == 0 - expirybits = expirybits[5:] + expirybits = trim_to_min_length(expirybits) data += tagged('x', expirybits) elif k == 'h': data += tagged_bytes('h', sha256(v.encode('utf-8')).digest()) elif k == 'n': data += tagged_bytes('n', v) elif k == 'c': - # Get minimal length by trimming leading 5 bits at a time. finalcltvbits = bitstring.pack('intbe:64', v)[4:64] - while finalcltvbits.startswith('0b00000'): - finalcltvbits = finalcltvbits[5:] + finalcltvbits = trim_to_min_length(finalcltvbits) data += tagged('c', finalcltvbits) + elif k == '9': + if v == 0: + continue + feature_bits = bitstring.BitArray(uint=v, length=v.bit_length()) + while feature_bits.len % 5 != 0: + feature_bits.prepend('0b0') + data += tagged('9', feature_bits) else: # FIXME: Support unknown tags? raise ValueError("Unknown tag {}".format(k)) @@ -247,7 +256,7 @@ def __init__(self, paymenthash: bytes = None, amount=None, currency=None, tags=N self.signature = None self.pubkey = None self.currency = constants.net.SEGWIT_HRP if currency is None else currency - self.amount = amount + self.amount = amount # in bitcoins self._min_final_cltv_expiry = 9 def __str__(self): @@ -389,8 +398,16 @@ def lndecode(invoice: str, *, verbose=False, expected_hrp=None) -> LnAddr: continue pubkeybytes = trim_to_bytes(tagdata) addr.pubkey = pubkeybytes + elif tag == 'c': addr._min_final_cltv_expiry = tagdata.int + + elif tag == '9': + features = tagdata.uint + addr.tags.append(('9', features)) + from .lnutil import validate_features + validate_features(features) + else: addr.unknown_tags.append((tag, tagdata)) diff --git a/electrum/lnonion.py b/electrum/lnonion.py index 51bab2817cb4..15199320396d 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -87,7 +87,7 @@ def from_bytes(cls, b: bytes) -> 'LegacyHopDataPayload': @classmethod def from_tlv_dict(cls, d: dict) -> 'LegacyHopDataPayload': return LegacyHopDataPayload( - short_channel_id=d["short_channel_id"]["short_channel_id"], + short_channel_id=d["short_channel_id"]["short_channel_id"] if "short_channel_id" in d else b"\x00" * 8, amt_to_forward=d["amt_to_forward"]["amt_to_forward"], outgoing_cltv_value=d["outgoing_cltv_value"]["outgoing_cltv_value"], ) diff --git a/electrum/lnutil.py b/electrum/lnutil.py index d237bf50b91d..9038b5e92a14 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -165,6 +165,7 @@ class HandshakeFailed(LightningError): pass class ConnStringFormatError(LightningError): pass class UnknownPaymentHash(LightningError): pass class RemoteMisbehaving(LightningError): pass +class UnknownEvenFeatureBits(Exception): pass class NotFoundChanAnnouncementForUpdate(Exception): pass @@ -882,6 +883,16 @@ def ln_compare_features(our_features: 'LnFeatures', their_features: int) -> 'LnF return our_features +def validate_features(features: int) -> None: + """Raises UnknownEvenFeatureBits if there is an unimplemented + mandatory feature. + """ + enabled_features = list_enabled_bits(features) + for fbit in enabled_features: + if (1 << fbit) & LN_FEATURES_IMPLEMENTED == 0 and fbit % 2 == 0: + raise UnknownEvenFeatureBits(fbit) + + class LNPeerAddr: def __init__(self, host: str, port: int, pubkey: bytes): diff --git a/electrum/tests/test_bolt11.py b/electrum/tests/test_bolt11.py index efe8105b44dc..88b18768407f 100644 --- a/electrum/tests/test_bolt11.py +++ b/electrum/tests/test_bolt11.py @@ -6,6 +6,7 @@ from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode, u5_to_bitarray, bitarray_to_u5 from electrum.segwit_addr import bech32_encode, bech32_decode +from electrum.lnutil import UnknownEvenFeatureBits from . import ElectrumTestCase @@ -66,11 +67,22 @@ def test_roundtrip(self): LnAddr(RHASH, amount=Decimal('1'), tags=[('h', longdescription)]), LnAddr(RHASH, currency='tb', tags=[('f', 'mk2QpYatsKicvFVuTAQLBryyccRXMUaGHP'), ('h', longdescription)]), LnAddr(RHASH, amount=24, tags=[ - ('r', [(unhexlify('029e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('0102030405060708'), 1, 20, 3), (unhexlify('039e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('030405060708090a'), 2, 30, 4)]), ('f', '1RustyRX2oai4EYYDpQGWvEL62BBGqN9T'), ('h', longdescription)]), + ('r', [(unhexlify('029e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('0102030405060708'), 1, 20, 3), + (unhexlify('039e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('030405060708090a'), 2, 30, 4)]), + ('f', '1RustyRX2oai4EYYDpQGWvEL62BBGqN9T'), + ('h', longdescription)]), LnAddr(RHASH, amount=24, tags=[('f', '3EktnHQD7RiAE6uzMj2ZifT9YgRrkSgzQX'), ('h', longdescription)]), LnAddr(RHASH, amount=24, tags=[('f', 'bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4'), ('h', longdescription)]), LnAddr(RHASH, amount=24, tags=[('f', 'bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3'), ('h', longdescription)]), LnAddr(RHASH, amount=24, tags=[('n', PUBKEY), ('h', longdescription)]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 514)]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 8))]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 9))]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 11))]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 12))]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 13))]), + #LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 14))]), + LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 15))]), ] # Roundtrip @@ -98,9 +110,18 @@ def test_n_decoding(self): assert lnaddr.pubkey.serialize() == PUBKEY def test_min_final_cltv_expiry_decoding(self): - self.assertEqual(144, lndecode("lnsb500u1pdsgyf3pp5nmrqejdsdgs4n9ukgxcp2kcq265yhrxd4k5dyue58rxtp5y83s3qdqqcqzystrggccm9yvkr5yqx83jxll0qjpmgfg9ywmcd8g33msfgmqgyfyvqhku80qmqm8q6v35zvck2y5ccxsz5avtrauz8hgjj3uahppyq20qp6dvwxe", expected_hrp="sb").get_min_final_cltv_expiry()) + lnaddr = lndecode("lnsb500u1pdsgyf3pp5nmrqejdsdgs4n9ukgxcp2kcq265yhrxd4k5dyue58rxtp5y83s3qdqqcqzystrggccm9yvkr5yqx83jxll0qjpmgfg9ywmcd8g33msfgmqgyfyvqhku80qmqm8q6v35zvck2y5ccxsz5avtrauz8hgjj3uahppyq20qp6dvwxe", + expected_hrp="sb") + self.assertEqual(144, lnaddr.get_min_final_cltv_expiry()) def test_min_final_cltv_expiry_roundtrip(self): lnaddr = LnAddr(RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60), ('c', 150)]) invoice = lnencode(lnaddr, PRIVKEY) self.assertEqual(150, lndecode(invoice).get_min_final_cltv_expiry()) + + def test_features(self): + lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9qzsze992adudgku8p05pstl6zh7av6rx2f297pv89gu5q93a0hf3g7lynl3xq56t23dpvah6u7y9qey9lccrdml3gaqwc6nxsl5ktzm464sq73t7cl") + self.assertEqual(514, lnaddr.get_tag('9')) + + with self.assertRaises(UnknownEvenFeatureBits): + lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9q4pqqqqqqqqqqqqqqqqqqszk3ed62snp73037h4py4gry05eltlp0uezm2w9ajnerhmxzhzhsu40g9mgyx5v3ad4aqwkmvyftzk4k9zenz90mhjcy9hcevc7r3lx2sphzfxz7")