Discrete Infomax Codes for Meta-Learning

05/28/2019 ∙ by Yoonho Lee, et al. ∙ POSTECH Kakao Corp. 0

Learning compact discrete representations of data is itself a key task in addition to facilitating subsequent processing. It is also relevant to meta-learning since a latent representation shared across relevant tasks enables a model to adapt to new tasks quickly. In this paper, we present a method for learning a stochastic encoder that yields discrete p-way codes of length d by maximizing the mutual information between representations and labels. We show that previous loss functions for deep metric learning are approximations to this information-theoretic objective function. Our model, Discrete InfoMax Codes (DIMCO), learns to produce a short representation of data that can be used to classify classes with few labeled examples. Our analysis shows that using shorter codes reduces overfitting in the context of few-shot classification. Experiments show that DIMCO requires less memory (i.e., code length) for performance similar to previous methods and that our method is particularly effective when the training dataset is small.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The task of learning a representation of data that reveals its underlying structure is a fundamental problem in machine learning. Deep neural networks

(Krizhevsky et al., 2012) have achieved remarkable success in this problem by learning a hierarchy of representations where each representation (i.e. layer activation) directly determines the next. However, deep representations learned from one task requires additional fine-tuning to be used in new tasks. This process can be time-consuming to perform on each new task and can also cause the model to overfit.

Deep metric learning models (Hoffer and Ailon, 2015; Snell et al., 2017) similarly learn representations of data using a neural network, but differ from ordinary deep networks in that they learn a representation of data that can directly be used to classify a novel dataset. Such representations have been interpreted as meta-learned knowledge which a nearest neighbor classifier can use to classify novel classes given only a few examples. The primary focus of previous work (Koch et al., 2015; Vinyals et al., 2016; Snell et al., 2017; Sung et al., 2018; Oreshkin et al., 2018) was on the metric inherited by comparing two datapoints using their learned representation.

Rather than focusing on the metric, we tackle the problem of optimizing the representation itself. In particular, we argue that a good representation of data should be as concise as possible while being able to predict class labels. We propose Discrete InfoMax COdes (DIMCO), a model that learns a discrete representation of data. We propose to maximize the correlation between representation and label by directly maximizing their mutual information, which can be evaluated in closed-form because we consider discrete representations. This approach has the advantage that it doesn’t require a mapping from representation to labels and does not require batches that are split into train- and test- sets.

Our specific contributions are:

  1. Derive generalization bounds for meta-learning that shows the roles of task size and number of tasks.

  2. Propose DIMCO, a model that learns concise discrete codes. DIMCO (1) generalizes better than previous models when trained with small datasets, and (2) is more memory- and time-efficient for image retrievel.

2 Supervised Representation Learning

We outline two tasks which can be seen as instances of the more general problem of supervised representation learning. We define supervised representation learning as the task of using class labels to learn useful representations of data. This problem differs from standard classification as it aims to learn a representation that generalizes to other datasets rather than directly predicting the labels themselves.

Few-shot Classification

The few-shot classification task consists of episodes, each of which are small datasets with train/test splits. In -way -shot classification, each episode has a train set with datapoints each from classes, and a test set of unlabeled instances from the same classes. Within each episode, the model observes the train set to predict the labels of the test set images and is evaluated on its accuracy.

Image Retrieval

Image retrieval is the problem of taking a query image and retrieving the most similar image from a large database of images. Models for this task are evaluated by measuring the similarity between a query image and a retrieved image. An example of such a measure is Recall@k:

(1)

where the definition of "relevant" depends on the specific dataset. For class-labeled images, an image is relevant to a query image if the two belong to the same class.

Learning a continuous representation and comparing data in embedding space has been proposed as a solution to both few-shot classification (Vinyals et al., 2016; Snell et al., 2017) and image retrieval Hoffer and Ailon (2015); Sohn (2016). We show in section 6 that the metrics for these two problems are strongly correlated, which motivates our consideration of the more general problem of supervised representation learning. In the next section, we propose an alternative information-theoretic objective for supervised representation learning.

