Skip to content

Commit

Permalink
lnaddr: add feature bit support to invoices
Browse files Browse the repository at this point in the history
  • Loading branch information
SomberNight authored and sidhujag committed Apr 2, 2020
1 parent 5d1d41a commit dc9dae4
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 23 deletions.
12 changes: 2 additions & 10 deletions electrum/channel_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger
from .lnutil import LN_FEATURES_IMPLEMENTED, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
UnknownEvenFeatureBits, validate_features)
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .lnmsg import decode_msg

Expand All @@ -47,15 +48,6 @@
from .lnchannel import Channel


class UnknownEvenFeatureBits(Exception): pass

def validate_features(features: int) -> None:
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) & LN_FEATURES_IMPLEMENTED == 0 and fbit % 2 == 0:
raise UnknownEvenFeatureBits()


FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0

Expand Down
37 changes: 27 additions & 10 deletions electrum/lnaddr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def tagged(char, l):
def tagged_bytes(char, l):
return tagged(char, bitstring.BitArray(l))

def trim_to_min_length(bits):
# Get minimal length by trimming leading 5 bits at a time.
while bits.startswith('0b00000'):
if len(bits) == 5:
break # v == 0
bits = bits[5:]
return bits

# Discard trailing bits, convert to bytes.
def trim_to_bytes(barr):
# Adds a byte if necessary.
Expand All @@ -155,7 +163,7 @@ def pull_tagged(stream):
length = stream.read(5).uint * 32 + stream.read(5).uint
return (CHARSET[tag], stream.read(length * 5), stream)

def lnencode(addr, privkey):
def lnencode(addr: 'LnAddr', privkey):
if addr.amount:
amount = Decimal(str(addr.amount))
# We can only send down to millisatoshi.
Expand Down Expand Up @@ -196,23 +204,24 @@ def lnencode(addr, privkey):
elif k == 'd':
data += tagged_bytes('d', v.encode())
elif k == 'x':
# Get minimal length by trimming leading 5 bits at a time.
expirybits = bitstring.pack('intbe:64', v)[4:64]
while expirybits.startswith('0b00000'):
if len(expirybits) == 5:
break # v == 0
expirybits = expirybits[5:]
expirybits = trim_to_min_length(expirybits)
data += tagged('x', expirybits)
elif k == 'h':
data += tagged_bytes('h', sha256(v.encode('utf-8')).digest())
elif k == 'n':
data += tagged_bytes('n', v)
elif k == 'c':
# Get minimal length by trimming leading 5 bits at a time.
finalcltvbits = bitstring.pack('intbe:64', v)[4:64]
while finalcltvbits.startswith('0b00000'):
finalcltvbits = finalcltvbits[5:]
finalcltvbits = trim_to_min_length(finalcltvbits)
data += tagged('c', finalcltvbits)
elif k == '9':
if v == 0:
continue
feature_bits = bitstring.BitArray(uint=v, length=v.bit_length())
while feature_bits.len % 5 != 0:
feature_bits.prepend('0b0')
data += tagged('9', feature_bits)
else:
# FIXME: Support unknown tags?
raise ValueError("Unknown tag {}".format(k))
Expand Down Expand Up @@ -247,7 +256,7 @@ def __init__(self, paymenthash: bytes = None, amount=None, currency=None, tags=N
self.signature = None
self.pubkey = None
self.currency = constants.net.SEGWIT_HRP if currency is None else currency
self.amount = amount
self.amount = amount # in bitcoins
self._min_final_cltv_expiry = 9

def __str__(self):
Expand Down Expand Up @@ -389,8 +398,16 @@ def lndecode(invoice: str, *, verbose=False, expected_hrp=None) -> LnAddr:
continue
pubkeybytes = trim_to_bytes(tagdata)
addr.pubkey = pubkeybytes

elif tag == 'c':
addr._min_final_cltv_expiry = tagdata.int

elif tag == '9':
features = tagdata.uint
addr.tags.append(('9', features))
from .lnutil import validate_features
validate_features(features)

else:
addr.unknown_tags.append((tag, tagdata))

Expand Down
2 changes: 1 addition & 1 deletion electrum/lnonion.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def from_bytes(cls, b: bytes) -> 'LegacyHopDataPayload':
@classmethod
def from_tlv_dict(cls, d: dict) -> 'LegacyHopDataPayload':
return LegacyHopDataPayload(
short_channel_id=d["short_channel_id"]["short_channel_id"],
short_channel_id=d["short_channel_id"]["short_channel_id"] if "short_channel_id" in d else b"\x00" * 8,
amt_to_forward=d["amt_to_forward"]["amt_to_forward"],
outgoing_cltv_value=d["outgoing_cltv_value"]["outgoing_cltv_value"],
)
Expand Down
11 changes: 11 additions & 0 deletions electrum/lnutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class HandshakeFailed(LightningError): pass
class ConnStringFormatError(LightningError): pass
class UnknownPaymentHash(LightningError): pass
class RemoteMisbehaving(LightningError): pass
class UnknownEvenFeatureBits(Exception): pass

class NotFoundChanAnnouncementForUpdate(Exception): pass

Expand Down Expand Up @@ -882,6 +883,16 @@ def ln_compare_features(our_features: 'LnFeatures', their_features: int) -> 'LnF
return our_features


def validate_features(features: int) -> None:
"""Raises UnknownEvenFeatureBits if there is an unimplemented
mandatory feature.
"""
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) & LN_FEATURES_IMPLEMENTED == 0 and fbit % 2 == 0:
raise UnknownEvenFeatureBits(fbit)


class LNPeerAddr:

def __init__(self, host: str, port: int, pubkey: bytes):
Expand Down
25 changes: 23 additions & 2 deletions electrum/tests/test_bolt11.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode, u5_to_bitarray, bitarray_to_u5
from electrum.segwit_addr import bech32_encode, bech32_decode
from electrum.lnutil import UnknownEvenFeatureBits

from . import ElectrumTestCase

Expand Down Expand Up @@ -66,11 +67,22 @@ def test_roundtrip(self):
LnAddr(RHASH, amount=Decimal('1'), tags=[('h', longdescription)]),
LnAddr(RHASH, currency='tb', tags=[('f', 'mk2QpYatsKicvFVuTAQLBryyccRXMUaGHP'), ('h', longdescription)]),
LnAddr(RHASH, amount=24, tags=[
('r', [(unhexlify('029e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('0102030405060708'), 1, 20, 3), (unhexlify('039e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('030405060708090a'), 2, 30, 4)]), ('f', '1RustyRX2oai4EYYDpQGWvEL62BBGqN9T'), ('h', longdescription)]),
('r', [(unhexlify('029e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('0102030405060708'), 1, 20, 3),
(unhexlify('039e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('030405060708090a'), 2, 30, 4)]),
('f', '1RustyRX2oai4EYYDpQGWvEL62BBGqN9T'),
('h', longdescription)]),
LnAddr(RHASH, amount=24, tags=[('f', '3EktnHQD7RiAE6uzMj2ZifT9YgRrkSgzQX'), ('h', longdescription)]),
LnAddr(RHASH, amount=24, tags=[('f', 'bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4'), ('h', longdescription)]),
LnAddr(RHASH, amount=24, tags=[('f', 'bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3'), ('h', longdescription)]),
LnAddr(RHASH, amount=24, tags=[('n', PUBKEY), ('h', longdescription)]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 514)]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 8))]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 9))]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 11))]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 12))]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 13))]),
#LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 14))]),
LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 15))]),
]

# Roundtrip
Expand Down Expand Up @@ -98,9 +110,18 @@ def test_n_decoding(self):
assert lnaddr.pubkey.serialize() == PUBKEY

def test_min_final_cltv_expiry_decoding(self):
self.assertEqual(144, lndecode("lnsb500u1pdsgyf3pp5nmrqejdsdgs4n9ukgxcp2kcq265yhrxd4k5dyue58rxtp5y83s3qdqqcqzystrggccm9yvkr5yqx83jxll0qjpmgfg9ywmcd8g33msfgmqgyfyvqhku80qmqm8q6v35zvck2y5ccxsz5avtrauz8hgjj3uahppyq20qp6dvwxe", expected_hrp="sb").get_min_final_cltv_expiry())
lnaddr = lndecode("lnsb500u1pdsgyf3pp5nmrqejdsdgs4n9ukgxcp2kcq265yhrxd4k5dyue58rxtp5y83s3qdqqcqzystrggccm9yvkr5yqx83jxll0qjpmgfg9ywmcd8g33msfgmqgyfyvqhku80qmqm8q6v35zvck2y5ccxsz5avtrauz8hgjj3uahppyq20qp6dvwxe",
expected_hrp="sb")
self.assertEqual(144, lnaddr.get_min_final_cltv_expiry())

def test_min_final_cltv_expiry_roundtrip(self):
lnaddr = LnAddr(RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60), ('c', 150)])
invoice = lnencode(lnaddr, PRIVKEY)
self.assertEqual(150, lndecode(invoice).get_min_final_cltv_expiry())

def test_features(self):
lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9qzsze992adudgku8p05pstl6zh7av6rx2f297pv89gu5q93a0hf3g7lynl3xq56t23dpvah6u7y9qey9lccrdml3gaqwc6nxsl5ktzm464sq73t7cl")
self.assertEqual(514, lnaddr.get_tag('9'))

with self.assertRaises(UnknownEvenFeatureBits):
lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9q4pqqqqqqqqqqqqqqqqqqszk3ed62snp73037h4py4gry05eltlp0uezm2w9ajnerhmxzhzhsu40g9mgyx5v3ad4aqwkmvyftzk4k9zenz90mhjcy9hcevc7r3lx2sphzfxz7")

0 comments on commit dc9dae4

Please sign in to comment.