1 Introduction

The primary task of survival analysis is to determine the timing of one or multiple events, which can signify the moment of a mechanical system malfunction, the period of transition from corporate deficit to surplus, the instance of patient fatality or so on, depending on the specific circumstance (Lee & Whitmore, 2006). Among all scenarios, survival analysis for medical data poses the most severe challenges (Collett, 2023). Some medical datasets are longitudinal, as exemplified by electronic health records (EHRs), where multiple observations of each patient’s covariates over time are recorded. Survival models must be capable of handling such measurements and learning from their continuous temporal trends. Moreover, observations in longitudinal data are often sparse, necessitating the effective handling of missing values for any reliable survival model, even when the missing rates are exceedingly high (Singer & Willett, 1991). Additionally, censoring represents a fundamental aspect of survival data, referring to cases in which complete information regarding the survival time or event occurrence of a subject is not fully observed or available within the study period (Leung et al., 1997). The occurrence of censoring signifies the unknown exact timing of the event, consequently lacking ground truth for comparative learning. This, in turn, poses significant challenges for deep survival learning. Existing deep learning approaches aim at mitigating this issue by typically guaranteeing non-occurrence of events before censoring. Notwithstanding, detailed elucidation pertaining to the temporal aspect of events subsequent to censoring frequently remains inadequately explored.

Developing survival analysis models requires regressing the probability of survival over a defined period. A high-quality estimation of probability distribution is essential for the time-to-event prediction. As the initial category, parametric survival models are capable of generating high-quality probability density function (PDF) or survival curve by predetermining stochastic distribution, however, their precision is contingent upon the validity of all underlying assumptions. In contrast, non-parametric models do not presume any prior distribution of events, but they struggle to accurately predict PDF over extended temporal spans within medical datasets, consequently yielding PDF or survival curve of comparatively lower quality.

To address the challenges in the development of survival models and to mitigate the limitations inherent in existing models, we propose UniSurv, a non-parametric model based on the Transformer architecture. In particular, UniSurv can: 1) generate higher quality PDF resembling normal distribution without any prior probability assumption, and significantly improve accuracy for predicting censoring by integrating novel Margin-Mean-Variance loss; 2) use distinct embedding branches for static and dynamic feature extractions separately; 3) effectively handle cases with high missing rates of longitudinal data and various data modalities. The superiority of the UniSurv is substantiated through empirical evidence obtained from real and synthetic datasets.

2 Literature

Semi- and fully-parametric models heavily rely on the premise of making explicit assumptions about the underlying distribution of event times. They provide a structured framework for understanding the relationship between covariates and the occurrence of events over time. However, the strength of these assumptions results in overly simplistic probability distributions predicted by the models. The lack of flexibility stemming from this oversimplification also renders these models impractical in various scenarios. Cox proportional hazard (CPH) (Cox, 1972) is a prime example in this field. It estimates the hazard function \(\lambda (t|X)\) by multiplying a predetermined base hazard function \(\lambda _{0}(t)\) with the learnt representation of features g(X). Subsequent studies (Faraggi & Simon, 1995; Vinzamuri & Reddy, 2013; Luck et al., 2017) have used more sophisticated models to improve the CPH model. However, the oversimplified stochastic process continues to constrain their predictive capabilities, and it is unable to conduct dynamic analysis. Meanwhile, Nagpal et al. (2021b) introduced Deep Survival Machines (DSM), which postulates that the survival function is a composition of multiple Weibull and log-normal distributions. The parameters of those distributions are estimated by a multi-layer perceptron (MLP). Besides, Nagpal et al. (2021a) illustrated Recurrent DSM (RDSM) by incorporating recurrent neural network (RNN) into DSM, thereby endowing it to process dynamic analysis. Nonetheless, DSM models exhibit suboptimal accuracy in predicting event times. Its loss function frequently becomes divergent during training, contributing to the overfitting problem.

Some recent works have concentrated on static analysis. For example, DeepSurv (Katzman et al., 2018) employs an MLP network to replace the parametric assumptions of the hazard function present in the conventional CPH. This transformation results in a semi-parametric variant of the CPH model. The incorporation of neural networks enhances its flexibility by enabling the model to learn nonlinear relationships more adeptly from covariates. Besides, Deep Cox Mixtures (DCM) (Nagpal et al., 2021c) encounters the same underlying assumption of proportional hazards, wherein it assumes the presence of latent groups. Employing Variational Autoencoders for clustering, DCM assumes the validity of proportional hazards within each latent group. Moreover, Ishwaran et al. (2008) propose an extension of random forest (RF) algorithm, named Random Survival Forest (RSF), which initially breaks through the inherent assumptions of CPH. It computes the risk scores through the generation of Nelson-Aalen estimators within the partitions established by RF. RSF assumes independence among trees in forest, which might not always hold. This assumption can affect its performance, usually when correlations or dependencies exist among survival trees.

Several studies have explored the dynamic analysis field. Lee et al. (2018) propose DeepHit for competing risk events as a non-parametric model. The encoder of DeepHit is constructed as a joint MLP, while its decoder employs a series of distinct MLPs to address individual events. This design results in the generation of separate PDFs for each event. Lee et al. (2019) further extend it into Dynamic DeepHit (DDH) by replacing the encoder with RNNs followed by an attention mechanism, to process longitudinal data. A primary limitation lies in the arbitrary fluctuations between adjacent predictions within the output layer, resulting in noise present in the final PDFs. This phenomenon becomes particularly pronounced when forecasting over a long-time horizon. Survival SEQ2SEQ (SS2S) (Pourjafari et al., 2022) addresses this issue and takes advantage of RNN cells in their framework decoder to generate smoother PDFs. However, its approach to handling censoring is overly simplistic, focusing solely on the premise that events should not occur before the censoring time, without addressing any potential implications after it. Another model that shares a similar concern in handling censoring is the Transformer-based Deep Survival Model (TDSM) (Hu et al., 2021). This deficiency is evident in the designs of their loss functions. The previous transformer architectures in survival analysis include TDSM and SurvTRACE (Wang & Sun, 2022), but none of them have been extended to handle dynamic analysis.

In summary, while most existing methods could perform static analysis, only three of them could handle longitudinal data. Despite all the advancements in these recent works, there is a lack of a universal model which could jointly integrate handling miss data method, process various input formats and produce organized PDFs.

3 Method

In this section, we introduce our formal framework UniSurv, which is a adaptive Transformer-based architecture for survival analysis. We assume that the available survival dataset is subject to right censoring.

3.1 Survival notation

We denote time-invariant and time-varying covariates by \(\pmb {x}_n\) and \(\pmb {x}_v\), probability by p, time by T, t or \(\tau\), PDF by p(t) and survival function by S(t). Let’s represent the survival dataset as \({(\pmb {x}_{n}^{i}, \pmb {x}_{v}^{i}, T^{i}, \delta ^{i})}_{i=1}^{N}\), where for individual i, \(\delta ^{i}\) is the event indicator typically taken from the set \(\{0,1\}\) without competing risks, and \(T^{i}\) represents the event or censoring time depending on \(\delta ^{i}\). We omit the explicit dependence on i throughout this and the next subsections for simplifying notation.

We assume that time \(t \in \{T_{0}, T_{1},..., T_{max}\}\) to fit a discrete survival model, where t is a discrete random variable, and \(T_j\) is each time step with equal interval. The cumulative distribution function (CDF) of t can be easily calculated by its PDF as

$$\begin{aligned} CDF(T_{j} \mid (\pmb {x}_n,\pmb {x}_v)) = \sum _{t=T_0}^{T_j}p_{t} \end{aligned}$$
(1)

Having defined the probability that an event has occurred by duration \(T_{j}\), the survival function can then be estimated as the probability that the survival time t is at least \(T_{j}\). It can also be represented as the complement of CDF as

$$\begin{aligned} S(T_{j} \mid (\pmb {x}_n,\pmb {x}_v)) = 1 - CDF(T_{j} \mid (\pmb {x}_n,\pmb {x}_v)) = \sum _{t=T_j}^{T_{max}}p_{t} \end{aligned}$$
(2)

3.2 Model description

Fig. 1
figure 1

The illustration of a the architecture of UniSurv model and b a schematic representation of the UniSurv during training and testing stages

Figure 1a presents a comprehensive illustration of the UniSurv model. It integrates a novel survival loss design to enable a seamless end-to-end learning procedure. It encompasses dynamic and static extraction components, coupled with a Transformer encoder module culminating in a softmax layer at the output terminal. Besides, Fig. 1b depicts the conceptual process of UniSurv during both training and testing stages.

3.2.1 Static and dynamic extractions

We integrate the variation of last-observation-carried-forward (LOCF) method to handle missing data. It duplicates the value of the last observation to replace the following missing values until \(\Delta _{\tau }=3\), and the ones that are still missing after LOCF are imputed by mean/mode of all previous points for continuous/binary covariates until \(T_{max}\), making sure no missing data for all time points \(T_j\) (Lee et al., 2019). Next, we extract latent representations of time-invariant features \(\pmb {x}_n\) and time-varying features \(\pmb {x}_v\) by static (-s) and tabular-data-based dynamic (-d1) extraction modules separately. These modules are constructed using MLPs to deal with numerical tabular data. The representation of \(\pmb {x}_n\) is replicated \(T_{max}+1\) times by encompassing \(T_0\) as well. For static modelling, these are subsequently transmitted to encoder. For dynamic modelling, these representations are concatenated with the representation of \(\pmb {x}_v\) at their respective time points before the transmission. Meanwhile, a convolutional neural network (CNN) variation (-d2) of the image-based dynamic extraction module can be used to address the image-like input formats. As shown in Fig. 1a, the extracted feature can be shared across a predefined time window \(T_w\) in light of the data sparsity.

