diff --git a/repo_utils/sub_tests/anno.sh b/repo_utils/sub_tests/anno.sh index c9806d66..961c3c4d 100644 --- a/repo_utils/sub_tests/anno.sh +++ b/repo_utils/sub_tests/anno.sh @@ -54,7 +54,7 @@ fi # grm run test_anno_grm \ - $truv anno grm -i $INDIR/variants/input2.vcf.gz -r $REF -o $OD/grm.jl + $truv anno grm -i $INDIR/variants/input2.vcf.gz -r $REF -o $OD/grm.jl -t 2 if [ $test_anno_grm ]; then assert_exit_code 0 df_check test_anno_grm $ANSDIR/anno/grm.jl $OD/grm.jl diff --git a/truvari/annotations/grm.py b/truvari/annotations/grm.py index 8b0581cc..37f6401a 100644 --- a/truvari/annotations/grm.py +++ b/truvari/annotations/grm.py @@ -7,6 +7,7 @@ import types import logging import argparse +import functools import multiprocessing from collections import namedtuple @@ -30,8 +31,6 @@ def setproctitle(_): """ dummy function """ return -# Data shared with workers; must be populated before workers are started. -grm_shared = types.SimpleNamespace() def make_kmers(ref, entry, kmer=25): @@ -213,13 +212,13 @@ def line_to_entry(fields): info_dict) -def read_vcf_lines(ref_name, start, stop): +def read_vcf_lines(in_fn, ref_name, start, stop): """ Faster VCF parsing """ logging.debug(f"Starting region {ref_name}:{start}-{stop}") - tb = tabix.open(grm_shared.input) + tb = tabix.open(in_fn) try: yield from tb.query(ref_name, start, stop) except tabix.TabixError as e: @@ -228,10 +227,11 @@ def read_vcf_lines(ref_name, start, stop): logging.debug(f"Done region {ref_name}:{start}-{stop}") -def process_entries(ref_section): +def process_entries(ref_section, grm_shared): """ Calculate GRMs for a set of vcf entries """ + grm_shared.aligner = BwaAligner(grm_shared.ref_filename, '-a') ref_name, start, stop = ref_section ref = pysam.FastaFile(grm_shared.ref_filename) aligner = grm_shared.aligner @@ -240,7 +240,7 @@ def process_entries(ref_section): minsize = grm_shared.min_size rows = [] next_progress = 0 - for line in read_vcf_lines(ref_name, start, stop): + for line in read_vcf_lines(grm_shared.input, ref_name, start, stop): if "SVLEN" not in line[7] and abs(len(line[3]) - len(line[4])) < minsize: continue entry = line_to_entry(line) @@ -305,21 +305,22 @@ def grm_main(cmdargs): else: m_ranges = truvari.bed_ranges(args.regions) - grm_shared.aligner = BwaAligner(args.reference, '-a') header = ["key"] for prefix in ["rup_", "rdn_", "aup_", "adn_"]: for key in ["nhits", "avg_q", "avg_ed", "avg_mat", "avg_mis", "dir_hits", "com_hits", "max_q", "max_ed", "max_mat", "max_mis", "max_strand", "min_q", "min_ed", "min_mat", "min_mis", "min_strand"]: header.append(prefix + key) + grm_shared = types.SimpleNamespace() grm_shared.header = header grm_shared.ref_filename = args.reference grm_shared.kmersize = args.kmersize grm_shared.input = args.input grm_shared.min_size = args.min_size + m_process_entries = functools.partial(process_entries, grm_shared=grm_shared) with multiprocessing.Pool(args.threads, maxtasksperchild=1) as pool: logging.info("Processing") - chunks = pool.imap(process_entries, m_ranges) + chunks = pool.imap(m_process_entries, m_ranges) pool.close() data = pd.concat(chunks, ignore_index=True) logging.info("Saving; df shape %s", data.shape)