Skip to content

raysas/deep-survival

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeepSurvival

Deep Learning models for Survival Analysis

Python 3.8+ PyTorch Survival Analysis

This repo comes at first as an attempt to replicate DeepSurv1 using pytorch, a first modern-deep learning approach that has been proved to be effective. To be extended to replicate other models

This approach is cool in a sense it's still based on Cox Proportional Hazard Model, and can accomodate for neural network architectires to predict time-to-event data. Opening up also the possibilities of enhancements in deep learning survival models as defined in subsequent models like Cox-Time2 that let go of the proportional hazard assumption, and LogisticHazard that is based on a discrete-time model.

Tip

a super comprehensive list of deep learning models for survival analysis can be found in this github repo survival-org/DL4Survival from Wiegrebe et al. (2024) 3.

Ideas

The model is in fact pretty simple, and requires some understanding of how Cox Model works, and some ideas on neural network optimization. As this will be a pytorch implementation concepts will be portrayed in a raw form within code, involving components like creating a custom loss function.

CPHM

As defined by the authors 1, CPHM assumed a linear combination of covariates, and the hazard function is defined as: $$h(t|x) = h_0(t) \exp(\beta^T x)$$

Where $h_0(t)$ is the baseline hazard function, and $\beta$ is the vector of coefficients. Slightly similar to linear models, the optimization of the model's coefficients is done through maximizing partial likelihood, which is defined as: $$PL(\beta) = \prod_{i=1}^n \left( \frac{\exp(\beta^T x_i)}{\sum_{j \in R(t_i)} \exp(\beta^T x_j)} \right)^{\delta_i}$$ here $R(t_i)$ is the risk set at time $t_i$ (those still alive), and $\delta_i$ is the event indicator (1 if the event occurred, 0 if censored), which is an important addition in implementing the network to be computed.

DeepSurv

The idea of DeepSurv is to replace the linear combination of covariates with a non-linear function, pretty much like a neural network replacing a linear model. The hazard function is defined as: $$h(t|x) = h_0(t) \exp(g(x))$$ where here what's new is the introduction of $g(x)$ that is a non-linear function of the covariates, and can be implemented as a neural network. Very much like maximizing partial likelihood, but reversed to suit the neural network setting, the objective function here is a loss to be minimized, that is defined as the negative log partial likelihood: $$\ell(\theta) = -\sum_{i=1}^n \delta_i \left( g(x_i) - \log \sum_{j \in R(t_i)} \exp(g(x_j)) \right)$$ where $\theta$ are the parameters of the nn, and $g(x)$ is the output of the nn for input $x$.

Note

the summation is equivalent to the log transformation fo the product in the original partial likelihood, and the negative sign is there to convert the maximization problem into a minimization one, since we're looking to minimize a loss

Implementation

Essentially, the main components are:

  • dataset that accomodates for the nature of the data (time-to-event) and the loss function (negative log partial likelihood); an item of which is a tuple of (features, event, time)
  • neural network model class inheriting from torch.nn.Module (as is the case for any pytorch model) $\iff g(x)$
  • custom loss function for the negative log partial likelihood $\iff \ell(\theta)$
  • training loop to optimize the model parameters using the defined loss function

So to be able to do this, the main extensions are in dataset and loss modules, where new classes inheritting from torch.utils.data.Dataset and torch.nn.Module are defined respectively:

  • SurvivalDataset: takes as input the features, event indicators, and times, and implements the __getitem__ method to return a tuple of (features, event, time) for each index (eq to Dataset behavior in pytorch)
  • NegativeLogPartialLikelihood: implements the forward method to compute the negative log partial
dataset = SurvivalDataset(X, events, times)
loss_fn = NegativeLogPartialLikelihood()

torch;s functionalities from len, getitem to forward in loss are tested using pytest in tests module

Any neural network architecture can be used to get $g(x)$ - risk score (or, to be precise, the log risk score). For the user it will be the same interface as in using pytorch for any classification problem. Somt hings are implemented around it in Dataset and Loss to accomodate for the target nature ((time, event) pairs) as well as the loss function and all. Added at the end a function to stratify risk groups and plot KM which is usually employed to visualize the results of survival models.

To test implementation:

pip install .
python -m src.main

Beyond DeepSurv

  • CoxTime

Footnotes

  1. Katzman, Jared L., et al. "DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network." BMC medical research methodology 18.1 (2018): 1-12. 2

  2. Kvamme, Håvard, Ørnulf Borgan, and Ida Scheel. "Time-to-event prediction with neural networks and Cox regression." Journal of Machine Learning Research 20.129 (2019): 1-30.

  3. Wiegrebe, S., Kopper, P., Sonabend, R., Bischl, B., & Bender, A. (2024). Deep learning for survival analysis: a review. Artificial Intelligence Review, 57(3), 65.

About

code replication of Survival models with Neural Networks: DeepSurv (Katzman et al., 2018)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages