Skip to main content

Advertisement

Log in

Deep doubly robust outcome weighted learning

  • Published:
Machine Learning Aims and scope Submit manuscript

Abstract

Precision medicine is a framework that adapts treatment strategies to a patient’s individual characteristics and provides helpful clinical decision support. Existing research has been extended to various situations but high-dimensional data have not yet been fully incorporated into the paradigm. We propose a new precision medicine approach called deep doubly robust outcome weighted learning (DDROWL) that can handle big and complex data. This is a machine learning tool that directly estimates the optimal decision rule and achieves the best of three worlds: deep learning, double robustness, and residual weighted learning. Two architectures have been implemented in the proposed method, a fully-connected feedforward neural network and the Deep Kernel Learning model, a Gaussian process with deep learning-filtered inputs. We compare and discuss the performance and limitation of different methods through a range of simulations. Using longitudinal and brain imaging data from patients with Alzheimer’s disease, we demonstrate the application of the proposed method in real-world clinical practice. With the implementation of deep learning, the proposed method can expand the influence of precision medicine to high-dimensional abundant data with greater flexibility and computational power.

This is a preview of subscription content, log in via an institution to check access.

Access this article

Price excludes VAT (USA)
Tax calculation will be finalised during checkout.

Instant access to the full article PDF.

Fig. 1
Fig. 2
Fig. 3

Similar content being viewed by others

Data availability

The code used to generate simulated data can be shared upon request. The National Alzheimer’s Coordinating Center database needs to be requested at https://naccdata.org/.

Code availability

Code can be shared upon request and will be shared on GitHub.

Notes

  1. The NACC database can be accessed at https://www.alz.washington.edu/WEB/researcher _home.html.

