Skip to content

Commit

Permalink
In a determination where N equals 100%, due to float type precision i…
Browse files Browse the repository at this point in the history
…ssues, situations such as `100 != 100.000002` occurred, leading to unexpected conditional branching. Therefore, the condition was changed to "having only one key and that key being `N`".
  • Loading branch information
akikuno committed Nov 10, 2023
1 parent cc482a2 commit ec33d36
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 51 deletions.
97 changes: 56 additions & 41 deletions src/DAJIN2/core/consensus/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,51 @@
###########################################################


def convert_to_percentage(cssplits: list[list[str]], mutation_loci: list[set[str]]) -> list[dict[str, float]]:
"""
Convert sequences and mutations into percentages, annotating sequence errors as "SEQERROR".
"""
# Transpose the cssplits to facilitate per-position processing
cssplits_transposed = [list(cs) for cs in zip(*cssplits)]
coverage = len(cssplits_transposed[0])

cons_percentage = []
for cs_transposed, mut_loci in zip(cssplits_transposed, mutation_loci):
count_cs = defaultdict(float)
for cs in cs_transposed:
# Annotate as "SEQERROR" if the condition is met
if cs[0] in {"+", "-", "*"} and cs[0] not in mut_loci:
cs = "SEQERROR"
count_cs[cs] += 1 / coverage * 100
cons_percentage.append(dict(count_cs))

return cons_percentage


def remove_all_n(cons_percentage: list[dict[str, float]]) -> list[dict[str, float]]:
for c in cons_percentage:
if c == {"N": 100}:
if len(c) == 1 and "N" in c:
continue
_ = c.pop("N", None)
c.pop("N", None)

return cons_percentage


def replace_sequence_errror(cons_percentage: list[dict[str, float]]) -> list[dict[str, float]]:
"""replace sequence error as distributing according to proportion of cs tags"""
cons_percentage_update = []
def replace_sequence_error(cons_percentage: list[dict[str, float]]) -> list[dict[str, float]]:
"""
Replace sequence error as distributing according to proportion of cs tags
If a dictionary contains only "SEQERROR", it is replaced with {"N": 100}. Otherwise, "SEQERROR" is removed.
"""
cons_percentage_replaced = []
for cons_per in cons_percentage:
if "SEQERROR" not in cons_per:
cons_percentage_update.append(cons_per)
continue
if len(cons_per) == 1 and cons_per["SEQERROR"]:
cons_percentage_update.append({"N": 100})
# Replace a dictionary containing only "SEQERROR" with {"N": 100}
if len(cons_per) == 1 and "SEQERROR" in cons_per:
cons_percentage_replaced.append({"N": 100})
continue
cons_per_update = dict()
div = 100 / (sum(cons_per.values()) - cons_per["SEQERROR"])
for key, val in cons_per.items():
if key == "SEQERROR":
continue
cons_per_update[key] = val * div
cons_percentage_update.append(cons_per_update)
return cons_percentage_update
cons_per.pop("SEQERROR", None)
cons_percentage_replaced.append(cons_per)

return cons_percentage_replaced


