1. Introduction
An emerging trend in natural language processing is to train a language model in an unsupervised fashion on a large corpus of text, and then to finetune the model for a specific task
(Radford et al., 2018; Puri et al., 2018; Devlin et al., 2018). The language model often takes the form of an LSTM (Jozefowicz et al., 2016)or a Transformer network
(Vaswani et al., 2017).These models already contain millions of parameters and will continue to grow even larger. Recently, (Yang et al., 2018) demonstrated that the expressiveness of a single Softmax layer was insufficient for the language modeling task. Their proposed solution was the Mixture of Softmax (MoS) layer, which combines several independent Softmax layers together. The number of Softmax layers typically ranges between 3 and 15, so the proposed solution requires significantly more space, especially for larger vocabularies.
Training largescale models efficiently is a challenging task. There are numerous publications that describe how to leverage multiGPU data parallelism and mixed precision training effectively (Hoffer et al., 2017; Ott et al., 2018; Micikevicius et al., 2018). A key tool for improving training time is to increase the batch size, taking advantage of the massive parallelism provided by GPUs. However, increasing the batch size also requires significant amounts of memory. Often times, a practitioner will sacrifice their batch size for a larger, more expressive model. For example, (Puri et al., 2018) showed that doubling the dimensionality of an multiplicative LSTM (Krause et al., 2016) from 4096 to 8192 forces them to reduce the batch size per GPU by .
One culprit that aggravates the memory capacity issue is the auxiliary parameters used by firstorder optimization algorithms, which are commonly used to accelerate the convergence rate of the model. Our proposed solution is to compress the auxiliary parameters of the optimizer using the CountSketch dataset structure (Charikar et al., 2002), freeing up memory for either a more expressive model or a larger batch size for faster training.
We primarily focus on compressing the auxiliary variables for the embedding and Softmax layers. These layers contain a significant portion of the model’s parameters and the set of active features or classes is extremely sparse for many tasks (Spring and Shrivastava, 2017). Consider the language modeling task where there are only a few words out of a large vocabulary in each sentence. There are several algorithms that impose sparsity on the Softmax layer to improve training time. However, getting around memory is still a major challenge. Since the distribution of words follows a powerlaw distribution, Sampled Softmax (Jean et al., 2014) is commonly used to training language models. (Shrivastava and Li, 2014; Vijayanarasimhan et al., 2014; Yen et al., 2018a) have proposed using approximate nearestneighbor search to find the output classes that contain the highest gradients.
Our solution takes advantage of the sparsity present in the Embedding and Softmax layers, so the computational cost scales with the gradient sparsity. We directly insert the sparse gradients into the countsketch, and then retrieve an approximation of the auxiliary variable. Furthermore, we can easily tradeoff the capacity of the countsketch to maintain the optimizer’s performance, without increasing the cost of updating or querying the structure. In Section 5, we formally prove this graceful memory tradeoff, by analyzing the convergence rate of our countsketch optimizer.
On the 1Billion Word dataset, we train an LSTM language model using the Adam optimizer, leveraging our countsketch technique. By compressing the auxiliary variables for the Embedding and Softmax layers, we reduce the memory usage during training by 25 without any accuracy or performance penalty. For an Amazon extreme classification task with over 49.5 million classes, we reduce the training time by 38% by increasing the minibatch size 3.5 using our countsketch optimizer.
2. CountSketch and Streaming Setting
In the traditional streaming setting, we are given a highdimensional vector
that is too costly to store in memory. We only see a very long sequence of updates over time. The only information available at time is of the form , which means that coordinate is updated by the amount . We are given a limited amount of storage, on the order of, which means that we can never store the entire vector. Sketching algorithms aim to estimate the value of current item
, after any number of updates using only memory.The CountSketch is a popular algorithm for estimation in the streaming setting. CountSketch keeps a matrix of bins of size , where and are chosen based on the desired accuracy guarantees. The algorithm uses random hash functions for to map the vector’s components to different bins, . In particular, for any row of sketch , component is hashed into bin . In addition, CountSketch uses random sign functions to map the components of the vectors randomly to , .
The CountSketch supports two operations: UPDATE(item , increment ) and QUERY(item ). The UPDATE operation updates the sketch with any observed increment. More formally, for an increment to an item , the sketch is updated by adding to the cell . The QUERY operation returns an estimate for component , the median of all the different associated counters. If the updates are strictly nonnegative, we return the minimum value across all the counters.
CountSketch Error: (Charikar et al., 2002) Let be the CountSketch estimate of component from vector . For any component
, with probability
, a CountMin Sketch matrix with width and depth satisfies:CountMin Sketch Error: (Cormode and Muthukrishnan, 2005) Let be the CountMin Sketch estimate of component from vector . For any component , with probability , a CountMin Sketch matrix with width and depth satisfies:
3. Intuition
Our goal is to compress the auxiliary variables without incurring significant accuracy loss. Unfortunately, selecting the appropriate compression scheme is not clear without any additional information on the parameter distribution. The challenge is that the parameter distribution can change over time, so any static assumption on the approximation is likely to hurt accuracy. Fortunately, in this section we show that there is a potential solution.
Power Law in Auxiliary Variables over Time: In Figure 2
, we plot the auxiliary variables sorted according to their absolute values during training. To understand the dynamics over time, we show the parameters at two different epochs 5 and 40. The plots clearly indicate a power law behavior where only a few parameters have large magnitudes. In Figure
1, we confirm this behavior for every iteration by plotting the midpoint dividing the head and tails. The auxiliary variables have long tails throughout the training process. Also, this behavior is invariant across the two datasets  (Wikitext2 and ImageNet). To the best of our knowledge, this is the first work that empirically shows the existence of a power law distribution behavior in the gradients and auxiliary variables while training. To dig deeper, we also show the identities of top100 parameters (the head of power law distribution) for epochs 5, 20, and 40 in Figure 2. The top identities change over time, which makes it difficult to cluster parameters into predefined, static clusters.Power law and linear sequence of updates: In summary, we need to compress a power law distribution where the topk identities are constantly changing. Fortunately, the auxiliary variables are updated in a linear fashion. The updates can be written as a linear operator over updates (See Section 4). The countsketch is a dynamic, lowmemory data structure, which preserves high magnitude parameters accurately, while allowing for any sequence of linear updates. The linearity of updates allows us to guarantee that the countsketch provides an accurate estimation of parameters with high probability at every stage in the iteration. The power law distribution and linear updates make sketchingbased ideas a perfect fit for this problem.
4. CountSketch Optimizers
A major chunk of the parameters in the deep network are contained in the fullyconnected layers (Han et al., 2015). Fortunately, for the embedding and softmax layers, the set of active features or classes and their corresponding gradient updates are sparse. Our insight is to use the countsketch data structure to accurately represent the auxiliary variables in a compressed manner. We will insert the sparse gradient information into the countsketch and retrieve an approximate value for the auxiliary variable whenever needed.
In the deep learning setting, the highdimensional vector is analogous to the matrices used to represent the auxiliary variables. The auxiliary variables are represented with matrices where is the number of features in the embedding layer or the number of classes in the softmax layer. Since the dimensionality of the columns is usually in the low thousands (), we represent the auxiliary variables with a countsketch tensor where . This countsketch tensor preserves structured sparsity where values are read from memory in contiguous chunks along the last dimension of the tensor. See Fig. 3 for a visualization. This tensor structure maintains high performance with GPUs and CPU SIMD vector instructions. On the other hand, the rows are compressed by randomly combining features and classes together.
Here is a brief overview of three popular firstorder optimizers whose auxiliary variables we seek to compress: Momentum (Sutskever et al., 2013; Polyak, 1964) remembers a history of gradient updates, which smooths out random oscillations and accelerates convergence. Adaptive gradient descent algorithms alter the learning rate for each feature based on the frequency of its updates. Sparse, rare features are given larger updates and a higher learning rates. These methods track a history of squared gradients for each feature. Adagrad (Duchi et al., 2011) divides the gradient by the square root of the cumulative squared gradient. Adam (Kingma and Ba, 2014) combines momentum and adaptive learning rates together, so it tracks an exponential average of the gradients and squared gradients.
The countsketch data structure expects to receive a stream of updates . For the Momentum and Adam optimizers, we need to transform the update operation into a form that is compatible with the countsketch. For an auxiliary variable , the desired update operation is . Given the appropriate update operation, we replace the addition assignment operator for the original matrix with the UpdateQuery operation for the CountSketch Tensor.
For Momentum, the update rule, given some gradient , is . For the Adam optimizer, given some constant and an update , the update rule for the exponential moving average is .
The CountSketch is essentially a plug and play replacement that saves memory, while retaining the speed and accuracy of the original matrix. Normally, algorithms that compress memory to save space are slower than their dense counterparts. However, the countsketch can leverage sparsity by lazily performing updates with high efficiency. In addition, we can gracefully increase the size of the countsketch for greater accuracy with minimal additional computational cost.
CountMin Sketch Cleaning Heuristic:
Since the CountMin Sketch only accepts nonnegative values, it always overestimates the desired value. The CountMin Sketch is used to estimate the adaptive learning rate for the Adagrad and Adam optimizers. Therefore, an overestimate will prematurely slow the learning rate for certain elements. Our heuristic solution is to clean the sketch periodically by multiplying the tensor by a constant
where every iterations. Instead of this heuristic, an alternative is to use principled adaptive sketches (Shrivastava et al., 2016), which can continuously clean the sketch and decay the overestimates over time.Periodic cleaning works well with the CountMin Sketch because it provides a better estimate for the top elements. During training, the accumulation of updates allows for the heavy hitter estimates to emerge in the sketch (Aghazadeh et al., 2018)
. Due to stochastic gradient descent, there is a certain amount of noise in the gradient, so cleaning immediately after each update destroys the internal state of the sketch. Furthermore, cleaning reduces the scale of the sketch, reducing the overall noise level. If the signal to noise ratio is too high, future heavy hitter are ignored because there values are equal to the noise in the sketch.
5. Theoretical Analysis
For stochastic nonconvex optimization (Zaheer et al., 2018), we measure how the algorithm converges to a stationary point at iteration —i.e., for some small constant . In our analysis, we focus on the CountMin Sketch Adam optimizer where we do not track the 1st moment—i.e., . This optimizer was used in the Amazon Extreme Classification task (See Section 7.3) in order to save additional memory, similar to the Adafactor optimizer (Shazeer and Stern, 2018).
We assume that the function is smooth with bounded gradients: Function has bounded gradients  . In addition, we receive an unbiased stochastic gradient estimate
with fixed variance
. Then, the following theorem holds:Theorem 5.1 ().
Let the learning rate . Assume , , and are selected such that and . Given a CountMin Sketch matrix with width and depth , we have the following bound that holds for CountMin Sketch Adam with probability where :
The proof of Theorem 5.1 is found in the Appendix. For comparison, we have the convergence bound from (Zaheer et al., 2018) for the standard Adam optimizer where :
Discussion: The bounds are similar except for the additional term caused by the CountMin Sketch approximation. The theorem states that the CountMin Sketch Adam converges to a region around a stationary point with radius . The additional error term depends on the adaptivity of the optimizer , the error rate of the sketch, and the gradient norm . The error rate is proportional to the width of the sketch and corresponds with the number of collisions along each row in the sketch. We can improve convergence gracefully by increasing the sketch’s width, which reduces the error caused when multiple components collide in the same bin. In practice, we bound the gradient norm to reasonable constant to prevent instability—i.e., . When the sketch width , the error term becomes a small constant.
Note that the gradient norm decreases over time. Thus, the error caused by the countsketch approximation decreases as the algorithm progresses, and we can shrink the sketch. A nice property of the countsketch data structure is that you can add one half of the sketch to the other, reducing its size by half while maintaining its accuracy guarantees. Please see (Matusevych et al., 2012) for more details.
The failure probability of exceeding the CountMin Sketch error bound is proportional to the depth of the sketch . In our theoretical results, the depth of the sketch depends logarithmically on the number of parameters and the number of time steps . However, our experiments show that a modest depth size of 35 is sufficient.
6. Related Work
Feature Compression: A straightforward option is to use dimensionality reduction techniques to minimize the number of features, which in turn decreases the size of the model and optimizer simultaneously. (Tito Svenstrup et al., 2017) describes a hash embedding scheme where the output embedding for a feature is a weighted sum between the embedding vectors and the weight vector. Their goal was to minimize the size of the embedding layer while preserving its flexibility to model large vocabularies. However, dramatically reducing the feature space may sacrifice model accuracy. For example, training the BERT language model (Devlin et al., 2018) on a GPU with 1216 GB memory requires a smaller, less effective architecture than the fullsized model trained on the 64 GB Google TPU.
Gradient Checkpointing: (Siskind and Pearlmutter, 2018; Chen et al., 2016) describe an orthogonal approach where training an
layer neural network requires
memory. Their insight was that storing the activations for the backpropagation pass is the most memoryintensive part of training. Instead of storing all the activations, their algorithm checkpoints certain sections of the neural network and lazily recomputes the activations during the backpropagation phase. In other words, their approach saves memory by sacrificing extra computation time.LowRank Approximation: A lowrank approximation has the potential to reduce the number of parameters from to where . However, updating the lowrank matrices is nontrivial. (Shazeer and Stern, 2018) demonstrated that there exists a unique, fast update rule for a rank1 approximation that minimizes the Idivergence between the approximation and original matrix. Their rank1 approximation was limited to nonnegative matrices, so only the second moment of the Adam optimizer was compressed in their experiments. The drawback of this approach is that it requires materializing the entire matrix via an outerproduct, which is prohibitive for largescale embedding and softmax layers. In addition, since their update rule only applies for rank1 vectors, their approach lacks the flexibility to increase the model’s memory capacity gracefully.
CountSketch: The original objective of the CountSketch data structure was to estimate the frequency of various events in the streaming setting. Recently, (Aghazadeh et al., 2018; Tai et al., 2018)
demonstrated that the CountSketch can learn a compressed model that accurately preserves the features with the largest weights. Their objective focused on feature extraction in ultrahigh dimensional settings and was limited to simple, linear models. In this work, we seek to use the CountSketch to preserve the different auxiliary variables maintained by commonly used firstorder optimizers. The ideal solution is for the memory cost of the optimizer to grow sublinearly with the model size, giving us the flexibility to increase the model’s capacity.
Type  CountSketch  LowRank 

Memory  
Gradient Type  Sparse  Dense 
Memory Control  Flexible  Fixed 
Query Time 
7. Experiments
All of the experiments were performed with the PyTorch framework on a single machine  2x Intel Xeon E52660 v4 processors (28 cores / 56 threads) with 512 GB of memory using a single Nvidia Tesla V100. The code
^{1}^{1}1https://github.com/rdspring1/CountSketchOptimizers for the CountSketch Optimizer is available online. We designed the experiments to answer these questions:
Does the model’s gradients and the optimizer’s auxiliary variables follow a powerlaw distribution?

How accurate is our estimate of the auxiliary variables retrieved from the countsketch data structure?

What the effect of cleaning the countmin sketch on convergence time and accuracy?

How well does our countsketch optimizer compare against the lowrank approximation given the same number of parameters?

Does our countsketch optimizer match original baseline in terms of speed and accuracy?
Here are the five datasets used in the experiments:

Wikitext2 (Merity et al., 2016)  This dataset was extracted from Wikipedia and contains 2M training tokens with a vocabulary size of 33,278. (10.8 MB)

Wikitext103 (Merity et al., 2016)  A larger version of the Wikitext2 dataset that contains 103M training tokens and its vocabulary size is 267,735. (539.2 MB)

1Billion Word (LM1B) (Chelba et al., 2013)  This largescale corpus contains 0.8 billion training tokens and a vocabulary with 793,471 words. (4.1 GB) An opensourced PyTorch model is available online ^{2}^{2}2https://github.com/rdspring1/PyTorch_GBW_LM

