The Information Complexity of Learning Tasks, their Structure and their Distance

04/05/2019 ∙ by Alessandro Achille, et al. ∙ Scuola Normale Superiore SISSA 26

We introduce an asymmetric distance in the space of learning tasks, and a framework to compute their complexity. These concepts are foundational to the practice of transfer learning, ubiquitous in Deep Learning, whereby a parametric model is pre-trained for a task, and then used for another after fine-tuning. The framework we develop is intrinsically non-asymptotic, capturing the finite nature of the training dataset, yet it allows distinguishing learning from memorization. It encompasses, as special cases, classical notions from Kolmogorov complexity, Shannon, and Fisher Information. However, unlike some of those frameworks, it can be applied easily to large-scale models and real-world datasets. It is the first framework to explicitly account for the optimization scheme, which plays a crucial role in Deep Learning, in measuring complexity and information.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The widespread use of Deep Learning is due in part to its flexibility: One can pre-train

a deep neural network for a task, say finding cats and dogs in images, and then

fine-tune it for another, say detecting tumors in a mammogram, or controlling a self-driving vehicle. Sometimes it works. So far, however, it has not been possible to predict whether such a transfer learning practice will work, and how well. Even the most fundamental questions are still unanswered: How far are two tasks? In what space do tasks live? What is the complexity of a learning task? How difficult it is to transfer from one task to another? In this paper, we lay the foundations for answering these questions.

Summary of contributions

  1. We introduce a distance between learning tasks (Section 4), where tasks are represented by finite datasets of input data and discrete output classes (labels). Each task can have a different number of labels and a different number of samples. The distance behaves properly with respect to composition and inclusion relations (Section 4) and can handle corner cases, such as learning random labels, where other distances such as Kolmogorov’s fail (Section 3). The distance is asymmetric by design, as it may be easier to learn one task starting from the solution to another, than the reverse (Section 4).

  2. Some learning tasks are more difficult than others, so complexity plays a key role in their distance (Section 3). We define a new notion of complexity of a learning task that captures as special cases Kolmogorov complexity, Shannon Mutual Information, and the Fisher Information of the parameters of a model (Section 5.1, Section 5.1, Section 5.1). It leverages and extends the classical notions of Structure Function and minimal sufficient statistics. It also provides a new interpretation to the PAC-Bayes generalization bound (Section 5.2).

  3. We show how a parametric function class, such as deep neural networks trained with stochastic gradient descent, can be employed to measure the complexity of a learning task (

    Section 5). Such complexity is related to the “information in the parameters” of the network. Experiments show that the resulting distance correlates with the ease of transfer learning by fine-tuning a pre-trained model (Section 7).

  4. We show that the asymmetric distance between tasks is a lower bound to the cost of transfer learning using deep neural networks. We introduce the notion of accessibility (or reachability) of a learning task, and illustrate examples of tasks that, while close to each other, are not accessible from one another. Accessibility depends on the global geometry of the space of tasks and on the optimization scheme employed. We conjecture that tasks are accessible through a local optimization scheme when there exists a continuous path of minimizers of the Structure Function (Section 6.2).

Our notion of complexity of a learning task formalizes intuitions regarding overfitting and generalization, and differentiates learning from memorization. It is, to the best of our knowledge, the first complexity measure to take into account the training algorithm, rather than just its asymptotic outcome. It is also the first to enable practical computation for models with millions of parameters, as those currently in use in real applications.

Organization and nature of the paper

The four main sections (3, 4; 5, 6) are organized as a matrix: Along the rows, Sections 3 and 4 use the language of Kolmogorov to introduce the notions of complexity, structure, and distance of learning tasks, whereas Sections 5 and 6 introduce our more general framework, with classical concepts from Kolmogorov, Shannon and Fisher emerging as special cases. Along the columns, Sections 3 and 5 deal with complexity and the associated notion of information of a single task, whereas Sections 4 and 6 deal with the distance between tasks.

Related work

This work aims to build the theoretical foundations of transfer learning, an active area of research too vast to review here (see [5] for a recent survey). More specifically, [14] computes pairwise distances between tasks by experimentally fine-tuning models for each pair, which does not scale, rather than using the characteristics of a task to predict whether such fine-tuning will succeed; [1] uses the Fisher Information Matrix as a linear embedding, and predicts fine-tuning performance in large scale tasks, providing empirical validation to our framework in Section 5. The existence of tasks that are similar according to most distances, yet one cannot “reach” one from another, has been observed in [3] and motivates our exploration of the dynamics of learning in Section 6. The notion of reachability in dynamical systems has been studied extensively for known models, whereas in our case the model changes as the training progresses. Numerical methods for reachability computation [9] are restricted to low-dimensional spaces, whereas typical models in use have millions of parameters. Our work leverages classical results from Kolmogorov’s complexity theory [13], classical statistical inference, and information theory, and relates to recent theoretical frameworks for Deep Learning including the Information Bottleneck Principle [12], with important differences that we outline throughout the paper.