3 An Information-Theoretic Perspective on Representation Learning

Throughout this section, we denote data, representations, and labels as , , and , respectively. Capital symbols , ,

denote the random variables corresponding to

, , .

The Mutual Information between two random variables is defined as

(2)

is a symmetric quantity which measures the amount of information shared between and . It has its lowest value when and are independent and increases with the correlation between and . We refer the reader to (Cover and Thomas, 2012) for further exposition.

3.1 Problem Setup

We now describe our meta-learning problem setup. Define a task to be a distribution over . Let tasks be sampled i.i.d. from a distribution of tasks . Associated with task is a dataset which is a set of i.i.d. samples from the data distribution ().

Denote model parameters and the representation , to show its dependence on data and parameters. Our learning objective is the expectation of the negative mutual information between the representation and labels:

(3)

This differs from previous formulations of representation learning in the following ways:

  1. Objective is negative mutual information within batch

  2. Does not split each task into a train/test set

This objective is closely related to previous loss functions and to previous evaluation metrics for supervised representation learning. We show in

appendix A of the appendix that previous loss functions can be seen as approximations to this quantity, and experiments in section 6 show that the mutual information is strongly correlated with metrics such as few-shot accuracy and .

3.2 Generalization Bound

We bound the true expected loss using the empirical loss:

Theorem 1.

Let be defined as above. Let

be the empirical estimate of the mutual information using finite dataset

, and define empirical loss as

(4)

The following inequality holds with high probability:

(5)
Proof.

See appendix B of appendix. ∎

The generalization gap has three terms, two of which decrease as increases, and the other decreases as increases. Typically for few-shot learning, is very large while is small: miniImagenet -way -shot has and . We therefore claim that the terms including are the main difficulties for generalizing to new tasks. We see from theorem 1 that using short representations (i.e. small ) can compensate for having a small train set (i.e. small ).

4 Discrete Infomax Codes (DIMCO)

[width=]figures/overview.pdf

Figure 1: A graphical overview of Discrete InfoMax COdes (DIMCO). A dataset consists of pairs of images and labels . DIMCO is a stochastic encoder that maps each image to a distribution of discrete codes . Each discrete code is a -way code of length . If (as in the diagram), each code consists of symbols and each symbol is . Inside the grid that represents the possible codes, the most likely row and column are colored. The most likely code in the diagram is with probability . DIMCO is optimized by maximizing the mutual information between the discrete code and the label within each batch.

We now present our model, Discrete InfoMax COdes (DIMCO). Motivated by section 3, DIMCO produces a short discrete code and is trained by maximizing mutual information . Figure 1 graphically shows the overall structure of DIMCO.

4.1 Factorized Discrete Codes

We propose a factorized discrete representation scheme which enables us to represent discrete distributions with exponentially fewer parameters compared to listing the probability of each event. We represent each event as the product of independent events, each of which consists of different possibilities. We thus have events in total, but only require parameters to represent the probability of each event. Binary codes can be viewed as a special case of this scheme where . This factorization trick allows us to consider representations of size (section 6). This representation has the advantage of requiring only bits per datapoint, whereas a

-dimensional continuous vector embedding requires

bits (assuming -bit floats).

4.2 Model

Recall that we represent a given image using independent discrete distributions, each of which has

possibilities. First, a (convolutional) neural network

takes image as input and outputs a vector of length , which we reshape into a matrix of size :

(6)

Each row of this matrix represents the logits of a discrete distribution. We apply the softmax function to each row to get probabilities.

(7)

The th codeword is sampled according to the categorical distribution following these probabilites:

(8)

The representation for the image is the concatenation of each :

(9)

4.3 Training

Recall that

is a discrete random variable and

is its distribution. Instead of sampling , we directly use to compute the objective:

(10)

The first term, , can be calculated by taking the average of all probabilities and computing the entropy:

(11)

The second term is where is the number of classes. The marginal probability of Y () is the frequency of in . can be obtained by computing (11) using only for which .

