Using a Cross-Task Grid of Linear Probes to Interpret CNN Model Predictions On Retinal Images

07/23/2021 ∙ by Katy Blumer, et al. ∙ 13

We analyze a dataset of retinal images using linear probes: linear regression models trained on some "target" task, using embeddings from a deep convolutional (CNN) model trained on some "source" task as input. We use this method across all possible pairings of 93 tasks in the UK Biobank dataset of retinal images, leading to  164k different models. We analyze the performance of these linear probes by source and target task and by layer depth. We observe that representations from the middle layers of the network are more generalizable. We find that some target tasks are easily predicted irrespective of the source task, and that some other target tasks are more accurately predicted from correlated source tasks than from embeddings trained on the same task.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Retinal fundus (internal eye) images have been well-studied in machine learning applications. Machine learning can predict retinal disease with great accuracy (Gulshan et al., 2016; Badar et al., 2020). However, many other, often surprising, features can also be predicted from these images: for example, visual acuity (Varadarajan et al., 2018), cardiovascular risk (Poplin et al., 2018), diabetes (Zhang et al., 2021), anaemia (Mitani et al., 2020) and many other variables (Rim et al., 2020). Many of these are novel predictions not known to be predictable from these images by human experts, and it would be useful to understand precisely which features in the fundus image make these features predictable.

The challenge in even framing this question is the lack of tractable formalisms for characterizing how predictions are made. One simple idea is to take a model that achieves the surprising outcome of predicting certain variables from retinal images and ask what else this model is able to predict effectively. To do this, we need a way to evaluate how effectively we can build simple extensions of the models’ internal representation to predict other quantities of interest. Linear probes (Alain and Bengio, 2016)

and concept activation vectors (TCAV)

(Kim et al., 2018) are techniques that provide paths to doing this; in this work we focus on linear probes.

Linear probes have been widely used for interpretability to understand performance of deep models with application to language processing (Hewitt and Liang, 2019; Hewitt and Manning, 2019; Belinkov, 2021)

, computer vision 

(Alain and Bengio, 2016; Asano et al., 2019), speech (Oord et al., 2018)

or generally when understanding different neural network architectures 

(Raghu et al., 2017; Graziani et al., 2019; Horoi et al., 2020). In this paper, we use this technique to study and understand CNN model predictions on 93 different tasks based on the UK Biobank dataset (Sudlow et al., 2015) of retinal images and labels.

We find that embeddings from the middle layers of the networks (as opposed to those closest to the output) learn features that are more generalizable across multiple target tasks, with linear probes consistently making accurate predictions. Additionally, we find that some target tasks such as eye position (left vs. right eye) and refractive error are easily predicted irrespective of the source task the CNN was trained on. We also find that other target tasks such as height are better predicted by embeddings trained on correlated tasks such as blood testosterone or self-reported sex, than by embeddings trained on the original task. Ultimately these results give insight into which features of the input data make it possible to learn different target values.

2 Methods

Figure 1: Plots of linear model performance by layer, organized by target variable (row) and source variable (column). X-axis scale is consistent along each row. Orange and red dots are from linear models trained on the single-valued source prediction and source variable value, respectively.

Data

We used a dataset of retinal images from the UK Biobank study (Sudlow et al., 2015) containing 140,000 retinal fundus images from 68,000 patients. We separated 12.5% of the patients into a test set, with the rest in the train set. We used a set of 93 non-eye-disease-related variables available in the UK Biobank resource. These included demographic data such as age and self-reported sex (recorded in the data as female/male); measurements such as blood tests and visual acuity; and miscellaneous features such as eye position (whether the image is of the left or right eye). Figure 5 summarizes all 93 variables.

2.1 Experiments

