Interpretable Disentanglement of Neural Networks by Extracting Class-Specific Subnetwork

10/07/2019 ∙ by Yulong Wang, et al. ∙ 0

We propose a novel perspective to understand deep neural networks in an interpretable disentanglement form. For each semantic class, we extract a class-specific functional subnetwork from the original full model, with compressed structure while maintaining comparable prediction performance. The structure representations of extracted subnetworks display a resemblance to their corresponding class semantic similarities. We also apply extracted subnetworks in visual explanation and adversarial example detection tasks by merely replacing the original full model with class-specific subnetworks. Experiments demonstrate that this intuitive operation can effectively improve explanation saliency accuracy for gradient-based explanation methods, and increase the detection rate for confidence score-based adversarial example detection methods.



There are no comments yet.


page 3

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks have recently transformed many areas including visual perceptions, language understanding, reinforcement learning, etc. Though they become the most representative intelligent systems with a dominant performance, DNNs are criticized for lacking transparency and interpretability. Better understanding the working mechanism of machine learning systems has become a requested demand, which is not only beneficial to academic research but also significant to many critical industries requiring a high level of safety concerns.

Figure 1:

Method overview. For each class, we extract a subnetwork from the full model by learning to activate only a fraction of neurons on each layer. The extracted class-specific subnetwork can focus on one class prediction, and maintain comparable performance with the full model.

In this paper, we propose a simple and interpretable disentanglement form for deep neural networks, which can not only reveal neural network’s functional behaviors but also have application improvement in visual explanation task [8] and adversarial example detection [1]. The main idea is that we propose to extract the class-specific subnetwork for each semantic category from a pre-trained full model while maintaining a comparable prediction performance (Figure 1). To effectively extract the subnetworks, we utilize the knowledge distillation criteria [2] and model pruning strategy [4]. We observe that the highly compressed subnetworks can display an architecture resemblance to their corresponding categorical semantic similarities.

Furthermore, the interpretable subnetworks extracted by the proposed method can have further operational influence in other tasks. One application is visual explanation, which provides salient regions or features relevant to the model prediction results. We propose a simple improvement operation for gradient-based visual explanation method, by just replacing the full model weights with subnetwork for the requested explanation class. This can lead to more accurate and concise salient regions. The similar technique can also be applied to another application: adversarial sample detection, which detects the malicious samples fooling DNN classifiers. We propose to use the features generated by the class-specific subnetworks to construct confidence score-based detector since the resulting features are observed to be more separable from adversarial data than those generated by full models.

2 Method

Let denote a neural network with pre-trained weights . For one test sample

, the output vector of

is then passed to the softmax activation function, which generates the prediction probability for each class. The output

for class

, which is usually termed as logit, can be expressed as

, and denotes a collection of samples .

To extract class-specific subnetworks, we consider our problem under the knowledge distillation framework [2]. Specifically, let denotes the original prediction made by the full model parametrized by for a single sample , and we want to extract a subnetwork with the parameter for class , whose prediction should be close to the full model under KL divergence measurement. Therefore, the objective is


where is an extra regularization term which encourages to be sparse enough. We adopt -norm as regularization term.

We observe that the original full model already has a good performance for single class binary classification. Therefore, the probability can be represented by transforming the output logit into , where

is Sigmoid function. Then the objective in Equation (

1) can be rewritten as


where stands for binary cross entropy function which is . By using Monte Carlo approximation, we can obtain the final objective for learning subnetwork , which is


As for the parametrization form of , we associate control gates on multiple layers’ output channels in the network. The control gate values then modulate the output features of -th layer by channel-wise multiplication.

3 Adversarial Sample Detection

A school of adversarial sample detection methods are based on confidence scores discrimination. By evaluating the confidence score based on training data density estimators in feature space, one can adequately judge whether a sample appears on the true data manifold with high probability or not.