3.2.2 Transformer encoder

The core of the encoder module is a Transformer, which treats each patient as a ‘sentence’ and the embedded features as ‘words’ of the sentence. For an input sample, the number of words correspond to the duration \(t \in \{0,1,...,T_{max}\}\), where we predefined \(T_0=0\) and \(T_{max}\) is a hyper-parameter selected based on the longest temporal data of a dataset. The previous concatenated representations pass through a MLP followed by layer normalization to get the embedded features. Following the conventional approach of a Transformer, we utilize the sine-cosine positional embedding (Vaswani et al., 2017) as temporal embedding in this work, and add it onto the set of embedded features, whose length is set as the embedding dimension \(d_{m}\) of the Transformer. The Transformer encoder then processes embedded features and produces \(T_{max}+1\) outputs, each with shape \(1 \times d_{m}\). It is worth noting that the self-attention layers in the encoder is modified to prevent positions from attending to subsequent positions. Specifically, it prohibits each position from attending to subsequent positions, and the attention scores for all illegal connections are masked out by assigning them with \(-\infty\) (Vaswani et al., 2017). Next, all-time point outputs are fed into an exclusive 2-layer MLP. The first layer is followed by rectified linear unit (ReLU) and layer normalization, with shape of \(d_{m} \times \frac{d_{m}}{2}\). The second layer, with shape of \(d_{m} \times 1\), is followed by a softmax layer to produce the individual estimated PDF. Further, the estimated survival function \(\hat{S}(t \mid (\pmb {x}_n,\pmb {x}_v))\) can be calculated based on Eq. 2, which ensures its monotonicity is preserved. Moreover, in discrete survival analysis, the mean lifetime \(\mu\) can be approximated by the sum of the survival probabilities up to \(T_{max}\). We could get the estimated mean lifetime by further involving Eq. 2 as

$$\begin{aligned} \hat{\mu } = \sum _{t=T_0}^{T_{max}}\hat{S}(t)=\sum _{t=T_0}^{T_{max}}\sum _{\tau =t}^{T_{max}} \hat{p}_{\tau } \approx \sum _{t=T_0}^{T_{max}} t \cdot \hat{p}_{t} \end{aligned}$$
(3)

where the employment of the approximately equal symbol in the equation is attributed to the presence of time point \(T_0=0\). The variance of distribution is computed as

$$\begin{aligned} v = \sum _{t=T_0}^{T_{max}} \hat{p}_{t}\cdot (t-\hat{\mu })^{2} \end{aligned}$$
(4)

3.3 Loss function

To robustly estimate the uncensored survival time via distribution learning and generating smooth PDF, we adopt a variation of Mean-Variance loss (Pan et al., 2018) in UniSurv, which requiring each training sample has a corresponding event time label. However, censoring existing in survival dataset does not have event but censoring time. Using censoring time as the label will be misleading for the model, resulting in prediction bias. To overcome this, we employ the “margin time” concept (Haider et al., 2020), to assign a “best guess” value to each censored subject based on the non-parametric population Kaplan-Meier (KM) (Kaplan & Meier, 1958) estimator. Given a subject i censored at time \(T^{i}\), its margin event time is calculated by

$$\begin{aligned} e_{m}^{i} = T^{i} + \frac{\int _{T^{i}}^{T_{max}}S_{km(D_{tr})}(t)dt}{S_{km(D_{tr})}(T^{i})} \end{aligned}$$
(5)

where \(S_{km(D_{tr})}\) is the KM estimation derived from the training dataset. It is worth to mention that during the integration process, we compute it up to \(T_{max}\) in this work, ensuring that \(e_{m}\) remains within the bounds of \(T_{max}\), which stands in contrast to the approach proposed by Haider et al. (2020). They extend the KM curve infinitely through risky extrapolation beyond observed values.

We denote \(T^{i}\) as the corresponding ground-truth event/censoring time for individual i. With the estimated mean lifetime \(\hat{\mu }^{i}\), the Margin-Mean loss can be computed as

$$\begin{aligned} \mathcal {L}_{mm} = \frac{1}{2} \sum _{i=1}^{N}\Big {(}\delta ^i \cdot (\hat{\mu }^i - T^i)^2 + (1-\delta ^i) \cdot \omega ^i \cdot (\hat{\mu }^i - e_{m}^i)^2\Big {)} \end{aligned}$$
(6)

where \(\omega ^i = 1 - S_{km(D_{tr})}(T^i)\) can give high confidence weight with late censor time but low with early censor time. Margin-Mean loss can penalize dissimilarity between estimated mean lifetime and actual event/margin-event time. Besides, the Variance loss is calculated as

$$\begin{aligned} \mathcal {L}_{v} = \sum _{i=1}^{N}v^{i} \end{aligned}$$
(7)