We trained identical deep convolutional Inception V3 models for each of the 93 variables in our dataset. The models were pre-trained on ImageNet with auxiliary loss turned off and then trained with early stopping for a maximum of 200,000 steps. We then trained linear regression models for the same set of variables. Each model used the output of an intermediate layer from one of the convolutional models as input. We used 19 different intermediate layers spanning the depth of the Inception V3 architecture, and spatially average-pooled the output to make an exact linear regression tractable. These models were trained on the same training set as the convolutional models.

In total, this gave rise to 164k different linear models: 93 “source” tasks (convolutional models) 93 “target” tasks (linear models) 19 intermediate layers. We evaluated each linear model on the test dataset and calculated either an AUC or coefficient of determination (), depending on whether the target task/variable was binary or continuous.

Along with linear regressions on intermediate layers, we also carried out linear regressions on the raw values of the variables themselves and on the single-valued predictions from each convolutional model, in order to distinguish tasks that shared common representations from tasks that were merely correlated with each other. The regressions were done on the training set and evaluated on the test set, just like the intermediate layer regressions.

3 Analysis and observations

In this section we present our analysis and observations from our experiments.

3.1 Middle layer representations generalize best

Figure 1 presents performance of the linear probe models (AUC or on the y axis) on embeddings across layers (x axis) from models trained to predict different variables (self-reported sex, eye position, smoker status, etc.) represented in each column. We observe that performance across layers tends to follow the same pattern: it increases from the layers closest to the input until the middle layers, and then decreases again - except where the source and target tasks are the same (on-diagonal plots). When the source and target are the same, the performance plateaus or continues to increase in the final layers.

The shape of the graph looks more similar along rows than columns, suggesting that the difficulty of learning a given task is more important than the differences between input embeddings learned for different tasks. This observation is in contrast with typical transfer learning setting where it is much more common to tune the final layers of a model (layers closer to the output) for new tasks, suggesting that we may want to entirely choose a layer in the middle for transfer learning.

In order to distinguish embeddings with similar representations from mere variable correlation, we also trained linear regression models on each input task’s ground truth values and the convolutional model’s output predictions (Fig 1, red and orange points). These generally do much worse than models trained on intermediate embeddings, though there are cases where the performance is comparable. One such example is predicting systolic blood pressure with age as the source task, which makes sense as the two variables are known to be correlated (Wolf-Maier et al., 2003).

3.2 Specific middle layers perform best, generalizing well across multiple tasks

Figure 2: Histograms of best-performing layers for each [source, target] variable pair. Pairs where the two variables were the same are excluded.
Figure 3: Dimensions of embeddings (input to linear regression models) by layer of the source convolutional model. Note that embeddings were spatially average-pooled, so the dimension is just the number of channels in that layer’s output.

Next we ask, for the Inception network are there specific layers that give the best performance when generalizing to all tasks? Figure 2 shows a histogram of the best-performing layer for each [source, target] variable pair. (Pairs where the source and target tasks are the same were excluded.) It is interesting to see two distinct peaks in the middle layers (around layers 6,7 as well as layers 11, 12). However layer 11 appears to be consistently more generalizable and is amongst the top 3 layers with best performance on many of our pairwise comparisons. We don’t know, however, why there are two peaks. The linear models’ input does have different dimensions depending on the layer (Fig 3), but the peaks don’t obviously correspond to the changes in size. The clustering could also be due to correlations between [source, target] pairs. This would be interesting to investigate further.

3.3 Performance bands for same target task

Figure 4: Comparison of performance on all source tasks for a given target task. Each line represents one source task, and the color is based on the performance of that variable as a target task on itself as a source task. Blue is best performance, red is worst. Dotted lines show the performance when the source variable is the same as the target. Observe that for ‘height’ representations from other source tasks (in this case, self-reported sex and blood testosterone) are better predictors.

As we would expect, the models generally perform best when the source and target tasks are the same (dotted lines, Fig 4). This is not always true, however - for example. “height” performs better on several other source tasks than on itself as a source task. The other sources were tasks such as as blood testosterone and self-reported sex, which share two properties: they’re correlated with height, and are easier to predict than height (for which we never get an above  0.3). We might guess that, because height is so hard to predict, there’s not as much room for embeddings trained on it to improve - an illustration of the utility of multitask learning.

