Neural Message Passing for Multi-Label Classification

04/17/2019 ∙ by Jack Lanchantin, et al. ∙ 18

Multi-label classification (MLC) is the task of assigning a set of target labels for a given sample. Modeling the combinatorial label interactions in MLC has been a long-haul challenge. We propose Label Message Passing (LaMP) Neural Networks to efficiently model the joint prediction of multiple labels. LaMP treats labels as nodes on a label-interaction graph and computes the hidden representation of each label node conditioned on the input using attention-based neural message passing. Attention enables LaMP to assign different importance to neighbor nodes per label, learning how labels interact (implicitly). The proposed models are simple, accurate, interpretable, structure-agnostic, and applicable for predicting dense labels since LaMP is incredibly parallelizable. We validate the benefits of LaMP on seven real-world MLC datasets, covering a broad spectrum of input/output types and outperforming the state-of-the-art results. Notably, LaMP enables intuitive interpretation of how classifying each label depends on the elements of a sample and at the same time rely on its interaction with other labels. We provide our code and datasets at https://github.com/QData/LaMP

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 3

page 5

page 6

page 9

page 12

page 13

page 15

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Multi-label classification (MLC) is receiving increasing attention in areas such as natural language processing, computational biology, and image recognition. Accurate and scalable MLC methods are in urgent need for applications like assigning topics to web articles, or identifying binding proteins on DNA. The most common and straightforward MLC method is the binary relevance (BR) approach that considers multiple target labels independently

[46]. However, in many MLC tasks there is a clear dependency structure among labels, which BR methods ignore.

Unfortunately, accurately modelling all combinatorial label interactions is an NP-hard problem. Many types of models, including a few deep neural network (DNN) based, have been introduced to approximately model such interactions, thus boosting classification accuracy.

Our main concern of this paper is how to represent multiple labels jointly (and conditioned on the input features) in order to make accurate predictions. The most relevant literature addressing this concern falls roughly into three groups.

The first group, probabilistic classifier chain (PCC) models, formulate the joint label dependencies using the chain rule and perform MLC in a sequential prediction manner

[39, 55, 36]. Notably, [36]

used a recurrent neural network (RNN) sequence to sequence (Seq2Seq) architecture

[17] for MLC and achieved the state-of-the-art performance on multiple text-based datasets. However, these methods are inherently unfit for MLC tasks due to their incapacity to be parallelized, and inability to perform well in dense label settings, or when there are a large number of positive labels (since errors propagate in the sequential prediction). We refer the reader to the supplementary material for a full background and analysis of PCC methods (Appendix section 5). The second group learns a shared latent space representing both input features and output labels, and then upsamples from the space to reconstruct the target labels [57, 6]. The main drawback of this group is the interpretability issue with a learned low dimensional latent space, as many real-world applications prefer interpretable predictors. The third group models conditional label dependencies using a structured output or graphical model representation [29, 45]. However, these methods are often limited to only considering pair-wise dependencies due to computational constraints, or are forced to use some variation of approximate inference which has no clear representation of conditional dependencies.

Thus our main question is: is it possible to have accurate, flexible and explainable MLC methods that are applicable to many dense labels? This paper provides empirical results showing that this is possible through extending attention based Message Passing Neural Networks (MPNNs) to learn the joint representation of multiple labels conditioned on input features.

MPNNs [15] are a class of methods that efficiently learn the joint representations of variables using neural message passing strategies. They provide a flexible framework for modeling multiple variables jointly which have no explicit ordering.

The key idea of our method is to rely on attention-based neural message passing entirely to draw global dependencies from labels to input features, and from labels to labels. To the best of our knowledge, this is the first extension of MPNNs to model a conditional joint representation of output labels, and additionally the first extension of MPNNs to model the interactions of variables where the exact structure is unknown. We name the proposed method Label Message Passing (LaMP) Networks since it performs neural message passing on an unknown, fully-connected label-to-label graph. Through intra-attention (aka self-attention), LaMP assigns different importance to different neighbor nodes per label, dynamically learning how labels interact conditioned on a specific input. We further extend LaMP to cases when a known label interaction graph is provided by modifying the intra-attention to only attend over a node’s known neighbors. LaMP networks allow for parallelization in training and testing and can work with dense labels, overcoming the drawbacks of PCC methods.

LaMP most closely belongs to the third MLC category we mentioned above, however it trains a unified model to classify each label and model the label to label dependencies at the same time, in an end-to-end fashion. The important aspect is that LaMP networks automatically learn the output label dependency structure conditioned on a specific input using neural message passing. This in turn can easily be interpreted to understand the conditional structure.

The main contributions of this paper include: (1) Accurate MLC

: Our model achieves similar, or better performance compared to the previous state of the art across five MLC metrics. We validate our model on eight MLC datasets which cover a wide spectrum of input data structure: sequences (English text, DNA), tabular (binary word vectors), graph (drug molecules), and images, as well as output label structure: unknown and graph. (2)

Interpretable

: Although deep-learning based systems have widely been viewed as “black boxes”, our attention based LaMP models allow for a straightforward way to extract three different types of model visualization: intermediate network predictions, label to feature dependencies, and label to label dependencies.

2 Method: LaMP Networks

Notations. We define the following notations, used throughout the paper. Let be the set of data samples with inputs and outputs . Inputs are a (possibly ordered) set of components , and outputs are a set of labels . MLC involves predicting the set of binary labels given input .

In general we can assume to represent the input feature components as embedded vectors , , using some learned embedding matrix . Here is the embedding size and, is the size of . can be any component of a particular input (for example, words in a sentence, patches of an image, nodes of a known graph, or one of the tabular features).

