Survival Analysis using pytorch
This library takes a deep learning approach to Survival Analysis.
A lot of classification problems are actually survival analysis problems and haven't been tackled as such. For example, consider a cancer patient and you take X-ray data from that patient. Over time, patients will eventually die from cancer (lets ignore the case where people will die from other diseases). The usual approach is to say here is the X-ray (x) and will the patient die in the next 30 days or not (y).
Survival analysis instead asks the question given the input (x) and a time(t), what is the probability that a patient will survive for a time greater than t. Considering the training dataset if a patient is still alive, in the classification case it would be thought of as y = 0. In survival analysis we say that it is a censored observation since the patient will die at a certain time in the future when the experiment is not being conducted.
The above analogy can be thought of in other scenarios such as churn prediction as well.
A proper dive into theory can be seen here.
Well, if you torch a life... you probability wouldn't survive. 😬
There are 3 models in here that can be used.
- [Kaplan Meier Model]
- [Proportional Hazard Models]
- [Accelerated Failure Time Models]
All 3 models require you to input a pandas dataframe with the columns
"t", "e"
indicating time elapsed and a binary variable, event if a death (1) or live (0) instance is observed respectively. They are all capable of doingfit
andplot_survival_function
.
This is the most simplistic model.
This model attempts to model the instantaneous hazard of an instance given time. It does this by binning time and finding the cumulative hazard upto a given point in time. It's extention the cox model, takes into account other variables that are not time dependant such that the above mentioned hazard can grow or shrink proportional to the risk associated with non-temporal feature based hazard.
from torchlife.model import ModelHazard
model = ModelHazard('cox')
model.fit(df)
inst_hazard, surv_probability = model.predict(df)
This model attempts to model (the mode and not average) time such that it is a function of the non-temporal features, x.
You are free to choose the distribution of the error, however, Gumbel
distribution is a popular option in SA literature. The distribution needs to be over the real domain and not just positive since we are modelling log time.
from torchlife.model import ModelAFT
model = ModelAFT('Gumbel')
model.fit(df)
surv_prob = model.predict(df)
mode_time = model.predict_time(df)
Special thanks to the following libraries and resources.
- lifelines and especially Cameron Davidson-Pilon
- pymc3 survival analysis examples
- nbdev
- pytorch lightning
- Generalised Linear Models by Germán Rodríguez, Chapter 7.
pip install torchlife
We need a dataframe that has a column named 't' indicating time, and 'e' indicating a death event.
import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
df.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
t | e | fin | age | race | wexp | mar | paro | prio | |
---|---|---|---|---|---|---|---|---|---|
0 | 20 | 1 | 0 | 27 | 1 | 0 | 0 | 1 | 3 |
1 | 17 | 1 | 0 | 18 | 1 | 0 | 0 | 1 | 8 |
2 | 25 | 1 | 0 | 19 | 0 | 1 | 0 | 1 | 13 |
3 | 52 | 0 | 1 | 23 | 1 | 1 | 1 | 1 | 1 |
4 | 52 | 0 | 0 | 19 | 0 | 1 | 0 | 1 | 3 |
from torchlife.model import ModelHazard
model = ModelHazard('cox', lr=0.5)
model.fit(df)
λ, S = model.predict(df)
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
| Name | Type | Params
--------------------------------------------
0 | base | ProportionalHazard | 12
--------------------------------------------
12 Trainable params
0 Non-trainable params
12 Total params
Epoch 0: 75%|███████▌ | 3/4 [00:00<00:00, 25.43it/s, loss=nan, v_num=49]
Epoch 0: 100%|██████████| 4/4 [00:00<00:00, 16.39it/s, loss=nan, v_num=49]
Epoch 1: 75%|███████▌ | 3/4 [00:00<00:00, 23.74it/s, loss=nan, v_num=49]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 15.25it/s, loss=nan, v_num=49]
Epoch 2: 75%|███████▌ | 3/4 [00:00<00:00, 22.98it/s, loss=nan, v_num=49]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 15.53it/s, loss=nan, v_num=49]
Epoch 3: 75%|███████▌ | 3/4 [00:00<00:00, 20.23it/s, loss=nan, v_num=49]
Validating: 0it [00:00, ?it/s]�[A
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 13.30it/s, loss=nan, v_num=49]
�[ASaving latest checkpoint...
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 12.92it/s, loss=nan, v_num=49]
Let's plot the survival function for the 4th element in the dataframe:
x = df.drop(['t', 'e'], axis=1).iloc[2]
t = np.arange(df['t'].max())
model.plot_survival_function(t, x)