Skip to content

Commit

Permalink
Shell scheduler: seems to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
BerengerBerthoul committed Sep 4, 2024
1 parent 83df48b commit 7d58ed6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 85 deletions.
32 changes: 4 additions & 28 deletions pytest_parallel/mpi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .utils import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index
from .utils_mpi import number_of_working_processes, is_dyn_master_process
from .gather_report import gather_report_on_local_rank_0
from .static_scheduler_utils import group_items_by_parallel_steps


def mark_skip(item):
Expand Down Expand Up @@ -134,31 +135,6 @@ def pytest_runtest_logreport(self, report):
return True # ranks that don't participate in the tests don't have to report anything


def group_items_by_parallel_steps(items, n_workers):
add_n_procs(items)
items.sort(key=lambda item: item.n_proc, reverse=True)

remaining_n_procs_by_step = []
items_by_step = []
items_to_skip = []
for item in items:
if item.n_proc > n_workers:
items_to_skip += [item]
else:
found_step = False
for idx, remaining_procs in enumerate(remaining_n_procs_by_step):
if item.n_proc <= remaining_procs:
items_by_step[idx] += [item]
remaining_n_procs_by_step[idx] -= item.n_proc
found_step = True
break
if not found_step:
items_by_step += [[item]]
remaining_n_procs_by_step += [n_workers - item.n_proc]

return items_by_step, items_to_skip


def prepare_items_to_run(items, comm):
i_rank = comm.Get_rank()

Expand Down Expand Up @@ -237,9 +213,9 @@ def pytest_runtestloop(self, session) -> bool:

n_workers = self.global_comm.Get_size()

items_by_steps, items_to_skip = group_items_by_parallel_steps(
session.items, n_workers
)
add_n_procs(session.items)

items_by_steps, items_to_skip = group_items_by_parallel_steps(session.items, n_workers)

items = items_to_run_on_this_proc(
items_by_steps, items_to_skip, self.global_comm
Expand Down
92 changes: 35 additions & 57 deletions pytest_parallel/shell_static_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from . import socket_utils
from .utils import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index
from .algo import partition

from .static_scheduler_utils import group_items_by_parallel_steps
from mpi4py import MPI

def mark_skip(item, ntasks):
n_proc_test = get_n_proc_for_test(item)
Expand Down Expand Up @@ -61,73 +62,57 @@ def _get_my_ip_address():

return my_ip


def submit_items(items_to_run, socket, main_invoke_params, ntasks):
def setup_socket(socket):
# Find IP our address
SCHEDULER_IP_ADDRESS = _get_my_ip_address()

# setup master's socket
socket.bind((SCHEDULER_IP_ADDRESS, 0)) # 0: let the OS choose an available port
socket.listen()
port = socket.getsockname()[1]

## generate SLURM header options
#if slurm_conf['file'] is not None:
# with open(slurm_conf['file']) as f:
# slurm_header = f.read()
# # Note:
# # ntasks is supposed to be <= to the number of the ntasks submitted to slurm
# # but since the header file can be arbitrary, we have no way to check at this point
#else:
# slurm_header = '#!/bin/bash\n'
# slurm_header += '\n'
# slurm_header += '#SBATCH --job-name=pytest_parallel\n'
# slurm_header += '#SBATCH --output=.pytest_parallel/slurm.%j.out\n'
# slurm_header += '#SBATCH --error=.pytest_parallel/slurm.%j.err\n'
# for opt in slurm_conf['options']:
# slurm_header += f'#SBATCH {opt}\n'
# slurm_header += f'#SBATCH --ntasks={ntasks}'

return SCHEDULER_IP_ADDRESS, port

def mpi_command(current_proc, n_proc):
mpi_vendor = MPI.get_vendor()[0]
if mpi_vendor == 'Intel MPI':
cmd = f'I_MPI_PIN_PROCESSOR_LIST={current_proc}-{current_proc+n_proc-1}; '
cmd += f'mpiexec -np {n_proc}'
return cmd
elif mpi_vendor == 'Open MPI':
cores = ','.join([str(i) for i in range(current_proc,current_proc+n_proc)])
return f'mpiexec --cpu-list {cores} -np {n_proc}'
else:
assert 0, f'Unknown MPI implementation "{mpi_vendor}"'

def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, main_invoke_params, ntasks, i_step, n_step):
# sort item by comm size to launch bigger first (Note: in case SLURM prioritize first-received items)
items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True)