Similarly, labels can be first represented as embedded vectors , , through a learned embedding matrix , where denotes the number of labels. Here we use to represent the ‘state’ of the embedding after the update step. This is because in LaMP networks, each label embedding is updated for steps before the predictions are made. The key idea of LaMP networks is that labels are represented as nodes in a label-interaction graph denoting nodes as embedding vectors . LaMP networks use MPNN modules with attention to pass messages from input embeddings to , and then within to model the joint prediction of labels.

2.1 Background: Message Passing Neural Networks

Message Passing Neural Networks (MPNNs) [15] are a generalization of graph neural networks (GNNs) [41]. MPNNs model variables as nodes on a graph . Here , where describes the set of nodes (variables) and denotes the set of edges (about how variables interact with other variables). In MPNNs, joint representations of nodes and edges are modelled using message passing rather than explicit probabilistic formulations, allowing for efficient inference. MPNNs model the joint dependencies using message function and node update function for time steps, where is the current time step. The hidden state of node is updated based on messages from its neighboring nodes defined by neighborhood :

(1)
(2)

After rounds of iterative updates to spread information to distant nodes, a readout function is used on the updated node embeddings to make predictions like classifying nodes or classifying properties about the graph.

Many possibilities exist for functions and . We specifically choose to pass messages using intra-attention (also called as self-attention) neural message passing which enable nodes to attend over their neighborhoods differentially. This allows for the network to learn different importances for different nodes in a neighborhood, without depending on knowing the graph structure upfront (essentially learning the unknown graph structure) [53]. In this formulation, messages for node are obtained by a weighted sum of all its neighboring nodes where the weights are calculated by attention representing the importance of each neighbor for a specific node [2]. In the rest of the paper, we use “graph attention” and “neural message passing” interchangeably.

Intra-attention neural message passing works as follows. We first calculate attention weights for pair of nodes (, ) using attention function :

(3)
(4)
(5)

where represents the importance of node for node , however un-normalized. are normalized across all neighboring nodes of node using a softmax function (Eq 3) to get . For the attention function

, we used a scaled dot product with node-wise linear transformations

on node and on node . Scaling by is used to mitigate training issues [52].

Then we use a so called attention message function to produce the message from node to node using the learned attention weights and another transformation matrix :

(6)
(7)

Eq 7 computes the full message for node by linearly combining messages from all neighbor nodes

with a residual connection on the current

.

Lastly, node is updated to next state using message

by a multi-layer perceptron (MLP) update function

, plus a residual connection:

(8)
(9)

Function is parameterized with matrices . It is important to note that in Eq 9

are shared (i.e., separately applied) across all nodes. This can be viewed as 1-dimensional convolution operation with kernel and stride sizes of 1. Weight sharing across nodes is a key aspect of MPNNs, where node dependencies are learned in an order-invariant manner.

2.2 LaMP: Label Message Passing

Given the input embeddings , the goal of Label Message Passing is to model the conditional dependencies between label embeddings using Message Passing Neural Networks. We assume that the label embeddings are nodes on a label-interaction graph called , where the initial state of the embeddings at are obtained using label embedding matrix .

Each step in Label Message Passing consists of two parts in order to update the label embeddings: (a). Feature-to-Label Message Passing, where messages are passed from the input embeddings to the label embeddings, and (b). Label-to-Label Message Passing, where messages are passed between labels. An overview of our model is shown in Fig. 1. We explain these two parts in detail in the following subsections. LaMP Networks use steps of attention-based neural message passing to update the label nodes before a readout function makes a prediction for each label on its final state .

Figure 1: LaMP Networks. Given input , we encode its components as embedded input nodes . We encode labels as embedded label nodes of label-interaction graph . First, MPNN is used to pass messages from the input nodes to the labels nodes and update the label nodes. Then, MPNN is used to pass messages between the label nodes and update label nodes. Finally, readout function performs node-level classification on label nodes to make binary label predictions .

Updating Label Embeddings via Feature-to-Label Message Passing

Given a particular input with embedded feature components , the first step in LaMP is to update the label embeddings by passing messages from the input embeddings to the label embeddings, as shown in the “Feature-to-Label MP” block of Fig. 1. To do this, LaMP uses neural message passing module MPNN to update the label node’s embedding using the embeddings of all the components of an input.

That is, we update each by using a weighted sum of all input embeddings , in which the weights represent how important an input component is to the label node. The weights for the message are learned via Label-to-Feature attention (i.e., each label attends to each input embedding differently to compute the weights).

In this step, messages are only passed from the input nodes to the label nodes, and not vice versa (i.e. Feature-to-Label message passing is directed).

More specifically, to update label embedding , MPNN uses attention message function on all embeddings of the input to produce messages , and MLP update function to produce the updated intermediate embedding state :

(10)
(11)

The key advantage of Feature-to-Label message passing with attention is that each label node can attend differently on input elements (e.g. different words in an input sentence).

Updating Label Embeddings via Label-to-Label Message Passing

At this point, an independent prediction can be made for each label conditioned on using . However, in order to consider label dependencies, we model interactions between the label nodes using Label-to-Label message passing and update them accordingly, as shown in the “Label-to-Label MP” block of Fig. 1. Given the exponentially large number of possible conditional dependencies, we use neural message passing as an efficient way to much such interactions, which has been shown to work well in practice for other tasks.

We assume there exist a label interaction graph , , and includes all undirected pairwise edges connecting node and node . At this stage, we use another message passing module, MPNN to pass messages between labels and update them. The label embedding is updated by a weighted combination through attention of all its neighbor label nodes .

MPNN uses attention message function on all neighbor label embeddings to produce message , and MLP update function to compute updated embedding :

(12)
(13)

If there exists a known label interaction graph , message for node is computed using its neighboring nodes , where the neighbors are defined by the graph. If there is no known graph, we assume a fully connected graph, which means (including ).

Message Passing for Multiple Time Steps

To learn more complex relations among nodes, we compute a total of time steps of updates. This is essentially a stack of MPNN layers. In our implementation, the label embeddings are updated by MPNN and MPNN for time steps to produce .