MegaFace  A facial recognition dataset derived from MegaFace (Challenge 2)
^{3}^{3}3http://megaface.cs.washington.edu/. Each person is a candidate class, but we only select classes with at least 10 images. Thus, this sampled dataset contains 1,943,802 examples with 80,204 classes. 10K images are randomly sampled to create the test dataset. (4 GB) 
Amazon  This sampled recommendation dataset contains 70.3 million examples and over 49.5 million object classes. (20.9 GB)
We implemented the following approaches to compare and contrast against our approach:

NonNegative Matrix Factorization (NMF) Rank1 — This decomposition minimizes the Idivergence between the auxiliary variable and the approximation formed from two rank1 vectors. However, it is limited to nonnegative matrices, so it cannot compress the auxiliary variables for Momentum or the 1st Moment of Adam. (Shazeer and Stern, 2018)

Rank1 — After each update, we perform an SVD decomposition of the auxiliary variable, and only keep the top singular value and its corresponding vectors. During the subsequent update, the auxiliary variable is reconstructed via an outer product. Unlike the NMF Rank1 Approximation, this approach is not limited to nonnegative values, but it is extremely slow and cannot be used in practice.

CountSketch — As described in Section 4. This approach is also not limited to nonnegative values and is capable of compressing the auxiliary variables for all optimizers efficiently.
Title  Symbol 