Though we have motivated the use of as a loss function throughout this paper, we provide yet another perspective using the decomposition in (10). Minimizing encourages discriminatory behavior. This term encourages the average embedding of each class to be as concentrated as possible. Maximizing incentivizes the model to overall use all possible values of .

We emphasize that such closed-form computation of is only possible because we are using discrete codes.

4.4 Evaluation

We map all images to their probabilities (eq. 6,7) for and . We map each training image to its most likely code:

(12)

Fix a train image and a test image, and let be the most likely code for the train image. The similarity between train image and test image is measured by the probability of the test image producing . This amounts to computing the product 111 In practice, we add log probabilities for numerical stability. of the test image’s probabilites using for each :

(13)

We use this as a similarity metric for both few-shot classification and image retrieval. We perform few-shot classification by computing the most likely code for each class via eq. (12) and classifying each test image by choosing the class that has highest value of (13). We similarly perform image retrival by mapping each support image to its most likely code (12) and for each query image retrieving the support image that has highest (13).

5 Related Work

Information Bottleneck

The concept of learning short descriptions of data that maximally correlate with the label is closely related to the information bottleneck (Tishby et al., 2000). This principle states that should be maximized while simultaneously minimizing . DIMCO maximizes while setting

to be low via a hyperparameter. DIMCO is also related to the deterministic information bottleneck

(Strouse and Schwab, 2017), which extends the information bottleneck by minimizing instead of . Note that these quantities are related by the inequality , which is tight when is an efficient code.

Information Theory and Unsupervised Representation Learning

Many works have applied information-theoretic principles to unsupervised representation learning. Bell and Sejnowski (1995) uses a mutual information objective to derive an algorithm for blind source separation. Slonim et al. (2005) derives a clustering algorithm based on the rate-distortion tradeoff. Chen et al. (2016) optimizes a lower bound of the mutual information to make a subset of its latent dimensions correlate with specific pre-specified features. Alemi et al. (2017) analyses the objective of VAEs from a rate-distortion theory perspective. Our work also uses information-theoretic principles for representation learning, but we apply these principles to a supervised meta-learning setting.

Discrete Representations

Discrete representations have been studied at least since the beginning of information theory (Shannon, 1948)

. Recent deep learning methods have proposed ways to directly learn discrete representations.

Rolfe (2016); van den Oord et al. (2017)

learn variational autoencoders with discrete latent variables.

Hu et al. (2017) learns discrete representations in an unsupervised manner by maximizing the mutual information between representation and data. In contrast, DIMCO assumes a supervised setting and performs infomax using labels instead of data.

Jeong and Song (2018) is close in spirit to our model: their method learns a quantizable continuous representation. Within each batch, their algorithm solves a minimum cost flow problem to find the locally optimal binary hash code. The training procedure of DIMCO is much simpler since it directly computes its loss function without requiring such inner-loop optimization. Additionally, the focus of Jeong and Song (2018) is on the speedup gained by using sparse binary hash codes, whereas our work focuses on learning an efficient (dense) discrete representation of data.

Factorized Representations

The idea of using factorized representations to increase representation power has appeared in other contexts. Jegou et al. (2011) factorizes a continuous input into a Cartesian product of quantized low-dimensional subspaces. Norouzi and Fleet (2013) uses factorized representations to represent cluster centers with memory. Vaswani et al. (2017) uses as one of its core components multi-head attention, which factorizes the output into the Cartesian product of dot-product attention in several independent subspaces.

Metric Learning

Our analysis provides a unifying view of embedding-based meta-learning Vinyals et al. (2016); Snell et al. (2017); Oreshkin et al. (2018) and image retrieval Hoffer and Ailon (2015); Oh Song et al. (2016); Sohn (2016); Movshovitz-Attias et al. (2017); Wu et al. (2017); Duan et al. (2018) from the perspective of supervised representation learning. We show in appendix A that the loss functions of these methods can be seen as approximation to the mutual information (). While all of these previous methods require a train/test (also called query/anchor) split within each task, DIMCO simply optimizes an information-theoretic quantity of each batch, removing the need for such structured batch construction.

