Order-free Learning Alleviating Exposure Bias in Multi-label Classification

09/08/2019 ∙ by Che-Ping Tsai, et al. ∙ 0

Multi-label classification (MLC) assigns multiple labels to each sample. Prior studies show that MLC can be transformed to a sequence prediction problem with a recurrent neural network (RNN) decoder to model the label dependency. However, training a RNN decoder requires a predefined order of labels, which is not directly available in the MLC specification. Besides, RNN thus trained tends to overfit the label combinations in the training set and have difficulty generating unseen label sequences. In this paper, we propose a new framework for MLC which does not rely on a predefined label order and thus alleviates exposure bias. The experimental results on three multi-label classification benchmark datasets show that our method outperforms competitive baselines by a large margin. We also find the proposed approach has a higher probability of generating label combinations not seen during training than the baseline models. The result shows that the proposed approach has better generalization capability.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.


Multi-label classification (MLC) is a fundamental but challenging problem in machine learning with applications such as text categorization [25], sound event detection [6, 26], and image classification [20]. In contrast to single-label classification, multi-label predictors must not only relate labels with the corresponding instances, but also exploit the underlying label structures. Take for instance the text classification dataset RCV1 [10], which uses a hierarchical tree structure between labels.

Recent studies show that MLC can be transformed to a sequence prediction problem by probabilistic classifier chains (PCC)

[15, 3, 5]

. PCC models the joint probabilities of output labels with the use of the chain rule and predicts labels based on previously generated output labels. Furthermore, PCC can be replaced by a RNN decoder to model label correlation. wang2016cnn (wang2016cnn) propose the CNN-RNN architecture to capture both image-label relevance and semantic label dependency in multi-label image classification. nam2017maximizing (nam2017maximizing) and yang2018sgm (yang2018sgm) show that state-of-the-art multi-label text classification results can be achieved by using a sequence-to-sequence (seq2seq) architecture to encode input text sequences and decode labels sequentially.

However, this kind of RNN-based decoder suffers from several problems. First, these models are trained using maximum likelihood estimation (MLE) on target label sequences, which relies on a predefined ordering of labels. Previous studies 

[23, 24] show that ordering has a significant impact on the performance. This issue also appears in the PCC, It is addressed by ensemble averaging [15, 3], ensemble pruning [11], pre-analysis of the label dependencies by Bayes nets [19] and integrating beam search with training to determine a suitable tag ordering [9]. However, these approaches rely on training multiple models to ensemble or determine a proper order, which is computationally expensive.

Although nam2017maximizing (nam2017maximizing) and yang2018sgm (yang2018sgm) compare several ordering strategies and suggest ordering positive labels by frequency directly in descending order (from frequent to rare labels), it is unnatural to impose a strict order on labels, which may break down label correlations in a chain. Furthermore, we find that this kind of model tends to overfit to label combinations and shows poor generalization ability.

Second, during training, the RNN-based models are always conditioned on correct prefixes; during inference, however, the prefixes are generated by the RNN-based model, yielding a problem known as exposure bias [14] in seq2seq learning. The error may propagate as the model might be in a part of the state space that it has not seen during training [17].

In this paper, we propose a novel learning algorithm for RNN-based decoders on multi-label classification not rely on a predefined label order. The proposed approach is inspired by optimal completion distillation (OCD) [16], a training procedure for optimizing seq2seq models. In this algorithm, we feed the RNN decoder generated tokens by sampling from the current model. Hence, the model may encounter different label orders and wrong prefixes during training, so the model explores more, and exposure bias is alleviated.

Another common and straightforward way to avoid the need for ordered labels in MLC is binary relevance (BR) [21], which decomposes MLC into multiple independent single-label binary classification problems. However, this yields a model that cannot take advantage of label co-occurrences. In this paper, we further propose helping the model to learn better by use of an auxiliary binary relevance (BR) decoder jointly trained with the RNN decoder within a multitask learning (MTL) framework.

In addition, at the inference stage, the predictions of the BR decoder can be jointly combined in the RNN decoder to improve performance further. We propose two methods to combine their probabilities. Extensive experiments show that the proposed model outperforms competitive baselines by a large margin on three multi-label classification benchmark datasets, including two text classification and one sound event classification datasets.

The contributions of this paper are as follows:

  • We propose a novel training algorithm for multi-label classification which predicts labels autoregressively but does not require a predefined label order for training.

  • We compare our methods with competitive baseline models on three multi-label classification datasets and demonstrate the effectiveness of the proposed models.

  • We systematically analyze the problem of exposure bias and the effectiveness of scheduled sampling [1] in multi-label classification.

Related work

RNN-based multi-label classification

To free the RNN-based MLC classifier from a predefined label order, chen2017order (chen2017order) proposes the order-free RNN to dynamically decide a target label at each time during training by choosing the label in the target label set with the highest predicted probability; hence, the model learns a label order by itself. Although the order can be modified during training, this approach still needs an initialized label order to start the training process. We find order-free RNN shows poor generalization ability to unseen label combinations in the experiments. Also, as the model is always supplied with the correct labels, it suffers from exposure bias.

To handle both exposure bias and label order, other studies apply a reinforcement learning (RL) algorithm to MLC. he2018reinforced (he2018reinforced) apply an off-policy Q learning algorithm to multi-label image classification. yang2018deep (yang2018deep) uses two decoders to solve multi-label text classification, one of which is trained with MLE and the other is trained with a self-critical policy gradient training algorithm. However, Q learning and policy gradients cannot easily incorporate ground truth sequence information, except via the reward function, as the model is rewarded only at the end of each episode. Indeed, he2018reinforced (he2018reinforced) does not work without pretraining on the target dataset. By contrast, we use optimal completion distillation (OCD) 