References

  • Angrist, J. D., Imbens, G. W., & Rubin, D. B. (1996). Identification of causal effects using instrumental variables. Journal of the American Statistical Association, 91(434), 444–455.

    Article  Google Scholar 

  • Athey, S., & Wager, S. (2021). Policy learning with observational data. Econometrica, 89(1), 133–161.

    Article  MathSciNet  Google Scholar 

  • Barthold, D., Joyce, G., Diaz Brinton, R., Wharton, W., Kehoe, P. G., & Zissimopoulos, J. (2020). Association of combination statin and antihypertensive therapy with reduced Alzheimer’s disease and related dementia risk. PLoS ONE, 15(3), e0229541.

    Article  Google Scholar 

  • Beekly, D. L., Ramos, E. M., van Belle, G., Deitrich, W., Clark, A. D., Jacka, M. E., Kukull, W. A., et al. (2004). The national Alzheimer’s coordinating center (nacc) database: An Alzheimer disease database. Alzheimer Disease & Associated Disorders, 18(4), 270–277.

    Google Scholar 

  • Bennett, A. & Kallus, N. (2020). Efficient policy learning from surrogate-loss classification reductions. In International conference on machine learning, pp. 788–798. PMLR.

  • Bergstra, J., Yamins, D., & Cox, D. D. (2013). Making a science of model search: Hyperparameter optimization in hundreds of dimensions for vision architectures. Jmlr.

  • Bergstra, J.S., R. Bardenet, Y. Bengio, & Kégl, B. (2011). Algorithms for hyper-parameter optimization. In Advances in neural information processing systems, pp. 2546–2554.

  • Besser, L., Kukull, W., Knopman, D. S., Chui, H., Galasko, D., Weintraub, S., Jicha, G., Carlsson, C., Burns, J., Quinn, J., et al. (2018). Version 3 of the national Alzheimer’s coordinating center’s uniform data set. Alzheimer Disease and Associated Disorders, 32(4), 351.

    Article  Google Scholar 

  • Dudík, M., J. Langford, & Li, L. (2011). Doubly robust policy evaluation and learning. arXiv:1103.4601.

  • Duron, E., Rigaud, A. S., Dubail, D., Mehrabian, S., Latour, F., Seux, M. L., & Hanon, O. (2009). Effects of antihypertensive therapy on cognitive decline in Alzheimer’s disease. American Journal of Hypertension, 22(9), 1020–1024.

    Article  Google Scholar 

  • Friedman, J., Hastie, T., & Tibshirani, R. (2010). Regularization paths for generalized linear models via coordinate descent. Journal of Statistical Software, 33(1), 1.

    Article  Google Scholar 

  • Gardner, J., Pleiss, G., Weinberger, K. Q., Bindel, D., & Wilson A. G. (2018). Gpytorch: Blackbox matrix-matrix gaussian process inference with gpu acceleration. In Advances in neural information processing systems, pp. 7587–7597.

  • Gorgolewski, K., Burns, C. D., Madison, C., Clark, D., Halchenko, Y. O., Waskom, M. L., & Ghosh, S. S. (2011). 08 Nipype: A flexible, lightweight and extensible neuroimaging data processing framework in python. Frontiers in Neuroinformatics. https://doi.org/10.3389/fninf.2011.00013

    Article  Google Scholar 

  • Greenland, S., Pearl, J., & Robins, J. M. (1999). Confounding and collapsibility in causal inference. Statistical Science, 14(1), 29–46.

    Article  Google Scholar 

  • Guo, W., Zhou, X. H., & Ma, S. (2021). Estimation of optimal individualized treatment rules using a covariate-specific treatment effect curve with high-dimensional covariates. Journal of the American Statistical Association, 116(533), 309–321.

    Article  MathSciNet  Google Scholar 

  • Hajjar, I., Hart, M., Chen, Y. L., Mack, W., Milberg, W., Chui, H., & Lipsitz, L. (2012). Effect of antihypertensive therapy on cognitive function in early executive cognitive impairment: A double-blind randomized clinical trial. Archives of Internal Medicine, 172(5), 442–444.

    Article  Google Scholar 

  • He, K., X. Zhang, S. Ren & Sun J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778.

  • Hernán, M. A., & Robins, J. M. (2020). Causal inference: What if. Chapman & Hall/CRC.

    Google Scholar 

  • Holloway, S. T., Laber, E. B., Linn, K. A., Zhang, B., Davidian, M., & Tsiatis, A. A. (2018). DynTxRegime: Methods for estimating optimal dynamic treatment regimes. R package version, 3, 2.

    Google Scholar 

  • Janocha, K. & Czarnecki, W. M. (2017). On loss functions for deep neural networks in classification. arXiv:1702.05659 .

  • Jiang, X., Nelson, A. E., Cleveland, R. J., Beavers, D. P., Schwartz, T. A., Arbeeva, L., Alvarez, C., Callahan, L. F., Messier, S., Loeser, R., et al. (2021). Precision medicine approach to develop and internally validate optimal exercise and weight-loss treatments for overweight and obese adults with knee osteoarthritis: Data from a single-center randomized trial. Arthritis Care & research, 73(5), 693–701.

    Article  Google Scholar 

  • Kallus, N. (2018). Balanced policy evaluation and learning. In Advances in neural information processing systems, pp. 8909–8920.

  • Kallus, N. (2020). Generalized optimal matching methods for causal inference. Journal of Machine Learning Research, 21, 62–1.

    MathSciNet  Google Scholar 

  • Kang, J.D. & Schafer, J. L. (2007). Demystifying double robustness: A comparison of alternative strategies for estimating a population mean from incomplete data.

  • Kosorok, M. R., & Laber, E. B. (2019). Precision medicine. Annual review of statistics and its application, 6, 263–286.

    Article  MathSciNet  Google Scholar 

  • Leete, O.E., N. Kallus, M.G. Hudgens, S. Napravnik, and M.R. Kosorok. 2019. Balanced policy evaluation and learning for right censored data. arXiv:1911.05728.

  • Liang, M., Ye, T., & Fu, H. (2018). Estimating individualized optimal combination therapies through outcome weighted deep learning algorithms. Statistics in Medicine, 37(27), 3869–3886.

    Article  MathSciNet  Google Scholar 

  • Liaw, A., Wiener, M., et al. (2002). Classification and regression by randomforest. R news, 2(3), 18–22.

    Google Scholar 

  • Liaw, R., Liang, E., Nishihara, R., Moritz, P., Gonzalez, J. E. & Stoica , I. (2018). Tune: A research platform for distributed model selection and training. arXiv:1807.05118.

  • Liu, Y., Wang, Y., Kosorok, M. R., Zhao, Y. & D. Zeng. 2016. Robust hybrid learning for estimating personalized dynamic treatment regimens. arXiv:1611.02314.

  • Michel, J. (2020). March. ENNUI: An elegant neural network user interface.

  • Nie, X., Brunskill, E., & Wager, S. (2021). Learning when-to-treat policies. Journal of the American Statistical Association, 116(533), 392–409.

    Article  MathSciNet  Google Scholar 

  • Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L. & Lerer, A. (2017). Automatic differentiation in pytorch. NIPS.

  • Plis, S. M., Hjelm, D. R., Salakhutdinov, R., Allen, E. A., Bockholt, H. J., Long, J. D., Johnson, H. J., Paulsen, J. S., Turner, J. A., & Calhoun, V. D. (2014). Deep learning for neuroimaging: A validation study. Frontiers in Neuroscience, 8, 229.

    Article  Google Scholar 

  • Provost, F (2000). Machine learning from imbalanced data sets 101. In Proceedings of the AAAI’2000 workshop on imbalanced data sets, Vol. 68, pp. 1–3. AAAI Press.

  • Qian, M., & Murphy, S. A. (2011). Performance guarantees for individualized treatment rules. The Annals of Statistics, 39(2), 1180.

    Article  MathSciNet  Google Scholar 

  • Robins, J. M., Rotnitzky, A., & Zhao, L. P. (1994). Estimation of regression coefficients when some regressors are not always observed. Journal of the American Statistical Association, 89(427), 846–866.

    Article  MathSciNet  Google Scholar 

  • Scharfstein, D. O., Rotnitzky, A., & Robins, J. M. (1999). Adjusting for nonignorable drop-out using semiparametric nonresponse models. Journal of the American Statistical Association, 94(448), 1096–1120.

    Article  MathSciNet  Google Scholar 

  • Shah, K., Qureshi, S. U., Johnson, M., Parikh, N., Schulz, P. E., & Kunik, M. E. (2009). Does use of antihypertensive drugs affect the incidence or progression of dementia? a systematic review. The American Journal of Geriatric Pharmacotherapy, 7(5), 250–261.

    Article  Google Scholar 

  • Shi, C., Blei, D. M. & Veitch V. (2019). Adapting neural networks for the estimation of treatment effects. arXiv:1906.02120.

  • Sverdrup, E., Kanodia, A., Zhou, Z., Athey, S., & Wager, S. (2020). policytree: Policy learning via doubly robust empirical welfare maximization over trees. Journal of Open Source Software, 5(50), 2232.

    Article  Google Scholar 

  • Talo, M., Baloglu, U. B., Yıldırım, Ö., & Acharya, U. R. (2019). Application of deep transfer learning for automated brain abnormality classification using mr images. Cognitive Systems Research, 54, 176–188.

    Article  Google Scholar 

  • Tibshirani, J., Athey, S., Sverdrup, E. & Wager S. (2023). grf: Generalized random forests. R package version 2.3.0.

  • Wainberg, M., Merico, D., Delong, A., & Frey, B. J. (2018). Deep learning in biomedicine. Nature Biotechnology, 36(9), 829–838.

    Article  Google Scholar 

  • Wilson, A.G., Hu, Z., Salakhutdinov, R. & Xing E.P. (2016). Deep kernel learning. In Artificial intelligence and statistics, pp. 370–378.

  • Zhang, B., Tsiatis, A. A., Laber, E. B., & Davidian, M. (2012). A robust method for estimating optimal treatment regimes. Biometrics, 68(4), 1010–1018.

    Article  MathSciNet  Google Scholar 

  • Zhao, Y., Zeng, D., Rush, A. J., & Kosorok, M. R. (2012). Estimating individualized treatment rules using outcome weighted learning. Journal of the American Statistical Association, 107(499), 1106–1118.

    Article  MathSciNet  Google Scholar 

  • Zhao, Y. Q., Laber, E. B., Ning, Y., Saha, S., & Sands, B. E. (2019). Efficient augmentation and relaxation learning for individualized treatment rules using observational data. Journal of Machine Learning Research, 20(48), 1–23.

    MathSciNet  Google Scholar 

  • Zhou, X. & Kosorok, M. R. (2017). Augmented outcome-weighted learning for optimal treatment regimes. arXiv:1711.10654.

  • Zhou, X., Mayer-Hamblett, N., Khan, U., & Kosorok, M. R. (2017). Residual weighted learning for estimating individualized treatment rules. Journal of the American Statistical Association, 112(517), 169–187.

    Article  MathSciNet  Google Scholar 

  • Zhou, Z., Athey, S., & Wager, S. (2023). Offline multi-action policy learning: Generalization and optimization. Operations Research, 71(1), 148–183.

    Article  MathSciNet  Google Scholar 

