diff --git a/docs/_static/docstring_previews/coxph_forestplot.png b/docs/_static/docstring_previews/coxph_forestplot.png new file mode 100644 index 00000000..82e72b89 Binary files /dev/null and b/docs/_static/docstring_previews/coxph_forestplot.png differ diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 99b17e70..ac088bca 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 99b17e7039699548a908433fa3ee6b5cbac5e29f +Subproject commit ac088bcabae5de8516ca9a5aa036b4e3cdf67df6 diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 0c740e95..4a57b84e 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -2,6 +2,6 @@ from ehrapy.plot._colormaps import * # noqa: F403 from ehrapy.plot._missingno_pl_api import * # noqa: F403 from ehrapy.plot._scanpy_pl_api import * # noqa: F403 -from ehrapy.plot._survival_analysis import kaplan_meier, kmf, ols +from ehrapy.plot._survival_analysis import cox_ph_forestplot, kaplan_meier, ols from ehrapy.plot.causal_inference._dowhy import causal_effect from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index 717f9202..5230ed1a 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -4,17 +4,20 @@ from typing import TYPE_CHECKING import matplotlib.pyplot as plt +import matplotlib.ticker as ticker import numpy as np +import pandas as pd +from matplotlib import gridspec from numpy import ndarray from ehrapy.plot import scatter if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from xmlrpc.client import Boolean from anndata import AnnData - from lifelines import KaplanMeierFitter + from lifelines import CoxPHFitter, KaplanMeierFitter from matplotlib.axes import Axes from statsmodels.regression.linear_model import RegressionResults @@ -293,5 +296,154 @@ def kaplan_meier( if not show: return ax + + else: + return None + + +def cox_ph_forestplot( + cox_ph: CoxPHFitter, + *, + labels: Iterable[str] | None = None, + fig_size: tuple = (10, 10), + t_adjuster: float = 0.1, + ecolor: str = "dimgray", + size: int = 3, + marker: str = "o", + decimal: int = 2, + text_size: int = 12, + color: str = "k", + show: bool = None, + title: str | None = None, +): + """Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model. + The method requires a fitted CoxPHFitter object from the lifelines library. + + Inspired by `zepid.graphics.EffectMeasurePlot `_ (zEpid Package, https://pypi.org/project/zepid/). + + Args: + coxph: Fitted CoxPHFitter object from the lifelines library. + labels: List of labels for each coefficient, default uses the index of the coxph.summary. + fig_size: Width, height in inches. + t_adjuster: Adjust the table to the right. + ecolor: Color of the error bars. + size: Size of the markers. + marker: Marker style. + decimal: Number of decimal places to display. + text_size: Font size of the text. + color: Color of the markers. + show: Show the plot, do not return figure and axis. + title: Set the title of the plot. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> adata_subset = adata[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]] + >>> coxph = ep.tl.cox_ph(adata_subset, event_col="censor_flg", duration_col="mort_day_censored") + >>> ep.pl.cox_ph_forestplot(coxph) + + .. image:: /_static/docstring_previews/coxph_forestplot.png + + """ + coxph_summary = cox_ph.summary + auc_col = "coef" + + if labels is None: + labels = coxph_summary.index + tval = [] + ytick = [] + for row_index in range(len(coxph_summary)): + if not np.isnan(coxph_summary[auc_col][row_index]): + if ( + (isinstance(coxph_summary[auc_col][row_index], float)) + & (isinstance(coxph_summary["coef lower 95%"][row_index], float)) + & (isinstance(coxph_summary["coef upper 95%"][row_index], float)) + ): + tval.append( + [ + round(coxph_summary[auc_col][row_index], decimal), + ( + "(" + + str(round(coxph_summary["coef lower 95%"][row_index], decimal)) + + ", " + + str(round(coxph_summary["coef upper 95%"][row_index], decimal)) + + ")" + ), + ] + ) + else: + tval.append( + [ + coxph_summary[auc_col][row_index], + ( + "(" + + str(coxph_summary["coef lower 95%"][row_index]) + + ", " + + str(coxph_summary["coef upper 95%"][row_index]) + + ")" + ), + ] + ) + ytick.append(row_index) + else: + tval.append([" ", " "]) + ytick.append(row_index) + + x_axis_upper_bound = round(((pd.to_numeric(coxph_summary["coef upper 95%"])).max() + 0.1), 2) + + x_axis_lower_bound = round(((pd.to_numeric(coxph_summary["coef lower 95%"])).min() - 0.1), 1) + + fig = plt.figure(figsize=fig_size) + gspec = gridspec.GridSpec(1, 6) + plot = plt.subplot(gspec[0, 0:4]) # plot of data + tabl = plt.subplot(gspec[0, 4:]) # table + plot.set_ylim(-1, (len(coxph_summary))) # spacing out y-axis properly + + plot.axvline(1, color="gray", zorder=1) + lower_diff = coxph_summary[auc_col] - coxph_summary["coef lower 95%"] + upper_diff = coxph_summary["coef upper 95%"] - coxph_summary[auc_col] + plot.errorbar( + coxph_summary[auc_col], + coxph_summary.index, + xerr=[lower_diff, upper_diff], + marker="None", + zorder=2, + ecolor=ecolor, + linewidth=0, + elinewidth=1, + ) + plot.scatter( + coxph_summary[auc_col], coxph_summary.index, c=color, s=(size * 25), marker=marker, zorder=3, edgecolors="None" + ) + plot.xaxis.set_ticks_position("bottom") + plot.yaxis.set_ticks_position("left") + plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter()) + plot.get_xaxis().set_minor_formatter(ticker.NullFormatter()) + plot.set_yticks(ytick) + plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound]) + plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound]) + plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound]) + plot.set_yticklabels(labels) + plot.tick_params(axis="y", labelsize=text_size) + plot.yaxis.set_ticks_position("none") + plot.invert_yaxis() # invert y-axis to align values properly with table + tb = tabl.table( + cellText=tval, cellLoc="center", loc="right", colLabels=[auc_col, "95% CI"], bbox=[0, t_adjuster, 1, 1] + ) + tabl.axis("off") + tb.auto_set_font_size(False) + tb.set_fontsize(text_size) + for _, cell in tb.get_celld().items(): + cell.set_linewidth(0) + plot.spines["top"].set_visible(False) + plot.spines["right"].set_visible(False) + plot.spines["left"].set_visible(False) + + if title: + plt.title(title) + + if not show: + return fig, plot + else: return None diff --git a/tests/_scripts/coxph_forestplot_create_expected.ipynb b/tests/_scripts/coxph_forestplot_create_expected.ipynb new file mode 100644 index 00000000..d75d8d2d --- /dev/null +++ b/tests/_scripts/coxph_forestplot_create_expected.ipynb @@ -0,0 +1,86 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "import ehrapy as ep" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "current_notebook_dir = %pwd\n", + "_TEST_IMAGE_PATH = f\"{current_notebook_dir}/../plot/_images\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "adata = ep.dt.mimic_2(encoded=False)\n", + "adata_subset = adata[:, [\"mort_day_censored\", \"censor_flg\", \"gender_num\", \"afib_flg\", \"day_icu_intime_num\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "genderafib_coxph = ep.tl.cox_ph(adata_subset, duration_col=\"mort_day_censored\", event_col=\"censor_flg\")\n", + "\n", + "fig, ax = ep.pl.cox_ph_forestplot(genderafib_coxph, fig_size=(12, 3), t_adjuster=0.15, marker=\"o\", size=2, text_size=14)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig.savefig(f\"{_TEST_IMAGE_PATH}/coxph_forestplot_expected.png\", dpi=80)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/conftest.py b/tests/conftest.py index 0983e7d0..6c42f8a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,12 @@ def rng(): return np.random.default_rng(seed=42) +@pytest.fixture +def mimic_2(): + adata = ep.dt.mimic_2() + return adata + + @pytest.fixture def mimic_2_encoded(): adata = ep.dt.mimic_2(encoded=True) diff --git a/tests/plot/_images/coxph_forestplot_expected.png b/tests/plot/_images/coxph_forestplot_expected.png new file mode 100644 index 00000000..9eba19ab Binary files /dev/null and b/tests/plot/_images/coxph_forestplot_expected.png differ diff --git a/tests/plot/test_catplot.py b/tests/plot/test_catplot.py index 64ffb517..e3591132 100644 --- a/tests/plot/test_catplot.py +++ b/tests/plot/test_catplot.py @@ -1,13 +1,13 @@ from pathlib import Path -from ehrapy.plot import catplot +import ehrapy as ep CURRENT_DIR = Path(__file__).parent _TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images" def test_catplot_vanilla(adata_mini, check_same_image): - fig = catplot(adata_mini, jitter=False) + fig = ep.pl.catplot(adata_mini, jitter=False) check_same_image( fig=fig, diff --git a/tests/plot/test_survival_analysis.py b/tests/plot/test_survival_analysis.py new file mode 100644 index 00000000..5196ddad --- /dev/null +++ b/tests/plot/test_survival_analysis.py @@ -0,0 +1,18 @@ +from pathlib import Path + +import ehrapy as ep + +CURRENT_DIR = Path(__file__).parent +_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images" + + +def test_coxph_forestplot(mimic_2, check_same_image): + adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]] + coxph = ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg") + fig, ax = ep.pl.cox_ph_forestplot(coxph, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14) + + check_same_image( + fig=fig, + base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot", + tol=2e-1, + )