2 Preliminaries and nomenclature

In supervised learning, one is given a a finite (training) dataset

of samples, where is the input data (e.g., an image) and is the output (e.g., a label). The goal is to learn (i.e.

, estimate the parameters

of) a model (a parametric function) that maps inputs to estimated outputs , so that some loss (or risk) is minimized on unseen (test) data. It is common to assume that consists of i.i.d. samples from some unknown distribution

, which are used to assemble a loss function, that is minimized with respect to the parameters

so that is close to the “true” posterior .

In this work however, we make no assumption on the data generation process, nor do we attempt to approximate the true posterior , even when we consider datasets composed of samples from some unknown distribution. Instead, we adopt Kolmogorov’s approach to learning the “structure” in the data, directly accounting for the finite sample size , the approximation error in the loss , and the complexity of the model. The essential elements of Kolmogorov Complexity theory that we use are summarized in [13]. Deep neural networks (DNNs) are a particularly powerful and efficient class of parametrized models, obtained by successive compositions (layers) of linear multiplication by weight (parameter) matrices, and simple element-wise non-linear operations. In Deep Learning, the optimization scheme acts as an implicit regularizer in the loss. For a measure of complexity to be relevant to Deep Learning, therefore, it must take into account both the loss function and the optimization scheme. To the best of our knowledge, none of the classical measures do so.

2.1 Deep Neural Networks

Deep neural networks are a class of functions (models) that implement a successive composition (layers) of affine operations (where both the linear term and the offset are considered as weights) and a non-linearity (such as a saturation or rectification). One of the most common non-linearities is the rectified linear unit (ReLU), which leaves the positive part unchanged and zeroes the negative part. The first layer is thus of the form

, where:

are the weights, or improperly the “weight vector”;

is defined as the component-wise maximum between and . The output of the first layer is therefore given by

. The second layer performs an operation of the same type, taking the output of the first layer as input, and with a different set of weights, and so on. The last layer produces a probability vector

, with components , usually through a soft-max operation, i.e., . The -th entry of this probability vector represents the probability of the input being of class , as assessed by the model.

The learning criterion is given by maximum likelihood: , where is the loss function and is given by . The loss function needs to be regularized, explicitly or implicitly, since the number of weights is typically larger than the number of samples in the training set. The loss function can be interpreted as an empirical approximation of the average cross-entropy , which is zero when .

The training process consists of minimizing the empirical cross-entropy using stochastic gradient descent (SGD). At every iteration, SGD makes a step in the direction of the (negative) gradient of , which is approximated using a random subset of the training set (a minibatch

). The length of the step is a hyperparameter called the

learning rate.

3 Complexity of a Learning Task

A learning task is specified by a dataset . However, the same task could be specified by different datasets. Their complexity is not just related to the size of the input, the output, or the number of samples: the popular MNIST and CIFAR-10 classification tasks are similar on these counts, yet one is very easy to learn, the other is not. Instead, the complexity of a dataset is a function of two factors: The underlying structure shared among data points, as well as variability that individual points exhibit relative to the shared structure. This split is not unique: a dataset can have many different explanations, each leaving a different amount of residual variability. These two factors are captured in the following definition. Let be a dataset. We define the complexity of as

(1)

where

is the empirical classification (cross-entropy) loss, and the minimum is over all possible computable probability distributions

of the label given the input . By we denote the Kolmogorov complexity of the distribution .

At first sight, appears similar to a conditioned version of the two-part code in [13, Appendix II]:

(2)

where and are strings obtained by concatenating all labels and inputs of , respectively. In Section 3.1, we recall that coincides with the conditional Kolmogorov complexity of the string given . Instead, in eq. 1 we only consider factorized distributions . This has major consequences for determining the complexity of a task: The distribution minimizing eq. 2 does not need to encode in all task-relevant information, as it can, at test time, extract any missing information from the training set (cf. Section 3.3). Hence, in eq. 2 alone would not be a valid measure of complexity. Instead, the distribution minimizing eq. 1 can only access a single datum at test time; hence, all structure in the dataset has to be encoded in .

Moreover note that, unlike , is invariant to permutations of the dataset. Suppose that a dataset for a binary classification task is ordered so that all negative samples precede the positive samples . Then we would have regardless of the complexity of the task, as it suffices to encode the number of negative and positive samples to reproduce the string exactly. However, we show in Section 3.3 that making permutation invariant does not yield a sensible measure of complexity of a task. To address permutation invariance, [8] proposes the following definition, which uses deterministic functions rather than probability distributions:

(3)

The following proposition compares these definitions of complexity.

