Climbing the WOL: Training for Cheaper Inference

07/02/2020 ∙ by Zichang Liu, et al. ∙ Stanford University Rice University 0

Efficient inference for wide output layers (WOLs) is an essential yet challenging task in large scale machine learning. Most approaches reduce this problem to approximate maximum inner product search (MIPS), which relies heavily on the observation that for a given model, ground truth labels correspond to logits of highest value during full model inference. However, such an assumption is restrictive in practice. In this paper, we argue that approximate MIPS subroutines, despite having sub-linear computation time, are sub-optimal because they are tailored for retrieving large inner products with high recall instead of retrieving the correct labels. With WOL, the labels often have moderate inner products, which makes approximate MIPS more challenging. We propose an alternative problem formulation, called Label Superior Sampling (LSS), where the objective is to tailor the system to ensure retrieval of the correct label. Accordingly, we propose a novel learned hash approach, which is significantly more efficient and sufficient for high inference accuracy than MIPS baselines. Our extensive evaluation indicates that LSS can match or even outperform full inference accuracy with around 5x speed up and 87



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In recent years, neural networks with wide output layers have obtained promising results in various applications such as recommendation systems 

xue2017deep; Bhatia16; fan2019mobius and language modeling bengio2003neural; mikolov2010recurrent; mikolov2013efficient

. One of the significant challenges of deploying such models lies in the massive computation cost of giant matrix multiplications in wide output layers (WOLs), which can easily contain millions of neurons 

Bhatia16. To tackle this problem, many existing methods focus on one principle direction: reducing computation by approximating the full WOL output with a few representative logits. A common formulation for this direction in literature is to pose the problem as a Maximum Inner Product Search (MIPS) shrivastava2014asymmetric problem and exploit MIPS approximation algorithms shrivastava2014improved; guo2016quantization; zhang2018navigating. Specifically, each input embedding from the previous hidden layer serves as a query, and neurons from the output layer are treated as data. The goal is to find the top-k neurons that have the maximum inner product with the query in sub-linear time. The model then performs predictions based on the logits computed by the query and only the selected neurons. Current approximate MIPS algorithms focus on indexing the data via different data structures. In this way, search computation is largely reduced. As expected, there is a trade-off between search efficiency and accuracy.

Shortcomings of approximate MIPS formalism: An approximate MIPS is a natural formulation for WOL inference since the inference phase of neural networks (NNs) treats logits (a monotonic function of the inner product between data embeddings and label neurons) as the score of the label shrivastava2014asymmetric. As a result, the topmost inner product is indeed the correct label, justifying the need to approximate the MIPS subroutine. However, the approximation brings a new source of trade-offs. The hardness of MIPS or any near-neighbor solutions depends on two things: 1) the value of the inner product of the correct class we want to search and 2) the gap between the best inner product values and the values of other inner products in consideration. Often with a large number of classes, the inner product values, even for the correct class, are significantly smaller. At the same time, there are many other classes with roughly the same inner products. As a result, it is unreasonable to expect any approximate MIPS sub-routine to efficiently retrieve the correct labels accurately. The final prediction of the network is very sensitive to this retrieval accuracy.

Better retrieval can even beat full softmax: We all know that NN models cannot generalize perfectly to testing data in practice. Specifically, the correct class might not have the maximum prediction score. Consider some statistics from Delicious-200K datasetdeli200k: the average rank of label neurons in inner products is only 498.14 out of 205443 during the inference phase of a fully trained classification model with a WOL of over 200k neurons and Softmax function. However, suppose we have an oracle that retrieves a set of neurons in a WOL with two properties: 1) The neuron representing the correct label is in the set and, 2) all other neurons in the retrieved set are likely to have a smaller inner product than the correct label. With this retrieval oracle in place, even if the correct label does not have the highest logit in the full prediction, it would be the highest one in this retrieved set, leading to even better accuracy than full softmax.

Therefore, based on the above observation, we design a superior retrieval mechanism, Label Sensitive Sampling (LSS), which enforces the two objectives mentioned above. The construction of this mechanism leverages the pre-trained model and its corresponding training dataset. We summarize our contributions as follows:

  • [leftmargin=*,nosep,nolistsep]

  • We identify an objective gap between efficient inference and the classical MIPS formulation. Moreover, to bridge this gap, we observe that inference over a perfect subset of output neurons is both efficient and accurate.

  • Based on the above observation, we propose Label Sensitive Sampling (LSS), a hashing-based method that uses a learning mechanism to incorporate ground truth information in the retrieval function. This retrieval mechanism can sample a small subset from WOL neurons with a high probability of including label neurons.

  • We provide rigorous evaluations of our method on four large benchmark datasets using two different model architectures. We show that our method achieves up to speed up and at most energy reduction without any loss in accuracy compared to full computation inference.

2 Related Work

2.1 Efficient Inference for Wide Output Layers

Many approaches have been developed for efficient inference over WOLs. Most of these methods can be categorized as an approximate MIPS problem. zhang2018navigating

proposed a graph-based method that maps the database vectors in a proximity graph 

tan2019efficient; zhou2019mobius and outperforms traditional PCA bachrach2014speeding or SVD shim2017svd approaches in language modelling tasks. However, graph-based methods severe performance degradation’s in parallel settings because of the difficulty in batching the greedy walks over the graph. Meanwhile, morozov2018non also mentioned the potential risks in the asymmetric transformation in zhang2018navigating . On the other hand, several MIPS solversguo2016quantization; wu2017multiscale have been proposed for inference over WOLs. However, these solvers trade plenty of computation for accuracy and are both energy and time consuming, even with full parallelism. We provide a detailed literature review in Appendix A.

2.2 Hashing Algorithms for Large Scale Learning

Hashing based data structures are widely applied in machine learning tasks at scale chen2018lshff; spring2020mutual. In formal terms, we consider as a family of hash functions that maps to some set .

Definition 1 (LSH Family).

A family is called -sensitive if for any two points and chosen uniformly from satisfies:

  • [leftmargin=*,nosep,nolistsep]

  • if then

  • if then

Here is a similarity measure and and is required. Details are presented in Appendix A. The general idea of these LSH functions is to pre-partition the dataset into buckets where vectors within the same bucket are similar shrivastava2014asymmetric; Proc:Indyk_STOC98; indyk2006polylogarithmic. Therefore, given a query vector, the computation can be focused on a tiny subset of the large database. Taking advantages of this massive computation reduction, hashing methods have been applied in: (1) Feature representation: li2011hashing; li2012one demonstrate an efficient way of performing efficient linear learning via permutation-based hashing that preserves Jaccard Similarity preserved (2) Neural network training: SLIDE

propose a sub-linear deep learning engine that use LSH to select neurons in forward and backward pass of NN training and achieve outperforming efficiency on CPU compared to a Tensorflow implementation on GPU. (3) Fast nearest neighbor search:

shrivastava2014defense; wang2017flash provide algorithms that tackle efficiency bottlenecks in metric similarity search on ultra high dimensional space.

Figure 1: The LSS pipeline in two stages: 1) During preprocessing, we incorporate label information from training data in LSH hash functions and rebuild hash tables accordingly. 2) During the actual inference, the softmax computation of the WOL is based on a subset of neurons retrieved by the input embeddings from hash tables rather than the full set of neurons.

3 Bridging the Gap Between MIPS and NN Inference

3.1 Notation and Settings

In the WOL setting, we denote the WOL’s weight matrix as

and its bias vector as

, where is the size of output layer (number of classes) and is the embedding dimension. The WOL can be represented by a set of neurons , where each neuron constitutes , the row of , and the element of . Typically, for a wide output layer. During inference, given an input embedding from the previous hidden layer, the output of a forward pass through the WOL is , where

is some activation function that translates the logits into probabilities; then, the indices of the largest logits are returned as the predicted classes. This formulation of WOLs can be applied to the softmax output layer in language modeling and extreme classification, as well as the matrix factorization in collaborative filtering 


Our first goal is to construct the retrieval oracle introduced in Section 1. We formulate the objective of the oracle as the construction of a Perfect Retrieval Set. For each input embedding (query), we sample a subset of neurons, such that and . In the sampling process, we want to maximize the probability of retrieving label neurons. Moreover, label neurons should have the highest inner products within the subset. Formally,

Definition 2 (Perfect Retrieval Set).

Given a WOL with neurons , for each input embedding , with labels in the multi-label setting, we want to sample a subset with size such that, .