2.3 Readout Layer (MLC Predictions from the label embeddings)

After updates to the label embeddings, the last module predicts each label }. A readout function projects each of the label embeddings using projection matrix , where row is the learned output vector for label . The calculated vector of size

is then fed through an element-wise sigmoid function to produce probabilities of each label being positive:

(14)

2.4 Model Details

Multi-head Attention.  In order to allow a particular node to attend to multiple other nodes (or multiple groups of nodes) at once, LaMP uses multiple attention heads. Inspired by [52], we use independent attention heads for each matrix during the message computation, where each matrix column is of dimension . The generated representations are concatenated (denoted by ) and linearly transformed by matrix . Multi-head attention changes message passing function , but update function stays the same.

(15)
(16)
(17)
(18)

Matrices , are not shared across time steps (but are shared across nodes).

Label Embedding Weight Sharing.  To enforce each label’s input embedding to correspond to that particular label, the label embedding matrix weights are shared with the readout projection matrix . In other words, is used to produce the initial node vectors for , and then is used again to calculate the pre-sigmoid output values for each label, so . This was shown beneficial in Seq2Seq models for machine translation [38].

2.5 Loss Function

The final output of LaMP networks are trained using the mean binary cross entropy (BCE) over all outputs . For one sample, given true binary label vector and predicted labels , the output loss is:

(19)

The final outputs are computed from the final label node states (Eq. 14). However, since LaMP networks iteratively update the label nodes from to , we can “probe” the label nodes at each intermediate state from = to - and enforce an auxilary loss on those states. To do this, we use the same matrix to extract the intermediate prediction at state : . We use the same BCE loss on the these predictions to compute intermediate loss :

(20)

We note that the intermediate predictions are computed for both (after Label-to-Label message passing), as well as (after Feature-to-Label message passing). The final loss is a combination of both the original and intermediate, where the intermediate loss is weighted by :

(21)

In LaMP networks, is approximated by jointly representing using message passing from and from the embeddings of all neighboring labels .

2.6 LaMP Variation: Input Encoding with Feature Message Passing (FMP)

Thus far, we have assumed that we use the raw feature embeddings to pass messages to the labels. However, we could also update the feature embeddings before they are passed to the label nodes by modelling the interactions between features.

For a particular input , we first assume that the input features are nodes on a graph, . , , and includes all undirected pairwise edges connecting node and node . MPNN, parameterized by , is used to pass messages between the input embeddings in order to update their states. Nodes on are represented as embedding vectors , where the initial states are obtained using embedding matrix on input components . The embeddings are then updated by MPNN using message passing for time steps to produce .

To update input embedding , MPNN uses attention message function (Eq. 6) on all neighboring input embeddings to produce messages , and MLP update function (Eq. 9) to produce updated embedding :

(22)
(23)

If there exists a known graph, message for node is computed using its neighboring nodes , where the neighbors are defined by the graph. If there is no known graph, we assume a fully connected graph, which means . Inputs with a sequential ordering can be modelled as a fully connected graph using positional embeddings [4].

In summary, MPNN is used to update input feature nodes by passing messages within the feature-interaction graph. MPNN, is used to update output label nodes by passing messages from the features to labels (from input nodes to output nodes ). MPNN, is used to update output label nodes by passing messages within the label-interaction graph (between label nodes). Once messages have been passed to update the feature and label nodes for integrative updates, a readout function is then used on the label nodes to make a binary classification prediction on each label, . Figure 1 shows the LaMP network without the feature-interaction graph.

2.7 Advantages of LaMP Models

Efficiently Handling Dense Label Predictions. 

It is known that autoregressive models such as RNN Seq2Seq suffer from the propagation of errors over the sequential positive label predictions. This makes it difficult for these models to handle dense, or many positive label, samples. In addition, autoregressive models require a time consuming post-processing step such as beam search to obtain the optimal label set. Lastly, autoregressive models require a predefined label ordering for training the sequential prediction, which can lead to instabilities at testing time

[54].

Motivated by the drawbacks of autoregressive models for MLC, the proposed LaMP model removes the reliance on sequential predictions, beam search, and a chosen label ordering, while still modelling the label dependencies. This is particularly beneficial when the number of positive output labels is large (i.e. dense). LaMP networks predict the output set of labels all at once, which is made possible by the fact that inference doesn’t use a probabilistic chain, but there is still a representation of label dependencies via label to label attention. As an additional benefit, as noted by [5], it may be useful to maintain ‘soft’ predictions for each label in MLC. This is a major drawback of the PCC models which make ‘hard’ predictions of the positive labels, defaulting all other labels to 0.

Structure Agnostic.  Many input or output types are instances where the relational structure is not made explicit, and must be inferred or assumed [4]. LaMP networks allow for greater flexibility of both input structures (known structure such as sequence or graph, or unknown such as tabular), as well as output structures (e.g., known graph vs unknown structure). To the best of our knowledge, this is the first work to use MPNNs to infer the relational structure of the data by using attention mechanisms.

Interpretability.  Our formulation of LaMP allows us to visualize predictions in several different ways. First, since predictions are made in an iterative manner via graph update steps, we can “probe” each label’s state at each step to get intermediate predictions. Second, we can visualize the attention weights which automatically learn the relational structure. Combining these two visualization methods allows us to see how the predictions change from the initial predictions given only the input sequence to the final state where messages have been passed from other labels, leading us to better insights for specific MLC samples.

2.8 Connecting to Related Topics

Structured Output Predictions.  The use of graph attention in LaMP models is closely connected to the literature of structured output prediction for MLC. [14] used conditional random fields (CRFs) [29] to model dependencies among labels and features for MLC by learning a distribution over pairs of labels to input features, but these are limited to pairwise dependencies.

To overcome the naive pairwise dependency constraint of CRFs, structured prediction energy networks (SPENS) [5] and related methods [50, 20] locally optimize an unconstrained structured output. In contrast to SPENs which use an iterative refinement of the output label predictions, our method is a simpler feed forward block to make predictions in one step, yet still models dependencies through attention mechanisms on embeddings, which gives the added interpretability benefit.

Multi-label Classification By Modeling Label Interaction Graphs.  [19] formulate MLC using a label graph and they introduced a conditional dependency SVM where they first trained separate classifiers for each label given the input and all other true labels and used Gibbs sampling to find the optimal label set. The main drawback is that this method does not scale to a large number of labels. [42] proposes a method to label the pairwise edges of randomly generated label graphs, and requires some chosen aggregation method over all random graphs. The authors introduce the idea that variation in the graph structure shifts the inductive bias of the base learners. Our fully connected label graph with attention on the neighboring nodes can be regarded as a form of graph ensemble learning [22]. [11] use graph neural networks for MLC, but focus on graph inputs. They do not explicitly model label the label-to-label dependencies, thus resulting in a worse performance than LaMP.

Graph Neural Networks (GNNs).  Passing embedding messages from node to neighbor nodes connects to a large body of literature on graph neural networks [4] and embedding models for structures [8].

The key idea is that instead of conducting probabilistic operations (e.g., product or re-normalization), the proposed models perform nonlinear function mappings in each step to learn feature representations of structured components. [15, 53, 3] all follow similar ideas to pass the embedding from node to neighbor nodes or neighbor edges.

There have been many recent works extending the basic GNN framework to update nodes using various message passing, update, and readout functions [26, 21, 31, 24, 59, 3, 15, 12]. We refer the readers to [4] for a survey. However, none of these have used GNNs for MLC. In addition, none of these have attempted to learn the graph structure by using neural attention on fully connected graphs.

3 Experiments

Reuters Bibtex Bookmarks Delicious RCV1 TFBS SIDER NUSWIDE
FastXML [37] - - - - 0.841 - - -
Madjarov [32] - 0.434 0.257 0.343 - - - -
SPEN [5] - 0.422 0.344 0.375 - - - -
RNN Seq2Seq [36] 0.894 0.393 0.362 0.320 0.890 0.249 0.356 0.329
Emb + MLP 0.854 0.363 0.368 0.371 0.865 0.167 0.766 0.371
Emb + LaMP 0.859 0.379 0.351 0.358 0.868 0.289 0.767 0.376
Emb + LaMP 0.896 0.427 0.376 0.368 0.871 0.319 0.763 0.376
Emb + LaMP 0.895 0.424 0.373 0.366 0.870 0.317 0.765 0.372
FMP + LaMP 0.883 0.435 0.375 0.369 0.887 0.310 0.766 -
FMP + LaMP 0.906 0.445 0.389 0.372 0.889 0.321 0.764 -
FMP + LaMP 0.902 0.447 0.386 0.372 0.887 0.321 0.766 -
Table 1: ebF1 Scores across all 8 datasets

We validate our model on eight real world MLC datasets. These datasets vary in the number of samples, number of labels, input type (sequential, tabular, graph, vector), and output type (unknown, known label graph). They also cover a wide spectrum of input data types, including: raw English text (sequential form), binary word vector (tabular form), drug molecules (graph form), and images (vector form). Data statistics are in Table 6 and Section 8.1

. Due to the space limit, we move the details of evaluation metrics to Section 

8.2 and the hyper-parameters to Section 8.3. Details of previous results from the state-of-the-art baselines are in Section 8.4.

3.1 LaMP Variations

For LaMP models, we use two variations of input features, and three variations of Label-to-Label Message Passing. For input features, we use (1) Emb, which is the raw learned feature embeddings of dimension , and (2) FMP222For NUS-WIDE, since we use the 128-dimensional cVLAD features as input to compare to [11], we cannot use the FMP method. which is the updated state of each feature embedding after 2 layers of Feature Message Passing, as explained in 2.6. For each of the two input feature variations, we use three variations of the label graph which Label-to-Label Message Passing uses to update the labels given the input features, explained as follows.

LaMP uses an edgeless label graph and messages are not passed between labels, assuming no label dependencies.

LaMP uses a fully connected label graph where each label is able to attend to all other labels (including itself) in order to compute the messages.

LaMP uses a prior label graph where each label is able to attend to only other labels from the known label graph (including itself) in order to compute the messages. For RCV1, we use the known tree label structure, and for TFBS we use known protein-protein interactions (PPI) from [44]. For all other datasets, we create a graph where we place an edge on the adjacency matrix for all labels that co-occur in any sample for the training set. This is summarized in the last column of Appendix Table 5.

3.2 Performance Evaluation

ebF1.  Table 1 shows the most commonly used evaluation, example-based F1 (ebF1) scores, for the seven datasets. LaMP outperforms the baseline MLP models which assume no label dependencies, as well as RNN Seq2Seq, which models label dependencies using a classifier chain. More importantly, we compare using an output graph with no edges (LaMP), which assumes no label dependencies vs. an output graph with edges (LaMP). The two models have the same architecture and number of parameters, with the only thing varying being the message passing between label nodes. We can see that for most datasets, modelling label dependencies using LaMP does in fact help. We found that using a known prior label structure (LaMP) did not improve the results significantly. LaMP predictions produced an average 1.8 ebF1 score increase over the independent LaMP predictions. LaMP resulted in an average 1.7 ebF1 score increase over LaMP. When comparing to the MLP baseline, LaMP and LaMP produced an average 18.5 and 18.4 increase, respectively.

miF1.  While high ebF1 scores indicate strong average F1 scores over all samples, the label-based Micro-averaged F1 (miF1) scores indicate strong results on the most frequent labels. Table 2 shows the miF1 scores, for the all datasets. LaMP produced an average 1.6 miF1 score increase over the independent LaMP. LaMP produced an average 1.8 miF1 score increase over LaMP. When comparing to the MLP baseline, LaMP and LaMP resulted in an average 20.2 and 20.5 increase, respectively.