Meta-Learning with Simple Inner-Loop Learners

Many works on gradient-based meta-learning have reported benefits from using few task-specific parameters. Lee and Choi (2018) learns a subset of the full network to alter during task-specific learning. Rusu et al. (2018) explicitly represents each task with a low-dimensional latent space. Zintgraf et al. (2018) alters only a pre-specified subset of the full network during task-specific learning. Our results further support this consensus that meta-learning models with simple task-specific learners generalize to new tasks more easily. We additionally made connections from this idea to information-theoretic principles and used this connection to derive generalization bounds for few-shot learning.

6 Experiments

We use the miniImageNet (Ravi and Larochelle, 2016) and CUB200 (Wah et al., 2011)

datasets with standard splits for both in our experiments. The miniImageNet dataset is a subset of the Imagenet

(Krizhevsky et al., 2012) dataset that was made for few-shot classification. It consists of classes each containing images of size . The classes are split into training, validation, and test classes. The Caltech-UCSD Birds-200-2011 (CUB200) dataset consists of images of birds from classes. The classes are split into training and test classes.

We use two different CNN backbones for our experiments: the 4-layer convnet commonly used for meta-learning (Finn et al., 2017; Sung et al., 2018; Liu et al., 2018), and the Inception network (Szegedy et al., 2015)

with batch normalization

(Ioffe and Szegedy, 2015) which is commonly used for deep image retrieval (Sohn, 2016; Movshovitz-Attias et al., 2017; Wu et al., 2017).

6.1 Correlation of Metrics

[width=]figures/code_figs/bin1-proxies53.png

[width=]figures/code_figs/bin2-proxies4.png

[width=]figures/code_figs/bin1-proxies58.png

[width=]figures/code_figs/bin4-proxies52.png

Figure 2: Results obtained from a small trained DIMCO model (, ). Each code’s location can be expressed as where and . The figure shows the top 10 images in the test set that assign highest probability to a specific code. Code locations were (1, 53), (1, 58), (2, 4), (4, 2), counterclockwise from top left.

This experiment attempts to verify whether mutual information is indeed a reasonable metric for quality of representation. Using the miniImageNet dataset, we trained independent runs of DIMCO with for epochs. We used the test split to compute five metrics: ()-way -shot accuracy, , and .

Due to space constraints, we show the pairwise correlation between these metrics in fig. 5 of the appendix. We see that all metrics are very strongly correlated. We point out that while correlates with previous metrics for fixed and , it is not suitable as a general evaluation metric since its scale depends on hyperparameters: it is roughly proportionate to .

6.2 What does each code learn?

We inspected what features were encoded in a small DIMCO model (, ) after training on miniImagenet. Recall that each image produces a probability matrix (eq. 6,7). For each of these entries, we plotted the top images in the test set that assigned highest probability to that entry. We show images corresponding to four such entries in fig. 2 and more in fig. 7 of the appendix.

The top left code in fig. 2

is representative of the bookshelf class. On the other hand, the bottom right code corresponds to animals with fur and assigns high probability to images of many different classes. We interpret this as DIMCO learning a distributed representation: by aggregating such complementary features in each of its

codewords, DIMCO is able to classify novel classes given only a few datapoints.

6.3 Small Train Set

[width=.49]figures/5way.pdf [width=.49]figures/10way.pdf
[width=.49]figures/20way.pdf [width=.49]figures/R_1.pdf

Figure 3:

Performance of various methods trained with small datasets. The lowermost y axis value for each metric corresponds to the expected performance of random guessing. Quantities shown are the mean and standard deviation of top

runs from a hyperparameter sweep of runs per configuration.

