From c2a6bf7c1c7c1f4f1cef78d2e4f923da94254a56 Mon Sep 17 00:00:00 2001 From: Julie Alberge Date: Sat, 20 Jan 2024 17:04:50 +0100 Subject: [PATCH] adding deephit --- examples/deep_hit_example.py | 148 ++++++++++++++++++++++++++++++++ hazardous/_deep_hit.py | 161 +++++++++++++++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 examples/deep_hit_example.py create mode 100644 hazardous/_deep_hit.py diff --git a/examples/deep_hit_example.py b/examples/deep_hit_example.py new file mode 100644 index 0000000..5b86cf3 --- /dev/null +++ b/examples/deep_hit_example.py @@ -0,0 +1,148 @@ +# %% +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from hazardous._deep_hit import _DeepHit +from hazardous.data._competing_weibull import make_synthetic_competing_weibull + +seed = 0 +independent_censoring = False +complex_features = True + +bunch = make_synthetic_competing_weibull( + n_samples=10000, + n_events=3, + n_features=10, + return_X_y=False, + independent_censoring=independent_censoring, + censoring_relative_scale=1.5, + random_state=seed, + complex_features=complex_features, +) +X, y, y_uncensored = bunch.X, bunch.y, bunch.y_uncensored + +censoring_rate = (y["event"] == 0).mean() +censoring_kind = "independent" if independent_censoring else "dependent" +ax = sns.histplot( + y, + x="duration", + hue="event", + multiple="stack", + palette="magma", +) +ax.set_title(f"{censoring_kind} censoring rate {censoring_rate:.2%}") + +# %% +# Let's compare Fine and Gray marginal incidence to AalenJohansen +# and assess of potential biases. +import warnings +from tqdm import tqdm +from sklearn.model_selection import train_test_split +from lifelines import AalenJohansenFitter + + +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed) +deephit = _DeepHit() +deephit.fit(X_train, y_train) + +# %% +y_pred = deephit.predict_cumulative_incidence(X_test) +n_events = y["event"].nunique() - 1 +fig, axes = plt.subplots(ncols=n_events, sharey=True, figsize=(12, 5)) + +for ax, event_id in tqdm(zip(axes, range(1, n_events + 1))): + times = deephit.labtrans.cuts + + for idx in range(3): + ax.plot( + times, + y_pred[event_id - 1, idx, :], + label=f"DeepHit sample {idx}", + linestyle="--", + ) + + ax.plot( + times, + y_pred.mean(axis=1)[event_id - 1, :], + label="DeepHit marginal", + linewidth=3, + ) + + with warnings.catch_warnings(record=True): + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + aj = AalenJohansenFitter(calculate_variance=False, seed=seed).fit( + durations=y["duration"], + event_observed=y["event"], + event_of_interest=event_id, + ) + + aj_uncensored = AalenJohansenFitter(calculate_variance=False, seed=seed).fit( + durations=y_uncensored["duration"], + event_observed=y_uncensored["event"], + event_of_interest=event_id, + ) + + aj.plot(ax=ax, label="AJ", color="k") + aj_uncensored.plot(ax=ax, label="AJ uncensored", color="k", linestyle="--") + + ax.set_title(f"Event {event_id}") + ax.grid() + ax.legend() + +# %% +from scipy.interpolate import interp1d +from hazardous.metrics import brier_score_incidence + + +fig, axes = plt.subplots(ncols=n_events, sharey=True, figsize=(12, 5)) + +times = deephit.labtrans.cuts + +for ax, event_id in tqdm(zip(axes, range(1, n_events + 1))): + y_pred_event = y_pred[event_id - 1] + fg_brier_score = brier_score_incidence( + y_train, + y_test, + y_pred_event, + times, + event_of_interest=event_id, + ) + + ax.plot(times, fg_brier_score, label="DeepHit brier score") + + with warnings.catch_warnings(record=True): + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + aj = AalenJohansenFitter(calculate_variance=False, seed=seed).fit( + durations=y["duration"], + event_observed=y["event"], + event_of_interest=event_id, + ) + + times_aj = aj.cumulative_density_.index + y_pred_aj = aj.cumulative_density_.to_numpy()[:, 0] + y_pred_aj = interp1d( + x=times_aj, + y=y_pred_aj, + kind="linear", + )(times) + + y_pred_aj = np.vstack([y_pred_aj for _ in range(X_test.shape[0])]) + + aj_brier_score = brier_score_incidence( + y_train, + y_test, + y_pred_aj, + times, + event_of_interest=event_id, + ) + + ax.plot(times, aj_brier_score, label="AJ brier score") + + ax.set_title(f"Event {event_id}") + ax.grid() + ax.legend() diff --git a/hazardous/_deep_hit.py b/hazardous/_deep_hit.py new file mode 100644 index 0000000..fb4159f --- /dev/null +++ b/hazardous/_deep_hit.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +import torchtuples as tt +from pycox.models import DeepHit +from pycox.preprocessing.label_transforms import LabTransDiscreteTime +from sklearn.model_selection import train_test_split + +SEED = 0 + +np.random.seed(1234) +_ = torch.manual_seed(1234) + + +class LabTransform(LabTransDiscreteTime): + def transform(self, durations, events): + durations, is_event = super().transform(durations, events > 0) + events[is_event == 0] = 0 + return durations, events.astype("int64") + + +class CauseSpecificNet(torch.nn.Module): + """Network structure similar to the DeepHit paper, but without the residual + connections (for simplicity). + """ + + def __init__( + self, + in_features, + num_nodes_shared, + num_nodes_indiv, + num_risks, + out_features, + batch_norm=True, + dropout=None, + ): + super().__init__() + self.shared_net = tt.practical.MLPVanilla( + in_features, + num_nodes_shared[:-1], + num_nodes_shared[-1], + batch_norm, + dropout, + ) + self.risk_nets = torch.nn.ModuleList() + for _ in range(num_risks): + net = tt.practical.MLPVanilla( + num_nodes_shared[-1], + num_nodes_indiv, + out_features, + batch_norm, + dropout, + ) + self.risk_nets.append(net) + + def forward(self, input): + out = self.shared_net(input) + out = [net(out) for net in self.risk_nets] + out = torch.stack(out, dim=1) + return out + + +def get_x(df): + return df.values.astype("float32") + + +def get_target(df): + return ( + df["duration"].astype("float32").values, + df["event"].astype("int32").values, + ) + + +class _DeepHit: + def __init__( + self, + num_nodes_shared=[64, 64], + num_nodes_indiv=[32], + batch_size=256, + epochs=512, + callbacks=[tt.callbacks.EarlyStoppingCycle()], + verbose=False, + num_durations=10, + batch_norm=True, + dropout=None, + alpha=0.2, + sigma=0.1, + optimizer=tt.optim.AdamWR( + lr=0.01, decoupled_weight_decay=0.01, cycle_eta_multiplier=0.8 + ), + ): + self.num_durations = num_durations + self.num_nodes_shared = num_nodes_shared + self.num_nodes_indiv = num_nodes_indiv + self.batch_norm = batch_norm + self.dropout = dropout + self.alpha = alpha + self.sigma = sigma + self.optimizer = optimizer + self.batch_size = batch_size + self.epochs = epochs + self.callbacks = callbacks + self.verbose = verbose + + def fit(self, X, y): + X_train_, X_val_, y_train_, y_val_ = train_test_split( + X, y, test_size=0.2, random_state=SEED + ) + + X_train = get_x(X_train_) + X_val = get_x(X_val_) + y_train = get_target(y_train_) + y_val = get_target(y_val_) + + self.labtrans = LabTransform(self.num_durations) + + y_train = self.labtrans.fit_transform(*y_train) + y_val = self.labtrans.transform(*y_val) + self.in_features = X_train.shape[1] + self.num_risks = y_train[1].max() + + self.net = CauseSpecificNet( + in_features=self.in_features, + num_nodes_shared=self.num_nodes_shared, + num_nodes_indiv=self.num_nodes_indiv, + num_risks=self.num_risks, + out_features=len(self.labtrans.cuts), + batch_norm=self.batch_norm, + dropout=self.dropout, + ) + + self.model = DeepHit( + net=self.net, + optimizer=self.optimizer, + alpha=self.alpha, + sigma=self.sigma, + duration_index=self.labtrans.cuts, + ) + + self.model.fit( + X_train, + y_train, + batch_size=self.batch_size, + epochs=self.epochs, + callbacks=self.callbacks, + verbose=self.verbose, + val_data=(X_val, y_val), + ) + + def predict_survival_function(self, X): + X_ = get_x(X) + return self.model.predict_surv_df(X_) + + def predict_cumulative_incidence(self, X): + X_ = get_x(X) + cifs = self.model.predict_cif(X_) + cifs = np.swapaxes(cifs, 1, 2) + return cifs + + def predict_proba(self, X): + X_ = get_x(X) + return self.model.predict_pmf(X_)