maF1.  Contrarily, high label-based Macro-averaged F1 (maF1) scores indicate strong results on less frequent labels. Table 2 shows maF1 scores, which show the strongest improvement of LaMP and LaMP variation over independent predictions. LaMP resulted in an average 2.4 maF1 score increase over the independent LaMP. LaMP produced an average 2.1 maF1 score increase over LaMP. This indicates that Label-to-Label message passing can help boost the accuracy of rare label predictions. When comparing to Emb + MLP, LaMP and LaMP produced an average 57.0 and 56.7 increase, respectively.

Reuters Bibtex Bookmarks Delicious RCV1 TFBS SIDER NUSWIDE
FastXML [37] - - - - 0.847 - - -
SVM [9] 0.787 - - - - - - -
GAML [11] - - 0.333 - - - - 0.398
Madjarov [32] - 0.462 0.268 0.339 - - - -
RNN Seq2Seq [36] 0.858 0.384 0.329 0.329 0.884 0.311 0.389 0.418
Emb + MLP 0.835 0.389 0.349 0.385 0.855 0.218 0.795 0.465
Emb + LaMP 0.842 0.413 0.334 0.372 0.858 0.401 0.797 0.472
Emb + LaMP 0.871 0.458 0.363 0.379 0.859 0.449 0.797 0.470
Emb + LaMP 0.877 0.462 0.363 0.380 0.859 0.448 0.798 0.468
FMP + LaMP 0.870 0.455 0.355 0.381 0.877 0.445 0.795 -
FMP + LaMP 0.886 0.465 0.373 0.384 0.877 0.450 0.795 -
FMP + LaMP 0.889 0.473 0.371 0.386 0.877 0.449 0.797 -
Table 2: miF1 Scores across all 8 datasets

Other Metrics.  Due to space constraints, we report subset accuracy in Appendix (supplementary) Table 7. RNN Seq2Seq models mostly perform all other models for this metric since they are trained to maximize it[36]. However, for all other metrics, RNN Seq2Seq does not perform as well, concluding that for most applications, PCC models aren’t necessary. We also report Hamming Accuracy in Appendix Table 8, and we note that LaMP networks outperform or perform similarly to baseline methods, but we observe that this metric is mostly unhelpful.

Metrics Performance Summary.  While LaMP does not explicitly model label dependencies as autoregressive or structured prediction models do, the attention weights do learn some dependencies among labels (Section 3.3). This is indicated by the fact that LaMP, which uses Label-to-Label attention, mostly outperforms the ones which don’t, indicating that it is learning label dependencies.

Speed.  LaMP results in a mean of 1.7x and 5.0x training and testing speedups, respectively, over the previous state-of-the-art probabilistic MLC method, RNN Seq2Seq. Speedups over RNN Seq2Seq model are shown in Table 4.

Reuters Bibtex Bookmarks Delicious RCV1 TFBS SIDER NUSWIDE
SVM [9] 0.468 - - - - - - -
FastXML [37] - - - - 0.592 - - -
GAML [11] - - 0.217 - - - - 0.114
Madjarov [32] - 0.316 0.119 0.142 - - - -
RNN Seq2Seq [36] 0.457 0.282 0.237 0.166 0.741 0.210 0.207 0.143
Emb + MLP 0.366 0.275 0.248 0.180 0.667 0.094 0.665 0.173
Emb + LaMP 0.476 0.308 0.229 0.176 0.680 0.326 0.666 0.198
Emb + LaMP 0.547 0.366 0.271 0.192 0.691 0.362 0.663 0.203
Emb + LaMP 0.560 0.372 0.267 0.192 0.698 0.365 0.663 0.196
FMP + LaMP 0.508 0.353 0.266 0.192 0.742 0.368 0.664 -
FMP + LaMP 0.520 0.371 0.286 0.195 0.743 0.364 0.668 -
FMP + LaMP 0.517 0.376 0.280 0.196 0.740 0.364 0.664 -
Table 3: maF1 Scores across all 8 datasets

3.3 Interpretability Evaluation

The structure of LaMP networks allows for three different types of visualization methods to understand how the network predicts each label. We explain the three types here and show the results for a sample from the Bookmarks dataset using the FMP + LaMP model.

Intermediate Output Prediction.  One advantage of the multi step formulation of label embedding updates is that it gives us the ability to probe the state of each label at intermediate steps and view the model’s predictions at those steps. To do this, we use the readout function on each intermediate label embeddings state to find the probability that the label embedding would predict a positive label. In other words, this is the post-sigmoid output of the readout function of each embedding at each step . We note that each step contains two stages: is the output after the Feature-to-Label message passing, and is output after the Label-to-Label message passing. The output after the second stage of the final step (i.e. ) is the model’s final output.

Figure 2 (a.) shows the intermediate prediction outputs from the step model. On the horizontal axis are a selected subset of all possible labels, with the red colored axis labels being all true positive labels. On the vertical axis, each row represents one of the label embedding states in the step model. Each cell represents the readout function’s prediction for each label embedding’s state. The brighter the grid cell, the more likely that label is positive at the current stage. Starting from the bottom, the first row shows the prediction of each label after the first Feature-to-Label message passing. The second row () shows the prediction of each label after the first Label-to-Label message passing. This is then repeated once more in and for the second layer’s output states, where the final output, is the network’s final output predictions. The most important aspect of this figure is that we can see the labels “design”, “html”, and “web design”, all change from weakly positive to strongly positive after the first Label-to-Label message passing step (row ). In other words, this indicates that these labels change to a strongly positive prediction by passing messages between each other.

