diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index fed63b9e..d7c996c7 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -3,7 +3,7 @@ import warnings from typing import TYPE_CHECKING, Literal -import numpy as np # This package is implicitly used +import numpy as np # noqa: TC002 import pandas as pd import statsmodels.api as sm import statsmodels.formula.api as smf @@ -217,9 +217,11 @@ def kaplan_meier( https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. - duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: The name of the column in anndata that contains the subjects’ death observation. + adata: AnnData object. + duration_col: The name of the column in the AnnData object that contains the subjects’ lifetimes. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. timeline: Return the best estimate at the values in timelines (positively increasing) entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. If None, all members of the population entered study when they were "born". @@ -347,9 +349,7 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_ return dataframe -def _regression_model( - model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True -): +def _regression_model_data_frame_preparation(adata: AnnData, duration_col: str, accept_zero_duration=True): """Convenience function for regression models.""" df = anndata_to_df(adata) df = df.dropna() @@ -357,26 +357,67 @@ def _regression_model( if not accept_zero_duration: df.loc[df[duration_col] == 0, duration_col] += 1e-5 - model = model_class() - model.fit(df, duration_col, event_col, entry_col=entry_col) + return df - return model - -def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter: +def cox_ph( + adata: AnnData, + duration_col: str, + event_col: str = None, + *, + uns_key: str = "cox_ph", + alpha: float = 0.05, + label: str | None = None, + baseline_estimation_method: Literal["breslow", "spline", "piecewise"] = "breslow", + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + strata: list[str] | str | None = None, + n_baseline_knots: int = 4, + knots: list[float] | None = None, + breakpoints: list[float] | None = None, + weights_col: str | None = None, + cluster_col: str | None = None, + entry_col: str = None, + robust: bool = False, + formula: str = None, + batch_mode: bool = None, + show_progress: bool = False, + initial_point: np.ndarray | None = None, + fit_options: dict | None = None, +) -> CoxPHFitter: """Fit the Cox’s proportional hazard for the survival function. The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables. It models the hazard rate as a product of a baseline hazard function and an exponential function of the predictors, assuming proportional hazards over time. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'cox_ph' unless specified otherwise in the `uns_key` parameter. See https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: The name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the uns slot in the AnnData object. + alpha: The alpha value in the confidence intervals. + label: The name of the column of the estimate. + baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + strata: specify a list of columns to use in stratification. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R. See http://courses.washington.edu/b515/l17.pdf. + n_baseline_knots: Used when baseline_estimation_method="spline". Set the number of knots (interior & exterior) in the baseline hazard, which will be placed evenly along the time axis. Should be at least 2. Royston et. al, the authors of this model, suggest 4 to start, but any values between 2 and 8 are reasonable. If you need to customize the timestamps used to calculate the curve, use the knots parameter instead. + knots: When baseline_estimation_method="spline", this allows customizing the points in the time axis for the baseline hazard curve. To use evenly-spaced points in time, the n_baseline_knots parameter can be employed instead. + breakpoints: Used when baseline_estimation_method="piecewise". Set the positions of the baseline hazard breakpoints. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + cluster_col: The name of the column in DataFrame that contains the cluster variable. Using this forces the sandwich estimator (robust variance estimator) to be used. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + formula: an Wilkinson formula, like in R and statsmodels, for the right-hand-side. If left as None, all columns not assigned as durations, weights, etc. are used. Uses the library Formulaic for parsing. + batch_mode: Enabling batch_mode can be faster for datasets with a large number of ties. If left as `None`, lifelines will choose the best option. + show_progress: Since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + initial_point: set the starting point for the iterative solver. + fit_options: Additional keyword arguments to pass into the estimator. Returns: Fitted CoxPHFitter. @@ -388,24 +429,95 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg") """ - return _regression_model(CoxPHFitter, adata, duration_col, event_col, entry_col) + df = _regression_model_data_frame_preparation(adata, duration_col) + cox_ph = CoxPHFitter( + alpha=alpha, + label=label, + strata=strata, + baseline_estimation_method=baseline_estimation_method, + penalizer=penalizer, + l1_ratio=l1_ratio, + n_baseline_knots=n_baseline_knots, + knots=knots, + breakpoints=breakpoints, + ) + cox_ph.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + robust=robust, + initial_point=initial_point, + weights_col=weights_col, + cluster_col=cluster_col, + batch_mode=batch_mode, + formula=formula, + fit_options=fit_options, + show_progress=show_progress, + ) + + summary = cox_ph.summary + adata.uns[uns_key] = summary + return cox_ph -def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> WeibullAFTFitter: + +def weibull_aft( + adata: AnnData, + duration_col: str, + event_col: str, + *, + uns_key: str = "weibull_aft", + alpha: float = 0.05, + fit_intercept: bool = True, + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + model_ancillary: bool = True, + ancillary: bool | pd.DataFrame | str | None = None, + show_progress: bool = False, + weights_col: str | None = None, + robust: bool = False, + initial_point=None, + entry_col: str | None = None, + formula: str | None = None, + fit_options: dict | None = None, +) -> WeibullAFTFitter: """Fit the Weibull accelerated failure time regression for the survival function. The Weibull Accelerated Failure Time (AFT) survival regression model is a statistical method used to analyze time-to-event data, where the underlying assumption is that the logarithm of survival time follows a Weibull distribution. It models the survival time as an exponential function of the predictors, assuming a specific shape parameter for the distribution and allowing for accelerated or decelerated failure times based on the covariates. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'weibull_aft' unless specified otherwise in the `uns_key` parameter. + See https://lifelines.readthedocs.io/en/latest/fitters/regression/WeibullAFTFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the uns slot in the AnnData object. + alpha: The alpha value in the confidence intervals. + fit_intercept: Whether to fit an intercept term in the model. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization. + ancillary: Choose to model the ancillary parameters. + If None or False, explicitly do not fit the ancillary parameters using any covariates. + If True, model the ancillary parameters with the same covariates as ``df``. + If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``. + If str, should be a formula + show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + initial_point: set the starting point for the iterative solver. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/ + If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.) + fit_options: Additional keyword arguments to pass into the estimator. + Returns: Fitted WeibullAFTFitter. @@ -413,27 +525,96 @@ def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: st Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=False) - >>> # Flip 'censor_fl' because 0 = death and 1 = censored >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg") + >>> adata = adata[:, ["mort_day_censored", "censor_flg"]] + >>> aft = ep.tl.weibull_aft(adata, duration_col="mort_day_censored", event_col="censor_flg") + >>> aft.print_summary() """ - return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False) + df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False) -def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> LogLogisticAFTFitter: + weibull_aft = WeibullAFTFitter( + alpha=alpha, + fit_intercept=fit_intercept, + penalizer=penalizer, + l1_ratio=l1_ratio, + model_ancillary=model_ancillary, + ) + + weibull_aft.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + ancillary=ancillary, + show_progress=show_progress, + weights_col=weights_col, + robust=robust, + initial_point=initial_point, + formula=formula, + fit_options=fit_options, + ) + + summary = weibull_aft.summary + adata.uns[uns_key] = summary + + return weibull_aft + + +def log_logistic_aft( + adata: AnnData, + duration_col: str, + event_col: str | None = None, + *, + uns_key: str = "log_logistic_aft", + alpha: float = 0.05, + fit_intercept: bool = True, + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + model_ancillary: bool = False, + ancillary: bool | pd.DataFrame | str | None = None, + show_progress: bool = False, + weights_col: str | None = None, + robust: bool = False, + initial_point=None, + entry_col: str | None = None, + formula: str | None = None, + fit_options: dict | None = None, +) -> LogLogisticAFTFitter: """Fit the log logistic accelerated failure time regression for the survival function. The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data. This model operates under the assumption that the logarithm of survival time adheres to a log-logistic distribution, offering a flexible framework for understanding the impact of covariates on survival times. By modeling survival time as a function of predictors, the Log-Logistic AFT model enables researchers to explore how specific factors influence the acceleration or deceleration of failure times, providing valuable insights into the underlying mechanisms driving event occurrence. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'log_logistic_aft' unless specified otherwise in the `uns_key` parameter. + See https://lifelines.readthedocs.io/en/latest/fitters/regression/LogLogisticAFTFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the uns slot in the AnnData object. + alpha: The alpha value in the confidence intervals. + fit_intercept: Whether to fit an intercept term in the model. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + model_ancillary: Set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization. + ancillary: Choose to model the ancillary parameters. + If None or False, explicitly do not fit the ancillary parameters using any covariates. + If True, model the ancillary parameters with the same covariates as ``df``. + If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``. + If str, should be a formula + show_progress: Since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + initial_point: set the starting point for the iterative solver. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/ + If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.) + fit_options: Additional keyword arguments to pass into the estimator. Returns: Fitted LogLogisticAFTFitter. @@ -443,12 +624,38 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co >>> adata = ep.dt.mimic_2(encoded=False) >>> # Flip 'censor_fl' because 0 = death and 1 = censored >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg") + >>> adata = adata[:, ["mort_day_censored", "censor_flg"]] + >>> llf = ep.tl.log_logistic_aft(adata, duration_col="mort_day_censored", event_col="censor_flg") """ - return _regression_model( - LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False + df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False) + + log_logistic_aft = LogLogisticAFTFitter( + alpha=alpha, + fit_intercept=fit_intercept, + penalizer=penalizer, + l1_ratio=l1_ratio, + model_ancillary=model_ancillary, + ) + + log_logistic_aft.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + ancillary=ancillary, + show_progress=show_progress, + weights_col=weights_col, + robust=robust, + initial_point=initial_point, + formula=formula, + fit_options=fit_options, ) + summary = log_logistic_aft.summary + adata.uns[uns_key] = summary + + return log_logistic_aft + def _univariate_model( adata: AnnData, @@ -515,10 +722,11 @@ def nelson_aalen( See https://lifelines.readthedocs.io/en/latest/fitters/univariate/NelsonAalenFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: The name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. timeline: Return the best estimate at the values in timelines (positively increasing) entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. If None, all members of the population entered study when they were "born". @@ -583,11 +791,11 @@ def weibull( See https://lifelines.readthedocs.io/en/latest/fitters/univariate/WeibullFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in the AnnData object that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. - adata: AnnData object with necessary columns `duration_col` and `event_col`. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. timeline: Return the best estimate at the values in timelines (positively increasing) entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. If None, all members of the population entered study when they were "born". diff --git a/tests/tools/test_sa.py b/tests/tools/test_sa.py index 48d85b36..b2d1f6e5 100644 --- a/tests/tools/test_sa.py +++ b/tests/tools/test_sa.py @@ -84,15 +84,23 @@ def test_anova_glm(self): assert dataframe.iloc[1, 4] == 2 assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185 - def _sa_function_assert(self, model, model_class): + def _sa_function_assert(self, model, model_class, adata=None): assert isinstance(model, model_class) assert len(model.durations) == 1776 assert sum(model.event_observed) == 497 - def _sa_func_test(self, sa_function, sa_class, mimic_2_sa): + if adata is not None: + model_summary = adata.uns.get("test") + assert model_summary is not None + assert model_summary.equals(model.summary) + + def _sa_func_test(self, sa_function, sa_class, mimic_2_sa, regression=False): adata, duration_col, event_col = mimic_2_sa + if regression: + sa = sa_function(adata, duration_col=duration_col, event_col=event_col, uns_key="test") + else: + sa = sa_function(adata, duration_col=duration_col, event_col=event_col) - sa = sa_function(adata, duration_col, event_col) self._sa_function_assert(sa, sa_class) def test_kmf(self, mimic_2_sa):