[16] for MLC, which optimizes token-level log-loss, where the training is stabilized and requires neither initialization nor joint optimization with MLE.

Optimal Completion Distillation (OCD)

Our work is inspired by OCD [16], which was first used in the context of end-to-end speech recognition in which it achieved state-of-the-art performance. In contrast to MLE, OCD algorithms encourage the model to extend all possible tokens that lead to the optimal edit distance by assigning equal probabilities to the target policy that the model learns from. We use OCD to train the RNN decoder in MLC. The OCD training details for MLC are in section Learning for RNN decoder . In contrast to the original OCD [16] which optimizes the edit distance, in the proposed approach we optimize the numbers of missing and false alarm labels.

Model architecture

An overview of the proposed model is shown in Fig. 1. The model is composed of three components: encoder , RNN decoder , and binary relevance decoder . Here, multi-label text classification is considered an instance of MLC. For other types of MLC other than text classification (e.g. sound event classification), the architecture of the encoder can be changed.

Figure 1: Overview of proposed model. The model is composed of three components: encoder , RNN decoder , and binary relevance decoder . represents the sampled sequence from , while

is a vector representing the predicted probabilities of labels by



We employ a bidirectional LSTM as an encoder . The encoder reads the input text sequence of words in both forward and backward directions and computes the hidden states for each word.


RNN decoder

The RNN decoder seeks to predict labels sequentially. It is potentially more powerful than the binary relevance decoder because each prediction is determined based on the previous prediction: thus it implicitly learns label dependencies. We implement it using LSTMs with an attention mechanism. Hence, the encoder and RNN decoder form a seq2seq model. In particular, we set the initial hidden state of the decoder and calculate the hidden state and output at time as


where is the predicting label at previous timestep. is estimated with by the following equations:


During sampling, we add a mask vector [25] to prevent the model from predicting repeated labels, where is the number of labels in the dataset:


Binary relevance decoder

The binary relevance (BR) decoder here is an auxiliary decoder to train the encoder within the multitask learning (MTL) framework. The BR decoder predicts each label separately as a binary classifier for each label, helping the model to learn better. Another advantage of using the BR decoder is that we can consider the predictions of both the RNN and BR decoders to further improve performance.

In particular, we feed the final hidden state of encoder to a DNN with a final prediction layer of size

with sigmoid activation functions to predict the probabilities of each label. To take into account vanishing gradients for long input sequences, we add another attention module. In particular, we calculated the context vector

in the attention mechanism with the output of fully-connected layers and then compute probabilities as


where is the matrix of weight parameters and indicates the concatenation of and .

Order-Free Training

In this section, we derive the training objective for the RNN decoder , the BR decoder , and the multitask learning objective.

Learning for RNN decoder

RNN decoder learning as RL

To reduce exposure bias and free the model from relying on a predefined label order, we never train on ground truth target sequences. Instead, we approach the MLC problem from an RL perspective. The model here plays the role of an agent whose action is the current generated label at time and whose state is the output labels before time . The policy

is a probability distribution over actions

given states . Once the process is ended with an end-of-sentence token, the agent is given a reward .

In our approach, reward is defined as


where and are the ground truth labels and the sequence of labels generated by the RNN decoder, and is the relative complement of set A in set B.

The first and second term of reward are the number of labels that were not predicted and the number of misclassified labels, respectively.

Optimal completion distillation

However, typical RL algorithms, such as Q learning and policy gradients, cannot easily incorporate ground truth sequence information except via the reward function. Here we introduce optimal Q-values, which evaluates each action at each time .

Optimal Q-values represents the maximum total reward the agent can acquire after taking action at state via subsequently conducting the optimal action sequence. Optimal Q-values at time can be expressed as


where is a complete sequence, which is the concatenation of token sequence generated before time , action at time and optimal subsequent action sequence .

Optimal policy at time can be calculated by taking a softmax over optimal Q-values of all the possible actions. Formally,


where is a temperature parameter. If is close to , is a hard target. Table 1 shows an example illustrating the optimal policy in OCD training procedure.

Given a dataset , we first draw generated sequences from the RNN decoder

by sampling. The loss function

can be obtained via calculating KL divergence between the optimal policy and the model’s generated distribution over labels at every time step ,


The above equation means we “distill” knowledge from optimal policies obtained by completing with optimal action sequences to RNN decoder , so RNN decoder can have similar behaviour as optimal policy .

In contrast to MLE, OCD encourages the model to extend all the possible targets resulting in the same evaluation metric score. Therefore, the OCD objective focuses on all labels that were not predicted and assigns them equal probabilities. Once all the target labels are successfully generated, the objective guides the model to produce the end-of-sentence token with probability


Since the OCD targets depend only on the tokens generated previously, we do not need a human-defined label order to train the RNN decoder. The label order is instead automatically determined at each time step. In addition, we always train on sequences generated from the current model, thus alleviating exposure bias. Note that we can substitute the reward function (Eq. 7) with other example-based test metrics such as the example-based F1 score (Eq. 19 in the Appendix), but these lead to the same OCD targets as the rewards of all the target labels are the same.

Time OCD Optimal policy Prediction
targets ()
0 A, B, D B
1 A, D C
2 A, D A
3 D D
4 <eos> <eos>
Table 1: A training example of optimal completion distillation, where there are 4 kinds of labels (A, B, C, D) and an end-of-sentence token (<eos>). Labels A, B, and D are the targets of this instance and the vectors of represent probabilities for labels A, B, C, D and <eos>, respectively. We set here, so the optimal policy only encourages labels with the highest optimal Q values. For example, at time 1, since the model has predicted correct token B at time 0, there are two optimal extended tokens {A, D}, which result in a total reward of 0 (Eq. 7) when combined with proper suffixes. Then we sample from the current policy and predict the incorrect token C, which leads to a decreased optimal possible reward of (Eq. 7) at time 2.

Learning for binary relevance decoder

For the binary relevance decoder, given the ground truth in binary format, we use binary cross-entropy loss as the objective:


where , which is a vector of length representing the predicted probability of each label.

Multitask Learning (MTL)

The objective of MTL is


where is a weight.

Decoder Integration

In this section, we seek to utilize both the RNN and BR decoders to find the optimal hypothesis , which consists of predicted labels , where is the end-of-sentence token indicating that the decoding process of the RNN decoder has ended.

For the BR decoder, the outputs after a sigmoid activation

are designed to estimate the posterior probabilities of each label

. Therefore, the theoretically optimal threshold for converting the probability to a binary value should be 0.5, which is equivalent to finding the optimal hypothesis that maximizes the Eq. 13 below, which is the product of the probabilities of all the labels.


For the RNN decoder, a typical inference step is performed with a beam search to solve Eq. 14 [25, 2]. Given input , the probability of the predicted hypothesis path is


To combine the predictions of the RNN and BR decoders, we simply take the product of Eq. 13 and Eq. 14 to yield the final objective function Eq. 15:


Nonetheless, the equation is not easy to solve because the RNN decoder outputs the probability of selecting a particular label at each time step while the BR decoder produces the unconditional probabilities of all the labels at once. To incorporate the BR probabilities of labels in the score, we provide two different decoding strategies to find the best hypothesis .

Logistic Rescoring

In this method, we first obtain a set of complete hypotheses using beam search only with the RNN decoder, and rescore each hypothesis using the probabilities produced by the BR decoder with Eq. 13. Finally, we select as the final prediction the hypothesis with the highest .

Logistic Joint Decoding

This method is one-pass decoding. We conduct a beam search according to the following equation (see the derivation in Appendix Derivation of logistic joint decoding).


Note that we manually set the probability of the end-of-sentence token for binary relevance , since it does not exist in the outputs of the BR decoder.

Experimental Setup

We validate our proposed model on two multi-label text classification datasets, which are AAPD [25] and Reuters-21758, and a sound event classification dataset, which is Audio set [6] proposed by Google. They span a wide variety in terms of the number of samples, the number of labels, and the number of words per sample. Due to space limit, we put an extra experiment in text classification on RCV1-V2 [10], data statistics, experimental settings and the descriptions of five evaluation metrics in Appendix.

Models maF1 miF1 ebF1 ACC HA Average
(a) Seq2set (simp.)
[24] - 0.705 - - 0.9753 -
(b) Seq2set
[24] - 0.698 - - 0.9751 -
(c) SGM+GE
[25] - 0.710 - - 0.9755 -
(d) BR 0.523 0.694 0.695 0.368 0.9741 0.651
(e) BR++ 0.521 0.700 0.703 0.390 0.9750 0.658
(f) Seq2seq 0.511 0.695 0.707 0.421 0.9743 0.662
(g) Seq2seq + SS 0.541 0.703 0.713 0.406 0.9742 0.667
(h) Order-free RNN 0.539 0.696 0.708 0.413 0.9742 0.666
(i) Order-free RNN + SS 0.548 0.699 0.709 0.416 0.9743 0.669
Proposed methods
(j) OCD 0.541 0.707 0.723 0.403 0.9740 0.670
OCD + MTL (k) RNN dec. 0.578 0.711 0.727 0.391 0.9742 0.676
(l) BR dec. 0.562 0.711 0.718 0.382 0.9760 0.670
(m) Logistic rescore 0.585 0.720 0.736 0.395 0.9749 0.682
(n) Logistic joint dec. 0.580 0.719 0.731 0.399 0.9753 0.681
Table 2: Performance on AAPD

Evaluation Metrics

Multi-label classification can be evaluated with multiple metrics, which capture different aspects of the problem. We follow  nam2017maximizing (nam2017maximizing) in using five different metrics: subset accuracy (ACC), Hamming accuracy (HA), example-based F1 (ebF1), macro-averaged F1 (maF1), and micro-averaged F1 (miF1).


We compare our methods with the following baselines. For a fair comparison, the architecture of all the encoders are the same except for BR++: the RNN decoders for Seq2seq, Order-free RNN, and the proposed approaches are the same.

  • Binary Relevance (BR) is the model trained with logistic loss (Eq. 11), and consists of an encoder and a BR decoder.

  • Binary Relevance++ (BR++) is a model with a larger encoder but with the same training algorithm as BR. Because the MTL model has more parameters than BR, for fair comparison, we increase the number of layers in the encoder RNN so that the number of parameters is approximately equal to the MTL model. 111Since yu2018multi (yu2018multi) have tested different architectures of BR models on Audio set, we do not use BR++ as a baseline.

  • Seq2seq [12] is composed of an RNN encoder and an RNN decoder with an attention mechanism. The model is trained with MLE. The target label sequences are ordered from frequent to rare, which yields better performance [12, 25]. 222For Audio set, the architecture of the encoder is described in Appendix and is not based on RNN. For consistency, we denoted it as Seq2seq.

  • Order-free RNN [2] uses an algorithm for multi-label image classification to train the RNN decoder without predefined orders but suffers from exposure bias.

Since scheduled sampling also tackles the problem of exposure bias, we also compare the performance of seq2seq and order-free RNN with scheduled sampling (SS), which are Seq2seq + SS and Order-free RNN + SS. The detailed discussion of the effectiveness of scheduled sampling is in section Disussion.

Results and Discussion

Figure 2: Performance of BR, OCD and MTL models on AAPD validation set with different decoding strategies during training. The x-axis denotes the number of updates; we use different scales on the y-axis for each measure.
Models maF1 miF1 ebF1 ACC HA Average
[4] 0.468 0.787 - - - -
[12] 0.457 0.855 0.891 0.828 0.996 0.805
BR 0.442 0.861 0.878 0.817 0.9964 0.799
BR++ 0.407 0.852 0.861 0.812 0.9962 0.786
Seq2seq 0.465 0.862 0.895 0.834 0.9965 0.811
Seq2seq+SS 0.464 0.856 0.895 0.834 0.9965 0.809
Order-free RNN 0.445 0.862 0.901 0.835 0.9963 0.806
Order-free RNN + SS 0.452 0.859 0.896 0.836 0.9962 0.808
Proposed methods
OCD 0.458 0.872 0.903 0.839 0.9966 0.814
OCD + MTL RNN dec. 0.475 0.874 0.905 0.844 0.9966 0.819
BR dec. 0.459 0.877 0.898 0.835 0.9966 0.813
Logistic rescore 0.477 0.875 0.903 0.842 0.9967 0.819
Logistic joint dec. 0.490 0.874 0.904 0.843 0.9967 0.822
Table 3: Performance comparisons on Reuters-21578

In the following, we show results of the baseline models and the proposed method on three text datasets. For MTL models, we show the results of the four kinds of different decoding strategies described in section Decoder Integration. For a simple comparison, we also compute averages of the five metrics as a reference. Note that blue an bold texts in Table 2-4 mean that the performance of proposed methods surpass all the baselines and the highest performance in each measure.

Experiments on AAPD

The experimental results on the AAPD dataset are shown in Table 2. We see that different models are skilled at different metrics. For example, RNN decoder based models, i.e. Seq2seq in row (f) and Order-free RNN in row (h), perform well on ACC, whereas BR and BR++ have better results in terms of HA but show clear weaknesses in predicting rare labels (cf. especially maF1). However, OCD in row (j) performs better than all the baselines (row (d)–(i)) (0.672 on average),333Except for ACC: the reason is given in the following discussion. especially in miF1 (0.707) and ebF1 (0.737), which verifies the power of the proposed training algorithm.

Models maF1 miF1 ebF1 ACC HA Average
BR 0.349 0.480 0.416 0.086 0.9957 0.465
Seq2seq 0.345 0.448 0.421 0.140 0.9942 0.470
Seq2seq + SS 0.340 0.448 0.419 0.137 0.9943 0.468
Order-free RNN 0.310 0.438 0.410 0.096 0.9940 0.450
Order-free RNN + SS 0.310 0.437 0.408 0.095 0.9947 0.449
Proposed methods
OCD 0.353 0.465 0.435 0.117 0.9941 0.473
OCD + MTL RNN dec. 0.359 0.466 0.438 0.115 0.9940 0.474
BR dec. 0.353 0.485 0.420 0.075 0.9950 0.466
Logistic rescore 0.378 0.487 0.456 0.096 0.9940 0.482
Logistic joint dec. 0.377 0.488 0.454 0.119 0.9945 0.487
Table 4: Performance comparisons on Audio set.
Figure 3: Position-wise accuracy of different models at each time step on Audio set. OCD+MTL was decoded by logistic joint decoding. Note that the length of the longest generated(reference) label sequence is 12.

For MTL, we report the performance for the decoding strategies from section Decoder Integration. The first two decoding methods (rows (k),(l)) consider only the predictions of one decoder, while the other two (rows (m),(n)) combine the predictions using different decoding strategies. With MTL, we see the performance is improved across all the metrics except for ACC (row (j) v.s. row (k)). In addition, joint decoding methods (row(m),(n)) achieve the best performance, and outperform previous works (row(a),(b),(c)). Interestingly, BR is also improved significantly with MTL (row (d) v.s. row(l)), the encoder of which may implicitly learn the label dependencies through the RNN decoder, which the original BR (row (d)) ignores.

Fig. 2 shows the validation ACC and miF1 curves for OCD, BR, and MTL with three decoding methods. BR performs the worst and converges the slowest. Nonetheless, with the help of MTL, BR converges much faster and better. Also, MTL helps to improve the performance of the OCD model.

Seen test set Unseen test set
Models miF1 ebF1 miF1 ebF1
Seq2seq 0.730 0.749 0.508 0.503
Seq2seq + SS 0.736 0.754 0.517 0.515
Order-free RNN 0.732 0.746 0.496 0.494
Order-free RNN + SS 0.724 0.740 0.520 0.517
OCD (correct prefix) 0.726 0.741 0.513 0.515
OCD 0.746 0.771 0.521 0.530
Table 5: Performance comparison on resplited AAPD, whose test set contains 2000 samples whose label sets occur in the training set (Seen test set) and 2000 samples are not (Unseen test set). OCD (correct prefix) means we only sample correct labels in the training phase.

Experiments on Reuters-21758

In comparison with AAPD, Reuters-21758 is a relatively small dataset. The average number of labels per sample is only 1.24 and over 80% of the samples have only one label, making for a relative simple dataset.

Table 3 shows the results of the methods. These results demonstrate again the superiority of OCD and the performance gains afforded by the MTL framework. Since there are over 80% of test samples only have one label in this corpus, to truly know the effect of proposed approaches to multi-label classification, we also provide results only on test samples with more than one label in the Section Analysis of Reuters-21758 in Appendix.

Experiments on Audio set

Models Case 1 Case 2 Case 3
Ground truth cs.it, math.it, cs.ds cs.lg, stat.ml, math.st, stat.th cs.it, math.it, cs.ds, cs.dc
Seq2seq cs.it, math.it, cs.ni cs.it, math.it, math.st, stat.th cs.it, math.it, cs.ni
Order-free RNN math.it, cs.it math.it, cs.it, stat.th, math.st math.it, cs.it, cs.ni
OCD math.it, cs.it, cs.ds stat.ml, stat.th, cs.lg, math.st cs.it, math.it
OCD + MTL + joint dec. cs.it, math.it math.st, stat.th, stat.ml, cs.lg, cs.it, math.it math.it, cs.it
Table 6: Examples of generated label sequences on AAPD from different models
Ref- Seq2seq OfRNN
erence OCD Seq2seq + SS OfRNN + SS
392 302 214 293 251 259
43 30 1 3 1 4
210 159 135 140 139 144
94 37 15 16 23 26
Audio set
6300 4787 1974 1781 1806 1789
2445 1943 4 8 263 237
Table 7: Number of different generated (or reference) label combinations (), and the number of generated (or reference) label combinations that do not occur in the training label sets () on three datasets.

Table 4 shows the performance of each model. In this experiment, all models have similar performance in HA. Surprisingly, BR is a competitive baseline model and performs especially well in miF1. Seq2seq achieves the best performance in terms of ACC, which is the same as the observation on AAPD. Overall, OCD performs better than all the baseline and MTL indeed improves the performance. OCD outperforms other RNN decoder-based models in maF1, miF1 and ebF1 and performs worse than BR only in terms of miF1.


Propagation of errors

Fig. 3 shows position-wise accuracy of different models at each time step of RNN decoder on Audio set. The accuracy is calculated by checking whether or not model’s generated labels are in reference label sets and then averaging the errors at each time step. If a generated label sequence is less than the corresponding target label sequence, the unpredicted part of the sequence is considered wrong.

We can see that accuracy of all models decreases dramatically along x-axis. Because the labels are generated sequentially, the models would condition on wrong generated prefix label sequence in the test stage. This problem may be amplified when the generated sequence is longer because of accumulation of errors. Compared with the baseline models, OCD+MTL and OCD perform better after position 2, which demonstrates that they are more robust against error propagation, or exposure bias. Similar phenomenon can be observed in AAPD ( Fig. 6 in Appendix).

Effectiveness of scheduled sampling and OCD

To demonstrate the effectiveness of scheduled sampling and OCD when dealing with exposure bias, we evaluate the performance of models when tackling with unseen label combinations, where models encounter unseen situations and the problem of exposure bias may be more severe.

In this experiment, since there are only 43 samples with unseen label combinations in original test set of AAPD, we resplited the AAPD dataset: 47840 samples in training set, 4000 samples for validation set and test set, respectively. Both validation set and test set contain 2000 samples whose label sets occur in the training set and 2000 samples are not.

Table 5 shows the results on resplited AAPD. OCD (correct prefix) means we only sample correct labels in the training phase, so this model has not encountered wrong prefix during training. Clearly, all models perform worse on unseen test set. We can see that SS improves the performance significantly on the unseen test set for both seq2seq and order free RNN. Additionally, OCD with correct prefix, which suffers from the exposure bias, performs worse in both case than OCD. They all demonstrate that sampling wrong labels from predicted distribution helps models become more robust when encountering rare situation.

SS for MLC has a potential drawback. The input labels of RNN decoders obtained by sampling could be labels which do not conform to the predefined order. This may mislead the model. However, there is no label ordering in OCD, so this problem does not exist.

On both seen and unseen test set, OCD performs the best since OCD not only alleviates exposure bias but also does not need predefined order. Results of five metrics and another experiment about exposure bias on AAPD can be found in Appendix.

Problem of overfitting

Table 7 shows number of different generated label combinations (), and the number of generated label combinations that do not occur in the training label sets () on three datasets. Seq2seq and OfRNN produce fewer kinds of label combinations on AAPD and Reuters-21758. As they tend to “remember” label combinations, the generated label sets are most alike, indicating a poor generalization ability to unseen label combinations. Because seq2seq is conservative and only generates label combinations it has seen in the training set, it achieves the highest ACC in Tables 2 and 4. For models with SS, they produce more kinds of label combinations, except for Audio set. OCD produces the most unseen label combinations on three datasets, since it encounters different label permutations during training.

Case study

Table 6 shows examples of generated label sequences using different models on AAPD. Note labels cs.it and math.it in the three cases: Seq2seq produces label sequences only from frequent to rare, which is the same as the ground truth order, while order-free RNN learns the order implicitly. In contrast, OCD generates label sequences with flexible orders because it encounters different label permutations in the sampling process during training.


