A network that learns Strassen multiplication

01/26/2016 ∙ by Veit Elser, et al. ∙ 0

We study neural networks whose only non-linear components are multipliers, to test a new training rule in a context where the precise representation of data is paramount. These networks are challenged to discover the rules of matrix multiplication, given many examples. By limiting the number of multipliers, the network is forced to discover the Strassen multiplication rules. This is the mathematical equivalent of finding low rank decompositions of the n× n matrix multiplication tensor, M_n. We train these networks with the conservative learning rule, which makes minimal changes to the weights so as to give the correct output for each input at the time the input-output pair is received. Conservative learning needs a few thousand examples to find the rank 7 decomposition of M_2, and 10^5 for the rank 23 decomposition of M_3 (the lowest known). High precision is critical, especially for M_3, to discriminate between true decompositions and "border approximations".

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The current surge of activity in machine learning with neural networks, far from being a fad, is a response to hard empirical evidence. The powers of discrimination, generalization, and abstraction exhibited by these networks matches and often outperforms humans. At the same time, from a theoretical perspective, we know that the core representation of knowledge by these networks is only approximate. The representation of images as compositions of non-negative features, for example, works well only until the images become unnatural. While synthetic data avoids this shortcoming, the allure of the network is diminished when the learned knowledge is not new or surprising.

Questions related to the Strassen algorithm [S] for matrix multiplication provide an opportunity for testing neural networks in the interesting setting where the representation of knowledge is both exact and poorly understood by humans (an open mathematical problem). In this short contribution we define Strassen multiplication (SM) in engineering terms, construct a neural network that has the capacity to implement SM, introduce a simple protocol for training the network called conservative learning, and finish with the mandatory presentation of spectacular results.

2 What is Strassen multiplication?

In 1969 Volker Strassen published a paper [S] that showed, among other things, that a pair of matrices could be multiplied with just seven scalar multiplications, one less than the eight required in the standard algorithm. Strassen’s trick becomes practical on very large matrices where it can be applied recursively on blocks, and finding the fewest number of scalar multiplications for general matrices turns out to be a deep mathematical problem. A glance at the network that implements Strassen multiplication (SM) in Figure 1 quickly establishes two things that make this application of machine learning unique:

  • The architecture of the network and the operations performed by its components are a perfect match to the mathematical problem, as opposed to being just a well motivated, robust platform for general machine learning applications.

  • The design of the network is deliberately naive and lives up to the standard that the machine should have the capacity to learn something that we don’t already know. Our network, in fact, does not even know how numbers are arranged to form a matrix!

Figure 1: A network for multiplying matrices that uses only seven multipliers. Multiplication by a constant (“weight”) occurs at each of the lines connecting input and output registers (for the matrices , and ) to the multipliers in the middle layer of the network.

Figure 1 is an engineer’s representation of SM. The elements of the matrices and are read in from two registers of numbers and , on the left and right. Those from are ‘pooled’ into the left inputs of the seven multipliers in the middle layer of the network, those from into the right inputs. Pooling is the most general linear-homogeneous operation. Denoting the inputs to the multipliers and , , the pooling equations are

(1)

The sets of real numbers and are called weights, though they are not required to be non-negative. Multiplication by a weight is not counted as one of the multiplies in SM. Unlike the multiplications performed by the seven multipliers, weight multiplication is a linear operation in the inputs to the network. The challenge of SM is to find a set of weights that work for all conceivable matrices we wish to multiply. We will see that this set is not unique, and there exist transformations from one set to another that works just as well. In the most practical set found by Strassen [S], all the nonzero weights are just , so the weight-multiplications are really just additions and subtractions.

The output end of the network works in the same way as the input end: the outputs of the seven multipliers are pooled into the four output registers of the matrix, , , again in the most general way with a set of weights:

(2)

Strassen’s discovery can now be expressed in purely engineering terms. That is, there exists a fixed set of weights , and such that the four numbers in the product of any pair of matrices can be computed with the 7-multiplier network of Figure 1 (i.e. the 8-multiplier network implicit in the standard algorithm is sub-optimal).

3 Mathematical digression: tensor rank

Combining (1) and (2) we obtain a curious statement about matrix multiplication, now for matrices:

(3)
(4)

The universal three-index set of numbers is a tensor of order three, the -matrix multiplication tensor. We see that can be written as a sum of products of three 1-tensors, the atoms of all tensors. The minimum in the decomposition into products of 1-tensors is called the rank of the tensor. Unlike the problem of determining the rank of a matrix (a 2-tensor), determining the ranks of higher order tensors such as is computationally difficult [H]. While we know [W], already for the next case we currently only have bounds [B, JDL]: .

A general transformation of the weights that gives another correct network (decomposition of ) can be inferred from the fact that the transformed matrices

(5)

for arbitrary invertible , satisfy the matrix product property whenever , and do (). But feeding and into the network, and seeing the correct output , is equivalent to feeding in and and seeing in the output after the three sets of weights have been appropriately transformed by the linear relations (5).

When we test our network for “matrix multiplication fidelity” we would like to do better than spot check its accuracy on sample instances. To assess how well it will perform on any instance of matrix multiplication we compute , where

(6)

is the mean-square error in how our weights are decomposing the true .

4 Conservative learning

The truly amazing thing about the network in Figure 1 is not that, given the special Strassen weights, it manages to produce the product with only seven multiplies. What is really remarkable is the fact that the network can learn a correct set of weights just by being shown enough examples of correctly multiplied matrices.

A huge industry has grown up around the problem of training networks. Most methods start with completely random weights and adjust these to minimize the discrepancy between the correct output and the actual output of the network. There are many schemes for minimization, though most are local and based on gradients of some measure of discrepancy (“loss”) with respect to the weights. We will use a somewhat different approach for the Strassen network called conservative learning.

Conservative learning is not the oxymoron it would seem, but a combination of two very reasonable principles:

  • When presented with a training item for which the network gives the wrong output, change the weights so that at least this item produces the correct output.

  • Make the smallest possible change to the weights when learning each new item.

By making the smallest possible change, when obsessing over the most recent item, we stand a better chance of not corrupting the accumulated knowledge derived from all previous items. To add mathematical support to this statement we analyze a very simple network.

Consider a network that implements a linear transformation from a set of inputs

to a set of outputs . We use matrix notation in what follows. The transformation to be learned corresponds to a matrix of weights , and training samples are pairs where . Suppose is the current matrix and a new training item. As our network is linear, we choose to train only with normalized inputs, . We wish to find the smallest change such that

(7)

Taking the Frobenius norm to define the smallness of the change, some simple linear algebra gives us the conservative learning rule:

(8)

The current weights are incremented by a rank 1 matrix comprising the input vector

as one factor and the output discrepancy as the other. The new weights are now correct on the item .

To better understand how we are doing relative to the corpus of all possible items, we make note of the following relationship between old and new weights:

(9)
(10)

The operator projects rows to the orthogonal complement of . From (10) we infer the norm inequality

(11)

where we get equality only in the rare case that is orthogonal to all the rows of . This inequality proves that conservatism is sound: by minimally accommodating each new item there is monotone improvement in our approximation of the transformation ) itself.

Conservative learning has the nice feature that there are no parameters that have to be tuned, such as the step size in gradient descent. This continues to be true when the principle is applied to general networks. However, multiple layers and nonlinear components make the problem of finding the exact weight modifications intractable. Fortunately, there is a systematic procedure for keeping track of the order of smallness of the changes in the general network so that an approximate learning rule can be written down. The update procedure for the weights in the multilayer setting has similarities with the back-propagation rules in gradient-based learning, but there are also some new modes of propagation. Not surprisingly, proving convergence is also out of reach. Our experiments with the Strassen network show spectacular convergence, even without imposing a small parameter to keep changes in check.

A derivation of the conservative learning rules for the Strassen network are given in the appendix. This network has all the features that need to be addressed in general networks and therefore serves as a good vehicle for explaining the method.

5 Conservative learning with the Strassen network

There are two sources of randomness when training a network: the distribution of the initial weights and the distribution of the training items. We did not explore this dimension in our experiments with the Strassen network. The weights were initialized with independent uniform samples from a symmetric interval about the origin. Clearly the scale of this interval is not arbitrary, since the tensor being decomposed has a definite scale. A scale for the weights is also implied by the error objective (6). Conservative learning proved not to be very sensitive to the scale of the initial weights, with only extreme limits (small and large scales) showing performance degradation. For simplicity we therefore chose to always initialize with samples drawn from .

Adding an interesting bias to the distribution of training matrices is also something we did not try. Our input matrices were simply constructed from uniform samples of in each element of and . After rescaling to conform with our normalization convention , these together with were handed to the network for training.

Figure 2: Root-mean-square error in the rank 7 decomposition of the matrix multiplication tensor by the network in Figure 1. Shown are results for five runs of the network, each starting with random weights and trained on streams of random instances of correctly multiplied matrices.

It astonished us how quickly the network discovers Strassen’s trick for matrices, when the weights are updated by the conservative learning rules. Figure 2 shows the decomposition error converging linearly, after an initial training set of only a few thousand. The five runs shown have slightly different convergence rates, a variability that can only be attributed to the initial weights and the early training items (before the onset of linear convergence). We did not compare with other learning rules, all of which, unlike conservative learning, come with parameters and batch protocols that we were not prepared to manage with competence.

When we remove one multiplier from the middle layer of the network in Figure 1, that is, attempt to find a rank 6 decomposition of , the error fluctuates indefinitely. This of course is what we expect, since there is a proof that has rank 7 [W]. For matrices the network exhibits a new kind of behavior. Since it is known that [JDL], convergent behavior is possible in principle when we have 23 multipliers. We find that this indeed happens in about 64% of all runs, an example of which is the lower plot in Figure 3. In the remaining runs decays much more slowly; an example of this behavior is the higher plot in Figure 3.

The phenomenon of there being two qualitatively different asymptotic states of the network, when decomposing , is consistent with the border rank property. This property, peculiar to tensors of order three and higher, refers to the topological closure properties of the space of tensor decompositions. When this space is not closed, the abstract closure defines tensor decompositions on the “border” of the space that have lower rank. In more concrete terms, it gives rise to approximate decompositions of lower rank, where the approximation as measured by gets better when the weights are allowed to diverge without limit.

Figure 3: Root-mean-square error , in rank 23 decompositions of the matrix multiplication tensor, on two runs. The absence of convincing linear convergence in one of the runs is accompanied (Figure 4) by a slow growth in the maximum weight magnitude.
Figure 4: The maximum weight magnitudes, , for the two training runs shown in Figure 3.

We see evidence of “border” behavior when we compare the evolution of the maximum weight magnitude, , for the two runs in Figure 3. These are plotted in Figure 4. In contrast to the linearly convergent run, where the maximum weight saturates, the run with slow convergence has a corresponding slow growth in the maximum weight. An interpretation of our results, consistent with what is known about tensor decompositions of , is that the weights in the slowly converging runs are rank 23 border approximations to true rank 24 or higher decompositions. Running the network with 22 multipliers we only observe the slow convergence/growing weight behavior. While this does not prove the true rank is 23, it is consistent with having rank 22 border approximations [Sch]. That slow convergence was never observed in our experiments with is in agreement with the known fact [JML] that this tensor does not have lower rank border approximations.

Conservative learning with the Strassen network for matrices is summarized in Figure 5. This gives the distribution of the final decomposition error for runs, each limited to training items and terminated when the error dropped below . Runs with linear convergence, such as the one in Figure 3, typically required far fewer than

training items. The large peak at the low end of the distribution in Figure 5, about 64% of all runs, therefore gives the probability that the network finds a true rank 23 decomposition. As explained above, we believe the network is slowly converging to border approximations in the other runs.

Figure 5: Distribution of the final decomposition error of in runs, each with training items. Runs were terminated when dropped below .

6 Appendix

The starting point for deriving the conservative learning rules for the Strassen network (Fig. 1) is the Lagrangian function

(12)

To streamline the derivation we use matrix notation. The new training item is the triplet of vectors corresponding to the unrolled matrices . We should think of as inputs to the network and as the output. Prior to this item the network has weight matrices and that map the inputs to the pairs of inputs of the multipliers in the middle layer, . The outputs of the multipliers is the vector , where denotes componentwise multiplication. If the network weights are correct even for the new item, then mapping with the weight matrix should match the output vector . In general, the three weight matrices have to be changed by for this to be true, and the first three terms of are the Frobenius norm objective on these changes to keep them small. The last three terms are constraints imposed via three vectors of Lagrange multipliers, . These insure that with the changed weights the inputs/outputs of the network match the inputs/outputs of the multipliers in the middle layer.

Given the current weights and the new training item , our task is then to find a stationary point of Lagrangian for the variables , , and . In the derivation below we assume that the inputs are normalized as .

Stationarity with respect to and imply

(13)

Comparing with the update rule for the linear network in section 4 prompts us to interpret and as discrepancies associated with the multiplier inputs. Multiplying (13) on the right respectively by and we also have

(14)

These equations and ones to follow are consistent with the Lagrange multipliers vanishing proportionately with the changes in the weights.

Before we impose stationarity with respect to the other variables, we define a set of approximate vectors associated with implementing the network using the current (unchanged) weights. This mode of evaluating the nodes (vectors) in the network is called a forward pass.

