From be77103ccd9ea6202a6f62e5604e27f0c14a934c Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 24 Jan 2022 16:16:09 -0800 Subject: [PATCH 01/15] Use zeromq for distributed communication --- src/maggma/cli/distributed.py | 116 ++++++++++++++++++++-------------- 1 file changed, 70 insertions(+), 46 deletions(-) diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index e8b6b320e..39565c3a6 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -2,27 +2,30 @@ # coding utf-8 import json -from asyncio import wait from logging import getLogger -from socket import socket +import socket as pysocket from typing import List from monty.json import jsanitize from monty.serialization import MontyDecoder -from pynng import Pair1 from maggma.cli.multiprocessing import multi from maggma.core import Builder from maggma.utils import tqdm +from zmq import REP, REQ +import zmq.asyncio as zmq + def find_port(): - sock = socket() + sock = pysocket.socket() sock.bind(("", 0)) return sock.getsockname()[1] -async def manager(url: str, port: int, builders: List[Builder], num_chunks: int): +async def manager( + url: str, port: int, builders: List[Builder], num_chunks: int, num_workers: int +): """ Really simple manager for distributed processing that uses a builder prechunk to modify the builder and send out modified builders for each worker to run @@ -30,42 +33,48 @@ async def manager(url: str, port: int, builders: List[Builder], num_chunks: int) logger = getLogger("Manager") logger.info(f"Binding to Manager URL {url}:{port}") - with Pair1(listen=f"{url}:{port}", polyamorous=True) as workers: - - for builder in builders: - logger.info(f"Working on {builder.__class__.__name__}") - builder_dict = builder.as_dict() - - try: - - builder.connect() - chunks_dicts = list(builder.prechunk(num_chunks)) - - logger.info(f"Distributing {len(chunks_dicts)} chunks to workers") - for chunk_dict in tqdm(chunks_dicts, desc="Chunks"): - temp_builder_dict = dict(**builder_dict) - temp_builder_dict.update(chunk_dict) - temp_builder_dict = jsanitize(temp_builder_dict) - - # Wait for client connection that announces client and says it is ready to do work - logger.debug("Waiting for a worker") - worker = await workers.arecv_msg() - logger.debug( - f"Got connection from worker: {worker.pipe.remote_address}" - ) - # Send out the next chunk - await worker.pipe.asend( - json.dumps(temp_builder_dict).encode("utf-8") - ) - except NotImplementedError: - logger.error( - f"Can't distributed process {builder.__class__.__name__}. Skipping for now" - ) - - # Clean up and tell workers to shut down - await wait( - [pipe.asend(json.dumps({}).encode("utf-8")) for pipe in workers.pipes] - ) + context = zmq.Context() + socket = context.socket(REP) + socket.bind(f"{url}:{port}") + + for builder in builders: + logger.info(f"Working on {builder.__class__.__name__}") + builder_dict = builder.as_dict() + + try: + + builder.connect() + chunks_dicts = list(builder.prechunk(num_chunks)) + + logger.info(f"Distributing {len(chunks_dicts)} chunks to workers") + for chunk_dict in tqdm(chunks_dicts, desc="Chunks"): + temp_builder_dict = dict(**builder_dict) + temp_builder_dict.update(chunk_dict) + temp_builder_dict = jsanitize(temp_builder_dict) + + # Wait for client connection that announces client and says it is ready to do work + logger.debug("Waiting for a worker") + + worker = await socket.recv() + + if worker.decode("utf-8") == "ERROR": + num_workers -= 1 + + logger.debug(f"Got connection from worker: {worker.decode('utf-8')}") + # Send out the next chunk + await socket.send(json.dumps(temp_builder_dict).encode("utf-8")) + + logger.info("Sending exit messages to workers") + for _ in range(num_workers): + await socket.recv() + await socket.send_json("EXIT") + + except NotImplementedError: + logger.error( + f"Can't distributed process {builder.__class__.__name__}. Skipping for now" + ) + + socket.close() async def worker(url: str, port: int, num_workers: int): @@ -77,18 +86,33 @@ async def worker(url: str, port: int, num_workers: int): logger = getLogger("Worker") logger.info(f"Connnecting to Manager at {url}:{port}") - with Pair1(dial=f"{url}:{port}", polyamorous=True) as manager: - logger.info(f"Connected to Manager at {url}:{port}") + context = zmq.Context() + socket = context.socket(REQ) + socket.connect(f"{url}:{port}") + + # Initial message package + hostname = pysocket.gethostname() + + try: running = True while running: - await manager.asend(b"Ready") - message = await manager.arecv() + await socket.send(hostname.encode("utf-8")) + message = await socket.recv() work = json.loads(message.decode("utf-8")) if "@class" in work and "@module" in work: # We have a valid builder builder = MontyDecoder().process_decoded(work) await multi(builder, num_workers) - else: + elif work == "EXIT": # End the worker # This should look for a specific message ? running = False + + except Exception as e: + logger.error(f"A worker failed with error: {e}") + await socket.send("ERROR".encode("utf-8")) + message = await socket.recv() + + socket.close() + + socket.close() From 39a0b5f6e3a520f796ace6106838e3642e8cf5b2 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 24 Jan 2022 16:26:49 -0800 Subject: [PATCH 02/15] Update tests --- tests/cli/test_distributed.py | 101 +++++++++++++++------------------- 1 file changed, 45 insertions(+), 56 deletions(-) diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index defcecabd..6cbefedbc 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -2,12 +2,18 @@ import json import pytest -from pynng import Pair1 -from pynng.exceptions import Timeout from maggma.cli.distributed import find_port, manager, worker from maggma.core import Builder +from zmq import REP, REQ +import zmq.asyncio as zmq +import socket as pysocket + +# TODO: Timeout errors? + +HOSTNAME = pysocket.gethostname() + class DummyBuilderWithNoPrechunk(Builder): def __init__(self, dummy_prechunk: bool, val: int = -1, **kwargs): @@ -44,7 +50,11 @@ async def manager_server(event_loop, log_to_stdout): task = asyncio.create_task( manager( - SERVER_URL, SERVER_PORT, [DummyBuilder(dummy_prechunk=False)], num_chunks=10 + SERVER_URL, + SERVER_PORT, + [DummyBuilder(dummy_prechunk=False)], + num_chunks=10, + num_workers=10, ) ) yield task @@ -52,70 +62,48 @@ async def manager_server(event_loop, log_to_stdout): @pytest.mark.asyncio -async def test_manager_wait_for_ready(manager_server): - with Pair1( - dial=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=100 - ) as manager: - with pytest.raises(Timeout): - manager.recv() +async def test_manager_give_out_chunks(manager_server, log_to_stdout): + context = zmq.Context() + socket = context.socket(REQ) + socket.connect(f"{SERVER_URL}:{SERVER_PORT}") + + for i in range(0, 10): + log_to_stdout.debug(f"Going to ask Manager for work: {i}") + await socket.send(b"Ready") + message = await socket.recv() -@pytest.mark.asyncio -async def test_manager_give_out_chunks(manager_server, log_to_stdout): - with Pair1( - dial=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=500 - ) as manager_socket: - - for i in range(0, 10): - log_to_stdout.debug(f"Going to ask Manager for work: {i}") - await manager_socket.asend(b"Ready") - message = await manager_socket.arecv() - print(message) - work = json.loads(message.decode("utf-8")) - - assert work["@class"] == "DummyBuilder" - assert work["@module"] == "tests.cli.test_distributed" - assert work["val"] == i - - await manager_socket.asend(b"Ready") - message = await manager_socket.arecv() work = json.loads(message.decode("utf-8")) - assert work == {} + + assert work["@class"] == "DummyBuilder" + assert work["@module"] == "tests.cli.test_distributed" + assert work["val"] == i @pytest.mark.asyncio async def test_worker(): - with Pair1( - listen=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=500 - ) as worker_socket: - - worker_task = asyncio.create_task( - worker(SERVER_URL, SERVER_PORT, num_workers=1) - ) - - message = await worker_socket.arecv() - assert message == b"Ready" + context = zmq.Context() + socket = context.socket(REP) + socket.bind(f"{SERVER_URL}:{SERVER_PORT}") - dummy_work = { - "@module": "tests.cli.test_distributed", - "@class": "DummyBuilder", - "@version": None, - "dummy_prechunk": False, - "val": 0, - } - for i in range(2): - await worker_socket.asend(json.dumps(dummy_work).encode("utf-8")) - await asyncio.sleep(1) - message = await worker_socket.arecv() - assert message == b"Ready" + worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_workers=1)) - await worker_socket.asend(json.dumps({}).encode("utf-8")) - with pytest.raises(Timeout): - await worker_socket.arecv() + message = await socket.recv() - assert len(worker_socket.pipes) == 0 + dummy_work = { + "@module": "tests.cli.test_distributed", + "@class": "DummyBuilder", + "@version": None, + "dummy_prechunk": False, + "val": 0, + } + for i in range(2): + await socket.send(json.dumps(dummy_work).encode("utf-8")) + await asyncio.sleep(1) + message = await socket.recv() + assert message == HOSTNAME.encode("utf-8") - worker_task.cancel() + worker_task.cancel() @pytest.mark.asyncio @@ -127,6 +115,7 @@ async def test_no_prechunk(caplog): SERVER_PORT, [DummyBuilderWithNoPrechunk(dummy_prechunk=False)], num_chunks=10, + num_workers=10, ) ) await asyncio.sleep(1) From b800d24ae0d495c6e4be914378ea5cfd1dca8e72 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 24 Jan 2022 17:12:24 -0800 Subject: [PATCH 03/15] Update reqs --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8edeae78e..1fa042ecd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ pydantic==1.9.0 fastapi==0.65.2 numpy==1.19.5;python_version<"3.7" numpy==1.21.0;python_version>"3.6" -pynng==0.5.0 +pyzmq==22.3.0 dnspython==2.1.0 uvicorn==0.13.4 sshtunnel==0.4.0 diff --git a/setup.py b/setup.py index 02ca54b5a..0845798f8 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ "numpy>=1.17.3", "pydantic>=0.32.2", "fastapi>=0.42.0", - "pynng>=0.5.0", + "pyzmq==22.3.0", "dnspython>=1.16.0", "sshtunnel>=0.1.5", "msgpack>=0.5.6", From b234f0727490a81b5d45e54faa1abc6520232868 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 24 Jan 2022 20:25:19 -0800 Subject: [PATCH 04/15] Client bug fix --- src/maggma/cli/__init__.py | 3 ++- src/maggma/cli/distributed.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index 3bc633864..c51b82f12 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -57,6 +57,7 @@ " mrun will find an open port if None is provided to the manager", ) @click.option("-N", "--num-chunks", "num_chunks", default=0, type=int) +@click.option("-w", "--num-workers", "num_workers", default=0, type=int) @click.option( "--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations" ) @@ -100,7 +101,7 @@ def run( root.critical(f"Using random port for mrun manager: {port}") loop.run_until_complete( manager( - url=url, port=port, builders=builder_objects, num_chunks=num_chunks + url=url, port=port, builders=builder_objects, num_chunks=num_chunks, num_workers=num_workers ) ) else: diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index 39565c3a6..8bed9c747 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -32,6 +32,10 @@ async def manager( """ logger = getLogger("Manager") + if not num_chunks and num_workers: + raise ValueError("Both num_chunks and num_workers must be non-zero") + + logger.info(f"Binding to Manager URL {url}:{port}") context = zmq.Context() socket = context.socket(REP) From dd58756a2cb55d1f7e191519463ec99890a82b00 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Mon, 24 Jan 2022 20:49:31 -0800 Subject: [PATCH 05/15] Fix unreferenced variable --- src/maggma/cli/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index 8bed9c747..4e003eaeb 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -115,7 +115,7 @@ async def worker(url: str, port: int, num_workers: int): except Exception as e: logger.error(f"A worker failed with error: {e}") await socket.send("ERROR".encode("utf-8")) - message = await socket.recv() + await socket.recv() socket.close() From 2bd945f803bec3769668729a89c8532c330affd7 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 11:53:30 -0800 Subject: [PATCH 06/15] Add worker error test --- tests/cli/test_distributed.py | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index 6cbefedbc..f9708400a 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -41,6 +41,17 @@ def prechunk(self, num_chunks): return [{"val": i} for i in range(num_chunks)] +class DummyBuilderError(DummyBuilderWithNoPrechunk): + def prechunk(self, num_chunks): + return [{"val": i} for i in range(num_chunks)] + + def get_items(self): + raise ValueError("Dummy error") + + def process_items(self, items): + raise ValueError("Dummy error") + + SERVER_URL = "tcp://127.0.0.1" SERVER_PORT = 8234 @@ -106,6 +117,33 @@ async def test_worker(): worker_task.cancel() +@pytest.mark.asyncio +async def test_worker_error(): + context = zmq.Context() + socket = context.socket(REP) + socket.bind(f"{SERVER_URL}:{SERVER_PORT}") + + worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_workers=1)) + + message = await socket.recv() + assert message == HOSTNAME.encode("utf-8") + + dummy_work = { + "@module": "tests.cli.test_distributed", + "@class": "DummyBuilderError", + "@version": None, + "dummy_prechunk": False, + "val": 0, + } + + await socket.send(json.dumps(dummy_work).encode("utf-8")) + await asyncio.sleep(1) + message = await socket.recv() + assert message.decode("utf-8") == "ERROR" + + worker_task.cancel() + + @pytest.mark.asyncio async def test_no_prechunk(caplog): From c27da9747500e6e20cf0e3c9aba32d36ecdf3734 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 11:53:43 -0800 Subject: [PATCH 07/15] Remove comment --- src/maggma/cli/distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index 4e003eaeb..baba10240 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -35,7 +35,6 @@ async def manager( if not num_chunks and num_workers: raise ValueError("Both num_chunks and num_workers must be non-zero") - logger.info(f"Binding to Manager URL {url}:{port}") context = zmq.Context() socket = context.socket(REP) @@ -109,7 +108,6 @@ async def worker(url: str, port: int, num_workers: int): await multi(builder, num_workers) elif work == "EXIT": # End the worker - # This should look for a specific message ? running = False except Exception as e: From 410433590696f77fa0ba18d27648239d7e3cf4a8 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 12:12:37 -0800 Subject: [PATCH 08/15] Rename num_workers --- src/maggma/cli/__init__.py | 44 ++++++++++++++++++++++++++++------- src/maggma/cli/distributed.py | 4 ++-- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index c51b82f12..83e43d443 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -31,9 +31,9 @@ ) @click.option( "-n", - "--num-workers", - "num_workers", - help="Number of worker processes. Defaults to single processing", + "--num-processes", + "num_processes", + help="Number of processes to spawn for each worker. Defaults to single processing", default=1, type=click.IntRange(1), ) @@ -56,13 +56,35 @@ help="Port for distributed communication." " mrun will find an open port if None is provided to the manager", ) -@click.option("-N", "--num-chunks", "num_chunks", default=0, type=int) -@click.option("-w", "--num-workers", "num_workers", default=0, type=int) +@click.option( + "-N", + "--num-chunks", + "num_chunks", + default=0, + type=int, + help="Number of chunks to distribute to workers", +) +@click.option( + "-d", + "--num-workers", + "num_workers", + default=0, + type=int, + help="Number of distributed workers to process chunks", +) @click.option( "--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations" ) def run( - builders, verbosity, reporting_store, num_workers, url, port, num_chunks, no_bars + builders, + verbosity, + reporting_store, + num_workers, + url, + port, + num_chunks, + no_bars, + num_processes, ): # Set Logging @@ -101,12 +123,18 @@ def run( root.critical(f"Using random port for mrun manager: {port}") loop.run_until_complete( manager( - url=url, port=port, builders=builder_objects, num_chunks=num_chunks, num_workers=num_workers + url=url, + port=port, + builders=builder_objects, + num_chunks=num_chunks, + num_workers=num_workers, ) ) else: # worker - loop.run_until_complete(worker(url=url, port=port, num_workers=num_workers)) + loop.run_until_complete( + worker(url=url, port=port, num_processes=num_processes) + ) else: if num_workers == 1: for builder in builder_objects: diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index baba10240..623b9ea6e 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -80,7 +80,7 @@ async def manager( socket.close() -async def worker(url: str, port: int, num_workers: int): +async def worker(url: str, port: int, num_processes: int): """ Simple distributed worker that connects to a manager asks for work and deploys using multiprocessing @@ -105,7 +105,7 @@ async def worker(url: str, port: int, num_workers: int): if "@class" in work and "@module" in work: # We have a valid builder builder = MontyDecoder().process_decoded(work) - await multi(builder, num_workers) + await multi(builder, num_processes) elif work == "EXIT": # End the worker running = False From 2864e5b741c5e1896b50feaba3e37745f38719b7 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 12:16:59 -0800 Subject: [PATCH 09/15] Proagated num_processes --- src/maggma/cli/__init__.py | 2 +- src/maggma/cli/multiprocessing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index 83e43d443..87c41482b 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -143,5 +143,5 @@ def run( loop = asyncio.get_event_loop() for builder in builder_objects: loop.run_until_complete( - multi(builder=builder, num_workers=num_workers, no_bars=no_bars) + multi(builder=builder, num_processes=num_processes, no_bars=no_bars) ) diff --git a/src/maggma/cli/multiprocessing.py b/src/maggma/cli/multiprocessing.py index 0fe1dd1f7..fc7111726 100644 --- a/src/maggma/cli/multiprocessing.py +++ b/src/maggma/cli/multiprocessing.py @@ -146,11 +146,11 @@ def safe_dispatch(val): return None -async def multi(builder, num_workers, no_bars=False): +async def multi(builder, num_processes, no_bars=False): builder.connect() cursor = builder.get_items() - executor = ProcessPoolExecutor(num_workers) + executor = ProcessPoolExecutor(num_processes) # Gets the total number of items to process by priming # the cursor From ca1201efc48f09972e2a06f956d7f64bf8f4911d Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 12:22:07 -0800 Subject: [PATCH 10/15] Missed num_processes --- src/maggma/cli/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index 87c41482b..e8fb9bf59 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -136,7 +136,7 @@ def run( worker(url=url, port=port, num_processes=num_processes) ) else: - if num_workers == 1: + if num_processes == 1: for builder in builder_objects: serial(builder, no_bars) else: From 6c4f5fdc5b3506e07d9f8d2972cab0b0386b075f Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 12:43:06 -0800 Subject: [PATCH 11/15] Distributed test fix --- tests/cli/test_distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index f9708400a..0506aa1d0 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -97,7 +97,7 @@ async def test_worker(): socket = context.socket(REP) socket.bind(f"{SERVER_URL}:{SERVER_PORT}") - worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_workers=1)) + worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1)) message = await socket.recv() @@ -123,7 +123,7 @@ async def test_worker_error(): socket = context.socket(REP) socket.bind(f"{SERVER_URL}:{SERVER_PORT}") - worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_workers=1)) + worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1)) message = await socket.recv() assert message == HOSTNAME.encode("utf-8") From 3fceea7a7fe4d6e35d04e80095ce8e6191add0fd Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 13:12:07 -0800 Subject: [PATCH 12/15] Test worker exit signal --- src/maggma/cli/__init__.py | 2 +- tests/cli/test_distributed.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index e8fb9bf59..d12c5afcd 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -65,7 +65,7 @@ help="Number of chunks to distribute to workers", ) @click.option( - "-d", + "-w", "--num-workers", "num_workers", default=0, diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index 0506aa1d0..53c3ddd1b 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -90,6 +90,11 @@ async def test_manager_give_out_chunks(manager_server, log_to_stdout): assert work["@module"] == "tests.cli.test_distributed" assert work["val"] == i + for i in range(0, 10): + await socket.send(b"Ready") + message = await socket.recv() + assert message == b'"EXIT"' + @pytest.mark.asyncio async def test_worker(): From dc05fbf0478c5a939f9217283ea79182b8043606 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 14:36:07 -0800 Subject: [PATCH 13/15] Fix distribution if workers error out --- src/maggma/cli/distributed.py | 39 +++++++++++++++---------- tests/cli/test_distributed.py | 55 +++++++++++++++++++++++++++++------ 2 files changed, 70 insertions(+), 24 deletions(-) diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index 623b9ea6e..d29fbe72d 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -3,6 +3,7 @@ import json from logging import getLogger +from multiprocessing.sharedctypes import Value import socket as pysocket from typing import List @@ -47,25 +48,34 @@ async def manager( try: builder.connect() - chunks_dicts = list(builder.prechunk(num_chunks)) + chunks_tuples = [(d, False) for d in builder.prechunk(num_chunks)] - logger.info(f"Distributing {len(chunks_dicts)} chunks to workers") - for chunk_dict in tqdm(chunks_dicts, desc="Chunks"): - temp_builder_dict = dict(**builder_dict) - temp_builder_dict.update(chunk_dict) - temp_builder_dict = jsanitize(temp_builder_dict) + logger.info(f"Distributing {len(chunks_tuples)} chunks to workers") - # Wait for client connection that announces client and says it is ready to do work - logger.debug("Waiting for a worker") + for chunk_dict, distributed in tqdm(chunks_tuples, desc="Chunks"): + while not distributed: + if num_workers <= 0: + socket.close() + raise RuntimeError("No workers left to distribute chunks to") - worker = await socket.recv() + temp_builder_dict = dict(**builder_dict) + temp_builder_dict.update(chunk_dict) + temp_builder_dict = jsanitize(temp_builder_dict) - if worker.decode("utf-8") == "ERROR": - num_workers -= 1 + # Wait for client connection that announces client and says it is ready to do work + logger.debug("Waiting for a worker") - logger.debug(f"Got connection from worker: {worker.decode('utf-8')}") - # Send out the next chunk - await socket.send(json.dumps(temp_builder_dict).encode("utf-8")) + worker = await socket.recv() + + if worker.decode("utf-8") == "ERROR": + num_workers -= 1 + else: + logger.debug( + f"Got connection from worker: {worker.decode('utf-8')}" + ) + # Send out the next chunk + await socket.send(json.dumps(temp_builder_dict).encode("utf-8")) + distributed = True logger.info("Sending exit messages to workers") for _ in range(num_workers): @@ -113,7 +123,6 @@ async def worker(url: str, port: int, num_processes: int): except Exception as e: logger.error(f"A worker failed with error: {e}") await socket.send("ERROR".encode("utf-8")) - await socket.recv() socket.close() diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index 53c3ddd1b..26b9bfb51 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -56,10 +56,10 @@ def process_items(self, items): SERVER_PORT = 8234 -@pytest.fixture(scope="function") -async def manager_server(event_loop, log_to_stdout): +@pytest.mark.asyncio +async def test_manager_give_out_chunks(log_to_stdout): - task = asyncio.create_task( + manager_server = asyncio.create_task( manager( SERVER_URL, SERVER_PORT, @@ -68,12 +68,6 @@ async def manager_server(event_loop, log_to_stdout): num_workers=10, ) ) - yield task - task.cancel() - - -@pytest.mark.asyncio -async def test_manager_give_out_chunks(manager_server, log_to_stdout): context = zmq.Context() socket = context.socket(REQ) @@ -95,6 +89,31 @@ async def test_manager_give_out_chunks(manager_server, log_to_stdout): message = await socket.recv() assert message == b'"EXIT"' + manager_server.cancel() + + +@pytest.mark.asyncio +async def test_manager_worker_error(log_to_stdout): + + manager_server = asyncio.create_task( + manager( + SERVER_URL, + SERVER_PORT, + [DummyBuilder(dummy_prechunk=False)], + num_chunks=10, + num_workers=1, + ) + ) + + context = zmq.Context() + socket = context.socket(REQ) + socket.connect(f"{SERVER_URL}:{SERVER_PORT}") + + await socket.send("ERROR".encode("utf-8")) + await asyncio.sleep(1) + assert manager_server.done() + manager_server.cancel() + @pytest.mark.asyncio async def test_worker(): @@ -149,6 +168,24 @@ async def test_worker_error(): worker_task.cancel() +@pytest.mark.asyncio +async def test_worker_exit(): + context = zmq.Context() + socket = context.socket(REP) + socket.bind(f"{SERVER_URL}:{SERVER_PORT}") + + worker_task = asyncio.create_task(worker(SERVER_URL, SERVER_PORT, num_processes=1)) + + message = await socket.recv() + assert message == HOSTNAME.encode("utf-8") + + await socket.send_json("EXIT") + await asyncio.sleep(1) + assert worker_task.done() + + worker_task.cancel() + + @pytest.mark.asyncio async def test_no_prechunk(caplog): From 65ccef63f1f76eeea71720183cab5858dff929be Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 14:50:03 -0800 Subject: [PATCH 14/15] Fix and test num_chunks and num_workers error --- src/maggma/cli/distributed.py | 2 +- tests/cli/test_distributed.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index d29fbe72d..79a94d3b1 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -33,7 +33,7 @@ async def manager( """ logger = getLogger("Manager") - if not num_chunks and num_workers: + if not (num_chunks and num_workers): raise ValueError("Both num_chunks and num_workers must be non-zero") logger.info(f"Binding to Manager URL {url}:{port}") diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index 26b9bfb51..89b9a7ba6 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -56,6 +56,24 @@ def process_items(self, items): SERVER_PORT = 8234 +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.asyncio +async def test_wrong_worker_input(log_to_stdout): + + manager_server = asyncio.create_task( + manager( + SERVER_URL, + SERVER_PORT, + [DummyBuilder(dummy_prechunk=False)], + num_chunks=2, + num_workers=0, + ) + ) + + await asyncio.sleep(1) + manager_server.result() + + @pytest.mark.asyncio async def test_manager_give_out_chunks(log_to_stdout): From 250d25acb56551d7a498865dd56997e38756de39 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 26 Jan 2022 15:07:01 -0800 Subject: [PATCH 15/15] Remove unused import --- src/maggma/cli/distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index 79a94d3b1..e42b4cf73 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -3,7 +3,6 @@ import json from logging import getLogger -from multiprocessing.sharedctypes import Value import socket as pysocket from typing import List