Skip to content

Commit

Permalink
added more prancSTR tests for argparsing
Browse files Browse the repository at this point in the history
  • Loading branch information
gymreklab committed Nov 20, 2023
1 parent 60d6659 commit a89196a
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 67 deletions.
93 changes: 29 additions & 64 deletions trtools/prancSTR/prancSTR.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import trtools.utils.tr_harmonizer as trh
from trtools import __version__

READFIELD = "MALLREADS"
ZERO = 10e-200
MAXSTUTTEROFFSET = 200

Expand Down Expand Up @@ -107,7 +106,8 @@ def MaximizeMosaicLikelihoodBoth(reads, A, B,
f = Just_F_Pred(reads, A, B, C, stutter_probs)
iter_num += 1
if iter_num > maxiter:
if not quiet: common.WARNING("ML didn't converge reads=%s A=%s B=%s %s" %
if not quiet:
common.WARNING("ML didn't converge reads=%s A=%s B=%s %s" %
(str(reads), A, B, locname))
break
if abs(f-f_prev) < 0.01 and (f < 0.000001 or C == c_prev):
Expand Down Expand Up @@ -397,13 +397,14 @@ def getargs(): # pragma: no cover
"--out", help=("Output file prefix. Use stdout to print file to standard output"), type=str, required=True)
inout_group.add_argument("--vcftype", help="Options=%s" %
[str(item) for item in trh.VcfTypes.__members__], type=str, default="auto")
inout_group.add_argument("--samples", help="Comma-separated list of samples to process", type=str)
inout_group.add_argument("--samples", help="Comma-separated list of samples to process."
" Note samples not in the VCF are ignored.", type=str)
filter_group = parser.add_argument_group("Filtering group")
filter_group.add_argument("--region", help="Restrict to the region "
"chrom:start-end. Requires file to bgzipped and"
" tabix indexed.", type=str)
# filter_group.add_argument("--readfield", help="Select the field to extract reads from"
# " Options are between MALLREADS and ALLREADS.", type=str)
filter_group.add_argument("--readfield", help="Select the field to extract reads from"
" Options are between MALLREADS and ALLREADS.", type=str, default="MALLREADS")
filter_group.add_argument("--only-passing", help="Only process records "
" where FILTER==PASS", action="store_true")
filter_group.add_argument("--output-all", help="Force output results for all loci", action="store_true")
Expand Down Expand Up @@ -432,6 +433,9 @@ def main(args):
common.WARNING("Error: The output location {} is a "
"directory".format(args.out))
return 1
if args.readfield not in ["ALLREADS","MALLREADS"]:
common.WARNING("Error: args.readfield must be either ALLREADS or MALLREADS")
return 1

checkgz = args.region is not None
invcf = utils.LoadSingleReader(args.vcf, checkgz=checkgz)
Expand All @@ -458,7 +462,6 @@ def main(args):

start_time = time.time()
nrecords = 0
# READFIELD=args.readfield

if args.out == "stdout":
outf = sys.stdout
Expand All @@ -478,16 +481,16 @@ def main(args):
trrecord = trh.HarmonizeRecord(vcftype, record)

if args.only_passing and not args.output_all and (record.FILTER is not None):
common.WARNING("Skipping non-passing record %s" %
str(trrecord))
common.WARNING("Skipping record %s with non-passing VCF FILTER field." %
str(trrecord))
continue

########### Extract necessary info from the VCF file #######
# Stutter params for the locus. These are the same for all samples
# First check we have all the fields we need
if READFIELD not in trrecord.format.keys():
common.WARNING("Could not find MALLREADS for %s" %
str(trrecord))
if args.readfield not in trrecord.format.keys():
common.WARNING("Could not find read field %s for %s" %
(args.readfield, str(trrecord)))
continue
if "INFRAME_UP" not in trrecord.info.keys() or \
"INFRAME_DOWN" not in trrecord.info.keys() or \
Expand All @@ -500,52 +503,15 @@ def main(args):
stutter_d = 0.05
stutter_rho = 0.90
else:
outf = open(args.out + ".tab", "w")

# Header
header_cols = ["sample", "chrom", "pos", "locus", "motif",
"A", "B", "C", "f", "pval", "reads",
"mosaic_support", "stutter parameter u",
"stutter paramter d", "stutter paramter rho",
"quality factor", "read depth"]
outf.write("\t".join(header_cols)+"\n")

for record in region:
nrecords += 1
trrecord = trh.HarmonizeRecord(vcftype, record)

if args.only_passing and not args.output_all and (record.FILTER is not None):
common.WARNING("Skipping record %s with non-passing VCF FILTER field." %
str(trrecord))
continue

########### Extract necessary info from the VCF file #######
# Stutter params for the locus. These are the same for all samples
# First check we have all the fields we need
if READFIELD not in trrecord.format.keys():
common.WARNING("Could not find MALLREADS for %s" %
str(trrecord))
continue
if "INFRAME_UP" not in trrecord.info.keys() or \
"INFRAME_DOWN" not in trrecord.info.keys() or \
"INFRAME_PGEOM" not in trrecord.info.keys():
common.WARNING(
"Could not find stutter info for %s" % str(trrecord))
common.WARNING(
"Adding default stutter info for %s" % str(trrecord))
stutter_u = 0.05
stutter_d = 0.05
stutter_rho = 0.90
else:
stutter_u = trrecord.info["INFRAME_UP"]
stutter_d = trrecord.info["INFRAME_DOWN"]
stutter_rho = trrecord.info["INFRAME_PGEOM"]
if stutter_u == 0.0:
stutter_u = 0.01
if stutter_d == 0.0:
stutter_d = 0.01
if stutter_rho == 1.0:
stutter_rho = 0.95
stutter_u = trrecord.info["INFRAME_UP"]
stutter_d = trrecord.info["INFRAME_DOWN"]
stutter_rho = trrecord.info["INFRAME_PGEOM"]
if stutter_u == 0.0:
stutter_u = 0.01
if stutter_d == 0.0:
stutter_d = 0.01
if stutter_rho == 1.0:
stutter_rho = 0.95
stutter_probs = [StutterProb(d, stutter_u, stutter_d, stutter_rho) \
for d in range(-MAXSTUTTEROFFSET, MAXSTUTTEROFFSET)]
period = len(trrecord.motif)
Expand All @@ -558,7 +524,7 @@ def main(args):
# Array of "reads" vectors for each sample
# given in repeat units diff from ref
mallreads = [ExtractReadVector(item, period)
for item in trrecord.format[READFIELD]]
for item in trrecord.format[args.readfield]]

