ConceptWhitening
None
view repo
What does a neural network encode about a concept as we traverse through the layers? Interpretability in machine learning is undoubtedly important, but the calculations of neural networks are very challenging to understand. Attempts to see inside their hidden layers can either be misleading, unusable, or rely on the latent space to possess properties that it may not have. In this work, rather than attempting to analyze a neural network posthoc, we introduce a mechanism, called concept whitening (CW), to alter a given layer of the network to allow us to better understand the computation leading up to that layer. When a concept whitening module is added to a CNN, the axes of the latent space can be aligned with concepts of interest. By experiment, we show that CW can provide us a much clearer understanding for how the network gradually learns concepts over layers without hurting predictive performance.
READ FULL TEXT VIEW PDFNone
An important practical challenge that arises with neural networks is the fact that the units within their hidden (intermediate, convolutional) layers are not usually semantically understandable. This is particularly true with computer vision applications, where an expanding body of research has focused centrally on explaining the calculations of neural networks and other black box models. Some of the core questions considered in these posthoc analyses of neural networks include: “What concept does a unit in a hidden layer of a trained neural network represent?”or “Does this unit in the network represent a concept that a human might understand?”
The questions listed above are important, but it is not clear that they would naturally have satisfactory answers when performing posthoc analysis on a pre-trained neural network. In fact, there are several reasons why various types of posthoc analyses would not answer these questions. Efforts to interpret individual nodes of pre-trained neural networks (e.g., zhou2018interpreting; zhou2014object)
have shown that some fraction of nodes can be identified to be aligned with some high-level semantic meaning, but these special nodes do not provably contain the network’s full information about the concepts. That is, the nodes are not “pure,” and information about the concept could be scattered throughout the network. Concept-vector methods also
(kim2017interpretability; zhou2018interpretable; ghorbani2019towards)have been used to analyze pre-trained neural networks. Here, vectors in the latent space are chosen to align with pre-defined or automatically-discovered concepts. While concept-vectors are more promising, they still make the assumption that the latent space of a neural network admits a posthoc analysis of a specific form. In particular, they assume that the latent space places members of each concept in one easy-to-classify portion of latent space. Since the latent space was not explicitly constructed to have this property, there is no reason to believe it holds.
Ideally, we would want a neural network whose latent space tells us how it is disentangling concepts, without needing to resort to extra classifiers like concept-vector methods (kim2017interpretability; ghorbani2019towards), without surveys to humans (zhou2014object), and without other manipulations that rely on whether the geometry of a latent space serendipitously admits analysis of concepts. Rather than having to rely on assumptions that the latent space admits disentanglement, we would prefer to constrain the latent space directly. We might even wish that the concepts align themselves along the axes of the latent space, so that each point in the latent space has an interpretation in terms of known concepts.
Let us discuss how one would go about imposing such constraints on the latent space. In particular, we introduce the possibility of what we call concept whitening. Concept whitening (CW) is a module inserted into a neural network. It constrains the latent space to represent target concepts and also provides a straightforward means to extract them. It does not force the concepts to be learned as an intermediate step, rather it imposes the latent space to be aligned along the concepts. For instance, let us say that, using CW on a lower layer of the network, the concept “airplane” is represented along one axis. By examining the images along this axis, we can find the lower-level abstraction that the network is using for the complex concept “airplane,” which might be white or silver objects with blue backgrounds. In the lower layers of a standard neural network we cannot necessarily find this abstraction, because the abstraction of “airplane” might be spread throughout latent space rather than along an “airplane” axis. By looking at images along the airplane axis at each layer, we see how the network gradually represents airplanes with an increasing level of sophistication and complexity.
Concept whitening could be used to replace a plain batch normalization step in a CNN backbone, because it combines batch whitening with an extra step involving a rotation matrix. Batch whitening usually provides helpful properties to latent spaces, but our goal requires the whitening to take place with respect to concepts; the use of the rotation matrix to align the concepts with the axes is the key to interpretability through disentangled concepts. Whitening decorrelates and normalizes each axis (i.e., transforms the post-convolution latent space so that the covariance matrix between channels is the identity). Exploiting the property that a whitening transformation remains valid after applying arbitrary rotation, the rotation matrix strategically matches the concepts to the axes.
The concepts used in CW do not need to be the labels in the classification problem, they can be learned from an auxiliary dataset in which concepts are labeled. The concepts do not need to be labeled in the dataset involved in the main classification task (though they could be), and the main classification labels do not need to be available in the auxiliary concept dataset.
Through qualitative and quantitative experiments, we illustrate how concept whitening applied to the various layers of the neural network illuminates its internal calculations. We verify the interpretability and pureness of concepts in the disentangled latent space. Importantly for practice, we show that by replacing the batch normalization layer in pretrained state-of-the-art models with a CW module, the resulting neural network can achieve accuracy on par with the corresponding original black box neural network on large datasets, and it can do this within one additional epoch of further training. Thus, with fairly minimal effort, one can make a small modification to a neural network architecture (adding a CW module), and in return be able to easily visualize how the network is learning all of the different concepts at any chosen layer.
CW can show us how a concept is represented at a given layer of the network. What we find is that at lower layers, since a complex concept cannot be represented by the network, it often creates lower-level abstract concepts. For example, an airplane at an early layer is represented by an abstract concept that is white or gray objects on a blue background. A bed is represented by an abstract concept that seems to be characterized by warm colors (orange, yellow). In that sense, the CW layer can help us to discover new concepts that can be formally defined and built on.
There are several large and rapidly expanding bodies of relevant literature.
Interpretability and explainability of neural networks:
There have been two schools of thought on improving the interpretability of neural networks: (1) learning an inherently interpretable model; (2) providing post-hoc explanations for an exist neural network. CW falls within the first type, though it only enlightens what the network is doing, rather than providing a full understanding of the network’s computations. To provide a full explanation of each computation would lead to more constraints and thus a loss in flexibility, whereas CW allows more flexibility in exchange for more general types of explanations. The vast majority of current works on neural networks are of the second type, explainability. A problem with the terminology is that “explanation” methods are often summary statistics of performance (e.g., local approximations, general trends on node activation) rather than actual explanations of the model’s calculations. For instance, if a node is found to activate when a certain concept is present in an image, it does not mean that all information (or even the majority of information) about this concept is involved with that particular node.
Saliency-based methods are the most common form of post-hoc explanations for neural networks (zeiler2014visualizing; simonyan2013deep; smilkov2017smoothgrad; selvaraju2017grad). These methods assign importance weights to each pixel of the input image to show the importance of each pixel to the image’s predicted class. Saliency maps are problematic for well-known reasons: they often provide highlighting of edges in images, regardless of the class. Thus, very similar explanations are given for multiple classes, and often none of them are useful explanations. Saliency methods can be unreliable and fragile (e.g., adebayo2018sanity).
Other work provides explanations of how the network’s latent features operate. Some measure the alignment of an individual internal unit or filter of trained neural networks to a predefined concept and find some units have relatively strong alignment to that concept (zhou2018interpreting; zhou2014object). While some units (i.e., filters) may align nicely with pre-defined concepts, the concept can be represented diffusely through many units (the concept representation by individual nodes is impure); this is because the network was not trained to have concepts expressed purely through individual nodes. To address this weakness, several concept-based post-hoc explanation approaches have recently been proposed that do not rely on the concept aligning with individual units (kim2017interpretability; zhou2018interpretable; ghorbani2019towards; yeh2019concept). Instead of analyzing individual units, these methods try to learn a linear combination of them to represent a predefined concept (kim2017interpretability) or to automatically discover concepts by clustering patches and defining the clusters as new concepts (ghorbani2019towards). Although these methods are promising, they are based on assumptions of the latent space that may not hold. For instance, these methods assume that a classifier (usually a linear classifier) exists on the latent space such that the concept is correctly classified. Since the network was not trained so that this assumption holds, it may not hold. More importantly, since the latent space is not shaped explicitly to handle this kind of concept-based explanation, unit vectors (directions) in the latent space may not represent concepts purely. We will give an example in the next section to show why latent spaces built without constraints may not achieve concept separation.
CW avoids these problems because it shapes the latent space through training. In that sense, CW is closer to work on inherently interpretable neural networks, though its use-case is in the spirit of concept vectors, in that it is useful for providing important directions in the latent space.
There are emerging works trying to build inherently interpretable neural networks. Like CW, they alter the network structure to encourage different forms of interpretability. For example, neural networks have been designed to perform case-based reasoning (chen2018looks; li2018deep), to incorporate logical or grammatical structures (li2017aognets; granmo2019convolutional; wu2019towards), to do classification based on hard attention (mnih2014recurrent; ba2014multiple; sermanet2014attention; elsayed2019saccader), or to do image recognition by decomposing the components of images (saralajew2019classification). These models all have different forms of interpretability than we consider (understanding how the latent spaces of each layer learn a known set of concepts). One work that is somewhat similar to ours is that of bouchacourt2019educe, who develop a concept-based deep learning method that is inherently interpretable, but it relies on specific properties of textual data that do not readily transfer to image data. zhang2018interpretable add losses to the filters to encourage them to detect object parts that can also be viewed as concepts. However, unlike in CW, the method of zhang2018interpretable works only for object parts while CW works for any type of concept, such as objects, colors, textures, etc. adel2018discovering transform the density of the current latent representation in an invertible way by normalizing flows and maximizing the mutual information between the transformed representation and side information provided by human users. Although side information could also include concepts, adel2018discovering
query for the side information by active learning which is different from ours.
Whitening and orthogonality:
Whitening is a linear transformation that transforms the covariance matrix of random input vectors to be the identity matrix. It is a classical preprocessing step in data science. In the realm of deep learning, batch normalization
(ioffe2015batch), which is widely used in many state-of-the-art neural network architectures, retains the standardization part of whitening but not the decorrelation. Earlier attempts (desjardins2015natural; luo2017learning)whiten by periodically estimating the whitening matrix, which leads to instability in training. Other methods
(cogswell2015reducing) perform whitening by adding a decorrelation loss. By observing that SVD is differentiable, huang2018decorrelated; huang2019iterative develop ZCA whitening, supported directly in back-propagation. siarohin2018whitening also propose a differentiable whitening block, but it is based on Cholesky whitening. The whitening part of our CW module borrows techiques from IterNorm (huang2019iterative) because it is differentiable and accelerated. CW is different from previous methods because its whitening matrixis multiplied by an orthogonal matrix
and maximizes the activation of known concepts along the latent space axes.In the field of deep learning, many initial works about incorporating orthogonality constraints are targeted for RNNs (vorontsov2017orthogonality; mhammedi2017efficient; wisdom2016full), since orthogonality could help avoid vanishing gradients or exploding gradients in RNNs. Other work explores ways to learn orthogonal weights or representations for all types of neural networks (not just RNNs) (harandi2016generalized; huang2018orthogonal; lezcano2019cheap; lezama2018ole). For example, lezama2018ole
use special loss functions to force orthogonality. The optimization algorithms used in the above methods are all different from ours. For CW, we optimize the orthogonal matrix by Cayley-transform-based curvilinear search algorithms proposed by
wen2013feasible. While vorontsov2017orthogonality also use a Cayley transform, they do it with a fixed learning rate that does not work effectively in our setting. More importantly, the goal of doing optimization with orthogonality contraints in all these works are completely different from ours. None of them try to align columns of the orthogonal matrix with any type of concept.Suppose are samples in our dataset and are their labels. From the latent space defined by a hidden layer, a DNN classifier can be divided into two parts, a feature extractor , with parameters , and a classifier , parameterized by . Then is the latent representation of the input and is the predicted label. Suppose we are interested in concepts called . We can then pre-define auxiliary datasets such that samples in are the most representative samples of concept . Our goal is to learn and simultaneously, such that (a) the classifier can predict the label accurately; (b) the dimension of the latent representation aligns with concept . In other words, samples in should have larger values of than other samples. Conversely, samples not in should have smaller values of .
Some posthoc explanation methods have looked at unit vectors in the direction of data where a concept is exhibited to measure how different concepts contribute to a classification task (zhou2018interpretable). Other methods consider directional derivatives towards data exhibiting the concept (kim2017interpretability), for the same reason. There are important reasons why these types of approaches may not work.
Consider, for instance, an elongated latent space similar to that illustrated in Figure 1(a). Here, two unit vectors pointing to different groups of data (perhaps exhibiting two separate concepts) may have a large inner product, suggesting that they may be part of the same concept, when in fact, they may be not be similar at all. Worse, a unit vector in the yellow direction appears to indicate that the red concept has more extreme values of the yellow concept than the members of the yellow concept itself. Thus, even if the latent space is standardized, multiple unrelated concepts can still be found by traversing towards the same general direction, as shown in Figure 1(a). For the same reason, taking derivatives towards the parts of the space where various concepts tend to appear may yield similar derivatives for very different concepts. This is true even if these concepts are not co-located in latent space.
If the latent space is not mean-centered, that alone could cause problems for posthoc methods that compute directions towards concepts. Consider, for instance, a case where all points in the latent space are far from the origin. In that case, all concept directions point towards the same part of the space: the part where the data lies (see Figure 1)(b).
At the very least, if we are to examine directions in the latent space to look for known concepts, the latent space should be mean-centered, and its axes should be aligned with these concepts.
Fortunately, due to properties of high dimensional geometry (blum2016foundations), if our data are whitened (standardized and decorrelated), there is some hope that the concepts can be fully represented by unit vectors. Two reasons for this are shown by the following theorems, which are variations of standard results from high-dimensional geometry.
(Variation of well-known theorem) If a random vector is sampled from a
-dimensional spherical Gaussian distribution
. Then for all ,This theorem means that if the dimension of the latent space is large enough, which is reasonable for DNNs, then as long as the data follow a spherical Gaussian distribution, we have that almost all data are distributed near the surface of the sphere with radius . In other words, since all data points have approximately the same distance to the origin, they are distinguished only by their directions. Therefore, as long as all samples representing the concept are near each other, and as long as non-concept samples are not nearby, then a unit vector fully characterizes the location of that concept in latent space.
(Variation of well-known theorem) Suppose two unit vectors are randomly drawn from a -dimensional unit sphere. Then for all , the angle between them obeys
This theorem indicates that two random unit vectors have high probability to be nearly orthogonal when
is large. Assuming that drawing images from different uncorrelated concepts is similar to drawing randomly distributed data, the unit vectors pointing to different concepts may be nearly orthogonal to each other. In that case, we may be able to find an orthonormal basis and use the first -axes to represent the concepts of interest.The above theorems provide the reasons why concept whitening might work: by aligning the samples representing a concept in the same direction, and aligning each concept in its own direction, unit vectors are able to fully characterize that concept in latent space.
Let be the latent representation matrix of samples, in which each column contains the latent features of the sample. Our Concept Whitening module (CW) consists of two parts, whitening and orthogonal transformation. The whitening transformation decorrelates and standardizes the data by
(1) |
where is the sample mean and is the whitening matrix that obeys . Here, is the covariance matrix. The whitening matrix is not unique and can be calculated in many ways such as ZCA whitening and Cholesky decomposition. Another important property of the whitening matrix is that it is rotation free; suppose is an orthogonal matrix, then
(2) |
is also a valid whitening matrix. In our module, after whitening the latent space to endow it with the properties discussed above, we still need to rotate the samples in their latent space such that the data from concept , namely , are highly activated on the axis. Specifically, we need to find an orthogonal matrix whose column is the axis, by optimizing the following objective:
(3) | ||||
where is a matrix denoting the latent representation of and are concepts of interest. An optimization problem with an orthogonality constraint like this can be solved by gradient-based approaches on the Stiefel manifold (e.g., the method of wen2013feasible).
This whole procedure constitutes CW, and can be done for any given layer of a neural network as part of the training of the network.
Whitening has not (to our knowledge) been previously applied to align the latent space to concepts. In the past, whitening has been used to speed up back-propagation. The specific whitening problem for speeding up back-propagation is different from that for concept alignment–the rotation matrix is not present in other work on whitening, nor is the notion of a concept–however, we can leverage some of the optimization tools used in that work on whitening (huang2019iterative; huang2018orthogonal; siarohin2018whitening). Specifically, we adapt ideas underlying the IterNorm algorithm (huang2019iterative), which employs Newton’s iterations to approximate ZCA whitening, to the problem studied here. Let us now describe how this is done.
The whitening matrix in ZCA is
(4) |
where and
are the eigenvalue diagonal matrix and eigenvector matrix given by the eigenvalue decomposition of the covariance matrix,
. Like other normalization methods, we calculate a and for each mini-batch of data, and average them together to form the model used in testing.As mentioned in Section 3.3, the challenging part for CW is that we also need to learn an orthogonal matrix by solving an optimization problem. To do this, we will optimize the objective while strictly maintaining the matrix to be orthogonal by performing gradient descent with a curvilinear search on the Stiefel manifold (wen2013feasible) and adjust it to deal with mini-batch data.
During training, our procedure must handle two types of data: data for calculating the main objective and the data representing the predefined concepts. The model is optimized by alternating optimization: the mini-batches of the main dataset and the auxiliary concept dataset are fed to the network, and the following two objectives are optimized in turns. The first objective is:
(5) |
where and are layers before and after the CW module parameterized by and respectively. is a whitening transformation parameterized by sample mean and whitening matrix . is the orthogonal matrix and together form the CW module (which is also a valid whitening transformation). The second objective is
(6) | ||||
The orthogonal matrix is fixed when training for the main objective and the other parameters are fixed when training for
. The optimization problem is a linear programming problem with quadratic constraints (LPQC) which is generally NP-hard. Since directly solving for the optimal solution is intractable, we optimize it by gradient methods on the Stiefel manifold. At each step
, the orthogonal matrix is updated by Cayley transformwhere
is a skew-symmetric matrix,
is the gradient of the loss function and is the learning rate. The optimization procedure is accelerated by curvilinear search on the learning rate at each step (wen2013feasible). Note that, in the Cayley transform, the stationary points are reached when , which has multiple solutions. Since the solutions are in high-dimensional space, these stationary points are very likely to be saddle points which can be avoided by SGD. Therefore, we use the stochastic gradient calculated by a mini-batch of samples to replace at each step. To accelerate and stabilize the stochastic gradient, we also apply momentum to it during implementation.In CNNs, a feature map (a channel within one layer, created by a convolution of one filter) contains the information of how activated a part of the image is by a single filter. That filter may be a detector for a specific concept. Let us reshape the feature map into a vector, where each element of the vector represents how much one part of the image is activated by the filter. Thus, if the feature map for one filter is then a vector of length contains the activation information for that filter around the whole feature map. We do this reshaping procedure for each filter, which reshapes the output of a convolution layer into a matrix . We then perform CW on the reshaped matrix. After doing this, the resulting matrix is still size
. If we reshape this matrix back to its original size as a tensor, one feature map of the tensor now (after training) represents whether a meaningful concept is detected at each location in the image for that layer. Note that, now the output of a filter is a feature map which is a
matrix but the concept activation score we used in the optimization problem is a scalar. Therefore, we need to get an activation value from the feature map. We may have multiple ways to do it. We try the following calculations to define activation based on the feature map: (a) mean of all feature map values; (b) max of all feature map values; (c) mean of all positive feature map values; (d) mean of down-sampled feature map obtained by max pooling. We use (d) in our experiments since it is good at capturing both high-level and low-level concepts. Detailed analysis and experiments about the choice of different activation calculations are discussed in Appendix
A.Let us discuss some aspects of practical implementation. The CW module can substitute for other normalization modules such as BatchNorm in an hidden layer of the CNN. Because both whitening and orthogonal optimization require relatively higher computational cost, one can leverage a pretrained model as a warm start. To do this, we might leverage a pretrained model (for the same main objective) that does not use CW, and replace a BatchNorm layer in that network with a CW layer. The model usually converges in one epoch (one pass over the data) if a pretrained model is used.
In this section, we first show that after replacing one batch norm (BN) layer with our CW module, the accuracy of image recognition is still on par with the original model (4.1). After that, we visualize the concept basis we learn and show that the axes are aligned with the concepts assigned to them. Specifically, we display the images that are most activated along a single axis (2); we then show how two axes interact with each other (4.2.2); and we further show how the same concept evolves in different layers (4.2.3), where we have replaced one layer at a time. Then we measure the purity of our concept axes and compare with other concept-based neural network methods (4.3).
We evaluate the image recognition accuracy of the CNNs before and after adding a CW module. We show that simply replacing a BN module with a CW module and training for a single epoch leads to (at most) a small drop in performance on its main objective performance. Specifically, after replacing the BN module with the CW module, we trained popular CNN architectures including VGG16+BN (simonyan2014very), ResNet with 18 layers and 50 layers (he2016deep) and DenseNet161 (huang2017densely) on the Places365 (zhou2017places) dataset. The auxiliary concept dataset we used is MS COCO (lin2014microsoft). Each annotation, e.g. “person,” in MS COCO was used as one concept, and we selected all the images with this annotation (images having “person” in it) as the data representing the concept. In order to limit the total time of the training process, we used pre-trained models for the popular CNN architectures (discussed above) and fine-tuned these models after BN was replaced with CW.
Table 1 shows the average test accuracy on the validation set of Places365 over 5 runs. We randomly selected 3 concepts to learn using CW for each run, and used the average of them to measure accuracy. We repeated this, applying CW to different layers and reported the average accuracy among the layers. The accuracy does not change much among CW applied to the different layers, as shown in Appendix B.
Top-1 acc. | Top-5 acc. | |||
---|---|---|---|---|
Original | +CW | Original | +CW | |
VGG16-BN | 53.6 | 53.3 | 84.2 | 83.8 |
ResNet18 | 54.5 | 53.9 | 84.6 | 84.2 |
ResNet50 | 54.7 | 54.9 | 85.1 | 85.2 |
DenseNet161 | 55.3 | 55.5 | 85.2 | 85.6 |
Because we have leveraged a pretrained model, when training with CW, we conduct only one additional epoch of training (one pass over the dataset) for each run. As shown in Table 1, the performance of these models using the CW module is on par with the original model: the difference is within with respect to top-1 and top-5 accuracy. This means in practice, if a pretrained model (using BN) exists, one can simply replace the BN module with a CW module and train it for one epoch, in which case, the pretrained black-box model can be turned into a more interpretable model that is approximately equally accurate.
In order to demonstrate the interpretability benefits of models equipped with CW modules, we visualize the concept basis in the CW module and validate that the axes are aligned with their assigned concepts. In detail, (a) we check the most activated images on these axes; (b) we look at how images are distributed in a 2D slice of the latent space; (c) we show how realizations of the same concept change if we apply CW on different layers. All of these experiments were done on a ResNet18 equipped with CW trained on Places365.
We sort all validation samples by their activation values (discussed in Section 3.4) to show how much they are related to the concept. Figure 2 shows the images that have the top-10 largest activations along three different concept’s axes. Note that all these concepts are trained together using one CW module.
From Figure 2(b) we can see that all the top activated images have the same semantic meaning if the CW module is located a higher layer (namely the 16th layer). Figure 2(a) shows that when the CW module is applied to a lower layer (namely the 2nd layer), it tends to capture low level information such as color or texture characteristic of these concepts. For instance, the top activated images on the “airplane” axis generally have a blue background with a white or gray object in the middle, which also happens in real airplanes images. It is reasonable that the lower layer CW module cannot extract complete information about high-level concepts such as “airplane” since the model complexity of the first two layers is limited.
In that sense, the CW layer has discovered an abstraction of a more complex concept; namely it has discovered that the blue images with white objects are primitive representations of the “airplane” concept. Similarly, the network seems to have discovered that warm colors is a lower-level abstraction of the “bedroom” concept, and that dark background with vertical light is an abstraction of the “person” concept.
Let us consider whether joint information about different concepts is captured by the latent space of CW. To investigate how the data are distributed in the new latent space, we pick a 2D slice of the latent space, which means we select two axes and and look at the subspace they form.
The data’s joint distribution on the two axes is shown in Figure 3. To visualize the joint distribution, we first compute the activations of all validation data on the two axes, then divide the latent space into a grid of blocks, where the maximum and minimum activation value are the top and bottom of the grid. For the grid shown in Figure 3 (a), we randomly select one image that falls into each block, and display the image in its corresponding block. If there is no image in the block, the block remains black. From Figure 3 (a), we observe that the axes are not only aligned with their assigned concepts, they also incorporate joint information. For example, a “person in bed” has high activation on both the “person” axis and “bed” axis.
We also include a 2D histogram of the number of images that fall into each block. As shown in Figure 3(b), most images are distributed near the center which agrees with Theorem 3.2: the samples’ feature vector has high probability to be nearly orthogonal to the axis we pick, and consequently they have near activation on the axis itself. Note that the 2D histogram does not contradict Theorem 3.1 since the samples that have low activations on these two axes might have high activations on the unplotted axes. Therefore, they can still have distances to the origin that are similar to those of the highly activated samples.
Although our objective is the same when we apply the CW module to different layers in the same CNN, the latent space we get might be different. This is because different layers might have different levels of semantic meaning. Because of this, it might be interesting to track how the representation of a single image will change as the CW module is applied to different layers of the CNN.
In order to better understand the latent representation, we plot a 2D slice of the latent space. Unlike in Section 4.2.2, here, a point in the plot is not specified by the activation values themselves but by their rankings. For example, the point means the point is a quantile in the first axis and quantile in the second axis. We use the percentage instead of using the value, because as mentioned in Section 4.2.2, most points are near the center of the plot, so the rankings spread the values for plotting purposes.
Figure 4 shows the 2D representation plot of two representative images. Each point in the plot corresponds to the percentile rank representation of the image when the CW module is applied to different layers. The points are connected by arrows according to the depth of the layer. These plots confirm that the abstract concepts learned in the lower layers tend to capture lower-level meaning (such as colors or shapes) while the higher layers capture high-level meaning (such as types of objects). For example, in the left image in Figure 4 (a), the bed is blue, where blue is typical low level information about the “airplane” class but not about the “bed” class since bedrooms are usually warm colors. Therefore, in lower layers, the bed image has higher ranking in the “airplane” axis than the “bed” axis. However, when CW is applied to deeper layers, high level information is available, and thus the image becomes highly ranked on the “bed” axis and lower on the “airplane” axis.
In Figure 4 (b), traversing through the networks’ layers, the image of a sunset does not have the typical blue coloring of a sky. Its warm colors put it high on the “bedroom” concept for the second layer, and low on the “airplane” concept. However, as we look at higher layers, where the network can represent more sophisticated concepts, we see the image’s rank grow on the “airplane” concept (perhaps the network uses the presence of skies to detect airplanes), and decrease on the “bed” concept. From there, as we increase layers, the “airplane” concept decreases slightly (perhaps because there is no airplane in the image), and the “bed” concept increases slightly.
In this subsection, we measure the interpretability of the latent space quantitatively. To quantitatively define the interpretability with respect to concepts, we measure the purity of the concepts we learned with CW and compare with other concept-based methods. The purity is measured by the AUC calculated from the activation values. Specifically, we choose 10 concepts to learn at the same time. Each concept dataset is divided into a training set and testing set. After training the CW module using the training set, we get the testing samples’ activation values on the 10 concept axes. For each concept axis, we assign samples of this concept to the label while giving other samples label . In this way, we can calculate the AUC score of the latent space with respect to each concept. The AUC score measures whether the samples belongs to a concept are ranked higher than other samples. Thus, the AUC score indicates the purity of the concept axis.
We compare the concept purity measured by AUC with the concept vectors learned by TCAV (kim2017interpretability), IBD (zhou2018interpretable) and filters in standard CNNs (zhou2014object). For TCAV and IBD, since these methods already find concept vectors, we use the samples’ projections on the vectors to measure the AUC score. For filters in standard CNNs, we measure the AUC score for all filters and choose the best one to compare with our method, separately for each concept (denoted “Best Filter”). As shown in Figure 5, we compare these methods across the different layers. The concepts learned in the CW module are generally purer than those of other methods. This results from CW’s whitening of the latent space and optimization of the loss function, as illustrated in Sections 3.1 and 3.2.
We also compare the correlation of axes in the latent space before and after the CW module is applied. For comparison with posthoc methods like TCAV and IBD, we measure the output of their BN modules in the pretrained model, because the output of these layers are mean centered and normalized, which, as we discussed, are important properties for concept vectors. Shown by the absolute correlation coefficients plotted in Figure 6, the axes still have relatively strong correlation after passing through the BN module. (If CW were applied instead of BN, they would instead be decorrelated). This result reflects why purity of concepts is important; when the axes are pure, the signal of one concept can be concentrated only on its axis, while in standard CNNs, the concept could be distributed throughout the latent space.
Concept whitening is a module placed at the bottleneck of a CNN, to force the latent space to be disentangled, and to align the axes of the latent space with the predefined concepts. By building an inherently interpretable CNN with concept whitening, we can gain intuition about how the network gradually learns the target concepts over the layers without harming the main objective’s performance.
There are many avenues for possible future work. Since CW modules are useful for helping humans to define primitive abstract concepts, such as those we have seen the network use at early layers, it would be interesting to automatically detect and quantify these new concepts, in the spirit of ghorbani2019towards. Also the requirement of CW to completely decorrelate the outputs of all the filters might be too strong for some tasks. This is because concepts might be highly correlated in practice such as “airplane” and “sky”. In this case, we may want to soften our definition of CW. We could define several general topics that are uncorrelated, and use multiple correlated filters to represent concepts within each general topic. In this scenario, instead of forcing the gram matrix to be the identity matrix, we could make it block diagonal. The orthogonal basis would become a set of orthogonal subspaces.
The output of a single filter is a feature map. However, a scalar is needed to quantify how much a sample is activated on a concept, which is used in both optimization and evaluation. Based on a feature map, multiple reasonable ways exists to calculate the concept activation.
Specifically, we try the following calculations to produce an activation value:
[topsep=0pt,noitemsep]
Mean of all feature map values
Max of all feature map values
Mean of all positive feature map values
Mean of down-sampled feature map obtained by max pooling.
Figure 7 shows these four methods of calculating the activation through demonstration. Among them, the mean of values is more suitable for capturing low-level concepts since they are distributed throughout the feature map. For high-level concepts, the max value and mean of positive values are more powerful: they can capture high-level concepts such as objects, since objects usually occur just in one location, not repeatedly throughout an image. The mean of max-pooled values is a combination of the previous types and is capable of representing both high-level and low-level concepts. Intuitively, the mean of max pooled values is more similar to the max function when applied to higher layers and more similar to the mean function when applied to lower layers. This is because, for higher layers, the mean is taken of only a few values, simply because higher layers are smaller in size. Thus, the max is the dominant calculation. In contrast, for lower layers, which are much larger, the max’s are taken over a relatively small number of elements (local regions), and then the mean is taken over all of the local regions. Hence the mean is the dominant calculation for lower layers.
Figure 8 shows the top-10 activated images under the four different calculations for concept activation. The CNN architecture, dataset and the depth of the CW module are the same as before. The figures show that when concept activation is calculated in different ways, the most activated images may look different and the network even may discover completely different abstract concepts. For example, when CW is applied to the layer, the network discovered the abstraction of the concept “bed” to be warm colors when the activation was the mean of feature map values, while the abstract concepts seems to involve boundaries of colors if activation is calculated as the max value. Also if the activation is calculated as the mean of all values, the “person” concept gives rise to an abstract concept involving dense texture, while under the mean of max-pooled values, the “person” concept is abstracted to a dark background with vertical lights. This difference in the discovered abstract concepts could be explained by the fact that these calculation methods focus on different locations within the image: the mean value focuses on the whole image while the max value only looks at one place within the image.
Table 2 shows concept AUC when different concept activation definitions are used. The definition and calculation of concept AUC is the same as in Section 4.3. The dataset and CNN architecture are also the same. To compare these concept activations’ capability to capture both high-level concepts and low-level concepts, we apply CW to the and layers of ResNet18. Table 2 indicates that in the layer, the max value of the feature map performs poorly in AUC than the other calculation methods. In contrast, in the layer, the mean performs poorly compared to the other methods. The max-pool-mean method performs well on both layers. This result matches our intuitive reasoning that the max-pool-mean combines the advantages of mean and max. It is suitable for capturing both low-level concepts and high-level concepts.
AUC-“airplane” | AUC-“bed” | AUC-“person” | ||||
---|---|---|---|---|---|---|
layer | layer | layer | layer | layer | layer | |
Mean | 0.820 | 0.981 | 0.687 | 0.853 | 0.714 | 0.918 |
Max | 0.716 | 0.992 | 0.589 | 0.904 | 0.759 | 0.969 |
Positive-mean | 0.798 | 0.992 | 0.614 | 0.924 | 0.757 | 0.968 |
Max-pool-mean | 0.818 | 0.993 | 0.692 | 0.906 | 0.757 | 0.966 |
As mentioned in Section 4.1, we measures the main objective accuracy when CW applied to different layers. Tables 6 through 6 show the layer-wise test accuracy of different CNN architectures. The dataset and CNN architectures are the same as in Section 4.1. Results in Tables 6 through 6 indicate that no matter which layer we apply CW, accuracy is not substantially impacted.
CW layer | Top-1 acc. | Top-5 acc. |
---|---|---|
53.2 | 83.8 | |
53.3 | 83.8 | |
53.4 | 83.8 | |
53.4 | 83.9 | |
53.2 | 83.9 | |
53.3 | 83.8 | |
53.5 | 83.8 | |
53.3 | 83.9 | |
53.4 | 83.8 | |
53.2 | 83.8 | |
53.2 | 83.9 | |
53.3 | 83.7 |
CW layer | Top-1 acc. | Top-5 acc. |
---|---|---|
55.2 | 85.4 | |
55.3 | 85.5 | |
55.3 | 85.5 | |
55.2 | 85.5 | |
55.3 | 85.5 | |
54.8 | 85.2 | |
54.7 | 85.0 | |
54.8 | 85.0 | |
54.7 | 85.0 | |
54.8 | 85.0 | |
54.8 | 85.1 | |
54.7 | 85.0 | |
54.8 | 85.1 | |
54.6 | 85.0 | |
54.7 | 84.9 | |
54.6 | 85.0 |
CW layer | Top-1 acc. | Top-5 acc. |
---|---|---|
55.6 | 85.7 | |
55.5 | 85.5 | |
55.5 | 85.6 | |
55.5 | 85.6 |
CW layer | Top-1 acc. | Top-5 acc. |
---|---|---|
53.9 | 84.2 | |
54.0 | 84.5 | |
54.0 | 84.3 | |
54.0 | 84.2 | |
54.0 | 84.3 | |
53.9 | 84.1 | |
53.7 | 83.9 | |
53.5 | 83.8 |
Comments
There are no comments yet.