Skip to content

Commit

Permalink
Merge pull request #543 from materialsproject/zero_mq
Browse files Browse the repository at this point in the history
Replace `pynng` functionality with `pyzmq`
  • Loading branch information
Jason Munro authored Jan 26, 2022
2 parents d7d93f0 + 250d25a commit 309df0e
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 113 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pydantic==1.9.0
fastapi==0.73.0
numpy==1.19.5;python_version<"3.7"
numpy==1.21.0;python_version>"3.6"
pynng==0.5.0
pyzmq==22.3.0
dnspython==2.1.0
uvicorn==0.13.4
sshtunnel==0.4.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"numpy>=1.17.3",
"pydantic>=0.32.2",
"fastapi>=0.42.0",
"pynng>=0.5.0",
"pyzmq==22.3.0",
"dnspython>=1.16.0",
"sshtunnel>=0.1.5",
"msgpack>=0.5.6",
Expand Down
47 changes: 38 additions & 9 deletions src/maggma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
)
@click.option(
"-n",
"--num-workers",
"num_workers",
help="Number of worker processes. Defaults to single processing",
"--num-processes",
"num_processes",
help="Number of processes to spawn for each worker. Defaults to single processing",
default=1,
type=click.IntRange(1),
)
Expand All @@ -56,12 +56,35 @@
help="Port for distributed communication."
" mrun will find an open port if None is provided to the manager",
)
@click.option("-N", "--num-chunks", "num_chunks", default=0, type=int)
@click.option(
"-N",
"--num-chunks",
"num_chunks",
default=0,
type=int,
help="Number of chunks to distribute to workers",
)
@click.option(
"-w",
"--num-workers",
"num_workers",
default=0,
type=int,
help="Number of distributed workers to process chunks",
)
@click.option(
"--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations"
)
def run(
builders, verbosity, reporting_store, num_workers, url, port, num_chunks, no_bars
builders,
verbosity,
reporting_store,
num_workers,
url,
port,
num_chunks,
no_bars,
num_processes,
):