Download references

Funding

There is no funding source to disclose.

Author information

Authors and Affiliations

Authors

Contributions

Concept and design: XJ, MRK. Acquisition, analysis, or interpretation of data: XJ, XZ, MRK. Drafting of the manuscript: XJ. Critical revision of the manuscript: XJ, XZ, MRK.

Corresponding author

Correspondence to Xiaotong Jiang.

Ethics declarations

Conflict of interest

There are no conflicts of interests to disclose.

Additional information

Editor: María Óskarsdóttir.

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Supplementary Information

Below is the link to the electronic supplementary material.

Supplementary file1 (PDF 390 KB)

Appendix

Appendix

1.1 Appendix A: More on data extraction and preprocessing of the clinical data

Different centers have different data collection methods and policies during different time periods which might not conform to each other. Given that such conformity issues could create noisy heterogeneity in a problematic way, some stringent inclusion and exclusion criteria were applied to keep the multi-center, multi-stage data relatively clean. We found that there were more missing data in earlier form versions and later visits, and the most information was collected at the initial visits. Since we were interested in baseline information, only observations of the first visit of each subject with the latest form version were included. For categorical variables, categories such as unknown (e.g., 9, 99, 999, 8888, 9999 values), not applicable (form submitted did not collect such data or a skip pattern precludes such responses), or left blank were considered missing. For continuous variables, indicators of unknown, not assessed, and not available (e.g., − 4, 888.8) and extreme values outside of the normal range (e.g., height more than 80 inches) were considered missing.

