diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index e9aee2905d5..9000fa304ba 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -95,6 +95,13 @@ def create_local_copy(self) -> QueueConfig: self.queue_options, ) + @property + def max_running(self) -> int: + for key, val in self.queue_options.get(self.queue_system, []): + if key == "MAX_RUNNING": + return int(val) + return 0 + def _check_for_overwritten_queue_system_options( selected_queue_system: QueueSystem, diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index d8b3ff662d2..1859224be62 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -4,8 +4,18 @@ import logging import threading import uuid -from functools import partial, partialmethod -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from functools import partialmethod +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Protocol, + Tuple, +) from cloudevents.http.event import CloudEvent @@ -31,6 +41,11 @@ event_logger = logging.getLogger("ert.event_log") +class _KillAllJobs(Protocol): + def kill_all_jobs(self) -> None: + ... + + class LegacyEnsemble(Ensemble): def __init__( self, @@ -42,16 +57,9 @@ def __init__( id_: str, ) -> None: super().__init__(reals, metadata, id_) - if not queue_config: - raise ValueError(f"{self} needs queue_config") - - if FeatureToggling.is_enabled("scheduler"): - if queue_config.queue_system != QueueSystem.LOCAL: - raise NotImplementedError() - driver = create_driver(queue_config) - self._job_queue = Scheduler(driver) - else: - self._job_queue = JobQueue(queue_config) + + self._queue_config = queue_config + self._job_queue: Optional[_KillAllJobs] = None self.stop_long_running = stop_long_running self.min_required_realizations = min_required_realizations self._config: Optional[EvaluatorServerConfig] = None @@ -180,43 +188,45 @@ async def _evaluate_inner( # pylint: disable=too-many-branches raise ValueError("no config") # mypy try: + if FeatureToggling.is_enabled("scheduler"): + if self._queue_config.queue_system != QueueSystem.LOCAL: + raise NotImplementedError() + driver = create_driver(self._queue_config) + queue = Scheduler( + driver, + self.active_reals, + max_submit=self._queue_config.max_submit, + max_running=self._queue_config.max_running, + ens_id=self.id_, + ee_uri=self._config.dispatch_uri, + ee_cert=self._config.cert, + ee_token=self._config.token, + ) + else: + queue = JobQueue( + self._queue_config, + self.active_reals, + ens_id=self.id_, + ee_uri=self._config.dispatch_uri, + ee_cert=self._config.cert, + ee_token=self._config.token, + on_timeout=on_timeout, + ) + self._job_queue = queue + await cloudevent_unary_send( event_creator(identifiers.EVTYPE_ENSEMBLE_STARTED, None) ) - for real in self.active_reals: - self._job_queue.add_realization(real, callback_timeout=on_timeout) - - # TODO: this is sort of a callback being preemptively called. - # It should be lifted out of the queue/evaluate, into the evaluator. If - # something is long running, the evaluator will know and should send - # commands to the task in order to have it killed/retried. - # See https://github.com/equinor/ert/issues/1229 - queue_evaluators = None - if self.stop_long_running and self.min_required_realizations > 0: - queue_evaluators = [ - partial( - self._job_queue.stop_long_running_jobs, - self.min_required_realizations, - ) - ] - - self._job_queue.set_ee_info( - ee_uri=self._config.dispatch_uri, - ens_id=self.id_, - ee_cert=self._config.cert, - ee_token=self._config.token, - ) - - # Tell queue to pass info to the jobs-file - # NOTE: This touches files on disk... - self._job_queue.add_dispatch_information_to_jobs_file() + if isinstance(queue, Scheduler): + result = await queue.execute() + elif isinstance(queue, JobQueue): + min_required_realizations = ( + self.min_required_realizations if self.stop_long_running else 0 + ) + queue.add_dispatch_information_to_jobs_file() - sema = threading.BoundedSemaphore(value=CONCURRENT_INTERNALIZATION) - result: str = await self._job_queue.execute( - sema, - queue_evaluators, - ) + result = await queue.execute(min_required_realizations) except Exception: logger.exception( @@ -236,5 +246,6 @@ def cancellable(self) -> bool: return True def cancel(self) -> None: - self._job_queue.kill_all_jobs() + if self._job_queue is not None: + self._job_queue.kill_all_jobs() logger.debug("evaluator cancelled") diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index eaf53904cd2..4fb484732df 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -8,18 +8,16 @@ import json import logging import ssl -import threading import time from collections import deque from threading import BoundedSemaphore, Semaphore from typing import ( TYPE_CHECKING, - Any, Callable, Dict, - Iterable, List, Optional, + Sequence, Tuple, Union, ) @@ -102,7 +100,18 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def __init__(self, queue_config: QueueConfig): + def __init__( + self, + queue_config: QueueConfig, + realizations: Optional[Sequence[Realization]] = None, + *, + ens_id: Optional[str] = None, + ee_uri: Optional[str] = None, + ee_cert: Optional[str] = None, + ee_token: Optional[str] = None, + on_timeout: Optional[Callable[[int], None]] = None, + verify_token: bool = True, + ) -> None: self.job_list: List[JobQueueNode] = [] self._stopped = False self.driver: Driver = Driver.create_driver(queue_config) @@ -112,17 +121,28 @@ def __init__(self, queue_config: QueueConfig): self._differ = QueueDiffer() self._max_submit = queue_config.max_submit self._pool_sema = BoundedSemaphore(value=CONCURRENT_INTERNALIZATION) + self._on_timeout = on_timeout + + self._ens_id = ens_id + self._ee_uri = ee_uri + self._ee_cert = ee_cert + self._ee_token = ee_token - self._ens_id: Optional[str] = None - self._ee_uri: Optional[str] = None - self._ee_cert: Optional[Union[str, bytes]] = None - self._ee_token: Optional[str] = None self._ee_ssl_context: Optional[Union[ssl.SSLContext, bool]] = None + if ee_cert is not None: + self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if verify_token: + self._ee_ssl_context.load_verify_locations(cadata=ee_cert) + else: + self._ee_ssl_context = True if ee_uri and ee_uri.startswith("wss") else None self._changes_to_publish: Optional[ asyncio.Queue[Union[Dict[int, str], object]] ] = None + for real in realizations or []: + self.add_realization(real) + def get_max_running(self) -> int: return self.driver.get_max_running() @@ -215,27 +235,6 @@ def launch_jobs(self, pool_sema: Semaphore) -> None: max_submit=self.max_submit, ) - def set_ee_info( - self, - ee_uri: str, - ens_id: str, - ee_cert: Optional[Union[str, bytes]] = None, - ee_token: Optional[str] = None, - verify_context: bool = True, - ) -> None: - self._ens_id = ens_id - self._ee_token = ee_token - - self._ee_uri = ee_uri - if ee_cert is not None: - self._ee_cert = ee_cert - self._ee_token = ee_token - self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if verify_context: - self._ee_ssl_context.load_verify_locations(cadata=ee_cert) - else: - self._ee_ssl_context = True if ee_uri.startswith("wss") else None - @staticmethod def _translate_change_to_cloudevent( ens_id: str, real_id: int, status: str @@ -304,14 +303,8 @@ async def _jobqueue_publisher(self) -> None: async def execute( self, - pool_sema: Optional[threading.BoundedSemaphore] = None, - evaluators: Optional[Iterable[Callable[..., Any]]] = None, + min_required_realizations: int = 0, ) -> str: - if pool_sema is not None: - self._pool_sema = pool_sema - if evaluators is None: - evaluators = [] - self._changes_to_publish = asyncio.Queue() asyncio.create_task(self._jobqueue_publisher()) @@ -322,8 +315,8 @@ async def execute( await asyncio.sleep(1) - for func in evaluators: - func() + if min_required_realizations > 0: + self.stop_long_running_jobs(min_required_realizations) changes, new_state = self.changes_without_transition() if len(changes) > 0: @@ -379,14 +372,13 @@ def add_job_from_run_arg( def add_realization( self, real: Realization, - callback_timeout: Optional[Callable[[int], None]] = None, ) -> None: job = JobQueueNode( job_script=real.job_script, num_cpu=real.num_cpu, run_arg=real.run_arg, max_runtime=real.max_runtime, - callback_timeout=callback_timeout, + callback_timeout=self._on_timeout, ) if job is None: raise ValueError("JobQueueNode constructor created None job") diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index b1a2c2f33d5..7aa17ea1b25 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -133,7 +133,4 @@ async def _send(self, state: State) -> None: "queue_event_type": status, }, ) - if self._scheduler._events is None: - await self._scheduler.ainit() - assert self._scheduler._events is not None await self._scheduler._events.put(to_json(event)) diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index a49f94d938c..deadf123935 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -5,9 +5,15 @@ import logging import os import ssl -import threading from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sequence, +) from pydantic.dataclasses import dataclass from websockets import Headers @@ -28,39 +34,45 @@ @dataclass class _JobsJson: - ens_id: str + ens_id: Optional[str] real_id: str - dispatch_url: str + dispatch_url: Optional[str] ee_token: Optional[str] ee_cert_path: Optional[str] experiment_id: str class Scheduler: - def __init__(self, driver: Optional[Driver] = None) -> None: + def __init__( + self, + driver: Optional[Driver] = None, + realizations: Optional[Sequence[Realization]] = None, + *, + max_submit: int = 1, + max_running: int = 1, + ens_id: Optional[str] = None, + ee_uri: Optional[str] = None, + ee_cert: Optional[str] = None, + ee_token: Optional[str] = None, + ) -> None: if driver is None: driver = LocalDriver() self.driver = driver - self._jobs: MutableMapping[int, Job] = {} self._tasks: MutableMapping[int, asyncio.Task[None]] = {} - self._events: Optional[asyncio.Queue[Any]] = None - self._cancelled = False - # will be read from QueueConfig - self._max_submit: int = 2 + self._jobs: Mapping[int, Job] = { + real.iens: Job(self, real) for real in (realizations or []) + } - self._ee_uri = "" - self._ens_id = "" - self._ee_cert: Optional[str] = None - self._ee_token: Optional[str] = None - - async def ainit(self) -> None: - # While supporting Python 3.8, this statement must be delayed. - if self._events is None: - self._events = asyncio.Queue() + self._events: asyncio.Queue[Any] = asyncio.Queue() + self._cancelled = False + self._max_submit = max_submit + self._max_running = max_running - def add_realization(self, real: Realization, callback_timeout: Any = None) -> None: - self._jobs[real.iens] = Job(self, real) + self._ee_uri = ee_uri + self._ens_id = ens_id + self._ee_cert = ee_cert + self._ee_token = ee_token def kill_all_jobs(self) -> None: self._cancelled = True @@ -70,14 +82,6 @@ def kill_all_jobs(self) -> None: def stop_long_running_jobs(self, minimum_required_realizations: int) -> None: pass - def set_ee_info( - self, ee_uri: str, ens_id: str, ee_cert: Optional[str], ee_token: Optional[str] - ) -> None: - self._ee_uri = ee_uri - self._ens_id = ens_id - self._ee_cert = ee_cert - self._ee_token = ee_token - async def _publisher(self) -> None: if not self._ee_uri: return @@ -89,10 +93,6 @@ async def _publisher(self) -> None: if self._ee_token: headers["token"] = self._ee_token - if self._events is None: - await self.ainit() - assert self._events is not None - async with connect( self._ee_uri, ssl=tls, @@ -112,21 +112,14 @@ def add_dispatch_information_to_jobs_file(self) -> None: async def execute( self, - semaphore: Optional[threading.BoundedSemaphore] = None, - queue_evaluators: Optional[Iterable[Callable[..., Any]]] = None, ) -> str: - if queue_evaluators is not None: - logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}") - async with background_tasks() as cancel_when_execute_is_done: cancel_when_execute_is_done(self._publisher()) cancel_when_execute_is_done(self._process_event_queue()) cancel_when_execute_is_done(self.driver.poll()) start = asyncio.Event() - sem = asyncio.BoundedSemaphore( - semaphore._initial_value if semaphore else 10 # type: ignore - ) + sem = asyncio.BoundedSemaphore(self._max_running) for iens, job in self._jobs.items(): self._tasks[iens] = asyncio.create_task( job(start, sem, self._max_submit) @@ -144,10 +137,6 @@ async def execute( return EVTYPE_ENSEMBLE_STOPPED async def _process_event_queue(self) -> None: - if self.driver.event_queue is None: - await self.driver.ainit() - assert self.driver.event_queue is not None - while True: iens, event = await self.driver.event_queue.get() if event == JobEvent.STARTED: diff --git a/src/ert/simulator/simulation_context.py b/src/ert/simulator/simulation_context.py index 45bfeba9d77..69b1930756a 100644 --- a/src/ert/simulator/simulation_context.py +++ b/src/ert/simulator/simulation_context.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from functools import partial from threading import Thread from time import sleep from typing import TYPE_CHECKING, Any, List, Optional, Tuple @@ -61,19 +60,12 @@ def _run_forward_model( ert.ert_config.preferred_num_cpu, ) - queue_evaluators = None - if ( - ert.ert_config.analysis_config.stop_long_running - and ert.ert_config.analysis_config.minimum_required_realizations > 0 - ): - queue_evaluators = [ - partial( - job_queue.stop_long_running_jobs, - ert.ert_config.analysis_config.minimum_required_realizations, - ) - ] - - asyncio.run(job_queue.execute(evaluators=queue_evaluators)) + required_realizations = 0 + if ert.ert_config.analysis_config.stop_long_running: + required_realizations = ( + ert.ert_config.analysis_config.minimum_required_realizations + ) + asyncio.run(job_queue.execute(required_realizations)) run_context.sim_fs.sync() diff --git a/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py b/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py index 26c243cebac..dc256832cb8 100644 --- a/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py +++ b/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py @@ -51,12 +51,9 @@ async def test_happy_path( await wait_for_evaluator(base_url=url, timeout=5) ensemble = make_ensemble_builder(monkeypatch, tmpdir, 1, 1).build() - queue = JobQueue(queue_config) - for real in ensemble.reals: - queue.add_realization(real, callback_timeout=None) + queue = JobQueue(queue_config, ensemble.reals, ee_uri=url, ens_id="ee_0") - queue.set_ee_info(ee_uri=url, ens_id="ee_0") - await queue.execute(pool_sema=threading.BoundedSemaphore(value=10)) + await queue.execute() done.set_result(None) await mock_ws_task diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index ea156dfe2b9..ddc814d2cd5 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -10,6 +10,7 @@ from ert.ensemble_evaluator.config import EvaluatorServerConfig from ert.ensemble_evaluator.evaluator import EnsembleEvaluator from ert.ensemble_evaluator.monitor import Monitor +from ert.job_queue.queue import JobQueue from ert.shared.feature_toggling import FeatureToggling @@ -111,7 +112,7 @@ def test_run_legacy_ensemble_exception(tmpdir, make_ensemble_builder, monkeypatc ) evaluator = EnsembleEvaluator(ensemble, config, 0) - with patch.object(ensemble._job_queue, "add_realization") as faulty_queue: + with patch.object(JobQueue, "add_realization") as faulty_queue: faulty_queue.side_effect = RuntimeError() evaluator._start_running() with Monitor(config) as monitor: diff --git a/tests/unit_tests/job_queue/test_job_queue.py b/tests/unit_tests/job_queue/test_job_queue.py index e5cabca2f47..8d32a4db2a7 100644 --- a/tests/unit_tests/job_queue/test_job_queue.py +++ b/tests/unit_tests/job_queue/test_job_queue.py @@ -65,11 +65,21 @@ def create_local_queue( num_realizations: int = 10, max_runtime: Optional[int] = None, callback_timeout: Optional["Callable[[int], None]"] = None, + *, + ens_id: Optional[str] = None, + ee_uri: Optional[str] = None, + ee_cert: Optional[str] = None, + ee_token: Optional[str] = None, ): job_queue = JobQueue( QueueConfig.from_dict( {"driver_type": QueueSystem.LOCAL, "MAX_SUBMIT": max_submit} - ) + ), + ens_id=ens_id, + ee_uri=ee_uri, + ee_cert=ee_cert, + ee_token=ee_token, + verify_token=False, ) for iens in range(num_realizations): @@ -223,22 +233,21 @@ def test_timeout_jobs(tmpdir, monkeypatch, never_ending_script): def test_add_dispatch_info(tmpdir, monkeypatch, simple_script): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(simple_script) ens_id = "some_id" cert = "My very nice cert" token = "my_super_secret_token" dispatch_url = "wss://example.org" cert_file = ".ee.pem" - runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)] - for runpath in runpaths: - (runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8") - job_queue.set_ee_info( - ee_uri=dispatch_url, + job_queue = create_local_queue( + simple_script, ens_id=ens_id, + ee_uri=dispatch_url, ee_cert=cert, ee_token=token, - verify_context=False, ) + runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)] + for runpath in runpaths: + (runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8") job_queue.add_dispatch_information_to_jobs_file( experiment_id="experiment_id", ) @@ -256,18 +265,17 @@ def test_add_dispatch_info(tmpdir, monkeypatch, simple_script): def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(simple_script) ens_id = "some_id" dispatch_url = "wss://example.org" cert = None token = None cert_file = ".ee.pem" + job_queue = create_local_queue( + simple_script, ee_uri=dispatch_url, ens_id=ens_id, ee_cert=cert, ee_token=token + ) runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)] for runpath in runpaths: (runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8") - job_queue.set_ee_info( - ee_uri=dispatch_url, ens_id=ens_id, ee_cert=cert, ee_token=token - ) job_queue.add_dispatch_information_to_jobs_file() for runpath in runpaths: diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index e50fa4536cf..33bfc08f252 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -75,8 +75,7 @@ async def init(iens, *args, **kwargs): driver = mock_driver(init=init) - sch = scheduler.Scheduler(driver) - sch.add_realization(realization) + sch = scheduler.Scheduler(driver, [realization]) assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED assert await future == realization.iens @@ -97,8 +96,7 @@ async def kill(): killed = True driver = mock_driver(wait=wait, kill=kill) - sch = scheduler.Scheduler(driver) - sch.add_realization(realization) + sch = scheduler.Scheduler(driver, [realization]) scheduler_task = asyncio.create_task(sch.execute()) @@ -130,11 +128,15 @@ async def test_add_dispatch_information_to_jobs_file(storage, tmp_path: Path): for iens in range(ensemble_size) ] - sch = scheduler.Scheduler() - sch.set_ee_info(test_ee_uri, test_ens_id, test_ee_cert, test_ee_token) + sch = scheduler.Scheduler( + realizations=realizations, + ens_id=test_ens_id, + ee_uri=test_ee_uri, + ee_cert=test_ee_cert, + ee_token=test_ee_token, + ) for realization in realizations: - sch.add_realization(realization) create_jobs_json(realization) sch.add_dispatch_information_to_jobs_file() @@ -169,10 +171,9 @@ async def wait(): return False driver = mock_driver(init=init, wait=wait) - sch = scheduler.Scheduler(driver) + sch = scheduler.Scheduler(driver, [realization]) sch._max_submit = max_submit - sch.add_realization(realization, callback_timeout=lambda _: None) assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED assert retries == max_submit