3.2 Algorithm Overview

To construct our retrieval oracle, which ideally returns a Perfect Retrieval Set, we introduce Label Sensitive Sampling (LSS). LSS is a scheme that exploits Locality Sensitive Hashing (LSH) along with an efficient hyperplane (hash function) training procedure to approximate the maximum inner product in WOLs using label information. Note that we choose a particular variant of LSH called SimHash 

charikar2002similarity, which is parameterized by hyperplanes in the dimensionality of the query. The full workflow is illustrated in Figure 1.

LSS works in two separate phases: an offline preprocessing stage (shown in algorithm 1) and then an online inference stage (shown in algorithm 2). We summarize the offline construction as the following three steps: (1) Given the weight of a trained model, we build Locality Sensitive Hash Tables() via hash functions () constructed by (initially) randomly generated hyperplanes. (2) Based on the neurons retrieved from our initial hash tables by each query, we iteratively update the hyperplanes (hash functions) with our designed sampling loss. For the online inference phase, we summarize LSS as the following two steps: (1) Given input from the test set, we compute the forward pass up to the last layer and get an embedding . Then, we query the hash tables with . (2) We set retrieved neurons as “active”, and all other neurons as inactive for this input. Finally, we perform prediction on the “active” neurons and the top-ranked neurons (with highest logits) are returned as the prediction.

Why Hashing based Indexing: In this work, we index the neurons (only Ids) to hash tables before performing LSS. There are three major advantages of using hash tables as the data structure for efficient inference: (1) Efficiency: In the preprocessing phase introduced in Appendix A, hash tables have lower time and space complexity compared to tree or graph methods ann. In the query phase, hash table lookup operations are also faster than a greedy walk on tree or graph structures. (2) Differentiability: The projection step in hash functions represents a space partition and can be adjusted via gradient-based methods wang2017survey; hashnet. (3) Scalability: Compared to MIPS solvers guo2016quantization and graph methods morozov2018non; zhang2018navigating, hashing methods are more amenable towards less computation and multi-threading  wang2017flash. Therefore, hashing methods are capable of massive parallel inference on CPUs.

3.3 Preprocessing: Construction of Hash Tables

Given a trained model, we construct hash tables, where each hash table has a capacity of , and insert each WOL neuron into each hash table. We do so using binary hash functions per table (the hash table keys are constructed by concatenating the binary hashes together). In total, we need hash functions. Specifically, each of the hash bits of an input is generated by function , where each column of is drawn i.i.d. from . This is equivalent to the method used in Simhash charikar2002similarity. Geometrically, each () represents a projection hyperplane in , such that the space is partitioned by hyperplanes.

1:  Input: , , , , , ,
2:  ,
3:  for  do
4:     Compute .
5:     S =
6:     for  do
7:         = Query(, )
8:     end for
11:  end for
12:  shuffle
16:  return  ,
Algorithm 1 Preprocessing

During initialization, we insert neuron Ids into each hash table. Recall each neuron can be represented by the concatenation of its weight and bias parameters, . Therefore, for each , we generate hash codes, which constitute total hash table keys, and insert into hash tables accordingly. For each input in the training set, we collect its embedding before it is fed forward to the WOL. Then, serves as an input embedding query to retrieve corresponding neuron Ids from the hash tables. For simplicity, we omit , and and directly use and in the following sections.

In order to possess the properties of a Perfect Retrieval Set, the hash functions should have the following properties: (1) the collision probability between the input embedding query and its ground truth label neuron is high (2) the collision probability between the input embedding query and its non-label neurons is low (3) neurons are distributed evenly over all buckets for better load-balancing, which leads towards lower overhead (otherwise no efficiency gain). We formally define our ideal hash function as the following:

Definition 3 (Label Sensitive Hash Family).

A hash family is called -sensitive if for a triplet , a hash function chosen uniformly from satisfies:

  • [leftmargin=*,nosep,nolistsep]

  • if then

  • if then

We approximate hash functions from such a family based on an iterative learning mechanism, which encourages the above three properties using an Index Update Loss (IUL). The key to this learning process is the collection of positive and negative pairwise training samples. For each input embedding , we retrieve its corresponding set of neurons from the existing hash tables. Then, pairwise training samples are collected according to the following criterion:

  • [leftmargin=*,nosep,nolistsep]

  • positive pair

  • negative pair