[Measures of complexity] Up to an additive term which does not depend on the dataset , we have:

  1. .

  2. for any permutation of .

  3. For every , there is a dataset such that and for any permutation . Therefore, is not simply the (average) complexity of encoding some permutation of .

  4. When is defined, i.e., if there is a function such that for each , then we have . In addition, if we have an oracle that provides a bijective map .

Assuming the data are sampled i.i.d. from a computable probability distribution , the following proposition characterizes the complexity of the dataset. In particular, it shows that, asymptotically, the complexity of the dataset is given by the noise in the labels and the complexity of the distribution generating the data. How to ignore the effect of noise on the complexity, and what happens in the non-asymptotic regime, is central to Kolmogorov’s framework, which we will build upon in the next sections.

Fix a probability distribution on , and assume that is computable. If is a collection of i.i.d. samples , then:

  1. The expected value of satisfies

    where is the conditional entropy of given .

  2. For any there is such that, with probability , for any we have the equality

    and is the only computable distribution for which the equality holds.

It is instructive to test our definition of complexity on a dataset of random labels, whereby each input is assigned a label at random. We will revisit this case often, as it challenges many of the extant theories of Deep Learning [15].

[Complexity of random labels] Suppose that each input of the dataset is associated to a label sampled uniformly at random in a fixed finite set, so has a constant complexity. Under the same assumptions of Section 3.4, the expected value of is by Section 3. Since for any such , the complexity of a “typical” dataset with random labels is approximately .

In a sense, learning random labels is a very complex task: They cannot be predicted from the input, so the model is forced to memorize them (“overfitting”). Accordingly, in Section 3 is very high. However, by construction there is no structure in the data, so the model cannot generalize to unseen data. In other words, this is a complex memorization task, but a trivial learning task, as there is nothing to learn from the data. We would like the definition of complexity to reflect this, and to differentiate between learning and memorization.

Another important aspect not captured by this definition of complexity is the role of performance in learning the structure in the data. For example, one can train a trivial model to distinguish airplanes from fireplaces by counting the number of blue pixels. It will not be very precise. To achieve a small error, however, one must learn what makes an airplane different from a fireplace even if the latter is painted blue.

3.1 Structure Function of a Task

The trade-off between the loss achievable by a solution on a dataset and its complexity is captured by the Structure Function [13]:

(4)

It is a decreasing function that reaches zero for sufficiently high complexity, depending on the task: As we increase complexity, the loss decreases rapidly while simple models correctly classify easy samples. After all shared structure is captured, the only way to further reduce the loss is to memorize ever samples, leading to the worst possible trade-off of one NAT of complexity for each unit of decrease in loss. Eventually, every dataset enters this

linear (overfitting) regime. For random labels, this happens at the outset.

[Structure Function for random labels] By definition of we have for any computable probability distribution . For a typical dataset with random labels (see Section 3), we have that . Therefore

The lower bound for can be achieved by memorizing the label of data points. Therefore .

The Structure Function of a dataset cannot be computed in general ([13, Section VII]). In Section 5 we introduce a generalized version that can: Figure 2

shows the result on common datasets. The predicted fast decrease in the loss as complexity increases is followed by the asymptotic linear phase. The sharp phase transition to the linear regime is clearly visible as a function of the loss: As the parameter

weighting complexity increases, a plateau is reached that depends on the task. Note that for random labels the loss decreases linearly as expected (left).

3.2 Task Lagrangian and Minimal Sufficiency

The constrained optimization problem in the definition of the Structure Function (eq. 4) has an associated Lagrangian , where is a Lagrange multiplier that trades off the complexity of the model with the fidelity . If we take the minimum over , we obtain a family of complexity measures parametrized by :

(5)

As a function of , this is the Legendre transform of the Structure Function . To minimize , we can increase the complexity of the model until the return of doing so has a ratio which is smaller than the constant we have selected.

If is a minimizer of eq. 5 for , the corresponding Kolmogorov complexity is the value at which the Structure Function reaches a linear regime. Thus, the special case marks the transition to overfitting, and is related to Kolmogorov’s notion of Minimal Sufficient Statistic [13]. Since we are using Kolmogorov’s complexity, is the worst possible trade-off.

Given a task , let be the largest for which is not realized by a constant distribution . Then , and if is a typical dataset with random labels.

As we have seen for random labels, a dataset may be complex and yet exhibit little underlying structure. We say that a distribution is a Kolmogorov sufficient statistic of if it minimizes . It is minimal if it also minimizes among all sufficient statistics, that is, the smallest statistic that is able to solves the task optimally. The rationale is that the smallest statistic that solves a task should not squander resources by memorizing nuisance variability. Rather, it should only capture the important information shared among the data. This is shown in the following example.

For random labels, both the distribution

