diff --git a/.slurm_draft/draft_0_subprocess_auto_join/master.py b/.slurm_draft/draft_0_subprocess_auto_join/master.py new file mode 100644 index 0000000..d74cda9 --- /dev/null +++ b/.slurm_draft/draft_0_subprocess_auto_join/master.py @@ -0,0 +1,8 @@ +import subprocess + +print('master begin') +p = subprocess.Popen(['python -u worker.py > out.txt 2> err.txt'], shell=True) +# worker.py is launched asynchronously +print('master end') +# master.py will finish quickly (not waiting worker.py to finish) +# However, worker.py will still continue until it is done diff --git a/.slurm_draft/draft_0_subprocess_auto_join/worker.py b/.slurm_draft/draft_0_subprocess_auto_join/worker.py new file mode 100644 index 0000000..3dc96de --- /dev/null +++ b/.slurm_draft/draft_0_subprocess_auto_join/worker.py @@ -0,0 +1,5 @@ +import time + +for i in range(0,500): + print(f'worker {i}') + time.sleep(1) diff --git a/.slurm_draft/draft_1_sockets/master.py b/.slurm_draft/draft_1_sockets/master.py new file mode 100644 index 0000000..ea7513b --- /dev/null +++ b/.slurm_draft/draft_1_sockets/master.py @@ -0,0 +1,36 @@ +import socket +#from socket_utils import send +import socket_utils +import subprocess + +LOCALHOST = '127.0.0.1' + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # setup master's socket + s.bind((LOCALHOST, 0)) # 0: let the OS choose an available port + s.listen() + port = s.getsockname()[1] + + n = 4 + + # launch workers + workers = [] + for i in range(0,n): + p = subprocess.Popen([f'python -u worker.py {port} {i} > out.txt 2> err.txt'], shell=True) + workers.append(p) + + # recv worker messages + remaining_workers = n + while n>0: + conn, addr = s.accept() + with conn: + print(f"Connected by {addr}") + msg = socket_utils.recv(conn) + print(msg) + n -= 1 + + # wait for workers to finish and collect error codes + returncodes = [] + for p in workers: + returncode = p.wait() + returncodes.append(returncode) diff --git a/.slurm_draft/draft_1_sockets/socket_utils.py b/.slurm_draft/draft_1_sockets/socket_utils.py new file mode 100644 index 0000000..24c899e --- /dev/null +++ b/.slurm_draft/draft_1_sockets/socket_utils.py @@ -0,0 +1,31 @@ +def send(sock, msg): + msg_bytes = msg.encode('utf-8') + + msg_len = len(msg_bytes) + sent = sock.send(msg_len.to_bytes(8,'big')) # send int64 big endian + if sent == 0: + raise RuntimeError('Socket send broken: could not send message size') + + totalsent = 0 + while totalsent < msg_len: + sent = sock.send(msg_bytes[totalsent:]) + if sent == 0: + raise RuntimeError('Socket send broken: could not send message') + totalsent = totalsent + sent + +def recv(sock): + msg_len_bytes = sock.recv(8) + if msg_len_bytes == b'': + raise RuntimeError('Socket recv broken: no message size') + msg_len = int.from_bytes(msg_len_bytes, 'big') + + chunks = [] + bytes_recv = 0 + while bytes_recv < msg_len: + chunk = sock.recv(min(msg_len-bytes_recv, 4096)) + if chunk == b'': + raise RuntimeError('Socket recv broken: could not receive message') + chunks.append(chunk) + bytes_recv += len(chunk) + msg_bytes = b''.join(chunks) + return msg_bytes.decode('utf-8') diff --git a/.slurm_draft/draft_1_sockets/worker.py b/.slurm_draft/draft_1_sockets/worker.py new file mode 100644 index 0000000..fb0b283 --- /dev/null +++ b/.slurm_draft/draft_1_sockets/worker.py @@ -0,0 +1,14 @@ +import socket +import socket_utils +#import time +import sys + +LOCALHOST = '127.0.0.1' + +assert len(sys.argv) == 3 +server_port = int(sys.argv[1]) +test_idx = int(sys.argv[2]) + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((LOCALHOST, server_port)) + socket_utils.send(s, f'Hello from {test_idx}') diff --git a/.slurm_draft/draft_2_srun/master_1p.sh b/.slurm_draft/draft_2_srun/master_1p.sh new file mode 100755 index 0000000..cf40bb3 --- /dev/null +++ b/.slurm_draft/draft_2_srun/master_1p.sh @@ -0,0 +1,17 @@ +echo "launch proc 0" +srun --exclusive --ntasks=1 --qos c1_inter_giga -l bash worker.sh 0 & +echo "detach proc 0" + +echo "launch proc 1" +srun --exclusive --ntasks=1 --qos c1_inter_giga -l bash worker.sh 1 & +echo "detach proc 1" + +echo "launch proc 2" +srun --exclusive --ntasks=1 --qos c1_inter_giga -l bash worker.sh 2 & +echo "detach proc 2" + +echo "launch proc 3" +srun --exclusive --ntasks=1 --qos c1_inter_giga -l bash worker.sh 3 & +echo "detach proc 3" + +wait diff --git a/.slurm_draft/draft_2_srun/master_1p_no_exclusive.sh b/.slurm_draft/draft_2_srun/master_1p_no_exclusive.sh new file mode 100755 index 0000000..710ffe3 --- /dev/null +++ b/.slurm_draft/draft_2_srun/master_1p_no_exclusive.sh @@ -0,0 +1,18 @@ + +echo "launch proc 0" +srun --ntasks=1 --qos c1_inter_giga -l bash worker.sh 0 & +echo "detach proc 0" + +echo "launch proc 1" +srun --ntasks=1 --qos c1_inter_giga -l bash worker.sh 1 & +echo "detach proc 1" + +echo "launch proc 2" +srun --ntasks=1 --qos c1_inter_giga -l bash worker.sh 2 & +echo "detach proc 2" + +echo "launch proc 3" +srun --ntasks=1 --qos c1_inter_giga -l bash worker.sh 3 & +echo "detach proc 3" + +wait diff --git a/.slurm_draft/draft_2_srun/master_multi_node.sh b/.slurm_draft/draft_2_srun/master_multi_node.sh new file mode 100755 index 0000000..9b21674 --- /dev/null +++ b/.slurm_draft/draft_2_srun/master_multi_node.sh @@ -0,0 +1,17 @@ +echo "launch proc 0" +srun --exclusive --ntasks=43 --qos c1_inter_giga -l bash worker.sh 0 & +echo "detach proc 0" + +echo "launch proc 1" +srun --exclusive --ntasks=4 --qos c1_inter_giga -l bash worker.sh 1 & +echo "detach proc 1" + +echo "launch proc 2" +srun --exclusive --ntasks=45 --qos c1_inter_giga -l bash worker.sh 2 & +echo "detach proc 2" + +echo "launch proc 3" +srun --exclusive --ntasks=4 --qos c1_inter_giga -l bash worker.sh 3 & +echo "detach proc 3" + +wait diff --git a/.slurm_draft/draft_2_srun/master_simple.sh b/.slurm_draft/draft_2_srun/master_simple.sh new file mode 100755 index 0000000..8acce2e --- /dev/null +++ b/.slurm_draft/draft_2_srun/master_simple.sh @@ -0,0 +1 @@ +srun --ntasks=1 -l bash worker.sh 0 diff --git a/.slurm_draft/draft_2_srun/slurm_job.sh b/.slurm_draft/draft_2_srun/slurm_job.sh new file mode 100755 index 0000000..fc320a1 --- /dev/null +++ b/.slurm_draft/draft_2_srun/slurm_job.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +#SBATCH --job-name=pytest_par +#SBATCH --time 00:30:00 +#SBATCH --qos=co_short_std +#SBATCH --ntasks=88 +#SBATCH --nodes=2-2 +#SBATCH --output=slurm.%j.out +#SBATCH --error=slurm.%j.err + +#date +#source /scratchm/sonics/dist/2023-11/source.sh --env sonics_dev --compiler gcc@12 --mpi intel-oneapi + +date +#./master_1p.sh +#./master_1p_no_exclusive.sh +./master_multi_node.sh +date diff --git a/.slurm_draft/draft_2_srun/slurm_mpi_job.sh b/.slurm_draft/draft_2_srun/slurm_mpi_job.sh new file mode 100755 index 0000000..63cc68d --- /dev/null +++ b/.slurm_draft/draft_2_srun/slurm_mpi_job.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name=test_slurm_pytest +#SBATCH --ntasks=48 +#SBATCH --time 0-0:10 +#SBATCH --qos=c1_test_giga +#SBATCH --output=slurm.%j.out +#SBATCH --error=slurm.%j.err + +#date +#source /scratchm/sonics/dist/2023-11/source.sh --env sonics_dev --compiler gcc@12 --mpi intel-oneapi + +date +#python3 master.py +./master.sh +date diff --git a/.slurm_draft/draft_2_srun/worker.py b/.slurm_draft/draft_2_srun/worker.py new file mode 100644 index 0000000..e536b43 --- /dev/null +++ b/.slurm_draft/draft_2_srun/worker.py @@ -0,0 +1,18 @@ +import socket +import socket_utils +import time +import sys +import datetime + +LOCALHOST = '127.0.0.1' + +assert len(sys.argv) == 3 +server_port = int(sys.argv[1]) +test_idx = int(sys.argv[2]) + +print(f'start proc {test_idx} - ',datetime.datetime.now()) + +#with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: +# time.sleep(10) +# s.connect((LOCALHOST, server_port)) +# socket_utils.send(s, f'Hello from {test_idx}') diff --git a/.slurm_draft/draft_2_srun/worker.sh b/.slurm_draft/draft_2_srun/worker.sh new file mode 100755 index 0000000..c61fab8 --- /dev/null +++ b/.slurm_draft/draft_2_srun/worker.sh @@ -0,0 +1,7 @@ + +printf "start proc $1 " +hostname +date +sleep 10 +printf "end proc $1 " +date diff --git a/.slurm_draft/draft_3_sockets_ip/job_worker.sh b/.slurm_draft/draft_3_sockets_ip/job_worker.sh new file mode 100644 index 0000000..c6e2d5b --- /dev/null +++ b/.slurm_draft/draft_3_sockets_ip/job_worker.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#MSUB -r deploy_test +#MSUB -o sonics.out +#MSUB -e sonics.err +#MSUB -n 1 +#MSUB -T 1600 +#MSUB -A +####MSUB -x +#MSUB -q milan +#MSUB -Q test +#MSUB -m scratch,work + +hostname -I +python worker.py diff --git a/.slurm_draft/draft_3_sockets_ip/machine_conf.py b/.slurm_draft/draft_3_sockets_ip/machine_conf.py new file mode 100644 index 0000000..d1c97b0 --- /dev/null +++ b/.slurm_draft/draft_3_sockets_ip/machine_conf.py @@ -0,0 +1,4 @@ + +ips = ['10.30.14.3', '10.136.0.3', '172.28.5.3', '10.137.0.3', '10.137.128.3'] +ip = ips[1] +port = 10000 diff --git a/.slurm_draft/draft_3_sockets_ip/master.py b/.slurm_draft/draft_3_sockets_ip/master.py new file mode 100644 index 0000000..4d89fb9 --- /dev/null +++ b/.slurm_draft/draft_3_sockets_ip/master.py @@ -0,0 +1,18 @@ +import socket +#from socket_utils import send +import socket_utils +import subprocess +from machine_conf import ip, port + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # setup master's socket + s.bind((ip, port)) # port=0: let the OS choose an available port + s.listen() + port = s.getsockname()[1] + + print('waiting for socket connection') + conn, addr = s.accept() + with conn: + print(f"Connected by {addr}") + msg = socket_utils.recv(conn) + print(msg) diff --git a/.slurm_draft/draft_3_sockets_ip/socket_utils.py b/.slurm_draft/draft_3_sockets_ip/socket_utils.py new file mode 100644 index 0000000..3e5acd2 --- /dev/null +++ b/.slurm_draft/draft_3_sockets_ip/socket_utils.py @@ -0,0 +1,31 @@ +def send(sock, msg): + msg_bytes = msg.encode('utf-8') + + msg_len = len(msg_bytes) + sent = sock.send(msg_len.to_bytes(4,'big')) # send int64 big endian + if sent == 0: + raise RuntimeError('Socket send broken: could not send message size') + + totalsent = 0 + while totalsent < msg_len: + sent = sock.send(msg_bytes[totalsent:]) + if sent == 0: + raise RuntimeError('Socket send broken: could not send message') + totalsent = totalsent + sent + +def recv(sock): + msg_len_bytes = sock.recv(4) + if msg_len_bytes == b'': + raise RuntimeError('Socket recv broken: no message size') + msg_len = int.from_bytes(msg_len_bytes, 'big') + + chunks = [] + bytes_recv = 0 + while bytes_recv < msg_len: + chunk = sock.recv(min(msg_len-bytes_recv, 4096)) + if chunk == b'': + raise RuntimeError('Socket recv broken: could not receive message') + chunks.append(chunk) + bytes_recv += len(chunk) + msg_bytes = b''.join(chunks) + return msg_bytes.decode('utf-8') diff --git a/.slurm_draft/draft_3_sockets_ip/worker.py b/.slurm_draft/draft_3_sockets_ip/worker.py new file mode 100644 index 0000000..7358cd2 --- /dev/null +++ b/.slurm_draft/draft_3_sockets_ip/worker.py @@ -0,0 +1,10 @@ +import socket +import socket_utils +import time + +from machine_conf import ip, port + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((ip, port)) + socket_utils.send(s, f'Hello from {socket.gethostname()}') + time.sleep(2) diff --git a/.slurm_draft/hello_mpi.cpp b/.slurm_draft/hello_mpi.cpp new file mode 100644 index 0000000..8feb2fc --- /dev/null +++ b/.slurm_draft/hello_mpi.cpp @@ -0,0 +1,16 @@ +#include "mpi.h" +#include + +int main(int argc, char *argv[]) { + MPI_Init(&argc, &argv); + + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + std::cout << "mpi proc = " << rank << "/" << world_size << "\n"; + + MPI_Finalize(); + + return 0; +} diff --git a/.slurm_draft/master.py b/.slurm_draft/master.py new file mode 100644 index 0000000..b4d1861 --- /dev/null +++ b/.slurm_draft/master.py @@ -0,0 +1,41 @@ +import socket +#from socket_utils import send +import socket_utils +import subprocess +import datetime + +LOCALHOST = '127.0.0.1' + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # setup master's socket + s.bind((LOCALHOST, 0)) # 0: let the OS choose an available port + s.listen() + port = s.getsockname()[1] + + n = 4 + + # launch workers + workers = [] + for i in range(0,n): + #p = subprocess.Popen([f'python3 -u worker.py {port} {i} > out.txt 2> err.txt'], shell=True) + print('starting subprocess - ',datetime.datetime.now()) + p = subprocess.Popen([f'srun --exclusive --ntasks=1 --qos c1_inter_giga -l python3 -u worker.py {port} {i} > out_{i}.txt 2> err_{i}.txt'], shell=True) # --exclusive for SLURM to parallelize with srun (https://stackoverflow.com/a/66805905/1583122) + print('detached subprocess - ',datetime.datetime.now()) + workers.append(p) + + # recv worker messages + remaining_workers = n + while n>0: + print(f'remaining_workers={n} - ',datetime.datetime.now()) + conn, addr = s.accept() + with conn: + msg = socket_utils.recv(conn) + print(msg) + n -= 1 + + # wait for workers to finish and collect error codes + returncodes = [] + for p in workers: + print('wait to finish - ',datetime.datetime.now()) + returncode = p.wait() + returncodes.append(returncode) diff --git a/.slurm_draft/run.sh b/.slurm_draft/run.sh new file mode 100644 index 0000000..8b91d52 --- /dev/null +++ b/.slurm_draft/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --job-name=pytest_par +#SBATCH --time 00:30:00 +#SBATCH --qos=co_short_std +#SBATCH --ntasks=1 +##SBATCH --nodes=2-2 +#SBATCH --output=slurm.%j.out +#SBATCH --error=slurm.%j.err + +#echo $TOTO +whoami +#srun --exclusive --ntasks=1 -l hostname +nproc --all diff --git a/.slurm_draft/slurm_impi.sh b/.slurm_draft/slurm_impi.sh new file mode 100755 index 0000000..63cc68d --- /dev/null +++ b/.slurm_draft/slurm_impi.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name=test_slurm_pytest +#SBATCH --ntasks=48 +#SBATCH --time 0-0:10 +#SBATCH --qos=c1_test_giga +#SBATCH --output=slurm.%j.out +#SBATCH --error=slurm.%j.err + +#date +#source /scratchm/sonics/dist/2023-11/source.sh --env sonics_dev --compiler gcc@12 --mpi intel-oneapi + +date +#python3 master.py +./master.sh +date diff --git a/.slurm_draft/socket_utils.py b/.slurm_draft/socket_utils.py new file mode 100644 index 0000000..3e5acd2 --- /dev/null +++ b/.slurm_draft/socket_utils.py @@ -0,0 +1,31 @@ +def send(sock, msg): + msg_bytes = msg.encode('utf-8') + + msg_len = len(msg_bytes) + sent = sock.send(msg_len.to_bytes(4,'big')) # send int64 big endian + if sent == 0: + raise RuntimeError('Socket send broken: could not send message size') + + totalsent = 0 + while totalsent < msg_len: + sent = sock.send(msg_bytes[totalsent:]) + if sent == 0: + raise RuntimeError('Socket send broken: could not send message') + totalsent = totalsent + sent + +def recv(sock): + msg_len_bytes = sock.recv(4) + if msg_len_bytes == b'': + raise RuntimeError('Socket recv broken: no message size') + msg_len = int.from_bytes(msg_len_bytes, 'big') + + chunks = [] + bytes_recv = 0 + while bytes_recv < msg_len: + chunk = sock.recv(min(msg_len-bytes_recv, 4096)) + if chunk == b'': + raise RuntimeError('Socket recv broken: could not receive message') + chunks.append(chunk) + bytes_recv += len(chunk) + msg_bytes = b''.join(chunks) + return msg_bytes.decode('utf-8') diff --git a/.slurm_draft/worker.py b/.slurm_draft/worker.py new file mode 100644 index 0000000..3084a00 --- /dev/null +++ b/.slurm_draft/worker.py @@ -0,0 +1,37 @@ +import socket +import socket_utils +import time +import sys +import datetime +from mpi4py import MPI + +assert len(sys.argv) == 4 +scheduler_ip = sys.argv[1] +server_port = int(sys.argv[2]) +test_idx = int(sys.argv[3]) + +comm = MPI.COMM_WORLD +print(f'start at {scheduler_ip}@{server_port} test {test_idx} at rank {comm.Get_rank()}/{comm.Get_size()} exec on {socket.gethostname()} - ',datetime.datetime.now()) + +if comm.Get_rank() == 0: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((scheduler_ip, server_port)) + #time.sleep(10+5*test_idx) + #msg = f'Hello from test {test_idx} at rank {comm.Get_rank()}/{comm.Get_size()} exec on {socket.gethostname()}' + #socket_utils.send(s, msg) + info = { + 'test_idx': test_idx, + 'setup': { + 'outcome': 'passed', + 'longrepr': f'setup msg {test_idx}', + }, + 'call': { + 'outcome': 'failed', + 'longrepr': f'call msg {test_idx}', + }, + 'teardown': { + 'outcome': 'passed', + 'longrepr': f'teardown msg {test_idx}', + }, + } + socket_utils.send(s, str(info)) diff --git a/pytest_parallel/algo.py b/pytest_parallel/algo.py index b74b4e3..bf3e6f7 100644 --- a/pytest_parallel/algo.py +++ b/pytest_parallel/algo.py @@ -8,11 +8,11 @@ def identity(elem): return elem -def partition(seq, pred): +def partition(xs, pred): """ - partitions sequence `seq` into - `xs_true` with elements of `seq` that satisfy predicate `pred` - `xs_false` with elements of `seq` that don't satisfy predicate `pred` + partitions sequence `xs` into + `xs_true` with elements of `xs` that satisfy predicate `pred` + `xs_false` with elements of `xs` that don't satisfy predicate `pred` then returns `xs_true`, `xs_false` Complexity: @@ -25,7 +25,7 @@ def partition(seq, pred): """ xs_true = [] xs_false = [] - for elem in seq: + for elem in xs: if pred(elem): xs_true.append(elem) else: @@ -33,7 +33,7 @@ def partition(seq, pred): return xs_true, xs_false -def partition_point(seq, pred): +def partition_point(xs, pred): """ Gives the partition point of sequence `xs` That is, the index i where @@ -52,25 +52,25 @@ def partition_point(seq, pred): Constant """ i = 0 - j = len(seq) + j = len(xs) while i < j: mid = (i + j) // 2 - if pred(seq[mid]): + if pred(xs[mid]): i = mid + 1 else: j = mid return i -def lower_bound(seq, value, key=identity, comp=operator.lt): +def lower_bound(xs, value, key=identity, comp=operator.lt): def pred(elem): return comp(key(elem), value) - return partition_point(seq, pred) + return partition_point(xs, pred) -def upper_bound(seq, value, key=identity, comp=operator.lt): +def upper_bound(xs, value, key=identity, comp=operator.lt): def pred(elem): return not comp(value, key(elem)) - return partition_point(seq, pred) + return partition_point(xs, pred) diff --git a/pytest_parallel/gather_report.py b/pytest_parallel/gather_report.py new file mode 100644 index 0000000..ad7a4e3 --- /dev/null +++ b/pytest_parallel/gather_report.py @@ -0,0 +1,77 @@ +from mpi4py import MPI +from _pytest._code.code import ( + ExceptionChainRepr, + ReprTraceback, + ReprEntryNative, + ReprFileLocation, +) + + +def gather_report(mpi_reports, n_sub_rank): + assert len(mpi_reports) == n_sub_rank + + report_init = mpi_reports[0] + goutcome = report_init.outcome + glongrepr = report_init.longrepr + + collect_longrepr = [] + # > We need to rebuild a TestReport object, location can be false # TODO ? + for i_sub_rank, test_report in enumerate(mpi_reports): + if test_report.outcome == "failed": + goutcome = "failed" + + if test_report.longrepr: + msg = f"On rank {i_sub_rank} of {n_sub_rank}" + full_msg = f"\n-------------------------------- {msg} --------------------------------" + fake_trace_back = ReprTraceback([ReprEntryNative(full_msg)], None, None) + collect_longrepr.append( + (fake_trace_back, ReprFileLocation(*report_init.location), None) + ) + collect_longrepr.append( + (test_report.longrepr, ReprFileLocation(*report_init.location), None) + ) + + if len(collect_longrepr) > 0: + glongrepr = ExceptionChainRepr(collect_longrepr) + + return goutcome, glongrepr + + +def gather_report_on_local_rank_0(report): + """ + Gather reports from all procs participating in the test on rank 0 of the sub_comm + """ + sub_comm = report.sub_comm + del report.sub_comm # No need to keep it in the report + # Furthermore we need to serialize the report + # and mpi4py does not know how to serialize report.sub_comm + i_sub_rank = sub_comm.Get_rank() + n_sub_rank = sub_comm.Get_size() + + if ( + report.outcome != "skipped" + ): # Skipped test are only known by proc 0 -> no merge required + # Warning: PyTest reports can actually be quite big + request = sub_comm.isend(report, dest=0, tag=i_sub_rank) + + if i_sub_rank == 0: + mpi_reports = n_sub_rank * [None] + for _ in range(n_sub_rank): + status = MPI.Status() + + mpi_report = sub_comm.recv( + source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status + ) + mpi_reports[status.Get_source()] = mpi_report + + assert ( + None not in mpi_reports + ) # should have received from all ranks of `sub_comm` + goutcome, glongrepr = gather_report(mpi_reports, n_sub_rank) + + report.outcome = goutcome + report.longrepr = glongrepr + + request.wait() + + sub_comm.barrier() diff --git a/pytest_parallel/mpi_reporter.py b/pytest_parallel/mpi_reporter.py index 309a148..e9faf6c 100644 --- a/pytest_parallel/mpi_reporter.py +++ b/pytest_parallel/mpi_reporter.py @@ -1,91 +1,20 @@ import numpy as np import pytest -from _pytest._code.code import ( - ExceptionChainRepr, - ReprTraceback, - ReprEntryNative, - ReprFileLocation, -) from mpi4py import MPI from .algo import partition, lower_bound -from .utils import ( - number_of_working_processes, - get_n_proc_for_test, - mark_skip, - add_n_procs, - is_dyn_master_process, -) - - -def gather_report(mpi_reports, n_sub_rank): - assert len(mpi_reports) == n_sub_rank - - report_init = mpi_reports[0] - goutcome = report_init.outcome - glongrepr = report_init.longrepr - - collect_longrepr = [] - # > We need to rebuild a TestReport object, location can be false # TODO ? - for i_sub_rank, test_report in enumerate(mpi_reports): - if test_report.outcome == "failed": - goutcome = "failed" - - if test_report.longrepr: - msg = f"On rank {i_sub_rank} of {n_sub_rank}" - full_msg = f"\n-------------------------------- {msg} --------------------------------" - fake_trace_back = ReprTraceback([ReprEntryNative(full_msg)], None, None) - collect_longrepr.append( - (fake_trace_back, ReprFileLocation(*report_init.location), None) - ) - collect_longrepr.append( - (test_report.longrepr, ReprFileLocation(*report_init.location), None) - ) - - if len(collect_longrepr) > 0: - glongrepr = ExceptionChainRepr(collect_longrepr) - - return goutcome, glongrepr - - -def gather_report_on_local_rank_0(report): - """ - Gather reports from all procs participating in the test on rank 0 of the sub_comm - """ - sub_comm = report.sub_comm - del report.sub_comm # No need to keep it in the report - # Furthermore we need to serialize the report - # and mpi4py does not know how to serialize report.sub_comm - i_sub_rank = sub_comm.Get_rank() - n_sub_rank = sub_comm.Get_size() - - if ( - report.outcome != "skipped" - ): # Skipped test are only known by proc 0 -> no merge required - # Warning: PyTest reports can actually be quite big - request = sub_comm.isend(report, dest=0, tag=i_sub_rank) - - if i_sub_rank == 0: - mpi_reports = n_sub_rank * [None] - for _ in range(n_sub_rank): - status = MPI.Status() - - mpi_report = sub_comm.recv( - source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status - ) - mpi_reports[status.Get_source()] = mpi_report - - assert ( - None not in mpi_reports - ) # should have received from all ranks of `sub_comm` - goutcome, glongrepr = gather_report(mpi_reports, n_sub_rank) - - report.outcome = goutcome - report.longrepr = glongrepr +from .utils import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index +from .utils_mpi import number_of_working_processes, is_dyn_master_process +from .gather_report import gather_report_on_local_rank_0 - request.wait() - sub_comm.barrier() +def mark_skip(item): + comm = MPI.COMM_WORLD + n_rank = comm.Get_size() + n_proc_test = get_n_proc_for_test(item) + skip_msg = f"Not enough procs to execute: {n_proc_test} required but only {n_rank} available" + item.add_marker(pytest.mark.skip(reason=skip_msg), append=False) + item.marker_mpi_skip = True def filter_and_add_sub_comm(items, global_comm): @@ -128,7 +57,7 @@ def pytest_collection_modifyitems(self, config, items): @pytest.hookimpl(hookwrapper=True, tryfirst=True) def pytest_runtestloop(self, session) -> bool: - outcome = yield + _ = yield # prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED) # when no test run on non-master if self.global_comm.Get_rank() != 0 and session.testscollected == 0: @@ -175,16 +104,7 @@ def group_items_by_parallel_steps(items, n_workers): return items_by_step, items_to_skip -def run_item_test(item, nextitem, session): - item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem) - if session.shouldfail: - raise session.Failed(session.shouldfail) - if session.shouldstop: - raise session.Interrupted(session.shouldstop) - - def prepare_items_to_run(items, comm): - n_rank = comm.Get_size() i_rank = comm.Get_rank() items_to_run = [] @@ -318,6 +238,7 @@ def pytest_runtest_logreport(self, report): report.outcome = mpi_report.outcome report.longrepr = mpi_report.longrepr + report.duration = mpi_report.duration def sub_comm_from_ranks(global_comm, sub_ranks): @@ -349,10 +270,6 @@ def _key(item): return lower_bound(items, max_needed_n_proc, _key) -def mark_original_index(items): - for i, item in enumerate(items): - item.original_index = i - ########### Client/Server ########### SCHEDULED_WORK_TAG = 0 @@ -381,7 +298,7 @@ def schedule_test(item, available_procs, inter_comm): item.sub_ranks = sub_ranks - # the procs are busy + # mark the procs as busy for sub_rank in sub_ranks: available_procs[sub_rank] = False @@ -408,17 +325,14 @@ def wait_test_to_complete(items_to_run, session, available_procs, inter_comm): # get associated item item = items_to_run[original_idx] - n_proc = item.n_proc sub_ranks = item.sub_ranks assert first_rank_done in sub_ranks - # receive done message from all other proc associated to the item + # receive done message from all other procs associated to the item for sub_rank in sub_ranks: if sub_rank != first_rank_done: rank_original_idx = inter_comm.recv(source=sub_rank, tag=WORK_DONE_TAG) - assert ( - rank_original_idx == original_idx - ) # sub_rank is supposed to have worked on the same test + assert (rank_original_idx == original_idx) # sub_rank is supposed to have worked on the same test # the procs are now available for sub_rank in sub_ranks: @@ -442,21 +356,19 @@ def receive_run_and_report_tests( # receive original_idx, sub_ranks = inter_comm.recv(source=0, tag=SCHEDULED_WORK_TAG) if original_idx == -1: - return # signal work is done + return # signal work is done # run sub_comm = sub_comm_from_ranks(global_comm, sub_ranks) item = items_to_run[original_idx] item.sub_comm = sub_comm - nextitem = None # not known at this point + nextitem = None # not known at this point run_item_test(item, nextitem, session) # signal work is done for the test inter_comm.send(original_idx, dest=0, tag=WORK_DONE_TAG) - MPI.Request.waitall( - current_item_requests - ) # make sure all report isends have been received + MPI.Request.waitall(current_item_requests) # make sure all report isends have been received current_item_requests.clear() @@ -468,6 +380,8 @@ def __init__(self, global_comm, inter_comm): @pytest.hookimpl(tryfirst=True) def pytest_pyfunc_call(self, pyfuncitem): + # This is where the test is normally run. + # Since the master process only collects the reports, it needs to *not* run anything. cond = is_dyn_master_process(self.inter_comm) and not ( hasattr(pyfuncitem, "marker_mpi_skip") and pyfuncitem.marker_mpi_skip ) @@ -520,25 +434,19 @@ def pytest_runtestloop(self, session) -> bool: while len(items_left_to_run) > 0: n_av_procs = np.sum(available_procs) - item_idx = item_with_biggest_admissible_n_proc( - items_left_to_run, n_av_procs - ) + item_idx = item_with_biggest_admissible_n_proc(items_left_to_run, n_av_procs) if item_idx == -1: - wait_test_to_complete( - items_to_run, session, available_procs, self.inter_comm - ) + wait_test_to_complete(items_to_run, session, available_procs, self.inter_comm) else: - schedule_test( - items_left_to_run[item_idx], available_procs, self.inter_comm - ) + schedule_test(items_left_to_run[item_idx], available_procs, self.inter_comm) del items_left_to_run[item_idx] wait_last_tests_to_complete( items_to_run, session, available_procs, self.inter_comm ) signal_all_done(self.inter_comm) - else: # worker proc + else: # worker proc receive_run_and_report_tests( items_to_run, session, @@ -585,12 +493,12 @@ def pytest_runtest_logreport(self, report): # The idea of the scheduler is the following: # The server schedules test over clients # A client executes the test then report to the server it is done - # The server execute the PyTest pipeline to make it think it ran the test (but only receives the reports of the client) + # The server executes the PyTest pipeline to make it think it ran the test (but only receives the reports of the client) # The server continues its scheduling # Here in the report, we need an isend, because the ordering is the following: # Client: run test, isend reports, send done to server # Server: recv done from client, 'run' test (i.e. recv reports) - # So the 'send done' must be received before the 'send report' + # So the 'send done' must be received before the 'isend report' request = self.inter_comm.isend(report, dest=0, tag=tag) self.current_item_requests.append(request) else: # global master: receive @@ -600,3 +508,4 @@ def pytest_runtest_logreport(self, report): report.outcome = mpi_report.outcome report.longrepr = mpi_report.longrepr + report.duration = mpi_report.duration diff --git a/pytest_parallel/plugin.py b/pytest_parallel/plugin.py index 4eb3802..908ad12 100644 --- a/pytest_parallel/plugin.py +++ b/pytest_parallel/plugin.py @@ -1,65 +1,161 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +import sys +import subprocess import tempfile import pytest from pathlib import Path -from mpi4py import MPI - - -from .mpi_reporter import SequentialScheduler, StaticScheduler, DynamicScheduler -from .utils import spawn_master_process, is_master_process +import argparse # -------------------------------------------------------------------------- def pytest_addoption(parser): parser.addoption( - "--scheduler", - dest="scheduler", - choices=["sequential", "static", "dynamic"], - default="sequential", + '--scheduler', + dest='scheduler', + choices=['sequential', 'static', 'dynamic', 'slurm'], + default='sequential', + help='Method used by pytest_parallel to schedule tests', ) + parser.addoption('--n-workers', dest='n_workers', type=int, help='Max number of processes to run in parallel') + + parser.addoption('--slurm-options', dest='slurm_options', type=str, help='list of SLURM options e.g. "--time=00:30:00 --qos=my_queue --n_tasks=4"') + parser.addoption('--slurm-additional-cmds', dest='slurm_additional_cmds', type=str, help='list of commands to pass to SLURM job e.g. "source my_env.sh"') + parser.addoption('--slurm-file', dest='slurm_file', type=str, help='Path to file containing header of SLURM job') # TODO DEL + parser.addoption('--slurm-sub-command', dest='slurm_sub_command', type=str, help='SLURM submission command (defaults to `sbatch`)') # TODO DEL + + if sys.version_info >= (3,9): + parser.addoption('--slurm-export-env', dest='slurm_export_env', action=argparse.BooleanOptionalAction, default=True) + else: + parser.addoption('--slurm-export-env', dest='slurm_export_env', default=False, action='store_true') + parser.addoption('--no-slurm-export-env', dest='slurm_export_env', action='store_false') + + parser.addoption('--detach', dest='detach', action='store_true', help='Detach SLURM jobs: do not send reports to the scheduling process (useful to launch slurm job.sh separately)') + + # Private to SLURM scheduler + parser.addoption('--_worker', dest='_worker', action='store_true', help='Internal pytest_parallel option') + parser.addoption('--_scheduler_ip_address', dest='_scheduler_ip_address', type=str, help='Internal pytest_parallel option') + parser.addoption('--_scheduler_port', dest='_scheduler_port', type=int, help='Internal pytest_parallel option') + parser.addoption('--_test_idx' , dest='_test_idx' , type=int, help='Internal pytest_parallel option') + + # Note: + # we need to NOT import mpi4py when pytest_parallel + # is called with the SLURM scheduler + # because it can mess SLURM `srun` + if "--scheduler=slurm" in sys.argv: + assert 'mpi4py.MPI' not in sys.modules, 'Internal pytest_parallel error: mpi4py.MPI should not be imported' \ + ' when we are about to register and environment for SLURM' \ + ' (because importing mpi4py.MPI makes the current process look like and MPI process,' \ + ' and SLURM does not like that)' + + r = subprocess.run(['env','--null'], stdout=subprocess.PIPE) # `--null`: end each output line with NUL, required by `sbatch --export-file` + + assert r.returncode==0, 'SLURM scheduler: error when writing `env` to `pytest_slurm/env_vars.sh`' + pytest._pytest_parallel_env_vars = r.stdout # -------------------------------------------------------------------------- @pytest.hookimpl(trylast=True) def pytest_configure(config): - global_comm = MPI.COMM_WORLD - - scheduler = config.getoption("scheduler") - if scheduler == "sequential": - plugin = SequentialScheduler(global_comm) - elif scheduler == "static": - plugin = StaticScheduler(global_comm) - elif scheduler == "dynamic": - inter_comm = spawn_master_process(global_comm) - plugin = DynamicScheduler(global_comm, inter_comm) - else: - assert 0 + # Get options and check dependent/incompatible options + scheduler = config.getoption('scheduler') + n_workers = config.getoption('n_workers') + slurm_options = config.getoption('slurm_options') + slurm_additional_cmds = config.getoption('slurm_additional_cmds') + slurm_worker = config.getoption('_worker') + slurm_file = config.getoption('slurm_file') + slurm_export_env = config.getoption('slurm_export_env') + slurm_sub_command = config.getoption('slurm_sub_command') + detach = config.getoption('detach') + if scheduler != 'slurm': + assert not slurm_worker, 'Option `--slurm-worker` only available when `--scheduler=slurm`' + assert not slurm_options, 'Option `--slurm-options` only available when `--scheduler=slurm`' + assert not slurm_additional_cmds, 'Option `--slurm-additional-cmds` only available when `--scheduler=slurm`' + assert not slurm_file, 'Option `--slurm-file` only available when `--scheduler=slurm`' + + if scheduler == 'slurm' and not slurm_worker: + assert slurm_options or slurm_file, 'You need to specify either `--slurm-options` or `--slurm-file` when `--scheduler=slurm`' + if slurm_options: + assert not slurm_file, 'You need to specify either `--slurm-options` or `--slurm-file`, but not both' + if slurm_file: + assert not slurm_options, 'You need to specify either `--slurm-options` or `--slurm-file`, but not both' + assert not slurm_additional_cmds, 'You cannot specify `--slurm-additional-cmds` together with `--slurm-file`' + + from .process_scheduler import ProcessScheduler + + enable_terminal_reporter = True + + # List of all invoke options except slurm options + ## reconstruct complete invoke string + quoted_invoke_params = [] + for arg in config.invocation_params.args: + if ' ' in arg and not '--slurm-options' in arg: + quoted_invoke_params.append("'"+arg+"'") + else: + quoted_invoke_params.append(arg) + main_invoke_params = ' '.join(quoted_invoke_params) + ## pull apart `--slurm-options` for special treatement + main_invoke_params = main_invoke_params.replace(f'--slurm-options={slurm_options}', '') + for file_or_dir in config.option.file_or_dir: + main_invoke_params = main_invoke_params.replace(file_or_dir, '') + slurm_option_list = slurm_options.split() if slurm_options is not None else [] + slurm_conf = { + 'options' : slurm_option_list, + 'additional_cmds': slurm_additional_cmds, + 'file' : slurm_file, + 'export_env' : slurm_export_env, + 'sub_command' : slurm_sub_command, + } + plugin = ProcessScheduler(main_invoke_params, n_workers, slurm_conf, detach) - config.pluginmanager.register(plugin, "pytest_parallel") + else: + from mpi4py import MPI + from .mpi_reporter import SequentialScheduler, StaticScheduler, DynamicScheduler + from .process_worker import ProcessWorker + from .utils_mpi import spawn_master_process, should_enable_terminal_reporter + + global_comm = MPI.COMM_WORLD + enable_terminal_reporter = should_enable_terminal_reporter(global_comm, scheduler) + + if scheduler == 'sequential': + plugin = SequentialScheduler(global_comm) + elif scheduler == 'static': + plugin = StaticScheduler(global_comm) + elif scheduler == 'dynamic': + inter_comm = spawn_master_process(global_comm) + plugin = DynamicScheduler(global_comm, inter_comm) + elif scheduler == 'slurm' and slurm_worker: + scheduler_ip_address = config.getoption('_scheduler_ip_address') + scheduler_port = config.getoption('_scheduler_port') + test_idx = config.getoption('_test_idx') + plugin = ProcessWorker(scheduler_ip_address, scheduler_port, test_idx, detach) + else: + assert 0 + + config.pluginmanager.register(plugin, 'pytest_parallel') # only report to terminal if master process - if not is_master_process(global_comm, scheduler): - terminal_reporter = config.pluginmanager.getplugin("terminalreporter") + if not enable_terminal_reporter: + terminal_reporter = config.pluginmanager.getplugin('terminalreporter') config.pluginmanager.unregister(terminal_reporter) # -------------------------------------------------------------------------- @pytest.fixture def comm(request): - """ - Only return a previous MPI Communicator (build at prepare step ) - """ - return request.node.sub_comm # TODO clean + ''' + Returns the MPI Communicator created by pytest_parallel + ''' + return request.node.sub_comm # -------------------------------------------------------------------------- class CollectiveTemporaryDirectory: - """ + ''' Context manager creating a tmp dir in parallel and removing it at the exit - """ + ''' def __init__(self, comm): self.comm = comm @@ -80,8 +176,8 @@ def __exit__(self, type, value, traceback): @pytest.fixture def mpi_tmpdir(comm): - """ + ''' This function ensure that one process handles the naming of temporary folders. - """ + ''' with CollectiveTemporaryDirectory(comm) as tmpdir: yield tmpdir diff --git a/pytest_parallel/process_scheduler.py b/pytest_parallel/process_scheduler.py new file mode 100644 index 0000000..edfe0ee --- /dev/null +++ b/pytest_parallel/process_scheduler.py @@ -0,0 +1,241 @@ +import pytest +import subprocess +import socket +import pickle +from pathlib import Path +from . import socket_utils +from .utils import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index +from .algo import partition + + +def mark_skip(item, slurm_ntasks): + n_proc_test = get_n_proc_for_test(item) + skip_msg = f"Not enough procs to execute: {n_proc_test} required but only {slurm_ntasks} available" + item.add_marker(pytest.mark.skip(reason=skip_msg), append=False) + item.marker_mpi_skip = True + +def replace_sub_strings(s, subs, replacement): + res = s + for sub in subs: + res = res.replace(sub,replacement) + return res + +def remove_exotic_chars(s): + return replace_sub_strings(str(s), ['[',']','/', ':'], '_') + +def parse_job_id_from_submission_output(s): + # At this point, we are trying to guess -_- + # Here we supposed that the command for submitting the job + # returned string with only one number, + # and that this number is the job id + import re + return int(re.search(r'\d+', str(s)).group()) + + +# https://stackoverflow.com/a/34177358 +def command_exists(cmd_name): + """Check whether `name` is on PATH and marked as executable.""" + from shutil import which + return which(cmd_name) is not None + +def _get_my_ip_address(): + hostname = socket.gethostname() + + assert command_exists('tracepath'), 'pytest_parallel SLURM scheduler: command `tracepath` is not available' + cmd = ['tracepath','-4','-n',hostname] + r = subprocess.run(cmd, stdout=subprocess.PIPE) + assert r.returncode==0, f'pytest_parallel SLURM scheduler: error running command `{" ".join(cmd)}`' + ips = r.stdout.decode("utf-8") + + try: + my_ip = ips.split('\n')[0].split(':')[1].split()[0] + except: + assert 0, f'pytest_parallel SLURM scheduler: error parsing result `{ips}` of command `{" ".join(cmd)}`' + import ipaddress + try: + ipaddress.ip_address(my_ip) + except ValueError: + assert 0, f'pytest_parallel SLURM scheduler: error parsing result `{ips}` of command `{" ".join(cmd)}`' + + return my_ip + + +def submit_items(items_to_run, socket, main_invoke_params, slurm_ntasks, slurm_conf): + # Find IP our address + SCHEDULER_IP_ADDRESS = _get_my_ip_address() + + # setup master's socket + socket.bind((SCHEDULER_IP_ADDRESS, 0)) # 0: let the OS choose an available port + socket.listen() + port = socket.getsockname()[1] + + # generate SLURM header options + if slurm_conf['file'] is not None: + with open(slurm_conf['file']) as f: + slurm_header = f.read() + # Note: + # slurm_ntasks is supposed to be <= to the number of the ntasks submitted to slurm + # but since the header file can be arbitrary, we have no way to check at this point + else: + slurm_header = '#!/bin/bash\n' + slurm_header += '\n' + slurm_header += '#SBATCH --job-name=pytest_parallel\n' + slurm_header += '#SBATCH --output=pytest_slurm/slurm.%j.out\n' + slurm_header += '#SBATCH --error=pytest_slurm/slurm.%j.err\n' + for opt in slurm_conf['options']: + slurm_header += f'#SBATCH {opt}\n' + slurm_header += f'#SBATCH --ntasks={slurm_ntasks}' + + # sort item by comm size to launch bigger first (Note: in case SLURM prioritize first-received items) + items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True) + + # launch srun for each item + worker_flags=f"--_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS} --_scheduler_port={port}" + cmds = '' + if slurm_conf['additional_cmds'] is not None: + cmds += slurm_conf['additional_cmds'] + '\n' + for item in items: + test_idx = item.original_index + test_out_file_base = f'pytest_slurm/{remove_exotic_chars(item.nodeid)}' + cmd = f'srun --exclusive --ntasks={item.n_proc} -l' + cmd += f' python3 -u -m pytest -s {worker_flags} {main_invoke_params} --_test_idx={test_idx} {item.config.rootpath}/{item.nodeid}' + cmd += f' > {test_out_file_base} 2>&1' + cmd += ' &\n' # launch everything in parallel + cmds += cmd + cmds += 'wait\n' + + pytest_slurm = f'{slurm_header}\n\n{cmds}' + Path('pytest_slurm').mkdir(exist_ok=True) + with open('pytest_slurm/job.sh','w') as f: + f.write(pytest_slurm) + + # submit SLURM job + with open('pytest_slurm/env_vars.sh','wb') as f: + f.write(pytest._pytest_parallel_env_vars) + + if slurm_conf['sub_command'] is None: + if slurm_conf['export_env']: + sbatch_cmd = 'sbatch --parsable --export-file=pytest_slurm/env_vars.sh pytest_slurm/job.sh' + else: + sbatch_cmd = 'sbatch --parsable pytest_slurm/job.sh' + else: + sbatch_cmd = slurm_conf['sub_command'] + ' pytest_slurm/job.sh' + + p = subprocess.Popen([sbatch_cmd], shell=True, stdout=subprocess.PIPE) + print('\nSubmitting tests to SLURM...') + returncode = p.wait() + assert returncode==0, f'Error when submitting to SLURM with `{sbatch_cmd}`' + + if slurm_conf['sub_command'] is None: + slurm_job_id = int(p.stdout.read()) + else: + slurm_job_id = parse_job_id_from_submission_output(p.stdout.read()) + + print(f'SLURM job {slurm_job_id} has been submitted') + return slurm_job_id + +def receive_items(items, session, socket, n_item_to_recv): + while n_item_to_recv>0: + conn, addr = socket.accept() + with conn: + msg = socket_utils.recv(conn) + test_info = pickle.loads(msg) # the worker is supposed to have send a dict with the correct structured information + test_idx = test_info['test_idx'] + if test_info['fatal_error'] is not None: + assert 0, f'{test_info["fatal_error"]}' + item = items[test_idx] + item.sub_comm = None + item.info = test_info + + # "run" the test (i.e. trigger PyTest pipeline but do not really run the code) + nextitem = None # not known at this point + run_item_test(items[test_idx], nextitem, session) + n_item_to_recv -= 1 + +class ProcessScheduler: + def __init__(self, main_invoke_params, slurm_ntasks, slurm_conf, detach): + self.main_invoke_params = main_invoke_params + self.slurm_ntasks = slurm_ntasks + self.slurm_conf = slurm_conf + self.detach = detach + + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TODO close at the end + self.slurm_job_id = None + + @pytest.hookimpl(tryfirst=True) + def pytest_pyfunc_call(self, pyfuncitem): + # This is where the test is normally run. + # Since the scheduler process only collects the reports, it needs to *not* run anything. + if not (hasattr(pyfuncitem, "marker_mpi_skip") and pyfuncitem.marker_mpi_skip): + return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run + + @pytest.hookimpl(tryfirst=True) + def pytest_runtestloop(self, session) -> bool: + # same beginning as PyTest default's + if ( + session.testsfailed + and not session.config.option.continue_on_collection_errors + ): + raise session.Interrupted( + "%d error%s during collection" + % (session.testsfailed, "s" if session.testsfailed != 1 else "") + ) + + if session.config.option.collectonly: + return True + + # mark original position + mark_original_index(session.items) + ## add proc to items + add_n_procs(session.items) + + # isolate skips + has_enough_procs = lambda item: item.n_proc <= self.slurm_ntasks + items_to_run, items_to_skip = partition(session.items, has_enough_procs) + + # run skipped + for i, item in enumerate(items_to_skip): + item.sub_comm = None + mark_skip(item, self.slurm_ntasks) + nextitem = items_to_skip[i + 1] if i + 1 < len(items_to_skip) else None + run_item_test(item, nextitem, session) + + # schedule tests to run + n_item_to_receive = len(items_to_run) + if n_item_to_receive > 0: + self.slurm_job_id = submit_items(items_to_run, self.socket, self.main_invoke_params, self.slurm_ntasks, self.slurm_conf) + if not self.detach: # The job steps are supposed to send their reports + receive_items(session.items, session, self.socket, n_item_to_receive) + + return True + + @pytest.hookimpl() + def pytest_keyboard_interrupt(excinfo): + if excinfo.slurm_job_id is not None: + print(f'Calling `scancel {excinfo.slurm_job_id}`') + subprocess.run(['scancel',str(excinfo.slurm_job_id)]) + + @pytest.hookimpl(hookwrapper=True) + def pytest_runtest_makereport(self, item): + """ + Need to hook to pass the test sub-comm and the master_running_proc to `pytest_runtest_logreport`, + and for that we add the sub-comm to the only argument of `pytest_runtest_logreport`, that is, `report` + Also, if test is not run on this proc, mark the outcome accordingly + """ + result = yield + report = result.get_result() + if hasattr(item, "marker_mpi_skip") and item.marker_mpi_skip: + report.mpi_skip = True + else: + report.info = item.info + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_logreport(self, report): + if hasattr(report, "mpi_skip") and report.mpi_skip: + pass + else: + assert report.when in ("setup", "call", "teardown") # only known tags + + report.outcome = report.info[report.when]['outcome'] + report.longrepr = report.info[report.when]['longrepr'] + report.duration = report.info[report.when]['duration'] diff --git a/pytest_parallel/process_worker.py b/pytest_parallel/process_worker.py new file mode 100644 index 0000000..a3c7f22 --- /dev/null +++ b/pytest_parallel/process_worker.py @@ -0,0 +1,69 @@ +import pytest + +from mpi4py import MPI + +import socket +import pickle +from . import socket_utils +from .utils import get_n_proc_for_test, run_item_test +from .gather_report import gather_report_on_local_rank_0 + +class ProcessWorker: + def __init__(self, scheduler_ip_address, scheduler_port, test_idx, detach): + self.scheduler_ip_address = scheduler_ip_address + self.scheduler_port = scheduler_port + self.test_idx = test_idx + self.detach = detach + + @pytest.hookimpl(tryfirst=True) + def pytest_runtestloop(self, session) -> bool: + comm = MPI.COMM_WORLD + assert len(session.items) == 1, f'INTERNAL FATAL ERROR in pytest_parallel with slurm scheduling: should only have one test per worker, but got {len(session.items)}' + item = session.items[0] + test_comm_size = get_n_proc_for_test(item) + + item.sub_comm = comm + item.test_info = {'test_idx': self.test_idx, 'fatal_error': None} + + + if comm.Get_size() != test_comm_size: # fatal error, SLURM and MPI do not interoperate correctly + error_info = f'FATAL ERROR in pytest_parallel with slurm scheduling: test `{item.nodeid}`' \ + f' uses a `comm` of size {test_comm_size} but was launched with size {comm.Get_size()}.\n' \ + f' This generally indicates that `srun` does not interoperate correctly with MPI.' + + item.test_info['fatal_error'] = error_info + else: # normal case: the test can be run + nextitem = None + run_item_test(item, nextitem, session) + + if not self.detach and comm.Get_rank() == 0: # not detached: proc 0 is expected to send results to scheduling process + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((self.scheduler_ip_address, self.scheduler_port)) + socket_utils.send(s, pickle.dumps(item.test_info)) + + if item.test_info['fatal_error'] is not None: + assert 0, f'{item.test_info["fatal_error"]}' + + return True + + @pytest.hookimpl(hookwrapper=True) + def pytest_runtest_makereport(self, item): + """ + We need to hook to pass the test sub-comm to `pytest_runtest_logreport`, + and for that we add the sub-comm to the only argument of `pytest_runtest_logreport`, that is, `report` + We also need to pass `item.test_info` so that we can update it + """ + result = yield + report = result.get_result() + report.sub_comm = item.sub_comm + report.test_info = item.test_info + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_logreport(self, report): + assert report.when in ("setup", "call", "teardown") # only known tags + gather_report_on_local_rank_0(report) + report.test_info.update({report.when: {'outcome' : report.outcome, + 'longrepr': report.longrepr, + 'duration': report.duration, }}) + + diff --git a/pytest_parallel/socket_utils.py b/pytest_parallel/socket_utils.py new file mode 100644 index 0000000..f434270 --- /dev/null +++ b/pytest_parallel/socket_utils.py @@ -0,0 +1,29 @@ +def send(sock, msg_bytes): + msg_len = len(msg_bytes) + sent = sock.send(msg_len.to_bytes(8,'big')) # send int64 big endian + if sent == 0: + raise RuntimeError('Socket send broken: could not send message size') + + totalsent = 0 + while totalsent < msg_len: + sent = sock.send(msg_bytes[totalsent:]) + if sent == 0: + raise RuntimeError('Socket send broken: could not send message') + totalsent = totalsent + sent + +def recv(sock): + msg_len_bytes = sock.recv(8) + if msg_len_bytes == b'': + raise RuntimeError('Socket recv broken: message has no size') + msg_len = int.from_bytes(msg_len_bytes, 'big') + + chunks = [] + bytes_recv = 0 + while bytes_recv < msg_len: + chunk = sock.recv(min(msg_len-bytes_recv, 4096)) + if chunk == b'': + raise RuntimeError('Socket recv broken: could not receive message') + chunks.append(chunk) + bytes_recv += len(chunk) + msg_bytes = b''.join(chunks) + return msg_bytes diff --git a/pytest_parallel/utils.py b/pytest_parallel/utils.py index bcdf971..0869e8d 100644 --- a/pytest_parallel/utils.py +++ b/pytest_parallel/utils.py @@ -2,9 +2,6 @@ import pytest from _pytest.nodes import Item -from mpi4py import MPI - - def get_n_proc_for_test(item: Item) -> int : if not hasattr(item, 'callspec'): return 1 # no callspec, so no `comm` => sequential test case try: @@ -18,47 +15,14 @@ def add_n_procs(items): item.n_proc = get_n_proc_for_test(item) -def mark_skip(item): - comm = MPI.COMM_WORLD - n_rank = comm.Get_size() - n_proc_test = get_n_proc_for_test(item) - skip_msg = f"Not enough procs to execute: {n_proc_test} required but only {n_rank} available" - item.add_marker(pytest.mark.skip(reason=skip_msg), append=False) - item.marker_mpi_skip = True - - -def is_dyn_master_process(comm): - parent_comm = comm.Get_parent() - if parent_comm == MPI.COMM_NULL: - return False - return True - - -def is_master_process(comm, scheduler): - if scheduler == "dynamic": - return is_dyn_master_process(comm) - return comm.Get_rank() == 0 - - -def spawn_master_process(global_comm): - if not is_dyn_master_process(global_comm): - error_codes = [] - if sys.argv[0].endswith(".py"): - inter_comm = global_comm.Spawn( - sys.executable, args=sys.argv, maxprocs=1, errcodes=error_codes - ) - else: - inter_comm = global_comm.Spawn( - sys.argv[0], args=sys.argv[1:], maxprocs=1, errcodes=error_codes - ) - for error_code in error_codes: - if error_code != 0: - assert 0 - return inter_comm - return global_comm.Get_parent() +def run_item_test(item, nextitem, session): + item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem) + if session.shouldfail: + raise session.Failed(session.shouldfail) + if session.shouldstop: + raise session.Interrupted(session.shouldstop) -def number_of_working_processes(comm): - if is_dyn_master_process(comm): - return comm.Get_remote_size() - return comm.Get_size() +def mark_original_index(items): + for i, item in enumerate(items): + item.original_index = i diff --git a/pytest_parallel/utils_mpi.py b/pytest_parallel/utils_mpi.py new file mode 100644 index 0000000..3466235 --- /dev/null +++ b/pytest_parallel/utils_mpi.py @@ -0,0 +1,40 @@ +import sys +from mpi4py import MPI + + +def is_dyn_master_process(comm): + parent_comm = comm.Get_parent() + if parent_comm == MPI.COMM_NULL: + return False + return True + + +def should_enable_terminal_reporter(comm, scheduler): + if scheduler == "dynamic": + return is_dyn_master_process(comm) + else: + return comm.Get_rank() == 0 + + +def spawn_master_process(global_comm): + if not is_dyn_master_process(global_comm): + error_codes = [] + if sys.argv[0].endswith(".py"): + inter_comm = global_comm.Spawn( + sys.executable, args=sys.argv, maxprocs=1, errcodes=error_codes + ) + else: + inter_comm = global_comm.Spawn( + sys.argv[0], args=sys.argv[1:], maxprocs=1, errcodes=error_codes + ) + for error_code in error_codes: + if error_code != 0: + assert 0 + return inter_comm + return global_comm.Get_parent() + + +def number_of_working_processes(comm): + if is_dyn_master_process(comm): + return comm.Get_remote_size() + return comm.Get_size() diff --git a/source.sh b/source.sh new file mode 100644 index 0000000..3a4ddea --- /dev/null +++ b/source.sh @@ -0,0 +1,19 @@ +module purge +#module use --append /scratchm/sonics/opt_el8/modules/linux-centos7-broadwell +#module use --append /scratchm/sonics/opt_el8/modules/linux-rhel8-broadwell +#module use --append /scratchm/sonics/usr/modules/ +#module load intel-oneapi-mpi-2021.6.0-gcc-8.3.1-hjimxhi +module load impi/.23.2.0 +#source /scratchm/sonics/prod/spacky_2022-09-23_el8/external/spack/var/spack/environments/maia_gcc_intel-oneapi_2023-05/loads +#module load intel/23.2.0 +#module load gcc/12.1.0 +#module load intel/.mkl-2021.2.0 +#export I_MPI_CC=$CC +#export I_MPI_CXX=$CXX +#export I_MPI_FC=$FC +#export I_MPI_F90=$FC +##unset I_MPI_PMI_LIBRARY +#echo "CC=$CC" +#echo "CXX=$CXX" +#echo "FC=$FC" +#echo "MPI library: intel-oneapi" diff --git a/test/pytest_parallel_refs/terminal_fixture_error b/test/pytest_parallel_refs/terminal_fixture_error index 5b71322..1b01ba5 100644 --- a/test/pytest_parallel_refs/terminal_fixture_error +++ b/test/pytest_parallel_refs/terminal_fixture_error @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 1 item[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? [^\n]*test_fixture_error.py::test_fixture_error ERROR diff --git a/test/pytest_parallel_refs/terminal_parametrize b/test/pytest_parallel_refs/terminal_parametrize index 4b5742d..be1d3f9 100644 --- a/test/pytest_parallel_refs/terminal_parametrize +++ b/test/pytest_parallel_refs/terminal_parametrize @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 8 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_scheduling b/test/pytest_parallel_refs/terminal_scheduling index 373cf54..dbedd2a 100644 --- a/test/pytest_parallel_refs/terminal_scheduling +++ b/test/pytest_parallel_refs/terminal_scheduling @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 8 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_seq b/test/pytest_parallel_refs/terminal_seq index 5006863..b4c7b85 100644 --- a/test/pytest_parallel_refs/terminal_seq +++ b/test/pytest_parallel_refs/terminal_seq @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_success_0_fail_1 b/test/pytest_parallel_refs/terminal_success_0_fail_1 index 0ba37d6..cdd0d6c 100644 --- a/test/pytest_parallel_refs/terminal_success_0_fail_1 +++ b/test/pytest_parallel_refs/terminal_success_0_fail_1 @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 1 item[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? [^\n]*test_success_0_fail_1.py::test_fail_one_rank\[2\] FAILED diff --git a/test/pytest_parallel_refs/terminal_two_fail_tests_one_proc b/test/pytest_parallel_refs/terminal_two_fail_tests_one_proc index 48346b2..52deb14 100644 --- a/test/pytest_parallel_refs/terminal_two_fail_tests_one_proc +++ b/test/pytest_parallel_refs/terminal_two_fail_tests_one_proc @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs b/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs index 52c8eb0..ab491f3 100644 --- a/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs +++ b/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs_skip b/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs_skip index 515d53c..15ce2fe 100644 --- a/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs_skip +++ b/test/pytest_parallel_refs/terminal_two_fail_tests_two_procs_skip @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? [^\n]*test_two_fail_tests_two_procs.py::test_fail_1\[2\] SKIPPED[^\n]* [^\n]*test_two_fail_tests_two_procs.py::test_fail_2\[2\] SKIPPED[^\n]* diff --git a/test/pytest_parallel_refs/terminal_two_success_fail_tests_two_procs b/test/pytest_parallel_refs/terminal_two_success_fail_tests_two_procs index 2b4449f..88e6d91 100644 --- a/test/pytest_parallel_refs/terminal_two_success_fail_tests_two_procs +++ b/test/pytest_parallel_refs/terminal_two_success_fail_tests_two_procs @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_two_success_tests_one_proc b/test/pytest_parallel_refs/terminal_two_success_tests_one_proc index 62f2947..949b2c3 100644 --- a/test/pytest_parallel_refs/terminal_two_success_tests_one_proc +++ b/test/pytest_parallel_refs/terminal_two_success_tests_one_proc @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_two_success_tests_two_procs b/test/pytest_parallel_refs/terminal_two_success_tests_two_procs index 947e968..42360f7 100644 --- a/test/pytest_parallel_refs/terminal_two_success_tests_two_procs +++ b/test/pytest_parallel_refs/terminal_two_success_tests_two_procs @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? (?# [This is a comment until the closing parenthesis] The following lines use the regex 'lookahead' mechanism to match the lines in any order diff --git a/test/pytest_parallel_refs/terminal_two_success_tests_two_procs_skip b/test/pytest_parallel_refs/terminal_two_success_tests_two_procs_skip index 3e5bd1c..2980d68 100644 --- a/test/pytest_parallel_refs/terminal_two_success_tests_two_procs_skip +++ b/test/pytest_parallel_refs/terminal_two_success_tests_two_procs_skip @@ -1,10 +1,12 @@ [=]+ test session starts [=]+ platform [^\n]* cachedir: [^\n]* -?(metadata: [^\n]*)? +?(?:metadata: [^\n]*)? rootdir: [^\n]* -?(plugins: [^\n]*)? +?(?:plugins: [^\n]*)? collecting ... [\s]*collected 2 items[\s]* +?(?:Submitting tests to SLURM...)? +?(?:SLURM job [^\n]* has been submitted)? [^\n]*test_two_success_tests_two_procs.py::test_success_1\[2\] SKIPPED[^\n]* [^\n]*test_two_success_tests_two_procs.py::test_success_2\[2\] SKIPPED[^\n]* diff --git a/test/test_pytest_parallel.py b/test/test_pytest_parallel.py index c12b5c3..e9e69a1 100644 --- a/test/test_pytest_parallel.py +++ b/test/test_pytest_parallel.py @@ -2,25 +2,23 @@ Test that pytest_parallel gives the correct outputs by running it on a set of examples, then comparing it to template references - - We run the checks with pytest - But, since we are in the process of testing pytest_parallel, - the testing framework (this file!) MUST DISABLE pytest_parallel when we run its tests - (but of course its tests will in turn run tests with pytest_parallel enabled) """ + + +# pytest_parallel MUST NOT be plugged in its testing framework environement +# it will be plugged by the framework when needed (see `run_pytest_parallel_test`) +# (else we would use pytest_parallel to test pytest_parallel, which is logically wrong) import os -import sys +pytest_plugins = os.getenv('PYTEST_PLUGINS') +assert pytest_plugins is None or 'pytest_parallel.plugin' not in pytest_plugins +import sys import re import subprocess from pathlib import Path import pytest -# def test_env_ok(pytestconfig): -# assert not pytestconfig.pluginmanager.hasplugin('pytest_parallel') - - root_dir = Path(__file__).parent tests_dir = root_dir / "pytest_parallel_tests" refs_dir = root_dir / "pytest_parallel_refs" @@ -54,11 +52,11 @@ def run_pytest_parallel_test(test_name, n_workers, scheduler, capfd, suffix=""): stderr_file_path.unlink(missing_ok=True) test_env = os.environ.copy() - # According to Gentoo people it is a good practise for CI - # To disable autoload and enforce explicit plugin loading + # To test pytest_parallel, we can need to launch pytest with it if "PYTEST_DISABLE_PLUGIN_AUTOLOAD" not in test_env: test_env["PYTEST_DISABLE_PLUGIN_AUTOLOAD"] = "1" cmd = f"mpiexec -n {n_workers} pytest -p pytest_parallel.plugin -s -ra -vv --color=no --scheduler={scheduler} {test_file_path}" + #cmd = f"pytest -p pytest_parallel.plugin -s -ra -vv --color=no --scheduler={scheduler} --slurm_options='--time=00:30:00 --qos=co_short_std --ntasks={n_workers}' {test_file_path}" subprocess.run(cmd, shell=True, text=True, env=test_env) captured = capfd.readouterr() with open(output_file_path, "w", encoding="utf-8", newline="\n") as f: @@ -69,16 +67,15 @@ def run_pytest_parallel_test(test_name, n_workers, scheduler, capfd, suffix=""): assert ref_match(output_file_name) -param_scheduler = ( - ["sequential", "static", "dynamic"] - if sys.platform != "win32" - else ["sequential", "static"] -) +param_scheduler = ["sequential", "static", "dynamic"] +# TODO "slurm" scheduler +#param_scheduler = ["slurm"] +if sys.platform == "win32": + param_scheduler = ["sequential", "static"] # fmt: off @pytest.mark.parametrize("scheduler", param_scheduler) class TestPytestParallel: - def test_00(self, scheduler, capfd): run_pytest_parallel_test('seq' , 1, scheduler, capfd) def test_01(self, scheduler, capfd): run_pytest_parallel_test('two_success_tests_one_proc' , 1, scheduler, capfd) # need at least 1 proc @@ -110,14 +107,14 @@ def test_19(self, scheduler, capfd): run_pytest_parallel_test('scheduling' # fmt: on ## If one test fail, it may be useful to debug regex matching along the following lines -# test = 'two_fail_tests_one_proc' -# file = 'terminal_'+test +#test = 'two_fail_tests_one_proc' +#file = 'terminal_'+test # -# template_path = refs_dir/file -# with open(template_path, 'r') as f: -# ref_regex = f.read() -# output_path = output_dir/file -# with open(output_path, 'r') as f: -# result = f.read() +#template_path = refs_dir/file +#with open(template_path, 'r') as f: +# ref_regex = f.read() +#output_path = output_dir/file +#with open(output_path, 'r') as f: +# result = f.read() # -# print(re.findall(ref_regex, result, flags=re.DOTALL)) +#print(re.findall(ref_regex, result, flags=re.DOTALL))