Difference from Standard Learning to MIPS: The positive and negative set construction is essential. It should be observed that standard learning approaches, focused on MIPS objective and use every positive and negative pair for training the hash function, potentially solving a harder problem. Instead, we only use the negative pairs arising from the buckets, and positive pairs missed by buckets if they are the correct labels. Overall, our training is aware of the retrieval mechanism and only enforces what is needed for classification.
Index Update Loss (IUL): We use Hamming distance as an approximation of the difference between the hash codes of one training pair. Since is a discrete function, we use as a differentiable approximation. We know that  hashnet. Therefore, if the inner product between the query embedding and a particular neuron is high, they tend to have similar hash codes. Thus, we propose an Index Update Loss based on Hamming distance to update hash functions (random hyperplanes) with collected . Formally,


where , , and .

The intuition behind our IUL design is that positive pairs are encouraged to land in the same bucket while negative pairs are pushed towards different buckets. Positive samples are used to maximize the probability of including correct label neurons. Concurrently, it increases the relative ranking of label neurons on the inner product by decreasing the probability of retrieving other non-essential neurons. Negative samples are used to maintain a relatively small bucket size by pushing out low inner product neurons from the bucket. Otherwise, all the neurons would ultimately converge to the same bucket in each table. We collect the pairwise training data based on each retrieved set because it directly reflects the circumstances on how data are separated by the current hyperplanes. and are two inner product ranking thresholds, that control the inner product quality of positive and negative pairs. Usually, we have in any valid setting. Otherwise, in the situation that a certain label neuron has a small logit, due to the nature of LSH, it would be challenging to train hyperplanes in the manner that low inner product neurons are retrieved while high inner product neurons are excluded.

1:  Input: , , ,
2:  Compute
3:  S =
4:  for  do
5:      = Query(, )
6:  end for
7:  return  
Algorithm 2 Inference of one sample

3.4 Online Efficient Inference

After presenting the most essential component of our proposal, we introduce the online inference process. The model weights, hash functions, and hash tables are frozen before the inference. For each input in the testing set, the output embedding from the second to the last layer is first computed. The hash codes of the embedding are generated to retrieve the corresponding neurons from the hash tables. The next step is similar to the usual inference routine of WOL for making predictions. However, instead of performing the full inference, the model only computes the logits of retrieved neurons.

4 Evaluation

In this section, we evaluate the effectiveness of our proposed LSS method in efficient inference for WOL on two large scale extreme classification and two language modeling datasets. Specifically, we would like to answer the following questions: (1) Does LSS outperform other efficient inference approaches on energy and time? (2) How do the inner metric change during the learning process of LSS? (3) Can LSS always surpass the accuracy of full inference in a shorter time?

Datasets and Models: For extreme classification, we use a standard fully connected neural network with one hidden layer of size 128. We evaluate on two datasets: Wiki10-31K wiki10 and Delicious200K deli200k. For language modeling, we use a standard fully connected network with one hidden layer of size 128 for the Text8 text8, and a two-layer LSTM network with a hidden dimension size of 200 for wiki-text-2 wikitext2. We present more experiment details in Appendix B.

Baselines: We compare the proposed LSS against the following state-of-the-art methods: (1) SLIDE SLIDE is a deep learning system utilizing locality-sensitive hashing for faster training, written in C++. We implement this method for inference. (2) Graph Decoder (GD) is a MIPS method proposed for efficient Softmax inference in zhang2018navigating that combines the asymmetric transform in  bachrach2014speeding with HNSW malkov2018efficient. Here we exploit the original implementation of HNSW HNSW and pre-process the data according to  zhang2018navigating. (3) ip-NSW is a state-of-the-art graph-based MIPS algorithm proposed in morozov2018non; ipnsw. It belongs to the direct MIPS category and shows performance improvement over GD. (4) Product Quantization (PQ) pq

is a MIPS solver with K-means and asymmetric transformation. We implement this method following the popular open-source ANNS platform from Facebook

JDH17. (5) FULL is the regular but paralleled NN inference using all neurons in the last layer.

Implementation and Experiment Setting:

All the experiments are conducted on a machine equipped with two 20-core/40-thread processors (Intel Xeon(R) E5-2698 v4 2.20GHz). The machine is installed with Ubuntu 16.04.5 LTS. LSS for the output layer is written in C++ and compiled under GCC7 with OpenMP. The full inference is implemented in PyTorch. GD, ip-NSW, PQ are implemented in C++ with OpenMP. All implementation is parallelized with multi-threading with full usage of CPU cores. All baselines use the best results after extensive hyperparameter search. CPU energy consumption is monitored over time with the command line tool described in Appendix C.

Evaluation Metric:

We compare our method against other baselines from multiple evaluation metrics: (1)

Precision@k (P@k) for multi-label classification tasks. (2) Label Recall indicates the proportion of the correct labels in the retrieved ones. (3) Time is measured as the average wall-clock time for passing 1000 testing data through the last layer in seconds. (4) Energy Consumption, measured in Joules, is the average CPU power (Watts) over the inference period, multiplied by the inference time. It is then averaged for every 1000 samples. (5) Collision Probability indicates the probability that a pair of inputs are hashed to the same bucket for a fixed hash table. We expect positive pairs to have high collision probabilities and negative pairs to have low collision probabilities.

4.1 Main Result

In this section, we compare LSS with baseline methods for the trade-off between accuracy and efficiency. For each method, we aim to minimize the time and energy spent in the inference while maximizing the and . Following this strategy, we report the best performance of LSS and all other baselines on four datasets in Tables 1, 1, 1, 1. From these tables, we observe that: (1) LSS achieves the best and compared to other methods. (2) The sample size of the LSS method is the smallest. LSS uses at most 6% of the neurons in the output layer. Furthermore, we can see that on a larger dataset, LSS samples even fewer neurons. In Delicious200K, the output space is over , while LSS only uses 360 neurons for inference computation. In Text8, the output dimension is 1,355,336, while LSS only uses 965 neurons on an average. LSS can match full accuracy with much less computation. (3) Most importantly, we observe that LSS achieves up to reduction in time and reduction in energy consumption.

The experimental results validate our argument regarding the gap between MIPS and inference, as well as our choice of using hash tables. (1) We observe that MIPS approximation algorithms usually have a low label retrieval rate. (2) Even though other baselines only visit a small portion of neurons, they fail to achieve consistent speed up. These observations validate our reason for choosing a hashing-based approach, as it is easier to parallelize. Previous works zhang2018navigating; chen2018learning compared the performances of different methods under a single CPU thread setting, which is not a practical simulation for real-world cloud systems. Moreover, methods such as ip-NSW or GD are ill-suited to exploit the full parallelization offered by multi-core CPUs and tend to have a large number of irregular memory accesses. This limitation significantly degrades their performance even compared to exact MIPS computation on CPU with the current PyTorch framework.

Based on the above results, we answer the first question from the beginning of the section: compared to full inference, LSS can perform inference with comparable accuracy using only 12 energy and time. Moreover, even in scenarios where full inference is extremely parallelizable and outperforms all other approximate MIPS approaches, LSS still achieves the best efficiency.

p@1 0.4245 0.4391 0.1079 0.0693 0.4362
p@5 0.3473 0.3619 0.1180 0.0256 0.3581
Sample size 424 full full 3000 3000
Label Retrieval Rate 0.889 1 0.3464 0.0900 0.7000
Avg. Time Per 1000 samples(s) 0.81(5.1x) 4.16 10.51 2.45 2.29
Avg. Energy Per 1000 samples (J) 8.70 (8.2x) 71.34 116.61 33.88 29.05
Results for Delicious200K
0.9132 0.9129 0.1631 0.8299 0.9129
0.7404 0.7370 0.1631 0.6652 0.7370
965 full full 3000 3000
1 1 0.5842 0.8977 0.9908
0.56(3.3x) 1.88 13.92 2.07 2.09
4.98(6.4x) 31.99 174.49 22.66 20.40
Results for Text8
p@1 0.8018 0.8232 0.3309 0.3207 0.7636
p@5 0.4822 0.5700 0.3259 0.1603 0.4790
Sample size 559 full full 1500 3000
Label Retrieval Rate 0.9779 1 0.8905 0.4854 0.9163
Avg. Time Per 1000 samples(s) 0.39(1.9x) 0.76 4.06 1.65 1.69
Avg. Energy Per 1000 samples (J) 3.53(3.0x) 10.69 39.28 15.75 15.80
Results for Wiki10-30K
0.4265 0.4044 0.2234 0.0750 0.1369
0.0837 0.0774 0.0430 0.0271 0.0478
3071 full full 5365 4956
0.9284 1 0.6654 0.8705 0.9215
0.36(1.7x) 0.63 10.57 1.60 1.76
3.20(2.9x) 9.31 128.76 23.92 27.07
Results for Wiki-Text2
Table 1: Baseline comparisons on various datasets
Figure 2: Blue line plots the collision probability between positive pairs. Green line plots the collision probability between negative pairs. For all the experiment, hashing function training batch is 256. The starting point represents the probabilities of random simhash.

4.2 Inner Metrics of LSS

Collision Probability We investigate the process of training LSS hash functions. Figure 2 shows the collision probability for both positive and negative pairs we collect. We observe that, in both Text8 and Delicious200K, collision probability between positive pairs increases and converges to a level above 0.9. Meanwhile, the collision probability between negative pairs decreases throughout the training process. In Delicious200K, the collision probability of negative pairs converges to around 0.1 while in Text8, it converges close to zero. This observation helps to explain the significantly higher label retrieval rate of LSS: the learned hash functions identify projections that assign higher collision probabilities of label neurons with a high inner product. As we mentioned in section 1, the average inner product rank of label neurons is only 498/205443. With the LSS learned hash functions, we achieve a significantly higher label neuron inner product rank of 7.77.

Choosing K and L: We investigate the choice of , the number of projections, and , the number of hash tables, on Delicious200K. and are the two main hyperparameters affecting computation time and accuracy. The objective of this experiment is to establish the robustness of LSS’s accuracy with various sample sizes. This robustness directly relates to the trade-off between accuracy and inference efficiency. As reported in table 2, and leads to the most efficient inference with tolerable accuracy loss. This hyperparameter is chosen because it requires less hash code computations (determined by ), a fewer number of table lookups (determined by ), and smaller last layer matrix multiplication (determined by sample size). On the other hand, we observe that and do not vary too much with modification of and . This experimental phenomenon suggests that the learning objectives can consistently guide the LSS towards a set of retrieved neurons achieving decent and , independent of the parameters of hash table structures.

K=4 K=6 K=8
P@1 P@5 Sample Size P@1 P@5 Sample Size P@1 P@5 Sample Size
L=1 0.4245 0.3473 424 NA NA 0 NA NA 0
L=10 0.4602 0.3676 2560 0.4488 0.3733 875.53 0.4408 0.3598 153.31
L=50 0.4405 0.3659 15568 0.4455 0.3599 2122.47 0.4457 0.3615 360
Table 2: Effect of on for Delicious-200K dataset.

Energy Efficiency: We summarize energy usages of LSS in table 1. On all datasets, we observe significant energy reduction. To further distinguish the learning mechanism from the naive data structure, we measure the energy usage of SLIDE, which exploits random Simhash. Given the Wiki10-30K dataset, we initialize the hash tables with random projection and then vary and for performance. Using the same optimization criterion described in Section  4.1, it retrieves 741 neurons, and on average, it takes 0.68 ms and 9.89 J to process 1000 samples. Compared to full inference, SLIDE requires less time and energy. This reduction demonstrates the benefit of hashing based retrieval functions. Hashcode calculation and hash table lookup are both energy-efficient operations. On the other hand, LSS achieves better performance compared to SLIDE. This reduction directly relates to smaller sample size and demonstrates the power of our learning mechanism.

Based on the analysis of internal metrics, we answer the second question: The learning process in LSS generates better projections that benefit for: (1) Label neurons retrieval, (2) Robustness over hash table parameters, (3) Energy and time improvements.

4.3 Accuracy Advantage

We demonstrate the potential of surpassing full inference accuracy with the right retrieval mechanism. Here we modify the optimization direction. We aim to find LSS parameters that outperform full inference in accuracy with sub-sampled neurons instead of finding the optimal accuracy-efficiency trade-offs. Table 3 presents the highest accuracy LSS achieves on each dataset. We notice that on Delicious-200K and Wiki-Text2, LSS can outperform full computation accuracy. For Wiki10-31k and Text8, we achieve matching and . Based on these results, we answer the third question: LSS is capable of surpassing full inference via sampling an subset of neurons for each input embedding. It suggests LSS’s strength in distinguishing label neurons.

Dataset Wiki10-31k Delicious200k Text8 wiki2
LSS p@1 0.8232 0.4596 0.9129 0.4265
LSS p@5 0.4822 0.3676 0.7370 0.0837
sample Size 2372 1487 1008 3071
Full p@1 0.8232 0.4391 0.9129 0.4044
Full p@1 0.5700 0.3619 0.7370 0.0774
Table 3: This table summarizes the highest accuracy of LSS on four dataset

5 Conclusion

In this paper, we introduce LSS—a hashing based approach that performs energy and time efficient inference on a wide output layered neural networks. We show a novel problem formulation that identifies and bridges the gap between efficient inference and maximum inner product search (MIPS). We propose a combination of hashing based data structures and hyperplane learning objectives for efficient retrieval of label neurons. We present a novel index update loss that dynamically adapts the hash functions that reduce the sample size while preserving the prediction accuracy. We compare LSS against graph-based MIPS methods, direct MIPS solvers, and exact full inference on four real-world scenarios. We show that LSS substantially outperforms other MIPS baselines from both accuracy and efficiency metrics, with up to 8x energy reduction and up to 5x speedup compared to ideally parallelized full inference.

Broader Impact

This paper proposes a novel method for efficient inference on large output layer models, which have a wide range of applications in real-world settings such as recommendation systems and language models. According to Facebook, its deep neural network-based recommendation systems consume more than 70 % of the cloud server’s workload. It is more than evident that these large models require massive computation and lead to high electricity usage and CO2 emissions. We show that our method achieves 60-80 % energy reduction with negligible accuracy loss. Our method is even orthogonal to the quantization direction, which has been recognized as the to-go way to reduce 30-40 % energy usage. We believe our work takes a solid step towards a more environmentally friendly and financially friendly machine learning.


Appendix A Related Literature

a.1 Maximum Inner Product Search

The application of deep neural network (NN) models in cloud services is usually associated with a Wide Output Layer (WOL). For recommendation NN models [Bhatia16, xue2017deep], the size of the WOL is equal to the number of items to be recommended. For language modelling, the number of neurons in the WOL is equal to the vocabulary size. This large output space becomes the computation bottleneck and for this paper, we specifically focus on inference efficiency, which is a major concern for deployment in a cloud computing setting

To tackle this inefficiency, various methods focus on efficient retrieval of the Top-K logits generated by the NN model. Most of these methods can be categorized as an approximate Maximum Inner Product Search (MIPS) problem. Formally, we aim to solve the following problem: given a set containing all neurons in the WOL as high dimensional vectors and each input embedding to the output layer as query , we aim to develop an efficient algorithm for computing


There are two main categories for efficient inference via MIPS. The first branch of methods aims to reduce the MIPS to classical approximate nearest neighbor search (ANNS) method [shrivastava2014asymmetric, shrivastava2014improved], which can be summarized as two steps: (1) pre-processing the data vector to and query vector to asymmetrically so that . Here is cosine distance or euclidean distance. (2) perform ANNS via indexing structures such as quantization [JDH17], or small world graph [malkov2012scalable, malkov2014approximate, malkov2018efficient]. Another category of MIPS-based methods target at directly performing inner product search without reduction. This direct inner product search can be performed via graph due to the flexibility of modifying the edge definitions [morozov2018non, zhou2019mobius, tan2019efficient].

a.2 Locality Sensitive Hashing

In this section, we briefly describe the recent development of using locality sensitive hashing [Proc:Indyk_STOC98, indyk2006polylogarithmic]. The high-level idea of LSH is to place similar items into the same bucket of a hash table with high probability. In formal terms, we consider as a family of hash functions that maps to some set .

Definition 4 (LSH Family).

A family is called
-sensitive if for any two points and chosen uniformly from satisfies:

  • [leftmargin=*,nosep,nolistsep]

  • if then

  • if then