CountSketch  CS 
LowRank  LR 
Adam 1st Moment  M 
Adam 2nd Moment  V 
NonNegative Matrix Factorization  NNF 
7.1. SmallScale Experiments
Wikitext2: The language model is a 2layer LSTM with 672 hidden units. The dimensionality of the word embeddings is equal to the number of hidden units. The model is unrolled 35 steps for the backpropagation through time (BPTT). The model is regularized via Dropout with a 50% chance of disabling a unit. We train the model for 40 epochs with a minibatch size of 20. For Momentum, the learning rate is 2.5, the decay rate is 0.9, and we clip the gradient norm to 0.25. For Adam, the learning rate is 0.001, the beta values
are (0.9, 0.999), and gradient clipping is 1. We reduce the learning rate by
whenever the validation error plateaus. We use the full softmax layer, so only the embedding layer is sparse for this dataset.Norm Approximation Error: Fig. 4 shows the Norm between the approximation and the original auxiliary variable over several training iterations. The left figure is for the Momentum optimizer, while the right figure is for the 2nd Moment for the Adam optimizer. All of the methods are given roughly an equal amount of parameters to approximate the original auxiliary variable. For the Wikitext2 dataset, the embedding and softmax layers use [33,278, 256] matrices. Therefore, the rank1 decomposition uses two vectors that use 33,278 + 256 = 33,534 parameters. The countsketch data structure is represented with a [3, 16, 672] tensor, containing 32,256 parameters. Our countsketch approach maps the 33,278 word vocabulary into 16 distinct bins, so there are about 2,080 collisions for each bucket.
The Adam optimizer’s 2nd Moment is strictly nonnegative and is suitable for the NMF Rank1 approximation. For the Momentum variable, we supplement the NMF decomposition with the SVD decomposition. The SVD decomposition maintains a good approximation of the Momentum variable. However, it is extremely slow during training, so we only show the approximation error for the first epoch of training. As expected, the NMF Rank1 baseline poorly approximates the momentum variable, which is not strictly nonnegative. It experiences significant variance in its approximation quality. The CountSketch is a consistent estimator for both variables with slightly more error for both variables.
Test Perplexity: Tables 3,4 show the test perplexity after training the model with the Momentum and Adam optimizers. For the momentum optimizer, the NNM LowRank approximation performs poorly, reinforcing the results from Fig. 4. When only the 2nd moment is compressed, the NNM LowRank and CountSketch approximations have negligible differences. When we compress both the 1st and 2nd moments with the CountSketch, there is some minor accuracy loss from the original optimizer.
Momentum  CS  LRNMF 
94.25  95.93  176.31 
CSMV  Adam  CSV  LRNMFV 
109.24  105.14  106.32  106.21 
MegaFace: For this experiment, we obtain pretrained embeddings of size 512 from the FaceNet architecture (Schroff et al., 2015) trained on the MSCeleb1M dataset ^{4}^{4}4https://github.com/davidsandberg/facenet. Afterwards, we train a softmax classifier on the MegaFace dataset using LSH Sampling (Yen et al., 2018b; Vijayanarasimhan et al., 2014). For LSH Sampling, we use SimHash — Signed Random Projection (SRP) with K=15 bits per hash fingerprint. There are L=16 hash tables that are rebuilt every 250 iterations. For Adam, the learning rate is 0.001 and the beta values are (0.9, 0.999). For Adagrad, the learning rate is 0.1. All the models were trained for 10 epochs.
Fig. 5 shows the effect of cleaning the CountMin Sketch Tensor on its corresponding optimizer. We measure how the testing accuracy, convergence rate, and auxiliary variable error changes because of cleaning for the Adam and Adagrad optimizers. The CountMin Sketch tensor is set to of the original variable’s size. For Adam, the cleaning scheme is every 125 iterations, multiply the countmin sketch by a constant . For Adagrad, the rate of cleaning is the same, but the constant is changed to .
For both Adam and Adagrad, there is a noticeable drop in Norm error with cleaning, which reflects positively in terms of test accuracy and convergence. For Adam, the countsketch optimizer with cleaning closely matches the convergence rate of the baseline and slightly surpasses its test accuracy. The test accuracy for CountSketch with cleaning is 69.4%, while the baseline is 69.03%. For Adagrad, cleaning did not improve the initial convergence rate, but allowed the final test accuracy to match the baseline. There is a solid improvement in test accuracy from to by using cleaning for the CountSketch Adagrad optimizer.
Given that the Adam optimizer already contains an exponential decay term, it is surprising that cleaning is necessary. However, despite further hyperparameter tuning, the countsketch optimizer with cleaning still achieves the best performance. For dense gradients, the decay term is applied to all elements. Since the gradients are sparse, only the nonzero elements are updated. Thus, the decay is applied in an irregular fashion for the elements in the sketch.
7.2. LargeScale Language Model
Since the Wikitext103 and LM1B datasets have large vocabularies, we use Sampled Softmax (Jean et al., 2014) to induce sparsity in the softmax layer and for faster training. Each CountSketch Tensor is smaller than the original variable. Therefore, there are at least 15 collisions for each bin on average.
Adagrad  Wikitext103: Our language model is a single layer LSTM with 1024 hidden units. The dimensionality of the word embeddings is 256 and we use a projection layer between the LSTM and Softmax layers. The model is unrolled 35 steps BPTT. The model is regularized via Dropout with . We train the model for 25 epochs with a minibatch size of 1024. For the Adagrad optimizer, the gradient norm is clipped to 0.1, the learning rate starts at 0.4 and decays linearly to 0 during training.
Results: For the Wikitext103 dataset, we allocated a [3, 17,849, 256] CountSketch tensor for each auxiliary variable. By providing the CountSketch with more parameters, our method has notably better test accuracy than the NMF lowrank approximation while using only slightly more memory. In addition, despite using more parameters than the lowrank approximation, the countsketch optimizer is still somewhat faster. Finally, the lowrank approximation fails to meet the same accuracy as the original baseline, while surprisingly the countsketch optimizer has the best test perplexity.
Metric  Adagrad  CS  LRNMF 