# Extracting quality parameter
Q = trrecord.format['Q']
Expand Down Expand Up @@ -600,21 +566,20 @@ def main(args):
str(record.ID), trrecord.motif, str(
A), str(B),
str(best_C), str(best_f), str(pval),
trrecord.format[READFIELD][i],
trrecord.format[args.readfield][i],
str(reads.count(best_C)),
str(stutter_u), str(
stutter_d), str(stutter_rho),
str(q), str(dp)]) + '\n')
if args.debug:
common.WARNING("Inferred best_C=%s best_f=%s" %
(best_C, best_f))
#############################################################
if args.out == "stdout" and nrecords % 50 == 0:
#############################################################
if nrecords % 50 == 0 and not args.quiet:
common.MSG("Finished {} records, time/record={:.5}sec".format(nrecords,
(time.time() - start_time)/nrecords), debug=True)
(time.time() - start_time)/nrecords), debug=True)

if args.out == "stdout":
common.MSG("Performed analysis on {} records".format(nrecords), debug=True)
if not args.quiet: common.MSG("Performed analysis on {} records".format(nrecords), debug=True)

if outf is not None and args.out != "stdout":
outf.close()
Expand Down
91 changes: 89 additions & 2 deletions trtools/prancSTR/tests/test_prancSTR.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def args(tmpdir):
args.debug = False
args.vcftype = "hipstr"
args.samples = None
args.quiet = False
args.quiet = True
args.output_all = False
args.readfield = "MALLREADS"
return args

