Skip to content

Commit

Permalink
Add close() and wait_closed()
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Jun 10, 2015
1 parent 0187334 commit b77ca59
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 15 deletions.
54 changes: 39 additions & 15 deletions mixedqueue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import asyncio

from heapq import heappush, heappop
import logging
import threading

from asyncio import QueueEmpty as AsyncQueueEmpty
from asyncio import QueueFull as AsyncQueueFull
from collections import deque
from queue import Empty as SyncQueueEmpty, Full as SyncQueueFull
from asyncio import (QueueEmpty as AsyncQueueEmpty, QueueFull as
AsyncQueueFull)
from heapq import heappop, heappush
from queue import Empty as SyncQueueEmpty
from queue import Full as SyncQueueFull
from time import monotonic

__version__ = '0.0.1'

log = logging.getLogger(__package__)


class Queue:
def __init__(self, maxsize=0, *, loop=None):
Expand Down Expand Up @@ -38,6 +40,15 @@ def __init__(self, maxsize=0, *, loop=None):
self._sync_queue = SyncQueue(self)
self._async_queue = AsyncQueue(self)

self._pending = set()

def close(self):
pass

@asyncio.coroutine
def wait_closed(self):
yield from asyncio.wait(self._pending, loop=self._loop)

@property
def maxsize(self):
return self._maxsize
Expand Down Expand Up @@ -77,37 +88,49 @@ def _notify_sync_not_empty(self):
def f():
with self._sync_not_empty:
self._sync_not_empty.notify()

self._loop.run_in_executor(None, f)

def _notify_sync_not_full(self):
def f():
with self._sync_not_full:
self._sync_not_full.notify()
self._loop.run_in_executor(None, f)

fut = self._loop.run_in_executor(None, f)
fut.add_done_callback(self._pending.discard)
self._pending.add(fut)

def _notify_async_not_empty(self, *, threadsafe):
@asyncio.coroutine
def f():
with (yield from self._async_not_empty):
self._async_not_empty.notify()

task = asyncio.async(f(), loop=self._loop)
def task_maker():
task = asyncio.async(f(), loop=self._loop)
task.add_done_callback(self._pending.discard)
self._pending.add(task)

if threadsafe:
self._loop.call_soon_threadsafe(task)
self._loop.call_soon_threadsafe(task_maker)
else:
self._loop.call_soon(task)
self._loop.call_soon(task_maker)

def _notify_async_not_full(self, *, threadsafe):
@asyncio.coroutine
def f():
with (yield from self._async_not_full):
self._async_not_full.notify()

task = asyncio.async(f(), loop=self._loop)
def task_maker():
task = asyncio.async(f(), loop=self._loop)
task.add_done_callback(self._pending.discard)
self._pending.add(task)

if threadsafe:
self._loop.call_soon_threadsafe(task)
self._loop.call_soon_threadsafe(task_maker)
else:
self._loop.call_soon(task)
self._loop.call_soon(task_maker)


class SyncQueue:
Expand Down Expand Up @@ -316,8 +339,9 @@ def put(self, item):
if self._parent._maxsize > 0:
do_wait = True
while do_wait:
do_wait = (self._parent._qsize() >=
self._parent._maxsize)
do_wait = (
self._parent._qsize() >= self._parent._maxsize
)
if do_wait:
locked = False
self._parent._sync_mutex.release()
Expand Down
78 changes: 78 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def test_empty(self):
self.assertEqual(1, q.get_nowait())
self.assertTrue(q.empty())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_full(self):
_q = mixedqueue.Queue(loop=self.loop)
q = _q.async_queue
Expand All @@ -95,6 +98,9 @@ def test_full(self):
q.put_nowait(1)
self.assertTrue(q.full())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_order(self):
_q = mixedqueue.Queue(loop=self.loop)
q = _q.async_queue
Expand All @@ -104,6 +110,9 @@ def test_order(self):
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 3, 2], items)

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_maxsize(self):
def gen():
when = yield
Expand Down Expand Up @@ -147,6 +156,9 @@ def test():
loop.run_until_complete(test())
self.assertAlmostEqual(0.02, loop.time())

_q.close()
self.loop.run_until_complete(_q.wait_closed())