that memorizes all the labels in the dataset, and the uniform distribution

, are sufficient statistics. However, only the latter is minimal, since is a constant which does not depend on . There is no structure to be extracted from a dataset of random labels.

The level of complexity of a model is an important design parameter the practitioner wishes to control. Rather than seeking minimal sufficient statistics, we explore the entire trade space, by introducing the notion of -sufficiency.

Given a dataset , define a -sufficient statistic of as a probability distribution such that . We say that is a -minimal sufficient statistic if it also minimizes among all -sufficient statistics.

Notice that, for , Section 3.2 reduces to minimal sufficiency in the sense of Kolmogorov.

4 Asymmetric Distance between Tasks

We now introduce a one-parameter family of distances between tasks. The hope is for them to correlate with the ease of transfer learning. Since it is easier to learn a simple task from a complex one, it is desirable for the distance to be asymmetric.

The asymmetric distance between tasks and at level is

where varies among all -minimal sufficient statistics of .

The intuition behind this definition is the following: for a task to be close to , every -minimal sufficient statistic of should be close to some -minimal sufficient statistic of . Then, every optimal model of can be fine-tuned to some optimal model of .

The asymmetric distance satisfies the following properties:

  •  (positivity);

  •  (a task is close to itself);

  •  (triangle inequality).

We now derive a characterization of the distance between tasks based on the complexity of their composition, amenable to generalization in Section 6. Denote by the disjoint union of two datasets and , defined as

Notice that an index is added to the input, in order to recognize the original dataset. A desirable property for a distance between tasks would be that : Indeed, a model that performs well on should be easily fine-tuned to a model that performs well on alone. Adding the index in the definition of is essential, as we can see in the following example.

Let be a typical dataset with random labels, and let be the set of data points satisfying some property of Kolmogorov complexity . Then the whole dataset has a trivial structure, whereas has a complicated structure. If both and are large, by Section 3 the Kolmogorov complexity of a minimal sufficient statistic of is , and the complexity of a minimal sufficient statistic of is .

We now prove that, under the hypotheses of Section 3.2, the property is satisfied.

Suppose that and are obtained by sampling from two fixed distributions and on , such that and are computable. Then, with high probability and for and sufficiently large, .

Under the hypothesis of Section 4, we have that

where varies among the -minimal sufficient statistics of and varies among those of .

We now have a way of comparing different learning tasks, at least in theory. The asymmetric distance allows us to quantify how difficult it is to learn a new task given a solution to . However, quantities defined in terms of Kolmogorov complexity are difficult to handle in practice, and may behave well only in an asymptotic regime. In the next section, we introduce a generalization of the framework developed so far that can be instantiated for a particular model class such as deep networks.

5 Information in the Model Parameters

Whatever structure or “information” was captured from the dataset, it ought to be measurable from the model parameters, since they are all we have left after training. As we will see in the next section, this intuition is faulty, as how we converge to a set of parameter (i.e., the optimization algorithm) also affects what information we can extract from the data. For now, we focus on generalizing the theory in the previous section with an eye towards computability. Although most of our arguments are general, we focus on deep neural networks (DNNs) as a model class. They can have millions of parameters, so measuring their information can be non-trivial.

One way to compute information in the parametersis to measure their coding length at some level of precision, independent of the particular task. This is suboptimal, as only a small subset of the weights of a trained neural networks matters: Imagine changing a certain component of the weights, and observing no change in the loss. Arguably, that weight “contains no information” about the dataset. For the purpose of storing the trained model, that weight could be replaced with any constant, or randomized each time the network is used. The loss landscape has small curvature111We will elaborate on this point after Section 5.1. in the coordinate direction corresponding to that weight. On the other hand, imagine changing the least significant bit of another component of the weights and noticing a large increase in the loss. That weight is very “informative,” so it is useful to store its value with high precision.

With these observations in mind, we allow the weights to be encoded with some uncertainty, through a probability distribution which depends on the dataset . For example, Dirac’s Delta corresponds to an exact encoding of the weight vector . If we fix a reference “prior” distribution , [7] shows that the labels can be reconstructed from the input and the prior , by using

additional NATS. This expression resembles the right-hand side of eq. 1 in capturing a trade-off between fidelity and complexity. Here, Kolmogorov complexity has been replaced by the Kullbach-Liebler (KL) divergence which we call the information in the parameters of the model.222 This should not be confused with the information in and , or any intermediate representation that is build by the model, such as the activations of a DNN, which is more frequently studied. Information in the parameters and information in the activations are different and, in the case of DNNs, are related through the Emergence Bound [4]. This leads to the following new definition of complexity.

The complexity of the task at level , using the posterior and the prior , is given by

(6)

The second term, , measures the information in the parameters of the model. We refer to as the (expected) reconstruction error of the label under the hypothesis .