Time  6.4  6.6  6.7 
Size  10,625  10,089  10,077 
Test Perplexity  57.63  56.07  58.27 
Adam  LM1B: For the 1Billion Word dataset, our goal is to mimic multiGPU distributed training on a single GPU. The original batch size is 128 with a learning rate of 5e4. By increasing our batch size from 128 to 1024, we scale our learning rate linearly by (Goyal et al., 2017). In addition, we decay our learning rate linearly to zero over 5 training epochs. We double the LSTM size from 1024 to 2048, but keep the word embedding size at 256. The model is unrolled 20 steps BPTT. Dropout is kept nominally at and the gradient norm is clipped to 1. A surprising side effect of increasing the batch size was that we reduced our training time by roughly from 12.25 hours to 6.25 hours per epoch despite using a single GPU.
Results: For the 1Billion Word dataset, we allocated a [3, 52,898, 256] CountSketch tensor for each auxiliary variable. Our primary comparison is only with the 2nd moment because the NMF lowrank approximation is not applicable to the 1st moment. The countsketch is slightly more accurate than the lowrank approximation. When both the 1st and 2nd moments are compressed with the countsketch tensor, its accuracy is onpar with the lowrank approximation that compresses only the 2nd moment. In general, the countsketch tensor is faster than the lowrank approach while using substantially less GPU memory. For large matrices, there is a noticeable cost with reconstructing the entire matrix to update only a sparse subset of values.
Metric  CSMV  Adam  CSV  LRNMFV 

Time  27.1  26.4  26.75  29.2 
Size  8,591  11,707  10,167  13,259 
Epoch  CSMV  Adam  CSV  LRNMFV 

1  50.78  48.48  49.49  50.04 
2  46.08  45.34  45.22  45.60 
3  43.71  42.79  42.95  43.55 
4  41.82  41.15  41.23  41.82 
5  40.55  39.90  39.88  40.41 
7.3. Extreme Classification
For the extremely largescale classification task, we conducted our experiments on an Amazon recommendation dataset. The task is to predict an object out of over 49 million classes given a query. The text query is parsed into trigram features. Feature hashing is applied to convert the strings into integers. The input feature dimension is 80K. On average, there are on 30 nonzero features per query, so the input layer is very sparse and suitable for our CountSketch optimizer. We trained a single hidden layer, fullyconnected neural network with an embedding dimension of 1024.
A traditional softmax classifier would require over 200 GB of memory, which is well beyond the memory capacity of the largest GPUs. Instead, we leverage a novel approach for extreme classification called MergedAveraged Classifiers via Hashing (MACH) (Huang et al., 2018). This algorithm randomly merges the output classes into a manageable number of coarsegrained, metaclasses via universal hashing. Several independent, fullyconnected neural networks are trained to solve this metaclass classification task. Each metaclassifier is associated with a unique hash function that creates a distinct class mapping. At inference time, we recover the scores for the original classes by aggregating the metaclass scores assigned to the original output class. For this experiment, we used 20K metaclasses in the output layer of each metaclassifier. For highaccuracy models, we use 32 metaclassifiers. Each individual metaclassifier required 414 MB of memory for a total of 12.95 GB. Therefore, our ensemble MACH classifier used less memory than a monolithic softmax classifier.
Since we are primarily interested in faster training times, we limit ourselves to 4 metaclassifiers in this experiment. For our baseline, each metaclassifier is trained using the Adam optimizer with a batch size of 750. Given these settings, a single metaclassifier takes 4 GB of GPU memory, allowing us to train 4 models in parallel on a single GPU. For maximum memory savings, we eliminate the 1st moment and use a countmin sketch tensor of size [3, 266, 1024] for the 2nd moment (1% of original size). By using the Adam CountSketch optimizer, we reduce the memory cost for each model from 4 GB to 2.6 GB (45% smaller). We take of advantage of this extra memory by increasing the batch size from 750 to 2600 (3.5 larger). As a result, the running time per epoch decreased from 5.32 hours to 3.3 hours (38% faster).
We measure the accuracy of the MACH model using the Recall@100 metric on a test dataset containing 20K queries. First, we evaluate the metaclassifiers and aggregate their scores. Then, we check how often the target class appears within the top 100 scores generated by the classifier. A major bottleneck during evaluation is sorting the 49.5 million classes to find the top 100 scores. Since we are only comparing the model’s relative performance and are interested in fast running times, we downsample the scores from 49.5 million to 1 million. The class subset contains the target classes for all 20K test queries and a random sample of the remaining classes. Given 16 metaclassifiers, the Adam baseline has a 0.6881 recall, while the CountSketch optimizer achieves a 0.6889 recall.
Type  Batch Size  Epoch Time  Recall@100 

Adam  750  5.32  0.4704 
CSV  2600  3.3  0.4789 
8. Conclusion and Future Work
In this paper, we present the concept of a countsketch tensor to compress the auxiliary variables associated with popular firstorder optimizers. The countsketch tensor retains the constanttime update and query operations, while maintaining structured sparsity for highspeed vectorized operations. The countsketch tensor can reduce the memory usage of largescale models with minimal cost by taking advantage of the model’s sparsity. Going forward, we are interested in compressing the auxiliary variables associated with the hidden layers without incurring any performance penalty. We hope to leverage recent ideas of adding sparsity to the hidden layers in order to increase the size of the model without increasing its computational cost (Spring and Shrivastava, 2017; Shazeer et al., 2017; Wen et al., 2017). Structured sparsity in the hidden layers would mesh well with our current approach for the Embedding and Softmax layers.
References
 (1)

