diff --git a/asyncua/client/ha/common.py b/asyncua/client/ha/common.py index 29e732419..e11a2eb20 100644 --- a/asyncua/client/ha/common.py +++ b/asyncua/client/ha/common.py @@ -5,6 +5,7 @@ from itertools import chain, islice +from asyncua.common.utils import wait_for _logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ class ClientNotFound(Exception): async def event_wait(evt, timeout) -> bool: try: - await asyncio.wait_for(evt.wait(), timeout) + await wait_for(evt.wait(), timeout) except asyncio.TimeoutError: pass return evt.is_set() diff --git a/asyncua/client/ua_client.py b/asyncua/client/ua_client.py index e2d39d264..a846fe9ec 100644 --- a/asyncua/client/ua_client.py +++ b/asyncua/client/ua_client.py @@ -8,6 +8,7 @@ from asyncua import ua from asyncua.common.session_interface import AbstractSession +from ..common.utils import wait_for from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError from ..ua.uaprotocol_auto import OpenSecureChannelResult, SubscriptionAcknowledgement @@ -165,7 +166,7 @@ async def send_request(self, request, timeout: Optional[float] = None, message_t # time out then. await self.pre_request_hook() try: - data = await asyncio.wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None) + data = await wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None) except Exception: if self.state != self.OPEN: raise ConnectionError("Connection is closed") from None @@ -221,7 +222,7 @@ async def send_hello(self, url, max_messagesize: int = 0, max_chunkcount: int = self._callbackmap[0] = ack if self.transport is not None: self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello)) - return await asyncio.wait_for(ack, self.timeout) + return await wait_for(ack, self.timeout) async def open_secure_channel(self, params) -> OpenSecureChannelResult: self.logger.info("open_secure_channel") @@ -230,7 +231,7 @@ async def open_secure_channel(self, params) -> OpenSecureChannelResult: if self._open_secure_channel_exchange is not None: raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.') self._open_secure_channel_exchange = params - await asyncio.wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout) + await wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout) _return = self._open_secure_channel_exchange.Parameters # type: ignore[union-attr] self._open_secure_channel_exchange = None return _return diff --git a/asyncua/common/utils.py b/asyncua/common/utils.py index 4d87e3c23..cb112588d 100644 --- a/asyncua/common/utils.py +++ b/asyncua/common/utils.py @@ -2,12 +2,13 @@ Helper function and classes that do not rely on asyncua library. Helper function and classes depending on ua object are in ua_utils.py """ - -import os +import asyncio import logging +import os import sys from dataclasses import Field, fields -from typing import get_type_hints, Dict, Tuple, Any, Optional +from typing import Any, Awaitable, Dict, get_type_hints, Optional, Tuple, TypeVar, Union + from ..ua.uaerrors import UaError _logger = logging.getLogger(__name__) @@ -132,3 +133,22 @@ def fields_with_resolved_types( pass return fields_ + + +_T = TypeVar('_T') + + +async def wait_for(aw: Awaitable[_T], timeout: Union[int, float, None]) -> _T: + """ + Wrapped version of asyncio.wait_for that does not swallow cancellations + + There is a bug in asyncio.wait_for before Python version 3.12 that prevents the inner awaitable from being cancelled + when the task is cancelled from the outside. + + See https://github.com/python/cpython/issues/87555 and https://github.com/python/cpython/issues/86296 + """ + if sys.version_info >= (3, 12): + return await asyncio.wait_for(aw, timeout) + + import wait_for2 + return await wait_for2.wait_for(aw, timeout) diff --git a/setup.cfg b/setup.cfg index aeebdd92c..990476d4e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,7 @@ max-line-length = 160 disable_error_code = misc, arg-type, assignment, var-annotated show_error_codes = True check_untyped_defs = False +mypy_path = ./stubs [mypy-asyncua.ua.uaprotocol_auto.*] # Autogenerated file disable_error_code = literal-required diff --git a/setup.py b/setup.py index 92bab3182..6e4776fb9 100644 --- a/setup.py +++ b/setup.py @@ -20,10 +20,10 @@ author="Olivier Roulet-Dubonnet", author_email="olivier.roulet@gmail.com", url='http://freeopcua.github.io/', - packages=find_packages(exclude=["tests"]), + packages=find_packages(exclude=["tests", "stubs"]), provides=["asyncua"], license="GNU Lesser General Public License v3 or later", - install_requires=["aiofiles", "aiosqlite", "python-dateutil", "pytz", "cryptography>42.0.0", "sortedcontainers", "importlib-metadata;python_version<'3.8'", "pyOpenSSL>23.2.0", "typing-extensions"], + install_requires=["aiofiles", "aiosqlite", "python-dateutil", "pytz", "cryptography>42.0.0", "sortedcontainers", "importlib-metadata;python_version<'3.8'", "pyOpenSSL>23.2.0", "typing-extensions", 'wait_for2==0.3.2'], classifiers=[ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", diff --git a/stubs/wait_for2/__init__.pyi b/stubs/wait_for2/__init__.pyi new file mode 100644 index 000000000..cede48a67 --- /dev/null +++ b/stubs/wait_for2/__init__.pyi @@ -0,0 +1,14 @@ +import asyncio +from typing import Any, Awaitable, Callable, TypeVar, Union + +_T = TypeVar('_T') + + +async def wait_for( + fut: Awaitable[_T], + timeout: Union[int, float, None], + *, + loop: asyncio.AbstractEventLoop = None, + race_handler: Callable[[Union[_T, BaseException], bool], Any] = None, +): + ...