Deep Learning models for Survival Analysis
This repo comes at first as an attempt to replicate DeepSurv1 using pytorch, a first modern-deep learning approach that has been proved to be effective. To be extended to replicate other models
This approach is cool in a sense it's still based on Cox Proportional Hazard Model, and can accomodate for neural network architectires to predict time-to-event data. Opening up also the possibilities of enhancements in deep learning survival models as defined in subsequent models like Cox-Time2 that let go of the proportional hazard assumption, and LogisticHazard that is based on a discrete-time model.
Tip
a super comprehensive list of deep learning models for survival analysis can be found in this github repo survival-org/DL4Survival from Wiegrebe et al. (2024) 3.
The model is in fact pretty simple, and requires some understanding of how Cox Model works, and some ideas on neural network optimization. As this will be a pytorch implementation concepts will be portrayed in a raw form within code, involving components like creating a custom loss function.
As defined by the authors 1, CPHM assumed a linear combination of covariates, and the hazard function is defined as:
Where
The idea of DeepSurv is to replace the linear combination of covariates with a non-linear function, pretty much like a neural network replacing a linear model.
The hazard function is defined as:
Note
the summation is equivalent to the log transformation fo the product in the original partial likelihood, and the negative sign is there to convert the maximization problem into a minimization one, since we're looking to minimize a loss
Essentially, the main components are:
- dataset that accomodates for the nature of the data (time-to-event) and the loss function (negative log partial likelihood); an item of which is a tuple of (features, event, time)
- neural network model class inheriting from
torch.nn.Module(as is the case for any pytorch model)$\iff g(x)$ - custom loss function for the negative log partial likelihood
$\iff \ell(\theta)$ - training loop to optimize the model parameters using the defined loss function
So to be able to do this, the main extensions are in dataset and loss modules, where new classes inheritting from torch.utils.data.Dataset and torch.nn.Module are defined respectively:
SurvivalDataset: takes as input the features, event indicators, and times, and implements the__getitem__method to return a tuple of (features, event, time) for each index (eq to Dataset behavior in pytorch)NegativeLogPartialLikelihood: implements theforwardmethod to compute the negative log partial
dataset = SurvivalDataset(X, events, times)
loss_fn = NegativeLogPartialLikelihood()torch;s functionalities from len, getitem to forward in loss are tested using pytest in tests module
Any neural network architecture can be used to get
To test implementation:
pip install .
python -m src.main- CoxTime
Footnotes
-
Katzman, Jared L., et al. "DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network." BMC medical research methodology 18.1 (2018): 1-12. ↩ ↩2
-
Kvamme, Håvard, Ørnulf Borgan, and Ida Scheel. "Time-to-event prediction with neural networks and Cox regression." Journal of Machine Learning Research 20.129 (2019): 1-30. ↩
-
Wiegrebe, S., Kopper, P., Sonabend, R., Bischl, B., & Bender, A. (2024). Deep learning for survival analysis: a review. Artificial Intelligence Review, 57(3), 65. ↩
