Skip to content

Commit

Permalink
Ruff updates
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 11, 2024
1 parent 5787811 commit 1da0207
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 17 deletions.
9 changes: 4 additions & 5 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import asyncio
import logging
import uuid
from typing import Any, AnyStr, Self
from typing import Any, Self

import zmq
import zmq.asyncio
from typing_extensions import Self

from _ert.async_utils import new_event_loop

Expand All @@ -25,7 +24,7 @@ class ClientConnectionClosedOK(Exception):
class Client:
DEFAULT_MAX_RETRIES = 5
DEFAULT_ACK_TIMEOUT = 5
_receiver_task: Optional[asyncio.Task[None]]
_receiver_task: asyncio.Task[None] | None

def __enter__(self) -> Self:
self.loop.run_until_complete(self.__aenter__())
Expand Down Expand Up @@ -103,7 +102,7 @@ async def connect(self) -> None:
self.term()
raise

def send(self, message: str, retries: Optional[int] = None) -> None:
def send(self, message: str, retries: int | None = None) -> None:
self.loop.run_until_complete(self._send(message, retries))

async def process_message(self, msg: str) -> None:
Expand Down Expand Up @@ -137,7 +136,7 @@ async def _send(self, message: str, retries: int | None = None) -> None:
self._ack_event.wait(), timeout=self._ack_timeout
)
return
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
f"{self.dealer_id} failed to get acknowledgment on the {message}. Resending."
)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from functools import partialmethod
from typing import Any, Awaitable, Callable, Protocol, Sequence
from typing import Any, Protocol

from _ert.events import (
Event,
Expand Down
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
import warnings
from base64 import b64encode
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta

import zmq
from cryptography import x509
Expand Down
6 changes: 3 additions & 3 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()

self._ee_tasks: List[asyncio.Task[None]] = []
self._ee_tasks: list[asyncio.Task[None]] = []
self._server_done: asyncio.Event = asyncio.Event()

# batching section
Expand Down Expand Up @@ -292,7 +292,7 @@ async def _server(self) -> None:
await self._server_done.wait()
try:
await asyncio.wait_for(self._dispatchers_empty.wait(), timeout=5)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"Not all dispatchers were disconnected when closing zmq server!"
)
Expand All @@ -304,7 +304,7 @@ async def _server(self) -> None:
await self._events_to_send.join()
try:
await asyncio.wait_for(self._clients_empty.wait(), timeout=5)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"Not all clients were disconnected when closing zmq server!"
)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class EventSentinel:
class Monitor(Client):
_sentinel: Final = EventSentinel()

def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
def __init__(self, ee_con_info: EvaluatorConnectionInfo) -> None:
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Event | EventSentinel] = asyncio.Queue()
self._receiver_timeout: float = 60.0
Expand Down
8 changes: 4 additions & 4 deletions tests/ert/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,16 @@ def copy_heat_equation(copy_case):
pytest.param(0, marks=pytest.mark.xdist_group(name="snake_oil_case_storage"))
],
)
def fixture_copy_snake_oil_case_storage(_shared_snake_oil_case, tmp_path, monkeypatch):
def fixture_copy_snake_oil_case_storage(shared_snake_oil_case, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
shutil.copytree(_shared_snake_oil_case, "test_data")
shutil.copytree(shared_snake_oil_case, "test_data")
monkeypatch.chdir("test_data")


@pytest.fixture
def copy_heat_equation_storage(_shared_heat_equation, tmp_path, monkeypatch):
def copy_heat_equation_storage(shared_heat_equation, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
shutil.copytree(_shared_heat_equation, "heat_equation")
shutil.copytree(shared_heat_equation, "heat_equation")
monkeypatch.chdir("heat_equation")


Expand Down
4 changes: 2 additions & 2 deletions tests/ert/ui_tests/gui/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def _evaluate(coeffs, x):


@pytest.fixture
def esmda_has_run(_esmda_run, tmp_path, monkeypatch):
def esmda_has_run(esmda_run, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
shutil.copytree(_esmda_run, tmp_path, dirs_exist_ok=True)
shutil.copytree(esmda_run, tmp_path, dirs_exist_ok=True)
with (
_open_main_window(tmp_path / "poly.ert") as (
gui,
Expand Down

0 comments on commit 1da0207

Please sign in to comment.