Aghazadeh et al. (2018)
Amirali Aghazadeh, Ryan
Spring, Daniel Lejeune, Gautam
Dasarathy, Anshumali Shrivastava, and
richard baraniuk. 2018.
MISSION: Ultra LargeScale Feature Selection using CountSketches. In
Proceedings of the 35th International Conference on Machine Learning
(Proceedings of Machine Learning Research), Jennifer Dy and Andreas Krause (Eds.), Vol. 80. PMLR, Stockholmsmässan, Stockholm Sweden, 80–88. http://proceedings.mlr.press/v80/aghazadeh18a.html  Charikar et al. (2002) M. Charikar, K. Chen, and M. FarachColton. 2002. Finding frequent items in data streams. In Intl. Colloquium on Automata, Languages, and Programming. Springer, 693–703.
 Chelba et al. (2013) Ciprian Chelba, Tomas Mikolov, Mike Schuster, Qi Ge, Thorsten Brants, Phillipp Koehn, and Tony Robinson. 2013. One billion word benchmark for measuring progress in statistical language modeling. arXiv preprint arXiv:1312.3005 (2013).
 Chen et al. (2016) Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. 2016. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174 (2016).
 Cormode and Muthukrishnan (2005) Graham Cormode and Shan Muthukrishnan. 2005. An improved data stream summary: the countmin sketch and its applications. Journal of Algorithms 55, 1 (2005), 58–75.
 Devlin et al. (2018) Jacob Devlin, MingWei Chang, Kenton Lee, and Kristina Toutanova. 2018. BERT: Pretraining of Deep Bidirectional Transformers for Language Understanding. arXiv preprint arXiv:1810.04805 (2018).
 Duchi et al. (2011) John Duchi, Elad Hazan, and Yoram Singer. 2011. Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research 12, Jul (2011), 2121–2159.
 Goyal et al. (2017) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. 2017. Accurate, large minibatch SGD: training imagenet in 1 hour. arXiv preprint arXiv:1706.02677 (2017).
 Han et al. (2015) Song Han, Huizi Mao, and William J Dally. 2015. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding. arXiv preprint arXiv:1510.00149 (2015).
 Hoffer et al. (2017) Elad Hoffer, Itay Hubara, and Daniel Soudry. 2017. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.). Curran Associates, Inc., 1731–1741.
 Huang et al. (2018) Qixuan Huang, Yiqiu Wang, Tharun Medini, and Anshumali Shrivastava. 2018. Extreme Classification in Log Memory. arXiv preprint arXiv:1810.04254 (2018).
 Jean et al. (2014) Sébastien Jean, Kyunghyun Cho, Roland Memisevic, and Yoshua Bengio. 2014. On using very large target vocabulary for neural machine translation. arXiv preprint arXiv:1412.2007 (2014).
 Jozefowicz et al. (2016) Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, and Yonghui Wu. 2016. Exploring the limits of language modeling. (2016). https://arxiv.org/pdf/1602.02410.pdf
 Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. 2014. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014).
 Krause et al. (2016) Ben Krause, Liang Lu, Iain Murray, and Steve Renals. 2016. Multiplicative LSTM for sequence modelling. arXiv preprint arXiv:1609.07959 (2016).
 Matusevych et al. (2012) Sergiy Matusevych, Alex Smola, and Amr Ahmed. 2012. Hokusaisketching streams in real time. arXiv preprint arXiv:1210.4891 (2012).
 Merity et al. (2016) Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. 2016. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843 (2016).
 Micikevicius et al. (2018) Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, and Hao Wu. 2018. Mixed Precision Training. In International Conference on Learning Representations. https://openreview.net/forum?id=r1gs9JgRZ
 Ott et al. (2018) Myle Ott, Sergey Edunov, David Grangier, and Michael Auli. 2018. Scaling Neural Machine Translation. arXiv preprint arXiv:1806.00187 (2018).
 Polyak (1964) Boris T Polyak. 1964. Some methods of speeding up the convergence of iteration methods. U. S. S. R. Comput. Math. and Math. Phys. 4, 5 (1964), 1–17.
 Puri et al. (2018) Raul Puri, Robert Kirby, Nikolai Yakovenko, and Bryan Catanzaro. 2018. Large Scale Language Modeling: Converging on 40GB of Text in Four Hours. arXiv preprint arXiv:1808.01371 (2018).
 Radford et al. (2018) Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. 2018. Improving language understanding by generative pretraining. Online (2018).

