Skip to content

Commit

Permalink
asyncio.wait_for() is too buggy. use util.wait_for2() instead
Browse files Browse the repository at this point in the history
wasted some time because asyncio.wait_for() was suppressing cancellations. [0][1][2]
deja vu... [3]

Looks like this is finally getting fixed in cpython 3.12 [4]
So far away...
In attempt to avoid encountering this again, let's try using
asyncio.timeout in 3.11, which is how upstream reimplemented wait_for in 3.12 [4], and
aiorpcx.timeout_after in 3.8-3.10.

[0] python/cpython#86296
[1] https://bugs.python.org/issue42130
[2] https://bugs.python.org/issue45098
[3] kyuupichan/aiorpcX#44
[4] python/cpython#98518
  • Loading branch information
SomberNight committed Aug 4, 2023
1 parent 20f4d44 commit d51f00e
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 31 deletions.
2 changes: 1 addition & 1 deletion electrum/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def send_request(self, *args, timeout=None, **kwargs):
try:
# note: RPCSession.send_request raises TaskTimeout in case of a timeout.
# TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
response = await asyncio.wait_for(
response = await util.wait_for2(
super().send_request(*args, **kwargs),
timeout)
except (TaskTimeout, asyncio.TimeoutError) as e:
Expand Down
11 changes: 6 additions & 5 deletions electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import aiorpcx
from aiorpcx import ignore_after
from async_timeout import timeout

from .crypto import sha256, sha256d
from . import bitcoin, util
Expand Down Expand Up @@ -331,7 +332,7 @@ def on_pong(self, payload):

async def wait_for_message(self, expected_name: str, channel_id: bytes):
q = self.ordered_message_queues[channel_id]
name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT)
name, payload = await util.wait_for2(q.get(), LN_P2P_NETWORK_TIMEOUT)
# raise exceptions for errors, so that the caller sees them
if (err_bytes := payload.get("error")) is not None:
err_text = error_text_bytes_to_safe_str(err_bytes)
Expand Down Expand Up @@ -460,12 +461,12 @@ async def process_gossip(self):

async def query_gossip(self):
try:
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT)
except Exception as e:
raise GracefulDisconnect(f"Failed to initialize: {e!r}") from e
if self.lnworker == self.lnworker.network.lngossip:
try:
ids, complete = await asyncio.wait_for(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT)
ids, complete = await util.wait_for2(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT)
except asyncio.TimeoutError as e:
raise GracefulDisconnect("query_channel_range timed out") from e
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
Expand Down Expand Up @@ -575,7 +576,7 @@ def query_short_channel_ids(self, ids, compressed=True):

