A Comprehensive Evaluation of Multi-task Learning and Multi-task Pre-training on EHR Time-series Data

07/20/2020 ∙ by Matthew B. A. McDermott, et al. ∙ MIT 10

Multi-task learning (MTL) is a machine learning technique aiming to improve model performance by leveraging information across many tasks. It has been used extensively on various data modalities, including electronic health record (EHR) data. However, despite significant use on EHR data, there has been little systematic investigation of the utility of MTL across the diverse set of possible tasks and training schemes of interest in healthcare. In this work, we examine MTL across a battery of tasks on EHR time-series data. We find that while MTL does suffer from common negative transfer, we can realize significant gains via MTL pre-training combined with single-task fine-tuning. We demonstrate that these gains can be achieved in a task-independent manner and offer not only minor improvements under traditional learning, but also notable gains in a few-shot learning context, thereby suggesting this could be a scalable vehicle to offer improved performance in important healthcare contexts.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Multi-task learning (MTL) is a machine learning technique aiming to improve model performance by leveraging information across many tasks. MTL has been explored extensively, especially in the computer vision and natural language processing domains 

Caruana_mtl ; luong2015multi ; zhang2017survey . Research has found that MTL can offer performance benefits for similar tasks, while for dissimilar tasks, it may induce negative transfer, where MTL harms overall performance ndirango2019generalization ; Wu2020Understanding . Additionally, some argue that MTL can act as a regularizer in learning ruder2017overview . Others have noted that MTL can induce gains in pre-training/single-task fine-tuning, few-shot learning, or fairness contexts DBLP:conf/interspeech/HeLYHC18 ; das2018mitigating ; liu2019multi ; tian2020rethinking ; mtlearning_fair_2019 .

One domain where MTL may be particularly helpful is machine learning for health (ML4H), a field which suffers from notable data difficulties that motivate the use of MTL. Clinical data is often smaller, higher-dimensional, and noisier than general domain data, and tasks are commonly susceptible to confounders that vary across institutions, time, and demographics pmlr-v106-prabhu19a ; ghassemi2019practical ; wiens2019no ; counterfactual_norm ; pmlr-v106-nestor19a ; PMID:29893864 , which may be alleviated by MTL’s regularizing effect. Additionally, clinical data poses novel challenges, such as the prevalence of diverse rare diseases 2019arXiv191113232C ; wakap2020estimating ; 10.1001/jamanetworkopen.2020.1965 , for which data is scarce and MTL’s possible benefits for few-shot learning would be critical, or the importance of model fairness 10.1001/jamainternmed.2018.3763 ; Obermeyer447 , where MTL’s possible ability to yield more equitable models under imbalanced data would be very valuable. These data difficulties impose substantial hurdles to effective learning in this domain and motivate the use and understanding of MTL in this context. While MTL has been used in ML4H harutyunyan_multitask_2017 ; suresh_learning_2018 ; Si2019 , a general understanding of its broader efficacy on clinical time-series is lacking. Further, important use cases of MTL, such as its ability to aid diverse tasks in a pre-training/fine-tuning setup, its efficacy for few-shot learning, and its advantages on imbalanced datasets have yet to be studied.

In this work, we provide a robust analysis of MTL across various learning contexts in ML4H. We design a broad set of tasks over physiological electronic health record (EHR) time-series data and use them to answer the following specific questions:

  1. [nosep]

  2. How extensive is negative transfer among traditionally studied tasks?

  3. Can MTL consistently offer benefits in performance?

  4. Can MTL pre-training offer benefits for few-shot learning?

  5. Can MTL pre-training help reduce sensitivity to population imbalance?

Our analyses reveal that negative transfer is common, and conventional MTL often does not yield performance benefits. In contrast, we observe that MTL pre-training followed by single-task fine-tuning does match or exceed both single-task and multi-task training, and offers significant gains in the few-shot regime. Unfortunately, we do not find evidence that it remedies issues of population imbalance on the chosen demographics. Overall, these findings robustly suggest that multi-task pre-training, followed by task-specific fine-tuning, can offer advantages in machine learning for health, particularly in few-shot or small data contexts.

2 Related Works

Prior work has demonstrated the benefits of MTL in various contexts 10.1023/A:1007379606734 ; 10.1145/1014052.1014067 ; NIPS2006_3143 , including in deep neural models specifically for sequence learning luong2015multitask

, face detection

8170321

, and self-supervised learning

Doersch_2017_ICCV . MTL has also been explored specifically in EHR data, including using MTL to improve mortality, readmission, long length-of-stay, or phenotype (ICD code) prediction harutyunyan_multitask_2017 , in-hospital, 30-day, and 1-year mortality prediction directly Si2019 , molecular property prediction wu2018moleculenet , or adverse outcome prediction across different patient sub-populations suresh_learning_2018 . MTL has also been used address incomplete health data hunt2018multi or rare disease detection liu2020multi ; 2019arXiv191113232C , and has been generally applied to many specific ML4H questions soleimani2017scalable ; zhou2013modeling ; schulam2015framework ; alaa2017bayesian ; futoma2017improved .

Despite the widespread use of MTL in ML4H, careful analyses of its use outside direct performance comparisons on the tasks under study has been limited. Ding et al. ding2019effectiveness offer an assessment on the propensity of negative transfer within the context of ICD code prediction, finding that negative transfer can occur on a subset of tasks even while more rare tasks still obtain a benefit, but similar analyses beyond ICD-code prediction tasks do not exist.

Transfer learning has been used extensively in the medical imaging domain; however, only recently have researchers performed focused investigation on the efficacy of transfer learning over medical imaging in general raghu2019transfusion . Pre-training has also been applied on EHR data, largely by using self-supervised autoregressive pre-training tasks shang_pre-training_2019 ; li2020behrt . In this work, we perform a more robust analysis of these techniques, moving to establish best practices.

3 Methods

(a) Our model is composed of an encoder which is used for all tasks, followed by task-specific decoders . In certain configurations (e.g., single-task training) only a subset of tasks are used in training.
(b) We use several multi-task regimes in this study, including single-task (ST), in which a separate model is trained on each task , multi-task (MT), in which a single model is trained on all tasks simultaneously, and two fine-tuning models, fine-tune decoder (FTD) and fine-tune full (FTF), both of which start with task-omitted pre-training, where a model is trained on all tasks except for , followed by fine-tuning of on task alone, either allowing only the decoder to specialize in finetuning (FTD) or allowing both the encoder and the decoder to specialize (FTF).
Figure 1: Our model and multi-task regimes.

3.1 Model Composition

We use a shared encoder sub-network, which is shared across all tasks, and separate, independent per-task decoder heads for actual prediction. More formally, given a sample from a dataset of EHR physiological timeseries and tasks with labels given by , we train a multi-task model structured via a shared encoder followed by task-specific decoder modules , such that the prediction of model on task is given by . Individual losses are computed for each task, which are then summed across all tasks to form the overall learning loss: , and learning over both encoder parameters and decoder parameters is performed via the Adam DBLP:journals/corr/KingmaB14

variant of stochastic gradient descent over

. This setup is shown in Figure 0(a).

3.2 Multi-task Regimes

In this work, recall that we are interested in answering 4 key questions about the efficacy of MTL over clinical data: how extensive is negative transfer, does MTL offer any direct performance benefits, and can MTL pre-training help us improve few-shot learning or ameliorate bias. To assess each of these, we need to use several MTL regimes, which we detail here and show graphically in Figure 0(b).

3.2.1 Single-task (ST)

As a baseline we train an unconstrained (encoder & decoder) model for each task independently, optimizing individually for each task . Note that for single-task training, we retain the same structure of our model, but as we have a separate model for each task , the encoder and decoder are both functionally task-specific. These experiments will serve as our comparison point for a global assessment of negative transfer, in which we compare a full MTL system trained across all tasks to the ST results independently, as well as our comparison point for the MTL pre-training system across raw, few-shot, and biased data.

3.2.2 Full Multi-task (MT)

We train unconstrained (encoder & decoder) models in a conventional multi-task setting across all tasks. These results are used as comparison points for single-task and fine-tuned experiments.

3.2.3 Pre-training/Fine-tuning

Task-Omitted Pre-training

Prior to fine-tuning, we must pretrain a model with the task-of-interest omitted.111We don’t pre-train with the task of interest included as we want to simulate fine-tuning on a completely unseen task, to assess the utility of MTL pre-training in ML4H for use across diverse (potentially unknown) downstream tasks, rather than on a static collection of tasks of interest. For each task , we train a model in a multi-task setting on all tasks except for task by optimizing over the loss . These models are then saved and used for several purposes. First, these models’ performance on all tasks still included in the training ensemble (e.g., all tasks other than ) can offer us a local picture of the extent of negative transfer, by allowing us to compare to the full MT results and judge how removing a single task affects the performance on the other tasks in the ensemble. This complements our more global comparison to the single-task system, which dictates how the inclusion of all other tasks affects performance. Second, these models are used as our pre-trained sources for our analyses into the effect of MTL pre-training on direct performance, few-shot learning, and imbalanced data, so that we can simulate adapting a multi-task model to a completely new task.

Fine-tuned, full (FTF) & Fine-tuned, decoder-only (FTD)

Using the trained, task-omitted models discussed above, we can fine-tune these models on the omitted task with loss directly, allowing either the full (encoder & decoder) model (FTF) or only the decoder (FTD) to update in fine-tuning. This style of training in which we pre-train on all tasks except for , then fine-tune on task alone, is meant to simulate how we could adapt a pre-trained multi-task model to a novel task.

3.3 Fine-tuning Settings

We fine-tune our models under several modified versions of our data, shown graphically in Figure 2. These are designed to assess performance and to determine if MT pre-training can help us adapt models to few-shot or imbalanced datasets. First, we fine-tuned models over the full dataset in our full-data mode. Next, to simulate the adaptation of a pre-trained model to a new or rare disease with minimal data available, we fine-tuned ST, FTF, and FTD models in a few-shot setting, with the training data sub-sampled to various degrees. Finally, to simulate adapting a model to a disease for which only imbalanced data is available in our imbalanced setting, we randomly subsampled female patients (by genotypical sex johnson_mimic-iii_2016 ) within the fine-tuning data to varying degrees, including total removal. In both the imbalanced and few-shot modes, note that pre-training is performed as described in Section 3.2.3 on all tasks but the fine-tuning task , over all available data, unaltered. Single task models used in these settings are naturally not pre-trained, and instead trained directly from scratch on the reduced datasets.

Figure 2: We assess models (either trained or fine-tuned) on three regimes: full-data (left), few-shot (middle), in which fine-tuning data was randomly subsampled to various degrees, and imbalanced (right), in which fine-tuning data was subsampled in a sex-imbalanced fashion, with random percentages of genotypically female patients removed to simulate adapting to an imbalanced dataset.

4 Experimental Settings

Data

We use the MIMIC-III dataset, containing intensive care unit (ICU) visits from patients in the Beth Israel Deaconess Medical Center between 2001-2012 johnson_mimic-iii_2016 . Data is parsed using the MIMIC-Extract wang_mimic-extract:_2019 pipeline and cohort under default parameters. We use the extracted time-series of labs and vitals, along with continuous embeddings of treatments as our input data and for our local auto-regressive tasks, and additionally use the static outputs and ICD codes for our tasks. We prepare additional data out of MIMIC for our novel tasks, including 30-day readmission flags per-patient, and do not resuscitate (DNR) & comfort measures only (CMO) code changes.222Full code and pre-processed data are available at https://github.com/mmcdermott/comprehensive_MTL_EHR.

We randomly split our data by patient via an 80-10-10 train, tuning, and test split, resulting in 17,500 patients in the train set, and 2188 patients in both the tuning and test sets. Minibatches are generated by sampling a random 48h sequence from a patient’s EHR record each training epoch. Tasks deemed as

rolling and autoregressive (See Table 1) are evaluated at 10 random time-points per patient. Static tasks are evaluated over the first 24 hours per patient, and terminal tasks are evaluated over the 48 hours prior to discharge/death.

Tasks

We use several tasks to assess the efficacy of MTL. We group these tasks into ten categories to ensure omitting a task group removes all highly correlated tasks. Scores are reported at a per-category level, averaging macro AUROCs across all tasks within that category. Full descriptions of the tasks can be found in the supplementary material A.

Task Category Abbr. Specific Task Temporal Gap Pred. Type Rel. Work # Labels MCA
Imminent Mortality MOR Mortality (24h) Rolling 2h 24h Bin. harutyunyan_multitask_2017 1.8M 0.980
Mortality (48h) Rolling 6h 48h Bin. 1.7M 0.962
Comfort Measures CMO CMO added (24h) Rolling 2h 24h Bin. lojun2010investigating 1.8M 0.992
CMO added (48h) Rolling 6h 48h Bin. 1.7M 0.987
DNR Ordered DNR DNR added (24h) Rolling 2h 24h Bin. lojun2010investigating 1.6M 0.988
DNR added (48h) Rolling 6h 48h Bin. 1.6M 0.981
Imminent Discharge DIS Discharge (24h) Rolling 2h 24h MC bertsimas2020predicting 1.5M 0.730
Discharge (48h) Rolling 6h 48h MC 1.4M 0.473
ICD Code Prediction ICD See Suppl. Mat. A.2 Static 12h N/A ML ding2019effectiveness ; harutyunyan_multitask_2017 7.7K 0.691
Long Length-of-Stay LOS Static 12h N/A Bin. pmlr-v106-nestor19a ; harutyunyan_multitask_2017 ; wang_mimic-extract:_2019 21.9K 0.529
30 Day ICU
Readmission REA Terminal N/A N/A Bin. harutyunyan_multitask_2017 21.9K 0.950
Final Acuity ACU Static 12h N/A MC Che2018 ; wang_mimic-extract:_2019 ; Si2019 21.9K 0.253
Next Timepoint WBM Will-be-measured AR 0h 1h ML chang19a 1.8M 0.920
Next hour AR 0h 1h Reg. 1.8M N/A
Future Treatment
Sequence FTS AR N/A N/A SMC Wu_vassopressur_2016 ; peng2018improving 1.8M N/A
Table 1:

Tasks that are used in the multi-task learning ensemble. Majority class accuracy (MCA) is reported for all classification tasks (macro, if task is multi-label) to give an estimate of the relative level of class imbalance of the task. Number of observed labels is reported for all tasks, reflecting the differences both between static and dynamic tasks (the former having one label per patient, the latter one per patient per hour), as well as reflecting that some tasks only have valid labels on a subset of patients/patient-hours. Both Future Treatment Sequence (FTS), and our more granular Final Acuity (ACU) task are, to the best of our knowledge, novel tasks.


Abbreviations: AR: Autoregressive, Bin.: binary classification, ML: binary multi-label classification, MC: multi-class classification, SMC: sequential decoding multi-class classification, Reg.: regression.

In addition to conventional classification tasks, we have two autoregressive tasks in our ensemble: multilabel prediction of which labs/vitals will be measured in the next hour, and continuous regression to the labs/vitals observed in the next time-point (presuming measured). Next hour regression was included in our ensemble as it was observed to be helpful to other tasks in preliminary results, but we do not report scores on the regression task as there is no clear analog to classification performance metrics and, despite the fact that this task was helpful on other tasks in training, all tested models’ overall performance on this task were no better than mean imputation.

We also use a novel future treatment sequence (FTS) task, for which individual treatments are first aggregated into categories (ventilation, vasopressors, and fluids

), then formed into a sequential decoding task in which the model predicts which subsets of treatments are used over the remainder of the patient’s stay in sequence, but ignoring the duration of application of those treatments. While all of our other tasks use a single fully connected layer for their task decoders, this sequential task uses a LSTM recurrent neural network with teacher forcing 

lamb2016professor .

Model Architecture

Data is first projected into a common embedding space via a linear layer, then passed into a sequential encoder to produce a fixed size representation. Finally, this fixed-size representation is fed into the task-specific prediction heads (a single fully connected layer with a softmax activation, except for FTS) to yield the task-specific output (See Figure 0(a)

). We investigated a linear, gated recurrent unit (GRU) recurrent neural network model 

cho-etal-2014-learning , and transformer architecture vaswani_attention_2017 for our shared encoder; however the main body of the paper will feature only the GRU results for brevity, as this was our best performing model and it (or variants thereof) is well-established in this space che2018recurrent

. Full details and results for all architectures are reported in the supplementary materials. To assess significance and give variance estimates, we re-trained our models over the same train/validation/test split with different random initialization.

333

Other schemes would require re-tuning hyperparameters, which was computationally infeasible.

Hyperparameter Tuning

Hyperparameter tuning was performed using the Hyperopt library bergstra_making_2013 , optimizing for the average AUROC across all tasks in the multi-task setting. These hyperparameters were applied to ST, FTD, and FTF models. Multi-task training was used for hyperparameter tuning like this for computational efficiency; if we performed a separate hyperparameter tuning run for all tasks independently, it would inflate the hyperparameter search time by a factor of 10, which was not feasible. To assess if choosing hyperparameters based solely on MTL performance introduced bias, we ran a single task through a separate hyperparameter tuning run, and analyzed the results of our main run as though we were choosing to optimize a single task (for example, instead of summing the losses of 10 tasks, only 1 task loss is used, which which may have required a higher learning rate to converge as fast as the MT models). In both cases, the difference in ultimate performance was negligible, so we concluded that this bias was an acceptable risk given the extent of the computational savings it offered. This effect would bias us in favour of pure multi-task (MT) results, whereas we do not find these to be top performing in practice. Further discussion of the hyperparameter search can be found in Supplementary Material B.

5 Results & Discussion

In this section, we report results and comment on the GRU system across our four guiding questions. Full results for all model types are present in the supplementary material, Section F.

5.1 Most Tasks Have Negative Transfer for MTL

By examining our intermediate, task-omitted pre-training results, we can assess the propensity of negative transfer directly. In Figure 3, we can see that the performance of our models when a single task is withheld is often greater than their performance in the full multi-task regime, indicated by the majority of the mass of the violin-plots on the left figure being greater than zero. We can also examine these results as a function of which tasks are being included in the ensemble from our right plot, seeing that there is no single task which, when added to the distribution, offers a consistent improvement on other tasks. This corroborates empirical and theoretical evidence that only highly correlated tasks result in positive task transfer ndirango2019generalization ; Wu2020Understanding

. In general, both views support that task performance under direct MT training is generally hurt when more tasks are included in the ensemble, and there are no clear outlier tasks that consistently improve performance when included, indicating that negative transfer is prevalent.

Figure 3: Performance (AUROC, scaled by 100) change on task between the full MT system and task -omitted MT system (recall Section 3.2), shown in 2 ways. On the left, we show (-axis) to demonstrate how performance on each individual task (-axis) changes as other tasks (colored dots) are removed. On the right, we show the same values (though inverted, with on the -axis), but transposed: Now task is on the -axis, showing how including affects the performance on all other tasks in the ensemble (now via colored dots).
Each dot represents the mean of the relevant difference taken over several random samples - the width of the violin plot reflects the distribution of all possible differences across all tasks.
We can see from the left plot that on almost all tasks, is larger than , indicating significant negative transfer. Similarly, on the right we can see that there are no “universally” helpful tasks to include and some tasks are consistently harmful.

5.2 Pre-training, Followed by FTF or FTD Fine-tuning Improves Final Performance

The final performance of the GRU model across all 4 main MTL regimes (ST, MT, FTD, and FTF) in the full data setting is reported in Table 2. There are several takeaways from this table. First, some variant of multi-task training reaches best-in-class performance on all but ICD Code Prediction, and some variant of fine-tuning performs best on all remaining tasks except final acuity prediction (ACU). Additionally, note that FTF alone outperforms or matches ST performance on all tasks but ICD code and imminent discharge (DIS) prediction. In the full-data setting only DNR and WBM show statistically significant comparisons between ST and FTF, and REA and WBM show statistically significant comparisons between ST and FTD, in all cases with fine-tuning performing better (significance assessed via a -test at , ). Overall this indicates that fine-tuning is capable both of matching original ST performance certainly, and possibly extending beyond it across this wide range of tasks. Note that this is true despite the fact that MT performance underperforms ST performance on all tasks save 30-day readmission (REA), ACU, and the will-be-measured task (WBM). This suggests that the fine-tuning is synergistically building on the strengths of both MT and ST training. Second, FTF outperforms or matches FTD results on all tasks DIS, REA, and WBM. This suggests that the regularization benefit of freezing the encoder layers can be helpful, but as a general rule one should fine-tune the full architecture rather than just the decoder. Third, the fact that MT consistently underperforms ST training underscores our findings from Section 5.1, suggesting that negative transfer is common here.

Task Full-data Few-shot (1%)
ST MT FTD FTF ST FTD FTF
MOR
CMO
DNR
DIS
ICD
LOS
REA
ACU
WBM
FTS
Table 2: GRU Results (AUROC, scaled by 100) subdivided among different MTL regimes, under both the full-data fine-tuning setting and few-shot (1%) setting. Bolded results indicate top performing result per each task/evaluation setting. We can see that on all tasks save ICD code prediction, some variant of multi-tasking, most commonly FTF or FTD, matches or improves over ST training by some margin in the full-data setting, and on all tasks save DIS & FTS, some variant of fine-tuning outperforms ST training in the few-shot setting, sometimes by impressive margins.

5.3 FTF Greatly Improves Few-Shot Learning Performance

Final performance of each of our pre-training regimes and our ST baseline on various levels of reduced fine-tuning/training data (simulating increasingly limited few-shot contexts) is shown in Figure 4, with a 1% dataset size highlighted explicitly in Table 2. On all tasks except FTS, DIS, and REA, FTF offers improvements over both FTD and ST training even in the extreme levels of dataset reduction, by margins ranging up to approximately 25% improvements over ST training. The comparisons at the 1% data level between FTF and ST for MOR, CMO, DNR, LOS, and ACU are all statistically significant (computed via a -test, at ); notably, this means both instances where ST outperforms FTF (DIS and FTS) are not statistically significant deviations. We can see that these gains are particularly dramatic for the MOR, CMO, DNR, and LOS task, which retain strong improvements over ST training even to as low as 0.1% dataset size, when only approximately 20 patients would be used for fine-tuning. For MOR, performance under FTF drops only by approximately 10% between full-data and 0.1% data, whereas for ST training it drops by nearly 30%. Note that we consistently see the largest gains on our rolling, binary classification tasks (MOR, CMO, and DNR), all of which show significant class imbalance (Table 1). This may suggest that this strategy is particularly suited to rolling tasks, imbalanced tasks, or binary classification tasks.

While these results do align with prior results in the natural imaging domain that multi-task pre-training can lead to significant advances in the few-shot domain tian2020rethinking , surprisingly FTD consistently underperforms FTF training. We expected that, given the reduced training set size, the decreased capacity enforced by a frozen encoder would have given FTD training an edge over FTF training, but this does not appear to be the case; instead, allowing the full model to tune is essential. In the context of clinical data, where diverse rare diseases are prevalent wakap2020estimating ; 10.1001/jamanetworkopen.2020.1965 and, even in the context of well understood diseases, individual treatment trajectories can be highly specialized hripcsak2016characterizing , the ability to fine-tune successfully on so little data is very valuable.

Figure 4: AUROC, both compared to a ST baseline (top) and raw (bottom) of multiple tasks (line color) and training styles (line style) as a function of dataset subsampling rate (-axis). Higher is better. Tasks in the upper plots are grouped according to rolling tasks (left), static tasks (middle), and autoregressive tasks (right), with a particular example highlighted in raw units across all three training styles above. We can see that on our rolling tasks, FTF models tend to perform much better in general even at drastically smaller dataset subsampling rates than ST models, and FTF training for mortality retains strong performance even for only 0.1% of the data.

5.4 MTL Pre-training Does Not Address Minority Class Imbalance

We do not find that MTL pre-training is able to reduce performance bias in favor of the majority group in our experiments. While this does not eliminate the possibility that some form of MTL can aid in reducing bias, as others have reported in other domains, from our work there is no evidence to suggest that multi-task pre-training on a balanced dataset, followed by task-specific fine-tuning on an imbalanced dataset reduces model bias over single-task training on the imbalanced dataset directly. We should note, however, that most discrepancies observed in our experiments were small overall, even when all female patients were removed entirely from the fine-tuning dataset, so it may be that this approach would still offer some gains under datasets/models with greater degrees of bias. We present full commentary on these results in Supplementary Material Section D.

6 Conclusion

In this work, we defined a battery of tasks over EHR physiological time-series data, including two novel tasks, and used these tasks to profile a battery of MTL strategies on clinical data. We find that while using traditional MTL results in systemic negative transfer, using a MT, task-independent pre-training scheme, followed by task-specific fine-tuning, yields modest improvements under standard fine-tuning, and can yield dramatic improvements in the few-shot context (up to a gain in AUROC of approximately 25% on 1% training data). This approach allowed our model to retain nearly maximal performance on prediction of imminent mortality with as little as 1% training data. We do not, however, find consistent evidence to suggest that FTF is more or less susceptible to pitfalls of training on population imbalanced data. This paper suggests that while MTL pre-training may not ameliorate model bias on imbalanced datasets, it nonetheless does offer a scalable vehicle for improving performance in important clinical settings, including rare disease detection.

Broader Impact

This paper explores the utility of multi-task learning to solve foundational challenges in machine learning for health and biomedicine. Of particular interest in our analyses is the ability of multi-task learning to aid in generalizing to smaller datasets, to simulate an application of a model to a rare disease setting, and to underrepresented subgroups, to assess the utility of multi-task learning to improve fairness concerns. While we find that this work does not offer improvements to fairness, and thus cannot be used to help solve that problem, we do find significant improvements in the few-shot setting which could aid modelling of rare or emerging diseases, or where large-scale labelling is expensive or invasive. Further, the knowledge that this technique does not, at present, appear to help ameliorate bias, helps the field understand the complex challenges faced to ensure our models are fair and help all patients.

References

  • [1] Ahmed M Alaa and Mihaela van der Schaar. Bayesian inference of individualized treatment effects using multi-task gaussian processes. In Advances in Neural Information Processing Systems, pages 3424–3432, 2017.
  • [2] Andreas Argyriou, Theodoros Evgeniou, and Massimiliano Pontil. Multi-task feature learning. In B. Schölkopf, J. C. Platt, and T. Hoffman, editors, Advances in Neural Information Processing Systems 19, pages 41–48. MIT Press, 2007.
  • [3] James Bergstra, Daniel Yamins, and David Cox. Making a Science of Model Search: Hyperparameter Optimization in Hundreds of Dimensions for Vision Architectures. In International Conference on Machine Learning, pages 115–123, February 2013.
  • [4] Dimitris Bertsimas, Jean Pauphilet, Jennifer Stevens, and Manu Tandon. Predicting inpatient flow at a major hospital using interpretable analytics. medRxiv, 2020.
  • [5] Rich Caruana. Multitask learning: A knowledge-based source of inductive bias. In Proceedings of the Tenth International Conference on International Conference on Machine Learning, ICML’93, page 41–48, San Francisco, CA, USA, 1993. Morgan Kaufmann Publishers Inc.
  • [6] Rich Caruana. Multitask learning. Mach. Learn., 28(1):41–75, July 1997.
  • [7] Chun-Hao Chang, Mingjie Mai, and Anna Goldenberg. Dynamic measurement scheduling for event forecasting using deep RL. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 951–960, Long Beach, California, USA, 09–15 Jun 2019. PMLR.
  • [8] Zhengping Che, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. Recurrent neural networks for multivariate time series with missing values. Scientific Reports, 8(1):6085, Apr 2018.
  • [9] Zhengping Che, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. Recurrent neural networks for multivariate time series with missing values. Scientific reports, 8(1):1–12, 2018.
  • [10] Zhengping Che, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. Recurrent Neural Networks for Multivariate Time Series with Missing Values. Scientific Reports, 8(1):6085, April 2018.
  • [11] Kyunghyun Cho, Bart van Merriënboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using RNN encoder–decoder for statistical machine translation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages 1724–1734, Doha, Qatar, October 2014. Association for Computational Linguistics.
  • [12] Limeng Cui, Siddharth Biswal, Lucas M. Glass, Greg Lever, Jimeng Sun, and Cao Xiao. CONAN: Complementary Pattern Augmentation for Rare Disease Detection. arXiv e-prints, page arXiv:1911.13232, November 2019.
  • [13] Abhijit Das, Antitza Dantcheva, and Francois Bremond.

    Mitigating bias in gender, age and ethnicity classification: a multi-task convolution neural network approach.

    In Proceedings of the European Conference on Computer Vision (ECCV), pages 0–0, 2018.
  • [14] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. arXiv:1810.04805 [cs], October 2018. arXiv: 1810.04805.
  • [15] Daisy Yi Ding, Chloé Simpson, Stephen Pfohl, Dave C Kale, Kenneth Jung, and Nigam H Shah. The effectiveness of multitask learning for phenotyping with electronic health records data. In PSB, pages 18–29. World Scientific, 2019.
  • [16] Carl Doersch and Andrew Zisserman. Multi-task self-supervised visual learning. In The IEEE International Conference on Computer Vision (ICCV), Oct 2017.
  • [17] Theodoros Evgeniou and Massimiliano Pontil. Regularized multi–task learning. In Proceedings of the Tenth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’04, page 109–117, New York, NY, USA, 2004. Association for Computing Machinery.
  • [18] Joseph Futoma, Sanjay Hariharan, Katherine Heller, Mark Sendak, Nathan Brajer, Meredith Clement, Armando Bedoya, and Cara O’Brien. An improved multi-output gaussian process rnn with real-time validation for early sepsis detection. In Finale Doshi-Velez, Jim Fackler, David Kale, Rajesh Ranganath, Byron Wallace, and Jenna Wiens, editors, Proceedings of the 2nd Machine Learning for Healthcare Conference, volume 68 of Proceedings of Machine Learning Research, pages 243–254, Boston, Massachusetts, 18–19 Aug 2017. PMLR.
  • [19] Marzyeh Ghassemi, Tristan Naumann, Peter Schulam, Andrew L Beam, Irene Y Chen, and Rajesh Ranganath.

    Practical guidance on artificial intelligence for health-care data.

    The Lancet Digital Health, 1(4):e157–e159, 2019.
  • [20] Milena A. Gianfrancesco, Suzanne Tamang, Jinoos Yazdany, and Gabriela Schmajuk. Potential Biases in Machine Learning Algorithms Using Electronic Health Record Data. JAMA Internal Medicine, 178(11):1544–1547, 11 2018.
  • [21] Hrayr Harutyunyan, Hrant Khachatrian, David C Kale, Greg Ver Steeg, and Aram Galstyan. Multitask learning and benchmarking with clinical time series data. Scientific data, 6(1):1–18, 2019.
  • [22] Di He, Boon Pang Lim, Xuesong Yang, Mark Hasegawa-Johnson, and Deming Chen. Improved ASR for under-resourced languages through multi-task learning with acoustic landmarks. In B. Yegnanarayana, editor, Interspeech 2018, 19th Annual Conference of the International Speech Communication Association, Hyderabad, India, 2-6 September 2018, pages 2618–2622. ISCA, 2018.
  • [23] G. Hripcsak, P.B. Ryan, J.D. Duke, N.H. Shah, R.W. Park, V. Huser, M.A. Suchard, M.J. Schuemie, F.J. DeFalco, A. Perotte, et al. Characterizing treatment pathways at scale using the ohdsi network. Proceedings of the National Academy of Sciences, 113(27):7329–7336, 2016.
  • [24] George Hripcsak, David J Albers, and Adler Perotte. Parameterizing time in electronic health record studies. Journal of the American Medical Informatics Association, 22(4):794–804, 02 2015.
  • [25] Xin J Hunt, Saba Emrani, Ilknur Kaynar Kabul, and Jorge Silva. Multi-task learning with incomplete data for healthcare. arXiv preprint arXiv:1807.02442, 2018.
  • [26] Alistair E. W. Johnson, Tom J. Pollard, Lu Shen, Li-wei H. Lehman, Mengling Feng, Mohammad Ghassemi, Benjamin Moody, Peter Szolovits, Leo Anthony Celi, and Roger G. Mark. MIMIC-III, a freely accessible critical care database. Scientific Data, 3(1):1–9, May 2016.
  • [27] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In Yoshua Bengio and Yann LeCun, editors, 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015.
  • [28] Alex M Lamb, Anirudh Goyal Alias Parth Goyal, Ying Zhang, Saizheng Zhang, Aaron C Courville, and Yoshua Bengio. Professor forcing: A new algorithm for training recurrent networks. In Advances In Neural Information Processing Systems, pages 4601–4609, 2016.
  • [29] Yikuan Li, Shishir Rao, José Roberto Ayala Solares, Abdelaali Hassaine, Rema Ramakrishnan, Dexter Canoy, Yajie Zhu, Kazem Rahimi, and Gholamreza Salimi-Khorshidi. Behrt: transformer for electronic health records. Scientific Reports, 10(1):1–12, 2020.
  • [30] Luchen Liu, Zequn Liu, Haoxian Wu, Zichang Wang, Jianhao Shen, Yiping Song, and Ming Zhang. Multi-task learning via adaptation to similar tasks for mortality prediction of diverse rare diseases. arXiv preprint arXiv:2004.05318, 2020.
  • [31] Xiaodong Liu, Pengcheng He, Weizhu Chen, and Jianfeng Gao. Multi-task deep neural networks for natural language understanding. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 4487–4496, Florence, Italy, July 2019. Association for Computational Linguistics.
  • [32] Sharon L Lojun, Christina J Sauper, Mitchell Medow, William J Long, Roger G Mark, and Regina Barzilay. Investigating resuscitation code assignment in the intensive care unit using structured and unstructured data. In AMIA Annual Symposium Proceedings, volume 2010, page 467. American Medical Informatics Association, 2010.
  • [33] Minh-Thang Luong, Quoc V. Le, Ilya Sutskever, Oriol Vinyals, and Lukasz Kaiser. Multi-task sequence to sequence learning, 2015.
  • [34] Minh-Thang Luong, Quoc V. Le, Ilya Sutskever, Oriol Vinyals, and Lukasz Kaiser. Multi-task sequence to sequence learning. In Yoshua Bengio and Yann LeCun, editors, 4th International Conference on Learning Representations, ICLR 2016, San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings, 2016.
  • [35] Aya A. Mitani and Sebastien Haneuse. Small Data Challenges of Studying Rare Diseases. JAMA Network Open, 3(3):e201965–e201965, 03 2020.
  • [36] Anthony Ndirango and Tyler Lee.

    Generalization in multitask deep neural classifiers: a statistical physics approach.

    In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems 32, pages 15862–15871. Curran Associates, Inc., 2019.
  • [37] Bret Nestor, Matthew B. A. McDermott, Willie Boag, Gabriela Berner, Tristan Naumann, Michael C. Hughes, Anna Goldenberg, and Marzyeh Ghassemi. Feature robustness in non-stationary health records: Caveats to deployable model performance in common clinical machine learning tasks. In Finale Doshi-Velez, Jim Fackler, Ken Jung, David Kale, Rajesh Ranganath, Byron Wallace, and Jenna Wiens, editors, Proceedings of the 4th Machine Learning for Healthcare Conference, volume 106 of Proceedings of Machine Learning Research, pages 381–405, Ann Arbor, Michigan, 09–10 Aug 2019. PMLR.
  • [38] Ziad Obermeyer, Brian Powers, Christine Vogeli, and Sendhil Mullainathan. Dissecting racial bias in an algorithm used to manage the health of populations. Science, 366(6464):447–453, 2019.
  • [39] Luca Oneto, Michele Doninini, Amon Elders, and Massimiliano Pontil. Taking advantage of multitask learning for fair classification. In Proceedings of the 2019 AAAI/ACM Conference on AI, Ethics, and Society, AIES ’19, page 227–237, New York, NY, USA, 2019. Association for Computing Machinery.
  • [40] Xuefeng Peng, Yi Ding, David Wihl, Omer Gottesman, Matthieu Komorowski, Li-wei H Lehman, Andrew Ross, Aldo Faisal, and Finale Doshi-Velez.

    Improving sepsis treatment strategies by combining deep and kernel-based reinforcement learning.

    In AMIA Annual Symposium Proceedings, volume 2018, page 887. American Medical Informatics Association, 2018.
  • [41] Viraj Prabhu, Anitha Kannan, Murali Ravuri, Manish Chaplain, David Sontag, and Xavier Amatriain. Few-shot learning for dermatological disease diagnosis. In Finale Doshi-Velez, Jim Fackler, Ken Jung, David Kale, Rajesh Ranganath, Byron Wallace, and Jenna Wiens, editors, Proceedings of the 4th Machine Learning for Healthcare Conference, volume 106 of Proceedings of Machine Learning Research, pages 532–552, Ann Arbor, Michigan, 09–10 Aug 2019. PMLR.
  • [42] Maithra Raghu, Chiyuan Zhang, Jon Kleinberg, and Samy Bengio. Transfusion: Understanding transfer learning for medical imaging. In Advances in Neural Information Processing Systems, pages 3342–3352, 2019.
  • [43] Alvin Rajkomar, Eyal Oren, Kai Chen, Andrew M Dai, Nissan Hajaj, Michaela Hardt, Peter J Liu, Xiaobing Liu, Jake Marcus, Mimi Sun, et al.

    Scalable and accurate deep learning with electronic health records.

    NPJ Digital Medicine, 1(1):18, 2018.
  • [44] R. Ranjan, V. M. Patel, and R. Chellappa.

    Hyperface: A deep multi-task learning framework for face detection, landmark localization, pose estimation, and gender recognition.

    IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(1):121–135, 2019.
  • [45] Sebastian Ruder. An overview of multi-task learning in deep neural networks. arXiv preprint arXiv:1706.05098, 2017.
  • [46] Peter Schulam and Suchi Saria. A framework for individualizing predictions of disease trajectories by exploiting multi-resolution structure. In Advances in Neural Information Processing Systems, pages 748–756, 2015.
  • [47] Junyuan Shang, Tengfei Ma, Cao Xiao, and Jimeng Sun. Pre-training of Graph Augmented Transformers for Medication Recommendation. arXiv:1906.00346 [cs], June 2019. arXiv: 1906.00346.
  • [48] Yuqi Si and Kirk Roberts. Deep patient representation of clinical notes via multi-task learning for mortality prediction. AMIA Joint Summits on Translational Science proceedings. AMIA Joint Summits on Translational Science, 2019:779–788, May 2019. 31259035[pmid].
  • [49] Vergil N Slee. The international classification of diseases: ninth revision (icd-9). Annals of internal medicine, 88(3):424–426, 1978.
  • [50] Hossein Soleimani, James Hensman, and Suchi Saria. Scalable joint models for reliable uncertainty-aware event prediction. IEEE transactions on pattern analysis and machine intelligence, 40(8):1948–1963, 2017.
  • [51] Adarsh Subbaswamy and Suchi Saria. Counterfactual normalization: Proactively addressing dataset shift using causal mechanisms. In Ricardo Silva, Amir Globerson, and Amir Globerson, editors, 34th Conference on Uncertainty in Artificial Intelligence 2018, UAI 2018, volume 2, pages 947–957. Association For Uncertainty in Artificial Intelligence (AUAI), 1 2018. 34th Conference on Uncertainty in Artificial Intelligence 2018, UAI 2018 ; Conference date: 06-08-2018 Through 10-08-2018.
  • [52] Harini Suresh, Jen J. Gong, and John Guttag. Learning Tasks for Multitask Learning: Heterogenous Patient Populations in the ICU. Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining - KDD ’18, pages 802–810, 2018. arXiv: 1806.02878.
  • [53] Yonglong Tian, Yue Wang, Dilip Krishnan, Joshua B Tenenbaum, and Phillip Isola. Rethinking few-shot image classification: a good embedding is all you need? arXiv preprint arXiv:2003.11539, 2020.
  • [54] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ł ukasz Kaiser, and Illia Polosukhin. Attention is all you need. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems 30, pages 5998–6008. Curran Associates, Inc., 2017.
  • [55] Stéphanie Nguengang Wakap, Deborah M Lambert, Annie Olry, Charlotte Rodwell, Charlotte Gueydan, Valérie Lanneau, Daniel Murphy, Yann Le Cam, and Ana Rath. Estimating cumulative point prevalence of rare diseases: analysis of the orphanet database. European Journal of Human Genetics, 28(2):165–173, 2020.
  • [56] Shirly Wang, Matthew B. A. McDermott, Geeticka Chauhan, Marzyeh Ghassemi, Michael C. Hughes, and Tristan Naumann. Mimic-extract: A data extraction, preprocessing, and representation pipeline for mimic-iii. In Proceedings of the ACM Conference on Health, Inference, and Learning, CHIL ’20, page 222–235, New York, NY, USA, 2020. Association for Computing Machinery.
  • [57] Jenna Wiens, Suchi Saria, Mark Sendak, Marzyeh Ghassemi, Vincent X Liu, Finale Doshi-Velez, Kenneth Jung, Katherine Heller, David Kale, Mohammed Saeed, et al. Do no harm: a roadmap for responsible machine learning for health care. Nature medicine, 25(9):1337–1340, 2019.
  • [58] Mike Wu, Marzyeh Ghassemi, Mengling Feng, Leo A Celi, Peter Szolovits, and Finale Doshi-Velez. Understanding vasopressor intervention and weaning: risk prediction in a public heterogeneous clinical time series database. Journal of the American Medical Informatics Association, 24(3):488–495, 10 2016.
  • [59] Sen Wu, Hongyang R. Zhang, and Christopher Ré. Understanding and improving information transfer in multi-task learning. In International Conference on Learning Representations, 2020.
  • [60] Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S Pappu, Karl Leswing, and Vijay Pande. Moleculenet: a benchmark for molecular machine learning. Chemical science, 9(2):513–530, 2018.
  • [61] Cao Xiao, Edward Choi, and Jimeng Sun. Opportunities and challenges in developing deep learning models using electronic health records data: a systematic review. Journal of the American Medical Informatics Association : JAMIA, 25(10):1419—1428, October 2018.
  • [62] Yu Zhang and Qiang Yang. A survey on multi-task learning. arXiv preprint arXiv:1707.08114, 2017.
  • [63] Jiayu Zhou, Jun Liu, Vaibhav A Narayan, Jieping Ye, Alzheimer’s Disease Neuroimaging Initiative, et al. Modeling disease progression via multi-task learning. NeuroImage, 78:233–248, 2013.

Appendix A Detailed Task Definitions

a.1 Task Selection Overview

We intentionally chose a diverse set of tasks designed to span the kinds of tasks we often assess in ML4H. These include clinically motivated tasks, such as those that measure acuity, likelihood of future treatments, or which labs & vitals will be measured in the next time-point. Operationally motivated tasks may be included, such as prediction of imminent discharge, ICD billing codes, and readmission risk. We also include a local auto-regressive task to encourage the model to maintain a faithful, full capacity representation of the input.

The goal of selecting such a diverse battery of tasks, even while common MTL literature suggests that tasks must be similar in order for MTL to offer benefits [45], is that in a pre-training/fine-tuning context, in which we ostensibly don’t know what the fine-tuning task will look like at the time of pre-training, we need the diverse ensemble to ensure generalizability to a wide range of tasks. Further work could examine under what arrangements of task similarity MT pre-training offers benefits on fine-tuning tasks. However, we intuit that pre-training on a diverse set of tasks, rather than a tightly correlated set (which may be very divergent from the fine-tuning task), would offer the best performance on an unknown downstream task in the context of a pre-training/fine-tuning regime.

Below, we detail all our tasks used, their precise source in the input data, related work on each task, and give baseline statistics about their label spaces and chance frequencies.

a.2 Our Tasks

Imminent Mortality: MOR

Description: We predict imminent mortality across both a 24h and 48h window, using 2h and 6h gap times, respectively. These predictions can be used as indicators of imminent physiological decompensation, and spanning multiple prediction windows gives the system incentive to learn a representation both reflecting immediate and urgent, but not necessarily immediate, signals of decompensation. This task is a binary classification task.
Data Source: Time of death was extracted from MIMIC Extract’s provided static output [56].
Prior Art: Imminent mortality has been used as a proxy for physiological decompensation historically in several studies. Harutyunyan et al. [21], for example, explore this task.
Statistics: Imminent mortality prediction is a highly imbalanced task, with approximately 98%, 96.2% of patients not dying within 24, 48 hours, respectively.

Comfort Measures: CMO

Description: “Comfort Measures Only” (CMO) orders indicate that the (usually terminally ill) patient has requested to receive care only designed to provide comfort, not treatment, and otherwise the course of illness should be allowed to progress (typically to mortality). Predicting that a patient will soon add a CMO order provides another view towards a measure of imminent acuity. Like mortality, we predict CMO across both a 24h and 48h window, using a 2h/6h gap time, respectively. This task is a binary classification task.
Data Source: This signal was extracted from MIMIC directly, via the code_status table. A forked version of MIMIC Extract [56] with this and all other additions we made for our extraction will be made publicly available.
Prior Art: In the traditional ML4H community, CMO prediction is somewhat understudied. However, examples do exist, such as in the work of Lojun et al. [32], who use natural language processing over clinical notes and structured data to predict CMO codes and do not resuscitate (DNR) codes.
Statistics: CMO prediction is a highly imbalanced task, with roughly 99.2, 98.7% of patients not registering a CMO order within 24, 48 hours, respectively.

DNR Ordered: DNR

Description: “Do Not Resuscitate” (DNR) orders indicate that the patient has requested to not receive resuscitation care (e.g., cardiopulmonary resuscitation a.k.a. CPR) and that, should those interventions be necessary, the patient should instead be allowed to die. Predicting that a patient will soon request a DNR order provides another view towards a measure of imminent acuity. Like mortality, we predict DNR across both a 24h and 48h window, using 2h and 6h gap times, respectively. This task is a binary classification task.
Data Source: This signal was extracted from MIMIC directly, via the code_status table. A forked version of MIMIC Extract [56] with this and all other additions we made for our extraction will be made publicly available.
Prior Art: In the traditional ML4H community, CMO prediction is somewhat understudied. However, examples do exist, such as in the work of Lojun et al. [32], who use natural language processing over clinical notes and structured data to predict CMO codes and do not resuscitate (DNR) codes.
Statistics: DNR prediction is a highly imbalanced task, with roughly 98.8, 98.1% of patients not yielding a new DNR order within 24, 48 hours, respectively.

