Training Multi-Layer Over-Parametrized Neural Network in Subquadratic Time
We consider the problem of training a multi-layer over-parametrized neural networks to minimize the empirical risk induced by a loss function. In the typical setting of over-parametrization, the network width m is much larger than the data dimension d and number of training samples n (m=poly(n,d)), which induces a prohibitive large weight matrix W∈ℝ^m× m per layer. Naively, one has to pay O(m^2) time to read the weight matrix and evaluate the neural network function in both forward and backward computation. In this work, we show how to reduce the training cost per iteration, specifically, we propose a framework that uses m^2 cost only in the initialization phase and achieves a truly subquadratic cost per iteration in terms of m, i.e., m^2-Ω(1) per iteration. To obtain this result, we make use of various techniques, including a shifted ReLU-based sparsifier, a lazy low rank maintenance data structure, fast rectangular matrix multiplication, tensor-based sketching techniques and preconditioning.
READ FULL TEXT