Skip to content

Commit

Permalink
adding deephit
Browse files Browse the repository at this point in the history
  • Loading branch information
juAlberge committed Jan 20, 2024
1 parent b66721b commit c2a6bf7
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 0 deletions.
148 changes: 148 additions & 0 deletions examples/deep_hit_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# %%
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from hazardous._deep_hit import _DeepHit
from hazardous.data._competing_weibull import make_synthetic_competing_weibull

seed = 0
independent_censoring = False
complex_features = True

bunch = make_synthetic_competing_weibull(
n_samples=10000,
n_events=3,
n_features=10,
return_X_y=False,
independent_censoring=independent_censoring,
censoring_relative_scale=1.5,
random_state=seed,
complex_features=complex_features,
)
X, y, y_uncensored = bunch.X, bunch.y, bunch.y_uncensored

censoring_rate = (y["event"] == 0).mean()
censoring_kind = "independent" if independent_censoring else "dependent"
ax = sns.histplot(
y,
x="duration",
hue="event",
multiple="stack",
palette="magma",
)
ax.set_title(f"{censoring_kind} censoring rate {censoring_rate:.2%}")

# %%
# Let's compare Fine and Gray marginal incidence to AalenJohansen
# and assess of potential biases.
import warnings
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from lifelines import AalenJohansenFitter


X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)
deephit = _DeepHit()
deephit.fit(X_train, y_train)

# %%
y_pred = deephit.predict_cumulative_incidence(X_test)
n_events = y["event"].nunique() - 1
fig, axes = plt.subplots(ncols=n_events, sharey=True, figsize=(12, 5))

for ax, event_id in tqdm(zip(axes, range(1, n_events + 1))):
times = deephit.labtrans.cuts

for idx in range(3):
ax.plot(
times,
y_pred[event_id - 1, idx, :],
label=f"DeepHit sample {idx}",
linestyle="--",
)

ax.plot(
times,
y_pred.mean(axis=1)[event_id - 1, :],
label="DeepHit marginal",
linewidth=3,
)

with warnings.catch_warnings(record=True):
# Cause all warnings to always be triggered.
warnings.simplefilter("always")

aj = AalenJohansenFitter(calculate_variance=False, seed=seed).fit(
durations=y["duration"],
event_observed=y["event"],
event_of_interest=event_id,
)

aj_uncensored = AalenJohansenFitter(calculate_variance=False, seed=seed).fit(
durations=y_uncensored["duration"],
event_observed=y_uncensored["event"],
event_of_interest=event_id,
)

aj.plot(ax=ax, label="AJ", color="k")
aj_uncensored.plot(ax=ax, label="AJ uncensored", color="k", linestyle="--")

ax.set_title(f"Event {event_id}")
ax.grid()
ax.legend()

# %%
from scipy.interpolate import interp1d
from hazardous.metrics import brier_score_incidence


fig, axes = plt.subplots(ncols=n_events, sharey=True, figsize=(12, 5))

times = deephit.labtrans.cuts

for ax, event_id in tqdm(zip(axes, range(1, n_events + 1))):
y_pred_event = y_pred[event_id - 1]
fg_brier_score = brier_score_incidence(
y_train,
y_test,
y_pred_event,
times,
event_of_interest=event_id,
)

ax.plot(times, fg_brier_score, label="DeepHit brier score")

with warnings.catch_warnings(record=True):
# Cause all warnings to always be triggered.
warnings.simplefilter("always")

aj = AalenJohansenFitter(calculate_variance=False, seed=seed).fit(
durations=y["duration"],
event_observed=y["event"],
event_of_interest=event_id,
)

times_aj = aj.cumulative_density_.index
y_pred_aj = aj.cumulative_density_.to_numpy()[:, 0]
y_pred_aj = interp1d(
x=times_aj,
y=y_pred_aj,
kind="linear",
)(times)

y_pred_aj = np.vstack([y_pred_aj for _ in range(X_test.shape[0])])

aj_brier_score = brier_score_incidence(
y_train,
y_test,
y_pred_aj,
times,
event_of_interest=event_id,
)

ax.plot(times, aj_brier_score, label="AJ brier score")