We can also observe that there are clear bands and outliers. The bands are closer together in the earlier layers, likely because features learned in the earlier layers are more universal. The outliers generally make sense: for example, the “blood pressure medication” as target task performs well with itself as a source task, but it performs just as well on source tasks for two related medications, aspirin and ACE inhibitors.

Figure 5:

Cross-comparison of all tasks for layer 12. This layer was chosen as it was close to the best layer for most models, and showed the most interesting variation between variables. Tasks are ordered first by whether they’re binary or continuous values, then by performance of that task as a target with the “eye position” source task (in practice, this worked almost as well as hierarchical clustering). The bolded source task is a random baseline, where the source CNN was trained on labels drawn from the standard uniform distribution.

3.4 Comparison across all tasks

When comparing a single layer for all task combinations, we again see that the target task is usually more important than the source task. But there are some other interesting relationships - for example, age, eye position (left vs. right eye), and most measures of visual acuity are easily predicted regardless of the source task. Age is correlated with a small subset of tasks (see supplement), but predicting age with those as source tasks gets only about the same performance as many other source task.

4 Conclusion

To conclude, in this work we used the basic notion of linear probes to study models trained on 93 different tasks/variables in fundus images in the UK Biobank. We looked at nearly 164k models, examining different variables as source and target pairs. We find interesting patterns in performance on source embeddings at various layer depths, and in performance on various combinations of source and target task.

Overall, the results show that simple linear probes provide a rich environment for unravelling the relationships between the underlying data and labels, providing insight into why neural networks trained on single labels are able to make accurate predictions. Future work will use the different representations to unravel which features of images are responsible for the different accurate predictions.

Acknowledgements

This research has been conducted using the UK Biobank Resource (Sudlow et al., 2015). We would like to thank Abbi Ward for help with permissions and access to the dataset, Avinash Varadarajan for help with the dataset, annotations, and support with the training and evaluation infrastructure, Xinle Sheila Lu for helpful pointers on model inference, and Yun Liu and Maithra Raghu for helpful conversations and references.

