Skip to content

Commit

Permalink
Refactor JobQueue and Scheduler
Browse files Browse the repository at this point in the history
Initialisation of `JobQueue` and `Scheduler` is moved to where they are
used. This means that the current event loop is the same during
initialisation and use.

Changes:
- `add_realization`: We can simply pass realisations as a list to the
  given executor.
- `set_ee_info`: We can simply pass this information to the queues.
- `CONCURRENT_INITIALIZATION`: No longer passed to the queue as it's a
  constant.
- `timeout_callback`: Now passed to the queue, rather than with each realisation.
- `queue_evaluators`: Concept removed. Only used for
  "min_required_realizations", which we can call directly.
  • Loading branch information
pinkwah committed Dec 21, 2023
1 parent 0039478 commit 4a3723b
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 174 deletions.
7 changes: 7 additions & 0 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def create_local_copy(self) -> QueueConfig:
self.queue_options,
)

@property
def max_running(self) -> int:
for key, val in self.queue_options.get(self.queue_system, []):
if key == "MAX_RUNNING":
return int(val)
return 0


def _check_for_overwritten_queue_system_options(
selected_queue_system: QueueSystem,
Expand Down
101 changes: 56 additions & 45 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@
import logging
import threading
import uuid
from functools import partial, partialmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from functools import partialmethod
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Protocol,
Tuple,
)

from cloudevents.http.event import CloudEvent

Expand All @@ -31,6 +41,11 @@
event_logger = logging.getLogger("ert.event_log")


class _KillAllJobs(Protocol):
def kill_all_jobs(self) -> None:
...


class LegacyEnsemble(Ensemble):
def __init__(
self,
Expand All @@ -42,16 +57,9 @@ def __init__(
id_: str,
) -> None:
super().__init__(reals, metadata, id_)
if not queue_config:
raise ValueError(f"{self} needs queue_config")

if FeatureToggling.is_enabled("scheduler"):
if queue_config.queue_system != QueueSystem.LOCAL:
raise NotImplementedError()
driver = create_driver(queue_config)
self._job_queue = Scheduler(driver)
else:
self._job_queue = JobQueue(queue_config)

self._queue_config = queue_config
self._job_queue: Optional[_KillAllJobs] = None
self.stop_long_running = stop_long_running
self.min_required_realizations = min_required_realizations
self._config: Optional[EvaluatorServerConfig] = None
Expand Down Expand Up @@ -180,43 +188,45 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
raise ValueError("no config") # mypy

try:
if FeatureToggling.is_enabled("scheduler"):
if self._queue_config.queue_system != QueueSystem.LOCAL:
raise NotImplementedError()
driver = create_driver(self._queue_config)
queue = Scheduler(
driver,
self.active_reals,
max_submit=self._queue_config.max_submit,
max_running=self._queue_config.max_running,
ens_id=self.id_,
ee_uri=self._config.dispatch_uri,
ee_cert=self._config.cert,
ee_token=self._config.token,
)
else:
queue = JobQueue(
self._queue_config,
self.active_reals,
ens_id=self.id_,
ee_uri=self._config.dispatch_uri,
ee_cert=self._config.cert,
ee_token=self._config.token,
on_timeout=on_timeout,
)
self._job_queue = queue

await cloudevent_unary_send(
event_creator(identifiers.EVTYPE_ENSEMBLE_STARTED, None)
)

for real in self.active_reals:
self._job_queue.add_realization(real, callback_timeout=on_timeout)

# TODO: this is sort of a callback being preemptively called.
# It should be lifted out of the queue/evaluate, into the evaluator. If
# something is long running, the evaluator will know and should send
# commands to the task in order to have it killed/retried.
# See https://github.com/equinor/ert/issues/1229
queue_evaluators = None
if self.stop_long_running and self.min_required_realizations > 0:
queue_evaluators = [
partial(
self._job_queue.stop_long_running_jobs,
self.min_required_realizations,
)
]

self._job_queue.set_ee_info(
ee_uri=self._config.dispatch_uri,
ens_id=self.id_,
ee_cert=self._config.cert,
ee_token=self._config.token,
)

# Tell queue to pass info to the jobs-file
# NOTE: This touches files on disk...
self._job_queue.add_dispatch_information_to_jobs_file()
if isinstance(queue, Scheduler):
result = await queue.execute()
elif isinstance(queue, JobQueue):
min_required_realizations = (
self.min_required_realizations if self.stop_long_running else 0
)
queue.add_dispatch_information_to_jobs_file()

sema = threading.BoundedSemaphore(value=CONCURRENT_INTERNALIZATION)
result: str = await self._job_queue.execute(
sema,
queue_evaluators,
)
result = await queue.execute(min_required_realizations)

except Exception:
logger.exception(
Expand All @@ -236,5 +246,6 @@ def cancellable(self) -> bool:
return True

def cancel(self) -> None:
self._job_queue.kill_all_jobs()
if self._job_queue is not None:
self._job_queue.kill_all_jobs()
logger.debug("evaluator cancelled")
72 changes: 32 additions & 40 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@
import json
import logging
import ssl
import threading
import time
from collections import deque
from threading import BoundedSemaphore, Semaphore
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)
Expand Down Expand Up @@ -102,7 +100,18 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return self.__repr__()

def __init__(self, queue_config: QueueConfig):
def __init__(
self,
queue_config: QueueConfig,
realizations: Optional[Sequence[Realization]] = None,
*,
ens_id: Optional[str] = None,
ee_uri: Optional[str] = None,
ee_cert: Optional[str] = None,
ee_token: Optional[str] = None,
on_timeout: Optional[Callable[[int], None]] = None,
verify_token: bool = True,
) -> None:
self.job_list: List[JobQueueNode] = []
self._stopped = False
self.driver: Driver = Driver.create_driver(queue_config)
Expand All @@ -112,17 +121,28 @@ def __init__(self, queue_config: QueueConfig):
self._differ = QueueDiffer()
self._max_submit = queue_config.max_submit
self._pool_sema = BoundedSemaphore(value=CONCURRENT_INTERNALIZATION)
self._on_timeout = on_timeout

self._ens_id = ens_id
self._ee_uri = ee_uri
self._ee_cert = ee_cert
self._ee_token = ee_token

self._ens_id: Optional[str] = None
self._ee_uri: Optional[str] = None
self._ee_cert: Optional[Union[str, bytes]] = None
self._ee_token: Optional[str] = None
self._ee_ssl_context: Optional[Union[ssl.SSLContext, bool]] = None
if ee_cert is not None:
self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if verify_token:
self._ee_ssl_context.load_verify_locations(cadata=ee_cert)
else:
self._ee_ssl_context = True if ee_uri and ee_uri.startswith("wss") else None

self._changes_to_publish: Optional[
asyncio.Queue[Union[Dict[int, str], object]]
] = None

for real in realizations or []:
self.add_realization(real)

def get_max_running(self) -> int:
return self.driver.get_max_running()

Expand Down Expand Up @@ -215,27 +235,6 @@ def launch_jobs(self, pool_sema: Semaphore) -> None:
max_submit=self.max_submit,
)