We call a “posterior” as it is a distribution decided after seeing the dataset . There is no implied Bayesian interpretation, as can be any distribution. Similarly, is a “prior” because it is picked before the dataset is seen. Depending on the choice, this expression can be computed in closed form or estimated (Section 5.1). For instance, when , the expression reduces to the length of a two-part code for using the model class . However, Section 5 is more general and can be extended to the continuous case, or in cases where there is a bona fide distribution, as in variational inference and Bayesian Neural Networks.

Another fundamental difference is that, while eq. 1 measures the complexity in terms of the best obtainable by the model class (in that case, the class of computable probability distributions), the complexity takes into account both the particular model class and the training algorithm, i.e., the map , as we shall see in Section 6.2.

5.1 Relation with Kolmogorov, Shannon, and Fisher Information

Since the choice of prior in Section 5 is arbitrary, we investigate three special cases: The “universal prior” of all computable distributions; an “adapted prior” which relies on a probability distribution over datasets; an uninformative prior, agnostic of the dataset.

We start with the first case, which provides a link between Section 5 and the framework of Section 3. For a given weight vector , we define the universal prior , where is a normalization constant. This can be interpreted as follows: for every , choose a minimal program that outputs , and assign it a probability which decreases exponentially in terms of the length of the program.

[Kolmogorov Complexity of the Weights] Let be the universal prior, and let be a Dirac delta. Then the information in the weights equals the Kolmogorov complexity of the weights , up to a constant.

We now turn to the second case, which provides a link with Shannon mutual information.

[Shannon Information in the Weights] Assume the dataset is sampled from a distribution , and let the outcome of training on a sampled dataset be described by a distribution . Then the prior minimizing the expected complexity is , and the expected information in the weights is given by

(7)

Here is Shannon’s mutual information between the weights and the dataset, where the weights are seen as a (stochastic) function of the dataset given by the training algorithm (e.g., SGD).

Note that, in this case, the prior is optimal given the choice of the training algorithm (i.e., the map ) and the distribution of training datasets . However, the distribution is generally unknown, as we are often given a single dataset to train. Even if it was known, computing the marginal distribution over all possible datasets would not be realistic, as it is high-dimensional and has complex interactions between different components. Nevertheless, it is interesting that the information in the parameters specializes to Shannon’s mutual information in the weights [4].

The third case, namely an uninformative prior with a Gaussian posterior, is the most practical, and provides a link to the Fisher Information Matrix and the learning dynamics of common optimization algorithms such as SGD.

[Fisher Information in the Weights] Choose an isotropic Gaussian prior . Let the posterior be also Gaussian: , where is a local minimum of the cross-entropy loss. Then, for , we have that:

  • the covariance which minimizes tends to (this is in accordance with the Cramér-Rao bound);

  • the information in the weights is given by

Recalling that the Fisher Information measures the local curvature, this proposition confirms the qualitative discussion of the beginning of Section 5: The optimal covariance gives high variability to the directions of low curvature, which are “less informative,” whereas it gives low variability to the “more informative” directions of high curvature. The Fisher Information describes the information contained in the weights about the dataset. In Section 6 we discuss how to get there.

5.2 Connections with the PAC-Bayes Bound

The Lagrangian admits another interpretation as an upper-bound to the test error, as shown by the PAC-Bayes test error bound:

[[11, Theorem 2]] Assume the dataset is sampled i.i.d. from a distribution , and assume that the per-sample loss used to train is bounded by (we can reduce to this case by clipping and rescaling the loss). For any fixed , prior , and weight distribution , with probability at least over the sample of , we have:

(8)
(9)

where is the expected per-sample test error that the model incurs using the weight distribution . Hence, we see that minimizing the Lagrangian can be interpreted as minimizing an upper-bound on the test error of the model, rather than directly minimizing the train error. This is in accordance with the intuition developed earlier, that minimizing forces the model to capture the structure of the data. It is also interesting to consider the following bound on the expectation over the sampling of ([11, Theorem 4]):

As we have seen in Section 5.1, for the optimal choice of prior minimizing the bound, we have . Hence, the Shannon Information that the weights of the model have about the dataset is the measure of complexity that gives (on expectation) the strongest generalization bound. This has also been noted in [4]. In [6], a non-vacuous generalization bound is computed for DNNs, using (non-centered and non-isotropic) Gaussian prior and posterior distributions.

6 Generalized distance, reachability, and learnability of tasks

Unlike , the definition of complexity in eq. 6 properly captures the complexity of a dataset for a particular model class and training algorithm. Motivated by this, we now define a distance between datasets which is tailored to the model.

