diff --git a/pyproject.toml b/pyproject.toml index e0590821..f87234a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hat-util" -version = "0.6.13" +version = "0.6.14" description = "Hat utility library" readme = "README.rst" requires-python = ">=3.10" @@ -16,10 +16,10 @@ Repository = "https://github.com/hat-open/hat-util.git" Documentation = "http://hat-util.hat-open.com" [project.optional-dependencies] -dev = ["hat-doit ~=0.15.11"] +dev = ["hat-doit ~=0.15.12"] [build-system] -requires = ["hat-doit ~=0.15.11"] +requires = ["hat-doit ~=0.15.12"] build-backend = "hat.doit.pep517" [tool.hat-doit] diff --git a/requirements.pip.txt b/requirements.pip.txt index d014a835..3e8cf719 100644 --- a/requirements.pip.txt +++ b/requirements.pip.txt @@ -1 +1 @@ -hat-doit ~=0.15.11 +hat-doit ~=0.15.12 diff --git a/src_py/hat/util.py b/src_py/hat/util.py deleted file mode 100644 index c50abdaf..00000000 --- a/src_py/hat/util.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Common utility functions""" - -import collections -import contextlib -import datetime -import inspect -import socket -import sqlite3 -import typing -import warnings - - -T = typing.TypeVar('T') - -Bytes: typing.TypeAlias = bytes | bytearray | memoryview - - -def register_type_alias(name: str): - """Register type alias - - This function is temporary hack replacement for typing.TypeAlias. - - It is expected that calling location will have `name` in local namespace - with type value. This function will wrap that type inside `typing.TypeVar` - and update annotations. - - """ - warnings.warn("use typing.TypeAlias", DeprecationWarning) - frame = inspect.stack()[1][0] - f_locals = frame.f_locals - t = f_locals[name] - f_locals[name] = typing.TypeVar(name, t, t) - f_locals.setdefault('__annotations__', {})[name] = typing.Type[t] - - -def first(xs: typing.Iterable[T], - fn: typing.Callable[[T], typing.Any] = lambda _: True, - default: T | None = None - ) -> T | None: - """Return the first element from iterable that satisfies predicate `fn`, - or `default` if no such element exists. - - Result of predicate `fn` can be of any type. Predicate is satisfied if it's - return value is truthy. - - Args: - xs: collection - fn: predicate - default: default value - - Example:: - - assert first(range(3)) == 0 - assert first(range(3), lambda x: x > 1) == 2 - assert first(range(3), lambda x: x > 2) is None - assert first(range(3), lambda x: x > 2, 123) == 123 - assert first({1: 'a', 2: 'b', 3: 'c'}) == 1 - assert first([], default=123) == 123 - - """ - return next((i for i in xs if fn(i)), default) - - -class RegisterCallbackHandle(typing.NamedTuple): - """Handle for canceling callback registration.""" - - cancel: typing.Callable[[], None] - """cancel callback registration""" - - def __enter__(self): - return self - - def __exit__(self, *args): - self.cancel() - - -ExceptionCb: typing.TypeAlias = typing.Callable[[Exception], None] -"""Exception callback""" - - -class CallbackRegistry: - """Registry that enables callback registration and notification. - - Callbacks in the registry are notified sequentially with - `CallbackRegistry.notify`. If a callback raises an exception, the - exception is caught and `exception_cb` handler is called. Notification of - subsequent callbacks is not interrupted. If handler is `None`, the - exception is reraised and no subsequent callback is notified. - - Example:: - - x = [] - y = [] - registry = CallbackRegistry() - - registry.register(x.append) - registry.notify(1) - - with registry.register(y.append): - registry.notify(2) - - registry.notify(3) - - assert x == [1, 2, 3] - assert y == [2] - - """ - - def __init__(self, - exception_cb: ExceptionCb | None = None): - self._exception_cb = exception_cb - self._cbs = [] # type: list[Callable] - - def register(self, - cb: typing.Callable - ) -> RegisterCallbackHandle: - """Register a callback.""" - self._cbs.append(cb) - return RegisterCallbackHandle(lambda: self._cbs.remove(cb)) - - def notify(self, *args, **kwargs): - """Notify all registered callbacks.""" - for cb in self._cbs: - try: - cb(*args, **kwargs) - except Exception as e: - if self._exception_cb: - self._exception_cb(e) - else: - raise - - -def get_unused_tcp_port(host: str = '127.0.0.1') -> int: - """Search for unused TCP port""" - with contextlib.closing(socket.socket()) as sock: - sock.bind((host, 0)) - return sock.getsockname()[1] - - -def get_unused_udp_port(host: str = '127.0.0.1') -> int: - """Search for unused UDP port""" - with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock: - sock.bind((host, 0)) - return sock.getsockname()[1] - - -class BytesBuffer: - """Bytes buffer - - All data added to BytesBuffer is considered immutable - it's content - (including size) should not be modified. - - """ - - def __init__(self): - self._data = collections.deque() - self._data_len = 0 - - def __len__(self): - return self._data_len - - def add(self, data: Bytes): - """Add data""" - if not data: - return - - self._data.append(data) - self._data_len += len(data) - - def read(self, n: int = -1) -> Bytes: - """Read up to `n` bytes - - If ``n < 0``, read all data. - - """ - if n == 0: - return b'' - - if n < 0 or n >= self._data_len: - data, self._data = self._data, collections.deque() - data_len, self._data_len = self._data_len, 0 - - else: - data = collections.deque() - data_len = 0 - - while data_len < n: - head = self._data.popleft() - self._data_len -= len(head) - - if data_len + len(head) <= n: - data.append(head) - data_len += len(head) - - else: - head = memoryview(head) - head1, head2 = head[:n-data_len], head[n-data_len:] - - data.append(head1) - data_len += len(head1) - - self._data.appendleft(head2) - self._data_len += len(head2) - - if len(data) < 1: - return b'' - - if len(data) < 2: - return data[0] - - data_bytes = bytearray(data_len) - data_bytes_len = 0 - - while data: - head = data.popleft() - data_bytes[data_bytes_len:data_bytes_len+len(head)] = head - data_bytes_len += len(head) - - return data_bytes - - def clear(self) -> int: - """Clear data and return number of bytes cleared""" - self._data.clear() - data_len, self._data_len = self._data_len, 0 - return data_len - - -def register_sqlite3_timestamp_converter(): - """Register modified timestamp converter - - This converter is modification of standard library convertor taking into - account possible timezone info. - - """ - - def convert_timestamp(val: bytes) -> datetime.datetime: - datepart, timetzpart = val.split(b" ") - if b"+" in timetzpart: - tzsign = 1 - timepart, tzpart = timetzpart.split(b"+") - elif b"-" in timetzpart: - tzsign = -1 - timepart, tzpart = timetzpart.split(b"-") - else: - timepart, tzpart = timetzpart, None - year, month, day = map(int, datepart.split(b"-")) - timepart_full = timepart.split(b".") - hours, minutes, seconds = map(int, timepart_full[0].split(b":")) - if len(timepart_full) == 2: - microseconds = int('{:0<6.6}'.format(timepart_full[1].decode())) - else: - microseconds = 0 - if tzpart: - tzhours, tzminutes = map(int, tzpart.split(b":")) - tz = datetime.timezone( - tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes)) - else: - tz = None - - val = datetime.datetime(year, month, day, hours, minutes, seconds, - microseconds, tz) - return val - - sqlite3.register_converter("timestamp", convert_timestamp) diff --git a/src_py/hat/util/__init__.py b/src_py/hat/util/__init__.py new file mode 100644 index 00000000..3c94ccc9 --- /dev/null +++ b/src_py/hat/util/__init__.py @@ -0,0 +1,24 @@ +"""Common utility functions""" + +from hat.util import cron +from hat.util.bytes import (Bytes, + BytesBuffer) +from hat.util.callback import (RegisterCallbackHandle, + ExceptionCb, + CallbackRegistry) +from hat.util.first import first +from hat.util.socket import (get_unused_tcp_port, + get_unused_udp_port) +from hat.util.sqlite3 import register_sqlite3_timestamp_converter + + +__all__ = ['cron', + 'Bytes', + 'BytesBuffer', + 'RegisterCallbackHandle', + 'ExceptionCb', + 'CallbackRegistry', + 'first', + 'get_unused_tcp_port', + 'get_unused_udp_port', + 'register_sqlite3_timestamp_converter'] diff --git a/src_py/hat/util/bytes.py b/src_py/hat/util/bytes.py new file mode 100644 index 00000000..caa33771 --- /dev/null +++ b/src_py/hat/util/bytes.py @@ -0,0 +1,86 @@ +import collections +import typing + + +Bytes: typing.TypeAlias = bytes | bytearray | memoryview + + +class BytesBuffer: + """Bytes buffer + + All data added to BytesBuffer is considered immutable - it's content + (including size) should not be modified. + + """ + + def __init__(self): + self._data = collections.deque() + self._data_len = 0 + + def __len__(self): + return self._data_len + + def add(self, data: Bytes): + """Add data""" + if not data: + return + + self._data.append(data) + self._data_len += len(data) + + def read(self, n: int = -1) -> Bytes: + """Read up to `n` bytes + + If ``n < 0``, read all data. + + """ + if n == 0: + return b'' + + if n < 0 or n >= self._data_len: + data, self._data = self._data, collections.deque() + data_len, self._data_len = self._data_len, 0 + + else: + data = collections.deque() + data_len = 0 + + while data_len < n: + head = self._data.popleft() + self._data_len -= len(head) + + if data_len + len(head) <= n: + data.append(head) + data_len += len(head) + + else: + head = memoryview(head) + head1, head2 = head[:n-data_len], head[n-data_len:] + + data.append(head1) + data_len += len(head1) + + self._data.appendleft(head2) + self._data_len += len(head2) + + if len(data) < 1: + return b'' + + if len(data) < 2: + return data[0] + + data_bytes = bytearray(data_len) + data_bytes_len = 0 + + while data: + head = data.popleft() + data_bytes[data_bytes_len:data_bytes_len+len(head)] = head + data_bytes_len += len(head) + + return data_bytes + + def clear(self) -> int: + """Clear data and return number of bytes cleared""" + self._data.clear() + data_len, self._data_len = self._data_len, 0 + return data_len diff --git a/src_py/hat/util/callback.py b/src_py/hat/util/callback.py new file mode 100644 index 00000000..ae6a74f6 --- /dev/null +++ b/src_py/hat/util/callback.py @@ -0,0 +1,71 @@ +from collections.abc import Callable +import typing + + +class RegisterCallbackHandle(typing.NamedTuple): + """Handle for canceling callback registration.""" + + cancel: Callable[[], None] + """cancel callback registration""" + + def __enter__(self): + return self + + def __exit__(self, *args): + self.cancel() + + +ExceptionCb: typing.TypeAlias = Callable[[Exception], None] +"""Exception callback""" + + +class CallbackRegistry: + """Registry that enables callback registration and notification. + + Callbacks in the registry are notified sequentially with + `CallbackRegistry.notify`. If a callback raises an exception, the + exception is caught and `exception_cb` handler is called. Notification of + subsequent callbacks is not interrupted. If handler is `None`, the + exception is reraised and no subsequent callback is notified. + + Example:: + + x = [] + y = [] + registry = CallbackRegistry() + + registry.register(x.append) + registry.notify(1) + + with registry.register(y.append): + registry.notify(2) + + registry.notify(3) + + assert x == [1, 2, 3] + assert y == [2] + + """ + + def __init__(self, + exception_cb: ExceptionCb | None = None): + self._exception_cb = exception_cb + self._cbs = [] # type: list[Callable] + + def register(self, + cb: Callable + ) -> RegisterCallbackHandle: + """Register a callback.""" + self._cbs.append(cb) + return RegisterCallbackHandle(lambda: self._cbs.remove(cb)) + + def notify(self, *args, **kwargs): + """Notify all registered callbacks.""" + for cb in self._cbs: + try: + cb(*args, **kwargs) + except Exception as e: + if self._exception_cb: + self._exception_cb(e) + else: + raise diff --git a/src_py/hat/util/cron.py b/src_py/hat/util/cron.py new file mode 100644 index 00000000..fdd8b46c --- /dev/null +++ b/src_py/hat/util/cron.py @@ -0,0 +1,103 @@ +import datetime +import typing + + +class AllSubExpr(typing.NamedTuple): + pass + + +class ValueSubExpr(typing.NamedTuple): + value: int + + +class RangeSubExpr(typing.NamedTuple): + from_: ValueSubExpr + to: ValueSubExpr + + +class ListSubExpr(typing.NamedTuple): + subexprs: list[ValueSubExpr] + + +SubExpr: typing.TypeAlias = (AllSubExpr | + ValueSubExpr | + RangeSubExpr | + ListSubExpr) + + +class Expr(typing.NamedTuple): + minute: SubExpr + hour: SubExpr + day: SubExpr + month: SubExpr + day_of_week: SubExpr + + +def parse(expr_str: str) -> Expr: + return Expr(*(_parse_subexpr(i) for i in expr_str.split(' '))) + + +def next(expr: Expr, + t: datetime.datetime + ) -> datetime.datetime: + t = t.replace(second=0, microsecond=0) + + while True: + t = t + datetime.timedelta(minutes=1) + + if match(expr, t): + return t + + +def match(expr: Expr, + t: datetime.datetime + ) -> bool: + if t.second or t.microsecond: + return False + + if not _match_subexpr(expr.minute, t.minute): + return False + + if not _match_subexpr(expr.hour, t.hour): + return False + + if not _match_subexpr(expr.day, t.day): + return False + + if not _match_subexpr(expr.month, t.month): + return False + + if not _match_subexpr(expr.day_of_week, t.isoweekday() % 7): + return False + + return True + + +def _parse_subexpr(subexpr_str): + if subexpr_str == '*': + return AllSubExpr() + + if '-' in subexpr_str: + from_str, to_str = subexpr_str.split('-') + return RangeSubExpr(int(from_str), int(to_str)) + + if ',' in subexpr_str: + return ListSubExpr([int(i) for i in subexpr_str.split(',')]) + + return ValueSubExpr(int(subexpr_str)) + + +def _match_subexpr(subexpr, value): + if isinstance(subexpr, AllSubExpr): + return True + + if isinstance(subexpr, ValueSubExpr): + return value == subexpr.value + + if isinstance(subexpr, RangeSubExpr): + return subexpr.from_ <= value <= subexpr.to + + if isinstance(subexpr, ListSubExpr): + return value in subexpr.subexprs + + raise ValueError('unsupported subexpression') diff --git a/src_py/hat/util/first.py b/src_py/hat/util/first.py new file mode 100644 index 00000000..89115c70 --- /dev/null +++ b/src_py/hat/util/first.py @@ -0,0 +1,33 @@ +from collections.abc import Callable, Iterable +import typing + + +T = typing.TypeVar('T') + + +def first(xs: Iterable[T], + fn: Callable[[T], typing.Any] = lambda _: True, + default: T | None = None + ) -> T | None: + """Return the first element from iterable that satisfies predicate `fn`, + or `default` if no such element exists. + + Result of predicate `fn` can be of any type. Predicate is satisfied if it's + return value is truthy. + + Args: + xs: collection + fn: predicate + default: default value + + Example:: + + assert first(range(3)) == 0 + assert first(range(3), lambda x: x > 1) == 2 + assert first(range(3), lambda x: x > 2) is None + assert first(range(3), lambda x: x > 2, 123) == 123 + assert first({1: 'a', 2: 'b', 3: 'c'}) == 1 + assert first([], default=123) == 123 + + """ + return next((i for i in xs if fn(i)), default) diff --git a/src_py/hat/util/socket.py b/src_py/hat/util/socket.py new file mode 100644 index 00000000..6d42fbe9 --- /dev/null +++ b/src_py/hat/util/socket.py @@ -0,0 +1,16 @@ +import contextlib +import socket + + +def get_unused_tcp_port(host: str = '127.0.0.1') -> int: + """Search for unused TCP port""" + with contextlib.closing(socket.socket()) as sock: + sock.bind((host, 0)) + return sock.getsockname()[1] + + +def get_unused_udp_port(host: str = '127.0.0.1') -> int: + """Search for unused UDP port""" + with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock: + sock.bind((host, 0)) + return sock.getsockname()[1] diff --git a/src_py/hat/util/sqlite3.py b/src_py/hat/util/sqlite3.py new file mode 100644 index 00000000..61d7eba9 --- /dev/null +++ b/src_py/hat/util/sqlite3.py @@ -0,0 +1,41 @@ +import datetime +import sqlite3 + + +def register_sqlite3_timestamp_converter(): + """Register modified timestamp converter + + This converter is modification of standard library convertor taking into + account possible timezone info. + + """ + + def convert_timestamp(val: bytes) -> datetime.datetime: + datepart, timetzpart = val.split(b" ") + if b"+" in timetzpart: + tzsign = 1 + timepart, tzpart = timetzpart.split(b"+") + elif b"-" in timetzpart: + tzsign = -1 + timepart, tzpart = timetzpart.split(b"-") + else: + timepart, tzpart = timetzpart, None + year, month, day = map(int, datepart.split(b"-")) + timepart_full = timepart.split(b".") + hours, minutes, seconds = map(int, timepart_full[0].split(b":")) + if len(timepart_full) == 2: + microseconds = int('{:0<6.6}'.format(timepart_full[1].decode())) + else: + microseconds = 0 + if tzpart: + tzhours, tzminutes = map(int, tzpart.split(b":")) + tz = datetime.timezone( + tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes)) + else: + tz = None + + val = datetime.datetime(year, month, day, hours, minutes, seconds, + microseconds, tz) + return val + + sqlite3.register_converter("timestamp", convert_timestamp) diff --git a/test_pytest/test_bytes.py b/test_pytest/test_bytes.py new file mode 100644 index 00000000..087d5006 --- /dev/null +++ b/test_pytest/test_bytes.py @@ -0,0 +1,34 @@ +from hat import util + + +def test_bytes_buffer(): + buff = util.BytesBuffer() + + assert len(buff) == 0 + data = buff.read() + assert bytes(data) == b'' + assert len(buff) == 0 + + buff.add(b'a') + buff.add(b'b') + buff.add(b'c') + assert len(buff) == 3 + data = buff.read() + assert bytes(data) == b'abc' + assert len(buff) == 0 + + buff.add(b'12') + buff.add(b'34') + buff.add(b'56') + assert len(buff) == 6 + data = buff.read(3) + assert bytes(data) == b'123' + assert len(buff) == 3 + data = buff.read(6) + assert bytes(data) == b'456' + assert len(buff) == 0 + + buff.add(b'123') + assert len(buff) == 3 + assert buff.clear() == 3 + assert len(buff) == 0 diff --git a/test_pytest/test_callback.py b/test_pytest/test_callback.py new file mode 100644 index 00000000..cf3a7b5b --- /dev/null +++ b/test_pytest/test_callback.py @@ -0,0 +1,88 @@ +import pytest + +from hat import util + + +def test_callback_registry(): + counter = 0 + + def on_event(): + nonlocal counter + counter = counter + 1 + + registry = util.CallbackRegistry() + + assert counter == 0 + + with registry.register(on_event): + registry.notify() + + assert counter == 1 + + registry.notify() + + assert counter == 1 + + +def test_callback_registry_example(): + x = [] + y = [] + registry = util.CallbackRegistry() + + registry.register(x.append) + registry.notify(1) + with registry.register(y.append): + registry.notify(2) + registry.notify(3) + + assert x == [1, 2, 3] + assert y == [2] + + +@pytest.mark.parametrize('value_count', [1, 2, 10]) +@pytest.mark.parametrize('cb_count', [0, 1, 2, 10]) +def test_callback_registry_with_exception_cb(value_count, cb_count): + + def exception_cb(e): + assert isinstance(e, Exception) + raised.append(str(e)) + + def cb(value): + raise Exception(value) + + registry = util.CallbackRegistry(exception_cb) + handlers = [registry.register(cb) for _ in range(cb_count)] + + raised = [] + expected = [] + for value in range(value_count): + registry.notify(str(value)) + expected.extend(str(value) for _ in range(cb_count)) + assert raised == expected + + for handler in handlers: + handler.cancel() + + raised = [] + expected = [] + for value in range(value_count): + registry.notify(str(value)) + assert raised == expected + + +@pytest.mark.parametrize('cb_count', [1, 2, 10]) +def test_callback_registry_without_exception_cb(cb_count): + + def cb(): + nonlocal call_count + call_count += 1 + raise Exception() + + registry = util.CallbackRegistry() + for _ in range(cb_count): + registry.register(cb) + + call_count = 0 + with pytest.raises(Exception): + registry.notify() + assert call_count == 1 diff --git a/test_pytest/test_first.py b/test_pytest/test_first.py new file mode 100644 index 00000000..4f58e645 --- /dev/null +++ b/test_pytest/test_first.py @@ -0,0 +1,19 @@ +from hat import util + + +def test_first(): + x = [1, 2, 3] + assert util.first(x) == 1 + assert util.first([]) is None + assert util.first(x, lambda x: x > 1) == 2 + assert util.first(x, lambda x: x > 3) is None + assert util.first([], default=4) == 4 + + +def test_first_example(): + assert util.first(range(3)) == 0 + assert util.first(range(3), lambda x: x > 1) == 2 + assert util.first(range(3), lambda x: x > 2) is None + assert util.first(range(3), lambda x: x > 2, 123) == 123 + assert util.first({1: 'a', 2: 'b', 3: 'c'}) == 1 + assert util.first([], default=123) == 123 diff --git a/test_pytest/test_socket.py b/test_pytest/test_socket.py new file mode 100644 index 00000000..91ac9cf5 --- /dev/null +++ b/test_pytest/test_socket.py @@ -0,0 +1,13 @@ +from hat import util + + +def test_get_unused_tcp_port(): + port = util.get_unused_tcp_port() + assert isinstance(port, int) + assert 0 < port <= 0xFFFF + + +def test_get_unused_udp_port(): + port = util.get_unused_udp_port() + assert isinstance(port, int) + assert 0 < port <= 0xFFFF diff --git a/test_pytest/test_sqlite3.py b/test_pytest/test_sqlite3.py new file mode 100644 index 00000000..7e6c873e --- /dev/null +++ b/test_pytest/test_sqlite3.py @@ -0,0 +1,35 @@ +import datetime +import sqlite3 + +import pytest + +from hat import util + + +@pytest.fixture(scope='session', autouse=True) +def register_sqlite3_timestamp_converter(): + util.register_sqlite3_timestamp_converter() + + +@pytest.mark.parametrize("t", [ + datetime.datetime.now(), + datetime.datetime(2000, 1, 1), + datetime.datetime(2000, 1, 2, 3, 4, 5, 123456), + datetime.datetime(2000, 1, 2, 3, 4, 5, 123456, + tzinfo=datetime.timezone.utc), + datetime.datetime(2000, 1, 2, 3, 4, 5, 123456, + tzinfo=datetime.timezone(datetime.timedelta(hours=1, + minutes=2))), + datetime.datetime(2000, 1, 2, 3, 4, 5, 123456, + tzinfo=datetime.timezone(-datetime.timedelta(hours=1, + minutes=2))) +]) +def test_sqlite3_timestamp_converter(t): + with sqlite3.connect(':memory:', + isolation_level=None, + detect_types=sqlite3.PARSE_DECLTYPES) as conn: + conn.execute("CREATE TABLE test (t TIMESTAMP)") + conn.execute("INSERT INTO test VALUES (:t)", {'t': t}) + + result = conn.execute("SELECT t FROM test").fetchone()[0] + assert result == t diff --git a/test_pytest/test_util.py b/test_pytest/test_util.py deleted file mode 100644 index 9b838a4a..00000000 --- a/test_pytest/test_util.py +++ /dev/null @@ -1,183 +0,0 @@ -import datetime -import sqlite3 - -import pytest - -from hat import util - - -@pytest.fixture(scope='session', autouse=True) -def register_sqlite3_timestamp_converter(): - util.register_sqlite3_timestamp_converter() - - -def test_first(): - x = [1, 2, 3] - assert util.first(x) == 1 - assert util.first([]) is None - assert util.first(x, lambda x: x > 1) == 2 - assert util.first(x, lambda x: x > 3) is None - assert util.first([], default=4) == 4 - - -def test_first_example(): - assert util.first(range(3)) == 0 - assert util.first(range(3), lambda x: x > 1) == 2 - assert util.first(range(3), lambda x: x > 2) is None - assert util.first(range(3), lambda x: x > 2, 123) == 123 - assert util.first({1: 'a', 2: 'b', 3: 'c'}) == 1 - assert util.first([], default=123) == 123 - - -def test_callback_registry(): - counter = 0 - - def on_event(): - nonlocal counter - counter = counter + 1 - - registry = util.CallbackRegistry() - - assert counter == 0 - - with registry.register(on_event): - registry.notify() - - assert counter == 1 - - registry.notify() - - assert counter == 1 - - -def test_callback_registry_example(): - x = [] - y = [] - registry = util.CallbackRegistry() - - registry.register(x.append) - registry.notify(1) - with registry.register(y.append): - registry.notify(2) - registry.notify(3) - - assert x == [1, 2, 3] - assert y == [2] - - -@pytest.mark.parametrize('value_count', [1, 2, 10]) -@pytest.mark.parametrize('cb_count', [0, 1, 2, 10]) -def test_callback_registry_with_exception_cb(value_count, cb_count): - - def exception_cb(e): - assert isinstance(e, Exception) - raised.append(str(e)) - - def cb(value): - raise Exception(value) - - registry = util.CallbackRegistry(exception_cb) - handlers = [registry.register(cb) for _ in range(cb_count)] - - raised = [] - expected = [] - for value in range(value_count): - registry.notify(str(value)) - expected.extend(str(value) for _ in range(cb_count)) - assert raised == expected - - for handler in handlers: - handler.cancel() - - raised = [] - expected = [] - for value in range(value_count): - registry.notify(str(value)) - assert raised == expected - - -@pytest.mark.parametrize('cb_count', [1, 2, 10]) -def test_callback_registry_without_exception_cb(cb_count): - - def cb(): - nonlocal call_count - call_count += 1 - raise Exception() - - registry = util.CallbackRegistry() - for _ in range(cb_count): - registry.register(cb) - - call_count = 0 - with pytest.raises(Exception): - registry.notify() - assert call_count == 1 - - -def test_get_unused_tcp_port(): - port = util.get_unused_tcp_port() - assert isinstance(port, int) - assert 0 < port <= 0xFFFF - - -def test_get_unused_udp_port(): - port = util.get_unused_udp_port() - assert isinstance(port, int) - assert 0 < port <= 0xFFFF - - -def test_bytes_buffer(): - buff = util.BytesBuffer() - - assert len(buff) == 0 - data = buff.read() - assert bytes(data) == b'' - assert len(buff) == 0 - - buff.add(b'a') - buff.add(b'b') - buff.add(b'c') - assert len(buff) == 3 - data = buff.read() - assert bytes(data) == b'abc' - assert len(buff) == 0 - - buff.add(b'12') - buff.add(b'34') - buff.add(b'56') - assert len(buff) == 6 - data = buff.read(3) - assert bytes(data) == b'123' - assert len(buff) == 3 - data = buff.read(6) - assert bytes(data) == b'456' - assert len(buff) == 0 - - buff.add(b'123') - assert len(buff) == 3 - assert buff.clear() == 3 - assert len(buff) == 0 - - -@pytest.mark.parametrize("t", [ - datetime.datetime.now(), - datetime.datetime(2000, 1, 1), - datetime.datetime(2000, 1, 2, 3, 4, 5, 123456), - datetime.datetime(2000, 1, 2, 3, 4, 5, 123456, - tzinfo=datetime.timezone.utc), - datetime.datetime(2000, 1, 2, 3, 4, 5, 123456, - tzinfo=datetime.timezone(datetime.timedelta(hours=1, - minutes=2))), - datetime.datetime(2000, 1, 2, 3, 4, 5, 123456, - tzinfo=datetime.timezone(-datetime.timedelta(hours=1, - minutes=2))) -]) -def test_sqlite3_timestamp_converter(t): - with sqlite3.connect(':memory:', - isolation_level=None, - detect_types=sqlite3.PARSE_DECLTYPES) as conn: - conn.execute("CREATE TABLE test (t TIMESTAMP)") - conn.execute("INSERT INTO test VALUES (:t)", {'t': t}) - - result = conn.execute("SELECT t FROM test").fetchone()[0] - assert result == t