Because the UDS dataset has many forms/variables that contribute to the assessment of the subject’s cognitive status, covariates that have moderate to high estimated Pearson correlations (\(> 0.5\)) with the outcome variable were excluded to avoid multicollinearity. Covariates with high estimated Pearson correlation (\(> 0.8\)) with other non-outcome covariates were excluded as well (e.g., height/weight and body mass index, and various CDR® scores). The ID data, data containing information that links with UDS and MRI, were processed similarly as the UDS data, such as removing severe missing data and multicollinearity.

There were 48 clinical variables included in the analysis (Table 8). They are: demographics (age, gender, smoking status, research center), vitals (blood pressure, resting heart rate), previous and current medication/therapy (antiadrenergic, anxiolytic, anticoagulant, antidepressant, angiotensin converting enzyme inhibitor, antipsychotic, diuretic, estrogen hormone therapy, lipid lowering medication, nonsteroidal anti-inflammatory, vasodilator), previous medical history (anxiety, diabetes, angioplasty, cardiac bypass procedure, vascular brain injury, heart valve replacement, microhemorrhage, incontinence, insomnia, obsessive-compulsive disorder, Parkinson’s disease), number of visits, and time to complete trail making test.

Variables related to smoking (e.g., consumption of tobacco in the past 30 days, number of years smoking, whether or not the subject quit smoking, etc.) were combined into two variables: an indicator of current smoker (has not quit smoking, or has consumed tobacco in the last 30 days) and an indicator of ever smoked (smoked cigarettes in the last 30 days, more than 100 cigarettes in life, non-zero years of smoking, or at least 1 cigarette smoked per day on average). Diabetes was redefined as a binary variables with 1 representing Type 1, Type 2, or other type of diabetes such as diabetes insipidus, latent autoimmune diabetes or Type 1.5, and gestational diabetes, and 0 representing no diabetes reported.

