A Max-Sum algorithm for training discrete neural networks

05/20/2015 ∙ by Carlo Baldassi, et al. ∙ 0

We present an efficient learning algorithm for the problem of training neural networks with discrete synapses, a well-known hard (NP-complete) discrete optimization problem. The algorithm is a variant of the so-called Max-Sum (MS) algorithm. In particular, we show how, for bounded integer weights with q distinct states and independent concave a priori distribution (e.g. l_1 regularization), the algorithm's time complexity can be made to scale as O(N N) per node update, thus putting it on par with alternative schemes, such as Belief Propagation (BP), without resorting to approximations. Two special cases are of particular interest: binary synapses W∈{-1,1} and ternary synapses W∈{-1,0,1} with l_0 regularization. The algorithm we present performs as well as BP on binary perceptron learning problems, and may be better suited to address the problem on fully-connected two-layer networks, since inherent symmetries in two layer networks are naturally broken using the MS approach.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

The problem of training an artificial, feed-forward neural network in a supervised way is a well-known optimization problem, with many applications in machine learning, inference etc. In general terms, the problem consists in obtaining an assignment of “synaptic weights” (i.e. the parameters of the model) such that the device realizes a transfer function which achieves the smallest possible error rate when tested on a given dataset of input-output examples. Time is usually assumed to be discretized. In a single-layer network, the transfer function is typically some non-linear function (e.g. a sigmoid or a step function) of the scalar product between a vector of inputs and the vector of synaptic weights. In multi-layer networks, many single-layer units operate in parallel on the same inputs, and their outputs provide the input to other similar (with a varying degree of similarity) units, until the last layer is reached.

The most popular and successful approaches to these kind of optimization problems are typically variants of the gradient descent algorithm, and in particular the back-propagation algorithm rumelhart1986learning . On single-layer networks with simple non-linearities in their output functions these algorithms can even be shown to achieve optimal results in linear time engel_statistical_2001 ; on multi-layer networks these algorithms suffer from the usual drawbacks of gradient descent (mostly the presence of local minima, and slow convergence under some circumstances).

On the other hand, gradient descent can only be applied to continuous problems. If the synaptic weights are restricted to take only discrete values, the abovementioned family of methods can not be applied; in fact, it is known that even the simplest version of the problem (classification using a single-layer network) becomes computationally hard (NP-complete) in the worst-case scenario amaldi1994review ; blum1992training . However, some theoretical properties of the networks, such as the storage capacity (i.e. the amount of information which can be effectively stored in the device by setting the synaptic weights), are only slightly worse in the case of discrete synapses, and other properties (e.g. robustness to noise and simplicity) would make them an attractive model for practical applications. Indeed, some experimental results petersen1998all ; o2005graded ; bartol2015 , as well as arguments from theoretical studies and computer simulations bhalla1999emergent ; bialek2001stability ; hayer2005molecular ; miller2005stability , suggest that long term information storage may be achieved by using discrete — rather than continuous — synaptic states in biological neural networks.

Therefore, the study of neural network models with discrete weights is interesting both as a hard combinatorial optimization problem and for its potential applications, in practical implementations as well as for modeling biological networks. On the theoretical side, some light has been shedded upon the origin of the computational hardness in these kind of problems by the study of the space of the solutions by means of methods derived from Statistical Physics approaches

krauth_storage_1989 ; huang_origin_2014

: in brief, most solutions are isolated, i.e. far from each other, and the energy landscape is riddled with local minima which tend to trap purely local search methods, which thus show very poor performance. On the application side, a family of heuristic algorithms, derived from the cavity method, have been devised, which exhibit very good performance on random instances, both in terms of solution time and in terms of scaling with the size of the problem.

In particular, it was first shown in braunstein_learning_2006 that a version of the Belief Propagation (BP) algorithm mezard_information_2009

with the addition of a reinforcement term was able to efficiently solve the problem of correctly classifying

random input-output associations using a single-layer network, or a tree-like two-layer network, with synapses, up to a value of close to the theoretical upper bound. For the single-layer case, the theoretical bound for binary synapses is krauth_storage_1989

, while the algorithmic bound as estimated from extensive simulations up to