ax.set_title(f"Event {event_id}")
ax.grid()
ax.legend()
161 changes: 161 additions & 0 deletions hazardous/_deep_hit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import numpy as np
import torch
import torchtuples as tt
from pycox.models import DeepHit
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from sklearn.model_selection import train_test_split

SEED = 0

np.random.seed(1234)
_ = torch.manual_seed(1234)


class LabTransform(LabTransDiscreteTime):
def transform(self, durations, events):
durations, is_event = super().transform(durations, events > 0)
events[is_event == 0] = 0
return durations, events.astype("int64")


class CauseSpecificNet(torch.nn.Module):
"""Network structure similar to the DeepHit paper, but without the residual
connections (for simplicity).
"""

def __init__(
self,
in_features,
num_nodes_shared,
num_nodes_indiv,
num_risks,
out_features,
batch_norm=True,
dropout=None,
):
super().__init__()
self.shared_net = tt.practical.MLPVanilla(
in_features,
num_nodes_shared[:-1],
num_nodes_shared[-1],
batch_norm,
dropout,
)
self.risk_nets = torch.nn.ModuleList()
for _ in range(num_risks):
net = tt.practical.MLPVanilla(
num_nodes_shared[-1],
num_nodes_indiv,
out_features,
batch_norm,
dropout,
)
self.risk_nets.append(net)

def forward(self, input):
out = self.shared_net(input)
out = [net(out) for net in self.risk_nets]
out = torch.stack(out, dim=1)
return out


def get_x(df):
return df.values.astype("float32")


def get_target(df):
return (
df["duration"].astype("float32").values,
df["event"].astype("int32").values,
)


class _DeepHit:
def __init__(
self,
num_nodes_shared=[64, 64],
num_nodes_indiv=[32],
batch_size=256,
epochs=512,
callbacks=[tt.callbacks.EarlyStoppingCycle()],
verbose=False,
num_durations=10,
batch_norm=True,
dropout=None,
alpha=0.2,
sigma=0.1,
optimizer=tt.optim.AdamWR(
lr=0.01, decoupled_weight_decay=0.01, cycle_eta_multiplier=0.8
),
):
self.num_durations = num_durations
self.num_nodes_shared = num_nodes_shared
self.num_nodes_indiv = num_nodes_indiv
self.batch_norm = batch_norm
self.dropout = dropout
self.alpha = alpha
self.sigma = sigma
self.optimizer = optimizer
self.batch_size = batch_size
self.epochs = epochs
self.callbacks = callbacks
self.verbose = verbose

def fit(self, X, y):
X_train_, X_val_, y_train_, y_val_ = train_test_split(
X, y, test_size=0.2, random_state=SEED
)

X_train = get_x(X_train_)
X_val = get_x(X_val_)
y_train = get_target(y_train_)
y_val = get_target(y_val_)

self.labtrans = LabTransform(self.num_durations)

y_train = self.labtrans.fit_transform(*y_train)
y_val = self.labtrans.transform(*y_val)
self.in_features = X_train.shape[1]
self.num_risks = y_train[1].max()

self.net = CauseSpecificNet(
in_features=self.in_features,
num_nodes_shared=self.num_nodes_shared,
num_nodes_indiv=self.num_nodes_indiv,
num_risks=self.num_risks,
out_features=len(self.labtrans.cuts),
batch_norm=self.batch_norm,
dropout=self.dropout,
)

self.model = DeepHit(
net=self.net,
optimizer=self.optimizer,
alpha=self.alpha,
sigma=self.sigma,
duration_index=self.labtrans.cuts,
)

self.model.fit(
X_train,
y_train,
batch_size=self.batch_size,
epochs=self.epochs,
callbacks=self.callbacks,
verbose=self.verbose,
val_data=(X_val, y_val),
)

def predict_survival_function(self, X):
X_ = get_x(X)
return self.model.predict_surv_df(X_)

def predict_cumulative_incidence(self, X):
X_ = get_x(X)
cifs = self.model.predict_cif(X_)
cifs = np.swapaxes(cifs, 1, 2)
return cifs

def predict_proba(self, X):
X_ = get_x(X)
return self.model.predict_pmf(X_)

0 comments on commit c2a6bf7

Please sign in to comment.