Skip to content

Commit

Permalink
add find_convergent_taxa function
Browse files Browse the repository at this point in the history
  • Loading branch information
farchaab committed Dec 6, 2024
1 parent bf0ebb9 commit 27b4085
Showing 1 changed file with 77 additions and 47 deletions.
124 changes: 77 additions & 47 deletions zamp/rules/DB_processing/scripts/tax_formatting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pandas as pd
import numpy as np
from snakemake.script import snakemake


# Functions
def propagate_nan(df):
Expand All @@ -23,49 +25,67 @@ def sorted_set(s):
return sorted(set(s))


def collapse_species(taxa):
def collapse_taxa(taxa, rank):
counts = {}
# Iterate through each taxon in the list
for taxon in taxa:
# Split the taxon to get the first word
first_word = taxon.split(" ")[0]
if "_placeholder_s" in first_word:
first_word = first_word.split("_placeholder_s")[0]
# Increment the count for this first word
if first_word in counts:
counts[first_word] += 1
parent = taxon.split(" ")[0]
if "_placeholder_" in parent:
parent = parent.split("_placeholder_")[0]
if parent in counts:
counts[parent] += 1
else:
counts[first_word] = 1
# Create the output string
return "/".join([f"{key}_s({value})" for key, value in counts.items()])
counts[parent] = 1
return "/".join(
[f"{key}_{rank[0].lower()}({value})" for key, value in counts.items()]
)


def format_discrepant_tax(rank, tax, rank_lim=None):
total_nb = len(tax)
if rank_lim:
if rank == "Species" and total_nb >= rank_lim["Species"]:
return collapse_species(tax)
else:
try:
nb_print = rank_lim[f"{rank}"]
tax = list(tax)[0:nb_print]
except KeyError:
tax = list(tax)
if total_nb > 1:
return "/".join(tax) + f" ({total_nb})"
else:
return list(tax)[0]
def format_taxa(tax, rank, collapse=False, n=4):
if collapse and len(tax) > n:
return collapse_taxa(tax, rank)
else:
return "/".join(tax)
return "/".join(list(tax))


def find_convergent_taxa(df):
"""
Identifies rows where a taxa is duplicated but has a different origin (Same species but two different genera for example)
"""
inconsistent_rows = []

# Iterate over columns starting from the second one
for i in range(1, len(df.columns)):
col = df.columns[i]
prev_col = df.columns[i - 1]

# Group by the current column to find duplicates
grouped = df.groupby(col)

for value, group in grouped:
# If there are duplicates in the current column
if len(group) > 1:
# Check if the previous column values are different
if group[prev_col].nunique() > 1:
inconsistent_rows.append(group)

# Combine all inconsistent groups into a single DataFrame
if inconsistent_rows:
return pd.concat(inconsistent_rows).drop_duplicates()
else:
return pd.DataFrame() # Return empty DataFrame if no inconsistencies are found


def get_lengths(tax):
return len(tax.split("/"))


def problematic_taxa(row):
def ambiguous_taxa(row):
return any("/" in str(cell) or "_s(" in str(cell) for cell in row)


# Ranks list
ranks = ["Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species"]
ranks_lim = snakemake.params.tax_collapse


# Read tables
tax_df = pd.read_csv(snakemake.input.tax, sep="\t")
Expand Down Expand Up @@ -157,8 +177,10 @@ def problematic_taxa(row):

if "silva" in snakemake.params.db_name:
## Get classified species index
index = ~prop_tax_df["Species"].str.contains(f"{placeholder["Species"]}", regex=False)
## Add genus name in species for classified species
index = ~prop_tax_df["Species"].str.contains(
f"{placeholder["Species"]}", regex=False
)
## Add parent name in species for classified species
prop_tax_df.loc[index, "Species"] = (
prop_tax_df.loc[index, "Genus"] + " " + prop_tax_df.loc[index, "Species"]
)
Expand Down Expand Up @@ -199,9 +221,7 @@ def problematic_taxa(row):
] + clust_df.loc[
clust_df[f"{rank}"].str.contains(f"{placeholder[rank]}", regex=False),
"rank_idx",
].astype(
str
)
].astype(str)

## Count sequences in each cluster
clust_df["seq_counts"] = clust_df.groupby("clust_id")["seq_id"].transform("count")
Expand All @@ -217,16 +237,28 @@ def problematic_taxa(row):


# Flag clusters with multiple taxa as descrepent for all ranks
## print only 2 species and 4 genus (by default) in formatted_rank
## print only 2 species and 4 parent (by default) in formatted_rank
## print all ranks in raw_rank
for rank in ranks:
multi_df.loc[:, f"formatted_{rank}"] = multi_df.loc[:, f"{rank}"].apply(
lambda x: format_discrepant_tax(rank, x, ranks_lim)
)
multi_df.loc[:, f"all_{rank}"] = multi_df.loc[:, f"{rank}"].apply(
lambda x: format_discrepant_tax(rank, x)
lambda x: format_taxa(x, rank)
)
multi_df.loc[:, f"formatted_{rank}"] = multi_df.loc[:, f"{rank}"].apply(
lambda x: format_taxa(x, rank, collapse=True)
)

# Find if there are any convergent Species after collapsing
formatted_ranks = [f"formatted_{rank}" for rank in ranks]
conv_df = find_convergent_taxa(multi_df[formatted_ranks])
# If there are convergent taxa
if not conv_df.empty:
# Increase collapse threshold based on the max number of species per cluster
new_collapse_threshold = int(conv_df.formatted_Species.apply(get_lengths).max() + 1)
# Re-collapse with increased threshold to get unique species names
for rank in ranks:
multi_df.loc[:, f"formatted_{rank}"] = multi_df.loc[:, f"{rank}"].apply(
lambda x: format_taxa(x, rank, collapse=True, n=new_collapse_threshold)
)

## Add seq_id, clust_id, clust_rep and seq_counts to dataframe
multi_df = multi_df.merge(
Expand All @@ -239,14 +271,14 @@ def problematic_taxa(row):

# Df with collapsed taxonomy

formatted_ranks = [f"formatted_{rank}" for rank in ranks]

multi_formatted_df = multi_df[
["clust_id", "clust_rep", "seq_id"] + formatted_ranks + ["seq_counts"]
]
multi_formatted_df.columns = cols
collapsed_df = pd.concat([single_df[cols], multi_formatted_df])
collapsed_df["taxpath"] = collapsed_df[ranks].T.agg(";".join)
collapsed_df = collapsed_df.replace("_placeholder","", regex=True)
collapsed_df = collapsed_df.replace("_placeholder", "", regex=True)

# Df with uncollapsed taxonomy

Expand All @@ -259,17 +291,15 @@ def problematic_taxa(row):
all_df["taxpath"] = all_df[ranks].T.agg(";".join)
all_df = all_df.replace("_placeholder", "", regex=True)
## Flag discrepant taxa as problematic
all_df.loc[:, "problematic_taxa"] = all_df.apply(problematic_taxa, axis=1)
all_df.loc[:, "ambiguous_taxa"] = all_df.apply(ambiguous_taxa, axis=1)


# Save tables

all_df[all_df.problematic_taxa == True].to_csv(
snakemake.output.problematic, sep="\t", index=False
)
all_df[all_df.ambiguous_taxa].to_csv(snakemake.output.ambiguous, sep="\t", index=False)

collapsed_df[["seq_id", "taxpath"]].to_csv(
snakemake.output.formatted_tax, sep="\t", index=False, header=False
snakemake.output.collapsed, sep="\t", index=False, header=False
)

all_df[["seq_id", "taxpath"]].to_csv(
Expand Down

0 comments on commit 27b4085

Please sign in to comment.