is . Two more algorithms, obtained as crudely simplified versions of the reinforced BP, were later shown baldassi_efficient_2007 ; baldassi_generalization_2009 to be able to achieve very similar performances, despite being simpler and working in an on-line fashion. The time complexity of all these algorithms was measured to be of order per pattern; the BP algorithm in particular achieves this performance thanks to a Gaussian approximation which is valid at large .

When considering multi-layer networks, the original BP approach of braunstein_learning_2006 can only effectively deal with tree-like network structures; fully-connected structures (such as those commonly used in machine learning tasks) can not be addressed (at least not straightforwardly) with this approach, due to strong correlations arising from a permutation symmetry which emerges in the second layer.

In this paper, we present a new algorithm for addressing the problem of supervised training of network with binary synapses. The algorithm is a variant of the so-called Max-Sum algorithm (MS) mezard_information_2009 with an additional reinforcement term (analogous to the reinforcement term used in braunstein_learning_2006 ). The MS algorithm is a particular zero-temperature limit of the BP algorithm; but it should be noted that this limit can be taken in different ways. In particular, the BP approach in braunstein_learning_2006 was applied directly at zero temperature as patterns had to be learned with no errors. In the MS approach we present here, in addition to hard constraints imposing that no errors are made on the training set, we add external fields with a temperature that goes to zero in a second step. Random small external fields also break the permutation symmetry for multi-layer networks.

In the MS approach, the Gaussian approximation which is used in the BP approach can not be used, and a full convolution needs to be computed instead: this in principle would add a factor of to the time complexity, but, as we shall show, the exploitation of the convexity properties of the problem allows to simplify this computation, reducing the additional factor to just .

This reinforced MS algorithm has very similar performance to the reinforced BP algorithm on single layer networks in terms of storage capacity and of time required to perform each node update; however, the number of updates required to reach a solution scales polynomially with , thus degrading the overall scaling. On fully-connected multi-layer networks, the MS algorithm performs noticeably better than BP.

The rest of the paper is organized as follows: in sec:The-network-model we present the network model and the mathematical problem of learning. In sec:The-Max-Sum-algorithm we present the MS approach for discrete weights. We show how the inherent equations can be solved efficiently thanks to properties of the convolution of concave piecewise-linear functions, and describe in complete detail the implementation for binary weights. Finally, in sec:Numerical-results we show simulation results for the single and two-layer case.

Ii The network model

We consider networks composed of one or more elementary “building blocks” (units), each one having a number of discrete weights and a binary transfer function, which classify binary input vectors. Units can be arranged as a composed function (in which the output from some units is considered as the input of others) in various ways (also called architectures) that is able to produce a classification output from each input vector.

We denote the input vectors as (where is a pattern index) and the weights as (where is a unit index). In the following, the are assumed to take evenly spaced values; we will then explicitly consider the cases with and with . The output of the unit is given by:


with the convention that .

We will consider two cases: a single layer network and two-layer comitee machine. In single-layer networks, also called perceptrons rosenblatt1958perceptron , there is a single unit, and therefore we will omit the index . Fully connected two-layer networks consist of units in the second layer, each of which receives the same input vector , and the output function of the device is

This kind of architecture is also called a committee or consensus machine nilsson1965learning . When , this reduces to the perceptron case. In a tree-like committee machine the input vectors would not be shared among the units; rather, each unit would only have access to a subset of the input vectors, without overlap between the units. For a given , the tree-like architectures are generally less powerful (in terms of computational capabilities or storage capacity) than the fully-connected ones, but are easier to train engel1992storage . Intermediate situations between these two extremes are also possible. In fully-connected committee machines there is a permutation symmetry in the indices , since any two machines which only differ by a permutation of the second layer’s indices will produce the same output.

Throughout this paper we will consider supervised contexts, in which each pattern has an associated desired output .

In classification (or storage) problems, association pairs of input vectors and corresponding desired outputs

are extracted from some probability distribution, and the goal is to find a set of weights

such that .

In random generalization problems, the input patterns are still extracted from some probability distribution, but the desired outputs are computed from some rule, usually from a teacher device (teacher-student problem). The goal then is to learn the rule itself, i.e. to achieve the lowest possible error rate when presented with a pattern which was never seen during the training phase. If the teacher’s architecture is identical to that of the student device, this can be achieved when the student’s weights match those of the teacher (up to a permutation of the units’ indices in the fully-connected case).

