diff --git a/figures/supplement_fig_1.py b/figures/supplement_fig_1.py new file mode 100644 index 0000000..e54fa8d --- /dev/null +++ b/figures/supplement_fig_1.py @@ -0,0 +1,209 @@ +import gzip +import itertools +import json +import os +import subprocess + +import matplotlib.pyplot as plt # type: ignore +import numpy as np +import pandas as pd +import seaborn as sns # type: ignore +from scipy.stats import gmean, mannwhitneyu # type: ignore + +dashboard = os.path.expanduser("~/code/mgs-pipeline/dashboard/") + +with open(os.path.join(dashboard, "human_virus_sample_counts.json")) as inf: + human_virus_sample_counts = json.load(inf) + +with open(os.path.join(dashboard, "metadata_samples.json")) as inf: + metadata_samples = json.load(inf) + +with open(os.path.join(dashboard, "metadata_bioprojects.json")) as inf: + metadata_bioprojects = json.load(inf) + +with open(os.path.join(dashboard, "metadata_papers.json")) as inf: + metadata_papers = json.load(inf) + +with open(os.path.join(dashboard, "taxonomic_names.json")) as inf: + taxonomic_names = json.load(inf) + + +studies = list(metadata_papers.keys()) + + +target_taxa = { + 2731341: ("duplodnaviria", "DNA Viruses"), + 2732004: ("varidnaviria", "DNA Viruses"), + 2731342: ("monodnaviria", "DNA Viruses"), + 2842242: ("ribozyviria", "RNA Viruses"), + 687329: ("anelloviridae", "DNA Viruses"), + 2559587: ("riboviria", "RNA Viruses"), +} + + +plotting_data = [] +for study in studies: + # Dropping studies that aren't WTP based + if study not in [ + "Bengtsson-Palme 2016", + "Munk 2022", + "Brinch 2020", + "Ng 2019", + "Maritz 2019", + "Brumfield 2022", + "Rothman 2021", + "Yang 2020", + "Spurbeck 2023", + "Crits-Christoph 2021", + ]: + continue + if study == "McCall 2023": + continue + + sequencing_type = metadata_papers[study]["na_type"] + + for bioproject in metadata_papers[study]["projects"]: + samples = metadata_bioprojects[bioproject] + + if study == "Bengtsson-Palme 2016": + samples = [ + sample + for sample in samples + if metadata_samples[sample]["fine_location"].startswith( + "Inlet" + ) + ] + + if study == "Ng 2019": + samples = [ + sample + for sample in samples + if metadata_samples[sample]["fine_location"] == "Influent" + ] + for sample in samples: + if metadata_samples[sample].get("enrichment") == "panel": + continue + + if study == "Brumfield 2022": + # only study where we have separate DNA and RNA sequencing samples + sequencing_type = metadata_samples[sample]["na_type"] + + cladecounts = "%s.tsv.gz" % sample + if not os.path.exists(f"../cladecounts/{cladecounts}"): + subprocess.check_call( + [ + "aws", + "s3", + "cp", + "s3://nao-mgs/%s/cladecounts/%s" + % (bioproject, cladecounts), + "cladecounts/", + ] + ) + with gzip.open(f"../cladecounts/{cladecounts}") as inf: + na_type_abundances = { + "DNA Viruses": 0, + "RNA Viruses": 0, + } + for line in inf: + ( + line_taxid, + _, + _, + clade_assignments, + _, + ) = line.strip().split() + taxid = int(line_taxid) + clade_hits = int(clade_assignments) + if taxid in target_taxa: + nucleic_acid_type = target_taxa[taxid][1] + relative_abundance = ( + clade_hits / metadata_samples[sample]["reads"] + ) + + na_type_abundances[ + nucleic_acid_type + ] += relative_abundance + + plotting_data.append( + { + "study": study, + "sample": sample, + **na_type_abundances, + "sequencing_type": sequencing_type, + } + ) + + +df = pd.DataFrame(plotting_data) +df_plotting = df[(df.drop(["study", "sample"], axis=1) != 0).all(1)].melt( + id_vars=["study", "sample", "sequencing_type"], + value_vars=["DNA Viruses", "RNA Viruses"], + var_name="virus_na_type", + value_name="relative_abundance", +) + +df_plotting["seq_na_virus_combo"] = ( + df_plotting["sequencing_type"] + + " Sequencing / " + + df_plotting["virus_na_type"] +) + +seq_na_virus_combo_ordered = [ + "DNA Sequencing / DNA Viruses", + "RNA Sequencing / DNA Viruses", + "DNA Sequencing / RNA Viruses", + "RNA Sequencing / RNA Viruses", +] + +df_plotting["seq_na_virus_combo"] = pd.Categorical( + df_plotting["seq_na_virus_combo"], + categories=seq_na_virus_combo_ordered, + ordered=True, +) + +df_plotting = df_plotting.sort_values("seq_na_virus_combo") + +combinations = list(itertools.pairwise(seq_na_virus_combo_ordered)) + +p_values = [] + +for combination in combinations: + _, p_value = mannwhitneyu( + df_plotting[df_plotting["seq_na_virus_combo"] == combination[0]][ + "relative_abundance" + ], + df_plotting[df_plotting["seq_na_virus_combo"] == combination[1]][ + "relative_abundance" + ], + ) + p_values.append(p_value) + print( + f"p_value when comparing '{combination[0]}' and '{combination[1]}' = {p_value}" + ) + +df_plotting["log_relative_abundance"] = np.log10( + df_plotting["relative_abundance"] +) + +plt.figure(figsize=(8, 4)) +sns.boxplot( + x="log_relative_abundance", y="seq_na_virus_combo", data=df_plotting +) + +plt.xlabel("Logged Relative Abundance") + +plt.ylabel("") +ax = plt.gca() + + +for i in range(0, 3): + p_value = p_values[i] # pulling out the respective p_value + ax.text( + 0.2, + i + 0.5, + f"p < {0.0001}" if p_value < 0.0001 else f"p = {round(p_value, 3)}", + ) + +plt.tight_layout() +plt.savefig("supplement_figure_1.pdf") diff --git a/figures/supplement_fig_2_to_4.py b/figures/supplement_fig_2_to_4.py new file mode 100755 index 0000000..91624ee --- /dev/null +++ b/figures/supplement_fig_2_to_4.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +import sys +from pathlib import Path + +sys.path.append("..") + +import matplotlib.patches as mpatches # type: ignore +import matplotlib.pyplot as plt # type: ignore +import matplotlib.ticker as ticker # type: ignore +import numpy as np +import pandas as pd +import seaborn as sns # type: ignore + +from pathogens import pathogens + + +def nucleic_acid(pathogen: str) -> str: + return pathogens[pathogen].pathogen_chars.na_type.value + + +def selection_round(pathogen: str) -> str: + return pathogens[pathogen].pathogen_chars.selection.value + + +def study_name(study: str) -> str: + return { + "brinch": "Brinch (DNA)", + "crits_christoph": "Crits-Christoph", + "rothman": "Rothman", + "spurbeck": "Spurbeck", + }[study] + + +plt.rcParams["font.size"] = 8 + + +def separate_viruses(ax) -> None: + yticks = ax.get_yticks() + ax.hlines( + [(y1 + y2) / 2 for y1, y2 in zip(yticks[:-1], yticks[1:])], + *ax.get_xlim(), + color="grey", + linewidth=0.3, + linestyle=":", + ) + + +def adjust_axes(ax, predictor_type: str, target_x: str) -> None: + yticks = ax.get_yticks() + # Y-axis is reflected + ax.set_ylim([max(yticks) + 0.5, min(yticks - 0.5)]) + ax.tick_params(left=False) + ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_func)) + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.vlines( + ax.get_xticks()[1:-1], + *ax.get_ylim(), + color="grey", + linewidth=0.3, + linestyle=":", + zorder=-1, + ) + target_percentage = { + "log10ra_at_1in1000": "0.1%", + "log10ra_at_1in10000": "0.01%", + "log10ra_at_1in10": "10%", + } + # ax.set_xscale("log") + ax.set_xlabel( + r"$\mathrm{RA}$" + f"{predictor_type[0]}" + f"({target_percentage[target_x]})" + f": expected relative abundance at {target_percentage[target_x]} " + f"{predictor_type} " + ) + ax.set_ylabel("") + + +def plot_violin( + ax, + data: pd.DataFrame, + viral_reads: pd.DataFrame, + y: str, + x: str, + sorting_order: list[str], + ascending: list[bool], + hatch_zero_counts: bool = False, + violin_scale=1.0, +) -> None: + assert len(sorting_order) == len(ascending) + plotting_order = viral_reads.sort_values( + sorting_order, ascending=ascending + ).reset_index() + sns.violinplot( + ax=ax, + data=data, + x=x, + y=y, + order=plotting_order[y].unique(), + hue="study", + hue_order=plotting_order.study.unique(), + inner=None, + linewidth=0.0, + bw=0.5, + scale="area", + scale_hue=False, + cut=0, + ) + x_min = ax.get_xlim()[0] + for num_reads, patches in zip(plotting_order.viral_reads, ax.collections): + # alpha = min((num_reads + 1) / 10, 1.0) + if num_reads == 0: + alpha = 0.5 + elif num_reads < 10: + alpha = 0.5 + else: + alpha = 1.0 + patches.set_alpha(alpha) + # Make violins fatter and hatch if zero counts + for path in patches.get_paths(): + y_mid = path.vertices[0, 1] + path.vertices[:, 1] = ( + violin_scale * (path.vertices[:, 1] - y_mid) + y_mid + ) + if (not hatch_zero_counts) and (num_reads == 0): + color = patches.get_facecolor() + y_max = np.max(path.vertices[:, 1]) + y_min = np.min(path.vertices[:, 1]) + x_max = path.vertices[np.argmax(path.vertices[:, 1]), 0] + rect = mpatches.Rectangle( + (x_min, y_min), + x_max - x_min, + y_max - y_min, + facecolor=color, + linewidth=0.0, + alpha=alpha, + fill=False, + hatch="|||", + edgecolor=color, + ) + ax.add_patch(rect) + + +def format_func(value, tick_number): + return r"$10^{{{}}}$".format(int(value)) + + +def plot_incidence( + data: pd.DataFrame, input_data: pd.DataFrame, ax: plt.Axes, target_x: str +) -> plt.Axes: + predictor_type = "incidence" + if target_x == "log10ra_at_1in10": + ax.set_xlim((-15, -1)) + ax.set_xticks(list(range(-15, 1, 2))) + else: + ax.set_xlim((-15, -3)) + ax.set_xticks(list(range(-15, -1, 2))) + + plot_violin( + ax=ax, + data=data[ + (data.predictor_type == predictor_type) + & (data.location == "Overall") + & ~( + (data.study == "Crits-Christoph") + & (data.pathogen == "influenza") + ) + ], + x=target_x, + viral_reads=count_viral_reads( + input_data[input_data.predictor_type == predictor_type] + ), + y="tidy_name", + sorting_order=[ + "nucleic_acid", + "selection_round", + "samples_observed_by_tidy_name", + "tidy_name", + "study", + ], + ascending=[False, True, False, True, False], + violin_scale=2.0, + ) + + separate_viruses(ax) + adjust_axes(ax, predictor_type=predictor_type, target_x=target_x) + legend = ax.legend( + title="MGS study", + bbox_to_anchor=(1.02, 1), + loc="upper left", + borderaxespad=0, + frameon=False, + ) + for legend_handle in legend.legend_handles: # type: ignore + legend_handle.set_edgecolor(legend_handle.get_facecolor()) # type: ignore + + ax_title = ax.set_title("a", fontweight="bold") + ax_title.set_position((-0.16, 0)) + return ax + + +def plot_prevalence( + data: pd.DataFrame, input_data: pd.DataFrame, ax: plt.Axes, target_x: str +) -> plt.Axes: + predictor_type = "prevalence" + + plot_violin( + ax=ax, + data=data[ + (data.predictor_type == predictor_type) + & (data.location == "Overall") + ], + viral_reads=count_viral_reads( + input_data[input_data.predictor_type == predictor_type] + ), + x=target_x, + y="tidy_name", + sorting_order=[ + "nucleic_acid", + "selection_round", + "samples_observed_by_tidy_name", + "tidy_name", + "study", + ], + ascending=[False, True, False, True, False], + violin_scale=1.5, + ) + if target_x == "log10ra_at_1in10": + ax.set_xlim((-15, -1)) + ax.set_xticks(list(range(-15, 1, 2))) + else: + ax.set_xlim((-15, -3)) + ax.set_xticks(list(range(-15, -1, 2))) + separate_viruses(ax) + # TODO Get these values automatically + num_rna_1 = 2 + num_dna_1 = 4 + ax.hlines( + [num_rna_1 - 0.5, num_rna_1 + num_dna_1 - 0.5], + *ax.get_xlim(), + linestyle="solid", + color="k", + linewidth=0.5, + ) + text_x = np.log10(1.1e-3) + ax.text(text_x, -0.4, "RNA viruses\nSelection Round 1", va="top") + ax.text( + text_x, num_rna_1 - 0.4, "DNA viruses\nSelection Round 1", va="top" + ) + ax.text( + text_x, + num_rna_1 + num_dna_1 - 0.4, + "DNA viruses\nSelection Round 2", + va="top", + ) + adjust_axes(ax, predictor_type=predictor_type, target_x=target_x) + legend = ax.legend( + title="MGS study", + bbox_to_anchor=(1.02, 0), + loc="lower left", + borderaxespad=0, + frameon=False, + ) + for legend_handle in legend.legend_handles: # type: ignore + legend_handle.set_edgecolor(legend_handle.get_facecolor()) # type: ignore + + ax_title = ax.set_title("b", fontweight="bold") + ax_title.set_position((-0.16, 0)) + + return ax + + +def count_viral_reads( + df: pd.DataFrame, by_location: bool = False +) -> pd.DataFrame: + groups = [ + "pathogen", + "tidy_name", + "predictor_type", + "study", + "nucleic_acid", + "selection_round", + ] + if by_location: + groups.append("location") + out = df.groupby(groups)[["viral_reads", "observed?"]].sum().reset_index() + out["reads_by_tidy_name"] = out.viral_reads.groupby( + out.tidy_name + ).transform("sum") + out["samples_observed_by_tidy_name"] = ( + out["observed?"].groupby(out.tidy_name).transform("sum") + ) + + return out + + +def composite_figure( + data: pd.DataFrame, + input_data: pd.DataFrame, + target_x: str, +) -> plt.Figure: + fig = plt.figure( + figsize=(5, 8), + ) + gs = fig.add_gridspec(2, 1, height_ratios=[5, 12], hspace=0.2) + plot_incidence(data, input_data, fig.add_subplot(gs[0, 0]), target_x) + plot_prevalence(data, input_data, fig.add_subplot(gs[1, 0]), target_x) + + return fig + + +def save_plot(fig, figdir: Path, name: str) -> None: + for ext in ["pdf", "png"]: + fig.savefig( + figdir / f"{name}.{ext}", + bbox_inches="tight", + dpi=600, + ) + + +def start() -> None: + parent_dir = Path("..") + figdir = Path(parent_dir / "figures") + figdir.mkdir(exist_ok=True) + + fits_df = pd.read_csv(parent_dir / "fits.tsv", sep="\t") + fits_df["study"] = fits_df.study.map(study_name) + + input_df = pd.read_csv(parent_dir / "input.tsv", sep="\t") + input_df["study"] = input_df.study.map(study_name) + # TODO: Store these in the files instead? + input_df["nucleic_acid"] = input_df.pathogen.map(nucleic_acid) + input_df["selection_round"] = input_df.pathogen.map(selection_round) + input_df["observed?"] = input_df.viral_reads > 0 + # For consistency between dataframes (TODO: fix that elsewhere) + input_df["location"] = input_df.fine_location + + fits_df["log10ra"] = np.log10(fits_df.ra_at_1in100) + + fits_df = fits_df[fits_df["pathogen"] != "aav5"] # FIX ME + input_df = input_df[input_df["pathogen"] != "aav5"] # FIX ME + + for target_x, change_factor in [ + ("log10ra_at_1in1000", -1), + ("log10ra_at_1in10000", -2), + ("log10ra_at_1in10", 1), + ]: + fits_df[target_x] = fits_df.log10ra + change_factor + fig = composite_figure(fits_df, input_df, target_x) + save_plot(fig, figdir, f"supplement_fig_2_{target_x}") + + +if __name__ == "__main__": + start() diff --git a/figures/supplement_fig_5.py b/figures/supplement_fig_5.py new file mode 100644 index 0000000..bc9982f --- /dev/null +++ b/figures/supplement_fig_5.py @@ -0,0 +1,214 @@ +import csv +from collections import defaultdict +from dataclasses import dataclass +from math import log + +import matplotlib.pyplot as plt # type: ignore +import numpy as np +import pandas as pd +import seaborn as sns # type: ignore + +PERCENTILES = [5, 25, 50, 75, 95] + + +def fits_df() -> pd.DataFrame: + data: defaultdict[str, list] = defaultdict(list) + + for p in PERCENTILES: + data[f"{p}%"] = [] + + with open("fits_summary.tsv") as datafile: + reader = csv.DictReader(datafile, delimiter="\t") + for row in reader: + if row["location"] != "Overall": + continue + if (row["study"]) == "brinch": + continue + data["predictor_type"].append(row["predictor_type"]) + data["virus"].append(row["tidy_name"]) + data["study"].append(row["study"]) + data["location"].append(row["location"]) + data["mean"].append(float(row["mean"])) + data["std"].append(float(row["std"])) + data["min"].append(float(row["min"])) + data["max"].append(float(row["max"])) + for p in PERCENTILES: + data[f"{p}%"].append(log(float(row[f"{p}%"]), 10)) + + df = pd.DataFrame.from_dict(data) + return df + + +def reads_df() -> pd.DataFrame: + df = pd.read_csv("input.tsv", sep="\t") + return df + + +def study_name(study: str) -> str: + return { + "brinch": "Brinch (DNA)", + "crits_christoph": "Crits-Christoph", + "rothman": "Rothman", + "spurbeck": "Spurbeck", + }[study] + + +def count_viral_reads( + df: pd.DataFrame, by_location: bool = False +) -> pd.DataFrame: + groups = [ + "pathogen", + "study", + "tidy_name", + ] + reads_by_study_and_pathogen = ( + df.groupby(groups)[["viral_reads"]].sum().reset_index() + ) + # create a tsv + reads_by_study_and_pathogen.to_csv( + "reads_by_study_and_pathogen.tsv", sep="\t" + ) + return reads_by_study_and_pathogen + + +def compute_diffs(df: pd.DataFrame) -> pd.DataFrame: + viruses = df["virus"].unique() + results_data = defaultdict(list) + for virus in viruses: + virus_df = df[df["virus"] == virus] + if virus_df["study"].nunique() < 2: + continue + if (virus_df["viral_reads"] == 0).sum() >= 2: + continue + virus_df = virus_df[virus_df["viral_reads"] != 0] + + predictor_type = virus_df["predictor_type"].unique() + + min_median_index = virus_df["50%"].idxmin() + max_median_index = virus_df["50%"].idxmax() + if virus in ["HSV-1", "CMV"]: + print(virus, virus_df.loc[min_median_index, "study"]) + print(virus, virus_df.loc[max_median_index, "study"]) + diff_median = ( + virus_df.loc[min_median_index, "50%"] + - virus_df.loc[max_median_index, "50%"] + ) + low_diff = ( + virus_df.loc[min_median_index, "5%"] + - virus_df.loc[max_median_index, "95%"] + ) + high_diff = ( + virus_df.loc[min_median_index, "95%"] + - virus_df.loc[max_median_index, "5%"] + ) + + results_data["virus"].append(virus) + results_data["diff_median"].append(diff_median) + results_data["low_diff"].append(low_diff) + results_data["high_diff"].append(high_diff) + results_data["predictor_type"].append(predictor_type) + results_data["selected_studies"].append( + [ + virus_df.loc[min_median_index, "study"], + virus_df.loc[max_median_index, "study"], + ] + ) + df = pd.DataFrame.from_dict(results_data) + return df + + +def plot_df(df: pd.DataFrame) -> None: + df = df.sort_values(by="diff_median", ascending=False).reset_index( + drop=True + ) + df = df.sort_values(by="predictor_type", ascending=False).reset_index( + drop=True + ) + fig, ax = plt.subplots(figsize=(6, 6)) + print(df["diff_median"]) + scatter = ax.scatter( + x=df["diff_median"], + y=range(len(df)), + alpha=0.6, + edgecolors="w", + ) + for i in range(len(df)): + ax.plot( + [df["low_diff"][i], df["high_diff"][i]], + [i, i], + color="k", + ) + + x_min, x_max = ax.get_xlim() + + for i in range(round(x_min), round(x_max)): + ax.axvline(i, color="k", alpha=0.1, lw=0.5) + + for i in range(len(df)): + min_study, max_study = df["selected_studies"][i] + + study_combo = f"{study_name(min_study)} <-> {study_name(max_study)}" + ax.text( + x_min - 1, + i - 0.3, + study_combo, + ha="right", + va="center", + fontsize=8, + ) + + if i == (len(df)) - 1: + ax.text( + x_max + 0.2, + i, + "Incidence\nViruses", + ha="left", + va="center", + ) + + if i == len(df) - 1: + break + if df["predictor_type"][i] != df["predictor_type"][i + 1]: + ax.axhline(i + 0.5, color="black", alpha=0.3, linestyle="--") + ax.text( + x_max + 0.2, + i + 0.05, + "Prevalence\nViruses", + ha="left", + va="center", + ) + + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + + ax.set_yticks(range(len(df))) + ax.set_yticklabels(df["virus"]) + + ax.set_title("RA(1%) Difference (median, 90% CI)") + ax.set_xlabel( + "OOM difference between lowest and highest study estimate (based on median location)" + ) + + plt.savefig("suuplement_fig_5.png", dpi=600, bbox_inches="tight") + plt.clf() + + +def start(): + reads_data = reads_df() + fits_data = fits_df() + + viral_counts = count_viral_reads(reads_data) + + fits_data_w_reads = pd.merge( + fits_data, + viral_counts, + how="left", + left_on=["virus", "study"], + right_on=["tidy_name", "study"], + ) + diffs_df = compute_diffs(fits_data_w_reads) + plot_df(diffs_df) + + +if __name__ == "__main__": + start() diff --git a/pathogen_properties.py b/pathogen_properties.py index b941a12..37df4a3 100755 --- a/pathogen_properties.py +++ b/pathogen_properties.py @@ -541,4 +541,4 @@ def by_taxids( # # It might be better to handle this in the modeling step, but by the time we # get to that point the granularity of the input data has been discarded. -QUANTITY_WHEN_NONE_OBSERVED = 0.001 +QUANTITY_WHEN_NONE_OBSERVED = 0.1 diff --git a/pathogens/aav5.py b/pathogens/aav5.py index cc4edd9..784c248 100644 --- a/pathogens/aav5.py +++ b/pathogens/aav5.py @@ -113,14 +113,10 @@ def estimate_prevalences() -> list[Prevalence]: date_source=Variable(date="2018"), location_source=Variable(country="Denmark"), ) - return [ - us_2020, - us_2021, - dk_2015, - dk_2016, - dk_2017, - dk_2018, - ] + return [] + # Dropped, because of zero counts across all studies; lends little + # additonal information to the preprint (and complicates explantion + # of how we selected viruses.) def estimate_incidences() -> list[IncidenceRate]: diff --git a/test.py b/test.py index a323468..b223199 100755 --- a/test.py +++ b/test.py @@ -57,8 +57,7 @@ def test_properties_exist(self): for estimate in pathogen.estimate_incidences(): self.assertIsInstance(estimate, IncidenceRate) saw_estimate = True - - if pathogen_name in ["aav6", "hbv", "hsv_2"]: + if pathogen_name in ["aav5", "aav6", "hbv", "hsv_2"]: # It's expected that these pathogens have no estimates; see # https://docs.google.com/document/d/1IIeOFKNqAwf9NTJeVFRSl_Q9asvu9_TGc_HSrlXg8PI/edit self.assertFalse(saw_estimate)