which is implemented to regulate the spread of the estimated survival distribution, limiting it to a narrow range within the mean. Considering with Eq. 4, \(\mathcal {L}_{v}\) can cause the probabilities at time points farther from \(\hat{\mu }^{i}\) to approach 0. The softmax loss, as known as cross-entropy loss, can be computed as

$$\begin{aligned} \mathcal {L}_{s} = \sum _{i=1}^{N}-logp_{i, T^i} \end{aligned}$$
(8)

which is further utilized to aid in early training convergence, as Margin-Mean-Variance loss alone may experience substantial fluctuations (Pan et al., 2018).

Finally, tailored to address uncensoring, we utilize discordant loss by

$$\begin{aligned} \mathcal {L}_{d} = \sum _{i=1}^{N}\delta ^i \cdot max\{0, (T^{k}-T^{i})-(\hat{\mu }^k-\hat{\mu }^i)\} \end{aligned}$$
(9)

which can penalize the randomized discordant pairs for improving model’s pairwise ranking ability. The process is similar to randomized algorithm: random sampling with replacement individual k for each confirmed individual i, making sure that \(T^{k}>T^{i}\) and the difference between estimated times should not be smaller than the difference between ground truths. \(\mathcal {L}_{d}\) can further penalize the discordant pairs because when \(T^k\) and \(T^i\) are close, the Margin-Mean-Variance loss cannot effectively discriminate discordant pairs and may fall into a local optimum.

The total loss to train UniSurv is the combination of the above four losses as

$$\begin{aligned} \mathcal {L}_{total} = \mathcal {L}_{s} + \lambda _{m}\mathcal {L}_{mm} + \lambda _{v}\mathcal {L}_{v} + \lambda _{d}\mathcal {L}_{d} \end{aligned}$$
(10)

where \(\lambda _{m}\), \(\lambda _{v}\), \(\lambda _{d}\) are weights for the corresponding loss functions.

4 Experiments

In this section, we demonstrate the effectiveness of UniSurv by comparing it with other benchmarks on real and synthetic datasets from static and dynamic settings.

4.1 Datasets

To highlight the right-skewed characteristic of survival data, we utilized three real-world datasets and two long-tailed synthetic datasets.

4.1.1 Static datasets

The Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT) (Knaus et al., 1995) is a large static survival dataset of seriously ill hospitalized adults. The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC) (Curtis et al., 2012) is a static breast cancer dataset aiming to distinguish its subtypes based on the molecular characteristics. Their pre-processing strategies follow DeepSurv (Katzman et al., 2018).

We also generate a static synthetic dataset (SYNTH-s) of the style of that in Lee et al. (2018) but without competing risks. The dataset contains \(N=15,100\) examples drawn from the stochastic process

$$\begin{aligned} \begin{aligned}&\pmb {x}_{n}^{i} \sim \mathcal {N} (0, \textbf{I})\\&T^{i} \sim \textrm{exp}(\pmb {\gamma }_{n}^{T}\pmb {x}_{n}^{i}) \end{aligned} \end{aligned}$$
(11)

where \(\pmb {x}_{n}^{i}\) is a vector of 4-dimensional variables and \(\pmb {\gamma }_{n}=\pmb {10}\). We randomly select \(50\%\) patients to be right-censored with random censoring time uniformly drawn from \([0, T^{i}]\). More details are listed in Table 1.

Table 1 Descriptive statistics of three real world medical datasets and two synthetic datasets

4.1.2 Dynamic datasets

On the basis of SYNTH-s, we further generate dynamic synthetic dataset (SYNTH-d) by adding additional dynamic variables following Weibull distribution, and introducing temporal noise disturbances to make them variable over time as

$$\begin{aligned} \begin{aligned}&\pmb {x}_{v}^{i}(t) \sim \frac{a}{\beta }\Big (\frac{\pmb {x}}{\beta }\Big )^{(a-1)}exp\Big (-(\frac{\pmb {x}}{\beta })^a\Big ) + \mathcal {N} (0, \textbf{I})\\&T^{i} \sim \textrm{exp}\Big (\pmb {\gamma }_{v_{1}}^{T} \cdot max\big (\pmb {x}_{v_{1}}^{i}(t)\big )+\pmb {\gamma }_{v_{2}}^{T} \cdot min\big (\pmb {x}_{v_{2}}^{i}(t)\big )+\pmb {\gamma }_{n}^{T}\pmb {x}_{n}^{i}\Big ) \end{aligned} \end{aligned}$$
(12)

where \(\pmb {x}_{v}^{i}\) is a \(20 \times T_{max}\) dynamic variable matrix for all time points, a is the shape parameter, \(\beta\) is the scale parameter, \(\pmb {\gamma }_{v_{1}}=\pmb {\gamma }_{v_{2}}=\pmb {5}\), and \(max(\cdot )\) and \(min(\cdot )\) are operations on the temporal dimension. Besides, \(v_{1}\) and \(v_{2}\) are two randomly selected subsets that satisfy \(v_{1} \cap v_{2} = \varnothing\), \(v_{1} \cup v_{2} = v\). \(T^{i}\) is then resampled, and the method of introducing censoring cases remains the same as before.

Moreover, MSReactor (Merlo et al., 2021) dataset is a quantifiable, objective collection on cognition via longitudinal computerized test for Multiple Sclerosis (MS), integrating with other 8 static covariates. In each test, patients are instructed to respond as quickly as possible to onscreen stimuli, and their reaction time is recorded in millisecond (ms). The test includes 3 different tasks for testing their psychomotor function, attention and working memory. Each patient undergoes the test a number of times after the diagnosis and prior to the occurrence of the event/censoring (with at least one-month interval between every two adjacent tests). The survival event is characterized by EDSS progression through the six-month disability worsening confirmation rule (Hunter et al., 2021). Numerous research investigations have indicated that utilizing reaction data could potentially offer a more responsive approach for detecting subclinical cognitive impairment in comparison to current cognitive assessment methods (Foong et al., 2023; Pham et al., 2021; Whitehouse et al., 2019).

Fig. 2
figure 2

The illustration of reaction tensor representation of a single individual in MSReactor

The longitudinal reaction test will be considered as time-varying covariates. However, due to certain redundancies present in MSReactor, evident through pronounced inter-column associations, the characteristics of adjacent columns exhibit robust correlations, deviating from the conventional tabular data extraction where each column represents highly streamlined information. Without the application of specialized data preprocessing techniques and innovative model architectures, the existing survival models may encounter difficulties in extracting meaningful latent patterns from this data. Therefore, we transform the longitudinal tabular data part into a composite “reaction tensor" after monthly imputation, and utilize the certain module to deal with it. Specifically, each patient has a unique 3-dimensional reaction tensorFootnote 1, as shown in Fig. 2. Its Z-axis is corresponding to the 3 different tasks. X-axis is response dimension with fixed length of 30, corresponding to the 30 times the patient needs to finish three tasks separately in each test. Y-axis corresponds to the times patient has undergone tests per month with fixed length from the start time \(T_0\) to the end time \(T_{max}\). The reaction tensor is divided into several smaller tensors along Y-axis by \(T_w\) as in Fig. 1a.

4.2 Evaluation metrics

We utilize ranking measures such as concordance index (C-index) from lifelines (Davidson-Pilon, 2019) library and mean cumulative area under ROC curve (mAUC) from scikit-survival (Pölsterl, 2020) library, and accuracy measures such as mean absolute error (MAE) as the evaluation metrics for all experiments.

4.2.1 Concordance

C-index (Uno et al., 2011) is able to estimate ranking ability by comparing relative risks across all pairs in the test set as

$$\begin{aligned} \text {C-index} = \frac{ {\textstyle \sum _{i, k}} \delta ^i \cdot \mathbb {I}(T^{i}<T^{k})\cdot \mathbb {I}(\hat{\mu }^{i}<\hat{\mu }^{k})}{{\textstyle \sum _{i, k}} \delta ^i \cdot \mathbb {I}(T^{i}<T^{k})} \end{aligned}$$
(13)

where \(\mathbb {I}(\star )\) is an indicator function, and \(\delta _i=0\) if \(T^i\) is uncensored and 1 otherwise.

4.2.2 MAE-uncensored

MAE-Uncensored (MAE-U) can compensate for the inability of C-index to measure the mean absolute value of the estimated risk score. It is computed as

$$\begin{aligned} \text {MAE-U} = \frac{{\textstyle \sum _{i}} \delta ^i \cdot \left| T^i - \hat{\mu }^i \right| }{{\textstyle \sum _{i}} \delta ^i} \end{aligned}$$
(14)

4.2.3 MAE-Hinge

MAE-Hinge (MAE-H) is a one-sided MAE for only censoring cases, opposite with MAE-U for uncensoring only. It considers only if the predicted time \(\hat{\mu }\) is earlier than the censored time T as follow

$$\begin{aligned} \text {MAE-H} = \frac{{\textstyle \sum _{i}} (1-\delta ^i) \cdot max\{ T^i - \hat{\mu }^i , 0\}}{{\textstyle \sum _{i}} (1-\delta ^i)} \end{aligned}$$
(15)

4.2.4 Mean cumulative area under ROC curve

The area under ROC curve for survival analysis involves treating survival issue as binary classification across various quantiles of event times and defining the sensitivity and specificity as time-dependent measures (Lambert & Chevret, 2016). The cumulative AUC measures model’s capability of discriminating individuals who fail by a specified t (\(T_j\le t\)) from subjects who fail after this time (\(T_j > t\)). We compute the mAUC by integrating the cumulative AUC over all time range \((T_j, T_j+1)\).

4.3 Experimental setting

We compare with five static benchmarks, including CPH, DeepSurv, DeepHit, DSM and TDSM, and two dynamic benchmarks,Footnote 2 including DDH and RDSM. As static dataset does not have longitudinal covariates, our dynamic extraction module in UniSurv is in non-activation mode named UniSurv-s. For MSReactor, the dynamic extraction module has two variants based on different data representations, tabular data representation named UniSurv-d1 and image-like representation named UniSurv-d2. More implementation and hyperparameter details are in the Appendix 2.Footnote 3

For a fair comparison, we use C-index as early stopping criterion for all approaches as it can cover more subjects than MAE. We report the results by using cross-validation, randomly splitting datasets 5 times into training, validation and test sets with ratio 7:1:2. All experiments are implemented in PyTorch 2.0.1 on the same environments with a fixed random seed.

4.4 Benchmarking results

Table 2 Benchmarking on three static datasets

Performance comparisons for all datasets are summarized in Tables 2 and 3. We bold the best and underline the second best. Besides, the statistical significance is determined by paired t-test between the best results and all others individually.

4.4.1 Static modelling results

In terms of C-index, our UniSurv-s secures the first position on SYNTH-s and the second position on both SUPPORT and METABRIC. It also reaches the best mAUC on METABRIC and SYNTH-s and the second best on SUPPORT. DeepSurv shows comparable ranking performance on two real-world datasets. This illustrates that parametric model still hold a slight advantage over non-parametric model, rely on its robust probability distribution assumptions. Meanwhile, the performances of the other four models vary, creating a competitive landscape. This makes it difficult to definitively judge their performance under single ranking metrics.

Meanwhile, our UniSurv performs well in MAE-U and exhibits notably superior performance in the realm of MAE-H, with statistical significance compared to other models. Only DeepSurv in SUPPORT is comparable with ours in both two MAEs. Conversely, the performance of TDSM, while excelling in MAE-U, lags notably behind in MAE-H. This is because the loss design of TDSM leads to overfitting on uncensored data throughout the learning process, failing to capture the fact that most censored samples have longer survival times. Further, the inadequacy of TDSM’s predictions for censoring is also evident by Fig. 3, in which we represent the difference between true censoring time and estimated mean lifetime with red lines for some censoring cases. We show the METABRIC results from TDSM, UniSurv-s and the second-best MAE-U model DeepHit here. The more and longer red lines, the model have less sensitivity of censoring prediction. It can be observed that UniSurv has the capability to provide accurate predictions for the majority of censoring cases. This outcome can be attributed to the incorporation of the MAE-margin concept within the Margin-Mean loss \(\mathcal {L}_{mm}\) in Eq. 6, as it leverages prior knowledge from the training dataset to effectively “enforce” predicted survival time to exceed the censoring time. On the other hand, DeepHit exhibits significant inefficiency in forecasting longer censoring times. Similar to TDSM, this is also due to the absence of certain constraints within its loss designs beyond the censoring time, which may give rise to a systemic bias in predicting censoring cases.

4.4.2 Dynamic modelling results

As depicted in Table 3, UniSurv-d1 demonstrates superior performance over two other models for longitudinal datasets, as evidenced by higher values in C-index, mAUC and lower values in two MAEs. However, the performance of these three methods is generally suboptimal, as their C-index values remain below 0.6. This occurrence likely arises from the fact that the temporal data in MSReactor diverges from conventional survival tabular data, instead representing a reaction testing approach applied to MS patients. Traditional models struggle to effectively extract meaningful insights from this intricate and redundant information. Notably, when we preprocess the computerized test data into "reaction tensor" and employ CNN to extract latent features, the performance of UniSurv-d2 surpasses the others with statistically significant improvements. However, this "tensor" method has not demonstrated effectiveness for SYNTH-d, primarily due to the isotropic distribution of each variable \(x_{v}\) during data generation, resulting in their mutual independence and lack of correlation.

Fig. 3
figure 3

The difference between the estimated lifetime \(\hat{\mu }^i\) (blue dot) and the true censoring time \(T^i\) (green square) of TDSM, DeepHit and UniSurv-s in METABRIC. Each red line indicates the difference if \(\hat{\mu }^i<T^i\) for individual i, which is conversely not displayed in the opposite scenario

Table 3 Benchmarking on two dynamic datasets

4.4.3 The implication of data distribution

Fig. 4
figure 4

The time-dependent AUC. The dashed line shows mAUC corresponding to each colored curve

As shown in Fig. 4, all five histograms depict the distribution of survival times skewed towards the early segment of the time horizon, while censoring times tend to cluster in the latter half, especially in SUPPORT, SYNTH-s, MSReactor and SYNTH-d. This leads to survival models facing difficulty in maintaining predictive accuracy over time, as evidenced by the time-dependent AUC (TD-AUC). For example, all the performances of UniSurv-s, DeepHit and DSM, or their dynamic variants (UniSurv-d2, DDH, RDSM) exhibit a consistent decline in TD-AUC as time progresses. However, UniSurv still outperforms others, especially on two dynamic datasets. For METABRIC, due to its relatively low censoring rate and evenly distributed censoring cases, all three models maintain their TD-AUC quite well, with some even showing an upward trend, particularly UniSurv. It affirms that Transformer encoder based on Margin-Mean-Variance loss learning can effectively alleviate the challenges posed by survival datasets characterized by long-tail distributions.

Fig. 5
figure 5

The estimated PDFs by DDH and UniSurv for five randomly selected uncensoring cases in MSReactor. Each color represents an individual

4.5 Importance of masked attention mechanism

In the context of leveraging Transformer for inference, the masking function within the attention mechanism is inherently discretionary, contingent upon whether each output necessitates contributions from all or specific designated inputs. For static survival data, the design of UniSurv does not entail distinctions in latent features at each time point beyond temporal embedding. Hence, there is no risk of information leakage, rendering the masking mechanism inconsequential. For instance, it is not employed in the TDSM. As demonstrated by Table 2, the overall performance of UniSurv-s has not been affected by removing masking mechanism from UniSurv-s, and the slight performance fluctuations can be negligible. However, when dealing with dynamic survival data, the missing data problem is inevitable, and imputations following event or censoring times may give rise to potential retro-active prediction concern. Therefore, the masking mechanism becomes imperative in such scenario. In Table 3, both UniSurv-d1 and UniSurv-d2 exhibited an equivalent degree of performance decline across two datasets by removing masking, which are evidenced by their ranking ability.

4.6 Comparison of PDF visualizations

In addition to predictive accuracy, the quality of estimated individual PDF stands as another crucial consideration when comparing non-parametric survival models. The distribution of PDF generated by our UniSurv is specifically governed by Margin-Mean-Variance loss and remains unaffected by variations in distinct extraction modules. In Fig. 5, we present a comparison of the PDF outputs for 5 randomly selected uncensoring cases in MSReactor. We choose to contrast the DDH and UniSurv due to their absence of assumptions regarding the shape of the PDF, whereas RDSM relies on strong assumptions related to the Weibull and log-normal distributions. As described in above sections and shown in Fig. 5, our \(\mathcal {L}_{mm}\) can penalize dissimilarity between the peak of PDF and the ground truth. Besides, diverging from the disordered PDFs from DDH, \(\mathcal {L}_{v}\) can regulate the spread of PDF and limit it into a distinct pattern and organization. In contrast, despite using the same MLP and softmax as the output layer in DDH, the high fluctuations of PDFs can be attributed to the shortcomings in its loss function design.

The unimodal nature of survival PDF offers several advantages. For example, it can better reflect the time-to-event and naturally calibrate the median survival time corresponding to survival curve, such as Rindt et al. (2022) employed several pre-defined unimodal distributions for survival modelling. However, UniSurv departs from this assumption, achieving the same objective through a distinctive loss design. The current over-concentrated PDF is not optimal, and appropriately adjusting \(\mathcal {L}_{v}\) to relax its constraints on the shape will become necessary.

Fig. 6
figure 6

Comparison results of the ablation study and the effectiveness analysis. All experiments are from UniSurv-d2 setting on MSReactor dataset

4.7 Ablation study

We further conduct an ablation study of losses on MSReactor using UniSurv-d2, to demonstrate the contribution of each loss. In Fig. 6a, it is evident that an incomplete loss combination sometimes can lead to lower MAE-U or MAE-H, however, this often results in a situation of local optimization, which is reflected in the shape of the PDF. In Fig. 6f, we compare the PDFs under selected scenarios: \(\mathcal {L}_{mm}\) only, \(\mathcal {L}_{v}\) only and \(\mathcal {L}_{total}\). It is discernible that relying solely on \(\mathcal {L}_{mm}\), due to the absence of \(\mathcal {L}_{v}\) constraints, tends to produce a probability distribution biased towards uniformity around the ground truth. On the other hand, training solely with \(\mathcal {L}_{v}\) generates irregular PDFs and fails to acquire meaningful information. This observation elucidates why these scenarios do not yield the optimal C-index. Besides, the inclusion of \(\mathcal {L}_{d}\) can further enhance the performance by mitigating the occurrence of discordant pairs. In addition, incorporating \(\mathcal {L}_{s}\) results in faster convergence and significant performance improvement, particularly evident in C-index. However, \(\mathcal {L}_{s}\) also leads to a rapid concentration of all probability distributions near the event time, which can result in overly concentrated PDFs and potential calibration issues. As shown in Fig. 6e, we compared the averaged PDF shape with and without \(\mathcal {L}_{s}\) in the combinations presented in Fig. 6a. The averaged shape is calculated by aligning all PDF peaks using the Dynamic Time Warp (Giorgino, 2009) technique and then averaging them in normalized horizon. It is apparent that the absence of \(\mathcal {L}_{s}\) yields a multimodal, smoother, and more realistic PDF. Hence, the selection of \(\mathcal {L}_{s}\) involves a trade-off between the ranking and calibration ability.

4.8 Effectiveness analysis

4.8.1 Sensitivity of time window \(T_w\)

Figure 6b shows the effect of \(T_w\). We can observe that when \(T_w=8\), the model can achieve the highest C-index and lowest MAE-H, which is associated with the progression rate of MS. However, during the same period, MAE-U demonstrates its poorest performance. It is also apparent that the fluctuations in MAE-U and MAE-H exhibit a contrasting pattern. This disparity can be attributed to the disparate distributions of censoring and uncensoring within the MSReactor as in Fig. 4. Meanwhile, this underscores that there exists potential for enhancing the robustness of UniSurv.

4.8.2 Sensitivity of loss weights \(\lambda _{m}\) and \(\lambda _{v}\)

As the number of losses increases, finding the optimal weight combination indeed becomes challenging, but grid search can take care of this. The four losses do not need to be standardized to a similar magnitude. The unique characteristics of different datasets can lead to distinct optimal weights for losses. We assessed the sensitivities of \(\lambda _{m}\) and \(\lambda _{v}\) particularly on MSReactor in Fig. 6c, d, and some selected PDFs are shown in Fig. 6g. The model exhibits robustness when small variations occur in \(\lambda _{m}\) or \(\lambda _{v}\), as the performance near their optimal values does not exhibit significant degradation. In some cases, two MAEs even perform better. This phenomenon is attributed to the opposing fluctuation trends exhibited by the MAEs, indicating a trade-off made by UniSurv during training. Notably, the C-index appears to be more sensitive to changes in \(\lambda _{v}\) compared to variations in \(\lambda _{m}\). As the variations in both weights increase, deviations in the PDF gradually emerge, with its peak drifting further away from the actual event time and assuming irregular shapes.

4.8.3 Sensitivity of larger and noised synthetic datasets

To emphasize UniSurv’s reliability for higher dimensionality datasets and its robustness to data noise, we have expanded the number of features for the existing synthetic datasets SYNTH-s and SYNTH-d without altering event or censoring settings, resulting in new SYNTH-sk and SYNTH-dk datasets, where k denotes the dimension of \(\pmb {x}_{n}^{i}\) in Eq. 11 and Eq. 12 is increased from 4 to \(4 \cdot 5^{k}\), and the dimension of \(\pmb {x}_{v}^{i}\) in Eq. 12 is increased from 20 to \(20 \cdot 5^{k}\). Additionally, we introduced noise \(\pmb {\epsilon }^{i} \sim \epsilon _{0} \cdot \mathcal {N} (0, \textbf{I})\) to \(\pmb {x}_{n}^{i}\) and \(\pmb {x}_{v}^{i}\) separately in all datasets. As the results shown in Table 4, UniSurv performs well on high-dimensional datasets and exhibits robustness to small levels of noise interference.

Table 4 Noise sensitivity analysis on different sizes of synthetic datasets

5 Conclusion and discussion

In this paper, we propose a non-parametric discrete survival model named UniSurv. Departing from the existing models of utilizing RNN for processing longitudinal data, we employ a Transformer for adeptly handling dynamic analysis. In particular, our survival framework firstly integrates imputation for handling missing data issue, then incorporates different embedding branches for time-varying and time-invariant features extraction. The Transformer encoder takes merged features as input and outputs the individual PDF. We also demonstrated how to process image-like data using variations of modules and how to select a time window based on the progression speed of the disease to share information. This is particularly beneficial in the field of medicine, as obtaining regular time-series medical images in the real world is challenging.

Furthermore, our novel Margin-Mean-Variance loss effectively produces smooth PDF in a unimodal manner, demonstrating clear superiority over other discrete models. Importantly, the proposed loss can be seamlessly embedded into various discrete survival models. Moreover, it significantly enhances prediction accuracy, particularly for patients with extended censoring times. Applying poorly performing models in such scenarios could evidently disrupt physician’s judgments and place unnecessary burdens on both society and healthcare institutions. This constitutes a valuable contribution. Although our current PDF may appear overly concentrated around event times, akin to many models relying on strong probability assumptions, resulting in unconventional survival curves, we intend to further modify the \(\mathcal {L}_{s}\) and \(\mathcal {L}_{v}\) to relax certain constraints in the future. This adjustment aims to yield a more elegant PDF, characterized by a smoother and less abrupt distribution while maintaining overall performance. Meanwhile, adapting UniSurv to accommodate multiple censoring scenarios, such as left truncation and interval-censored data, presents an interesting direction for future research. Additionally, expanding the scope to include a post-processing statistic for interpreting risk predictions in both static and dynamic analyses of disease progression is necessary. For example, individual explanations of predicted probabilities can be achieved through the generation of SHapley Additive exPlanations (SHAP) (Pieszko et al., 2023; Krzyziński et al., 2023). This approach is expected to result in more effective health care.