(15)
(16)

Imposing stationarity with respect to and comparing the resulting equations with (15) and (14), we obtain

(17)
(18)

This shows that the discrepancies represented by and are the differences between their true values (after the conservative-learning update) and their forward-pass values.

Learning in some sense starts at the output layer, where the of the forward pass is compared with the output of the training item. To obtain the conservative learning rule for this we start by imposing stationarity of with respect to :

(19)

We now make the first of a series of approximations. Extending to the output layer the property of the middle layer, that the pairs and have the same order of smallness, we expect also to vanish proportionately. A good approximation of (19) is then to replace and by their forward pass values (the error being higher order),

(20)

where we have expressed the result in terms the forward pass value of the multiplier outputs, . Now imposing stationarity of with respect to ,

(21)

and comparing with (16) we obtain

(22)
(23)

where we have neglected second order terms and used (13) and (20).

By imposing stationarity with respect to the two remaining sets of variables we will obtain equations that relate to , enabling us to cast the equation for the output discrepancy (23) just in terms of the unknown :

(24)
(25)

In (25) we again discarded second order terms. Substituting and from (25) into (23) we arrive at the equation that begins the process of updating the weights:

(26)

Equation (26) relates the discrepancy, between the true (training) and the of the forward pass, to the Lagrange multipliers . Were we to neglect the off-diagonal terms in this linear matrix equation and determine by

(27)

then the process of learning item would be completely analogous to the usual “back-propagation” scheme. After forward-propagating the inputs with , and to determine the discrepancy , the given by (27) is back-propagated (25) by to get and . The three Lagrange multipliers then determine the weight changes by (13) and (20).

Lacking an argument for discarding the off-diagonal terms in (26), we need to look for methods to solve this more complex linear equation. The off-diagonal terms correspond to one-level of back-propagation followed by forward-propagation. An iterative solution of the linear equation would thus involve multiple backward-forward propagations between just the final two layers of the network. This is not as intimidating as it might seem for two reasons. First, when using the conjugate gradient (CG) method, the number of backward-forward iterations is bounded by the number of components of the vector

. Second, given a reasonable initial solution-estimate, CG usually requires only few iterations in practice.

Since our derivation of the conservative learning rules has been based on the premise that the discrepancies are small, we are keeping with this premise when we take as our initial CG solution-estimate . As a better motivated alternative to (27) we apply the fewest (non-trivial) number of CG iterations — a single one — to this solution-estimate. At this single-iteration level of CG the approximate solution has a simple interpretation. Let

(28)

be the symmetric positive definite matrix in our linear equation

(29)

The single-iteration CG solution has the form , where is a scalar multiplier and determined by projecting the equation on :

(30)

The conservative learning alternative to (27) is therefore

(31)

Not surprisingly this reduces to (27) when the off-diagonal terms in are dropped.

We conclude this appendix with a short summary of the conservative learning rules for updating the network weights when given a new training item .

  1. Forward propagate using (15) and (16) to get and .

  2. From the output discrepancy compute using (31). Computing the scalar multiple of in this expression requires a single backward-forward propagation by the matrix in (28).

  3. The update of is the rank 1 matrix (20) constructed from and .

  4. Backward propagate by (25) to obtain and .

  5. The updates of are given by the rank 1 matrices (13) constructed from and .

7 Acknowledgements

I thank Cris Moore for instigating this study, Jonathan Yedidia for suggesting the conservative learning method, and Alex Alemi for keeping things competitive. The Simons Foundation and ADI Lyric Labs provided financial support.

References

  • [S] V. Strassen, Gaussian elimination is not optimal, Numer. Math. 13, 354-356 (1969).
  • [H] J. Håstad, Tensor rank is NP-complete, J. Algorithms 11, 644-654 (1990).
  • [B] M. Bläser, On the complexity of the multiplication of matrices of small formats, J. Complexity 19, 43-60 (2003).
  • [JDL] J.D. Laderman, A noncommutative algorithm for multiplying matrices using 23 multiplications, Bull. Amer. Math. Soc. 82, 126-128 (1976).
  • [W] S. Winograd, On the multiplication of matrices, Linear Algebra Appl. 4, 381-388 (1971).
  • [Sch] A. Schönhage, Partial and total matrix multiplication, SIAM J. Comput. 10, 434-455 (1981).
  • [JML] J.M. Landsberg, The border rank of the multiplication of matrices is seven, J. Amer. Math. Soc. 19, 447-459 (2006).