Imminent Discharge: DIS

Description: Like the prior tasks, we predict imminent discharge across both a 24h and 48h window, using a 2h/6h gap time. Unlike the prior tasks, the imminent discharge task is a multi-class classification task, forcing the model to predict to where the patient will be discharged, among several possible destinations outlined in Table 3. Whereas the former tasks provide a view into acuity by predicting an imminent event that indicates a heightened acuity, predicting imminent discharge indicates that the patient’s has become less acutely ill. Additionally, prediction of imminent discharge has operational benefits, by enabling hospitals to estimate how many beds will be free in the ICU in the near term future.
Data Source: We use the discharge time and location provided in the MIMIC Extract static output [56].
Prior Art: Imminent discharge has been primarily predicted in operational contexts, rather than for use as a signal of acuity; for example, Bertsimas et al. [4] predict imminent discharge to estimate patient flow and aid in scheduling.
Statistics: Each possible discharge location, along with the percent of patients that are discharged to that location within 24 hours/48 hours of any given timepoint, respectively, are shown in Table 3.

Discharge Location % @ 24h % @ 48h
No Discharge 73.0% 47.3%
Home Health Care 7.7% 15.1%
Home 7.3% 14.0%
Skilled Nursing Facility (SNF) 5.2% 10.3%
Rehab/Distinct Part Hosp 4.0% 7.9%
Long Term Care Hospital 1.1% 2.2%
Discharge-Transfer Cancer/Children Hospital 0.4% 0.9%
Short Term Hospital 0.3% 0.6%
Discharge-Transfer To Psych Hospital 0.3% 0.6%
Hospice-Home 0.3% 0.5%
Left Against Medical Advice 0.1% 0.2%
Hospice-Medical Facility 0.1% 0.2%
Home With Home Iv Provider 0.0% 0.1%
Integrated Care Facility (ICF) 0.0% 0.1%
Other Facility 0.0% 0.1%
Discharge-Transfer To Federal Hc 0.0% 0.0%
Snf-Medicaid Only Certif 0.0% 0.0%
Table 3: All discharge locations we predict, along with the percent of patient-houts across the entire dataset that are discharged to that location within 24, 48 hours, respectively.
ICD Code Prediction: ICD

Description: We predict the multi-label presence of each major ICD category [49] (full list reported in Table 4). ICD code prediction can be seen as a phenotyping task; however, as ICD codes are applied to summarize a patient’s entire stay for billing purposes, it is more accurately interpreted as an operational task which could aid a clinic’s billing department.
Data Source: This is extracted from the provided default ICD code dataframe output by MIMIC Extract [56].
Prior Art: Prediction of ICD codes is commonly used as a phenotyping task [21, 15].
Statistics: Percentage of patients with at least one code in each ICD category used for this task are shown in Table 4. Recall that results are reported via macro AUROC across all labels, so though there are a subset with extreme class imbalance, they only play a moderate role on the overall AUROC reported for this task.

Category % Patients
Circulatory 72.2%
Endocrine 63.5%
Respiratory 53.0%
Injury 50.0%
Digestive 48.9%
Ill Defined 48.7%
Genitourinary 48.0%
Blood 47.9%
Mental Health 43.2%
Infection 41.8%
Nervous 40.8%
Musculoskeletal 33.5%
Neoplasm 29.8%
Skin 22.7%
Congenital 8.3%
Pregnancy 1.2%
Unknown 0.02%
Perinatal 0.00%
Table 4: Percent of patients who have at least one ICD code within each of the below categories.
Long Length-of-Stay: LOS

Description: We translate the "remaining LOS" regression task into a binary classification task by classifying patients as either below or above the average length of stay, rounded to the nearest day, in our cohort, which was 3 days.
Data Source: We use the default LOS output in MIMIC Extract’s static outputs [56].
Prior Art: Long LOS has been predicted numerous times, both in a classification sense for 3-day LOS [56] and 7-day LOS [21].
Statistics: This is a balanced prediction task, with a positive rate of approximately 52.9%.

30 Day ICU Readmission: REA

Description: Rapid readmission is a serious operational concern to clinics, as they face financial penalties from certain insurance providers if a patient is discharged, but rapidly requires readmission. Given the limitations of the MIMIC-III data, which only covers data for patients admitted to the ICU, we predict solely 30-day ICU readmission, so that we can trust both our positive and negative labels, in a binary classification context. As MIMIC-Extract extracts a cohort only of patients’ first ICU stays [56], this task also has the bias of only being analyzed on a new ICU visit for a patient, and would not be applicable for a population of repeat patients.
Data Source: We constructed labels for this task our-self, by looking to see if there was a second record of an ICU admission for that patient within MIMIC within 30 days. We used MIMIC’s intime and outtime as our definitive record of admission start and end times.
Prior Art: Rajkomar et al. [43] examine overall hospital readmission in their work.
Statistics: This task is a relatively imbalanced task, with approximately 95% of patients not being readmitted.

Final Acuity: ACU

Description: This task is an extension of the common in-hospital mortality prediction task. The labels cover a more granular target space, outlined in Table 5. This extension makes the task much more difficult, and renders our results not directly comparable to previously published numbers.
Data Source: We use the death time and discharge locations output by default in the MIMIC Extract static output for this task [56].
Prior Art: Various sub-forms of this task have been explored historically. In-ICU and in-hospital mortality, for example, have been explored in numerous ways [21, 56, 10]. Prediction of final discharge location is an extended version of the mortality task. Challenging the model to predict over both spaces jointly is novel, to the best of our knowledge.
Statistics: Label values and the percent of patients with each label are shown in Table 5.

Final Acuity Event % Patients
Discharge to Home Health Care 25.3%
Discharge to Home 24.0%
Discharge to SNF 17.2%
Discharge to Rehab/Distinct Part Hosp 13.2%
In ICU Mortality 7.4%
In Hospital Mortality 3.7%
Discharge to Long Term Care Hospital 3.6%
Discharge-Transfer Cancer/Children Hospital 1.5%
Discharge to Short Term Hospital 1.1%
Discharge-Transfer To Psych Hosp 1.0%
Discharge to Hospice-Home 0.9%
Left Against Medical Advice 0.4%
Discharge to Hospice-Medical Facility 0.3%
Discharge to Home With Home Iv Provider 0.2%
Discharge to ICF 0.1%
Discharge to Other Facility 0.1%
Discharge-Transfer To Federal Hc 0.0%
Discharge to SNF-Medicaid Only Certified 0.0%
Table 5: Possible labels for our “Final Acuity” task, with the % of Patients that have that label in our cohort.
Next Timepoint: WBM

Description: We predict two local autoregressive tasks designed to assess the model’s ability to forecast what will happen in the immediate next hour. First, we predcit which labs & vitals will be measured in the next hour via multi-label binary classification; second, we predict what values will be observed for those labs & vitals that are measured via continuous regression. For reporting purposes, we present only the classification task, as our regression performance was universally poor, save on a subset of commonly measured labs/vitals (in particular, blood pressures, oxygen saturation, and heart rate). However, it is included in training ensembles as, in prior experiments, we observed that, surprisingly, removing it actually weakened the overall ensemble.
Data Source: This is sourced directly from the time-varying hourly input features output by MIMIC Extract [56].
Prior Art: Predicting which labels will be measured in the next time-point has been explored using reinforcement learning [7].
Statistics: The labs & vitals over which we predict, along with their observed measurement rates and average continuous values are shown in Table 6.

Lab / Vital Measurement Rate (%)
Heart Rate 91.6%
Respiratory Rate 90.2%
Diastolic Blood Pressure 88.8%
Systolic Blood Pressure 88.8%
Mean Blood Pressure 88.3%
Oxygen Saturation 87.6%
Temperature 29.8%
Glucose 23.2%
Central Venous Pressure 20.3%
Glascow Coma Scale Total 17.7%
Hematocrit 11.2%
Potassium 10.4%
Sodium 9.9%
Pulmonary Artery Pressure Systolic 9.4%
Chloride 9.4%
Ph 9.2%
Hemoglobin 9.0%
Creatinine 8.8%
Blood Urea Nitrogen 8.7%
Bicarbonate 8.6%
Magnesium 8.3%
Anion Gap 8.3%
Partial Pressure Of Carbon Dioxide 8.3%
Co2 (Etco2, Pco2, Etc.) 8.3%
Platelets 8.2%
Positive End-Expiratory Pressure Set 8.0%
White Blood Cell Count 7.9%
Calcium 7.1%
Fraction Inspired Oxygen Set 7.0%
Tidal Volume Observed 6.8%
Mean Corpuscular Hemoglobin Concentration 6.2%
Mean Corpuscular Volume 6.2%
Red Blood Cell Count 6.2%
Mean Corpuscular Hemoglobin 6.2%
Partial Thromboplastin Time 6.0%
Prothrombin Time Inr 5.7%
Prothrombin Time Pt 5.7%
Peak Inspiratory Pressure 5.6%
Phosphate 5.5%
Phosphorous 5.4%
Respiratory Rate Set 4.9%
Calcium Ionized 4.9%
Fraction Inspired Oxygen 4.7%
Tidal Volume Set 4.6%
Partial Pressure Of Oxygen 4.3%
Cardiac Index 3.6%
Co2 3.5%
Pulmonary Artery Pressure Mean 3.5%
Tidal Volume Spontaneous 3.5%
Plateau Pressure 3.4%
Systemic Vascular Resistance 3.4%
Potassium Serum 3.2%
Cardiac Output Thermodilution 3.0%
Lactate 2.7%
Weight 2.5%
Lactic Acid 2.4%
Table 6: The Labs & Vitals we predict over for our next timepoint task, along with the % of time they are measured. These are also the input labs & vitals we use as the input to our pipeline. Note that all inputs were centered and scaled to unit variance when measured, so our continuous regression task had ouputs that were 0 mean and unit variance.
Future Treatment Sequence: FTS