def adjust_to_100_percent(cons_percentage: list[dict[str, float]]) -> list[dict[str, float]]:
Expand All @@ -55,23 +74,13 @@ def adjust_to_100_percent(cons_percentage: list[dict[str, float]]) -> list[dict[
return adjusted_percentages


def call_percentage(cssplits: list[str], mutation_loci) -> list[dict[str, float]]:
def call_percentage(cssplits: list[list[str]], mutation_loci: list[set[str]]) -> list[dict[str, float]]:
"""call position weight matrix in defferent loci.
- non defferent loci are annotated to "Match" or "Unknown(N)"
- sequence errors are annotated to "SEQERROR"
"""
cssplits_transposed = [list(cs) for cs in zip(*cssplits)]
coverage = len(cssplits)
cons_percentage = []
for cs_transposed, mut_loci in zip(cssplits_transposed, mutation_loci):
count_cs = defaultdict(float)
for cs in cs_transposed:
if cs[0] in {"+", "-", "*"} and cs[0] not in mut_loci:
cs = "SEQERROR"
count_cs[cs] += 1 / coverage * 100
cons_percentage.append(dict(count_cs))
cons_percentage = convert_to_percentage(cssplits, mutation_loci)
cons_percentage = remove_all_n(cons_percentage)
cons_percentage = replace_sequence_errror(cons_percentage)
cons_percentage = replace_sequence_error(cons_percentage)
return adjust_to_100_percent(cons_percentage)


Expand All @@ -80,14 +89,14 @@ def call_percentage(cssplits: list[str], mutation_loci) -> list[dict[str, float]
###########################################################


def _process_base(cons: str) -> str:
def cstag_to_base(cons: str) -> str:
if cons.startswith("="): # match
return cons.replace("=", "")
elif cons.startswith("-"): # deletion
if cons.startswith("-"): # deletion
return ""
elif cons.startswith("*"): # substitution
if cons.startswith("*"): # substitution
return cons[-1]
elif cons.startswith("+"): # insertion
if cons.startswith("+"): # insertion
cons_ins = cons.split("|")
if cons_ins[-1].startswith("="): # match after insertion
cons = cons.replace("=", "")
Expand All @@ -99,13 +108,13 @@ def _process_base(cons: str) -> str:
return cons


def _call_sequence(cons_percentage: list[dict[str, float]]) -> str:
def call_sequence(cons_percentage: list[dict[str, float]]) -> str:
consensus_sequence = []
n_left, n_right = find_n_boundaries(cons_percentage)
for i, cons_per in enumerate(cons_percentage):
if n_left < i < n_right:
cons = max(cons_per, key=cons_per.get)
consensus_sequence.append(_process_base(cons))
consensus_sequence.append(cstag_to_base(cons))
else:
consensus_sequence.append("N")
return "".join(consensus_sequence)
Expand All @@ -127,15 +136,21 @@ def call_consensus(
) -> tuple[defaultdict[list], defaultdict[str]]:
cons_percentages = defaultdict(list)
cons_sequences = defaultdict(str)
mutation_loci_cache = dict()
clust_sample.sort(key=lambda x: x["LABEL"])
for label, group in groupby(clust_sample, key=lambda x: x["LABEL"]):
clust = list(group)
allele = clust[0]["ALLELE"]
key = ConsensusKey(allele, label, clust[0]["PERCENT"])
cssplits = [cs["CSSPLIT"].split(",") for cs in clust]
with open(Path(TEMPDIR, SAMPLE_NAME, "mutation_loci", f"{allele}.pickle"), "rb") as p:
mutation_loci = pickle.load(p)

if allele not in mutation_loci_cache:
with open(Path(TEMPDIR, SAMPLE_NAME, "mutation_loci", f"{allele}.pickle"), "rb") as p:
mutation_loci_cache[allele] = pickle.load(p)
mutation_loci = mutation_loci_cache[allele]

cons_percentage = call_percentage(cssplits, mutation_loci)

key = ConsensusKey(allele, label, clust[0]["PERCENT"])
cons_percentages[key] = cons_percentage
cons_sequences[key] = _call_sequence(cons_percentage)
cons_sequences[key] = call_sequence(cons_percentage)
return cons_percentages, cons_sequences
33 changes: 23 additions & 10 deletions tests/src/consensus/test_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from src.DAJIN2.core.consensus.consensus import (
# _remove_nonconsecutive_n,
replace_sequence_errror,
replace_sequence_error,
adjust_to_100_percent,
call_percentage,
_process_base,
_call_sequence,
cstag_to_base,
call_sequence,
)


Expand Down Expand Up @@ -34,10 +35,22 @@
# assert result == expected_output


def test_replace_sequence_errror():
def test_replace_sequence_error():
cons_percentage = [{"A": 25, "C": 25, "SEQERROR": 50}, {"SEQERROR": 100}]
expected_output = [{"A": 50, "C": 50}, {"N": 100}]
assert replace_sequence_errror(cons_percentage) == expected_output
expected_output = [{"A": 25, "C": 25}, {"N": 100}]
assert replace_sequence_error(cons_percentage) == expected_output


def test_adjust_to_100_percent():
test = [{"A": 25, "C": 25}, {"N": 100}]
expected = [{"A": 50, "C": 50}, {"N": 100}]
assert adjust_to_100_percent(test) == expected


def test_adjust_to_100_percent_float():
test = [{"A": 20.1, "C": 19.9}]
expected = [{"A": 50.25, "C": 49.75}]
assert adjust_to_100_percent(test) == expected


def test_call_percentage():
Expand All @@ -52,7 +65,7 @@ def test_call_percentage():
###########################################################


# Example test cases for _process_base function
# Example test cases for cstag_to_base function
@pytest.mark.parametrize(
"cons, expected_output",
[
Expand All @@ -68,8 +81,8 @@ def test_call_percentage():
("", ""),
],
)
def test_process_base(cons, expected_output):
result = _process_base(cons)
def test_cstag_to_base(cons, expected_output):
result = cstag_to_base(cons)
assert result == expected_output


Expand All @@ -88,5 +101,5 @@ def test_process_base(cons, expected_output):
],
)
def test_call_sequence(cons_percentage_by_key, expected_sequence):
result_sequence = _call_sequence(cons_percentage_by_key)
result_sequence = call_sequence(cons_percentage_by_key)
assert result_sequence == expected_sequence

0 comments on commit ec33d36

Please sign in to comment.