# Set Logging
Expand Down Expand Up @@ -100,19 +123,25 @@ def run(
root.critical(f"Using random port for mrun manager: {port}")
loop.run_until_complete(
manager(
url=url, port=port, builders=builder_objects, num_chunks=num_chunks
url=url,
port=port,
builders=builder_objects,
num_chunks=num_chunks,
num_workers=num_workers,
)
)
else:
# worker
loop.run_until_complete(worker(url=url, port=port, num_workers=num_workers))
loop.run_until_complete(
worker(url=url, port=port, num_processes=num_processes)
)
else:
if num_workers == 1:
if num_processes == 1:
for builder in builder_objects:
serial(builder, no_bars)
else:
loop = asyncio.get_event_loop()
for builder in builder_objects:
loop.run_until_complete(
multi(builder=builder, num_workers=num_workers, no_bars=no_bars)
multi(builder=builder, num_processes=num_processes, no_bars=no_bars)
)
116 changes: 75 additions & 41 deletions src/maggma/cli/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,94 @@
# coding utf-8

import json
from asyncio import wait
from logging import getLogger
from socket import socket
import socket as pysocket
from typing import List

from monty.json import jsanitize
from monty.serialization import MontyDecoder
from pynng import Pair1

from maggma.cli.multiprocessing import multi
from maggma.core import Builder
from maggma.utils import tqdm

from zmq import REP, REQ
import zmq.asyncio as zmq


def find_port():
sock = socket()
sock = pysocket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


async def manager(url: str, port: int, builders: List[Builder], num_chunks: int):
async def manager(
url: str, port: int, builders: List[Builder], num_chunks: int, num_workers: int
):
"""
Really simple manager for distributed processing that uses a builder prechunk to modify
the builder and send out modified builders for each worker to run
"""
logger = getLogger("Manager")

if not (num_chunks and num_workers):
raise ValueError("Both num_chunks and num_workers must be non-zero")

logger.info(f"Binding to Manager URL {url}:{port}")
with Pair1(listen=f"{url}:{port}", polyamorous=True) as workers:
context = zmq.Context()
socket = context.socket(REP)
socket.bind(f"{url}:{port}")

for builder in builders:
logger.info(f"Working on {builder.__class__.__name__}")
builder_dict = builder.as_dict()

try:

for builder in builders:
logger.info(f"Working on {builder.__class__.__name__}")
builder_dict = builder.as_dict()
builder.connect()
chunks_tuples = [(d, False) for d in builder.prechunk(num_chunks)]

try:
logger.info(f"Distributing {len(chunks_tuples)} chunks to workers")

builder.connect()
chunks_dicts = list(builder.prechunk(num_chunks))
for chunk_dict, distributed in tqdm(chunks_tuples, desc="Chunks"):
while not distributed:
if num_workers <= 0:
socket.close()
raise RuntimeError("No workers left to distribute chunks to")

logger.info(f"Distributing {len(chunks_dicts)} chunks to workers")
for chunk_dict in tqdm(chunks_dicts, desc="Chunks"):
temp_builder_dict = dict(**builder_dict)
temp_builder_dict.update(chunk_dict)
temp_builder_dict = jsanitize(temp_builder_dict)

# Wait for client connection that announces client and says it is ready to do work
logger.debug("Waiting for a worker")
worker = await workers.arecv_msg()
logger.debug(
f"Got connection from worker: {worker.pipe.remote_address}"
)
# Send out the next chunk
await worker.pipe.asend(
json.dumps(temp_builder_dict).encode("utf-8")
)
except NotImplementedError:
logger.error(
f"Can't distributed process {builder.__class__.__name__}. Skipping for now"
)

# Clean up and tell workers to shut down
await wait(
[pipe.asend(json.dumps({}).encode("utf-8")) for pipe in workers.pipes]
)


async def worker(url: str, port: int, num_workers: int):

worker = await socket.recv()

if worker.decode("utf-8") == "ERROR":
num_workers -= 1
else:
logger.debug(
f"Got connection from worker: {worker.decode('utf-8')}"
)
# Send out the next chunk
await socket.send(json.dumps(temp_builder_dict).encode("utf-8"))
distributed = True

logger.info("Sending exit messages to workers")
for _ in range(num_workers):
await socket.recv()
await socket.send_json("EXIT")

except NotImplementedError:
logger.error(
f"Can't distributed process {builder.__class__.__name__}. Skipping for now"
)

socket.close()


async def worker(url: str, port: int, num_processes: int):
"""
Simple distributed worker that connects to a manager asks for work and deploys
using multiprocessing
Expand All @@ -77,18 +98,31 @@ async def worker(url: str, port: int, num_workers: int):
logger = getLogger("Worker")

logger.info(f"Connnecting to Manager at {url}:{port}")
with Pair1(dial=f"{url}:{port}", polyamorous=True) as manager:
logger.info(f"Connected to Manager at {url}:{port}")
context = zmq.Context()
socket = context.socket(REQ)
socket.connect(f"{url}:{port}")

# Initial message package
hostname = pysocket.gethostname()

try:
running = True
while running:
await manager.asend(b"Ready")
message = await manager.arecv()
await socket.send(hostname.encode("utf-8"))
message = await socket.recv()
work = json.loads(message.decode("utf-8"))
if "@class" in work and "@module" in work:
# We have a valid builder
builder = MontyDecoder().process_decoded(work)
await multi(builder, num_workers)
else:
await multi(builder, num_processes)
elif work == "EXIT":
# End the worker
# This should look for a specific message ?
running = False

except Exception as e:
logger.error(f"A worker failed with error: {e}")
await socket.send("ERROR".encode("utf-8"))

socket.close()

socket.close()
4 changes: 2 additions & 2 deletions src/maggma/cli/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def safe_dispatch(val):
return None


async def multi(builder, num_workers, no_bars=False):
async def multi(builder, num_processes, no_bars=False):

builder.connect()
cursor = builder.get_items()
executor = ProcessPoolExecutor(num_workers)
executor = ProcessPoolExecutor(num_processes)

# Gets the total number of items to process by priming
# the cursor
Expand Down
Loading

0 comments on commit 309df0e

Please sign in to comment.