async def _message_loop(self):
try:
await asyncio.wait_for(self.initialize(), LN_P2P_NETWORK_TIMEOUT)
await util.wait_for2(self.initialize(), LN_P2P_NETWORK_TIMEOUT)
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
raise GracefulDisconnect(f'initialize failed: {repr(e)}') from e
async for msg in self.transport.read_messages():
Expand Down Expand Up @@ -699,7 +700,7 @@ async def channel_establishment_flow(
Channel configurations are initialized in this method.
"""
# will raise if init fails
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT)
# trampoline is not yet in features
if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey):
raise Exception('Not a trampoline node: ' + str(self.their_features))
Expand Down
2 changes: 1 addition & 1 deletion electrum/lnworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ async def _open_channel_coroutine(
funding_sat=funding_sat,
push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32))
chan, funding_tx = await asyncio.wait_for(coro, LN_P2P_NETWORK_TIMEOUT)
chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT)
util.trigger_callback('channels_updated', self.wallet)
self.wallet.adb.add_transaction(funding_tx) # save tx as local into the wallet
self.wallet.sign_transaction(funding_tx, password)
Expand Down
4 changes: 2 additions & 2 deletions electrum/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ async def _run_new_interface(self, server: ServerAddr):
# note: using longer timeouts here as DNS can sometimes be slow!
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
try:
await asyncio.wait_for(interface.ready, timeout)
await util.wait_for2(interface.ready, timeout)
except BaseException as e:
self.logger.info(f"couldn't launch iface {server} -- {repr(e)}")
await interface.close()
Expand Down Expand Up @@ -1401,7 +1401,7 @@ async def send_multiple_requests(
async def get_response(server: ServerAddr):
interface = Interface(network=self, server=server, proxy=self.proxy)
try:
await asyncio.wait_for(interface.ready, timeout)
await util.wait_for2(interface.ready, timeout)
except BaseException as e:
await interface.close()
return
Expand Down
3 changes: 2 additions & 1 deletion electrum/plugins/payserver/payserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aiohttp import web
from aiorpcx import NetAddress

from electrum import util
from electrum.util import log_exceptions, ignore_exceptions
from electrum.plugin import BasePlugin, hook
from electrum.logging import Logger
Expand Down Expand Up @@ -173,7 +174,7 @@ async def get_status(self, request):
return ws
while True:
try:
await asyncio.wait_for(self.pending[key].wait(), 1)
await util.wait_for2(self.pending[key].wait(), 1)
break
except asyncio.TimeoutError:
# send data on the websocket, to keep it alive
Expand Down
4 changes: 2 additions & 2 deletions electrum/scripts/ln_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from electrum.logging import get_logger, configure_logging
from electrum.simple_config import SimpleConfig
from electrum import constants
from electrum import constants, util
from electrum.daemon import Daemon
from electrum.wallet import create_new_wallet
from electrum.util import create_and_start_event_loop, log_exceptions, bfh
Expand Down Expand Up @@ -84,7 +84,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag):
print(f"worker connecting to {connect_str}")
try:
peer = await wallet.lnworker.add_peer(connect_str)
res = await asyncio.wait_for(peer.initialized, TIMEOUT)
res = await util.wait_for2(peer.initialized, TIMEOUT)
if res:
if peer.features & flag == work['features'] & flag:
await results_queue.put(True)
Expand Down
36 changes: 18 additions & 18 deletions electrum/tests/test_lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,8 @@ async def test_payment_race(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def pay():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# prep
_maybe_send_commitment1 = p1.maybe_send_commitment
_maybe_send_commitment2 = p2.maybe_send_commitment
Expand Down Expand Up @@ -1374,8 +1374,8 @@ async def _test_shutdown(self, alice_fee, bob_fee, alice_fee_range=None, bob_fee
w2.enable_htlc_settle = False
lnaddr, pay_req = self.prepare_invoice(w2)
async def pay():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# alice sends htlc
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
p1.pay(route=route,
Expand All @@ -1401,8 +1401,8 @@ async def test_warning(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def action():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True)
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(GracefulDisconnect):
Expand All @@ -1414,8 +1414,8 @@ async def test_error(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def action():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True)
assert alice_channel.is_closed()
gath.cancel()
Expand Down Expand Up @@ -1447,8 +1447,8 @@ async def test_close_upfront_shutdown_script(self):

async def test():
async def close():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# bob closes channel with different shutdown script
await p1.close_channel(alice_channel.channel_id)
gath.cancel()
Expand Down Expand Up @@ -1477,8 +1477,8 @@ async def main_loop(peer):

async def test():
async def close():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
await p1.close_channel(alice_channel.channel_id)
gath.cancel()

Expand Down Expand Up @@ -1538,8 +1538,8 @@ async def test_sending_weird_messages_that_should_be_ignored(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def send_weird_messages():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# peer1 sends known message with trailing garbage
# BOLT-01 says peer2 should ignore trailing garbage
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4) + bytes(range(55))
Expand Down Expand Up @@ -1570,8 +1570,8 @@ async def test_sending_weird_messages__unknown_even_type(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def send_weird_messages():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# peer1 sends unknown 'even-type' message
# BOLT-01 says peer2 should close the connection
raw_msg2 = (43334).to_bytes(length=2, byteorder="big") + bytes(range(55))
Expand Down Expand Up @@ -1600,8 +1600,8 @@ async def test_sending_weird_messages__known_msg_with_insufficient_length(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)

async def send_weird_messages():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# peer1 sends known message with insufficient length for the contents
# BOLT-01 says peer2 should fail the connection
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4)[:-1]
Expand Down
32 changes: 31 additions & 1 deletion electrum/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import os, sys, re, json
from collections import defaultdict, OrderedDict
from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any,
Sequence, Dict, Generic, TypeVar, List, Iterable, Set)
Sequence, Dict, Generic, TypeVar, List, Iterable, Set, Awaitable)
from datetime import datetime
import decimal
from decimal import Decimal
Expand Down Expand Up @@ -1371,6 +1371,36 @@ def _aiorpcx_monkeypatched_unset_task_deadline(task):
aiorpcx.curio._unset_task_deadline = _aiorpcx_monkeypatched_unset_task_deadline


async def wait_for2(fut: Awaitable, timeout: Union[int, float, None]):
"""Replacement for asyncio.wait_for,
due to bugs: https://bugs.python.org/issue42130 and https://github.com/python/cpython/issues/86296 ,
which are only fixed in python 3.12+.
"""
if sys.version_info[:3] >= (3, 12):
return await asyncio.wait_for(fut, timeout)
else:
async with async_timeout(timeout):
return await asyncio.ensure_future(fut, loop=get_running_loop())


if hasattr(asyncio, 'timeout'): # python 3.11+
async_timeout = asyncio.timeout
else:
class TimeoutAfterAsynciolike(aiorpcx.curio.TimeoutAfter):
async def __aexit__(self, exc_type, exc_value, traceback):
try:
await super().__aexit__(exc_type, exc_value, traceback)
except (aiorpcx.TaskTimeout, aiorpcx.UncaughtTimeoutError):
raise asyncio.TimeoutError from None
except aiorpcx.TimeoutCancellationError:
raise asyncio.CancelledError from None

def async_timeout(delay: Union[int, float, None]):
if delay is None:
return nullcontext()
return TimeoutAfterAsynciolike(delay)


class NetworkJobOnDefaultServer(Logger, ABC):
"""An abstract base class for a job that runs on the main network
interface. Every time the main interface changes, the job is
Expand Down

0 comments on commit d51f00e

Please sign in to comment.