Skip to content

Commit

Permalink
Update consensus.call_consensus: For mutations determined to be seq…
Browse files Browse the repository at this point in the history
…uence errors, we previously replaced them with unknown (`N`), but this `N` had low interpretability. Therefore, mutations that DAJIN2 determines to be sequence errors will now be assigned the same base as the reference genome.
  • Loading branch information
akikuno committed Jul 12, 2024
1 parent ab571b7 commit 1f46215
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 41 deletions.
37 changes: 12 additions & 25 deletions src/DAJIN2/core/consensus/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
###########################################################


def convert_to_percentage(cssplits: list[list[str]], mutation_loci: list[set[str]]) -> list[dict[str, float]]:
def convert_to_percentage(
cssplits: list[list[str]], mutation_loci: list[set[str]], sequence: str
) -> list[dict[str, float]]:
"""
Convert sequences and mutations into percentages, annotating sequence errors as "SEQERROR".
"""
Expand All @@ -22,13 +24,13 @@ def convert_to_percentage(cssplits: list[list[str]], mutation_loci: list[set[str
coverage = len(cssplits_transposed[0])

cons_percentage = []
for cs_transposed, mut_loci in zip(cssplits_transposed, mutation_loci):
for cs_transposed, mut_loci, nucleotide in zip(cssplits_transposed, mutation_loci, sequence):
count_cs = defaultdict(float)
for cs in cs_transposed:
operator = cs[0]
if operator in {"+", "-", "*"}:
if operator not in mut_loci:
cs = "SEQERROR"
cs = f"={nucleotide.upper()}"
count_cs[cs] += 1 / coverage * 100
cons_percentage.append(dict(count_cs))

Expand All @@ -44,23 +46,6 @@ def remove_all_n(cons_percentage: list[dict[str, float]]) -> list[dict[str, floa
return cons_percentage


def replace_sequence_error(cons_percentage: list[dict[str, float]]) -> list[dict[str, float]]:
"""
Replace sequence error as distributing according to proportion of cs tags
If a dictionary contains only "SEQERROR", it is replaced with {"N": 100}. Otherwise, "SEQERROR" is removed.
"""
cons_percentage_replaced = []
for cons_per in cons_percentage:
# Replace a dictionary containing only "SEQERROR" with {"N": 100}
if len(cons_per) == 1 and "SEQERROR" in cons_per:
cons_percentage_replaced.append({"N": 100})
continue
cons_per.pop("SEQERROR", None)
cons_percentage_replaced.append(cons_per)

return cons_percentage_replaced


def adjust_to_100_percent(cons_percentage: list[dict[str, float]]) -> list[dict[str, float]]:
adjusted_percentages = []

Expand All @@ -74,13 +59,12 @@ def adjust_to_100_percent(cons_percentage: list[dict[str, float]]) -> list[dict[
return adjusted_percentages


def call_percentage(cssplits: list[list[str]], mutation_loci: list[set[str]]) -> list[dict[str, float]]:
def call_percentage(cssplits: list[list[str]], mutation_loci: list[set[str]], sequence: str) -> list[dict[str, float]]:
"""call position weight matrix in defferent loci.
- non defferent loci are annotated to "Match" or "Unknown(N)"
"""
cons_percentage = convert_to_percentage(cssplits, mutation_loci)
cons_percentage = convert_to_percentage(cssplits, mutation_loci, sequence)
cons_percentage = remove_all_n(cons_percentage)
cons_percentage = replace_sequence_error(cons_percentage)
return adjust_to_100_percent(cons_percentage)


Expand All @@ -96,20 +80,23 @@ class ConsensusKey:
percent: float


def call_consensus(tempdir: Path, sample_name: str, clust_sample: list[dict]) -> tuple[dict[list], dict[str]]:
def call_consensus(
tempdir: Path, sample_name: str, fasta_alleles: dict[str, str], clust_sample: list[dict]
) -> tuple[dict[list], dict[str]]:
clust_sample.sort(key=lambda x: [x["ALLELE"], x["LABEL"]])

cons_percentages = {}
cons_sequences = {}

for (allele, label), group in groupby(clust_sample, key=lambda x: [x["ALLELE"], x["LABEL"]]):
clust = list(group)
sequence = fasta_alleles[allele]

path_consensus = Path(tempdir, sample_name, "consensus", allele, str(label))
cons_mutation_loci = io.load_pickle(Path(path_consensus, "mutation_loci.pickle"))

cssplits = [cs["CSSPLIT"].split(",") for cs in clust]
cons_percentage = call_percentage(cssplits, cons_mutation_loci)
cons_percentage = call_percentage(cssplits, cons_mutation_loci, sequence)

key = ConsensusKey(allele, label, clust[0]["PERCENT"])
cons_percentages[key] = cons_percentage
Expand Down
2 changes: 1 addition & 1 deletion src/DAJIN2/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def execute_sample(arguments: dict):

consensus.cache_mutation_loci(ARGS, clust_downsampled)

cons_percentage, cons_sequence = consensus.call_consensus(ARGS.tempdir, ARGS.sample_name, clust_downsampled)
cons_percentage, cons_sequence = consensus.call_consensus(ARGS.tempdir, ARGS.sample_name, ARGS.fasta_alleles, clust_downsampled)

allele_names = consensus.call_allele_name(cons_sequence, cons_percentage, ARGS.fasta_alleles)
cons_percentage = consensus.update_key_by_allele_name(cons_percentage, allele_names)
Expand Down
19 changes: 4 additions & 15 deletions tests/src/consensus/test_consensus.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
from src.DAJIN2.core.consensus.consensus import (
adjust_to_100_percent,
call_percentage,
replace_sequence_error,
)

###########################################################
# replace_sequence
###########################################################


def test_replace_sequence_error():
cons_percentage = [{"A": 25, "C": 25, "SEQERROR": 50}, {"SEQERROR": 100}]
expected_output = [{"A": 25, "C": 25}, {"N": 100}]
assert replace_sequence_error(cons_percentage) == expected_output


###########################################################
# adjust_to_100_percent
###########################################################
Expand All @@ -38,13 +26,14 @@ def test_adjust_to_100_percent_float():


def test_call_percentage():
cssplits = [["+A|=C", "-T", "=C", "=A", "=T"], ["-C", "=T", "=C", "*AT", "*AT"]]
cssplits = [["+A|=C", "-T", "=C", "=A", "=T"], ["-C", "=T", "=C", "*AT", "*TA"]]
mutation_loci = [{"+", "-"}, {"-"}, {}, {}, {"*"}]
sequence = "CTCAT"
expected_output = [
{"+A|=C": 50.0, "-C": 50.0},
{"-T": 50.0, "=T": 50.0},
{"=C": 100.0},
{"=A": 100.0},
{"=T": 50.0, "*AT": 50.0},
{"=T": 50.0, "*TA": 50.0},
]
assert call_percentage(cssplits, mutation_loci) == expected_output
assert call_percentage(cssplits, mutation_loci, sequence) == expected_output

0 comments on commit 1f46215

Please sign in to comment.