Affective computing is a field of research concerned with building computational models for emotion recognition in order to help computers understand, analyze, and mimic human emotions [picard2000affective]. Non-invasive technologies such as recording of brain signals with Electroencephalogram (EEG) have been widely used for affective computing [zheng2015investigating, zheng2017multimodal, zhang2020rfnet].
EEG is a non-stationary time-series, generally with a large number of dimensions (high dimensionality) and sampling rate (temporal resolution). As a result, recent deep learning solutions for EEG representation learning often require complex networks to sufficiently learn the information contained within EEG signals. Specifically, Capsule Networks (CapsNet)[sabour2017dynamic] have been applied to EEG for affective computing and achieved state-of-the-art results [zhang2019capsule]. Nonetheless, such approaches that rely on capsules for EEG representation learning consist of a large number of parameters, making it difficult for online and real-time use for smart device deployment.
In this paper, we propose a novel method for knowledge distillation based on capsule networks, capable of being used for both classification and regression tasks in the context of EEG representation learning. Specifically, we first revisit the capsule-based model proposed in [zhang2019capsule] and use a similar architecture as our student network. We then develop a novel knowledge distillation framework via capsules, which transfers knowledge contained in both higher and lower level capsules from the teacher to the student. Next, in order to utilize more training data, we pre-train the teacher network on large amounts of cross-subject EEG data. We then fine-tune the pre-trained teacher on intra-subject data to learn subject-specific knowledge. Afterwards, we evaluate the student on subject-specific experiments with the help of privileged information learned by the teacher and transferred to the student. Our experiments show that a compact student network with only of the number of parameters of the original teacher network can achieve competitive results in comparison. Moreover, our approach improves the robustness of the compact student when faced with limited training samples. Lastly, our experiments on two separate public datasets show competitive results by the student model for one of the datasets, while outperforming the related work on the other dataset.
Our contributions are summarized as follows. (1) For the first time, we propose a distillation pipeline based on a teacher-student framework capable of capsule network compression, while improving overall student performance. (2) To the best of our knowledge, this is the first time that knowledge distillation has been implemented for affective EEG representation learning. (3) Our proposed method can be used for both classification and regression tasks. Experiments on SEED and SEED-VIG datasets show that our model performs well on SEED, while achieving a new state-of-the-art on SEED-VIG.
Ii Related Work
Affective Computing with EEG.
A variety of deep learning techniques have been used to learn the most discriminative features extracted from EEG for affective computing. For example, In[wu2018regression]huo2016driving]
, a Graph regularized Extreme Learning Machine (GELM) was employed to predict fatigue . To learn the time-dependency and spatial information in EEG signals, the authors used Spatial-Temporal Recurrent Neural Network (STRNN)[zhang2018spatial], achieving strong performance. In addition to spatiotemporal feature learning, Regional to Global Brain-Spatial-Temporal Neural Network (R2G-STNN) was proposed to minimize the domain-shift by applying a discriminator [li2019regional]. To further investigate dependencies between EEG electrodes, in [zhang2020variational]
, the authors proposed Variational Pathway Reasoning (VPR) for emotion classification, achieving state-of-the-art results. The VPR pipeline employed Long short-term memory (LSTM) to learn sequential information between electrodes, thus encoding the pathway around them. Then this method used a Bayesian probabilistic approach to learn pathways’ scaling factors to identify the one with the most salient pair-wise connections[zhang2020variational]. In [zhong2020eeg], Regularized Graph Neural Networks (RGNN) was employed to to explore the graph connections of EEG electrodes, approaching the best results with fully explored topological knowledge. In [zhang2019capsule], an LSTM-CapsNet model was proposed to explore both temporal and spatial information to predict vigilance. Very Recently, Zhang and Etemad proposed a Riemannian Fusion Network (RFNet) to learn the temporal information through an LSTM-attention network [zhang2020rfnet]. This method also learned spatial information through a parallel Riemannian-based approach with Spatial Covariance Matrix (SCM) as input.
Capsule Networks. Capsule networks were proposed in [sabour2017dynamic] to learn the part-whole
relationships of objects through iterative routing among different level capsules. The pipelines using capsule networks have achieved state-of-the-art results in some areas of Natural Language Processing (NLP) such as intent detection[zhang2018joint] and multi-label classification [chen2020hyperbolic]
, as well as computer vision such as expression recognition[sepas2101capsfield] and low resolution image recognition [singh2019dual]. Recently, an LSTM-CapsNet architecture was successfully proposed for EEG-based affective computing [zhang2019capsule].
Knowledge Distillation. Vapnik et al.
proposed a learning paradigm to enable machine learning from other machines with privileged information in the training stage[vapnik2009new, vapnik2015learning]
. The paradigm relies on a teacher machine with more discriminative information than the student machine, using Support Vector Machine (SVM). The teacher machine is effective when its expected error is smaller than the student’s, as theoretically shown in[vapnik1998statistical]. However, there are several constraints in the paradigm such as the restriction to SVM, fixed parameters, and lack of information on hard labels in the student machine training [lopez2015unifying].
In order to develop an effective knowledge transfer pipeline suitable for deep learning, knowledge distillation has been proposed in [hinton2015distilling]. Knowledge distillation transfers soft target distributions learned by a cumbersome model (teacher) to a smaller model (student), where the architecture and parameters of the models can be customized [hinton2015distilling]. For this purpose, KL divergence of the soft target distribution is minimized between the student and teacher networks, thus enabling pure knowledge distillation during the training stage [hinton2015distilling]. In [lopez2015unifying], knowledge distillation was described as a two-step process in which i): the teacher learns the data using hard labels; ii): the student learns the data using soft labels computed from the teacher, as well as the hard labels. Consequently, the student not only receives the privileged knowledge from the teacher, but also learns the hard labels during the distillation process. In [phuong2019towards], analysis of data geometry and optimization bias also showed the benefits of knowledge distillation. Very recently, knowledge distillation has been successfully implemented in a number of computer vision areas such as video captioning [Pan_2020_CVPR] and prediction regularization [Yun_2020_CVPR]. This concept has also been recently used in different areas of NLP such as large model compression [sun2019patient]tang2019natural].
Iii Our Approach
Overview. We aim to distill information from a heavy and cumbersome capsule-based model to a lightweight model for subject-specific tasks suitable for both classification and regression. To do so, we propose a four-step process. (i) Developing two separate networks containing CapsNet architectures, one called the teacher network and the other the student network. (ii) Pre-training the teacher network on the large amounts of available cross-subject data. (iii) Using the pre-trained teacher to then learn information embedded in capsules with intra-subject data. (iv) Training the student on intra-subject data with the help of the privileged information learned by the teacher via capsules. As shown by our results in Section 4, this process enables us to maximally compress the model with minimal loss in performance.
In the following sections, we first revisit the architecture of the LSTM-CapsNet model used in this study, and then introduce our novel distillation framework via capsules.
Iii-a Revisiting LSTM-CapsNet Architecture
Feature Space. The input EEG data are first pre-processed followed by feature extraction. Pre-processing of EEG has been kept consistent with previous works on the same datasets [zheng2015investigating, zheng2017multimodal]. Specifically, the signals were downsampled to Hz from Hz. A band-pass filter of Hz followed by a notch filter at Hz were then applied to the raw EEG signals to minimize artifacts and scale down power line noise. Data normalization was followed to scale signal amplitudes into the range of , thus reducing the discrepancy of EEG collected from various subjects and data recording periods [zhang2020rfnet].
In the feature extraction step, we extracted two types of features notably Power Spectrum Density (PSD) and Differential Entropy (DE). To extract PSD features, firstly, we applied consecutive -second Hanning windows with no overlap on each -second EEG segment, thus avoiding spectral leakage caused by finite windowing (the value for along with other parameters used for pre-processing and the network architectures, are presented later in Section 4). We then have a total of
number of Hanning windows in each EEG segment. We then applied Short-Time Fourier Transform with the Hanning window to transform signals from time domain to frequency domain. Afterwards, we compute the logarithm of PSD using Eq.1 and DE using Eq. 2
(with the assumption of the Gaussian distribution of the signal) in different frequency bands[zhang2020rfnet].
LSTM Network. We employ an LSTM network to learn the time-dependencies within the EEG signals. Specifically, we feed the extracted features from each of the number of windows to the corresponding number of cells of the input LSTM layer with hidden units. The time-dependent information learned by each LSTM cell is reshaped from units to a square matrix for further processing.
Lower Level Capsules. We employ a 2D convolution layer with number of output channels, kernel size of
, and stride of, to capture local features. We then apply another 2D convolution layer with output channels, kernel and stride to produce lower level capsules, yielding number of lower level capsules with dimensions.
Higher Level Capsules. Higher level capsules are designed to learn global information as opposed to lower level capsules which capture local information [sabour2017dynamic]. Let’s denote and as the number and dimension of higher level capsules, respectively.
is consistent with the number of categories in our classification task and can be empirically tuned for regression tasks. We enable the higher level capsules to have larger degrees of freedom by setting[sabour2017dynamic].
Capsule Network. Capsule networks have been used to learn the ‘part-whole’ relationships between lower level capsules and higher level capsules [sabour2017dynamic]. The capsule network assigns attention scores from lower level capsules (’part’ information) to higher level ones (’whole’ information) through dynamic routing [sabour2017dynamic]
. Specifically, a CapsNet used as an attention mechanism is established between the prediction vectorand the output of higher level capsules (). represents the prediction from each lower level capsule to each higher level capsule . The prediction vector is expressed as the multiplication of weight matrix with the output of lower level capsule . represents the total input to capsule . The output of higher level capsule is the squashed output of which normalizes into the range of .
are the softmax outputs of logit, where
represents the log prior probabilities that will be updated by the iterative processafter its zero initialization.
Iii-B Proposed Method for Knowledge Distillation
An overview of our novel knowledge distillation framework is illustrated in Figure 1. We develop a knowledge distillation framework to compress the large model without performance degradation. To do so, we first employ the LSTM-CapsNet architecture described above as the teacher network. Next, we pre-train the teacher on cross-subject data and then fine-tune it on intra-subject data in order to adapt to subject-specific features. At last, we train the student model with the help of privileged information learned by the teacher, and then evaluate it on intra-subject data. In order to fully explore the privileged information, we learn the inter-dimension relationships of lower level capsules, as well as the information embedded in higher level capsules, through minimizing their similarities between the teacher and student networks.
Knowledge Distillation via Lower Level Capsules. Lower level capsules contain local features where their different dimensions capture different aspects of the information space [sabour2017dynamic, zhang2019capsule]. For example, when trained with handwritten digital images, each dimension of lower level capsules contains different information on digital-specific variations, for instance, scale, thickness, and width [sabour2017dynamic]. Therefore, we explore the similarity [tung2019similarity] of the capsules’ inter-dimension correlations between the teacher and student networks. In order to learn such inter-dimension correlations, we first calculate the covariance matrix of the lower level capsules as:
where represent the covariance matrices of lower level capsules of the teacher () and student () networks respectively. Dimension is kept the same for both networks. The number of lower level capsules of the teacher are greater than or equal to the student’s as . Next, we compute the square euclidean distance between two normalized covariance matrices [tung2019similarity] as similarity loss:
where is the Frobenius norm.
Knowledge Distillation via Higher Level Capsules. Higher level capsules include global information, where the length of their output vectors
represents the probability that the entity corresponding to that capsuleexists. Such information is further used as ’soft target labels’ for knowledge distillation. Specifically, we employ KL divergence to measure the difference of information distribution in higher level capsules between teacher () and student network (). The following equation is used:
where is the softmax operator, is the temperature parameter [hinton2015distilling], and a logarithm operation is applied to the student output to help accelerate the distillation process. We use throughout the experiments.
Teacher Network. We use similar architecture to the state-of-the-art network from [zhang2019capsule] as our teacher network. We employ -stacked LSTM layers with units. Layer normalization [ba2016layer] is used and followed by LeakyReLU after each LSTM layer. For the SEED dataset, we set the number of higher level capsules to to be consistent with the number of emotion classes, in order to use margin loss for classification [sabour2017dynamic]. For the regression task in the SEED-VIG dataset, we empirically set as the number of higher level capsules. The output is then followed by a fully connected layer containing hidden units with Sigmoid activation. The details are shown in Table I.
Student Network. The student network has the same architecture as the teacher but with fewer parameters. Specifically, the student network contains a single layer of LSTM with fewer hidden units , yielding a smaller number of lower level capsules . The parameter details are presented later in Section 4.4.
Training Loss Function.
The training loss function includes three parts, namely lower level capsule distillation loss, higher level capsule distillation loss, as well as a task-specific loss. The task-specific loss depends on the task type (classification vs. regression). For the regression task, we employ a fully connected layer (
hidden units) with a sigmoid activation function as in[zhang2019capsule] to enable the Minimum Squared Error (MSE) loss calculation (). For the classification task, we use margin loss (Eq. 6) as recommended in [sabour2017dynamic]:
where if class is the correct prediction, otherwise . The first part of the equation will be zero if and only if the probability of correct prediction is greater than . The second part of the loss function will be zero if and only if the probability of incorrect prediction is less than .
Consequently, the total loss is shown as:
where is the scaling factor, and and are trade-off hyper-parameters for lower and higher level capsules distillation loss, respectively.
|Feature Space||Temporal Info.||Lower Level Capsules||Higher Level Capsules||Regression Layer|
|[zhang2020rfnet]||SCM, DE, PSD||RFNet|
SEED. The SEED dataset was collected by [zheng2015investigating] to perform emotion recognition with three categories of positive, negative, and neutral emotions. emotion-related videos were selected as stimuli in each experiment. subjects, including females and males, performed a total of experiments, where each subject participated in the experiments in two different runs. Each run contained sessions. Each session started with a -second notice before playing the video clips, followed by approximately minutes of watching the movie clip, and concluded by seconds of self-assessment. Each session ended with a -second relaxation. EEG channels were recorded with a sampling rate of Hz using the international system.
SEED-VIG. The SEED-VIG dataset [zheng2017multimodal] contains EEG recordings to estimate drivers’ continuous vigilance levels. subjects ( female and male) participated in the experiment and drove a simulated vehicle in a virtual environment. The experiment took around minutes. overall consecutive EEG segments were recorded in each experiment. The duration of eye blinks and eye closures as well as the duration of fixation and saccade [zheng2017multimodal], which were all measured using eye-tracking glasses, were used to measure the output ground-truth labels called PERCLOS. The EEG signals were recorded from locations with a sampling rate of Hz using the international system.
Iv-B Evaluation Protocol
Teacher Network. We pre-train the teacher network on the cross-subject data. We use leave-one-subject-out cross-validation to pre-train a teacher for each subject. Consequently, the pre-trained teacher used for each specific subject has not seen the data from that subject during training. For the SEED dataset, we have EEG trials for each experiment run per subject, yielding a total of EEG trials for training, and EEG trials for testing. Similarly, in the SEED-VIG dataset, we have a total of EEG trials for training and EEG trials for testing.
Student Network. We train and evaluate the student network on intra-subject data. We follow the same evaluation protocol as the related works [zheng2015investigating, zheng2017multimodal, zhang2020rfnet]. In the SEED dataset, we use the pre-defined first sessions and the last sessions as the training set ( EEG trials) and test set ( EEG trials), respectively [zheng2015investigating]. In the SEED-VIG dataset, we employ -fold cross-validation for our train-test set split as in [zheng2017multimodal].
We adopt both Pearson Correlation Coefficient (PCC) and Root Mean Squared Error (RMSE) as evaluation metrics for the regression task in the SEED-VIG dataset[zheng2017multimodal], while accuracy (Acc.) is used as the evaluation metric for classification in the SEED dataset [zheng2015investigating].
Iv-C Implementation Details
Feature Extraction. We use different frequency bands in the feature extraction step for each dataset. For the SEED dataset, we use five frequency bands, notably delta, theta, alpha, beta, and gamma bands [zheng2015investigating]. Accordingly we have features extracted from each -second window. For the SEED-VIG dataset, we use frequency bands with resolution, starting from to [zheng2017multimodal]. We thus have features extracted from each window.
Other Hyper-Parameters and Training. In this work, we apply weight clipping to avoid gradient explosion. In the teacher pre-training phase, we run a total of epochs. The learning rate is initialized to and decreases by times after th epoch, then drops again by times after the th epoch. For the rest of the experiments (fine-tuning and subject-specific phases), training is performed with epochs, with a fixed learning rate of . We employ the Adam algorithm with default decay rates for optimization. The batch sizes are set to during teacher pre-training and for all the other experiments. We set the scaling factor to , and the trade-off hyper-parameter to for SEED and for SEED-VIG, respectively. The parameter is set to be
for both datasets. All hyper-parameters were empirically tuned on the validation set. All of our experiments are implemented using PyTorch[paszke2019pytorch] on a pair of NVIDIA Tesla P100 GPUs.
Student Model Size. We evaluate the impact of knowledge distillation on student networks with different number of parameters. To do so, we select several numbers of hidden units in the single LSTM layer, yielding number of lower level capsules for each of the four student networks. As shown in Table IV, each student network has a different number of parameters and respective compression ratio. Figure 2 presents the performance of the students with different number of parameters. We observe consistent performance improvement for the student when our knowledge distillation framework is used, compared to when the student is trained from scratch in both datasets. For example, for the SEED dataset, the student achieves a boost with the help of the teacher, while for the SEED-VIG dataset, the performance of student improves by in RMSE and in PCC.
Comparison with Existing Methods. We compare the performance of the best student, i.e. student 1, using our knowledge distillation framework with existing methods. Table II shows the recent related work on the SEED dataset. Our student model with distillation obtains an accuracy of , achieving a good result. Table III shows the related work on the SEED-VIG dataset. Our student model obtains an RMSE of and a PCC of , setting new state-of-the-art results.
Fewer Training Samples. We investigate the role of distillation when fewer training samples are available. To do so, we randomly select different subsets of training samples () for training, with the same random seed throughout all experiments. We then compare the performance between the student trained without distillation and the student with the privileged information learned from the teacher, when different amount of training samples are used. We conduct the experiments using student model which has the smallest parameters among all students. Figure 3 shows the impact of our framework on the performance of the student network, when fewer training samples are available. For SEED, we observe convincing performance improvements brought by distillation when less than of training samples are available. Specifically, with only of training samples, the model performance improves by with the help of our proposed method. When more than of the training samples are used, improvements are marginal. Finally, the two models converge in the end. For SEED-VIG where the task is regression, utilizing a small subset of training samples may result in validation labels having output values not appearing during training. As a result, this experiment proves challenging when using very limited training samples. As shown in Figure 3, the performance of both models significantly drops when the number of available training samples decreases. We observe that distillation doesn’t help the student in performance when less than of training samples are available, while for larger subsets, slight improvements are consistently achieved.
We proposed a novel distillation framework on EEG representations for effective model compression. The framework is established on a capsule-based network and utilizes similarities among capsules for knowledge distillation. Our proposed method was applied on both classification and regression tasks on two popular public EEG datasets, in the field of affective computing. Our method shows strong performance for one dataset and achieves state-of-the-art for the other. Our experiments show that the improvement in performance of the student network with the teacher’s privileged information compared to the same student trained without the teacher, increases as our student network is more compressed. Moreover, further experiments illustrate that our method is less sensitive to the size of the training set, helping the student in more effective learning when fewer training samples are available.