class QueueGetTests(_QueueTestBase):
def test_blocking_get(self):
Expand Down Expand Up @@ -183,6 +195,9 @@ def put():
self.loop.run_until_complete(t)
self.assertEqual(1, q.qsize())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_blocking_get_wait(self):
def gen():
when = yield
Expand Down Expand Up @@ -218,16 +233,25 @@ def queue_put():
self.assertEqual(1, res)
self.assertAlmostEqual(0.01, loop.time())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_nonblocking_get(self):
_q = mixedqueue.Queue(loop=self.loop)
q = _q.async_queue
q.put_nowait(1)
self.assertEqual(1, q.get_nowait())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_nonblocking_get_exception(self):
_q = mixedqueue.Queue(loop=self.loop)
self.assertRaises(asyncio.QueueEmpty, _q.async_queue.get_nowait)

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_get_cancelled(self):
def gen():
when = yield
Expand Down Expand Up @@ -255,6 +279,9 @@ def test():
self.assertEqual(1, loop.run_until_complete(test()))
self.assertAlmostEqual(0.06, loop.time())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_get_cancelled_race(self):
_q = mixedqueue.Queue(loop=self.loop)
q = _q.async_queue
Expand All @@ -270,6 +297,9 @@ def test_get_cancelled_race(self):
test_utils.run_briefly(self.loop)
self.assertEqual(t2.result(), 'a')

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_get_with_waiting_putters(self):
_q = mixedqueue.Queue(loop=self.loop, maxsize=1)
q = _q.async_queue
Expand All @@ -280,6 +310,9 @@ def test_get_with_waiting_putters(self):
self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
self.assertEqual(self.loop.run_until_complete(q.get()), 'b')

_q.close()
self.loop.run_until_complete(_q.wait_closed())


class QueuePutTests(_QueueTestBase):
def test_blocking_put(self):
Expand All @@ -293,6 +326,9 @@ def queue_put():

self.loop.run_until_complete(queue_put())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_blocking_put_wait(self):
def gen():
when = yield
Expand Down Expand Up @@ -326,18 +362,27 @@ def queue_get():
loop.run_until_complete(queue_get())
self.assertAlmostEqual(0.01, loop.time())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_nonblocking_put(self):
_q = mixedqueue.Queue(loop=self.loop)
q = _q.async_queue
q.put_nowait(1)
self.assertEqual(1, q.get_nowait())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_nonblocking_put_exception(self):
_q = mixedqueue.Queue(maxsize=1, loop=self.loop)
q = _q.async_queue
q.put_nowait(1)
self.assertRaises(asyncio.QueueFull, q.put_nowait, 2)

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_float_maxsize(self):
_q = mixedqueue.Queue(maxsize=1.3, loop=self.loop)
q = _q.async_queue
Expand All @@ -346,6 +391,9 @@ def test_float_maxsize(self):
self.assertTrue(q.full())
self.assertRaises(asyncio.QueueFull, q.put_nowait, 3)

_q.close()
self.loop.run_until_complete(_q.wait_closed())

_q = mixedqueue.Queue(maxsize=1.3, loop=self.loop)
q = _q.async_queue

Expand All @@ -357,6 +405,9 @@ def queue_put():

self.loop.run_until_complete(queue_put())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_put_cancelled(self):
_q = mixedqueue.Queue(loop=self.loop)
q = _q.async_queue
Expand All @@ -375,6 +426,9 @@ def test():
self.assertTrue(t.done())
self.assertTrue(t.result())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_put_cancelled_race(self):
_q = mixedqueue.Queue(loop=self.loop, maxsize=1)
q = _q.async_queue
Expand All @@ -395,6 +449,9 @@ def test_put_cancelled_race(self):

self.loop.run_until_complete(put_b)

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_put_with_waiting_getters(self):
fut = asyncio.Future(loop=self.loop)

Expand All @@ -415,6 +472,9 @@ def put():
self.loop.run_until_complete(put())
self.assertEqual(self.loop.run_until_complete(t), 'a')

_q.close()
self.loop.run_until_complete(_q.wait_closed())


class LifoQueueTests(_QueueTestBase):
def test_order(self):
Expand All @@ -426,6 +486,9 @@ def test_order(self):
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([2, 3, 1], items)

_q.close()
self.loop.run_until_complete(_q.wait_closed())


class PriorityQueueTests(_QueueTestBase):
def test_order(self):
Expand All @@ -437,6 +500,9 @@ def test_order(self):
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 2, 3], items)

_q.close()
self.loop.run_until_complete(_q.wait_closed())


class _QueueJoinTestMixin:

Expand All @@ -447,6 +513,9 @@ def test_task_done_underflow(self):
q = _q.async_queue
self.assertRaises(ValueError, q.task_done)

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_task_done(self):
_q = self.q_class(loop=self.loop)
q = _q.async_queue
Expand Down Expand Up @@ -485,6 +554,9 @@ def test():
q.put_nowait(0)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))

_q.close()
self.loop.run_until_complete(_q.wait_closed())

def test_join_empty_queue(self):
_q = self.q_class(loop=self.loop)
q = _q.async_queue
Expand All @@ -499,6 +571,9 @@ def join():

self.loop.run_until_complete(join())

_q.close()
self.loop.run_until_complete(_q.wait_closed())

@unittest.expectedFailure
def test_format(self):
_q = self.q_class(loop=self.loop)
Expand All @@ -508,6 +583,9 @@ def test_format(self):
q._unfinished_tasks = 2
self.assertEqual(q._format(), 'maxsize=0 tasks=2')

_q.close()
self.loop.run_until_complete(_q.wait_closed())


class QueueJoinTests(_QueueJoinTestMixin, _QueueTestBase):
q_class = mixedqueue.Queue
Expand Down

0 comments on commit b77ca59

Please sign in to comment.