Schroff
et al. (2015)
Florian Schroff, Dmitry
Kalenichenko, and James Philbin.
2015.
Facenet: A unified embedding for face recognition
and clustering. In
Proceedings of the IEEE conference on computer vision and pattern recognition
. 815–823.  Shazeer et al. (2017) Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. 2017. Outrageously large neural networks: The sparselygated mixtureofexperts layer. arXiv preprint arXiv:1701.06538 (2017).
 Shazeer and Stern (2018) Noam Shazeer and Mitchell Stern. 2018. Adafactor: Adaptive Learning Rates with Sublinear Memory Cost. In Proceedings of the 35th International Conference on Machine Learning (Proceedings of Machine Learning Research), Jennifer Dy and Andreas Krause (Eds.), Vol. 80. PMLR, Stockholmsmässan, Stockholm Sweden, 4596–4604. http://proceedings.mlr.press/v80/shazeer18a.html
 Shrivastava et al. (2016) Anshumali Shrivastava, Arnd Christian Konig, and Mikhail Bilenko. 2016. Time adaptive sketches (adasketches) for summarizing data streams. In Proceedings of the 2016 International Conference on Management of Data. ACM, 1417–1432.
 Shrivastava and Li (2014) Anshumali Shrivastava and Ping Li. 2014. Asymmetric LSH (ALSH) for sublinear time maximum inner product search (MIPS). In Advances in Neural Information Processing Systems. 2321–2329.
 Siskind and Pearlmutter (2018) Jeffrey Mark Siskind and Barak A Pearlmutter. 2018. Divideandconquer checkpointing for arbitrary programs with no user annotation. Optimization Methods and Software 33, 46 (2018), 1288–1330.
 Spring and Shrivastava (2017) Ryan Spring and Anshumali Shrivastava. 2017. Scalable and sustainable deep learning via randomized hashing. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. ACM, 445–454.
 Sutskever et al. (2013) Ilya Sutskever, James Martens, George Dahl, and Geoffrey Hinton. 2013. On the importance of initialization and momentum in deep learning. In International conference on machine learning. 1139–1147.
 Tai et al. (2018) Kai Sheng Tai, Vatsal Sharan, Peter Bailis, and Gregory Valiant. 2018. Sketching Linear Classifiers over Data Streams. In Proceedings of the 2018 International Conference on Management of Data (SIGMOD ’18). ACM, New York, NY, USA, 757–772. https://doi.org/10.1145/3183713.3196930
 Tito Svenstrup et al. (2017) Dan Tito Svenstrup, Jonas Hansen, and Ole Winther. 2017. Hash Embeddings for Efficient Word Representations. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.). Curran Associates, Inc., 4928–4936.
 Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ł ukasz Kaiser, and Illia Polosukhin. 2017. Attention is All you Need. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.). Curran Associates, Inc., 5998–6008. http://papers.nips.cc/paper/7181attentionisallyouneed.pdf
 Vijayanarasimhan et al. (2014) Sudheendra Vijayanarasimhan, Jonathon Shlens, Rajat Monga, and Jay Yagnik. 2014. Deep networks with large output spaces. arXiv preprint arXiv:1412.7479 (2014).
 Wen et al. (2017) Wei Wen, Yuxiong He, Samyam Rajbhandari, Minjia Zhang, Wenhan Wang, Fang Liu, Bin Hu, Yiran Chen, and Hai Li. 2017. Learning intrinsic sparse structures within long shortterm memory. arXiv preprint arXiv:1709.05027 (2017).
 Yang et al. (2018) Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, and William W. Cohen. 2018. Breaking the Softmax Bottleneck: A HighRank RNN Language Model. In International Conference on Learning Representations. https://openreview.net/forum?id=HkwZSGCZ
 Yen et al. (2018a) Ian EnHsu Yen, Satyen Kale, Felix Yu, Daniel HoltmannRice, Sanjiv Kumar, and Pradeep Ravikumar. 2018a. Loss Decomposition for Fast Learning in Large Output Spaces. In Proceedings of the 35th International Conference on Machine Learning (Proceedings of Machine Learning Research), Jennifer Dy and Andreas Krause (Eds.), Vol. 80. PMLR, Stockholmsmässan, Stockholm Sweden, 5640–5649. http://proceedings.mlr.press/v80/yen18a.html
 Yen et al. (2018b) Ian EnHsu Yen, Satyen Kale, Felix Yu, Daniel HoltmannRice, Sanjiv Kumar, and Pradeep Ravikumar. 2018b. Loss Decomposition for Fast Learning in Large Output Spaces. In Proceedings of the 35th International Conference on Machine Learning (Proceedings of Machine Learning Research), Jennifer Dy and Andreas Krause (Eds.), Vol. 80. PMLR, Stockholmsmässan, Stockholm Sweden, 5640–5649.
 Zaheer et al. (2018) Manzil Zaheer, Sashank Reddi, Devendra Sachan, Satyen Kale, and Sanjiv Kumar. 2018. Adaptive Methods for Nonconvex Optimization. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. CesaBianchi, and R. Garnett (Eds.). Curran Associates, Inc., 9815–9825. http://papers.nips.cc/paper/8186adaptivemethodsfornonconvexoptimization.pdf
Appendix A Appendix
CountSketch Error Bound: (Charikar et al., 2002) Let be the CountSketch estimate of component from vector . For any component , with probability , a CountMin Sketch matrix with width and depth satisfies
(1) 
CountMin Sketch Error Bound: (Cormode and Muthukrishnan, 2005) Let be the CountMin Sketch estimate of component from vector . For any component , with probability , a CountMin Sketch matrix with width and depth satisfies
(2) 
For stochastic nonconvex optimization, we measure how the algorithm converges to a stationary point  for some constant . Notation: batch size , learning rate , 2nd moment decay rate , countmin sketch error rate , countmin sketch failure probability . Assumptions: Here are the assumptions used in our analysis:

Function is LSmooth  There exists a constant such that

Function has bounded gradients 

The stochastic gradient oracle provides us with an unbiased estimate with fixed variance. Let
represents the randomness (due to minibatch sampling) at iteration .
For simplicity and to save additional memory by not tracking the 1st moment, let
. In this form, the optimizer is commonly called RMSPROP. Therefore, the update rule for all
is(3) 
where represents the CountMin Sketch estimate of component from vector .
Theorem A.1 ().
Let learning rate and batch size . Assume , , and are selected such that and . Given a CountMin Sketch matrix width and depth , we have the following bound that holds for CountMin Sketch Adam with probability where :
Proof.
Given that the function is smooth and by the optimizer update rule, we derive the following:
(4) 
Next, we take the expectation of , given we that know (assumed fixed):
The second equality occurs because is an unbiased estimate of . Now, we upperbound the term :
From Lemma A.4, we have the second equality. The second inequality occurs because of Lemma A.3, which is derived using the CountMin Sketch error bound. The third inequality occurs because and when we drop from .
By substituting the upperbound for , we arrive at the following:
The first inequality follows because the function has bounded gradients  . Now, the second inequality holds because . In addition, we split the and terms using the linearity of expectation. For the third inequality, we use the result and definitions in Lemma A.2. From the specified parameters for , , and , we assume the following conditions hold: and .