In this paper, we propose a new framework for multi-label classification based on optimal completion distillation and multitask learning. Extensive experimental results show that our method outperforms competitive baselines by a large margin. Furthermore, we systematically analyze exposure bias and the effectiveness of scheduled sampling.


  • [1] S. Bengio, O. Vinyals, N. Jaitly, and N. Shazeer (2015) Scheduled sampling for sequence prediction with recurrent neural networks. In Advances in Neural Information Processing Systems, pp. 1171–1179. Cited by: 3rd item.
  • [2] S. Chen, Y. Chen, C. Yeh, and Y. F. Wang (2017) Order-free rnn with visual attention for multi-label classification. arXiv preprint arXiv:1707.05495. Cited by: Decoder Integration, 4th item.
  • [3] W. Cheng, E. Hüllermeier, and K. J. Dembczynski (2010) Bayes optimal multilabel classification via probabilistic classifier chains. In Proceedings of the 27th international conference on machine learning (ICML-10), pp. 279–286. Cited by: Introduction, Introduction.
  • [4] F. Debole et al. (2005) An analysis of the relative hardness of reuters-21578 subsets. Journal of the American Society for Information Science and technology 56 (6), pp. 584–596. Cited by: Table 3.
  • [5] K. Dembczynski, W. Waegeman, and E. Hüllermeier (2012) An analysis of chaining in multi-label classification.. In ECAI, pp. 294–299. Cited by: Introduction.
  • [6] J. F. Gemmeke, D. P. Ellis, D. Freedman, A. Jansen, W. Lawrence, R. C. Moore, M. Plakal, and M. Ritter (2017) Audio set: an ontology and human-labeled dataset for audio events. In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 776–780. Cited by: Introduction, Experimental Setup, 4th item.
  • [7] D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: Experimental Settings.
  • [8] Q. Kong, Y. Xu, W. Wang, and M. D. Plumbley (2018)

    Audio set classification with attention model: a probabilistic perspective

    In 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 316–320. Cited by: Multi-label sound event classification.
  • [9] A. Kumar, S. Vembu, A. K. Menon, and C. Elkan (2013) Beam search algorithms for multilabel learning. Machine learning 92 (1), pp. 65–89. Cited by: Introduction.
  • [10] D. D. Lewis, Y. Yang, T. G. Rose, and F. Li (2004) Rcv1: a new benchmark collection for text categorization research. Journal of machine learning research 5 (Apr), pp. 361–397. Cited by: Introduction, Experimental Setup, 3rd item.
  • [11] N. Li and Z. Zhou (2013) Selective ensemble of classifier chains. In International Workshop on Multiple Classifier Systems, pp. 146–156. Cited by: Introduction.
  • [12] J. Nam, E. L. Mencía, H. J. Kim, and J. Fürnkranz (2017) Maximizing subset accuracy with recurrent neural networks in multi-label classification. In Advances in neural information processing systems, pp. 5413–5423. Cited by: 3rd item, Table 3, Evaluation Metrics.
  • [13] J. R. Quevedo, O. Luaces, and A. Bahamonde (2012) Multilabel classifiers with a probabilistic thresholding strategy. Pattern Recognition 45 (2), pp. 876–883. Cited by: Multi-label text classification.
  • [14] M. Ranzato, S. Chopra, M. Auli, and W. Zaremba (2015) Sequence level training with recurrent neural networks. arXiv preprint arXiv:1511.06732. Cited by: Introduction.
  • [15] J. Read, B. Pfahringer, G. Holmes, and E. Frank (2011) Classifier chains for multi-label classification. Machine learning 85 (3), pp. 333. Cited by: Introduction, Introduction.
  • [16] S. Sabour, W. Chan, and M. Norouzi (2018) Optimal completion distillation for sequence learning. arXiv preprint arXiv:1810.01398. Cited by: Introduction, RNN-based multi-label classification, Optimal Completion Distillation (OCD).
  • [17] R. Senge, J. J. Del Coz, and E. Hüllermeier (2014) On the problem of error propagation in classifier chains for multi-label classification. In Data Analysis, Machine Learning and Knowledge Discovery, pp. 163–170. Cited by: Introduction.
  • [18] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov (2014) Dropout: a simple way to prevent neural networks from overfitting. The Journal of Machine Learning Research 15 (1), pp. 1929–1958. Cited by: Experimental Settings.
  • [19] L. E. Sucar, C. Bielza, E. F. Morales, P. Hernandez-Leal, J. H. Zaragoza, and P. Larrañaga (2014)

    Multi-label classification with bayesian network-based chain classifiers

    Pattern Recognition Letters 41, pp. 14–22. Cited by: Introduction.
  • [20] C. Tsai and H. Lee (2018) Adversarial learning of label dependency: a novel framework for multi-class classification. arXiv preprint arXiv:1811.04689. Cited by: Introduction.
  • [21] G. Tsoumakas and I. Katakis (2007) Multi-label classification: an overview. International Journal of Data Warehousing and Mining (IJDWM) 3 (3), pp. 1–13. Cited by: Introduction.
  • [22] L. Tu and K. Gimpel (2018) Learning approximate inference networks for structured prediction. arXiv preprint arXiv:1803.03376. Cited by: Multi-label text classification.
  • [23] O. Vinyals, S. Bengio, and M. Kudlur (2016) Order matters: sequence to sequence for sets. In International Conference on Learning Representations (ICLR), External Links: Link Cited by: Introduction.
  • [24] P. Yang, S. Ma, Y. Zhang, J. Lin, Q. Su, and X. Sun (2018) A deep reinforced sequence-to-set model for multi-label text classification. arXiv preprint arXiv:1809.03118. Cited by: Introduction, Table 2, Multi-label text classification.
  • [25] P. Yang, X. Sun, W. Li, S. Ma, W. Wu, and H. Wang (2018) SGM: sequence generation model for multi-label classification. In Proceedings of the 27th International Conference on Computational Linguistics, pp. 3915–3926. Cited by: Introduction, RNN decoder , Decoder Integration, 3rd item, Table 2, Experimental Setup, Multi-label text classification.
  • [26] C. Yu, K. S. Barsim, Q. Kong, and B. Yang (2018) Multi-level attention model for weakly supervised audio classification. arXiv preprint arXiv:1803.02353. Cited by: Introduction, Multi-label sound event classification.


Derivation of logistic joint decoding

In this appendix, we derive the equation for logistic joint decoding (Eq. 16). We first reformulate Eq. 13.


Since the second term of Eq. 17 does not depend on , we can substitute it into Eq. 15.


Datasets and Preprocessing

Dataset # labels Words/sample Labels/sample
Reuters-21758 6,993 776 3019 90 53.94 1.24
AAPD 53,840 1000 1000 54 163.42 2.41
RCV1-V2 802,414 1000 1000 103 123.94 3.24
Audio set 2,063,949 0 20,371 527 - 1.98
Table 8: Summary of datasets. # of training samples (), # of validation samples, (), # of test samples (), # of labels. Words/sample is the average number of labels per sample and labels/sample is the average number of labels per sample.