References

  • G. Alain and Y. Bengio (2016)

    Understanding intermediate layers using linear classifier probes

    .
    arXiv preprint arXiv:1610.01644. Cited by: §1.
  • Y. Asano, C. Rupprecht, and A. Vedaldi (2019) Self-labelling via simultaneous clustering and representation learning. In ICLR, Cited by: §1.
  • M. Badar, M. Haris, and A. Fatima (2020)

    Application of deep learning for retinal image analysis: a review

    .
    Computer Science Review 35, pp. 100203. Cited by: §1.
  • Y. Belinkov (2021) Probing classifiers: promises, shortcomings, and alternatives. arXiv preprint arXiv:2102.12452. Cited by: §1.
  • M. Graziani, H. Muller, and V. Andrearczyk (2019) Interpreting intentionally flawed models with linear probes. In Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops, pp. 0–0. Cited by: §1.
  • V. Gulshan, L. Peng, M. Coram, M. C. Stumpe, D. Wu, A. Narayanaswamy, S. Venugopalan, K. Widner, T. Madams, J. Cuadros, et al. (2016) Development and validation of a deep learning algorithm for detection of diabetic retinopathy in retinal fundus photographs. JAMA 316 (22), pp. 2402–2410. Cited by: §1.
  • J. Hewitt and P. Liang (2019) Designing and interpreting probes with control tasks. arXiv preprint arXiv:1909.03368. Cited by: §1.
  • J. Hewitt and C. D. Manning (2019) A structural probe for finding syntax in word representations. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4129–4138. Cited by: §1.
  • S. Horoi, G. Lajoie, and G. Wolf (2020)

    Internal representation dynamics and geometry in recurrent neural networks

    .
    arXiv preprint arXiv:2001.03255. Cited by: §1.
  • B. Kim, M. Wattenberg, J. Gilmer, C. Cai, J. Wexler, F. Viegas, et al. (2018) Interpretability beyond feature attribution: quantitative testing with concept activation vectors (tcav). In International Conference on Machine Learning, pp. 2668–2677. Cited by: §1.
  • A. Mitani, A. Huang, S. Venugopalan, G. S. Corrado, L. Peng, D. R. Webster, N. Hammel, Y. Liu, and A. V. Varadarajan (2020) Detection of anaemia from retinal fundus images via deep learning. Nature Biomedical Engineering 4 (1), pp. 18–27. Cited by: §1.
  • A. v. d. Oord, Y. Li, and O. Vinyals (2018) Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: §1.
  • R. Poplin, A. V. Varadarajan, K. Blumer, Y. Liu, M. V. McConnell, G. S. Corrado, L. Peng, and D. R. Webster (2018) Prediction of cardiovascular risk factors from retinal fundus photographs via deep learning. Nature Biomedical Engineering 2 (3), pp. 158–164. Cited by: §1.
  • M. Raghu, J. Gilmer, J. Yosinski, and J. Sohl-Dickstein (2017) SVCCA: singular vector canonical correlation analysis for deep learning dynamics and interpretability. In NIPS, Cited by: §1.
  • T. H. Rim, G. Lee, Y. Kim, Y. Tham, C. J. Lee, S. J. Baik, Y. A. Kim, M. Yu, M. Deshmukh, B. K. Lee, et al. (2020) Prediction of systemic biomarkers from retinal photographs: development and validation of deep-learning algorithms. The Lancet Digital Health 2 (10), pp. e526–e536. Cited by: §1.
  • C. Sudlow, J. Gallacher, N. Allen, V. Beral, P. Burton, J. Danesh, P. Downey, P. Elliott, J. Green, M. Landray, et al. (2015) UK Biobank: an open access resource for identifying the causes of a wide range of complex diseases of middle and old age. PLOS Med 12 (3), pp. e1001779. Cited by: §1, §2, §4.
  • A. V. Varadarajan, R. Poplin, K. Blumer, C. Angermueller, J. Ledsam, R. Chopra, P. A. Keane, G. S. Corrado, L. Peng, and D. R. Webster (2018) Deep learning for predicting refractive error from retinal fundus images. Investigative Ophthalmology & Visual Science 59 (7), pp. 2861–2868. Cited by: §1.
  • K. Wolf-Maier, R. S. Cooper, J. R. Banegas, S. Giampaoli, H. Hense, M. Joffres, M. Kastarinen, N. Poulter, P. Primatesta, F. Rodríguez-Artalejo, B. Stegmayr, M. Thamm, J. Tuomilehto, D. Vanuzzo, and F. Vescio (2003) Hypertension Prevalence and Blood Pressure Levels in 6 European Countries, Canada, and the United States. JAMA 289 (18), pp. 2363–2369. Cited by: §3.1.
  • K. Zhang, X. Liu, J. Xu, J. Yuan, W. Cai, T. Chen, K. Wang, Y. Gao, S. Nie, X. Xu, X. Qin, Y. Su, W. Xu, A. Olvera, K. Xue, Z. Li, M. Zhang, X. Zeng, C. L. Zhang, O. Li, E. E. Zhang, J. Zhu, Y. Xu, D. Kermany, K. Zhou, Y. Pan, S. Li, I. F. Lai, Y. Chi, C. Wang, M. Pei, G. Zang, Q. Zhang, J. Lau, D. Lam, X. Zou, A. Wumaier, J. Wang, Y. Shen, F. F. Hou, P. Zhang, T. Xu, Y. Zhou, and G. Wang (2021) Deep-learning models for the detection and incidence prediction of chronic kidney disease and type 2 diabetes from retinal fundus images. Nature Biomedical Engineering. Cited by: §1.