diff --git a/qcodes/monitor/monitor.py b/qcodes/monitor/monitor.py index f3e76798262..4f9d48132c5 100644 --- a/qcodes/monitor/monitor.py +++ b/qcodes/monitor/monitor.py @@ -17,7 +17,7 @@ ``monitor = qcodes.Monitor(param1, param2, param3, ...)`` """ - +from __future__ import annotations import asyncio import json @@ -30,28 +30,11 @@ from collections import defaultdict from contextlib import suppress from threading import Event, Thread -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Dict, - Optional, - Sequence, - Union, -) +from typing import Any, Awaitable, Callable, Sequence import websockets - -try: - from websockets.legacy.server import serve -except ImportError: - # fallback for websockets < 9 - # for the same reason we only support typechecking with websockets 9 - from websockets import serve # type:ignore[attr-defined,no-redef] - -if TYPE_CHECKING: - from websockets.legacy.server import WebSocketServerProtocol, WebSocketServer +import websockets.exceptions +import websockets.server from qcodes.parameters import Parameter @@ -63,18 +46,18 @@ def _get_metadata( *parameters: Parameter, use_root_instrument: bool = True -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Return a dictionary that contains the parameter metadata grouped by the instrument it belongs to. """ metadata_timestamp = time.time() # group metadata by instrument - metas: Dict[Any, Any] = defaultdict(list) + metas: dict[Any, Any] = defaultdict(list) for parameter in parameters: # Get the latest value from the parameter, # respecting the max_val_age parameter - meta: Dict[str, Optional[Union[float, str]]] = {} + meta: dict[str, float | str | None] = {} meta["value"] = str(parameter.get_latest()) timestamp = parameter.get_latest.get_timestamp() if timestamp is not None: @@ -106,11 +89,14 @@ def _get_metadata( def _handler( parameters: Sequence[Parameter], interval: float, use_root_instrument: bool = True -) -> Callable[["WebSocketServerProtocol", str], Awaitable[None]]: +) -> Callable[[websockets.server.WebSocketServerProtocol, str], Awaitable[None]]: """ Return the websockets server handler. """ - async def server_func(websocket: "WebSocketServerProtocol", _: str) -> None: + + async def server_func( + websocket: websockets.server.WebSocketServerProtocol, _: str + ) -> None: """ Create a websockets handler that sends parameter values to a listener every "interval" seconds. @@ -167,15 +153,14 @@ def __init__( raise TypeError(f"We can only monitor QCodes " f"Parameters, not {type(parameter)}") - self.loop: Optional[asyncio.AbstractEventLoop] = None - self.server: Optional["WebSocketServer"] = None + self.loop: asyncio.AbstractEventLoop | None = None + self._stop_loop_future: asyncio.Future | None = None self._parameters = parameters self.loop_is_closed = Event() self.server_is_started = Event() self.handler = _handler( parameters, interval=interval, use_root_instrument=use_root_instrument ) - log.debug("Start monitoring thread") if Monitor.running: # stop the old server @@ -194,24 +179,23 @@ def run(self) -> None: Start the event loop and run forever. """ log.debug("Running Websocket server") - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + + async def run_loop() -> None: + self.loop = asyncio.get_running_loop() + self._stop_loop_future = self.loop.create_future() + + async with websockets.server.serve( + self.handler, "127.0.0.1", WEBSOCKET_PORT, close_timeout=1 + ): + self.server_is_started.set() + try: + await self._stop_loop_future + except asyncio.CancelledError: + log.debug("Websocket server thread shutting down") + try: - server_start = serve(self.handler, '127.0.0.1', - WEBSOCKET_PORT, close_timeout=1) - self.server = self.loop.run_until_complete(server_start) - self.server_is_started.set() - self.loop.run_forever() - except OSError: - # The code above may throw an OSError - # if the socket cannot be bound - log.exception("Server could not be started") + asyncio.run(run_loop()) finally: - log.debug("loop stopped") - log.debug("Pending tasks at close: %r", - asyncio.all_tasks(self.loop)) - self.loop.close() - log.debug("loop closed") self.loop_is_closed.set() def update_all(self) -> None: @@ -231,20 +215,7 @@ def stop(self) -> None: self.join() Monitor.running = None - async def __stop_server(self) -> None: - log.debug("asking server %r to close", self.server) - if self.server is not None: - self.server.close() - log.debug("waiting for server to close") - if self.loop is not None and self.server is not None: - await self.loop.create_task(self.server.wait_closed()) - log.debug("stopping loop") - if self.loop is not None: - log.debug("Pending tasks at stop: %r", - asyncio.all_tasks(self.loop)) - self.loop.stop() - - def join(self, timeout: Optional[float] = None) -> None: + def join(self, timeout: float | None = None) -> None: """ Overwrite ``Thread.join`` to make sure server is stopped before joining avoiding a potential deadlock. @@ -256,9 +227,11 @@ def join(self, timeout: Optional[float] = None) -> None: log.debug("monitor is dead") return try: - if self.loop is not None: - asyncio.run_coroutine_threadsafe(self.__stop_server(), - self.loop) + if self.loop is not None and self._stop_loop_future is not None: + log.debug("Instructing server to stop event loop.") + self.loop.call_soon_threadsafe(self._stop_loop_future.cancel) + else: + log.debug("No event loop found. Cannot stop event loop.") except RuntimeError: # the above may throw a runtime error if the loop is already # stopped in which case there is nothing more to do