We used three multi-label text classification datasets and one sound event classification dataset. The statistics of the four datasets are shown in Table 8.

  • Reuters-21758444http://www.daviddlewis.com/resources/testcollections/
    : The Reuters-21758 dataset is a collection of around 10,000 documents that appeared on Reuters newswire in 1987 with 90 classes.

  • Arxiv Academic Paper Dataset (AAPD): This dataset is provided by yang2018sgm (yang2018sgm) , and consists of the abstracts and corresponding subjects of 55,840 academic computer science papers from arxiv. Each paper has 2.41 subjects on average.

  • Reuters Corpus Volume I (RCV1-V2): The RCV1-V2 dataset [10] consists of a large number of manually categorized newswire stories (804,414) with 103 topics.

  • Audio Set: The Audio set was proposed by Google [6], which consists of over 2 million 10-second audios covering 527 kinds of audio events, including music, speech, vehicles and animal sounds. Because Google only released bottleneck features provided by a pretrained Resnet-50, we used the features as inputs of models. The inputs are ten 128-dim bottleneck features for each audio clip.

When preparing Reuters-21758, we followed  nam2017maximizing (nam2017maximizing) in randomly setting aside 10% of the training instances as the validation set. For AAPD and RCV1-V2, we used the training/validation/test set split from yang2018sgm (yang2018sgm) . For these three text datasets, we filtered out samples with more than 500 words, which removed about 0.5% of the samples in each dataset. For Audio set, we follow yu2018multi (yu2018multi) to use the original training/test set split from Google (no validation set).

Experimental Settings

We implemented our experiments in Pytorch. Some hyperparameters of the model on four datasets are shown in Table 


We used the Adam optimizer [7] with a learning rate of 0.0005. In addition, to avoid overfitting, we utilized dropout [18] and clipped the gradients to the maximum norm of 10. For OCD models, we set the softmax temperature to , which resulted in hard targets. For models with scheduled sampling, we set the teacher forcing ratio from 1.0 (start-of-training) to 0.7 (end-of-training). For MTL models, the weight between OCD and logistic losses was 1.

Different settings of three multi-label classification datasets and Audio set are in following.

Multi-label text classification

The BR decoder is a 3-layer DNN with 512 leaky-RELU units. The word embeddings were initialized randomly; their size was 512.

During training, we trained the model for a fixed number of epochs and monitored its performance on the validation set every 1000 updates. Once the training was completed, we chose the model with the best micro-F1 score on the validation set and evaluated its performance on the test set.

During testing, we set the beam size to 6 during the beam search process for all the RNN decoders. For models with BR decoders, we chose the best threshold on the validation set to maximize the micro F1 score, and selected labels whose score was higher than the selected threshold [22, 13].

All the hyperparameters were tuned on the baseline model, seq2seq, until the performance is approximately equal to the performance of the same model reported in the previous works[25, 24]. Then we applied the same model architecture and hyperparameters to all of the models for fair comparison.

Multi-label sound event classification

In this experiment, for the BR model, we used the architecture provided by yu2018multi (yu2018multi) 555https://github.com/qiuqiangkong/audioset_classification (2-A-1-A model). To the best of our knowledge, it achieved the best performance on Audio set. We reimplemented the model and achieved similar performance (Mean average precision score 666Because mean average precision measures ranking of confidence scores of each label, which RNN-based approaches can not generate since it only produces hard target sequences. Therefore, we did not utilize it as an evaluation metric. of 0.349 without aggregating the output probabilities of models during training at each epoch).

yu2018multi (yu2018multi) trained their model with binary logistic loss (Eq. 11). We decompose it into two parts: a BR decoder, which is the final fully-connected output layer with sigmoid activation function of their proposed model, and an encoder (the remaining part; without final output layer). For RNN-based model, we set the output of the encoder as the initial state of RNN decoder, which is comprised of 2 layers of LSTM with 512 hidden units. We used the technique of mini-batch balancing [8].

Because we follow the setting of [26] and there is no validation set, we trained models for 10 epoches and then test them on the test set. We set the beam size to 6 during the beam search process for all the RNN decoders. For models with BR decoders, we fix the threshold to .

Vocab LSTM Batch
Dataset size layer size Dropout
Reuters 22747 (2,2) 96 0.5
AAPD 30000 (2,2) 128 0.5
RCV1-V2 50000 (2,3) 96 0.3
Audio set - (-,2) 500 0.5
Table 9: Hyperparameters for datasets. LSTM layer (2,3) means the numbers of layers of the RNN encoder and decoder are 2 and 3, respectively. ”-” means it does not exist.

Evaluation Metrics

The five metrics can be split in two parts.

Example-based measures are defined by comparing the target vector to the prediction vector . Subset accuracy (ACC) is the most strict metric, indicating the percentage of samples in which all the labels are classified correctly. . Hamming accuracy (HA) is the ratio of the number of correctly predicted labels to the total number of labels. . Example-based F1 (ebF1) defined by Eq. 19 measures the ratio of the number of correctly predicted labels to the total of the predicted and ground truth labels.


Label-based measures treat each label as a separate two-class prediction problem, and compute the number of true positives (tp), false positives (fp), and false negatives (fn) for each label over the dataset.

Macro-averaged F1 (maF1) is the average of the F1 scores of each label (Eq. 20), and Micro-averaged F1 (miF1) is calculated by summing tp, fp, and fn and then calculating the F1 score (Eq. 21).

High maF1 scores usually indicate high performance on rare labels, while high miF1 scores usually indicate high performance on frequent labels [12].


Experiment on RCV1-V2