def set_ee_info(
self,
ee_uri: str,
ens_id: str,
ee_cert: Optional[Union[str, bytes]] = None,
ee_token: Optional[str] = None,
verify_context: bool = True,
) -> None:
self._ens_id = ens_id
self._ee_token = ee_token

self._ee_uri = ee_uri
if ee_cert is not None:
self._ee_cert = ee_cert
self._ee_token = ee_token
self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if verify_context:
self._ee_ssl_context.load_verify_locations(cadata=ee_cert)
else:
self._ee_ssl_context = True if ee_uri.startswith("wss") else None

@staticmethod
def _translate_change_to_cloudevent(
ens_id: str, real_id: int, status: str
Expand Down Expand Up @@ -304,14 +303,8 @@ async def _jobqueue_publisher(self) -> None:

async def execute(
self,
pool_sema: Optional[threading.BoundedSemaphore] = None,
evaluators: Optional[Iterable[Callable[..., Any]]] = None,
min_required_realizations: int = 0,
) -> str:
if pool_sema is not None:
self._pool_sema = pool_sema
if evaluators is None:
evaluators = []

self._changes_to_publish = asyncio.Queue()
asyncio.create_task(self._jobqueue_publisher())

Expand All @@ -322,8 +315,8 @@ async def execute(

await asyncio.sleep(1)

for func in evaluators:
func()
if min_required_realizations > 0:
self.stop_long_running_jobs(min_required_realizations)

changes, new_state = self.changes_without_transition()
if len(changes) > 0:
Expand Down Expand Up @@ -379,14 +372,13 @@ def add_job_from_run_arg(
def add_realization(
self,
real: Realization,
callback_timeout: Optional[Callable[[int], None]] = None,
) -> None:
job = JobQueueNode(
job_script=real.job_script,
num_cpu=real.num_cpu,
run_arg=real.run_arg,
max_runtime=real.max_runtime,
callback_timeout=callback_timeout,
callback_timeout=self._on_timeout,
)
if job is None:
raise ValueError("JobQueueNode constructor created None job")
Expand Down
3 changes: 0 additions & 3 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,4 @@ async def _send(self, state: State) -> None:
"queue_event_type": status,
},
)
if self._scheduler._events is None:
await self._scheduler.ainit()
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(event))
Loading

0 comments on commit 4a3723b

Please sign in to comment.