This experiment shows how each model performs when learning with a small dataset. We trained each model using samples from each training class in the miniImageNet dataset. For example, when using samples, we reduced the full train split of ( classes images per class) into ( classes images per class). We compare against three methods: prototypical networks(Snell et al., 2017), Triplet Networks(Hoffer and Ailon, 2015), and multiclass N-pair loss(Sohn, 2016). After training with a subsampled dataset, we test using the full test split.

-way -shot accuracies and of each method are shown in fig. 3. First note that DIMCO is the only method that can be trained with a dataset of example per class. This is because other methods require at least one train and test example per class within each batch, while DIMCO requires no such train/test split and simply maximizes the mutual information within a batch. DIMCO learns much more effectively when the number of examples per class is low. We attribute this to our model’s low inner-loop generalization gap (section 3.2). Because our model can effectively learn using small batches compared to other methods, it can learn using a small total number of training data.

6.4 Fine-Grained Image Retrieval

[width=]figures/bits_recall.pdf

[width=]figures/time_recall.pdf

Figure 4: Image retrieval performance of DIMCO and N-pair loss on CUB-200 dataset. The y-axis for both figures are the metric, and error bars reflect standard deviation computed from runs per configuration. The x-axes represent (left) bits required to store one representation and (right) seconds required to perform retrieval for one query. Both x-axes are log-scale.

We conducted a fine-grained image retrieval experiment using the CUB200 dataset. We compare DIMCO to multiclass N-pair loss (Sohn, 2016), a state-of-the-art deep image retrieval method. For this experiment only, we use the Inception network as specified in the beginning of this section. Using the same Inception encoder backbone, we trained DIMCO with and multiclass N-pair with embedding dimension . We measured the time per query for each method on a single Tesla P40 GPU by averaging the time required for batches of queries of size .

Results in fig. 4 show that the compact code of DIMCO takes roughly an order of magnitude less memory for similar performance to N-pair loss, and has benefits in retrieval query time as well. This experiment also demonstrates that discrete representations can match the performance of state-of-the-art methods on this relatively large-scale task and also is able to train using large neural network backbones without significantly overfitting. For example, experiments reported in Mishra et al. (2017) indicate that MAML (Finn et al., 2017) overfits tremendously when training with a deeper backbone.

7 Conclusion

We introduced DIMCO, a model that learns a discrete representation of data by directly optimizing the mutual information with the label. To evaluate our initial intuition that shorter representations generalize better between tasks, we provided generalization bounds that get tighter as the representation get shorter. We additionally performed meta-learning experiments to show that the concise representations learned by DIMCO generalize well even when learning from very small datasets.

Previous meta-learning models required batches with the specific structure of an evenly balanced train/test split. Because DIMCO can be trained using any batch of labelled data, we believe it is a step towards bridging the gap between the seemingly disparate problems of few-shot classification and traiditional classification.

References

Appendix A Previous Loss functions Are Approximations to Mutual Information

Cross-entropy Loss

The cross-entropy loss has directly been used for few-shot classification [Vinyals et al., 2016, Snell et al., 2017].

Let be a parameterized prediction of given , which tries to approximate the true conditional distribution . Typically in a classification network, is the parameters of a learned projection matrix and is the final linear layer. The expected cross-entropy loss can be written as

(14)

Assuming that the approximate distribution is sufficiently close to , minimizing (14) can be seen as

(15)
(16)

where the last equality uses the fact that is independent of model parameters. Therefore, cross-entropy minimization is approximate maximization of the mutual information between representation and labels .

The approximation is that we parameterized as a linear projection. This structure cannot generalize to new classes because the parameters are specific to the labels seen during training. For a model to generalize to unseen classes, one must amortize the learning of this approximate conditional distribution. [Vinyals et al., 2016, Snell et al., 2017] sidestepped this issue by using the embeddings for each class as .

Triplet Loss

The Triplet loss [Hoffer and Ailon, 2015] is defined as

(17)

where are the embedding vectors of query, positive, and negative images. Let denote the label of the query data. Recall that the pdf function of a unit Gaussian is where are constants. Let and

be unit Gaussian distributions centered at

respectively. We have

(18)
(19)
(20)