In the following, we will always address the problem of minimizing the error function on the training patterns:

where is the Heaviside step function if and otherwise. The term has the role of an external field, and can be used e.g. to implement a regularization scheme; in the following, we will always assume it to be concave. For example we can implement regularization by setting where is a parameter. The first term of expression (II) therefore counts the number of misclassified patterns and the second one favours sparser solutions.

Throughout the paper, all random binary variables are assumed to be extracted from an unbiased i.i.d. distribution.

Under these conditions, it is known that in the limit of

there are phase transitions at particular values of

. For single units (perceptrons) with binary synapses, for the classification problem, the minimum number of errors is typically up to . For the generalization problem, the number of devices which are compatible with the training set is larger than up to , after which the teacher perceptron becomes the only solution to the problem.

Iii The Max-Sum algorithm

Following braunstein_learning_2006 , we can represent the optimization problem of finding the zeros of the first term of eq. (II) on a complete bipartite factor graph. Starting from the single-layer case, the graph has vertices (variable nodes) representing the values and factor nodes representing the error terms .

The standard MS equations for this graph involve two kind of messages associated with each edge of the graph; we indicate with the message directed from node to variable at time step , and with the message directed in the opposite direction.

These messages represent a certain zero-temperature limit of BP messages, but have also a direct interpretation in terms of energy shifts of modified systems. Disregarding an insubstantial additive constant, message represents the negated energy (II) restricted to solutions taking a specific value for variable , on a modified system in which the energy depends on only through the factor node , i.e. in which all terms for all are removed from the energy expression (II). Similarly, message represents an analogous negated energy on a modified system in which the term is removed. For factor graphs that are acyclic, the MS equations can be thought of as the dynamic programming exact solution. In our case, the factor graph, being complete bipartite, is far from being acyclic and the equations are only approximate. For BP, the approximation is equivalent to the one of the Thouless-Anderson-Palmers equations kabashima_cdma_2003 and is expected to be exact in the single-layer case below the critical capacity krauth_storage_1989 . For a complete description of the MS equations and algorithm, see mezard_information_2009 .

The MS equations mezard_information_2009 for energy (II) are:


where and are normalization scalars that ensure and can be computed after the rest of the RHS. At any given time , we can compute the single-site quantities


and use them to produce an assignment of the ’s:


The standard MS procedure thus consists in initializing the messages , iterating eqs. (3) and (4) and, at each time step , computing a vector according to eqs. (5) and (6) until either (in the absence of prior terms, i.e. when ), or the messages converge to a fixed point, or some maximum iteration limit is reached.

Strictly speaking, standard MS is only guaranteed to reach a fixed point if the factor graph is acyclic, which is clearly not the case here. Furthermore, if the problem has more than one solution (ground state), the assignment in eq. (6) would not yield a solution even in the acyclic case. In order to (heuristically) overcome these problems, we add a time-dependent reinforcement term to eqs. (4) and (5), analogously to what is done for BP braunstein_learning_2006 :


where controls the reinforcement speed. This reinforcement term in the case of standard BP implements a sort of “soft decimation” process, in which single variable marginals are iteratively pushed to concentrate on a single value. For the case of MS, this process is useful to aid convergence: on a system in which the MS equations do not converge, the computed MS local fields still give some information about the ground states and can be used to iteratively “bootstrap” the system into one with very large external fields, i.e. fully polarized on a single configuration bailly-bechet_finding_2011 . The addition of this term introduces a dependence on the initial condition. Experimentally, by reducing this dependence can be made arbitrarily small, leading to more accurate results (see Sec. IV), at the cost of increasing the number of steps required for convergence; our tests show that the convergence time scales as .

Furthermore, in order to break symmetries between competing configurations, we add a small symmetry-breaking concave noise to the external fields ; this, together with the addition of the reinforcement term, is sufficient to ensure — for all practical purposes — that the in (6) is unique at every step of the iteration.

iii.1 Max Convolution

While Eq. (7) can be efficiently computed in a straightforward way, the first term of Eq. (3) involves a maximum over an exponentially large set. The computation of Eq. (3) can be rendered tractable by adding where is equal to 0 if and otherwise, which leads to the following transformations:


where in the last step above is defined as:


The right-hand side of (10) is usually called a “Max-Convolution” of the functions for , and is analogous to the standard convolution but with operations substituting the usual . As standard convolution, the operation is associative, which allows to compute the convolution of the functions in a recursive way. As the convolution of two functions with discrete domains and respectively can be computed in operations and has domain in , it follows that (10) can be computed in operations. In principle, this computation must be performed times for each pattern to compute all messages, in a total of time .

A technique like the one described in braunstein_estimating_2008 ; braunstein_efficient_2011 ; braunstein_efficient_2009 can be employed to reduce this by a factor , coming back again to operations per pattern update, as follows. Precomputing the partial convolutions and of and (respectively) for every in can be done using operations in total; then the convolution of can be computed as the convolution of and . Computing this convolution would require operations but fortunately this will not be needed; the computation of (9) can proceed as:


where we defined . As the vectors can be pre-computed recursively in a total time of and (11) requires time , we obtain a grand total operations per pattern update, or per iteration. Unfortunately, this scaling is normally still too slow to be of practical use for moderately large values of and we will thus proceed differently by exploiting convexity properties. However, note that the above scaling is still the best we can achieve for the general case in which regularization terms are not concave.

At variance with standard discrete convolution, in general Max-Convolution does not have an analogous to the Fast Fourier Transform, that would allow a reduction of the computation time of a convolution of functions with

values from to . Nevertheless, for concave functions the convolution can be computed efficiently, as we will show below. Note that for this class of functions, an operation that is analogous to the Fast Fourier Transform is the Legendre-Fenchel transform boudec_network_2001 , though it will be simpler to work with the convolution directly in the original function space.

First, let us recall well-known results about max-convolution of concave piecewise-linear functions in the family boudec_network_2001 . First, the max-convolution of belongs to . Moreover, can be built in an efficient way from and . Start with , which is easily computed as with . Moreover, . Then, order the set of linear pieces from and in decreasing order of slope and place them in order, starting from to form a piecewise-linear continuous function. The method is sketched in Fig. 1. In symbols, let us write each concave piecewise-linear function when as:

with for and . Here we used the notation111We allow , in which case we conventionally define if . . This function is concave, as for the slope is that is clearly decreasing with . To compute the convolution of and , just order the slopes i.e. compute a one to one map from couples , to integers such that implies . The max convolution for is still concave and piecewise-linear, and thus it can be written as:

where . For each we can retrieve ; with this, the parameters of the convolution are , and .

For more details about the max-convolution of piecewise-linear concave functions, see e.g. (boudec_network_2001, , Part II).

We now consider the case of functions defined on a discrete domain. Let be concave discrete functions in

We will define the continuous extension

as the piecewise-linear interpolation of

, with value for arguments in . This can be written as:

with , (implying ). It is easy to see that coincides with the discrete convolution of and in its (discrete) domain; the reason is simply that is also piecewise-linear, with kinks only in discrete values ,.

Figure 1: Sketch of the discrete max-convolution of piecewise-linear concave functions . The result is simply obtained by sorting the pieces of the two functions in descending order of slope.

When computing the convolution of functions with domain in , one can clearly order all slopes together in one initial pass, using comparisons. It is also easy to see that, if one has the full convolution, it is simple to compute a “cavity” convolution in which one function is omitted, in time: this is achieved by simply removing the slopes of from the full convolution.

In order to apply this to eq. (10) the only remaining step is a trivial affine mapping of the functions arguments on a domain to the domain . In the following, we will show explicitly how to do this for the binary case , but the argument can be easily generalized to arbitrary . Note that, while in the binary case the functions are linear and thus trivially concave, in the general case we need to ensure that both the initial values and the external fields are concave; in such case, the iteration equations (3), (7) and (8) ensure that the concavity property holds for all time steps .

iii.2 The binary case

We will show explicitly how to perform efficiently the computations for the binary case. In this case we can simplify the notation by parametrizing the message with a single scalar, i.e. we can write and and . Eqs. (3) and (7) then become:


Correspondingly, eqs. (8) and (6) simplify to:


In order to apply the results of the previous sections, and perform efficiently the trace over all possible assignments of of eq. (12), we first introduce the auxiliary quantities


For simplicity of notation, we will temporarily drop the indices and . We will also assume that all values are different: as remarked above, the presence of term is sufficient to ensure that this is the case, and otherwise we can impose an arbitrary order without loss of generality. With this assumption, the function , which is defined over , has a single absolute maximum, and is indeed concave. The absolute maximum is obtained with the special configuration , which is trivially obtained by setting for all . This configuration corresponds to a value . Any variable flip with respect to this configuration, i.e. any for which , adds a “cost” in terms of . Therefore, if we partition the indices in two groups and defined by , and we sort the indices within each group in ascending order according to , we can compute the function for each by constructively computing the corresponding optimal configuration , in the following way: we start from , then proceed in steps of in both directions subtracting the values in ascending order, using the variable indices in for and those in for .

This procedure also associates a “turning point” to each index , defined as the value of for which the optimal value of changes sign, or equivalently such that . This also implies that:


We can also bijectively associate an index to each value of , by defining such that .

Next, consider the same quantity where a variable is left out of the sum (see eq. (10))


Clearly, one gets the same overall picture as before, except with a shifted argmax, and shifted turning points. The shifts can be easily expressed in terms of the previous quantities, and the expressions used for computing eq. (12) as:


The full details of the computation are provided in the Appendix, Sec. Details of the computation of the cavity fields.. Here, we report the end result:

where . From this expression, we see that we can update the cavity fields very efficiently for all , using the following procedure:

  • We do one pass of the whole array of by which we determine the values, we split the indices into and and we compute . This requires operations (all of which are trivial).

  • We separately partially sort the indices in and and get , and and the turning points . This requires at most operations. Note that we can use a partial sort because we computed , and so we know how many indices we need to sort, and from which set , until we get to the ones with turning points around ; also, we are only interested in computing and instead of all values of . This makes it likely for the procedure to be significantly less computationally expensive than the worst case scenario.

  • For each we compute from the equation above. This requires operations (implemented in practice with three conditionals and a lookup).

Iv Numerical results

We tested extensively the binary case with and the ternary case with , for single layer networks.

We start from the binary case. Fig. 2 shows the probability of finding a solution when fixing the reinforcement rate , for different values of and . Reducing allows to reach higher values of ; the shape of the curves suggest that in the limit there would be sharp transitions at critical values of ’s. In the classification case, Fig. 2A, the transition is around , while the theoretical critical capacity is . This value is comparable to the one obtained with the reinforced BP algorithm of braunstein_learning_2006 . In the generalization case, there are two transitions: the first one occurs around , before the first-order transition at where, according to the theory, the only solution is the teacher; the second transition occurs around . This second transition is compatible with the end of the meta-stable regime (see e.g. engel_statistical_2001 ); indeed, after this point the algorithm is able to correctly infer the teacher perceptron.

Figure 2: Solving probability. Probability of finding a solution for different values of , in the binary perceptron case with , with different values of the reinforcement rate parameter . Performance improves with lower values of . A. Classification case, samples per point. The theoretical capacity is in this case. B. Generalization case, samples per point. In this case, the problem has multiple solutions up to , after which the only solution is the teacher.

A second batch of experiments on the same architecture, in the classification case, is shown in Fig. 3. In this case, we estimated the maximum value of which allows to find a solution, at different values of and ; i.e. for each test sample we started from a high value of (e.g. ) and checked if the algorithm was able to find a solution; if the algorithm failed, we reduced and tried the same sample again. In the cases shown, the solution was always found eventually. The results indicate that the value of required decreases with , and the behaviour is well described by a power low, i.e.  with and , where the values of and depend on . Since the number of iterations required is inversely proportional to (not shown), this implies that the overall solving time of the MS algorithm is of , i.e. it is worse than the reinforced BP in this respect. The value of is between and up to , after which its magnitude decreases abruptly (see Fig. 3B). The behaviour for large seems to be reasonably well fit by a curve , suggesting the presence of a vertical asymptote at , which is an estimate of the critical capacity of the algorithm in the limit of large .

Figure 3: Maximum reinforcement rate. A. Average value of the maximum reinforcement rate which allows to find a solution, in the binary perceptron classification case, at various values of and , in log-log scale. The reinforcement rate decreases with and

. Error bars show the standard deviation of the distributions. Black lines show the result of fits of the form

, one for each value of . The number of samples varied between for and for . B. The fitted values of the exponents in panel A. The continuous curve shows a fit of the data for by the function . The fit yields and . The value of is an estimate of the critical capacity of the algorithm.
Figure 4: Learning of random patterns with a ternary perceptron, with dilution (regularization) prior term; , samples per point. For solved instances with , the average fraction of non-zero weights is also shown (standard deviations smaller than point size).

In the ternary single layer case, we tested learning of random patterns with ternary weights and concave bias (i.e. prior). In practice, we use the function (where is the symmetry-breaking noise term and is sufficiently large) to favour zero weights, so solutions with a minimimal number of zeros are searched, i.e. we add an regularization term. Results (See Fig. 4) are qualitatively similar to the case with a larger capacity (around ; the critical capacity is in this case). The average non-zero weights in a solution grows when getting closer to the critical up to a value that is smaller than (the value that makes the entropy of unconstrained perceptrons largest).

In the fully-connected multi-layer case, the algorithm does not get as close to the critical capacity as for the single-layer case, but it is still able to achieve non-zero capacity in rather large instances. For example, in the classification case with binary synapses, inputs, hidden units, the algorithmic critical capacity is when (tested on samples), corresponding to storing patterns with weights (thus demonstrating a greater discriminatory power than the single-layer case with the same input size). The reason for the increased difficulty in this case is not completely clear: we speculate that it is both due to the permutation symmetry between the hidden units and to replica-symmetry-breaking effects: these effects tend to trap the algorithm — in its intermediate steps — in states which mix different clusters of solutions, making convergence difficult. Still, the use of symmetry-breaking noise helps achieving non trivial results even in this case, which constitutes an improvement with respect to the standard BP algorithm.

V Conclusions

Up to now, the large limit could be exploited on BP equations for the learning problem with discrete synapses to obtain an extremely simple set of approximated equations that made the computation of an iteration to scale linearly with the problem size . For the MS equations however, those approximations cannot be made and a naive approximation scales as which is normally too slow for most practical purposes. In this work, we showed that MS equations can be computed exactly (with no approximations) in time , rendering the approach extremely interesting in practice. A word is in order about the MS equations with reinforcement term, which we propose as a valid alternative to Belief Propagations-based methods. Although we cannot claim any universal property of the reinforced equations from theoretical arguments and we only tested a limited number of cases, extensive simulations for these cases and previous results obtained by applying the same technique to other optimization problems of very different nature bailly-bechet_finding_2011 ; anthony_gitter_sharing_2013 ; altarelli_optimizing_2013 ; altarelli_containing_2014 have confirmed the same qualitative behaviour; that is, that the number of iterations until convergence scales as and that results monotonically improve as decreases. As an additional advantage of MS, inherent symmetries present in the original system are naturally broken thanks to ad-hoc noise terms that are standard in MS. The MS equations are additionally computationally simpler because they normally require only sum and max operations, in contrast with hyperbolic trigonometric functions required by BP equations. Extensive simulations for discrete and weights show that the performance is indeed very good, and the algorithm achieves a capacity close to the theoretical one (very similar to the one of Belief Propagation).

CB acknowledges the European Research Council for grant n° 267915.


Details of the computation of the cavity fields.

In this section we provide the full details of the computation leading to eq. (III.2).

As noted in the main text, the expression of the cavity quantities (see eq. (20)) is analogous to that of the non-cavity counterpart (eq. (18)), where the argmax has changed to , and the turning points have changed:


We need to express the relationship between the old turning points and the new ones: having omitted the variable , it means that there is a global shift of , and that the turning points to the left (right) of have shifted to the right (left) if (, respectively):


(note that we chose to use the convention that ).

Therefore we obtain:


Next, we consider the cavity quantity:


which allows us to write eq. (12) as


(this is eq. (21) in the main text).

Note that is concave and has a maximum at . Using this fact, and eq. (25), we can derive explicit formulas for the expressions which appear in the cavity field, by considering the two cases for separately, and simplifying the result with simple algebraic manipulations afterwards: