From 3bbe4038523113fb811eb43ed3d6ca4137239855 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Thu, 6 Oct 2022 09:46:40 +0200 Subject: [PATCH 1/4] rewrite LNTransport using aiorpcx - prepare for use in ElectrumX: - use create_server and create_connection - subclass aiorpcx.RSTransport - use random, non-persisted privkey --- electrum/interface.py | 38 ++++-- electrum/lnpeer.py | 15 +-- electrum/lntransport.py | 252 +++++++++++++++++++++++++--------------- electrum/lnworker.py | 38 +++--- 4 files changed, 218 insertions(+), 125 deletions(-) diff --git a/electrum/interface.py b/electrum/interface.py index 9edbd41dca5d..c6dbad5b406c 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -58,6 +58,8 @@ from .i18n import _ from .logging import Logger from .transaction import Transaction +from .lnutil import LNPeerAddr +from .lntransport import LNClient if TYPE_CHECKING: from .network import Network @@ -70,7 +72,7 @@ MAX_INCOMING_MSG_SIZE = 1_000_000 # in bytes -_KNOWN_NETWORK_PROTOCOLS = {'t', 's'} +_KNOWN_NETWORK_PROTOCOLS = {'t', 's', 'b'} PREFERRED_NETWORK_PROTOCOL = 's' assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS @@ -271,7 +273,7 @@ async def create_connection(self): class ServerAddr: - def __init__(self, host: str, port: Union[int, str], *, protocol: str = None): + def __init__(self, host: str, port: Union[int, str], *, protocol: str = None, pubkey: str = None): assert isinstance(host, str), repr(host) if protocol is None: protocol = 's' @@ -288,13 +290,20 @@ def __init__(self, host: str, port: Union[int, str], *, protocol: str = None): self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address) self.port = int(net_addr.port) self.protocol = protocol + self.pubkey = pubkey self._net_addr_str = str(net_addr) @classmethod def from_str(cls, s: str) -> 'ServerAddr': # host might be IPv6 address, hence do rsplit: - host, port, protocol = str(s).rsplit(':', 2) - return ServerAddr(host=host, port=port, protocol=protocol) + s = str(s).rsplit(':', 3) + if len(s) == 4: + host, port, protocol, pubkey = s + assert protocol == 'b' + elif len(s) == 3: + host, port, protocol = s + pubkey = None + return ServerAddr(host=host, port=port, protocol=protocol, pubkey=pubkey) @classmethod def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']: @@ -654,11 +663,26 @@ def is_main_server(self) -> bool: return (self.network.interface == self or self.network.interface is None and self.network.default_server == self.server) + #@log_exceptions async def open_session(self, sslc, exit_early=False): session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface) - async with _RSClient(session_factory=session_factory, - host=self.host, port=self.port, - ssl=sslc, proxy=self.proxy) as session: + + def create_client(): + if self.protocol == 'b': + peer_addr = LNPeerAddr(self.host, self.port, bytes.fromhex(self.server.pubkey)) + bolt8_privkey = os.urandom(32) + return LNClient( + privkey=bolt8_privkey, + session_factory=session_factory, + peer_addr=peer_addr, + proxy=self.proxy) + else: + return _RSClient( + session_factory=session_factory, + host=self.host, port=self.port, + ssl=sslc, proxy=self.proxy) + + async with create_client() as session: self.session = session # type: NotificationSession self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic)) try: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index e393ecc22bd5..49570d68fae0 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -44,7 +44,7 @@ IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, ChannelType, LNProtocolWarning) from .lnutil import FeeUpdate, channel_id_from_funding_tx -from .lntransport import LNTransport, LNTransportBase +from .lntransport import LNTransport from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType from .interface import GracefulDisconnect from .lnrouter import fee_for_edge_msat @@ -76,7 +76,7 @@ def __init__( self, lnworker: Union['LNGossip', 'LNWallet'], pubkey: bytes, - transport: LNTransportBase, + transport: LNTransport, *, is_channel_backup= False): self.lnworker = lnworker @@ -90,7 +90,7 @@ def __init__( self.querying = asyncio.Event() self.transport = transport self.pubkey = pubkey # remote pubkey - self.privkey = self.transport.privkey # local privkey + self.privkey = self.transport._privkey # local privkey self.features = self.lnworker.features # type: LnFeatures self.their_features = LnFeatures(0) # type: LnFeatures self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)] @@ -155,10 +155,7 @@ def is_initialized(self) -> bool: and self.initialized.result() is True) async def initialize(self): - # If outgoing transport, do handshake now. For incoming, it has already been done. - if isinstance(self.transport, LNTransport): - await self.transport.handshake() - self.logger.info(f"handshake done for {self.transport.peer_addr or self.pubkey.hex()}") + assert self.transport.handshake_done.is_set() features = self.features.for_init_message() b = int.bit_length(features) flen = b // 8 + int(bool(b % 8)) @@ -847,7 +844,7 @@ async def channel_establishment_flow( ) chan.storage['funding_inputs'] = [txin.prevout.to_json() for txin in funding_tx.inputs()] chan.storage['has_onchain_backup'] = has_onchain_backup - if isinstance(self.transport, LNTransport): + if not self.transport.is_listener(): chan.add_or_update_peer_addr(self.transport.peer_addr) sig_64, _ = chan.sign_next_commitment() self.temp_id_to_id[temp_channel_id] = channel_id @@ -1024,7 +1021,7 @@ async def on_open_channel(self, payload): initial_feerate=feerate ) chan.storage['init_timestamp'] = int(time.time()) - if isinstance(self.transport, LNTransport): + if not self.transport.is_listener(): chan.add_or_update_peer_addr(self.transport.peer_addr) remote_sig = funding_created['signature'] try: diff --git a/electrum/lntransport.py b/electrum/lntransport.py index a919688c5b62..2eb658b1fef3 100644 --- a/electrum/lntransport.py +++ b/electrum/lntransport.py @@ -7,15 +7,43 @@ import hashlib import asyncio -from asyncio import StreamReader, StreamWriter +from asyncio import Queue from typing import Optional -from functools import cached_property +from functools import cached_property, partial from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt -from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed, - HandshakeFailed, LNPeerAddr) +from .lnutil import (get_ecdh, privkey_to_pubkey, HandshakeFailed) from . import ecc -from .util import bh2u, MySocksProxy +from .util import MySocksProxy, log_exceptions +from .logging import Logger + +from aiorpcx.util import NetAddress +from aiorpcx.session import SessionKind, SessionBase +from aiorpcx.framing import FramerBase +from aiorpcx.rawsocket import RSTransport, ConnectionLostError + + +class QueueFramer(FramerBase): + + def __init__(self): + self.queue = Queue() + + def frame(self, message): + raise NotImplementedError + + def received_message(self, msg): + self.queue.put_nowait(msg) + + async def receive_message(self): + msg = await self.queue.get() + return msg + + def fail(self, exception): + self.exception = exception + + +class LNSession(SessionBase): + pass class HandshakeState(object): @@ -89,22 +117,39 @@ def create_ephemeral_key() -> (bytes, bytes): return privkey.get_secret_bytes(), privkey.get_public_key_bytes() -class LNTransportBase: - reader: StreamReader - writer: StreamWriter - privkey: bytes - peer_addr: Optional[LNPeerAddr] = None +class LNTransport(RSTransport, Logger): - def name(self) -> str: - pubkey = self.remote_pubkey() - pubkey_hex = pubkey.hex() if pubkey else pubkey - return f"{pubkey_hex[:10]}-{self._id_hash[:8]}" + _privkey: bytes + _remote_pubkey: bytes - @cached_property - def _id_hash(self) -> str: - id_int = id(self) - id_bytes = id_int.to_bytes((id_int.bit_length() + 7) // 8, byteorder='big') - return sha256(id_bytes).hex() + def __init__(self, session_factory, privkey, peer_addr=None): + framer = QueueFramer() + kind = SessionKind.SERVER if peer_addr is None else SessionKind.CLIENT + self.peer_addr = peer_addr # todo: remove this, pass only pubkey + self._remote_pubkey = peer_addr.pubkey if peer_addr else None + + Logger.__init__(self) + RSTransport.__init__(self, session_factory, framer, kind) + assert type(privkey) is bytes and len(privkey) == 32 + self._privkey = privkey + self._data = bytearray() + self._data_received = asyncio.Event() + self.handshake_done = asyncio.Event() + + def is_listener(self): + return self.kind == SessionKind.SERVER + + @log_exceptions + async def read_data(self, len): + await self._data_received.wait() + chunk = self._data[0:len] + self._data = self._data[len:] + if not self._data: + self._data_received.clear() + return chunk + + async def write(self, message) -> None: + self.send_bytes(message) def send_bytes(self, msg: bytes) -> None: l = len(msg).to_bytes(2, 'big') @@ -112,32 +157,36 @@ def send_bytes(self, msg: bytes) -> None: c = aead_encrypt(self.sk, self.sn(), b'', msg) assert len(lc) == 18 assert len(c) == len(msg) + 16 - self.writer.write(lc+c) + self._asyncio_transport.write(lc+c) - async def read_messages(self): - buffer = bytearray() + @log_exceptions + async def decrypt_messages(self): + if self.is_listener(): + await self.listener_handshake() + else: + await self.handshake() while True: rn_l, rk_l = self.rn() rn_m, rk_m = self.rn() while True: - if len(buffer) >= 18: - lc = bytes(buffer[:18]) + if len(self._data) >= 18: + lc = bytes(self._data[:18]) l = aead_decrypt(rk_l, rn_l, b'', lc) length = int.from_bytes(l, 'big') offset = 18 + length + 16 - if len(buffer) >= offset: - c = bytes(buffer[18:offset]) - del buffer[:offset] # much faster than: buffer=buffer[offset:] + if len(self._data) >= offset: + c = bytes(self._data[18:offset]) + del self._data[:offset] # much faster than: buffer=buffer[offset:] msg = aead_decrypt(rk_m, rn_m, b'', c) - yield msg + self._framer.received_message(msg) break - try: - s = await self.reader.read(2**10) - except Exception: - s = None - if not s: - raise LightningPeerConnectionClosed() - buffer += s + await self._data_received.wait() + self._data_received.clear() + + async def read_messages(self): + while True: + msg = await self.receive_message() + yield msg def rn(self): o = self._rn, self.rk @@ -156,37 +205,16 @@ def sn(self): return o def init_counters(self, ck): - # init counters self._sn = 0 self._rn = 0 self.r_ck = ck self.s_ck = ck - def close(self): - self.writer.close() - - def remote_pubkey(self) -> Optional[bytes]: - raise NotImplementedError() - - -class LNResponderTransport(LNTransportBase): - """Transport initiated by remote party.""" - - def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter): - LNTransportBase.__init__(self) - self.reader = reader - self.writer = writer - self.privkey = privkey - self._pubkey = None # remote pubkey - - def name(self) -> str: - return f"{super().name()}(in)" - - async def handshake(self, **kwargs): - hs = HandshakeState(privkey_to_pubkey(self.privkey)) + async def listener_handshake(self, **kwargs): + hs = HandshakeState(privkey_to_pubkey(self._privkey)) act1 = b'' while len(act1) < 50: - buf = await self.reader.read(50 - len(act1)) + buf = await self.read_data(50 - len(act1)) if not buf: raise HandshakeFailed('responder disconnected') act1 += buf @@ -197,11 +225,10 @@ async def handshake(self, **kwargs): c = act1[-16:] re = act1[1:34] h = hs.update(re) - ss = get_ecdh(self.privkey, re) + ss = get_ecdh(self._privkey, re) ck, temp_k1 = get_bolt8_hkdf(sha256(HandshakeState.protocol_name), ss) _p = aead_decrypt(temp_k1, 0, h, c) hs.update(c) - # act 2 if 'epriv' not in kwargs: epriv, epub = create_ephemeral_key() @@ -210,14 +237,12 @@ async def handshake(self, **kwargs): epub = ecc.ECPrivkey(epriv).get_public_key_bytes() hs.ck = ck hs.responder_pub = re - msg, temp_k2 = act1_initiator_message(hs, epriv, epub) - self.writer.write(msg) - + self._asyncio_transport.write(msg) # act 3 act3 = b'' while len(act3) < 66: - buf = await self.reader.read(66 - len(act3)) + buf = await self.read_data(66 - len(act3)) if not buf: raise HandshakeFailed('responder disconnected') act3 += buf @@ -233,40 +258,36 @@ async def handshake(self, **kwargs): _p = aead_decrypt(temp_k3, 0, hs.update(c), t) self.rk, self.sk = get_bolt8_hkdf(ck, b'') self.init_counters(ck) - self._pubkey = rs + self._remote_pubkey = rs + self.handshake_done.set() return rs - def remote_pubkey(self) -> Optional[bytes]: - return self._pubkey + def connection_made(self, transport): + RSTransport.connection_made(self, transport) + self._decrypt_messages_task = self.loop.create_task(self.decrypt_messages()) + def connection_lost(self, exc): + RSTransport.connection_lost(self, exc) + self._process_messages_task.cancel() # fixme: this should be done in parent class + self._decrypt_messages_task.cancel() -class LNTransport(LNTransportBase): - """Transport initiated by local party.""" - - def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *, - proxy: Optional[dict]): - LNTransportBase.__init__(self) - assert type(privkey) is bytes and len(privkey) == 32 - self.privkey = privkey - self.peer_addr = peer_addr - self.proxy = MySocksProxy.from_proxy_dict(proxy) + def data_received(self, chunk): + self._data += chunk + self._data_received.set() + self.session.data_received(chunk) async def handshake(self): - if not self.proxy: - self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port) - else: - self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port) - hs = HandshakeState(self.peer_addr.pubkey) + assert self._remote_pubkey is not None + hs = HandshakeState(self._remote_pubkey) # Get a new ephemeral key epriv, epub = create_ephemeral_key() - msg, _temp_k1 = act1_initiator_message(hs, epriv, epub) # act 1 - self.writer.write(msg) - rspns = await self.reader.read(2**10) + self._asyncio_transport.write(msg) + rspns = await self.read_data(2**10) if len(rspns) != 50: raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, " - f"are you sure this is the right pubkey? {self.peer_addr}") + f"are you sure this is the right pubkey? {self._remote_pubkey.hex()}") hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:] if bytes([hver]) != hs.handshake_version: raise HandshakeFailed("unexpected handshake version: {}".format(hver)) @@ -278,17 +299,66 @@ async def handshake(self): p = aead_decrypt(temp_k2, 0, hs.h, tag) hs.update(tag) # act 3 - my_pubkey = privkey_to_pubkey(self.privkey) + my_pubkey = privkey_to_pubkey(self._privkey) c = aead_encrypt(temp_k2, 1, hs.h, my_pubkey) hs.update(c) - ss = get_ecdh(self.privkey[:32], alice_epub) + ss = get_ecdh(self._privkey[:32], alice_epub) ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss) hs.ck = ck t = aead_encrypt(temp_k3, 0, hs.h, b'') msg = hs.handshake_version + c + t - self.writer.write(msg) + self._asyncio_transport.write(msg) self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'') self.init_counters(ck) + self.handshake_done.set() + + @cached_property + def _id_hash(self) -> str: + id_int = id(self) + id_bytes = id_int.to_bytes((id_int.bit_length() + 7) // 8, byteorder='big') + return sha256(id_bytes).hex() + + def name(self) -> str: + pubkey = self.remote_pubkey() + if pubkey: + pubkey_hex = pubkey.hex() if pubkey else pubkey + return f"{pubkey_hex[:10]}-{self._id_hash[:8]}" + else: + return '' def remote_pubkey(self) -> Optional[bytes]: - return self.peer_addr.pubkey + return self._remote_pubkey + + + +class LNClient: + + def __init__(self, privkey, session_factory, peer_addr, proxy=None, loop=None): + assert type(privkey) is bytes and len(privkey) == 32 + self.privkey = privkey + self.peer_addr = peer_addr + self.proxy = MySocksProxy.from_proxy_dict(proxy) if proxy else None + self.loop = loop or asyncio.get_running_loop() + self.session_factory = session_factory + self.protocol_factory = partial(LNTransport, self.session_factory, self.privkey, peer_addr=self.peer_addr) + + @log_exceptions + async def create_connection(self): + connector = self.proxy or self.loop + return await connector.create_connection(self.protocol_factory, self.peer_addr.host, self.peer_addr.port) + + async def __aenter__(self): + _transport, protocol = await self.create_connection() + self.session = protocol.session + assert isinstance(self.session, SessionBase) + await protocol.handshake_done.wait() + return self.session + + async def __aexit__(self, exc_type, exc_value, traceback): await self.session.close() + + + +async def create_bolt8_server(privkey, session_factory, host=None, port=None, *, loop=None, **kwargs): + loop = loop or asyncio.get_event_loop() + protocol_factory = partial(LNTransport, session_factory, privkey) + return await loop.create_server(protocol_factory, host, port) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 0c95d1045fbc..a7d982753767 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -46,7 +46,7 @@ from .util import timestamp_to_datetime, random_shuffled_copy from .util import MyEncoder, is_private_netaddress, UnrelatedTransactionException from .logging import Logger -from .lntransport import LNTransport, LNResponderTransport, LNTransportBase +from .lntransport import LNClient, LNTransport, LNSession, create_bolt8_server from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT from .lnaddr import lnencode, LnAddr, lndecode from .ecc import der_sig_from_sig_string @@ -259,17 +259,14 @@ async def maybe_listen(self): except Exception as e: self.logger.error(f"failed to parse config key 'lightning_listen'. got: {e!r}") return - addr = str(netaddr.host) - async def cb(reader, writer): - transport = LNResponderTransport(self.node_keypair.privkey, reader, writer) - try: - node_id = await transport.handshake() - except Exception as e: - self.logger.info(f'handshake failure from incoming connection: {e!r}') - return - await self._add_peer_from_transport(node_id=node_id, transport=transport) + def session_factory(transport): + async def coro(): + await transport.handshake_done.wait() + await self._add_peer_from_transport(node_id=transport._remote_pubkey, transport=transport) + asyncio.run_coroutine_threadsafe(coro(), self.network.asyncio_loop) + return LNSession(transport) try: - self.listen_server = await asyncio.start_server(cb, addr, netaddr.port) + self.listen_server = await create_bolt8_server(self.node_keypair.privkey, session_factory, str(netaddr.host), netaddr.port) except OSError as e: self.logger.error(f"cannot listen for lightning p2p. error: {e!r}") @@ -309,12 +306,14 @@ async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer: self.logger.info(f"adding peer {peer_addr}") if node_id == self.node_keypair.pubkey: raise ErrorAddingPeer("cannot connect to self") - transport = LNTransport(self.node_keypair.privkey, peer_addr, - proxy=self.network.proxy) - peer = await self._add_peer_from_transport(node_id=node_id, transport=transport) + connector = self.network.proxy or self.network.asyncio_loop + protocol_factory = partial(LNTransport, LNSession, self.node_keypair.privkey, peer_addr=peer_addr) + asyncio_transport, protocol = await connector.create_connection(protocol_factory, peer_addr.host, peer_addr.port) + await protocol.handshake_done.wait() + peer = await self._add_peer_from_transport(node_id=node_id, transport=protocol) return peer - async def _add_peer_from_transport(self, *, node_id: bytes, transport: LNTransportBase) -> Peer: + async def _add_peer_from_transport(self, *, node_id: bytes, transport: LNTransport) -> Peer: peer = Peer(self, node_id, transport) with self.lock: existing_peer = self._peers.get(node_id) @@ -370,7 +369,7 @@ def is_good_peer(self, peer: LNPeerAddr) -> bool: return True def on_peer_successfully_established(self, peer: Peer) -> None: - if isinstance(peer.transport, LNTransport): + if not peer.transport.is_listener(): peer_addr = peer.transport.peer_addr # reset connection attempt count self._on_connection_successfully_established(peer_addr) @@ -2515,8 +2514,11 @@ async def _request_force_close_from_backup(self, channel_id: bytes): async def _request_fclose(addresses): for host, port, timestamp in addresses: peer_addr = LNPeerAddr(host, port, node_id) - transport = LNTransport(privkey, peer_addr, proxy=self.network.proxy) - peer = Peer(self, node_id, transport, is_channel_backup=True) + connector = self.proxy or self.loop + protocol_factory = partial(LNTransport, LNSession, privkey, peer_addr=peer_addr) + asyncio_transport, protocol = await connector.create_connection(protocol_factory, peer_addr.host, peer_addr.port) + await protocol.handshake_done.wait() + peer = Peer(self, node_id, transport=client._protocol, is_channel_backup=True) try: async with OldTaskGroup(wait=any) as group: await group.spawn(peer._message_loop()) From f1a9d9074a5ef8df0eef8916622e1fd62b7b7220 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Tue, 11 Oct 2022 11:48:46 +0200 Subject: [PATCH 2/4] lntransport: parameterize message size length --- electrum/interface.py | 1 + electrum/lntransport.py | 38 +++++++++++++++++++++-------------- electrum/lnworker.py | 4 ++-- electrum/tests/test_lnpeer.py | 2 +- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/electrum/interface.py b/electrum/interface.py index c6dbad5b406c..2e35f21c97bd 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -672,6 +672,7 @@ def create_client(): peer_addr = LNPeerAddr(self.host, self.port, bytes.fromhex(self.server.pubkey)) bolt8_privkey = os.urandom(32) return LNClient( + prologue=b'electrum', privkey=bolt8_privkey, session_factory=session_factory, peer_addr=peer_addr, diff --git a/electrum/lntransport.py b/electrum/lntransport.py index 2eb658b1fef3..983581c013e2 100644 --- a/electrum/lntransport.py +++ b/electrum/lntransport.py @@ -47,11 +47,11 @@ class LNSession(SessionBase): class HandshakeState(object): - prologue = b"lightning" protocol_name = b"Noise_XK_secp256k1_ChaChaPoly_SHA256" handshake_version = b"\x00" - def __init__(self, responder_pub): + def __init__(self, prologue, responder_pub): + self.prologue = prologue self.responder_pub = responder_pub self.h = sha256(self.protocol_name) self.ck = self.h @@ -117,16 +117,23 @@ def create_ephemeral_key() -> (bytes, bytes): return privkey.get_secret_bytes(), privkey.get_public_key_bytes() +MSG_SIZE_LEN = { + b'lightning': 2, + b'electrum': 4, +} + class LNTransport(RSTransport, Logger): _privkey: bytes _remote_pubkey: bytes - def __init__(self, session_factory, privkey, peer_addr=None): + def __init__(self, prologue, session_factory, privkey, peer_addr=None): framer = QueueFramer() kind = SessionKind.SERVER if peer_addr is None else SessionKind.CLIENT self.peer_addr = peer_addr # todo: remove this, pass only pubkey self._remote_pubkey = peer_addr.pubkey if peer_addr else None + self.prologue = prologue + self.msg_size_len = MSG_SIZE_LEN[prologue] Logger.__init__(self) RSTransport.__init__(self, session_factory, framer, kind) @@ -152,10 +159,10 @@ async def write(self, message) -> None: self.send_bytes(message) def send_bytes(self, msg: bytes) -> None: - l = len(msg).to_bytes(2, 'big') + l = len(msg).to_bytes(self.msg_size_len, 'big') lc = aead_encrypt(self.sk, self.sn(), b'', l) c = aead_encrypt(self.sk, self.sn(), b'', msg) - assert len(lc) == 18 + assert len(lc) == 16 + self.msg_size_len assert len(c) == len(msg) + 16 self._asyncio_transport.write(lc+c) @@ -165,17 +172,18 @@ async def decrypt_messages(self): await self.listener_handshake() else: await self.handshake() + header_length = 16 + self.msg_size_len while True: rn_l, rk_l = self.rn() rn_m, rk_m = self.rn() while True: - if len(self._data) >= 18: - lc = bytes(self._data[:18]) + if len(self._data) >= header_length: + lc = bytes(self._data[:header_length]) l = aead_decrypt(rk_l, rn_l, b'', lc) length = int.from_bytes(l, 'big') - offset = 18 + length + 16 + offset = header_length + length + 16 if len(self._data) >= offset: - c = bytes(self._data[18:offset]) + c = bytes(self._data[header_length:offset]) del self._data[:offset] # much faster than: buffer=buffer[offset:] msg = aead_decrypt(rk_m, rn_m, b'', c) self._framer.received_message(msg) @@ -211,7 +219,7 @@ def init_counters(self, ck): self.s_ck = ck async def listener_handshake(self, **kwargs): - hs = HandshakeState(privkey_to_pubkey(self._privkey)) + hs = HandshakeState(self.prologue, privkey_to_pubkey(self._privkey)) act1 = b'' while len(act1) < 50: buf = await self.read_data(50 - len(act1)) @@ -278,7 +286,7 @@ def data_received(self, chunk): async def handshake(self): assert self._remote_pubkey is not None - hs = HandshakeState(self._remote_pubkey) + hs = HandshakeState(self.prologue, self._remote_pubkey) # Get a new ephemeral key epriv, epub = create_ephemeral_key() msg, _temp_k1 = act1_initiator_message(hs, epriv, epub) @@ -333,14 +341,14 @@ def remote_pubkey(self) -> Optional[bytes]: class LNClient: - def __init__(self, privkey, session_factory, peer_addr, proxy=None, loop=None): + def __init__(self, prologue, privkey, session_factory, peer_addr, proxy=None, loop=None): assert type(privkey) is bytes and len(privkey) == 32 self.privkey = privkey self.peer_addr = peer_addr self.proxy = MySocksProxy.from_proxy_dict(proxy) if proxy else None self.loop = loop or asyncio.get_running_loop() self.session_factory = session_factory - self.protocol_factory = partial(LNTransport, self.session_factory, self.privkey, peer_addr=self.peer_addr) + self.protocol_factory = partial(LNTransport, prologue, self.session_factory, self.privkey, peer_addr=self.peer_addr) @log_exceptions async def create_connection(self): @@ -358,7 +366,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): -async def create_bolt8_server(privkey, session_factory, host=None, port=None, *, loop=None, **kwargs): +async def create_bolt8_server(prologue, privkey, session_factory, host=None, port=None, *, loop=None, **kwargs): loop = loop or asyncio.get_event_loop() - protocol_factory = partial(LNTransport, session_factory, privkey) + protocol_factory = partial(LNTransport, prologue, session_factory, privkey) return await loop.create_server(protocol_factory, host, port) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index a7d982753767..e99daf408d7c 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -266,7 +266,7 @@ async def coro(): asyncio.run_coroutine_threadsafe(coro(), self.network.asyncio_loop) return LNSession(transport) try: - self.listen_server = await create_bolt8_server(self.node_keypair.privkey, session_factory, str(netaddr.host), netaddr.port) + self.listen_server = await create_bolt8_server(b'lightning', self.node_keypair.privkey, session_factory, str(netaddr.host), netaddr.port) except OSError as e: self.logger.error(f"cannot listen for lightning p2p. error: {e!r}") @@ -307,7 +307,7 @@ async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer: if node_id == self.node_keypair.pubkey: raise ErrorAddingPeer("cannot connect to self") connector = self.network.proxy or self.network.asyncio_loop - protocol_factory = partial(LNTransport, LNSession, self.node_keypair.privkey, peer_addr=peer_addr) + protocol_factory = partial(LNTransport, b'lightning', LNSession, self.node_keypair.privkey, peer_addr=peer_addr) asyncio_transport, protocol = await connector.create_connection(protocol_factory, peer_addr.host, peer_addr.port) await protocol.handshake_done.wait() peer = await self._add_peer_from_transport(node_id=node_id, transport=protocol) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 47df5149c1e0..4e978e5a650c 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -300,7 +300,7 @@ class PutIntoOthersQueueTransport(MockTransport): def __init__(self, keypair, name): super().__init__(name) self.other_mock_transport = None - self.privkey = keypair.privkey + self._privkey = keypair.privkey def send_bytes(self, data): self.other_mock_transport.queue.put_nowait(data) From 5832c80576bd2f2cf380cbbecc7abb161ce5a0a2 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Tue, 11 Oct 2022 13:00:30 +0200 Subject: [PATCH 3/4] LNTransport: add authentication during handshake --- electrum/commands.py | 8 ++++++++ electrum/interface.py | 5 +++-- electrum/lntransport.py | 10 +++++++--- electrum/lnworker.py | 2 +- electrum/simple_config.py | 8 ++++++++ 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 62f1cab057b9..0a36ee961d57 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -327,6 +327,14 @@ async def setconfig(self, key, value): self.config.set_key(key, value) return True + @command('') + async def create_key_for_server(self, server_pubkey:str) -> str: + " returns a pubkey to add to electrumx, using the 'add_user' RPC" + privkey = os.urandom(32) + pubkey = ecc.ECPrivkey(privkey).get_public_key_bytes(compressed=True) + self.config.set_bolt8_privkey_for_server(server_pubkey, privkey.hex()) + return pubkey.hex() + @command('') async def get_ssl_domain(self): """Check and return the SSL domain set in ssl_keyfile and ssl_certfile diff --git a/electrum/interface.py b/electrum/interface.py index 2e35f21c97bd..c657f5a7419e 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -670,10 +670,11 @@ async def open_session(self, sslc, exit_early=False): def create_client(): if self.protocol == 'b': peer_addr = LNPeerAddr(self.host, self.port, bytes.fromhex(self.server.pubkey)) - bolt8_privkey = os.urandom(32) + privkey = self.network.config.get_bolt8_privkey_for_server(self.server.pubkey) + privkey = bytes.fromhex(privkey) if privkey else os.urandom(32) return LNClient( prologue=b'electrum', - privkey=bolt8_privkey, + privkey=privkey, session_factory=session_factory, peer_addr=peer_addr, proxy=self.proxy) diff --git a/electrum/lntransport.py b/electrum/lntransport.py index 983581c013e2..42b127d060a6 100644 --- a/electrum/lntransport.py +++ b/electrum/lntransport.py @@ -127,13 +127,14 @@ class LNTransport(RSTransport, Logger): _privkey: bytes _remote_pubkey: bytes - def __init__(self, prologue, session_factory, privkey, peer_addr=None): + def __init__(self, prologue, session_factory, privkey, peer_addr=None, whitelist=None): framer = QueueFramer() kind = SessionKind.SERVER if peer_addr is None else SessionKind.CLIENT self.peer_addr = peer_addr # todo: remove this, pass only pubkey self._remote_pubkey = peer_addr.pubkey if peer_addr else None self.prologue = prologue self.msg_size_len = MSG_SIZE_LEN[prologue] + self.whitelist = whitelist Logger.__init__(self) RSTransport.__init__(self, session_factory, framer, kind) @@ -267,9 +268,12 @@ async def listener_handshake(self, **kwargs): self.rk, self.sk = get_bolt8_hkdf(ck, b'') self.init_counters(ck) self._remote_pubkey = rs + if self.whitelist is not None and rs not in self.whitelist: + raise HandshakeFailed(f'Not authorised {rs.hex()}') self.handshake_done.set() return rs + def connection_made(self, transport): RSTransport.connection_made(self, transport) self._decrypt_messages_task = self.loop.create_task(self.decrypt_messages()) @@ -366,7 +370,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): -async def create_bolt8_server(prologue, privkey, session_factory, host=None, port=None, *, loop=None, **kwargs): +async def create_bolt8_server(prologue, privkey, whitelist, session_factory, host=None, port=None, *, loop=None, **kwargs): loop = loop or asyncio.get_event_loop() - protocol_factory = partial(LNTransport, prologue, session_factory, privkey) + protocol_factory = partial(LNTransport, prologue, session_factory, privkey, whitelist=whitelist) return await loop.create_server(protocol_factory, host, port) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index e99daf408d7c..81a30b6c442f 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -266,7 +266,7 @@ async def coro(): asyncio.run_coroutine_threadsafe(coro(), self.network.asyncio_loop) return LNSession(transport) try: - self.listen_server = await create_bolt8_server(b'lightning', self.node_keypair.privkey, session_factory, str(netaddr.host), netaddr.port) + self.listen_server = await create_bolt8_server(b'lightning', self.node_keypair.privkey, None, session_factory, str(netaddr.host), netaddr.port) except OSError as e: self.logger.error(f"cannot listen for lightning p2p. error: {e!r}") diff --git a/electrum/simple_config.py b/electrum/simple_config.py index 0ecdfa6b77b8..02e669bbe13f 100644 --- a/electrum/simple_config.py +++ b/electrum/simple_config.py @@ -705,6 +705,14 @@ def set_base_unit(self, unit): def get_decimal_point(self): return self.decimal_point + def get_bolt8_privkey_for_server(self, server_pubkey): + return self.get('bolt8_privkeys', {}).get(server_pubkey) + + def set_bolt8_privkey_for_server(self, server_pubkey:str, privkey:str): + d = self.get('bolt8_privkeys', {}) + d[server_pubkey] = privkey + self.set_key('bolt8_privkeys', d) + def read_user_config(path): """Parse and store the user config settings in electrum.conf into user_config[].""" From 7f7c453f186f31843ae2800965d1e5c1d15a1e5d Mon Sep 17 00:00:00 2001 From: ThomasV Date: Fri, 30 Sep 2022 14:16:37 +0200 Subject: [PATCH 4/4] watchtower: delegate watchtower to electrum server --- electrum/commands.py | 4 ++-- electrum/daemon.py | 34 ------------------------------- electrum/gui/qt/__init__.py | 7 ------- electrum/gui/qt/main_window.py | 2 -- electrum/lnwatcher.py | 2 +- electrum/lnworker.py | 34 +++++++------------------------ electrum/network.py | 15 ++++++-------- electrum/tests/regtest.py | 4 ---- electrum/tests/regtest/regtest.sh | 7 +------ 9 files changed, 17 insertions(+), 92 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 0a36ee961d57..c9d15741b121 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1249,8 +1249,8 @@ async def get_channel_ctx(self, channel_point, iknowwhatimdoing=False, wallet: A @command('wnl') async def get_watchtower_ctn(self, channel_point, wallet: Abstract_Wallet = None): - """ return the local watchtower's ctn of channel. used in regtests """ - return await self.network.local_watchtower.sweepstore.get_ctn(channel_point, None) + """ return the remote watchtower's ctn of channel. used in regtests """ + return await self.network.watchtower_get_ctn(channel_point, None) @command('wnl') async def rebalance_channels(self, from_scid, dest_scid, amount, wallet: Abstract_Wallet = None): diff --git a/electrum/daemon.py b/electrum/daemon.py index c5aa7e6d51f4..5459d29da672 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -340,34 +340,6 @@ async def run_cmdline(self, config_options): return result -class WatchTowerServer(AuthenticatedServer): - - def __init__(self, network, netaddress): - self.addr = netaddress - self.config = network.config - self.network = network - watchtower_user = self.config.get('watchtower_user', '') - watchtower_password = self.config.get('watchtower_password', '') - AuthenticatedServer.__init__(self, watchtower_user, watchtower_password) - self.lnwatcher = network.local_watchtower - self.app = web.Application() - self.app.router.add_post("/", self.handle) - self.register_method(self.get_ctn) - self.register_method(self.add_sweep_tx) - - async def run(self): - self.runner = web.AppRunner(self.app) - await self.runner.setup() - site = web.TCPSite(self.runner, host=str(self.addr.host), port=self.addr.port, ssl_context=self.config.get_ssl_context()) - await site.start() - self.logger.info(f"now running and listening. addr={self.addr}") - - async def get_ctn(self, *args): - return await self.lnwatcher.get_ctn(*args) - - async def add_sweep_tx(self, *args): - return await self.lnwatcher.sweepstore.add_sweep_tx(*args) - @@ -403,12 +375,6 @@ def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): if listen_jsonrpc: self.commands_server = CommandsServer(self, fd) daemon_jobs.append(self.commands_server.run()) - # server-side watchtower - self.watchtower = None - watchtower_address = self.config.get_netaddress('watchtower_address') - if not config.get('offline') and watchtower_address: - self.watchtower = WatchTowerServer(self.network, watchtower_address) - daemon_jobs.append(self.watchtower.run) if self.network: self.network.start(jobs=[self.fx.run]) # prepare lightning functionality, also load channel db early diff --git a/electrum/gui/qt/__init__.py b/electrum/gui/qt/__init__.py index 2ff8ce6692fb..313ee4f65dfa 100644 --- a/electrum/gui/qt/__init__.py +++ b/electrum/gui/qt/__init__.py @@ -199,8 +199,6 @@ def build_tray_menu(self): m.addAction(_("Network"), self.show_network_dialog) if network and network.lngossip: m.addAction(_("Lightning Network"), self.show_lightning_dialog) - if network and network.local_watchtower: - m.addAction(_("Local Watchtower"), self.show_watchtower_dialog) for window in self.windows: name = window.wallet.basename() submenu = m.addMenu(name) @@ -285,11 +283,6 @@ def show_lightning_dialog(self): self.lightning_dialog = LightningDialog(self) self.lightning_dialog.bring_to_top() - def show_watchtower_dialog(self): - if not self.watchtower_dialog: - self.watchtower_dialog = WatchtowerDialog(self) - self.watchtower_dialog.bring_to_top() - def show_network_dialog(self): if self.network_dialog: self.network_dialog.on_event_network_updated() diff --git a/electrum/gui/qt/main_window.py b/electrum/gui/qt/main_window.py index 3019968a6715..b12686de49ea 100644 --- a/electrum/gui/qt/main_window.py +++ b/electrum/gui/qt/main_window.py @@ -738,8 +738,6 @@ def add_toggle_action(view_menu, tab): tools_menu.addAction(_("Electrum preferences"), self.settings_dialog) tools_menu.addAction(_("&Network"), self.gui_object.show_network_dialog).setEnabled(bool(self.network)) - if self.network and self.network.local_watchtower: - tools_menu.addAction(_("Local &Watchtower"), self.gui_object.show_watchtower_dialog) tools_menu.addAction(_("&Plugins"), self.plugins_dialog) tools_menu.addSeparator() tools_menu.addAction(_("&Sign/verify message"), self.sign_verify_message) diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 5e27438438c8..db166384aa38 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -374,7 +374,7 @@ async def broadcast_or_log(self, funding_outpoint: str, tx: Transaction): return txid async def get_ctn(self, outpoint, addr): - if addr not in self.callbacks.keys(): + if addr and addr not in self.callbacks.keys(): self.logger.info(f'watching new channel: {outpoint} {addr}') self.add_channel(outpoint, addr) return await self.sweepstore.get_ctn(outpoint, addr) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 81a30b6c442f..add5109274c4 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -709,48 +709,29 @@ def diagnostic_name(self): @ignore_exceptions @log_exceptions - async def sync_with_local_watchtower(self): - watchtower = self.network.local_watchtower - if watchtower: - while True: - for chan in self.channels.values(): - await self.sync_channel_with_watchtower(chan, watchtower.sweepstore) - await asyncio.sleep(5) - - @ignore_exceptions - @log_exceptions - async def sync_with_remote_watchtower(self): + async def sync_with_watchtower(self): while True: # periodically poll if the user updated 'watchtower_url' await asyncio.sleep(5) watchtower_url = self.config.get('watchtower_url') if not watchtower_url: continue - parsed_url = urllib.parse.urlparse(watchtower_url) - if not (parsed_url.scheme == 'https' or is_private_netaddress(parsed_url.hostname)): - self.logger.warning(f"got watchtower URL for remote tower but we won't use it! " - f"can only use HTTPS (except if private IP): not using {watchtower_url!r}") - continue # try to sync with the remote watchtower try: - async with make_aiohttp_session(proxy=self.network.proxy) as session: - watchtower = JsonRPCClient(session, watchtower_url) - watchtower.add_method('get_ctn') - watchtower.add_method('add_sweep_tx') - for chan in self.channels.values(): - await self.sync_channel_with_watchtower(chan, watchtower) + for chan in self.channels.values(): + await self.sync_channel_with_watchtower(chan) except aiohttp.client_exceptions.ClientConnectorError: self.logger.info(f'could not contact remote watchtower {watchtower_url}') - async def sync_channel_with_watchtower(self, chan: Channel, watchtower): + async def sync_channel_with_watchtower(self, chan: Channel): outpoint = chan.funding_outpoint.to_str() addr = chan.get_funding_address() current_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) - watchtower_ctn = await watchtower.get_ctn(outpoint, addr) + watchtower_ctn = await self.network.watchtower_get_ctn(outpoint, addr) for ctn in range(watchtower_ctn + 1, current_ctn): sweeptxs = chan.create_sweeptxs(ctn) for tx in sweeptxs: - await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize()) + await self.network.watchtower_add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize()) def start_network(self, network: 'Network'): super().start_network(network) @@ -769,8 +750,7 @@ def start_network(self, network: 'Network'): self.maybe_listen(), self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified self.reestablish_peers_and_channels(), - self.sync_with_local_watchtower(), - self.sync_with_remote_watchtower(), + self.sync_with_watchtower(), ]: tg_coro = self.taskgroup.spawn(coro) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) diff --git a/electrum/network.py b/electrum/network.py index ddbdfdb174de..71b27e699657 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -69,7 +69,6 @@ from .channel_db import ChannelDB from .lnrouter import LNPathFinder from .lnworker import LNGossip - from .lnwatcher import WatchTower from .daemon import Daemon @@ -260,7 +259,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): channel_db: Optional['ChannelDB'] = None lngossip: Optional['LNGossip'] = None - local_watchtower: Optional['WatchTower'] = None path_finder: Optional['LNPathFinder'] = None def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None): @@ -348,13 +346,6 @@ def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None): self._set_status('disconnected') self._has_ever_managed_to_connect_to_server = False - # lightning network - if self.config.get('run_watchtower', False): - from . import lnwatcher - self.local_watchtower = lnwatcher.WatchTower(self) - self.local_watchtower.adb.start_network(self) - asyncio.ensure_future(self.local_watchtower.start_watching()) - def has_internet_connection(self) -> bool: """Our guess whether the device has Internet-connectivity.""" return self._has_ever_managed_to_connect_to_server @@ -482,6 +473,12 @@ async def _request_fee_estimates(self, interface): self.logger.info(f'fee_histogram {histogram}') self.notify('fee_histogram') + async def watchtower_get_ctn(self, *args): + return await self.interface.session.send_request('watchtower.get_ctn', args) + + async def watchtower_add_sweep_tx(self, *args): + return await self.interface.session.send_request('watchtower.add_sweep_tx', args) + def get_status_value(self, key): if key == 'status': value = self.connection_status diff --git a/electrum/tests/regtest.py b/electrum/tests/regtest.py index 53eaa3e11391..307fa1e9f41c 100644 --- a/electrum/tests/regtest.py +++ b/electrum/tests/regtest.py @@ -65,9 +65,5 @@ def test_breach_with_unspent_htlc(self): def test_breach_with_spent_htlc(self): self.run_shell(['breach_with_spent_htlc']) - -class TestLightningABC(TestLightning): - agents = ['alice', 'bob', 'carol'] - def test_watchtower(self): self.run_shell(['watchtower']) diff --git a/electrum/tests/regtest/regtest.sh b/electrum/tests/regtest/regtest.sh index 38df9650949f..80b57bfcdff8 100755 --- a/electrum/tests/regtest/regtest.sh +++ b/electrum/tests/regtest/regtest.sh @@ -330,11 +330,6 @@ fi if [[ $1 == "configure_test_watchtower" ]]; then - # carol is the watchtower of bob - $carol setconfig -o run_watchtower true - $carol setconfig -o watchtower_user wtuser - $carol setconfig -o watchtower_password wtpassword - $carol setconfig -o watchtower_address 127.0.0.1:12345 $bob setconfig -o watchtower_url http://wtuser:wtpassword@127.0.0.1:12345 fi @@ -356,7 +351,7 @@ if [[ $1 == "watchtower" ]]; then alice_ctn=$($alice list_channels | jq '.[0].local_ctn') msg="waiting until watchtower is synchronized" # watchtower needs to be at latest revoked ctn - while watchtower_ctn=$($carol get_watchtower_ctn $channel) && [[ $watchtower_ctn != $((alice_ctn-1)) ]]; do + while watchtower_ctn=$($bob get_watchtower_ctn $channel) && [[ $watchtower_ctn != $((alice_ctn-1)) ]]; do sleep 0.1 printf "$msg $alice_ctn $watchtower_ctn\r" done