Two approximations were made in the process. We first assumed that the embedding distribution of images not in is equal to the distribution of all embeddings. This is reasonable when each class only represents a small fraction of the full data. We also approximated the embedding distributions with unit Gaussian distributions centered at single samples from each.

N-pair Loss

Multiclass -pair loss [Sohn, 2016] was proposed as an alternative to Triplet loss. This loss function requires one positive embedding and multiple negative embeddings , and takes the form

(21)

This can be seen as the cross-entropy loss applied to .

Following the same logic as the cross-entropy loss, this is also an approximation to

. This objective should have less variance than Triplet loss since it approximates

using more examples.

Adversarial Metric Learning

Deep Adversarial Metric Learning [Duan et al., 2018] tackles the problem of most negative exmples being uninformative by directly generating meaningful negative embeddings. This model employs a generator which takes as input the embeddings of anchor, positive, and negative images. The generator then outputs a "synthetic negative" embedding that is hard to distinguish from a positive embedding while being close to the negative embedding.

This can be seen as optimizing

(22)

by estimating using a generative network rather than directly from samples. Rather than modelling the marginal distribution , this method conditionally models so that is hard to distinguish from while sufficiently close to both and .

Appendix B Proof of Theorem 1

The following lemma was proved in Shamir et al. [2010], and we restate it using our notation.

Lemma 1.

Let be a random mapping of . Let be a sample of size

drawn from the joint probability distribution

. Denote the empirical mutual information observed from between and as . For any , the following holds with probability at least :

(23)

We simplify this and plug in our specific quantities of interest (, ):

(24)

We similarly bound the error caused by estimating with a finite number of tasks sampled from . Denote the finite sample estimate of as

(25)

Let the mapping be parameterized by and let this model have VC dimension . Using , we can state that with high probability,

(26)

where is the VC dimension of hypothesis class .

Combining equations (26, 24), we have with high probability

(27)
(28)
(29)

Appendix C Experiments and Implementation Details

Hardware

Every experiment was conducted on a single Nvidia V100 GPU with CUDA 9.2. We used PyTorch version 1.0.1. Each experiment was performed with different fixed initial seeds; we manually fix seeds with

manual_seed() for python, pytorch, and numpy.

Optimizer

For experiments with the 4-layer convnet, we use the Adam optimizer [Kingma and Ba, 2014] with learning rate 3e-4. For the Inception network, we use SGD with learning rate 3e-5 and momentum .

Correlation of Metrics experiment

[width=]figures/pairplot_info.pdf

Figure 5:

Correlation between metrics for representation learning. Each row and column represents one of the five considered metrics. Cells on the diagonal are kernel density estimates of value frequencies, and other cells are scatterplots that correspond to pairs of different metrics.

[width=]figures/pairplot.pdf

Figure 6: Correlation between few-shot accuracy and retrieval measures.

We report the average of batches of -shot accuracies and mutual information. was computed using balanced batches of images each from different classes. We additionally show in fig. 6 the correlation between -shot accuracies, , and NMI using three previously proposed losses (triplet, npair, protonet).

[width=]figures/code_figs/bin1-proxies26.png

[width=]figures/code_figs/bin1-proxies27.png

[width=]figures/code_figs/bin1-proxies37.png

[width=]figures/code_figs/bin1-proxies48.png

[width=]figures/code_figs/bin15-proxies2.png

[width=]figures/code_figs/bin14-proxies31.png

[width=]figures/code_figs/bin15-proxies35.png

[width=]figures/code_figs/bin4-proxies58.png

[width=]figures/code_figs/bin5-proxies60.png

[width=]figures/code_figs/bin5-proxies62.png

Figure 7: Additional examples.

Small Train Set Experiment

For this experiment, we used the Adam optimizer and performed a log-uniform hyperparameter sweep for learning rate For DIMCO, we swept and . For other methods, we made the embedding dimension . For each combination of loss and number of training examples per class, we ran the experiment times and reported the mean and standard deviation of the top .