diff --git a/aviary/aviary.py b/aviary/aviary.py index aad581d..95beaf3 100755 --- a/aviary/aviary.py +++ b/aviary/aviary.py @@ -153,6 +153,13 @@ def main(): default=8, ) + base_group.add_argument( + '-p', '--pplacer-threads', '--pplacer_threads', + help='The number of threads given to pplacer, values above 48 will be scaled down', + dest='pplacer_threads', + default=48, + ) + base_group.add_argument( '-n', '--n-cores', '--n_cores', help='Maximum number of cores available for use. Setting to multiples of max_threads will allow for multiple processes to be run in parallel.', diff --git a/aviary/modules/processor.py b/aviary/modules/processor.py index 2f247bf..b6a1a18 100644 --- a/aviary/modules/processor.py +++ b/aviary/modules/processor.py @@ -99,10 +99,14 @@ def __init__(self, self.output = os.path.abspath(args.output) self.threads = args.max_threads self.max_memory = args.max_memory - self.pplacer_threads = min(int(self.threads), 48) self.workflows = args.workflow self.request_gpu = args.request_gpu + try: + self.pplacer_threads = min(int(args.pplacer_threads), int(self.threads), 48) + except AttributeError: + self.pplacer_threads = min(int(self.threads), 48) + try: self.strain_analysis = args.strain_analysis except AttributeError: diff --git a/test/test_recover.py b/test/test_recover.py index d1775d0..5bd2aad 100644 --- a/test/test_recover.py +++ b/test/test_recover.py @@ -358,6 +358,7 @@ def test_recover_config(self): f"SINGLEM_METAPACKAGE_PATH=. " f"aviary recover " f"--refinery-max-iterations 3 " + f"--max-threads 8 " f"--assembly {ASSEMBLY} " f"-1 {FORWARD_READS} " f"-2 {REVERSE_READS} " @@ -373,6 +374,60 @@ def test_recover_config(self): config = load_configfile(config_path) self.assertEqual(config["refinery_max_iterations"], 3) + self.assertEqual(config["pplacer_threads"], 8) + + def test_recover_config_many_threads(self): + with tempfile.TemporaryDirectory() as tmpdir: + cmd = ( + f"GTDBTK_DATA_PATH=. " + f"CHECKM2DB=. " + f"EGGNOG_DATA_DIR=. " + f"SINGLEM_METAPACKAGE_PATH=. " + f"aviary recover " + f"--max-threads 128 " + f"--assembly {ASSEMBLY} " + f"-1 {FORWARD_READS} " + f"-2 {REVERSE_READS} " + f"--output {tmpdir}/test --tmpdir {tmpdir} " + f"--conda-prefix {path_to_conda} " + f"--dryrun " + f"--snakemake-cmds \" --quiet\" " + ) + extern.run(cmd) + + config_path = os.path.join(tmpdir, "test", "config.yaml") + self.assertTrue(os.path.exists(config_path)) + config = load_configfile(config_path) + + self.assertEqual(config["refinery_max_iterations"], 5) + self.assertEqual(config["pplacer_threads"], 48) + + def test_recover_config_many_pplacer_threads(self): + with tempfile.TemporaryDirectory() as tmpdir: + cmd = ( + f"GTDBTK_DATA_PATH=. " + f"CHECKM2DB=. " + f"EGGNOG_DATA_DIR=. " + f"SINGLEM_METAPACKAGE_PATH=. " + f"aviary recover " + f"--max-threads 128 " + f"--pplacer-threads 32 " + f"--assembly {ASSEMBLY} " + f"-1 {FORWARD_READS} " + f"-2 {REVERSE_READS} " + f"--output {tmpdir}/test --tmpdir {tmpdir} " + f"--conda-prefix {path_to_conda} " + f"--dryrun " + f"--snakemake-cmds \" --quiet\" " + ) + extern.run(cmd) + + config_path = os.path.join(tmpdir, "test", "config.yaml") + self.assertTrue(os.path.exists(config_path)) + config = load_configfile(config_path) + + self.assertEqual(config["refinery_max_iterations"], 5) + self.assertEqual(config["pplacer_threads"], 32) if __name__ == '__main__': unittest.main()