Skip to content

Commit

Permalink
If the read length is 500 bases or less, change the mappy preset to `…
Browse files Browse the repository at this point in the history
…sr`.
  • Loading branch information
akikuno committed May 22, 2024
1 parent d21827f commit 6e56804
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 105 deletions.
18 changes: 14 additions & 4 deletions src/DAJIN2/core/preprocess/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Generator

from DAJIN2.utils.dna_handler import revcomp
from DAJIN2.utils import dna_handler, sam_handler


def to_sam(
Expand Down Expand Up @@ -55,7 +55,7 @@ def to_sam(

# Handle reverse complement for negative strand
if hit.strand == -1:
query_seq = revcomp(query_seq)
query_seq = dna_handler.revcomp(query_seq)
if query_qual:
query_qual = query_qual[::-1]

Expand Down Expand Up @@ -113,13 +113,23 @@ def generate_sam(

for path_fasta in paths_fasta:
name_fasta = Path(path_fasta).stem
for preset in ["map-ont", "splice"]:
len_sequence = len(Path(path_fasta).read_text().split("\n")[1])
if len_sequence < 500:
presets = ["sr"]
else:
presets = ["map-ont", "splice"]

for preset in presets:
sam = to_sam(path_fasta, path_fastq, preset=preset, threads=ARGS.threads, options=mappy_options)

sam_removed = sam_handler.remove_overlapped_reads([record.split("\t") for record in sam])

if is_control and is_insertion:
path_sam = Path(out_directory, f"{preset}_{name_fasta}_{ARGS.sample_name}.sam")
else:
path_sam = Path(out_directory, f"{preset}_{name_fasta}.sam")
path_sam.write_text("\n".join(sam))

path_sam.write_text("\n".join("\t".join(record) for record in sam_removed))


########################################################################
Expand Down
121 changes: 64 additions & 57 deletions src/DAJIN2/core/preprocess/midsv_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from pathlib import Path
from typing import Generator
from itertools import chain, groupby

from collections import Counter
from collections import Counter, defaultdict

from DAJIN2.utils import io, sam_handler, cssplits_handler

Expand All @@ -24,51 +22,60 @@ def has_inversion_in_splice(CIGAR: str) -> bool:
return False


def extract_qname_of_map_ont(sam_ont: Generator[list[str]], sam_splice: Generator[list[str]]) -> set[str]:
"""Extract qname of reads from `map-ont` when:
- no inversion signal in `splice` alignment (insertion + deletion)
- single read
- long alignment length
"""
alignments_splice = {s[0]: s for s in sam_splice if not s[0].startswith("@")}
alignments_ont = sorted([s for s in sam_ont if not s[0].startswith("@")], key=lambda x: x[0])
qname_of_map_ont = set()
for qname_ont, group in groupby(alignments_ont, key=lambda x: x[0]):
group = list(group)

if qname_ont not in alignments_splice:
qname_of_map_ont.add(qname_ont)
def extract_preset_and_cigar_by_qname(path_sam_files: Generator[Path]) -> dict[dict[str, str]]:
preset_cigar_by_qname = defaultdict(dict)
# Extract preset and CIGAR
for path in path_sam_files:
preset = path.stem.split("_")[0]
sam = io.read_sam(path)
for record in sam:
if record[0].startswith("@"):
continue
qname = record[0]
cigar = record[5]
preset_cigar_by_qname[qname].update({preset: cigar})
return dict(preset_cigar_by_qname)


def extract_best_preset(preset_cigar_by_qname: dict[str, dict[str, str]]) -> dict[str, str]:
best_preset = defaultdict(str)
for qname in preset_cigar_by_qname:
alignment_lengths = {
preset: sam_handler.calculate_alignment_length(cigar)
for preset, cigar in preset_cigar_by_qname[qname].items()
}

# If all alignment lengths are the same, prioritize map-ont
if len(set(alignment_lengths.values())) == 1 and "map-ont" in alignment_lengths:
best_preset[qname] = "map-ont"
continue

cigar_splice = alignments_splice[qname_ont][5]
# Define a custom key function to prioritize map-ont
def custom_key(key: str) -> tuple[int, bool]:
return (alignment_lengths[key], key == "map-ont")

# If preset=splice and inversion is present, `midsv.transform`` will not work, so use preset=map-ont.
if has_inversion_in_splice(cigar_splice):
qname_of_map_ont.add(qname_ont)
continue
max_key = max(alignment_lengths, key=custom_key)
best_preset[qname] = max_key

# only accept single read or inversion reads
# TODO: 逆位のリードは単純に「3つのリードがある」という条件でいいのか?
if len(group) == 2 or len(group) >= 4:
continue
cigar_ont = group[0][5]
return dict(best_preset)

alignment_length_ont = sam_handler.calculate_alignment_length(cigar_ont)
alignment_length_splice = sam_handler.calculate_alignment_length(cigar_splice)

if alignment_length_ont >= alignment_length_splice:
qname_of_map_ont.add(qname_ont)
return qname_of_map_ont


def filter_sam_by_preset(sam: Generator[list[str]], qname_of_map_ont: set, preset: str = "map-ont") -> Generator:
for alignment in sam:
if alignment[0].startswith("@"):
yield alignment
elif preset == "map-ont" and alignment[0] in qname_of_map_ont:
yield alignment
elif preset != "map-ont" and alignment[0] not in qname_of_map_ont:
yield alignment
def extract_best_alignment_length_from_sam(
path_sam_files: Generator[Path], best_preset: dict[str, str]
) -> Generator[list[str]]:
flag_header = False
for path in path_sam_files:
preset = path.stem.split("_")[0]
sam = io.read_sam(path)
for record in sam:
if record[0].startswith("@"):
if not flag_header:
yield record
else:
qname = record[0]
if best_preset[qname] == preset:
yield record
flag_header = True


def transform_to_midsv_format(sam: Generator[list[str]]) -> Generator[list[dict]]:
Expand Down Expand Up @@ -127,11 +134,6 @@ def filter_samples_by_n_proportion(midsv_sample: Generator[dict], threshold: int
# convert_consecutive_indels_to_match
###########################################################

"""
Due to alignment errors, there can be instances where a true match is mistakenly replaced with "insertion following a deletion".
For example, although it should be "=C,=T", it gets replaced by "-C,+C|=T". In such cases, a process is performed to revert it back to "=C,=T".
"""


def convert_consecutive_indels_to_match(cssplit: str) -> str:
i = 0
Expand Down Expand Up @@ -176,11 +178,20 @@ def convert_consecutive_indels_to_match(cssplit: str) -> str:


def convert_consecutive_indels(midsv_sample: Generator) -> Generator[list[dict]]:
"""
Due to alignment errors, there can be instances where a true match is mistakenly replaced with "insertion following a deletion".
For example, although it should be "=C,=T", it gets replaced by "-C,+C|=T". In such cases, a process is performed to revert it back to "=C,=T".
"""
for m in midsv_sample:
m["CSSPLIT"] = convert_consecutive_indels_to_match(m["CSSPLIT"])
yield m


###########################################################
# reallocate_insertions_within_deletion
###########################################################


def reallocate_insertions_within_deletion(midsv_sample: Generator) -> Generator[list[dict]]:
for m in midsv_sample:
m["CSSPLIT"] = cssplits_handler.reallocate_insertion_within_deletion(m["CSSPLIT"])
Expand All @@ -203,23 +214,19 @@ def generate_midsv(ARGS, is_control: bool = False, is_insertion: bool = False) -
"""
Set the destination for midsv as `barcode01/midsv/insertion1_barcode02.json` when control is barcode01, sample is barcode02, and the allele is insertion1.
"""
path_ont = Path(ARGS.tempdir, name, "sam", f"map-ont_{allele}_{ARGS.sample_name}.sam")
path_splice = Path(ARGS.tempdir, name, "sam", f"splice_{allele}_{ARGS.sample_name}.sam")
path_sam_files = list(Path(ARGS.tempdir, name, "sam").glob(f"*_{allele}_{ARGS.sample_name}.sam"))
path_output_midsv = Path(ARGS.tempdir, name, "midsv", f"{allele}_{ARGS.sample_name}.json")
else:
"""
Set the destination for midsv as `barcode02/midsv/insertion1.json` when the sample is barcode02 and the allele is insertion1.
"""
path_ont = Path(ARGS.tempdir, name, "sam", f"map-ont_{allele}.sam")
path_splice = Path(ARGS.tempdir, name, "sam", f"splice_{allele}.sam")
path_sam_files = list(Path(ARGS.tempdir, name, "sam").glob(f"*{allele}.sam"))
path_output_midsv = Path(ARGS.tempdir, name, "midsv", f"{allele}.json")

sam_ont = sam_handler.remove_overlapped_reads(list(io.read_sam(path_ont)))
sam_splice = sam_handler.remove_overlapped_reads(list(io.read_sam(path_splice)))
qname_of_map_ont = extract_qname_of_map_ont(sam_ont, sam_splice)
sam_of_map_ont = filter_sam_by_preset(sam_ont, qname_of_map_ont, preset="map-ont")
sam_of_splice = filter_sam_by_preset(sam_splice, qname_of_map_ont, preset="splice")
midsv_chaind = transform_to_midsv_format(chain(sam_of_map_ont, sam_of_splice))
preset_cigar_by_qname = extract_preset_and_cigar_by_qname(path_sam_files)
best_preset = extract_best_preset(preset_cigar_by_qname)
sam_best_alignments = extract_best_alignment_length_from_sam(path_sam_files, best_preset)
midsv_chaind = transform_to_midsv_format(sam_best_alignments)
midsv_sample = replace_internal_n_to_d(midsv_chaind, sequence)
midsv_sample = convert_flag_to_strand(midsv_sample)
midsv_sample = filter_samples_by_n_proportion(midsv_sample)
Expand Down
3 changes: 3 additions & 0 deletions src/DAJIN2/core/report/bam_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def update_sam(sam: list, GENOME_COODINATES: dict = {}) -> list:

def export_to_bam(TEMPDIR, NAME, GENOME_COODINATES, THREADS, UUID, RESULT_SAMPLE=None, is_control=False) -> None:
path_sam_input = Path(TEMPDIR, NAME, "sam", "map-ont_control.sam")
if not path_sam_input.exists(): # In the case of short-read.
path_sam_input = Path(TEMPDIR, NAME, "sam", "sr_control.sam")

sam_records = list(io.read_sam(path_sam_input))

# Update sam
Expand Down
45 changes: 1 addition & 44 deletions tests/src/preprocess/test_midsv_caller.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

import pytest
import midsv

from DAJIN2.core.preprocess.midsv_caller import has_inversion_in_splice
from DAJIN2.core.preprocess.midsv_caller import extract_qname_of_map_ont
from DAJIN2.core.preprocess.midsv_caller import replace_internal_n_to_d
from DAJIN2.core.preprocess.midsv_caller import convert_flag_to_strand
from DAJIN2.core.preprocess.midsv_caller import convert_consecutive_indels_to_match
Expand Down Expand Up @@ -61,47 +59,6 @@ def test_has_inversion_in_splice_random_deletion():
# test_cigar = [s[5] for s in test if not s[0].startswith("@")]
# assert not has_inversion_in_splice(test_cigar[0])

###########################################################
# extract_qname_of_map_ont
###########################################################


def test_extract_qname_of_map_ont_simulation():
sam_ont = [["@header"], ["read1", "", "", "", "", "5M"]]
sam_splice = [["@header"], ["read1", "", "", "", "", "5M"]]
qname_of_map_ont = extract_qname_of_map_ont(iter(sam_ont), iter(sam_splice))
assert qname_of_map_ont == {"read1"}

# Large deletion
sam_ont = [["@header"], ["read1", "", "", "", "", "5M"], ["read1", "", "", "", "", "5M"]]
sam_splice = [["@header"], ["read1", "", "", "", "", "5M100D5M"]]
qname_of_map_ont = extract_qname_of_map_ont(iter(sam_ont), iter(sam_splice))
assert qname_of_map_ont == set()

# Inversion
sam_ont = [
["@header"],
["read1", "", "", "", "", "5M"],
["read1", "", "", "", "", "100M"],
["read1", "", "", "", "", "5M"],
]
sam_splice = [["@header"], ["read1", "", "", "", "", "5M100I100N5M"]]
qname_of_map_ont = extract_qname_of_map_ont(iter(sam_ont), iter(sam_splice))
assert qname_of_map_ont == {"read1"}

# Inversion with single read in map-ont
sam_ont = [["@header"], ["read1", "", "", "", "", "5M10N5M"]]
sam_splice = [["@header"], ["read1", "", "", "", "", "5M100I100N5M"]]
qname_of_map_ont = extract_qname_of_map_ont(iter(sam_ont), iter(sam_splice))
assert qname_of_map_ont == {"read1"}


def test_extract_qname_of_map_ont_real():
sam_ont = list(midsv.read_sam(Path("tests", "data", "preprocess", "midsv_caller", "stx2-ont_deletion.sam")))
sam_splice = list(midsv.read_sam(Path("tests", "data", "preprocess", "midsv_caller", "stx2-splice_deletion.sam")))
qname_of_map_ont = extract_qname_of_map_ont(iter(sam_ont), iter(sam_splice))
assert qname_of_map_ont == {"stx2-small-deletion"}


###########################################################
# replace n to d
Expand Down Expand Up @@ -201,7 +158,7 @@ def test_convert_flag_to_strand(input_sample, expected_output):
# no change
("=A,=C,=G,=T,=A", "=A,=C,=G,=T,=A"),
# empty
("", "")
("", ""),
],
)
def test_convert_consecutive_indels_to_match(cons, expected):
Expand Down

0 comments on commit 6e56804

Please sign in to comment.