# Test no such file or directory
Expand All @@ -29,13 +30,99 @@ def test_WrongFile(args, vcfdir):
retcode = main(args)
assert retcode==1

# Test the right file or directory
# real path but not VCF
fname = os.path.join(vcfdir, "CEU_test.vcf.gz.tbi")
args.vcf = fname
retcode = main(args)
assert retcode==1

# Test bad output directory
def test_BadOutdir(args, vcfdir, tmpdir):
fname = os.path.join(vcfdir, "test_hipstr.vcf")
args.vcf = fname
args.out = str(tmpdir / "bad/test")
retcode = main(args)
assert retcode==1

args.out = str(tmpdir)+os.sep
print(args.out)
retcode = main(args)
assert retcode==1

# Test a good VCF file
def test_RightFile(args, vcfdir):
fname = os.path.join(vcfdir, "test_hipstr.vcf")
args.vcftype = "auto"
args.vcf = fname
retcode = main(args)
assert retcode==0

args.quiet = False
retcode = main(args)
assert retcode==0

# Wrong VCF type
args.vcftype = "advntr"
retcode = main(args)
assert retcode==1

# Test a multi-sample VCF file
def test_MosaicCase(args, vcfdir):
fname = os.path.join(vcfdir, "CEU_test.vcf.gz")
args.vcf = fname
args.quiet = True
retcode = main(args)
assert retcode==0

args.quiet = False
retcode = main(args)
assert retcode==0

# Specific region
# Note: cyvcf2 handles region parsing and will
# output a warning if no intervals found
args.quiet = True
args.region = "chr1:987287-987288"
retcode = main(args)
assert retcode==0

# Samples list
args.samples = "NA12878"
retcode = main(args)
assert retcode==0

# Bad samples list should just give no output
args.samples = "XYZ"
retcode = main(args)
assert retcode==0

# With only passing
args.samples = "NA12878"
args.only_passing = True
args.region = None
retcode = main(args)
assert retcode==0

# With only passing - debug mode
args.only_passing = True
args.region = None
args.debug = True
retcode = main(args)
assert retcode==0

# Write to stdout
args.samples = "NA12878"
args.out = "stdout"
retcode == main(args)
assert retcode==0

# With bad readfield
args.readfield = "badreadfield"
retcode = main(args)
assert retcode==1



# Test the probability of observing a certain repeat length
def test_StutterProb1():
delta = 0
Expand Down
Binary file added trtools/testsupport/sample_vcfs/CEU_test.vcf.gz
Binary file not shown.
Binary file added trtools/testsupport/sample_vcfs/CEU_test.vcf.gz.tbi
Binary file not shown.
7 changes: 6 additions & 1 deletion trtools/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def LoadSingleReader(
-------
reader : Optional[cyvcf2.VCF]
The cyvcf2.VCF instance, or None if the VCF is not present
or could not be opened
"""
# check that vcf_loc is a file or file descriptor (ex: '/dev/stdin')
if not os.path.exists(vcf_loc) or os.path.isdir(vcf_loc):
Expand All @@ -46,7 +47,11 @@ def LoadSingleReader(
if not os.path.isfile(vcf_loc+".tbi"):
common.WARNING("Could not find VCF index %s.tbi"%vcf_loc)
return None
return cyvcf2.VCF(vcf_loc)
try:
return cyvcf2.VCF(vcf_loc)
except OSError:
common.WARNING("Could not open VCF file %s. Is it really VCF?"%vcf_loc)
return None

def LoadReaders(
vcf_locs: List[str],
Expand Down

0 comments on commit a89196a

Please sign in to comment.