Figure 2: (a) Visualization of Model Predictions and Attention Weights Intermediate Predictions: this shows the readout function predictions for each intermediate state in the two update steps. (b) Label-to-Feature Attention Weights for the first step of Feature-to-Label message passing (). (c) Label-to-Label Attention Weights for the first step of Label-to-Label message passing ().

Label-to-Feature Attention.  While the iterative prediction visualization shows how the model updates its prediction of each label, it doesn’t explicitly show how or why. To understand why each label changes its predictions, we first look at the Feature-to-Label attention, which tells us the input nodes that each label node attends to in order to update its state (and thus producing the predictions in Figure 2 (a.)). Figure 2 (b.) shows us which input nodes (i.e. features) each of the positive label attends to in order to make its first update step . The colors represent the post-softmax attention weight (summed over the 4 attention heads), with the darker cells representing high attention. In this example, we can see that the “web design” label attends to the “pick”, “smart”, and “version” features, but as we can see from the first row of Figure 2 (a.), prediction for the current state of the “web design” label isn’t very strong yet.

Label-to-Label Attention.  Label-to-Feature attention shows us the input nodes that each label node attends to in order to make its first update, but the second step of the label graph update is the Label-to-Label message passing step where labels are updated according to the states of all other nodes after the first Feature-to-Label message passing. Figure 2 (c.) shows us the first Label-to-Label attention stage where each label node can attend to the other label nodes in order to update its state. Here we show only the Label-to-Label attention for the positive labels in this example. We then look at the second row of Figure 2 (a.) which shows the model’s prediction of each label node after the Label-to-Label attention. The interesting thing to note here is we can see many of the true positive labels change their state to positive after the positive labels attend to each other during the Label-to-Label attention step, indicating that dependencies are learned.

Attention weights for the second step are not as interpretable since they model higher order interactions. We have added these plots in Appendix Fig. 3.

Dataset Training Testing
Reuters 0.788 (1.5x) 0.116 (2.1x)
Bibtex 0.376 (2.1x) 0.080 (2.1x)
Delicious 3.172 (1.1x) 0.473 (3.2x)
Bookmarks 9.664 (1.2x) 1.849 (1.3x)
RCV1 98.346 (1.2x) 1.003 (1.7x)
TFBS 187.14 (2.5x) 13.04 (4.2x)
NUS-WIDE 3.201 (1.2x) 0.921 (8.0x)
SIDER 0.027 (2.5x) 0.003 (21x)
Table 4: Speed.

Each column shows training or testing speed for LaMP in minutes per epoch. Speedups over RNN Seq2Seq are in parentheses. Since LaMP does not depend on sequential prediction, we see a drastic speedup, especially during testing where RNN Seq2Seq requires beam search.

4 Conclusion

In this work we present Label Message Passing (LaMP) Networks which achieve better than, or close to the same accuracy as previous methods across five metrics and seven datasets. In addition, the iterative label embedding updates with attention of LaMP provide a straightforward way to shed light on the model’s predictions and allow us to extract three forms of visualizations, including conditional label dependencies which influence MLC classifications.

References

5 Appendix: MLC Background

5.1 Background of Multi-Label Classification:

MLC has a rich history in text [33, 51], images [46, 13], bioinformatics [46, 13], and many other domains. MLC methods can roughly be broken into several groups, which are explained as follows.

Label powerset models (LP) [49, 40], classify each input into one label combination from the set of all possible combinations

. LP explicitly models the joint distribution by predicting the one subset of all positive labels. Since the label set

grows exponentially in the number of total labels (), classifying each possible label set is intractable for a modest . In addition, even in small tasks, LP suffers from the “subset scarcity problem” where only a small amount of the label subsets are seen during training, leading to bad generalization.

Binary relevance (BR) methods predict each label separately as a logistic regression classfier for each label

[58, 16]. The naïve approach to BR prediction is to predict all labels independently of one another, assuming no dependencies among labels. That is, BR uses the following conditional probability parameterized by learned weights :

(24)

Probabilistic classifier chain (PCC) methods [10, 39]

are autoregressive models that estimate the true joint probability of output labels given the input by using the chain rule, predicting one label at a time:

(25)

Two issues with PCC models are that inference is very slow if is large, and the errors propagate as increases [34]. To mitigate the problems with both LP and PCC methods, one solution is to only predict the true labels in the LP subset. In other words, only predicting the positive labels (total of for a particular sample) and ignoring all other labels, which we call PCC. Similar to PCC, the joint probability of PCC can be computed as product of conditional probabilities, but unlike PCC, only terms are predicted as positive:

(26)

This can be beneficial when the number of possible labels is large, reducing the total number of prediction steps. However, in both PCC and PCC, inference is done using beam search, which is a costly dynamic programming step to find the optimal prediction.

Recently, Recurrent neural network (RNN) based encoder-decoder models following PCC and PCC have shown state-of-the-art performance for solving MLC. However, the sequential nature of modeling label dependencies through an RNN limits its ability in parallel computation, predicting dense labels, and providing interpretable results.

The main drawback of classifier chain models is that their inherently sequential nature precludes parallelization during training and inference. This can be detrimental when there are a large number of positive labels as the classifier chain has to sequentially predict each label, and often requires beam search to obtain the optimal set. Aside from time-cost disadvantages, PCC methods have several other drawbacks. First, PCC methods require a defined ordering of labels for the sequential prediction, but MLC output labels are an unordered set, and the chosen order can lead to prediction instability [36]. Secondly, even if the optimal ordering is known, PCC methods struggle to accurately capture long-range dependencies among labels in cases where the number of positive labels is large (i.e., dense labels). For example, the Delicious dataset we used in the experiment has a median of 19 positive labels per sample, so it can be difficult to correctly predict the labels at the end of the prediction chain. Lastly, many real-world applications prefer interpretable predictors. For instance, in the task of predicting which proteins (labels) will bind to a DNA sequence based binding site, users care about how a prediction is made and how the interactions among labels (proteins) influence the binding predictions. An important task is modelling what is known as “co-binding” effects, where one protein will only bind if there is another specific protein already binding, or similarly will not bind if there is another already binding.

