Skip to content

Commit

Permalink
Separate EP heartbeat processing from results
Browse files Browse the repository at this point in the history
To ensure endpoint heartbeats are always processed in a timely manner, a
dedicated multiprocessing queue and RabbitMQ channel, exchange and queue
are now used. Also, the `EndpointInterchange` now uses a separate thread
to process pending heartbeats.
  • Loading branch information
rjmello committed Jan 15, 2025
1 parent 1cb7257 commit 6abe216
Show file tree
Hide file tree
Showing 17 changed files with 331 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ def __init__(
cq_info = reg_info["command_queue_info"]
_ = cq_info["connection_url"], cq_info["queue"]

rq_info = reg_info["result_queue_info"]
_ = rq_info["connection_url"], rq_info["queue"]
_ = rq_info["queue_publish_kwargs"]
hbq_info = reg_info["heartbeat_queue_info"]
_ = hbq_info["connection_url"], hbq_info["queue"]
_ = hbq_info["queue_publish_kwargs"]
except Exception as e:
log_reg_info = _redact_url_creds(str(reg_info))
log.debug("%s", log_reg_info)
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(
stop_event=self._command_stop_event,
thread_name="CQS",
)
self._heartbeat_publisher = ResultPublisher(queue_info=rq_info)
self._heartbeat_publisher = ResultPublisher(queue_info=hbq_info)

@staticmethod
def get_metadata(config: ManagerEndpointConfig, conf_dir: pathlib.Path) -> dict:
Expand Down
92 changes: 75 additions & 17 deletions compute_endpoint/globus_compute_endpoint/endpoint/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@

class _ResultPassthroughType(t.TypedDict):
message: bytes
task_id: str | None
task_id: str


class _HeartbeatPassthroughType(t.TypedDict):
message: bytes


class EndpointInterchange:
Expand Down Expand Up @@ -74,8 +78,8 @@ def __init__(
Globus Compute config object describing how compute should be provisioned
reg_info : dict[str, dict]
Dictionary containing connection information for both the task and
result queues. The required data structure is returned from the
Dictionary containing connection information for the task, result and
heartbeat queues. The required data structure is returned from the
Endpoint registration API call, encapsulated in the SDK by
`Client.register_endpoint()`.
Expand All @@ -100,6 +104,7 @@ def __init__(

self.task_q_info = reg_info["task_queue_info"]
self.result_q_info = reg_info["result_queue_info"]
self.heartbeat_q_info = reg_info["heartbeat_queue_info"]

self.time_to_quit = False
self.heartbeat_period = self.config.heartbeat_period
Expand Down Expand Up @@ -132,6 +137,9 @@ def __init__(
log.info(f"Platform info: {self.current_platform}")

self.results_passthrough: queue.Queue[_ResultPassthroughType] = queue.Queue()
self.heartbeats_passthrough: queue.Queue[_HeartbeatPassthroughType] = (
queue.Queue()
)
# Rename self.executor -> self.engine in second round
self.executor: GlobusComputeEngineBase = self.config.executors[0]
self._test_start = False
Expand All @@ -140,6 +148,7 @@ def start_engine(self):
log.info("Starting Engine")
self.executor.start(
results_passthrough=self.results_passthrough,
heartbeats_passthrough=self.heartbeats_passthrough,
endpoint_id=self.endpoint_id,
run_dir=self.logdir,
)
Expand Down Expand Up @@ -294,6 +303,9 @@ def _main_loop(self):
results_publisher = ResultPublisher(queue_info=self.result_q_info)
results_publisher.start()

heartbeat_publisher = ResultPublisher(queue_info=self.heartbeat_q_info)
heartbeat_publisher.start()

executor = self.executor

num_tasks_forwarded = 0
Expand Down Expand Up @@ -338,8 +350,8 @@ def process_pending_tasks() -> None:
d_tag, prop_headers, body = self.pending_task_queue.get(timeout=1)
task_q_subscriber.ack(d_tag)

fid: str = prop_headers.get("function_uuid")
tid: str = prop_headers.get("task_uuid")
fid: str | None = prop_headers.get("function_uuid")
tid: str | None = prop_headers.get("task_uuid")

if not fid or not tid:
raise InvalidMessageError(
Expand Down Expand Up @@ -407,15 +419,14 @@ def process_pending_results() -> None:
# iterating the loop regardless.
nonlocal num_results_forwarded

def _create_done_cb(mq_msg: bytes, tid: str | None):
def _create_done_cb(mq_msg: bytes, tid: str):
def _done_cb(pub_fut: Future):
_exc = pub_fut.exception()
if _exc:
# Publishing didn't work -- quiesce and see if a simple
# restart fixes the issue.
if tid:
log.info(f"Storing result for later: {tid}")
self.result_store[tid] = mq_msg
log.info(f"Storing result for later: {tid}")
self.result_store[tid] = mq_msg

self._quiesce_event.set()
log.error("Failed to publish results", exc_info=_exc)
Expand All @@ -426,7 +437,7 @@ def _done_cb(pub_fut: Future):
try:
msg = self.results_passthrough.get(timeout=1)
packed_message: bytes = msg["message"]
task_id: str | None = msg.get("task_id")
task_id: str = msg["task_id"]

except queue.Empty:
continue
Expand All @@ -438,9 +449,8 @@ def _done_cb(pub_fut: Future):
)
continue

if task_id:
num_results_forwarded += 1
log.debug("Forwarding result for task: %s", task_id)
num_results_forwarded += 1
log.debug("Forwarding result for task: %s", task_id)

try:
f = results_publisher.publish(packed_message)
Expand All @@ -455,13 +465,53 @@ def _done_cb(pub_fut: Future):
"Something broke while forwarding results; setting quiesce"
" event"
)
if task_id:
log.info("Storing result for later: %s", task_id)
self.result_store[task_id] = packed_message
log.info("Storing result for later: %s", task_id)
self.result_store[task_id] = packed_message
continue # just be explicit

log.debug("Exit process-pending-results thread.")

def process_pending_heartbeats() -> None:
def _done_cb(pub_fut: Future):
_exc = pub_fut.exception()
if _exc:
# Publishing didn't work -- quiesce and see if a simple
# restart fixes the issue.
self._quiesce_event.set()
log.error("Failed to publish heartbeat", exc_info=_exc)

while not self._quiesce_event.is_set():
try:
msg = self.heartbeats_passthrough.get(timeout=1)
packed_message: bytes = msg["message"]

except queue.Empty:
continue

except Exception as exc:
log.warning(
"Invalid message received. Ignoring."
f" ([{type(exc).__name__}] {exc})"
)
continue

try:
f = heartbeat_publisher.publish(packed_message)
f.add_done_callback(_done_cb)

except Exception:
# Publishing didn't work -- quiesce and see if a simple restart
# fixes the issue.
self._quiesce_event.set()

log.exception(
"Something broke while forwarding heartbeats; setting quiesce"
" event"
)
continue # just be explicit

log.debug("Exit process-pending-heartbeats thread.")

stored_processor_thread = threading.Thread(
target=process_stored_results, daemon=True, name="Stored Result Handler"
)
Expand All @@ -471,9 +521,15 @@ def _done_cb(pub_fut: Future):
result_processor_thread = threading.Thread(
target=process_pending_results, daemon=True, name="Pending Result Handler"
)
heartbeat_processor_thread = threading.Thread(
target=process_pending_heartbeats,
daemon=True,
name="Pending Heartbeat Handler",
)
stored_processor_thread.start()
task_processor_thread.start()
result_processor_thread.start()
heartbeat_processor_thread.start()

connection_stable_hearbeats = 0
last_t, last_r = 0, 0
Expand Down Expand Up @@ -574,6 +630,7 @@ def _done_cb(pub_fut: Future):
stored_processor_thread.join(timeout=5)
task_processor_thread.join(timeout=5)
result_processor_thread.join(timeout=5)
heartbeat_processor_thread.join(timeout=5)

# let higher-level error handling take over if the following excepts
message = EPStatusReport(
Expand All @@ -589,7 +646,7 @@ def _done_cb(pub_fut: Future):
task_statuses={},
)
try:
f = results_publisher.publish(pack(message))
f = heartbeat_publisher.publish(pack(message))
f.result(timeout=5)
except concurrent.futures.TimeoutError:
log.warning(
Expand All @@ -598,5 +655,6 @@ def _done_cb(pub_fut: Future):

task_q_subscriber.stop()
results_publisher.stop()
heartbeat_publisher.stop()

log.debug("_main_loop exits")
3 changes: 2 additions & 1 deletion compute_endpoint/globus_compute_endpoint/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
self.results_passthrough: queue.Queue[dict[str, bytes | str | None]] = (
queue.Queue()
)
self.heartbeats_passthrough: queue.Queue[dict[str, bytes]] = queue.Queue()
self._engine_ready: bool = False

@abstractmethod
Expand All @@ -141,7 +142,7 @@ def set_working_dir(self, run_dir: str | None = None):
def report_status(self) -> None:
status_report = self.get_status_report()
packed: bytes = messagepack.pack(status_report)
self.results_passthrough.put({"message": packed})
self.heartbeats_passthrough.put({"message": packed})

def _handle_task_exception(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def start(
endpoint_id: t.Optional[uuid.UUID] = None,
run_dir: t.Optional[str] = None,
results_passthrough: t.Optional[queue.Queue] = None,
heartbeats_passthrough: t.Optional[queue.Queue] = None,
**kwargs,
):
assert endpoint_id, "GCExecutor requires kwarg:endpoint_id at start"
Expand All @@ -258,6 +259,8 @@ def start(
# Only update the default queue in GCExecutorBase if
# a queue is passed in
self.results_passthrough = results_passthrough
if heartbeats_passthrough:
self.heartbeats_passthrough = heartbeats_passthrough
self.executor.start()
self._status_report_thread.start()
# Add executor to poller *after* executor has started
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def start(
endpoint_id: t.Optional[uuid.UUID] = None,
run_dir: t.Optional[str] = None,
results_passthrough: t.Optional[queue.Queue] = None,
heartbeats_passthrough: t.Optional[queue.Queue] = None,
**kwargs,
) -> None:
"""
Expand All @@ -44,6 +45,7 @@ def start(
endpoint_id: Endpoint UUID
run_dir: endpoint run directory
results_passthrough: Queue to which packed results will be posted
heartbeats_passthrough: Queue to which packed status reports are posted
Returns
-------
"""
Expand All @@ -60,6 +62,8 @@ def start(
self.endpoint_id = endpoint_id
if results_passthrough:
self.results_passthrough = results_passthrough
if heartbeats_passthrough:
self.heartbeats_passthrough = heartbeats_passthrough
assert self.results_passthrough
self.set_working_dir(run_dir=run_dir)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def start(
endpoint_id: t.Optional[uuid.UUID] = None,
run_dir: t.Optional[str] = None,
results_passthrough: t.Optional[queue.Queue] = None,
heartbeats_passthrough: t.Optional[queue.Queue] = None,
**kwargs,
) -> None:
"""
Expand All @@ -42,13 +43,16 @@ def start(
endpoint_id: Endpoint UUID
run_dir: endpoint run directory
results_passthrough: Queue to which packed results will be posted
heartbeats_passthrough: Queue to which packed status reports are posted
Returns
-------
"""
assert endpoint_id, "ThreadPoolEngine requires kwarg:endpoint_id at start"
self.endpoint_id = endpoint_id
if results_passthrough:
self.results_passthrough = results_passthrough
if heartbeats_passthrough:
self.heartbeats_passthrough = heartbeats_passthrough
assert self.results_passthrough

self.set_working_dir(run_dir=run_dir)
Expand Down
Loading

0 comments on commit 6abe216

Please sign in to comment.