Typically, and is needed. Moreover, the algorithm uses two parameters, . We construct independent hash tables from the collection. Each hash table has a meta-hash function that is formed by concatenating random independent hash functions. Given a query, we collect one bucket from each hash table and return the union of buckets. Intuitively, the meta-hash function makes the buckets sparse and reduces the number of false positives, because only valid nearest-neighbor items are likely to match all hash values for a given query. The union of the buckets decreases the number of false negatives by increasing the number of potential buckets that could hold valid nearest-neighbor items. One sufficient condition for a hash family to be a LSH family is that the collision probability is a monotonically increasing function of the similarity, i.e.


where is a monotonically increasing function.

The overall generation algorithm of nearest neighbor candidates works in two phases (See [spring2017new, chen2018lsh] for details):

Pre-processing Phase: Constructing hash tables from the data by storing all elements . We only store pointers to the vectors in the hash tables because storing whole data vectors is very memory inefficient.

Query Phase: Given a query , we search for its nearest neighbors. We obtain the union from all buckets collected from the hash tables. Note that, we do not scan all the elements in , we only probe different buckets, one bucket for each hash table. After generating the set of potential candidates, we compute the distance between the query and each item in the candidate set, and sort to find the nearest neighbor.

Appendix B Experiment Details

b.1 Dataset Statistic

In our work, we present experiment on 4 datasets. The first two datasets, Wiki10-31k and Delicious-200K are obtained from the Extreme Classification Repository[Bhatia16], which is a benchmark for various recommendation systems. Each Extreme Classification dataset uses Bag-of-words (BoW) features as input and multi-hot label vector as output. For language modelling, we introduce 2 datasets from two models. For Word2vec model, we use the text8 dataset from [text8]. We conduct three preprocessing steps on the dataset: (1) Remove the words with frequency less than 2 from the vocabulary and mark the removed word as ’UNK’. (2) Represent each word in the document as input one-hot vector. (3) For each input word, represent its previous 25 words and after 25 words as a multi-hot vector. Then, use the vector as label. We also introduce a RNN based language models that uses the Wiki-Text-2 dataset [wikitext2]. In this dataset, we given a 35 word sequence as a multi-hot vector input, we would like to predict the next 35 words sequence. Therefore, the label vector is also multi-hot. Details about the datasets are shown in table below.

Dataset Wiki10-31k Delicious-200K Text8 Wiki-Text-2
Output Dimension 30938 205443 1355336 50000
Input Dimension 101938 782585 1355336 50000
Training Samples 14146 6616 11903644 725434
Testing Samples 196606 100095 5101563 245550

Table 4: Summary of output dimension for our benchmark dataset

b.2 Task and Models

Extreme Classifications

The Extreme Classifications model targets at predicting the labels with ultra-high label Dimensionality given the input BoW features in ultra-high Dimensionality. The network architecture is summarized as: (1) Embedding layer that maps multi-hot input vector into a dense 128 dimension vector. (2) Relu actiation function. (3) Output layer with number of neurons equal to label Dimensionality.

Word2vec The Word2vec model targets at predicting the neighbor words given the central word. The network architecture is summarized as: (1) Embedding layer that maps one-hot input vector into a dense 128 dimension vector. (2) Relu actiation function. (3) Output layer with number of neurons equal to vocabulary size.

RNN The RNN language model targets at predicting the next sequence of words given the current sequence of words. The network architecture is summarized from input to output as: (1) Embedding layer that maps multi-hot input vector into a dense 200 dimension vector. (2) First Dropout function. (3) 2 LSTM layers with hidden size equivalent to 200. (4) Second Dropout function. (3) Output layer with number of neurons equal to vocabulary size.

b.3 IUL loss

  • [leftmargin=*,nosep,nolistsep]

  • positive pair

  • negative pair

We use Hamming distance as an approximation of the difference between the hash codes of one training pair. Since is a discrete function, we use as a differentiable approximation. We know that  [hashnet]. Therefore, if the inner product between the query embedding and a particular neuron is high, they tend to have similar hash codes. Thus, we propose an Index Update Loss based on Hamming distance to update hash functions (random hyperplanes) with collected . Formally,


where , , and .

Appendix C Energy Measurement

In our work, we measure the energy of inference methods via a monitoring tool. Command line utility tools, including s-tui, were used to monitor the CPU power consumption, in Watts (Joules / second), over the inference times for each dataset and each method, in intervals of 1 second. The base power consumption would be subtracted from the average power of the inference time, in order to measure and compare the energy expenditure of only the inference step of each method.