There were a total of 5,616 MRI sessions available. Each session contains multiple DICOM files, with one file representing one MRI slice. The DICOM format was preferred because it contains image information such as slice position and sequence type in the headers. We extracted the MRI slices from each subject’s MRI session in a compressed folder and removed slices without series description or image position because series description informs the sequence type of the MRI scan and image position helps sort the slices in the right order. Since the consecutive slices differed by a matter of milliseconds, we selected every 5th slice from the 150 middle slices to save space and maintain the same reasonable image dimension for every subject. End slices were discarded as they contain less useful information about the brain. We restricted to one MRI per subject because we were only interested in baseline covariates and wanted to keep the input dimension consistent.

The preprocessed ID and preprocessed UDS were merged by subject ID, their center, and visit year. Only complete cases were used because addressing data incompleteness was not our main research focus here, and imputation on such multi-center observations was often unreliable or needed extremely careful manipulation. In the preprocessed data (before merging with MRI scans but after merging UDS with ID), there were 424 observations collected from 12 centers spanning from 2015 to 2019. Among the 424 subjects, 48% had better or maintained normal cognitive status and 48% currently used antihypertensive or blood pressure medication at the initial visit. The preprocessed MRI data were merged with the medical data by file locator information so only subjects who had qualified UDS, ID, and MRI data were included, resulting in a sample size of 186. We further excluded 24 subjects whose MRI date was more than 200 days away from the initial visit date. The final sample size for the preprocessed data was 162. As mentioned above, we applied a relatively strict inclusion criteria to make sure the input data for the DL models were relatively clean and similar. The MRI data were only lightly processed to preserve the original values but could be piped through more systematic image processing tools.

1.2 Appendix B: More on transfer learning and ResNet34

Transfer learning can be regarded as a feature selection tool because images often contain a large amount of nuisance pixels. The outputs of ResNet34 prior to the dense layers have a lowered dimension of 1000, much smaller and more extracted than the original dimension. If we fed them together with medical data directly into a DL architecture, the MRI data would dominate the dimension. An alternative to transfer learning is applying unsupervised learning such as autoencoder (AE) to the MRI data. AE is a good dimension reduction method but the encoded outputs are sometimes not necessarily good representations of the original input.

In addition to the reasons mentioned in the main text, ResNet34 was chosen as the pretrained model because it has lower model complexity and relatively low top-1 and top-5 errors on the ImageNet data compared wth other famous deep learning architectures such as AlexNet or VGG. Top-1 error means the proportion of test images whose true label does not match with the prediction class with the highest estimated probability. Top-5 error means the proportion of test images whose true label is not among the 5 prediction classes with the top 5 highest estimated probabilities. We applied a pretrained model instead of training our own structure because the lower-level representations extracted from the earlier layers of existing models are generally transferrable across images.

1.3 Appendix C: Definition of Value Functions