LaMP methods approximate the following factored formulation, where denotes the neighboring nodes of .

(27)

5.2 Seq2Seq Models

In machine translation (MT), sequence-to-sequence (Seq2Seq) models have proven to be the superior method, where an encoder RNN reads the source language sentence into an encoder hidden state, and a decoder RNN translates the hidden state into a target sentence, predicting each word autoregressively [43]. [2] improved this model by introducing “neural attention” which allows the decoder RNN to “attend” to every encoder word at each step of the autoregressive translation.

Recently, [36] showed that, across several metrics, state-of-the-art MLC results could be achieved by using a recurrent neural network (RNN) based encoder-to-decoder framework for Equation 26 (PCC). They use a Seq2Seq RNN model (Seq2Seq Autoregressive) which uses one RNN to encode , and a second RNN to predict each positive label sequentially, until it predicts a ‘stop’ signal. This type of model seeks to maximize the ‘subset accuracy’, or correctly predict every label as its exact 0/1 value.

[52] eliminated the need for the recurrent network in MT by introducing the Transformer. Instead of using an RNN to model dependencies, the Transformer explicitly models pairwise dependencies among all of the features by using attention [2, 56] between signals. This speeds up training time because RNNs can’t be fully parallelized but, the transformer still uses an autoregressive decoder.

5.3 Drawbacks of Autoregressive Models

Seq2Seq MLC [36] uses an encoder RNN encoding elements of an input sequence, a decoder RNN predicting output labels one after another, and beam search that computes the probability of the next predictions of labels and then chooses the solution with the max combined probability.

Autoregressive models have been proven effective for machine translation and MLC [43, 2, 36]. However, predictions must be made sequentially, eliminating parallelization. Also, beam search is typically used at test time to find optimal predictions. But beam search is limited by the time cost of large beams sizes, making it difficult to optimally predict many output labels [27].

In addition to speed constraints, beam search for autoregressive inference introduces a second drawback: initial wrong predictions will propagate when using a modest beam size (e.g. most models use a beam size of 5). This can lead to significant decreases in performance when the number of positive labels is large. For example, the Delicious dataset has a median of 19 positive labels per sample, and it can be very difficult to correctly predict the labels at the end of the prediction chain.

Autoregressive models are well suited for machine translation because these models mimic the sequential decoding process of real translation. However, for MLC, the output labels have no intrinsic ordering. While the joint probability of the output labels is independent of the label ordering via autoregressive based inference, the chosen ordering can make a difference in practice [54, 36]. Some ordering of labels must be used during training, and this chosen ordering can lead to unstable predictions at test time.

Our LaMP connects to [18] who removed the autoregressive decoder in MT with the Non-Autoregressive Transformer. In this model, the encoder makes a proxy prediction, called “fertilities”, which are used by the decoder to predict all translated words at once. The difference between their model and ours is that we have a constant label at each position, so we don’t need to marginalize over all possible labels at each position.

6 Appendix: Dataset Details

Dataset Input Type Domain #Train #Val #Test
Labels
()
Features
Prior
Graph
Structure
Reuters-21578 Sequential Text 6,993 777 3,019 90 23,662 Co-occur
RCV1-V2 Sequential Text 703,135 78,126 23,149 103 368,998 Tree
TFBS Sequential Biology 1,671,873 301,823 323,796 179 4 PPI
BibTex Binary Vector Text 4,377 487 2,515 159 1,836 Co-occur
Delicious Binary Vector Text 11,597 1,289 3,185 983 500 Co-occur
Bookmarks Binary Vector Text 48,000 12,000 27,856 208 368,998 Co-occur
NUS-WIDE Vector Image 129,431 32,358 107,859 85 128 Co-occur
SIDER Graph Drug 1,141 143 143 27 37 Co-occur
Table 5: Dataset Statistics. We use 7 well studied MLC datasets, plus our own TFBS protein binding dataset. Each dataset varies in the input type, number of samples, number of labels, and number of input features. The last column shows the prior graph structure type we explore for the LaMP model.
Dataset
Mean
Labels
/Sample
Median
Labels
/Sample
Max
Labels
/Sample
Mean
Samples
/Label
Median
Samples
/Label
Max
Samples
/Label
Reuters-21578 1.23 1 15 106.50 18 2,877
RCV1-V2 3.21 3 17 24,362 7,250 363,991
TFBS 7.62 2 178 84,047 45,389 466,876
BibTex 2.38 2 28 72.79 54 689
Delicious 19.06 20 25 250.15 85 5,189
Bookmarks 2.03 1 44 584.67 381 4,642
NUS-WIDE 1.86 1 12 3721.7 1104 44255
SIDER 15.3 16 26 731.07 851 1185
Table 6: Additional Dataset Statistics Here we show additional statistics of datasets with respect to the specific number of labels for each dataset. This shows how each dataset has a varying degree of MLC difficulty regarding the number of labels which need to be predicted.

7 Appendix: Extra Metrics

Here we provide the results from an extra two metrics: subset accuracy and hamming accuracy.

