Skip to content

Commit

Permalink
Merge pull request #776 from materialsproject/enhancement/rabbitmq_su…
Browse files Browse the repository at this point in the history
…pport

Overhaul distributed framework and add RabbitMQ support
  • Loading branch information
Jason Munro authored Feb 16, 2023
2 parents fc61170 + 618d87a commit 6f6a2f2
Show file tree
Hide file tree
Showing 6 changed files with 427 additions and 91 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[run]
omit = *test*
omit = *test*, rabbitmq*
71 changes: 47 additions & 24 deletions src/maggma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import click
from monty.serialization import loadfn

from maggma.cli.distributed import find_port, manager, worker
from maggma.cli.distributed import find_port
from maggma.cli.multiprocessing import multi
from maggma.cli.serial import serial
from maggma.cli.source_loader import ScriptFinder, load_builder_from_source
Expand Down Expand Up @@ -44,17 +44,14 @@
help="Store in JSON/YAML form to send reporting data to",
type=click.Path(exists=True),
)
@click.option(
"-u", "--url", "url", default=None, type=str, help="URL for the distributed manager"
)
@click.option("-u", "--url", "url", default=None, type=str, help="URL for the distributed manager")
@click.option(
"-p",
"--port",
"port",
default=None,
type=int,
help="Port for distributed communication."
" mrun will find an open port if None is provided to the manager",
help="Port for distributed communication." " mrun will find an open port if None is provided to the manager",
)
@click.option(
"-N",
Expand All @@ -72,8 +69,15 @@
type=int,
help="Number of distributed workers to process chunks",
)
@click.option("--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations")
@click.option("--rabbitmq", is_flag=True, help="Enables the use of RabbitMQ as the work broker")
@click.option(
"--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations"
"-q",
"--queue_prefix",
"queue_prefix",
default="builder",
type=str,
help="Prefix to use in queue names when RabbitMQ is select as the broker",
)
def run(
builders,
Expand All @@ -85,17 +89,22 @@ def run(
num_chunks,
no_bars,
num_processes,
rabbitmq,
queue_prefix,
):
# Import proper manager and worker
if rabbitmq:
from maggma.cli.rabbitmq import manager, worker
else:
from maggma.cli.distributed import manager, worker

# Set Logging
levels = [logging.WARNING, logging.INFO, logging.DEBUG]
level = levels[min(len(levels) - 1, verbosity)] # capped to number of levels
root = logging.getLogger()
root.setLevel(level)
ch = TqdmLoggingHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
root.addHandler(ch)

Expand All @@ -121,27 +130,41 @@ def run(
port = find_port()
root.critical(f"Using random port for mrun manager: {port}")

manager(
url=url,
port=port,
builders=builder_objects,
num_chunks=num_chunks,
num_workers=num_workers,
)
if rabbitmq:
manager(
url=url,
port=port,
builders=builder_objects,
num_chunks=num_chunks,
num_workers=num_workers,
queue_prefix=queue_prefix,
)
else:
manager(
url=url,
port=port,
builders=builder_objects,
num_chunks=num_chunks,
num_workers=num_workers,
)

else:
# worker
loop = asyncio.get_event_loop()
loop.run_until_complete(
# Worker
if rabbitmq:
worker(
url=url,
port=port,
num_processes=num_processes,
no_bars=no_bars,
queue_prefix=queue_prefix,
)
else:
worker(url=url, port=port, num_processes=num_processes, no_bars=no_bars)
)
else:
if num_processes == 1:
for builder in builder_objects:
serial(builder, no_bars)
else:
loop = asyncio.get_event_loop()
for builder in builder_objects:
loop.run_until_complete(
multi(builder=builder, num_processes=num_processes, no_bars=no_bars)
)
loop.run_until_complete(multi(builder=builder, num_processes=num_processes, no_bars=no_bars))
78 changes: 47 additions & 31 deletions src/maggma/cli/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def find_port():
return sock.getsockname()[1]


def manager(
url: str, port: int, builders: List[Builder], num_chunks: int, num_workers: int
):
def manager(url: str, port: int, builders: List[Builder], num_chunks: int, num_workers: int):
"""
Really simple manager for distributed processing that uses a builder prechunk to modify
the builder and send out modified builders for each worker to run.
Expand Down Expand Up @@ -65,10 +63,7 @@ def manager(

try:
builder.connect()
chunk_dicts = [
{"chunk": d, "distributed": False, "completed": False}
for d in builder.prechunk(num_chunks)
]
chunk_dicts = [{"chunk": d, "distributed": False, "completed": False} for d in builder.prechunk(num_chunks)]
pbar_distributed = tqdm(
total=len(chunk_dicts),
desc="Distributed chunks for {}".format(builder.__class__.__name__),
Expand All @@ -83,14 +78,11 @@ def manager(

except NotImplementedError:
attempt_graceful_shutdown(workers, socket)
raise RuntimeError(
f"Can't distribute process {builder.__class__.__name__} as no prechunk method exists."
)
raise RuntimeError(f"Can't distribute process {builder.__class__.__name__} as no prechunk method exists.")

completed = False

while not completed:

completed = all(chunk["completed"] for chunk in chunk_dicts)

if num_workers <= 0:
Expand Down Expand Up @@ -125,19 +117,15 @@ def manager(

# If everything is distributed, send EXIT to the worker
if all(chunk["distributed"] for chunk in chunk_dicts):
logger.debug(
f"Sending exit signal to worker: {msg.split('_')[1]}"
)
logger.debug(f"Sending exit signal to worker: {msg.split('_')[1]}")
socket.send_multipart([identity, b"", b"EXIT"])
workers.pop(identity)

elif "ERROR" in msg:
# Remove worker and requeue work sent to it
attempt_graceful_shutdown(workers, socket)
raise RuntimeError(
"At least one worker has stopped with error message: {}".format(
msg.split("_")[1]
)
"At least one worker has stopped with error message: {}".format(msg.split("_")[1])
)

elif msg == "PING":
Expand All @@ -146,20 +134,20 @@ def manager(
workers[identity]["last_ping"] = perf_counter()
workers[identity]["heartbeats"] += 1

print(workers)

# Decide if any workers are dead and need to be removed
handle_dead_workers(workers, socket)

for work_index, chunk_dict in enumerate(chunk_dicts):
if not chunk_dict["distributed"]:

temp_builder_dict = dict(**builder_dict)
temp_builder_dict.update(chunk_dict["chunk"]) # type: ignore
temp_builder_dict = jsanitize(temp_builder_dict)

# Send work for available workers
for identity in workers:
if not workers[identity]["working"]:

# Send out a chunk to idle worker
socket.send_multipart(
[
Expand Down Expand Up @@ -213,12 +201,10 @@ def handle_dead_workers(workers, socket):
z_score = 0.6745 * (workers[identity]["heartbeats"] - median) / mad
if z_score <= -3.5:
attempt_graceful_shutdown(workers, socket)
raise RuntimeError(
"At least one worker has timed out. Stopping distributed build."
)
raise RuntimeError("At least one worker has timed out. Stopping distributed build.")


async def worker(url: str, port: int, num_processes: int, no_bars: bool):
def worker(url: str, port: int, num_processes: int, no_bars: bool):
"""
Simple distributed worker that connects to a manager asks for work and deploys
using multiprocessing
Expand All @@ -227,38 +213,68 @@ async def worker(url: str, port: int, num_processes: int, no_bars: bool):
logger = getLogger(f"Worker {identity}")

logger.info(f"Connnecting to Manager at {url}:{port}")
context = azmq.Context()
socket = context.socket(zmq.REQ)
context = zmq.Context()
socket: zmq.Socket = context.socket(zmq.REQ)

socket.setsockopt_string(zmq.IDENTITY, identity)
socket.connect(f"{url}:{port}")

poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)

# Initial message package
hostname = pysocket.gethostname()

try:
running = True
while running:
await socket.send("READY_{}".format(hostname).encode("utf-8"))
try:
bmessage: bytes = await asyncio.wait_for(socket.recv(), timeout=MANAGER_TIMEOUT) # type: ignore
except asyncio.TimeoutError:
socket.send("READY_{}".format(hostname).encode("utf-8"))

# Poll for MANAGER_TIMEOUT seconds, if nothing is given then assume manager is dead and timeout
connections = dict(poller.poll(MANAGER_TIMEOUT * 1000))
if not connections:
socket.close()
raise RuntimeError("Stopping work as manager timed out.")

bmessage: bytes = socket.recv()

message = bmessage.decode("utf-8")
if "@class" in message and "@module" in message:
# We have a valid builder
work = json.loads(message)
builder = MontyDecoder().process_decoded(work)
await multi(builder, num_processes, socket=socket, no_bars=no_bars)

asyncio.run(
multi(
builder,
num_processes,
no_bars=no_bars,
heartbeat_func=ping_manager,
heartbeat_func_kwargs={"socket": socket, "poller": poller},
)
)
elif message == "EXIT":
# End the worker
running = False

except Exception as e:
logger.error(f"A worker failed with error: {e}")
await socket.send("ERROR_{}".format(e).encode("utf-8"))
socket.send("ERROR_{}".format(e).encode("utf-8"))
socket.close()

socket.close()


def ping_manager(socket, poller):
socket.send_string("PING")

# Poll for MANAGER_TIMEOUT seconds, if nothing is given then assume manager is dead and timeout
connections = dict(poller.poll(MANAGER_TIMEOUT * 1000))
if not connections:
socket.close()
raise RuntimeError("Stopping work as manager timed out.")

message: bytes = socket.recv()
if message.decode("utf-8") != "PONG":
socket.close()
raise RuntimeError("Stopping work as manager did not respond to heartbeat from worker.")
36 changes: 13 additions & 23 deletions src/maggma/cli/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
Queue,
gather,
get_event_loop,
wait_for,
TimeoutError,
)
from concurrent.futures import ProcessPoolExecutor
from logging import getLogger
from types import GeneratorType
from typing import Any, Awaitable, Callable, Dict, Optional