Throughout this section, fix a parametrized model class , a reference prior , and a class of posterior distributions . The case we are most interested in is that of DNNs, with an uninformative prior and a Gaussian posterior (see Section 5.1). Our starting point is to generalize Kolmogorov’s Structure Function framework. Consider the following generalized Structure Function:

Here the minimum is taken among all posterior distributions in the chosen class . Similarly to what we have seen in Section 3.2, this minimization problem has as its associated Lagrangian. We say that is a -sufficient statistic if it minimizes . It is a -minimal sufficient statistic if it also minimizes . Motivated by Section 4, we then introduce the following distance.

The asymmetric distance between tasks and at level is

where is a -minimal sufficient statistic for , and is a -minimal sufficient statistic for .

While more general and amenable to computation than the distance of Section 4, this is not an actual distance as it lacks several of the properties in Section 4. Nonetheless, we will show that does indeed capture the difficulty of fine-tuning from one dataset to another using specific model families (such as DNNs) and training algorithms (SGD), and empirically shows good correlation with other distances (e.g., taxonomical), when those are defined.

6.1 Reachability of a Task through a Local Learning Algorithm

Until now, we have only considered global minimization problems, where we aim to find the best solution satisfying some optimal trade-off. However, many learning algorithm (e.g., SGD) are local: Starting from the current solution, they take a small step to greedily minimize some objective function. While this greatly reduces the complexity of the algorithm, it raises the question of which conditions allow such an algorithm to recover an optimal solution to the task.

Given a distribution , denote by the expected loss . Fix a metric on such that both and are continuous, as functions . In this way, the Lagrangian is continuous in the joint variable . In the case of DNNs, where is the uninformative prior and

is the class of the Gaussian distributions

, we can take for example as the Wasserstein distance, or the Euclidean distance between the parameters .

[-local learning algorithm] Fix . We say that a step is -local if, starting from a given statistic , it finds the statistic that minimizes

and such that . We say that a learning algorithm is -local if it only takes -local steps.

In the limit , this reduces to gradient descent on the Lagrangian . Notice however that this is not the same as performing gradient descent on the cross-entropy loss , unless . Indeed, minimizing , unlike minimizing (Section 5.2), gives no guarantees on the performance on test data, as the learning algorithm could simply memorize the dataset . We will show in the next section than a DNN trained with SGD can actually be intepreted as a local learning algorithm minimizing . A natural question is: Does an -local learning algorithm always recover a global minimum of ?

In practice, when training a DNN with SGD, one starts with a high learning rate and anneals it throughout the training. In our framework, this corresponds to starting with a high value of , and gradually decreasing it to a final value . This helps avoiding degenerate solutions, because the model starts by favouring structure over memorization.

An -local learning algorithm with annealing is a learning algorithm that alternates -local steps (that change the distribution ) and annealing steps (that can decrease the value of ).

Notice that, if the annealing is slow, then an -local learning algorithm with annealing can be regarded as a discrete gradient descent with respect to the joint variable . The following result gives a sufficient condition for an -local learning algorithm with annealing to recover the global minimum of .

Fix an annealing schedule . Suppose that, for every global minimizer of , there exists a global minimizer of with . Then, an -local learning algorithm with annealing that starts from a global minimizer of , and performs one -local step after each annealing step, computes a global minimizer of .

We say that a task that satisfies the conditions of Section 6.1 for some annealing schedule is -connected. In general, we cannot guarantee that an -local learning algorithm can get stuck in a local minimum of . For example, this is bound to happen if there is no sequence such that is a global minimum of , , and for all . In particular, this happens if there is no continuous path of global minima, which is an intrinsic property of the dataset with respect to the function class.

6.2 SGD as a local learning algorithm

So far, we have introduced an abstract notion of distance between tasks. However, we have not yet shown that this notion is useful for DNNs, or that indeed SGD is a local algorithm in the sense of Section 6.1.

In [2], a step in this direction is taken. It is shown that, in first approximation, the probability of SGD converging to a configuration solving task in a given time , starting from a configuration , is given by:

(10)

when using the prior , the optimal Gaussian posterior , and . Here , is the weight decay coefficient used to train the network, and is a temperature parameter that depends on the learning rate and the batch size .

From eq. 10 we see that, with high probability, SGD takes steps that minimize the effective potential (static part), while trying to minimize the distance traveled in the given time (dynamic part). Hence, SGD can be seen as a stochastic implementation of a local learning algorithm in the sense of Section 6.1.

In particular, this has the non-intuitive implication that SGD does not greedily optimize the loss function as it may seem from its update equation. Rather, on average, the complexity of the recovered solution affects the dynamics of the optimization process, and changes the objective. Therefore, the complexity we have introduced in Section 5 is not simply an abstract mean to study different trade-offs, but rather plays a concrete role in the learning dynamics.

