Temporal Knowledge Distillation for On-device Audio Classification

by   Kwanghee Choi, et al.

Improving the performance of on-device audio classification models remains a challenge given the computational limits of the mobile environment. Many studies leverage knowledge distillation to boost predictive performance by transferring the knowledge from large models to on-device models. However, most lack the essence of the temporal information which is crucial to audio classification tasks, or similar architecture is often required. In this paper, we propose a new knowledge distillation method designed to incorporate the temporal knowledge embedded in attention weights of large models to on-device models. Our distillation method is applicable to various types of architectures, including the non-attention-based architectures such as CNNs or RNNs, without any architectural change during inference. Through extensive experiments on both an audio event detection dataset and a noisy keyword spotting dataset, we show that our proposed method improves the predictive performance across diverse on-device architectures.



There are no comments yet.


page 2


Multi-Representation Knowledge Distillation For Audio Classification

As an important component of multimedia analysis tasks, audio classifica...

Efficient Knowledge Distillation for RNN-Transducer Models

Knowledge Distillation is an effective method of transferring knowledge ...

The State of Knowledge Distillation for Classification

We survey various knowledge distillation (KD) strategies for simple clas...

Lifelong Language Knowledge Distillation

It is challenging to perform lifelong language learning (LLL) on a strea...

Device-Directed Speech Detection: Regularization via Distillation for Weakly-Supervised Models

We address the problem of detecting speech directed to a device that doe...

Better Supervisory Signals by Observing Learning Paths

Better-supervised models might have better performance. In this paper, w...

Domain Generalization on Efficient Acoustic Scene Classification using Residual Normalization

It is a practical research topic how to deal with multi-device audio inp...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

With the ubiquity of real time communication, on-device audio understanding has received great attention. On-device models have achieved comparable performance to large models on several tasks such as keyword spotting (KWS) [20, 17]. Nevertheless, compared to large models, on-device models still struggle with more complex tasks (e.g., audio event detection (AED) [9]). It is challenging to improve the performance of on-device models due to the restricted memory and computing resources in the mobile environment.

Several studies [10, 14] utilize knowledge distillation (KD) [13]

to tackle the problem described above, applying the knowledge of large models (teacher) to on-device models (student) without incurring any computational overhead at inference time. Many on-device models commonly focus on the knowledge embedded in logits produced by the classification layer 

[10, 14, 3], mainly because it can be easily applied even when the teacher and the student have dissimilar architectures. However, temporal information, which is known to be beneficial in audio tasks [15]

, cannot be easily distilled when it is compressed into classifier logits 


With the success of the transformer [19], recent studies [22, 4]

have focused on distilling the knowledge from self-attention maps, preserving the temporal information. However, their methods are limited to transferring the knowledge between the same transformer-based architectures only, where even the smallest transformer variants remain computationally expensive for many mobile devices. Also, it is not straightforward to transfer the knowledge of self-attention maps from the large transformer-based model to other architectures such as convolutional neural networks (CNNs) or recurrent neural networks (RNNs).

In this paper, we introduce a simple yet effective method that can distill the temporal knowledge from attention weights of large models to on-device models of various architectures. We first employ XLSR-wav2vec 2.0 [7] as a teacher model and extract attention weights from its self-attention maps. We design the attention distillation loss for the on-device (student) models by attaching a simple attention layer only at training time to align the teacher and the student attention weights. To evaluate the effectiveness of our proposed method, we conduct experiments on a real-world AED dataset (FSD50K [9]) and a noisy KWS dataset. The noisy KWS dataset is constructed by injecting the existing KWS dataset samples (Google Speech Commands v2 [20]) into different noisy audios [11], making the temporal information more important for classifying each keyword. Experimental results demonstrate that applying our method improves the predictive performance of various on-device models without any architectural changes during inference.

2 Proposed Method

Figure 1: Illustration of our proposed method.

In this section, we first describe a large-scale transformer-based model that is used as a teacher model (2.1) and on-device models employed as student models (2.2). We then introduce our method that transfers the temporal knowledge from self-attention of the teacher model to student models without any architectural changes during inference (2.3). Our proposed method is illustrated in Figure 1.