# launch srun for each item
# launch `mpiexec` for each item
worker_flags=f"--_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS} --_scheduler_port={port}"
cmds = []
current_proc = 0
for item in items:
test_idx = item.original_index
test_out_file_base = f'.pytest_parallel/{remove_exotic_chars(item.nodeid)}'
cmd = f'mpiexec -np {item.n_proc}'
cmd = mpi_command(current_proc, item.n_proc)
cmd += f' python3 -u -m pytest -s {worker_flags} {main_invoke_params} --_test_idx={test_idx} {item.config.rootpath}/{item.nodeid}'
cmd += f' > {test_out_file_base}'
cmds.append(cmd)
current_proc += item.n_proc

script = " & \\\n".join(cmds)
Path('.pytest_parallel').mkdir(exist_ok=True)
script_path = '.pytest_parallel/pytest_static_sched.sh'
script_path = f'.pytest_parallel/pytest_static_sched_{i_step}.sh'
with open(script_path,'w') as f:
f.write(script)

## submit SLURM job
#with open('.pytest_parallel/env_vars.sh','wb') as f:
# f.write(pytest._pytest_parallel_env_vars)

current_permissions = stat.S_IMODE(os.lstat(script_path).st_mode)
os.chmod(script_path, current_permissions | stat.S_IXUSR)

p = subprocess.Popen([script_path], shell=True, stdout=subprocess.PIPE)
print('\nLaunching tests...')
#returncode = p.wait()
#assert returncode==0, f'Error launching tests with `{script_path}`'

#if slurm_conf['sub_command'] is None:
# slurm_job_id = int(p.stdout.read())
#else:
# slurm_job_id = parse_job_id_from_submission_output(p.stdout.read())

#print(f'SLURM job {slurm_job_id} has been submitted')
#return slurm_job_id
return 0
print(f'\nLaunching tests (step {i_step}/{n_step})...')
return p

def receive_items(items, session, socket, n_item_to_recv):
while n_item_to_recv>0:
Expand Down Expand Up @@ -183,12 +168,7 @@ def pytest_runtestloop(self, session) -> bool:
## add proc to items
add_n_procs(session.items)

# isolate skips
print(f"{self.ntasks=}")
for i in session.items:
print(f"{i.n_proc=}")
has_enough_procs = lambda item: item.n_proc <= self.ntasks
items_to_run, items_to_skip = partition(session.items, has_enough_procs)
items_by_steps, items_to_skip = group_items_by_parallel_steps(session.items, self.ntasks)

# run skipped
for i, item in enumerate(items_to_skip):
Expand All @@ -198,20 +178,18 @@ def pytest_runtestloop(self, session) -> bool:
run_item_test(item, nextitem, session)

# schedule tests to run
n_item_to_receive = len(items_to_run)
if n_item_to_receive > 0:
self.slurm_job_id = submit_items(items_to_run, self.socket, self.main_invoke_params, self.ntasks)
if not self.detach: # The job steps are supposed to send their reports
receive_items(session.items, session, self.socket, n_item_to_receive)
SCHEDULER_IP_ADDRESS,port = setup_socket(self.socket)
n_step = len(items_by_steps)
for i_step,items in enumerate(items_by_steps):
n_item_to_receive = len(items)
sub_process = submit_items(items, SCHEDULER_IP_ADDRESS, port, self.main_invoke_params, self.ntasks, i_step, n_step)
if not self.detach: # The job steps are supposed to send their reports
receive_items(session.items, session, self.socket, n_item_to_receive)
returncode = sub_process.wait() # at this point, the sub-process should be done since items have been received
assert returncode==0, f'Error during step {i_step}` of shell scheduler'

return True

#@pytest.hookimpl()
#def pytest_keyboard_interrupt(excinfo):
# if excinfo.slurm_job_id is not None:
# print(f'Calling `scancel {excinfo.slurm_job_id}`')
# subprocess.run(['scancel',str(excinfo.slurm_job_id)])

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(self, item):
"""
Expand Down
24 changes: 24 additions & 0 deletions pytest_parallel/static_scheduler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
def group_items_by_parallel_steps(items, n_workers):
items.sort(key=lambda item: item.n_proc, reverse=True)

remaining_n_procs_by_step = []
items_by_step = []
items_to_skip = []
for item in items:
if item.n_proc > n_workers:
items_to_skip += [item]
else:
found_step = False
for idx, remaining_procs in enumerate(remaining_n_procs_by_step):
if item.n_proc <= remaining_procs:
items_by_step[idx] += [item]
remaining_n_procs_by_step[idx] -= item.n_proc
found_step = True
break
if not found_step:
items_by_step += [[item]]
remaining_n_procs_by_step += [n_workers - item.n_proc]

return items_by_step, items_to_skip


0 comments on commit 7d58ed6

Please sign in to comment.