Since SGD can be seen as a local learning algorithm, annealing the learning rate during training (i.e., annealing the parameter ) can be interpreted as a way of learning the structure of a task by slowly sweeping the Structure Function. Hence, SGD with annealing adaptively changes the complexity of the model, even if its dimensionality is fixed at the outset. This creates non-trivial dynamics that turn out to be beneficial to avoid memorization. It also points to the importance of the initial transient of learning during the annealing, a prime area for investigation beyond the asymptotics of the particular value of the weights, or the local structure of the residual around them, at convergence.

Another consequence is that, when the task is not -connected, SGD may fail to recover the optimal structure of the data. One may wonder if there are examples of simple tasks which are severely non--connected, and if SGD actually fails to solve them.

A particularly interesting example comes from a biological inspiration. In [3] it is reported an example of two tasks — a dataset of blurred images and one of the same high-resolution images — which are apparently close to each other, but such that a network trained on the first one cannot properly be fine-tuned to solve the second one. This peculiar phenomenon has analogues in biology, and it is shown in [3] to be indeed closely correlated with changes of the Fisher Information Matrix. Within our framework, this can be interpreted as biasing the initial optimization process toward a minimum of for the blurred task , which, while also being a local minimum of the Lagrangian of the original task, is not close to any of the global minima. Hence, the local learning algorithm, rather than learning the correct structure, starts performing sub-optimal greedy choices.

Even though the static term of eq. 10 is high (i.e., the distance is small), this example indicates that that the dynamic term can give a non-trivial contribution. This phenomenon is observed across different architectures and local optimization schemes. It opens the door to further investigations into the dynamics of differential learning.

7 Empirical validation

Figure 1: (Left) Estimated distance matrix between several tasks. Each entry shows the distance going from the column task to the row task . Going from a complex task like CIFAR-100 to a simpler task (like MNIST) is always easier than the converse. Subtasks are close to the full tasks (e.g., the subset of “artificial” and “natural” objects of CIFAR-100 are both close to CIFAR-100). Similar tasks on the domain of small black and white images (Fashion MNIST, MNIST, Letters) are also closer together than to natural images. Inverting the colors on Fashion images leads to a very similar task (I-Fashion), as expected. (Right) T-SNE embedding of several organism species classification tasks and clothing attributes classification tasks based on their distance, reproduced from [1], which uses a similar definition of task distance based on the Fisher Information. Intuitively, similar tasks cluster together. In the case of species classification, this largely follows the taxonomical structure.

The theoretical framework developed in this paper has tangible ramifications. A robust version of the asymmetric distance between tasks described here has been used in [1], to create a metric embedding of hundreds of real-world datasets (Figure 1, right). The structure of the embedding shows good accordanc with the complexity of the tasks (Figure 2), and both with intuitive notions of similarity and with other metrics, such as taxonomical distance, when those are available (Figure 1). The metric embedding allows tackling important meta-tasks, such as predicting the performance of a pre-trained model on a new datasets, given its performance on past similar datasets. This greatly simplifies the process of recommending a pre-trained model as the starting point for fine-tuning a new task, which otherwise would require running a large experiment to determine the best model. The results are also shown to improve compared to pre-training a model on the most complex available task, i.e.

, full ImageNet classification.

In [3], it is shown that there are tasks close to each other (with respect to our asymmetric distance) which cannot be easily fine-tuned. Indeed, there are two nearly identical tasks (classifying CIFAR images, or a slightly blurred version of them) which turn out to be unreachable from one another: Once pre-trained on the blurred data, no matter how much additional training occurs, performance on the non-blurred data remains sub-optimal.

Figure 2: (Left) Approximation of the Structure Function . Notice that the cross-entropy loss on simpler datasets rapidly decreases as we increase the capacity of the model. On the other hand, for random labels, the decrease in error is much slower and follows an almost linear trend, as predicted by Section 3.1. (Right) Plot of the loss obtained when minimizing the lagrangian , as a function of (which controls the trade-off between complexity and error). All datasets show a sharp transition between overfitting and underfitting at some value of . Simpler datasets can still be fit for a high , since models of low complexity can already correctly classify the data. On the other hand, more complex datasets need a very low to fit the data: In particular, random labels have the worst trade-off. The position of the phase transition depends on the complexity of the data distribution, and is mostly independent of the dataset size: the random labeled datasets all transition at a similar point, despite the difference in size.

Besides these independent validations, we also report some additional experiments that illustrate the concepts introduced in this paper. In Figure 2 (right), we plot the loss as a function of . As predicted by Section 3.2, the Lagrangian exhibits sharp phase transitions at a critical value , when the model transitions from fitting the dataset with a trivial uniform distribution to actually fitting the labels. Regardless of their size, datasets of random labels transition around the same value . The other datasets transition at an higher value, which depends on their complexity. For example, a complex dataset such a CIFAR-100 transitions at a much lower than a simple dataset such as MNIST. Notice that, in this experiment, the critical value for random labels is not : This is because the complexity is computed using an uninformative prior (see Section 5.1), and not the universal prior as in Section 3.2.

8 Discussion and Open Problems

The modern practice of Deep Learning is rich with success stories where high-dimensional parametric models are trained on large datasets, and then adapted (fine-tuned) to specific tasks. This process requires considerable effort and expertise, and few tools are available to predict whether a given pre-trained model will succeed when fine-tuned for a different task. We have started developing a language to reason about transfer learning in the abstract, and analytical tools that allow to predict the success of transfer learning.

The first step is to properly define the tasks and the space they live in. We take a minimalistic approach, and identify a task with a finite dataset. The second step is to endow the space of tasks with a metric. This is non-trivial, since different datasets can have a different cardinality and dimension, and one needs to capture the fact that a simple task is usually “closer” to a complex one than vice-versa (in the sense that it is easier to fine-tune a model from a complex task to a simpler one). Thus, a notion of complexity of a learning task needs to be defined first.

We introduce a notion of complexity and a notion of distance for learning tasks that are very general, and encompass well-known analogues in Kolmogorov complexity theory and Information Theory as special cases. They espouse characteristics of Kolmogorov’s framework (focusing on finite datasets rather than asymptotics) with the intuitive notions and ease of computation of Information Theory. We use deep neural networks to compute information quantities. On one hand, this provides a convenient way of instantiating our general theory. On the other hand, it allows to measure the complexity of a Deep Network, and to reason about generalization (learning vs. memorization) using a non-asymptotic language, in terms of quantities that are measurable from finite datasets.

Our theory exposes interesting connections between Deep Learning, Complexity Theory, and Information Theory, and PAC-Bayes theory. These connections represent a fertile ground for further theoretical investigation.

Much work is yet to be done. The (static) distance we introduce only gives a lower bound to the feasibility of transfer learning with deep neural networks, and dynamics plays an important role as well. In particular, there are tasks that are very close, yet there is no likely path between them, so fine-tuning typically fails. This is observed broadly across different architectures and optimization schemes, but also across different species in biology, pointing to fundamental complexity and information phenomena yet to be fully unraveled.

Figure 3: Correlation between test errorand the trace of the Fisher Information Matrix on several tasks (reproduced from [1]). The Fisher Information Matrix emerges as a complexity measure when using an uninformative prior (Section 5.1). This plot shows that that the FIM trace correlates with the error obtained on the task, and so the FIM is indeed a sensible measure of the complexity of a task.

Acknowledgment

Work supported by ONR N00014-19-1-2229 and ARO W911NF-17-1-0304.

References

Appendix A Proofs

Proof of Section 3.

(1) Let be such that . We can compress using , for example with an algebraic code of length . A program that outputs given then only needs to encode the distribution and the code for , requiring NATS, so . For the opposite inequality, let be the program that witness and let . Then and .

(2) Clearly, as minimizes over a smaller subset of distributions. Since is permutation invariant, we have . Hence, .

(3) Fix a function such that , and let be a program for . Now, consider the dataset , where

Here, the first bit is added in order to recognize the “special” data point. We have for any permutation , because the concatenation of the input data of contains an encoding of . On the other hand, , where is the datased obtained from by removing the special data point, and if is sufficiently large.

(4) Let be a function such that for every . Consider the probability distribution defined by for every . Then we have that , and . Choosing such that , we get .

To prove the equality, let be the bijective function provided by the oracle. Now, create a list of codes so that contains the encoding of constructed using the distribution . The length of the prefix code is (we need a prefix code so that we can concatenate all the codes). Now, given the distribution , we can construct a function such that as follows: Given , compute , read the code , and decode it using the distribution to obtain the correct . ∎

Lemma .

Fix a probability distribution on , and assume that is computable. Suppose that is a collection of i.i.d. samples . Let , and let be the maximum likelihood estimation (MLE) of . Then, for every there exists such that, in the limit we have

with probability .

Proof.

Consider and as vectors of size . Expand at :

Where is on the line connecting and . The first term is zero, since the MLE estimation is by definition a minimum of . Now, recall that the MLE converges to the real distribution as , where is the Fisher Information Matrix computed in , and . Let be such that, with probability , . Then,

Proof of Section 3.

(i) By Shannon’s coding theorem, the expected value of is at least . Then, the first inequality follows from part (1) of Section 3. If we use itself in the definition of , we obtain . The expected value of is , so we obtain the second inequality.

(ii) We need to prove that, for any other distribution , we have:

In fact, suppose that . Then we have

Notice that we can lower-bound the LHS using the MLE estimator , which by definition minimizes . By Appendix A, we have , hence

. By the central limit theorem, the LHS grows as