2.1 Teacher: Large Transformer-based Model

We employ the XLSR-wav2vec 2.0 [7] as our teacher model , which is a large-scale transformer-based ASR model with a state-of-the-art performance on multilingual ASR. The teacher model translates a raw audio to latent representations for time-steps using a convolutional feature encoder. The latent representations are passed through consecutive transformer layers to output context representations , where is the

-dimensional vector.

To perform audio classification, we attach a fully-connected layer to the output of

. Similar to the fine-tuning of language models 

[8], we feed only the first output to the fully-connected layer. and the fully-connected layer are trained end to end on audio classification datasets (Details described in Sec. 3).

2.2 Student: Lightweight On-device Models

In this paper, we consider the following on-device audio classification models adopted by [17] as our student models : a simple RNN-based model (LSTM-P[12], a CNN-based model (TC-ResNet[5], a model that uses both CNN and RNN (CRNN[2], a model including an attention mechanism (Att-RNN[6], and a multi-head variant of Att-RNN (MHAtt-RNN[17]. For the student models, we pass the raw audio to the MFCC-based feature encoder. The student models extract the context representations for time steps, where is the -dimensional vector. Note that the student models have different sizes and depending on their architecture.

We integrate the attention mechanism that extracts attention weights from every student model except Att-RNN and MHAtt-RNN architectures which already include the attention mechanism. The attention weights act as recipients of the transferred high-level knowledge of the teacher model . The attention weights are computed by applying the softmax function on the inner-product of the context representations and a query as follows: . There are many strategies for designing the query , e.g., random initialization [21] or projection of the medium context representation [6], where we chose the latter one. As the training progresses, learns to capture the importance of each context representation . For Att-RNN, we directly use its attention weights as . For MHAtt-RNN, which adopts a multi-head attention mechanism, we are motivated by [19] to choose one of the heads to yield attention weights. While this approach is not universally applicable for arbitrary architectures, we emphasize that the only architectural requirement for the student model to satisfy is to output an intermediate features which preserve the temporal information. There are already many prominent architectures that satisfy this requirement, e.g. CNN feature map before the global average pooling layer or sequence of outputs of RNN.

2.3 Temporal Knowledge Distillation

To extract the temporal knowledge from the teacher model, we leverage self-attention maps from transformer layers of the teacher model . The attention rollout technique [1] is applied to the self-attention maps and results in a single unified attention map . We utilize the first vector in as attention weights of the teacher model since the teacher is trained by performing the audio classification task based on the context representation of the first time step.

To transfer the temporal knowledge from the teacher model to the student model, we align and using the loss. We define as a Kullback-Leibler (KL) divergence between the two attention weights so that minimizing the loss will penalize the misalignment. However, the loss term cannot necessarily be directly computed since the dimensions of the two weights and might not match (

). Therefore, we employ a simple linear interpolation method to match the dimension of

with while preserving the temporal knowledge. After applying the linear interpolation on the attention weights , we obtain the final attention weights of the teacher model . Using the attention weights of the teacher and student models, the loss term is computed as follows:


where is the KL divergence.

The final student loss is defined as follows:


where the is a cross-entropy-based classification loss of the student model.

is a hyperparameter that controls the influence of each loss term.

Note that the on-device models (student models) except Att-RNN and MHAtt-RNN leverage the attention weights only during training to receive the knowledge from the teacher model, and they do not use the attention weights during inference. In other words, there is no architectural change in the model during inference, hence showing zero computational overhead for the inference of on-device models.

3 Experiments

3.1 Experimental Setup

Datasets. We verify the effectiveness of our proposed KD method with experiments on a real-world AED dataset (FSD50K) and a noisy KWS dataset (called noisy speech commands v2).

FSD50K [9]: The FSD50K dataset [9]

is a multi-label audio event detection dataset, which represent real-world audios. The dataset is composed of 51,197 human-labeled audio events with 200 classes with lengths ranging from 0.3 to 30 seconds. The audio inputs are zero-padded when the inputs are shorter than 30 seconds.

Noisy Speech Commands v2: To clearly demonstrate the effectiveness of our method in distilling the high-level knowledge of temporal information, we construct a noisy KWS dataset by inserting the existing KWS dataset, Speech Commands v2 [20], to the background speech noise. The Speech Commands v2 dataset [20] contains 105,829 one-second utterances of 35 words stored as 16-bit mono PCM WAVE files with a 16KHz sample rate. Following the settings from [17, 20], we use their training splits and the 12 class labels, which include the “silence” label with no speech and “unknown” label with an additional 20 keywords. We generate the synthetic audios by injecting the one-second speech command audios to the background speech noise obtained from the “Hubbub”, “speech noise”, “speech babble” classes of the AudioSet dataset [11]. The speech noise is randomly cropped to a predefined duration. The locations of the speech command audios are uniformly sampled within the speech noise audio. We define four Noisy Speech Commands datasets, each using a fixed duration of 2, 4, 6, and 8 seconds noise. The noisy datasets are split using the same training splits with the Speech Commands v2 dataset. Audio mixing is done by weighted sum of both the one-second Speech Commands v2 and AudioSet noise PCM signals, with respective weights of and .
Baselines. We employ the five on-device models described in Sec. 2.2 as baseline models. By varying the hyperparameter , we observe the change of predictive performance to show the effect of the hyperparameter that controls the influence of our proposed method. When , the loss becomes equal to the vanilla cross-entropy loss without knowledge distillation.
Metrics. Following [9] and [17], we use mean average precision (mAP) and accuracy to evaluate the performance on FSD50K and Noisy speech commands v2, which are the multi- and single-label classification datasets, respectively. The higher score in both metrics indicates higher performance.
Implementation Details. The teacher model (XLSR-wav2vec 2.0) is pretrained on multilingual speech dataset [7]

. We fine-tune the teacher model with a batch size of 16 for 50 epochs. For FSD50K and Noisy speech commands v2, we set the learning rate as 2e-5 and 5e-4, respectively, with 1K warmup steps. We apply SpecAugment 


with probability 0.75 which consists of two 10% frequency mask while training.

The on-device models are trained for 20K iterations with a batch size of 100. Audios are re-sampled to a sample rate of 16kHz. The student models take the MFCC representations of the audio as an input. We employed a best keeping strategy to keep the weights with the best performance on the validation set using the evaluation metrics. We leverage an existing code 

[17] for our experiments.

3.2 Evaluation Results

3.2.1 Results on Real-word AED dataset

Model Vanilla Attention Distillation
wav2vec 2.0 0.5498 N/A
LSTM-P 0.1141 0.1274 0.1300 0.1043
TC-ResNet 0.1814 0.1841 0.1951 0.1509
CRNN 0.2789 0.2670 0.2835 0.3053
Att-RNN 0.2856 0.3471 0.2885 0.2891
MHAtt-RNN 0.2647 0.1694 0.3182 0.3317
Table 1: Performance comparison on the FSD50K dataset. Test mAP of the best model found by the validation is reported, where the validation mAP is obtained every 400 steps.

Table 1 demonstrates the effectiveness of our method on FSD50K dataset. Our method shows clear improvement on all the representative architectures we have chosen, where the mAP scores of LSTM-P, TC-ResNet, CRNN, Att-RNN, and MHAtt-RNN increased by 13.9%, 7.6%, 9.5%, 21.5%, and 25.3%, respectively. Especially, improvement on attention-based methods is substantial.

3.2.2 Results on Noisy KWS dataset

Audio Model Vanilla Attention Distillation
2s wav2vec 2.0 90.59 N/A
LSTM-P 88.73 88.98 89.31 88.92
TC-ResNet 87.77 88.08 86.21 86.27
CRNN 89.96 90.06 90.00 89.46
Att-RNN 89.88 91.31 91.67 90.94
MHAtt-RNN 89.75 91.25 91.67 91.75
4s wav2vec 2.0 91.22 N/A
LSTM-P 85.19 88.23 88.52 89.08
TC-ResNet 87.60 88.33 87.21 84.62
CRNN 89.69 89.44 89.46 90.21
Att-RNN 90.65 91.98 91.79 91.58
MHAtt-RNN 91.19 91.58 92.12 91.73
6s wav2vec 2.0 90.93 N/A
LSTM-P 45.27 69.21 85.58 85.10
TC-ResNet 86.00 86.85 84.23 82.10
CRNN 88.58 89.88 89.29 89.46
Att-RNN 90.88 90.77 90.73 91.19
MHAtt-RNN 90.58 90.96 91.67 91.10
8s wav2vec 2.0 90.95 N/A
LSTM-P 78.44 82.19 34.58 66.25
TC-ResNet 77.81 85.79 85.71 80.15
CRNN 88.94 89.02 89.79 89.77
Att-RNN 88.81 90.44 90.98 90.75
MHAtt-RNN 88.33 91.50 91.79 91.35
Table 2: Performance comparison on Noisy Speech Commands v2 dataset. Test accuracy (%) of the best model found by the validation accuracy is reported, where the validation accuracy is obtained every 400 steps. Best accuracies are in bold, and the performance of the student models that outperform the teacher model is underlined.

Table 2 summarizes the evaluation results on the Noisy Speech Commands v2. Our method shows superior performance across all datasets on all the student architectures. Furthermore, attention-based student models that applied attention distillation often exceed the teacher model performance, whereas vanilla student models are always inferior to the teacher model. We also observed that the accuracy disparity between models trained with vanilla and attention distillation losses tends to grow for non-attention-based models as the audio length increases. This implies that those models benefit from the temporal knowledge extracted from attention weights without any architectural changes.

3.2.3 Further Analysis

Figure 2: Visualization of attention weights extracted from multiple models. We input an arbitrary sample from the Noisy Speech Commands v2 dataset with 8 seconds noise. We plot the location of the one second keyword to all the plots.

In Figure 2, we visualize attentions extracted from teacher model and all student models applied on 8 second sample from Noisy Speech Commands v2. Student architectures are selected based on the best categorical accuracy from Table 2. The gray region represents the position of the keyword location within the audio. We observe that even though the teacher model is trained only with the classification label, attention weights successfully focuses on the keyword location. We can also see that all the on-device models attend at similar positions inside the keyword location, indicating that the teacher and the student attention weights are accurately aligned.

4 Conclusion

In this paper, we proposed a novel attention distillation method that transfers the temporal knowledge from large teacher models to on-device student audio classification models. We extract the attention weights from both the teacher and the student models, and align them via KL divergence. Our method can be applied to various architectures with no architectural change during inference. Through extensive experiments on both an audio event detection dataset and a noisy keyword spotting dataset, we show that our proposed method improves the predictive performance. Experimental results indicate our method improves the predictive performance of on-device models by successfully learning the attention weights from the large teacher model.


  • [1] S. Abnar and W. H. Zuidema (2020) Quantifying attention flow in transformers. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, ACL 2020, Online, July 5-10, External Links: Link, Document Cited by: §2.3.
  • [2] S. Ö. Arik, M. Kliegl, R. Child, J. Hestness, A. Gibiansky, C. Fougner, R. Prenger, and A. Coates (2017-08)

    Convolutional recurrent neural networks for small-footprint keyword spotting

    In Proc. Interspeech 2017, External Links: Link Cited by: §2.2.
  • [3] A. Berg, M. O’Connor, and M. T. Cruz (2021)

    Keyword Transformer: A Self-Attention Model for Keyword Spotting

    In INTERSPEECH, pp. 4249–4253. External Links: Document Cited by: §1.
  • [4] C. Chang, C. Kao, M. Sun, and C. Wang (2020) Intra-utterance similarity preserving knowledge distillation for audio tagging. In INTERSPEECH, Cited by: §1.
  • [5] S. Choi, S. Seo, B. Shin, H. Byun, M. Kersner, B. Kim, D. Kim, and S. Ha (2019-09) Temporal convolution for real-time keyword spotting on mobile devices. In INTERSPEECH, External Links: Link, Document Cited by: §2.2.
  • [6] D. Coimbra de Andrade, S. Leo, M. Loesener Da Silva Viana, and C. Bernkopf (2018-08) A neural attention model for speech command recognition. ArXiv e-prints. External Links: 1808.08929 Cited by: §2.2, §2.2.
  • [7] A. Conneau, A. Baevski, R. Collobert, A. Mohamed, and M. Auli (2020) Unsupervised cross-lingual representation learning for speech recognition. arXiv preprint arXiv:2006.13979. Cited by: §1, §2.1, §3.1.
  • [8] J. Devlin, M. Chang, K. Lee, and K. Toutanova (2019) BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL 2019, Volume 1 (Long and Short Papers), Cited by: §2.1.
  • [9] E. Fonseca, X. Favory, J. Pons, F. Font, and X. Serra (2020) FSD50K: an open dataset of human-labeled sound events. arXiv preprint arXiv:2010.00475. Cited by: §1, §1, §3.1, §3.1.
  • [10] H. Futami, H. Inaguma, S. Ueno, M. Mimura, S. Sakai, and T. Kawahara (2020-10) Distilling the knowledge of BERT for sequence-to-sequence ASR. In INTERSPEECH, External Links: Link, Document Cited by: §1.
  • [11] J. F. Gemmeke, D. P. W. Ellis, D. Freedman, A. Jansen, W. Lawrence, R. C. Moore, M. Plakal, and M. Ritter (2017) Audio set: an ontology and human-labeled dataset for audio events. In Proc. IEEE ICASSP, New Orleans, LA. Cited by: §1, §3.1.
  • [12] F. A. Gers, N. N. Schraudolph, and J. Schmidhuber (2002) Learning precise timing with LSTM recurrent networks. J. Mach. Learn. Res. 3, pp. 115–143. External Links: Link Cited by: §2.2.
  • [13] G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §1.
  • [14] L. Lu, M. Guo, and S. Renals (2017) Knowledge distillation for small-footprint highway networks. Proc. IEEE ICASSP, pp. 4820–4824. Cited by: §1.
  • [15] R. M. Mun’im, N. Inoue, and K. Shinoda (2019) Sequence-level knowledge distillation for model compression of attention-based sequence-to-sequence speech recognition. In Proc. IEEE ICASSP, pp. 6151–6155. Cited by: §1.
  • [16] D. S. Park, W. Chan, Y. Zhang, C. Chiu, B. Zoph, E. D. Cubuk, and Q. V. Le (2019)

    SpecAugment: A simple data augmentation method for automatic speech recognition

    In INTERSPEECH, pp. 2613–2617. External Links: Link, Document Cited by: §3.1.
  • [17] O. Rybakov, N. Kononenko, N. Subrahmanya, M. Visontai, and S. Laurenzo (2020-10) Streaming keyword spotting on mobile devices. INTERSPEECH. External Links: Link, Document Cited by: §1, §2.2, §3.1, §3.1.
  • [18] R. V. Swaminathan, B. King, G. P. Strimel, J. Droppo, and A. Mouchtaris (2021) CoDERT: distilling encoder representations with co-learning for transducer-based speech recognition. CoRR abs/2106.07734. External Links: Link, 2106.07734 Cited by: §1.
  • [19] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §1, §2.2.
  • [20] P. Warden (2018-04) Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition. Proc. IEEE ICASSP. External Links: 1804.03209, Link Cited by: §1, §1, §3.1.
  • [21] Z. Yang, D. Yang, C. Dyer, X. He, A. Smola, and E. Hovy (2016) Hierarchical attention networks for document classification. In Proceedings of the 2016 conference of the North American chapter of the association for computational linguistics: human language technologies, pp. 1480–1489. Cited by: §2.2.
  • [22] S. Zagoruyko and N. Komodakis (2017) Paying more attention to attention: improving the performance of convolutional neural networks via attention transfer. ICLR. Cited by: §1.