Figure 2: Class-specific subnetworks can improve the discriminability between true data and adversarial data in feature space. Five hundred true images and corresponding adversarial samples are displayed by using the UMAP embedding projection of VGG16 penultimate features. The different predicted labels are marked with different colors. (a) feature embeddings generated by the original full model weights. (b) feature embeddings produced by their class-specific subnetworks.

Now with the class-specific subnetworks, we can further increase the discriminability of confidence score methods. Figure (a)a have displayed clustering patterns, but the adversarial samples have high overlaps with true data manifold. Figure (b)b demonstrates that adversarial samples become more separable when using the features generated from class-specific subnetworks.

Figure 3:

Subnetwork representation visualization by 2D UMAP embedding projection. For each figure, there show 1,000 subnetworks which are categorized into 25 clusters using a hierarchical clustering algorithm. For some apparent groups, we denote with their composition members’ label names. These subnetwork representations display a resemblance to their class semantic similarity.

Figure 4: Visual explanation saliency by using class-specific model subnetworks. For each column, we show that for gradient-based explanation methods, the visual saliency can be improved by just replacing the full model weights with the subnetwork for the requested explanation class.
Method AlexNet VGG16 ResNet50
Normal Subnet Normal Subnet Subnet Subnet
Err(%) Err(%) Err(%) Err(%) Err(%) Err(%)
Saliency 3.5 46.11 3.5 44.80 5.0 48.01 5.0 44.21 5.0 49.20 4.0 41.64
Deconv [8] 4.5 49.61 4.5 47.00 5.0 49.87 5.0 46.85 3.5 47.26 4.0 50.32
Guided-BP [10] 5.0 43.05 4.5 42.20 7.0 41.66 6.0 40.83 6.5 47.52 4.0 40.25
GradCAM [7] 1.0 54.13 1.0 51.48 1.0 49.21 1.0 46.94 1.0 42.63 1.0 45.14
IntGrad [11] 7.5 43.01 6.0 42.71 9.0 43.05 6.0 41.05 8.5 47.60 4.0 41.85
SmoothGrad [9] 7.0 48.72 4.0 47.01 9.5 53.57 2.0 47.75 3.0 52.37 2.5 51.98
Table 1:

Weakly supervised localization errors on ImageNet validation dataset. “Normal” indicates standard explanation practice by using full model weights regardless of the requested explanation class. “Subnet” indicates the proposed practice by replacing with class-specific subnetworks.

is optimized on held-out 5,000 images from ImageNet training dataset. “Err” indicates localization error (lower is better).
Dataset Mehod Detection AUROC (%) Unknown Attack Detection AUROC (%)
FGSM BIM DeepFool CW FGSM (seen) BIM DeepFool CW
CIFAR-10 KD + PU 81.21 82.28 81.07 55.93 81.21 16.16 76.80 56.30
LID 99.71 96.39 88.47 82.93 99.71 95.38 71.86 77.53
Mahalanobis 99.92 99.59 91.53 95.85 99.92 98.91 78.06 93.90
Ours 99.97 99.17 91.91 96.88 99.97 99.11 82.45 95.62
CIFAR-100 KD+PU 89.90 83.67 80.22 77.37 89.90 68.85 57.78 73.72
LID 89.27 85.19 64.80 75.35 89.27 55.82 63.15 75.03
Mahalanobis 99.77 96.72 83.93 91.65 99.77 96.38 81.95 90.96
Ours 99.81 96.95 82.44 94.41 99.81 95.84 77.80 92.56
SVHN KD+PU 82.67 66.19 89.71 76.57 82.67 43.21 74.26 67.85
LID 95.72 87.41 88.81 85.66 95.72 84.88 67.28 76.58
Mahalanobis 99.63 97.14 95.46 92.13 99.63 95.39 72.20 86.73
Ours 99.54 97.24 95.82 93.63 99.54 96.38 78.75 91.09
Table 2:

Adversarial sample detection AUROC (%) for different methods. Our method improves upon Mahalanobis distance score by replacing feature extractor with class-specific subnetworks to estimate empirical means and covariance matrices. For unknown attack detection, FGSM samples denoted by “seen” are used for training logistic regression detector.

Detection based on Mahalanobis distance In this section, we will formally present the improved detection algorithm based on the Mahalanobis distance score proposed in [3]. Following the definition in [3], the Mahalanobis confidence score for a test sample is computed by measuring the Mahalanobis distance between

and its closest class-conditional Gaussian distribution in feature space:


where and are the empirical class mean and covariance for features of training samples .

With the extracted subnetworks, we can modify the empirical mean and covariance estimation by using the class-specific feature as instead. Suppose that is the resulting subnetwork for class . Then and can be estimated by


Similar to [5], the other low-level features in the neural network can also be combined to estimate confidence and a logistic regression detector is trained on a held-out validation data to weight each feature importance.

4 Experiments

We extract class-specific subnetworks of three typical ImageNet pre-trained networks: AlexNet, VGG16, and ResNet50. When optimizing the subnetwork for class

, for each epoch a balanced training set is sampled dynamically by including all the 1,000 images of class

, and an equal number of randomly chosen images for all the other classes. The final subnetwork is selected after epochs with minimum loss, below the sparsity level . Balance parameter . Mini-batch size is 64. The learning rate for Adam optimizer is 0.1 for all the experiments.

Subnetworks Visualization For each subnetwork, the associated control gates can reflect the utilization of each layer when predicting specific class. Figure 3 displays the relationships between different class-specific subnetworks when projected onto the 2D plane using the UMAP algorithm. Here we can observe that the subnetwork representations tend to be more similar when their corresponding labels are semantically closer.

Improving Visual Explanation Visual explanation methods usually present the highlighted salient regions in the input image as explanation results. For most of visual explanation methods [8, 10, 11, 9, 7]

, they all generate the visual saliency by following specific predefined “layerwise attribution backpropagation” rules 

[6]. Here we propose a simple alternative explanation procedure for the above gradient-based explanation methods, by using class-specific subnetwork as model weights when explaining the requested class. Figure 4 shows that by using the extracted class-specific subnetwork, these methods can generate more clear and accurate salient regions focusing on the main objects.

Weakly Supervised Object Localization To demonstrate the improvement of visual explanation methods more rigorously, here we adopt the Weakly Supervised Object Localization (WSOL) evaluation protocol. Table 1 summarizes the results. The proposed method can reduce localization errors across different methods. These results validate that the proposed practice can help improve gradient-based visual explanation methods.

Detecting Adversarial Samples Following the similar experimental setups in [3], we experiment with four attacking methods including FGSM, BIM, DeepFool and -version CW attack. We first extract class-specific subnetworks for each dataset. Then the Mahalanobis scores are calculated according to Equation (6). The subsequent logistic regression detector setups are the same as [3].

We compared three state-of-the-art logistic regression detectors, which are based on 1) the combinations of kernel density (KD) and predictive uncertainty (PU) [1], 2) the local intrinsic dimensionality scores (LID) [5] and 3) the Mahalanobis distance scores [3].

The middle columns of Table 2 summarize the detection results. Our method can generally improve detection success rates over the baseline methods across different attacking methods. We also train the logistic regression detector on FGSM and evaluate its detection performance on the other types of adversarial samples. The right columns of Table 2 summarize the results. Our method can still outperform baseline methods in most cases. The results validate the power of class-specific subnetworks to detect adversarial examples.

5 Conclusion

In this paper, we explore the possibility of understanding DNNs from disentangled subnetworks. The discovery reveals that the extracted subnetworks can display a resemblance to their corresponding class semantic similarity. Furthermore, the proposed techniques can effectively improve the localization accuracy of visual explanation methods, and detection success rate of adversarial sample detection methods.