from aioitertools import enumerate
from tqdm import tqdm

from maggma.utils import primed

MANAGER_TIMEOUT = 5400 # max timeout in seconds for manager
MANAGER_TIMEOUT = 300 # max timeout in seconds for manager

logger = getLogger("MultiProcessor")

Expand Down Expand Up @@ -153,8 +152,13 @@ def safe_dispatch(val):
return None


async def multi(builder, num_processes, no_bars=False, socket=None):

async def multi(
builder,
num_processes,
no_bars=False,
heartbeat_func: Optional[Callable[..., Any]] = None,
heartbeat_func_kwargs: Dict[Any, Any] = {},
):
builder.connect()
cursor = builder.get_items()
executor = ProcessPoolExecutor(num_processes)
Expand Down Expand Up @@ -205,15 +209,14 @@ async def multi(builder, num_processes, no_bars=False, socket=None):
disable=no_bars,
)

if socket:
await ping_manager(socket)
if heartbeat_func:
heartbeat_func(**heartbeat_func_kwargs)

back_pressure_relief = back_pressured_get.release(processed_items)

update_items = tqdm(total=total, desc="Update Targets", disable=no_bars)

async for chunk in grouper(back_pressure_relief, n=builder.chunk_size):

logger.info(
"Processed batch of {} items".format(builder.chunk_size),
extra={
Expand All @@ -230,8 +233,8 @@ async def multi(builder, num_processes, no_bars=False, socket=None):
builder.update_targets(processed_items)
update_items.update(len(processed_items))

if socket:
await ping_manager(socket)
if heartbeat_func:
heartbeat_func(**heartbeat_func_kwargs)

logger.info(
f"Ended multiprocessing: {builder.__class__.__name__}",
Expand All @@ -247,16 +250,3 @@ async def multi(builder, num_processes, no_bars=False, socket=None):

update_items.close()
builder.finalize()


async def ping_manager(socket):
await socket.send_string("PING")
try:
message = await wait_for(socket.recv(), timeout=MANAGER_TIMEOUT)
if message.decode("utf-8") != "PONG":
socket.close()
raise RuntimeError("Stopping work as manager did not respond to heartbeat from worker.")

except TimeoutError:
socket.close()
raise RuntimeError("Stopping work as manager is not responding.")
Loading

0 comments on commit 6f6a2f2

Please sign in to comment.