Models maF1 miF1 ebF1 ACC HA Average
BR 0.671 0.868 0.881 0.642 0.9919 0.811
BR++ 0.650 0.867 0.881 0.646 0.9919 0.807
Seq2seq 0.654 0.864 0.881 0.662 0.9916 0.811
Seq2seq + SS 0.653 0.860 0.878 0.658 0.9914 0.809
Order-free RNN 0.660 0.863 0.878 0.650 0.9917 0.809
Order-free RNN + SS 0.637 0.862 0.876 0.662 0.9917 0.806
Proposed methods
OCD 0.668 0.866 0.882 0.654 0.9918 0.812
OCD + MTL RNN dec. 0.671 0.867 0.882 0.651 0.9918 0.813
BR dec. 0.663 0.869 0.885 0.637 0.9920 0.813
Logistic rescore 0.676 0.869 0.884 0.653 0.9919 0.815
Logistic joint dec. 0.674 0.871 0.885 0.658 0.9920 0.816
Table 10: Performance comparisons on RCV1-V2.

Table 10 shows the results. Compared to AAPD and Reuters-21578, RCV1-V2 consists of a larger number of documents. Moreover, the labels in this dataset have a hierarchical structure. If a leaf label belongs to one document, all labels from the root to the leaf label in the label tree also belong to the document. Hence, if we sort the labels from frequent to rare, parent labels precede child labels in the label tree.

In this dataset, OCD shows a smaller performance gain. This may be because the predefined order contains rich information about hierarchical structures which OCD does not utilize. However, datasets whose label have hierachical structures are not common.

Detailed results of resplited AAPD

Models maF1 miF1 ebF1 ACC HA Average
Seen test set
Seq2seq 0.530 0.730 0.749 0.453 0.9771 0.688
Seq2seq + SS 0.551 0.736 0.754 0.449 0.9774 0.693
Order-free RNN 0.545 0.732 0.746 0.468 0.9777 0.694
Order-free RNN + SS 0.546 0.724 0.740 0.415 0.9764 0.680
OCD(correct prefix) 0.543 0.726 0.741 0.452 0.9770 0.688
OCD 0.571 0.746 0.771 0.443 0.9780 0.702
Unseen test set
Seq2seq 0.402 0.503 0.508 0.002 0.9550 0.474
Seq2seq + SS 0.418 0.515 0.517 0.009 0.9562 0.483
Order free RNN 0.391 0.494 0.496 0.006 0.9560 0.469
Order free RNN + SS 0.426 0.517 0.520 0.040 0.9557 0.492
OCD(correct prefix) 0.421 0.513 0.515 0.006 0.9566 0.482
OCD 0.446 0.521 0.530 0.017 0.9553 0.494
Table 11: Detailed results on resplited AAPD (Table 5).

Table 11 shows the detailed results of Table 5. Interestingly, all models have difficulties predicting all the correct labels of unseen label combinations. Hence, the subset accuracy is extremely low on unseen test set.

Impact of exposure bias on AAPD

Figure 4: Example-based F1 score of test samples versus the number of times the label combination appears in the training set on AAPD. “OfRNN” denotes order-free RNN.

Fig. 4 shows the example-based F1 score of test samples of different models versus the number of times that the label combination appears in the training set on AAPD. Clearly, the more times the model sees the label combinations, the higher the performance. An interesting observation is that scheduled sampling (SS) helps Seq2seq and OfRNN with rare label combinations but not with frequent ones. This may be because with models that perform poorly when presented with rare situations, exposure bias may become more severe and models are more likely to make wrong predictions. Hence, SS is more helpful with rare examples.

SS performs worse when presented with frequent label combinations. For models with SS, labels obtained by sampling may be labels which do not conform to the predefined order, which may mislead the model. In contrast, OCD performs well consistently. Since in OCD the loss function depends on the input prefixes, and we never supply the ground truth sequence, the model explores more states at the training stage. Hence, it is more robust under all situations.

Analysis of Reuters-21758

Models maF1 miF1 ebF1 ACC HA Average
BR 0.315 0.706 0.712 0.365 0.9850 0.617
Seq2seq 0.316 0.712 0.718 0.405 0.9855 0.627
Seq2seq+SS 0.325 0.718 0.722 0.380 0.9859 0.626
Order-free RNN 0.331 0.730 0.735 0.425 0.9862 0.641
Order-free RNN + SS 0.324 0.699 0.711 0.400 0.9849 0.624
Proposed methods
OCD 0.319 0.734 0.741 0.415 0.9864 0.639
OCD + MTL RNN dec. 0.335 0.745 0.749 0.440 0.9870 0.651
BR dec. 0.322 0.739 0.737 0.430 0.9869 0.643
Logistic rescore 0.337 0.750 0.752 0.435 0.9869 0.652
Logistic joint dec. 0.342 0.743 0.746 0.435 0.9870 0.651
Table 12: Performance comparisons on Reuters-21578 with more than one label.

Table 12 shows the results on the test set of Reuters-21758 with more than one label. The smaller test set has 405 samples. Comparing to Table 3, all models perform worse on this smaller test set since samples with only one label are taken out. However, the performance gap between baseline models and proposed methods are larger, which strengthen the superiority of OCD and MTL.

Average ranking of models

Figure 5: Average ranks of different methods on four different datasets. The smaller the rank value, the better the performance. The MTL results are decoded by logistic joint decoding; “OfRNN” denotes order-free RNN.

Fig. 5 shows the average ranks of four datasets using different metrics. Note that all models achieve similar performance on HA on these datasets. Clearly, MTL performs the best, followed by OCD. Note that Seq2seq achieves the best performance in terms of ACC, but it performs worse in terms of other metrics.

Position-wise accuracy on AAPD

Figure 6: Position-wise accuracy of different models at each time step on AAPD. OCD+MTL was decoded by logisttic joint decoding. Note that the length of the longest generated(reference) label sequence is 6.