Reuters Bibtex Bookmarks Delicious RCV1 TFBS SIDER NUSWIDE
Madjarov - 0.202 0.209 0.018 - - - -
RNN Seq2Seq 0.837 0.195 0.273 0.016 0.6798 0.114 0.000 0.252
Emb + MLP 0.774 0.151 0.234 0.180 0.620 0.040 0.014 0.263
Emb + LaMP 0.757 0.141 0.214 0.176 0.619 0.077 0.014 0.268
Emb + LaMP 0.813 0.171 0.234 0.192 0.630 0.086 0.007 0.269
Emb + LaMP 0.813 0.169 0.232 0.192 0.621 0.087 0.007 0.267
FMP + LaMP 0.808 0.158 0.231 0.192 0.656 0.084 0.007 -
FMP + LaMP 0.835 0.182 0.242 0.195 0.660 0.090 0.014 -
FMP + LaMP 0.828 0.185 0.241 0.196 0.659 0.090 0.007 -
Table 7: Subset Accuracy Scores across all 7 datasets
Reuters Bibtex Bookmarks Delicious RCV1 TFBS SIDER NUSWIDE
Madjarov - 0.988 0.991 0.982 - - - -
RNN Seq2Seq 0.996 0.985 0.990 0.980 0.9925 0.961 0.593 0.980
Emb + MLP 0.996 0.987 0.991 0.982 0.992 0.959 0.752 0.980
Emb + LaMP 0.996 0.987 0.991 0.982 0.992 0.963 0.750 0.980
Emb + LaMP 0.997 0.988 0.992 0.982 0.992 0.964 0.752 0.980
Emb + LaMP 0.997 0.988 0.991 0.982 0.992 0.964 0.751 0.980
FMP + LaMP 0.997 0.987 0.991 0.982 0.993 0.964 0.748 -
FMP + LaMP 0.997 0.988 0.992 0.982 0.993 0.964 0.749 -
FMP + LaMP 0.997 0.988 0.992 0.982 0.993 0.964 0.747 -
Table 8: Hamming Accuracy across all 7 datasets

8 Appendix: More about Experiments

Figure 3: This shows the step visualization results from Fig 2 (a). 2.1 Label-to-Feature Attention Weights (b). 2.2 Label-to-Label Attention Weights (c)

8.1 Datasets

We test our method against baseline methods on seven different multi-label sequence classification datasets. The datasets are summarized in Table 6. We use Reuters-21578 [30], Bibtex [48], Delicious [47], Bookmarks [23], RCV1-V2 [30], our own DNA protein binding dataset (TFBS) from [7], and SIDER [28], which is side effects of drug molecules. As shown in the table, each dataset has a varying number of samples, number of labels, positive labels per sample, and samples per label. For BibTex and Delicious, we use 10% of the provided training set for validation. For the TFBS dataset, we use 1 layer of convolution at the first layer to extract “words” from the DNA characters (A,C,G,T), as commonly done in deep learning models for DNA.

For datasets which have sequential ordering of the input components (Reuters, RCV1), we add a positional encoding to the word embedding as used in [52] (sine and cosine functions of different frequencies) to encode the location of each word in the sentence. For datasets with no ordering or graph stucture (Bibtex, Delicious, Bookmarks, which use bag-of-word input representations) we do not use positional encodings. For inputs with an explicit graph representation (SIDER), we use the known graph structer.

8.2 Evaluation Metrics

Multi-label classification methods can be evaluated with many different metrics which each evaluate different strengths or weaknesses. We use the same 5 evaluation metrics from [36].

All of our autoregressive models predict only the positive labels before outputting a stop signal. This is a special case of PCC models , which have been shown to outperform the binary prediction of each label in terms of performance and speed. These models use beam search at inference time with a beam size of 5. For the non-autoregressive models, to convert the labels to we chose the best threshold on the validation set from the same set of thresholds used in [50].

Example-based measures are defined by comparing the target vector to the prediction vector . Subset Accuracy (ACC) requires an exact match of the predicted labels and the true labels: ACC. Hamming Accuracy (HA) evaluates how many labels are correctly predicted in : HA. Example-based F1 (ebF1) measures the ratio of correctly predicted labels to the sum of the total true and predicted labels: .

Label-based measures treat each label as a separate two-class prediction problem, and compute the number of true positives (), false positives (), and false negatives () for a label. Macro-averaged F1 (maF1) measures the label-based F1 averaged over all labels: . Micro-averaged F1 (miF1) measures the label-based F1 averaged over each sample: . High maF1 scores usually indicate high performance on less frequent labels. High miF1 scores usually indicate high performance on more frequent labels.

8.3 Model Hyperparameter Tuning

For all 7 datasets (Table 6), we use the same LaMP model with =2 time steps, , and =4 attention heads. We trained our models on an NVIDIA TITAN X Pascal with a batch size of 32. We used Adam [25] with betas, eps1e-08, and a learning rate of 0.0002 for each dataset. We used dropout of for the smaller datasets (Reuters, Bibtex, SIDER), and dropout of for all other datasets. The LaMP models also use layer normalization [1] around each of the attention and feedforward layers. All LaMP models are trained with the LaMP loss (Eq. 21

). The hyperparameter

is selected from the best performing value in for each model. MLP models are trained with regular binary cross entropy (Eq. 19), and RNN Seq2Seq model are trained with cross entropy across all possible labels at each position. To convert the soft predictions into values, we use the same thresholds in [5] and select the best one for each metric on the validation set. For the TFBS dataset, which uses DNA input sequences, we use one layer of convolution to get 512 dimensional embeddings as commonly done for deep neural network prediction tasks on DNA sequences.

8.4 Baseline Comparisons

Briefly, we compare against the following methods for all reported datasets and metrics. For those results named as “Madjarov”: we take the best method for each reported metric from [32]

who compared 12 different types of models including SVMs, decision trees, boosting, classification rules, and neural networks. For results of “SPEN”: Structured prediction energy networks from

[5]. For results of “SVM”: SVM method from the Reuters dataset authors [9]

. For results of “FastXML”: Fast random forest model

[37]. For results of “GAML”: graph attention for MLC from [11]. For “RNN Seq2Seq”: RNN Sequence to Sequence model from [35] which is a PCC model that predicts only the positive labels. For “Emb + MLP”: we use the mean embeddings of all input features as the input to a 4 layer multi-layer perceptron (MLP). This is a BR baseline which predicts all labels independently.