Description: This task is a global, autoregressive task designed to force the model to learn how to predict the high-level future of the patient’s care. Labels for this task are the sequence of treatment combinations the patient will receive over the remainder of their stay, in a duration agnostic manner. In our context, individual treatments are aggregated first into the broad categories ventilation, vasopressors, and fluid boli, so each sequential label is an element of the powerset of these categories. This sequence of treatment sets is duration agnostic—i.e., elements in the label sequence merely represent that a patient will receive this combination of treatments next in sequence, and do not comment on for how long the patient will receive them. Accordingly, there are no sequential duplicates in this label sequence. The task-specific decoder is an LSTM recurrent neural network, and during both training and evaluation we use teacher forcing [28]—i.e., we pass in the true sequence of treatments when asking the system to predict the next element. The input from the encoder is used as the initial hidden state to this decoder model. This means that our evaluation results should not be interpreted as the model’s ability to correctly decode the full projection of future treatments, but rather the model’s ability to understand how clinicians will transition between treatments given this patient’s unique record to date. Changing the formulation of this task to not use teacher forcing during evaluation would make the task much more difficult, and represents a valuable area of future work. This task is, to the best of our knowledge, novel.
Data Source: We construct labels for this task ourselves based on the provided time-varying treatments produced by MIMIC Extract.
Prior Art: While this task directly has not been explored previously, various researchers have investigated learning optimal control policies for applications of treatments, including ventilators or vasopressors [58, 40, 24].
Statistics: We show the relative frequency of the various treatment combinations (recall that our labels here are subsets

of treatments, expanded into a one-hot encoding over the entire powerset) in Figure 

5.

Figure 5: An Upset plot showing the frequency of relative combinations of our three treatment types: Vasopressors (vaso), Ventilation (vent), and Fluid Bolus administration (bolus).

Appendix B Hyperparameter Search Analysis

b.1 Hyperparameter Search Algorithm Details

We used the bayesian hyperparameter tuning Hyperopt library [3]—specifically, Hyperopts Tree of Parzen Estimators (TPE) optimization method. All hyperparameter tuning was performed to optimize for the average AUROC (or, for our regression task, an AUROC analog score defined to be ) across all tasks and labels under full MT training. We allowed Hyperopt to tune both the hyperparameters for all architectures as well as select which architecture (GRU, Linear baseline, Transformer) should be used. The system was permitted to devote more samples to higher performing architectures. In order to ensure that all architectures were sampled to a reasonable degree, we did a separate, albeit smaller run of Hyperopt iterations on each architecture independently, allowing the system to use the prior samples of that architecture from the joint training to inform its algorithm’s next hyperparameter selection. We also performed a mild amount of manual hyperparameter selection early in the process, largely to adjust sampling distributions for the Hyperopt search space before the final overall hyperparameter tuning run.

For final runs, we used the best performing (as evaluated on our 10% validation set) hyperparameters found across all iterations for each architecture independently. Final selections are discussed below.

b.2 Expanded Hyperparameter Search Biases Discussion

The training procedure described above was chosen for computational efficiency; by jointly optimizing over all tasks and architectures in a systematic, motivated fashion we can ensure we have good coverage over all likely high-performing parameters for all tasks and architectures, while dramatically minimizing our overall search time. However, this does induce two potential biases.

First, not all architectures received the same number of samples. In particular, architectures that both (1) took longer to run, and (2) seemed less promising in early experiments would receive fewer tuning samples. This may disproportionately penalize, for example, the transformer model, which both took the longest to run and had poor preliminary results.

Second, this system may favor multi-task based systems over single-task systems, potentially in a task specific manner. This is a particularly poignant concern, given that MTL has been postulated to have a regularizing effect in past literature [45]; thus, we might be concerned in this situation that the multi-task chosen hyperparameters would inappropriately prefer less regularization than a true optimal ST model would. We note two mitigating factors that make us not concerned that this bias plays a serious role. First, our ST models already outperform the full MT model; thus, our results do not appear concordant with this bias. Second, we additionally performed a (admittedly smaller) secondary round of hyperparameter tuning on a single task alone, and analyzed the results of our full hyperparamaeter tuning system as though we were optimizing for that same task, and in both cases found the performance difference between the chosen optimal parameters negligible.

b.3 Search Space

For our hyperparameter search procedure, we searched over a wide variety of parameters, including number of epochs, batch size, learning rate, learning rate decay paradigms, L2 regularization penalty, dropout, a weighting for the regression task losses’ contribution to training, the maximum length of a patients record included, the size, number, and configuration of various hidden layers, pooling and fully connected stack parameters, and various other model-specific options. All search distributions are shown in Table 7. Note that some of these search space parameters are specific to our implementation; for example, the “Encoder Hidden Size Multiplier” used in the transformer architecture encodes the relationship between the size of the overall internal transformer hidden state and the number of attention heads (the former must be divisible by the latter). Additionally, note that this search distribution reflects our final search distribution, used to enable the most granular refinement, while optimal parameters may have been chosen for certain model types by earlier incarnations of the search on wider, more uncertain search distributions, or search distributions with fewer options enabled.

Architecture Hyperparameter Search Space
Shared # Epochs Uniform[15, 30]
Batch Size Uniform[4, 64]
Learning Rate (LR) Lognormal[-7, 0.5]
LR Decay Loguniform[-2.3, 0]
LR Step Uniform[1, 25]
Hidden Dropout Uniform[0, 0.5]
Hidden Size Uniform[8, 256]
Weight Decay Uniform[0, 1]
Input Window Size (h) Uniform[12, 168]
Transformer Encoder Hidden Size Multiplier Uniform[4, 32]
Intermediate Size Uniform[32, 256]
# Attention Heads Uniform[2, 24]
# Hidden Layers Uniform[1, 4]
Use CLS Analog Choice[True, False]
GRU Bidirectional Choice[True, False]
# Hidden Layers Uniform[1, 3]
Encoder Hidden Layer Size Uniform[16, 512]
Encoder # Fully Connected Layers Uniform[0, 3]
GRU Pooling Method Choice[max, avg, last]
GRU FC Layer Base Size Uniform[32, 512]
GRU FC Layer Growth Loguniform[-1.1, 1.1]
Table 7: The Hyperopt search space we used in this work. Distributions are noted in pseudocode, but typically refer directly to the appropriate analog in Hyperopt

(e.g., a uniform distribution over an integral parameter maps to the quantized uniform distribution that only outputs integers).

Shared hyperparameters were used across all 3 model types. The linear model required no non-shared hyperparameters.

b.4 Final Optimal Parameters

Projection Model

Our projected model ran for 22 epochs, using a projection dimension of 140, a batch size of 16 and learning rate of 0.00024 with no decay, and dropout at 0.22.

Gru

Our optimal GRU model was a unidirectional, 2 126-dimensional hidden layer GRU which ran for 18 epochs with an effective batch size of 254, a projection dimension of 233, a learning rate of 0.001 with no decay and dropout of 0.42. It used the last-element GRU pooling, and had no fully connected layers post-processing the GRU samples.

Transformer

Our optimal transformer ran for 24 epochs with a projection dimensionality of 72, 12 attention heads, 1 hidden layer, an intermediate size of 55, a specially added overall sequence sentinel token, a learning rate of 0.002, a batch size of 30, and dropout of 0.18.

Appendix C Full Architecture Details for Linear Baseline & Transformer Architectures

Linear Baseline Projection Architecture

For our linear baseline projection architecture, we simply project all inputs into the same embedding space, concatenate the full input sequence together (as we used a fixed input window size, this is a fixed-size representation), then pass that through to the per-task decoders. Note that the only “multi-tasking” that happens in this representation is that the projection layers mapping content to the input embedding spacea are shared.

Transformer

For our transformer architecture, we used a bidirectional transformer (e.g., a BERT architecture [14]) operating on the continuous (projected) embeddings of all input features. We did not employ any positional embeddings; however, we did add an auxiliary “CLS” token to the front of each sequence which was used as the source of our pooled representation (much like BERT). This strategy of pooling was found to be preferred in our hyperparameter tuning.

Appendix D GRU Population Imbalance Experiments

