diff --git a/pytest_parallel/plugin.py b/pytest_parallel/plugin.py index fcbf86d..d0383a7 100644 --- a/pytest_parallel/plugin.py +++ b/pytest_parallel/plugin.py @@ -24,6 +24,7 @@ def pytest_addoption(parser): parser.addoption('--timeout', dest='timeout', type=int, default=7200, help='Timeout') parser.addoption('--slurm-options', dest='slurm_options', type=str, help='list of SLURM options e.g. "--time=00:30:00 --qos=my_queue --n_tasks=4"') + parser.addoption('--slurm-srun-options', dest='slurm_srun_options', type=str, help='list of SLURM srun options e.g. "--mem-per-cpu=4GB"') parser.addoption('--slurm-additional-cmds', dest='slurm_additional_cmds', type=str, help='list of commands to pass to SLURM job e.g. "source my_env.sh"') parser.addoption('--slurm-file', dest='slurm_file', type=str, help='Path to file containing header of SLURM job') # TODO DEL parser.addoption('--slurm-sub-command', dest='slurm_sub_command', type=str, help='SLURM submission command (defaults to `sbatch`)') # TODO DEL @@ -78,6 +79,7 @@ def pytest_configure(config): scheduler = config.getoption('scheduler') n_workers = config.getoption('n_workers') slurm_options = config.getoption('slurm_options') + slurm_srun_options = config.getoption('slurm_srun_options') slurm_additional_cmds = config.getoption('slurm_additional_cmds') is_worker = config.getoption('_worker') slurm_file = config.getoption('slurm_file') @@ -90,6 +92,7 @@ def pytest_configure(config): assert n_workers, f'You need to specify `--n-workers` when `--scheduler={scheduler}`' if scheduler != 'slurm': assert not slurm_options, 'Option `--slurm-options` only available when `--scheduler=slurm`' + assert not slurm_srun_options, 'Option `--slurms-run-options` only available when `--scheduler=slurm`' assert not slurm_additional_cmds, 'Option `--slurm-additional-cmds` only available when `--scheduler=slurm`' assert not slurm_file, 'Option `--slurm-file` only available when `--scheduler=slurm`' @@ -116,9 +119,10 @@ def pytest_configure(config): slurm_option_list = slurm_options.split() if slurm_options is not None else [] slurm_conf = { 'options' : slurm_option_list, + 'srun_options' : slurm_srun_options, 'additional_cmds': slurm_additional_cmds, 'file' : slurm_file, - 'export_env' : slurm_export_env, + 'export_env' : slurm_export_env, 'sub_command' : slurm_sub_command, } plugin = ProcessScheduler(main_invoke_params, n_workers, slurm_conf, detach) diff --git a/pytest_parallel/process_scheduler.py b/pytest_parallel/process_scheduler.py index f18ff2c..9e7c47c 100644 --- a/pytest_parallel/process_scheduler.py +++ b/pytest_parallel/process_scheduler.py @@ -90,6 +90,8 @@ def submit_items(items_to_run, socket, main_invoke_params, ntasks, slurm_conf): items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True) # launch srun for each item + if srun_options := slurm_conf['srun_options'] is None: + srun_options = '' worker_flags=f"--_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS} --_scheduler_port={port}" cmds = '' if slurm_conf['additional_cmds'] is not None: @@ -97,7 +99,7 @@ def submit_items(items_to_run, socket, main_invoke_params, ntasks, slurm_conf): for item in items: test_idx = item.original_index test_out_file_base = f'.pytest_parallel/{remove_exotic_chars(item.nodeid)}' - cmd = f'srun --exclusive --ntasks={item.n_proc} -l' + cmd = f'srun {srun_options} --exclusive --ntasks={item.n_proc} -l' cmd += f' python3 -u -m pytest -s {worker_flags} {main_invoke_params} --_test_idx={test_idx} {item.config.rootpath}/{item.nodeid}' cmd += f' > {test_out_file_base} 2>&1' cmd += ' &\n' # launch everything in parallel