We used the same CV definition of value function estimator and its variance estimator as in Jiang et al. (2021). Let \(j=1,\ldots , MK\) denote all MK tuning folds regardless of repetition across M repetitions and K CV folds and \(i = 1, \ldots , n_j\) be the ith observation in the jth overall fold. We applied cross validation on a dataset of size \(n_{tr}\) which will be split into training and validation sets. The CV estimated value function was used to compare tuning performance:

$$\begin{aligned} \widehat{V}^{cv} (\hat{d}_{n_{tr}}^{(-j)}) = \frac{\sum _{j=1}^{MK}\sum _{i=1}^{n_j} U_{ji} }{\sum _{j=1}^{MK}\sum _{i=1}^{n_j} W_{ji}}, \end{aligned}$$

where \(W_{ji} = \frac{1\{A_{ji} = \hat{d}_{n_{tr}}^{(-j)} ({\varvec{X}}_{ji}) \}}{\hat{P}^{(-j)}(A_{ji} \vert {\varvec{X}}_{ji})}\), \(U_{ji} = Y_{ji}W_{ji}\), \(\hat{d}_{n_{tr}}^{(-j)}\) is the decision rule estimated from the dataset of size \(n_{tr}\) with the jth fold left out, and \(\hat{P}^{(-j)}(A_{ji} \vert {\varvec{X}}_{ji})\) is the estimated propensity score. Its variance is used to measure the estimation uncertainty

$$\begin{aligned} \widehat{\text {Var}}[ \widehat{V}^{cv} (\hat{d}_{n_{tr}}^{(-j)})] = \frac{1}{K(MK-1)} \sum _{j=1}^{MK} \sum _{i=1}^{n_j} R_{ji}^2, \end{aligned}$$

where \(R_{ji} = \frac{1}{\bar{W}_j} U_{ji} - \frac{\bar{U}_j}{\bar{W}_j^2} W_{ji}\) is an influence function-inspired form of the value function with \(\bar{U}_j = \sum _{i=1}^{n_j}U_{ji}\) and \(\bar{W}_j = \sum _{i=1}^{n_j} W_{ji}\). By definition, \(\sum _{j=1}^{MK}\sum _{i=1}^{n_j} R_{ji} = 0\).

For testing results, there is no CV so \(j = 1, \ldots , M\). The estimated value function becomes

$$\begin{aligned} \widehat{V}(\hat{d}_{n_{te}}) = \frac{\sum _{i=1}^{n_{te}} U_i}{\sum _{i=1}^{n_{te}} W_i}, \end{aligned}$$

where \(n_{te}\) is sample size of the testing set, and \(U_i, W_i\) are defined similarly as \(U_{ji}\) and \(W_{ji}\) but with \(i=1,\ldots , n_{te}\) and decision rule \(\hat{d}_{n_{te}}\). Its variance is given as

$$\begin{aligned} \widehat{\text {Var}} [\widehat{V}(\hat{d}_{n_{te}}) ] = \frac{\sum _{j=1}^{M} (\widehat{V}(\hat{d}_{n_{te}, j}) - \bar{V}(\hat{d}_{n_{te}, M}))^2}{M-1}, \end{aligned}$$

where \(\hat{d}_{n_{te}, j}\) is the single estimated decision rule from the jth repetition and \(\bar{V}(\hat{d}_{n_{te}, M}) = \sum _{j=1}^{M} \widehat{V}(\hat{d}_{n_{te}, j})\) is the average estimated value functions over M single estimated decision rules. The SD is the square root of the variance.

The performance of the model during tuning and testing was determined by higher estimated value functions and lower SDs.

Rights and permissions

Springer Nature or its licensor (e.g. a society or other partner) holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Jiang, X., Zhou, X. & Kosorok, M.R. Deep doubly robust outcome weighted learning. Mach Learn 113, 815–842 (2024). https://doi.org/10.1007/s10994-023-06484-w

Download citation

  • Received:

  • Revised:

  • Accepted:

  • Published:

  • Issue Date:

  • DOI: https://doi.org/10.1007/s10994-023-06484-w

Keywords

Navigation