We examined our GRU model under population imbalance in the following manner. First, we pre-trained a multi-task model on full, unaltered data. Next, we fine-tuned the model on a dataset that was randomly subsampled in a genotypical sex imbalanced manner, up to and including removing all patients who were genotypically female. We anticipated that for single-task models (which were not pre-trained, and instead tuned from scratch on these imbalanced datasets), the use of imbalanced data would engender significant biases in model performance favoring the majority class (genotypically male patients) and that the balanced multi-task pre-training would help ameliorate these biases. However, using AUROC discrepancy between genotypically male and female patients as our guide, we did not see either effect. In Table 8, we show the Male - Female AUROC discrepancy of our GRU model under the ST, FTD, and FTF training regimes in the case that all genotypically female patients were removed during fine-tuning. We see that in roughly half the case, this discrepancy is positive, and roughly half it is negative (indicating the model about equally favors genotypically male patients and genotypically female patients), and that the discrepancies are largely small. Unfortunately, due to time constraints, we were only able to run one sample of these runs, so it may be that a stronger effect would emerge with more samples and greater statistical power; however, these preliminary findings do not suggest that this is the case.

GRU
ST FTD FTF
Task
MOR
CMO
DNR
DIS
ICD
LOS
REA
ACU
WBM
FTS
Table 8: Comparison of the AUROC for fine-tuned models on data biased to have no women in the training set. Bold indicates lowest discrepancy (e.g., most favorable to women).

Appendix E Few-Shot Experiments under all architectures

We replicated our few-shot experiments under all 3 architectures, and present here in Figures 6, 7, 8 plots showing the performance of ST, FTF, and FTD model types on subsampled datasets ranging from 1% to 100% across each task individually. We can see that while the dramatic gains of the GRU model are not replicated on other model types, on the transformer model, we see consistent gains of fine-tuning regimes on MOR, CMO, LOS, and WBM, much like the GRU. Additionally, and somewhat surprisingly, while the GRU outperformed the transformer architecture on the full dataset as a general rule, in the very small data regime for some tasks, the transformer offers notable gains (e.g., MOR). The linear projection model shows largely concordant behavior across all model types, with perhaps a slight gain on a subset of tasks, including MOR and WBM, towards fine-tuning results.

Figure 6: Few-shot experiments for GRU (duplicated from main body)
Figure 7: Few-shot experiments for the linear projection model
Figure 8: Few-shot experiments for the tranformer model

Appendix F Full Results

Figure 9 shows full results for all models in the full data regime, including local negative transfer analyses (analogous to Section 5.1 in the main body) in the two left-most columns, and global results in the right-most column. We can see several main take-aways from these results. First, while the transformer and GRU results both display significant common negative transfer, as evidenced by the majority of the violin-plot point mass being above 0 in the left-most column, the linear projection algorithm actually displays different behavior, with a significant extent of positive transfer happening as well. This is also reflected in the overall results globally — for the linear system, unlike the GRU and Transformer results, the full multi-task system is actually commonly one of the top performing system, and outperforms the ST model on all tasks except ICD code prediction.

Secondly, we can see that MTL preferences are very architecture dependent, with common discrepancies in which model types are preferred across the different architectures. This suggests researchers should be cognizant of strong relationships between MTL efficacy and architecture choice in future modelling. Lastly, we see that there is significant, task-dependent variability in the Transformer results, much moreso than the other model types. For example, we can see that including the WBM task is extremely harmful for the DIS and LOS tasks, inducing a roughly 4% drop in AUROC in each, as well as inducing smaller costs to other tasks. This relationship is nowhere near as pronounced for the GRU model, and is almost reversed for the linear model. This may be related to the fact that, while the GRU model is our best performing model on average, the linear model outperforms both other model types on the WBM task by significant margins.

Number of independent repeats

To assess variance in all our results, we ran a number of independent repeats of each run under different random seeds. Note that we did not alter the train/val/test split for these runs, as differing splits would require re-doing hyperparameter tuning which was computationally prohibitive. Total numbers of samples for all runs is shown in Tables 9, 10, and 11. Additional samples, both repeats and train/val/test splits / hyperparameter tuning runs would be run prior to camera ready submission.

Figure 9: Left: The performance delta (positive meaning better than fully multi-task performance) on each of our tasks (-axis) when the various other tasks are held out (point color; horizontal offset of colored points within each task column is only so the task-specific error bars can be simultaneously visible, and is consistent, but arbitrary), all measured in AUROC %. This shows that task performance commonly improves when the other tasks are held out. Middle: The average improvement across all other tasks (point color; horizontal offset within each task is merely for display purposes) when each of the possible held-out task groups (-axis) are included in the ensemble (positive meaning inclusion of this task helps). We see that omitting single task often improves the performance on other tasks. Right: Fully multi-task performance vs. single task performance vs. fine-tuned performance. Fine-tuned or multi-task representations are quite consistently preferred. Top: The GRU model architecture. Middle: The projection model architecture. Bottom: The Transformer model architecture.
MT ST FTD FTF ST Few-Shot % FTD Few-Shot % FTF Few-Shot %
Full-data 0.0 0.1 0.2 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2 0.0 0.1 0.2 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2 0.0 0.1 0.2 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2
MOR 10 3 3 5 5 5 5 5 4 2 2 2 2 2 2 2 2 2 3 3 3 2 3 3 2 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 4 3 1
CMO 10 3 3 5 5 5 5 5 4 3 3 3 3 3 3 3 3 2 3 3 3 2 3 3 3 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 3 4 1
DNR 10 3 3 5 5 5 5 5 4 3 3 3 3 3 2 2 2 2 3 3 3 2 3 3 3 3 3 2 1 2 4 5 5 5 4 5 5 5 5 5 3 3 1
ICD 10 3 3 5 5 5 5 5 4 2 2 2 2 2 2 2 2 2 3 3 3 2 3 3 2 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 3 1 1
LOS 10 3 3 5 5 5 5 5 4 2 2 2 2 2 2 2 2 2 3 3 3 2 3 3 2 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 4 3 1
REA 10 3 3 5 5 5 5 5 4 2 2 2 2 2 2 2 2 2 3 3 3 2 3 3 2 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 4 3 1
DIS 10 3 3 5 5 5 5 5 4 3 3 3 3 3 3 2 2 2 3 3 3 2 3 3 3 2 2 3 1 2 4 5 5 5 4 5 5 5 5 5 2 1 1
ACU 10 3 3 5 5 5 5 5 4 3 3 3 3 3 3 2 2 2 3 3 3 2 3 3 3 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 4 3 1
WBM 10 3 3 5 5 5 5 5 4 2 2 2 2 2 2 2 2 2 3 3 3 2 3 3 3 3 3 2 1 2 4 5 5 5 4 5 5 5 5 5 3 3 1
FTS 10 3 3 5 5 5 5 5 4 3 3 3 2 2 2 2 2 2 3 3 3 2 3 3 3 2 2 2 1 2 4 5 5 5 4 5 5 5 5 5 4 3 1
Table 9: Number of samples run for all modes of the GRU architecture. Additional samples will be run before any camera ready publication.
MT ST FTD FTF ST Few-Shot % FTD Few-Shot % FTF Few-Shot %
Full-data 0.2 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2 0.2 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2 0.2 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2
MOR 9 1 2 5 4 4 3 5 5 5 5 2 3 3 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 5
CMO 9 2 2 5 3 3 2 5 5 5 5 4 4 2 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 4 4
DNR 9 2 2 5 4 3 2 5 5 5 5 4 3 2 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 4 4
ICD 9 1 2 5 4 4 3 5 5 5 4 2 3 3 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 5
LOS 9 1 2 5 4 4 3 4 4 4 4 3 4 3 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 4
REA 9 1 2 5 4 4 2 5 4 4 4 3 3 2 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 4
DIS 9 2 2 5 4 4 3 4 4 4 4 3 4 3 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 5
ACU 9 2 2 5 4 3 2 3 4 4 4 3 4 2 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 4 4
WBM 9 1 2 5 4 3 2 5 5 4 4 3 4 3 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 4
FTS 9 2 2 5 4 4 2 5 5 5 5 3 4 2 1 2 2 2 2 2 2 2 2 2 2 2 5 5 4 5 5 5 5 4 5 5 4
Table 10: Number of samples run for all modes of the linear architecture. Additional samples will be run before any camera ready publication.
MT ST FTD FTF ST Few-Shot % FTD Few-Shot % FTF Few-Shot %
Full-data 0.1 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2 0.1 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2 0.1 0.3 0.6 1.0 1.8 3.2 5.6 10.0 17.8 31.6 56.2
MOR 7 2 1 3 4 3 3 3 3 3 3 3 3 2 3 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 2 2 1 1 1 1 1
CMO 7 2 1 2 3 4 4 4 4 4 4 3 3 3 2 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 1 1 1 1
DNR 7 2 1 3 4 3 3 3 3 3 3 3 2 3 2 1 1 1 1 1 1 1 1 1 1 1 3 3 3 2 2 1 1 1 1 1 1
ICD 7 2 1 3 4 3 3 3 3 3 3 3 4 4 2 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 2 1 1 1 1 1 1
LOS 7 2 1 3 3 4 3 3 3 3 3 3 2 2 2 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 3 2 2 1 1 1 1
REA 7 2 1 3 3 5 4 4 4 4 4 4 4 2 2 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 3 2 2 1 1 1 1
DIS 7 2 1 3 3 4 4 3 3 3 3 3 4 3 2 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 2 2 2 1 1 1 1
ACU 7 2 1 3 4 4 4 4 4 4 4 4 5 2 2 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 3 2 1 1 1 1 1
WBM 7 2 1 2 4 3 4 3 3 3 3 3 5 3 3 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 1 1 1 1 1 1
FTS 7 1 1 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 1 1 1 1 3 3 3 3 2 1 1 1 1 1 1
Table 11: Number of samples run for all modes of the self-attention architecture. Additional samples will be run before any camera ready publication.