Skip to content

Commit

Permalink
Make global_futures and futures public
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Mar 3, 2023
1 parent f0ca3cc commit 188dbf4
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def test_start(transfer: Transfer):

assert transfer.started
assert transfer.peer in transfer.container
assert len(transfer.task_group._futures) == 2
assert len(transfer.task_group.futures) == 2


async def test_double_start(transfer: Transfer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def coro():
def total_coro_count():
count = 0
for endpoint in child_endpoints + [root_endpoint]:
count += len(endpoint.async_group._futures)
count += len(endpoint.async_group.futures)
return count

assert total_coro_count() == 3
Expand Down
21 changes: 11 additions & 10 deletions src/tribler/core/utilities/async_group/async_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

def done_callback(group_ref):
def actual_callback(future):
AsyncGroup._global_futures.discard(future)
AsyncGroup.global_futures.discard(future)
group: Optional[AsyncGroup] = group_ref()
if group is not None:
group._futures.discard(future)
group.futures.discard(future)

return actual_callback


Expand Down Expand Up @@ -46,12 +47,12 @@ class AsyncGroup:
# But theoretically, some async groups can be garbage collected too early.
#
# To prevent this problem all futures stores in the global set.
_global_futures: Set[Future] = set()
global_futures: Set[Future] = set()

def __init__(self):
self._logger = logging.getLogger(self.__class__.__name__)
self.ref = ref(self)
self._futures: Set[Future] = set()
self.futures: Set[Future] = set()
self._canceled = False

def add_task(self, coroutine: Coroutine) -> Task:
Expand All @@ -63,17 +64,17 @@ def add_task(self, coroutine: Coroutine) -> Task:
task.cancel()
raise CanceledException()

self._futures.add(task)
self._global_futures.add(task)
self.futures.add(task)
self.global_futures.add(task)

task.add_done_callback(done_callback(self.ref))
return task

async def wait(self):
""" Wait for completion of all futures
"""
if self._futures:
await asyncio.wait(self._futures)
if self.futures:
await asyncio.wait(self.futures)

async def cancel(self) -> List[Future]:
"""Cancel the group.
Expand All @@ -85,7 +86,7 @@ async def cancel(self) -> List[Future]:

self._canceled = True

active = list(self._active(self._futures))
active = list(self._active(self.futures))
for future in active:
future.cancel()

Expand All @@ -103,7 +104,7 @@ def _active(futures: Iterable[Future]) -> Iterable[Future]:
return (future for future in futures if not future.done())

def __del__(self):
if active := list(self._active(self._futures)):
if active := list(self._active(self.futures)):
self._logger.error(f'AsyncGroup is destroying but {len(active)} futures are active')
for future in active:
future.cancel()
35 changes: 17 additions & 18 deletions src/tribler/core/utilities/async_group/tests/test_async_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import gc
from asyncio import gather
from contextlib import suppress
from weakref import ref

Expand All @@ -17,7 +16,7 @@
async def group():
# When test is just started, the global set of futures should be empty.
# If not, they are the futures leaked from the previous test
assert not AsyncGroup._global_futures
assert not AsyncGroup.global_futures

g = AsyncGroup()

Expand All @@ -26,14 +25,14 @@ async def group():
if not g._canceled:
await g.cancel()

if AsyncGroup._global_futures:
if AsyncGroup.global_futures:
# It is possible that after the group was canceled, some of its futures were canceled as well,
# but their done_callbacks were not executed yet. Here we give these futures a chance to execute
# their done_callbacks and remove themselves from the global set of futures
await asyncio.sleep(0.01)

# The test should not leave unfinished futures at the end
assert not AsyncGroup._global_futures
assert not AsyncGroup.global_futures


async def void():
Expand All @@ -51,7 +50,7 @@ async def raise_exception():
async def test_add_task(group: AsyncGroup):
task = group.add_task(void())

assert len(group._futures) == 1
assert len(group.futures) == 1
assert task


Expand Down Expand Up @@ -79,13 +78,13 @@ async def test_wait(group: AsyncGroup):
group.add_task(sleep_1s())

await group.wait()
assert not group._futures
assert not group.futures


async def test_wait_no_futures(group: AsyncGroup):
"""Ensure that awe can wait for the futures even there are no futures"""
await group.wait()
assert not group._futures
assert not group.futures


async def test_double_cancel(group: AsyncGroup):
Expand All @@ -108,7 +107,7 @@ async def test_cancel_completed_task(group: AsyncGroup):
await asyncio.gather(*completed)

active = asyncio.create_task(void())
group._futures = completed + [active]
group.futures = completed + [active]

cancelled = await group.cancel()

Expand All @@ -126,12 +125,12 @@ async def test_auto_cleanup(group: AsyncGroup):
for f in functions:
for _ in range(100):
group.add_task(f())
assert len(group._futures) == 300
assert len(group.futures) == 300

with suppress(ValueError):
await asyncio.gather(*group._futures, return_exceptions=True)
await asyncio.gather(*group.futures, return_exceptions=True)

assert not group._futures
assert not group.futures


async def test_del_error(group: AsyncGroup, caplog: LogCaptureFixture):
Expand All @@ -141,7 +140,7 @@ async def test_del_error(group: AsyncGroup, caplog: LogCaptureFixture):
"""
group.add_task(void())
group.__del__()
assert f'AsyncGroup is destroying but 1 futures are active' in caplog.text
assert 'AsyncGroup is destroying but 1 futures are active' in caplog.text


async def test_del_no_error(group: AsyncGroup, caplog: LogCaptureFixture):
Expand All @@ -152,11 +151,11 @@ async def test_del_no_error(group: AsyncGroup, caplog: LogCaptureFixture):
group.add_task(void())
await group.wait()
group.__del__()
assert f'AsyncGroup is destroying but 1 futures are active' not in caplog.text
assert 'AsyncGroup is destroying but 1 futures are active' not in caplog.text


async def test_gc_error(caplog: LogCaptureFixture):
assert not AsyncGroup._global_futures
assert not AsyncGroup.global_futures

async def infinite_loop():
while True:
Expand All @@ -168,7 +167,7 @@ async def infinite_loop():
group = AsyncGroup()
group.add_task(task1)
group.add_task(task2)
assert len(AsyncGroup._global_futures) == 2
assert len(AsyncGroup.global_futures) == 2

group_ref = ref(group)
del group
Expand All @@ -180,12 +179,12 @@ async def infinite_loop():
assert 'AsyncGroup is destroying but 2 futures are active' in text

await asyncio.sleep(0.01)
assert not AsyncGroup._global_futures
assert not AsyncGroup.global_futures


async def test_group_fixture():
# There should be no active futures before the test
assert not AsyncGroup._global_futures
assert not AsyncGroup.global_futures

# We want to test the `group` fixture itself. Pytest does not allow to call fixture functions directly,
# so we access fixture implementation using a Pytest internal attribute for that
Expand All @@ -207,4 +206,4 @@ async def test_group_fixture():
await asyncio.sleep(0)

# There should be no active futures after the test
assert not AsyncGroup._global_futures
assert not AsyncGroup.global_futures

0 comments on commit 188dbf4

Please sign in to comment.