Skip to content

Commit

Permalink
Half of #154 fixed
Browse files Browse the repository at this point in the history
No longer using global simplenamespace to pass around arguments in trf.
GRM will need more testing to ensure BwaAligner is fine being initialized per-thread.
  • Loading branch information
ACEnglish committed May 12, 2023
1 parent 1d6e48d commit f9f3630
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,4 @@ exclude-protected=_asdict,_fields,_replace,_source,_make

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
overgeneral-exceptions=builtins.Exception
45 changes: 20 additions & 25 deletions truvari/annotations/trf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,17 @@
import sys
import json
import math
import types
import shutil
import logging
import argparse
import functools
import multiprocessing
from io import StringIO
from functools import cmp_to_key

import pysam
import tabix
import truvari

trfshared = types.SimpleNamespace()

try:
from setproctitle import setproctitle # pylint: disable=import-error,useless-suppression
except ModuleNotFoundError:
Expand Down Expand Up @@ -50,7 +47,7 @@ def compare_scores(a, b):
elif aspan < bspan:
ret = -1
return ret
score_sorter = cmp_to_key(compare_scores)
score_sorter = functools.cmp_to_key(compare_scores)

class TRFAnno():
"""
Expand Down Expand Up @@ -344,7 +341,7 @@ def build_from_top(self):
self.tanno = TRFAnno(cur_anno, ref_seq, self.motif_sim)


def process_ref_region(region):
def process_ref_region(region, args):
"""
Process a section of the reference.
Tries to run TRF only once
Expand All @@ -355,7 +352,7 @@ def process_ref_region(region):
logging.debug(f"Starting region {region['chrom']}:{region['start']}-{region['end']}")
setproctitle(f"trf {region['chrom']}:{region['start']}-{region['end']}")

vcf = pysam.VariantFile(trfshared.args.input)
vcf = pysam.VariantFile(args.input)
new_header = edit_header(vcf.header)
out = StringIO()

Expand All @@ -366,10 +363,10 @@ def process_ref_region(region):
return None


m_stack = AnnoStack(list(iter_tr_regions(trfshared.args.repeats,
m_stack = AnnoStack(list(iter_tr_regions(args.repeats,
(region["chrom"], region["start"], region["end"]))),
pysam.FastaFile(trfshared.args.reference),
trfshared.args.motif_similarity)
pysam.FastaFile(args.reference),
args.motif_similarity)

batch = []
fa_fn = truvari.make_temp_filename(suffix=".fa")
Expand All @@ -389,23 +386,23 @@ def process_ref_region(region):

svtype = truvari.entry_variant_type(entry)
svlen = truvari.entry_size(entry)
if svlen < trfshared.args.min_length or svtype not in [truvari.SV.DEL, truvari.SV.INS]:
if svlen < args.min_length or svtype not in [truvari.SV.DEL, truvari.SV.INS]:
out.write(str(edit_entry(entry, None, new_header)))
continue

if svtype == truvari.SV.DEL:
m_anno = m_stack.tanno.del_annotate(entry, svlen)
out.write(str(edit_entry(entry, m_anno, new_header)))
elif svtype == truvari.SV.INS:
m_anno = m_stack.tanno.ins_estimate_anno(entry) if not trfshared.args.no_estimate else None
m_anno = m_stack.tanno.ins_estimate_anno(entry) if not args.no_estimate else None
if m_anno:
out.write(str(edit_entry(entry, m_anno, new_header)))
else:
batch.append((entry, m_stack.tanno))
fa_out.write(f">{len(batch) - 1}\n{m_stack.tanno.make_seq(entry, 'INS')}\n")

if batch:
annotations = run_trf(fa_fn, trfshared.args.executable, trfshared.args.trf_params)
annotations = run_trf(fa_fn, args.executable, args.trf_params)
for key, (entry, tanno) in enumerate(batch):
key = str(key)
m_anno = None
Expand All @@ -420,17 +417,17 @@ def process_ref_region(region):
logging.debug(f"Done region {region['chrom']}:{region['start']}-{region['end']}")
return out.read()

def process_tr_region(region):
def process_tr_region(region, args):
"""
Process vcf lines from a tr reference section
"""
logging.debug(f"Starting region {region['chrom']}:{region['start']}-{region['end']}")
setproctitle(f"trf {region['chrom']}:{region['start']}-{region['end']}")

ref = pysam.FastaFile(trfshared.args.reference)
ref = pysam.FastaFile(args.reference)
ref_seq = ref.fetch(region["chrom"], region["start"], region["end"])
tanno = TRFAnno(region, ref_seq, trfshared.args.motif_similarity, trfshared.args.buffer)
vcf = pysam.VariantFile(trfshared.args.input)
tanno = TRFAnno(region, ref_seq, args.motif_similarity, args.buffer)
vcf = pysam.VariantFile(args.input)
new_header = edit_header(vcf.header)
out = StringIO()

Expand All @@ -448,15 +445,15 @@ def process_tr_region(region):
continue
svtype = truvari.entry_variant_type(entry)
svlen = truvari.entry_size(entry)
if svlen < trfshared.args.min_length or svtype not in [truvari.SV.DEL, truvari.SV.INS]:
if svlen < args.min_length or svtype not in [truvari.SV.DEL, truvari.SV.INS]:
out.write(str(edit_entry(entry, None, new_header)))
continue

if svtype == truvari.SV.DEL:
m_anno = tanno.del_annotate(entry, svlen)
out.write(str(edit_entry(entry, m_anno, new_header)))
elif svtype == truvari.SV.INS:
m_anno = tanno.ins_estimate_anno(entry) if not trfshared.args.no_estimate else None
m_anno = tanno.ins_estimate_anno(entry) if not args.no_estimate else None
if m_anno:
out.write(str(edit_entry(entry, m_anno, new_header)))
else:
Expand All @@ -465,7 +462,7 @@ def process_tr_region(region):
fa_out.write(f">{len(batch) - 1}\n{seq}\n")

if batch:
annotations = run_trf(fa_fn, trfshared.args.executable, trfshared.args.trf_params)
annotations = run_trf(fa_fn, args.executable, args.trf_params)
for key, entry in enumerate(batch):
key = str(key)
m_anno = None
Expand Down Expand Up @@ -634,20 +631,18 @@ def trf_main(cmdargs):
""" TRF annotation """
args = parse_args(cmdargs)
check_params(args)
trfshared.args = args

m_regions = None
m_process = None
if args.regions_only:
m_regions = iter_tr_regions(args.repeats)
m_process = process_tr_region
m_process = functools.partial(process_tr_region, args=args)
else:
# refactor. need streaming mode
m_regions = truvari.ref_ranges(args.reference, chunk_size=int(args.chunk_size * 1e6))
m_process = process_ref_region

m_process = functools.partial(process_ref_region, args=args)

vcf = pysam.VariantFile(trfshared.args.input)
vcf = pysam.VariantFile(args.input)
new_header = edit_header(vcf.header)

with multiprocessing.Pool(args.threads, maxtasksperchild=1) as pool:
Expand Down
4 changes: 2 additions & 2 deletions truvari/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def make_region_report(data):
Given a refine counts DataFrame, calculate the performance of
PPV, TNR, etc. Also adds 'state' column to regions inplace
"""
false_pos = (data['out_fp'] != 0)
false_neg = (data['out_fn'] != 0)
false_pos = data['out_fp'] != 0
false_neg = data['out_fn'] != 0
any_false = false_pos | false_neg

true_positives = (data['out_tp'] != 0) & (data['out_tpbase'] != 0) & ~any_false
Expand Down

0 comments on commit f9f3630

Please sign in to comment.