learn-to-hash
Neural LSH [ICLR 2020] - Using supervised learning to produce better space partitions for fast nearest neighbor search.
view repo
Most of the efficient sublinear-time indexing algorithms for the high-dimensional nearest neighbor search problem (NNS) are based on space partitions of the ambient space R^d. Inspired by recent theoretical work on NNS for general metric spaces [Andoni, Naor, Nikolov, Razenshteyn, Waingarten STOC 2018, FOCS 2018], we develop a new framework for constructing such partitions that reduces the problem to balanced graph partitioning followed by supervised classification. We instantiate this general approach with the KaHIP graph partitioner [Sanders, Schulz SEA 2013] and neural networks, respectively, to obtain a new partitioning procedure called Neural Locality-Sensitive Hashing (Neural LSH). On several standard benchmarks for NNS, our experiments show that the partitions found by Neural LSH consistently outperform partitions found by quantization- and tree-based methods.
READ FULL TEXT VIEW PDFNeural LSH [ICLR 2020] - Using supervised learning to produce better space partitions for fast nearest neighbor search.
The Nearest Neighbor Search (NNS) problem is defined as follows. Given an -point dataset in a -dimensional Euclidean space , we want to preprocess to answer -nearest neighbor queries quickly. That is, given a query point , we want to find the data points from that are closest to . NNS is a cornerstone of the modern data analysis and, at the same time, a fundamental geometric data structure problem that led to many exciting theoretical developments over the past decades. See, e.g., [WLKC16, AIR18] for an overview.
The main two approaches to constructing efficient NNS data structures are indexing and sketching. The goal of indexing is to construct a data structure that, given a query point, produces a small subset of (called candidate set
) that includes the desired neighbors. In contrast, the goal of sketching is to compute compressed representations of points (e.g., compact binary hash codes with the Hamming distance used as an estimator
[WSSJ14, WLKC16]) to enable computing approximate distances quickly. Indexing and sketching can be (and often are) combined to maximize performance [JDJ17].Both indexing and sketching have been the topic of a vast amount of theoretical and empirical literature. In this work, we consider the indexing problem, and focus on optimizing the trade-off between three metrics: the number of reported candidates, the fraction of the true nearest neighbors among the candidates, and the computational efficiency of the indexing data structure.
Most of the efficient indexing methods are based on space partitions (with some exceptions mentioned below). The overarching idea is to find a partition of the ambient space and split the dataset accordingly. Given a query point , we identify the part containing and form the resulting list of candidates from the data points residing in the same part. To boost the search accuracy, it is often necessary to add all data points from nearby parts to the candidate list (this is often referred as the multi-probe technique). Some of the popular indexing methods include locality-sensitive hashing (LSH) [LJW07, AIL15, DSN17]; quantization-based approaches, where partitions are obtained via -means clustering of the dataset [JDS11, BL12]; and tree-based methods such as random-projection trees or PCA trees [Spr91, BCG05, DS13, KS18].
Recently, there has been a large body of work that studies how modern machine learning techniques (such as neural networks) can help tackle various classic algorithmic problems (for a small sample, see
[KBC18, BDSV18, LV18, Mit18]). Similar methods were used to improve the sketching approach to NNS (more on this below) [WLKC16]. However, when it comes to indexing, only very “rudimentary” unsupervised techniques such as PCA or -means have been (successfully) used. This state of affairs naturally leads to the following general question:Can we employ modern (supervised) machine learning techniques to find good space partitions for nearest neighbor search?
In this paper we address the aforementioned challenge and present a new framework for finding high-quality space partitions of , by directly optimizing the objective function that quantifies the performance of a partition for the NNS problem. At a high level, our approach consists of two steps. First, we perform balanced partitioning of the -NN graph built on the data points, where each point is connected to the nearest neighbors. Then we train a model to solve the supervised classification task, with inputs being the data points and labels given by the partition found during the first step. (See Figure 1
for the illustration.) The resulting classifier induces a partition of the
whole space , which is our end result.The new framework has multiple benefits:
We directly reduce the question of interest (geometric partitioning) to two well-studied problems, namely graph partitioning and supervised learning.
Our reduction is very flexible and uses partitioning and learning in a black-box way. This allows us to plug various models (linear models, neural networks etc.) and explore the trade-off between the quality and the algorithmic efficiency of the resulting partitions.
Our framework aims to optimize an objective function that directly controls the quality of a partition (assuming the distribution of queries is similar to the distribution of the data points), whereas many of the previous methods that work well in practice (e.g., -means) use various proxies instead.
It is important to note that our method is unsupervised; in particular, it does not require any given labeling of the input data points. Instead, we harness supervised learning to extend the solution to a finite unsupervised problem – graph partitioning on a fully observed set of points – to a space partition that generalizes to any unseen point in .
Further, we emphasize the importance of balanced partitions in the indexing problem. In a balanced partition of
, all parts contain roughly the same number of data points. Unbalanced partitions lead to large variance in the number of candidates reported for different queries, leading to an unpredictable computational cost. Conversely, balanced partitions allow us to control the number of candidates by parameterizing the total number of parts in the partition as well as the number of parts probed per query. A priori, it is unclear how to partition
so as to respect the balance of a given dataset. This makes the combinatorial portion of our approach particularly useful, as balanced graph partitioning is a well studied problem, and our supervised extension to naturally preserves the balance by virtue of attaining high training accuracy.We instantiate our framework with the KaHIP algorithm [SS13] for the partitioning step, and linear models and small-size neural networks for the learning part. We evaluate our approach on several standard benchmarks for NNS [ABF17] and conclude that in terms of quality of the resulting partitions, it consistently outperforms quantization-based and tree-based partitioning procedures, while maintaining comparable algorithmic efficiency. In the high accuracy regime, our framework yields partitions that require to process up to fewer candidates than alternative approaches.
As a baseline method we use -means clustering. It produces a partition of the dataset into parts, in a way that naturally extends to all of , by assigning a query point to its closest centroid. (More generally, for multi-probe querying, we can rank the parts by the distance of their centroids to ). This simple scheme produces very high-quality results for indexing.
The new framework is inspired by a recent line of theoretical work that studies the NNS problem for general metric spaces [ANN18a, ANN18b]. The two relevant contributions of these works are as follows. First, they prove that graphs embedded into sufficiently “nice” metric spaces (including Euclidean space, but also many others) with short edges but without “dense regions”, must have sparse cuts. Second, for the special case of normed spaces defined on , such sparse cuts can be assumed to be induced by geometrically nice subsets of the ambient -dimensional space. This is directly related to the method developed in the present paper, where the starting point is a sparse (multi-)cut in a graph embedded into , which is then “deformed” to a geometrically nice cut using supervised learning.
On the empirical side, currently the fastest indexing techniques for the NNS problem are graph-based [MY18]. The high-level idea is to construct a graph on the dataset (it can be the -NN graph, but other constructions are also possible), and then for each query perform a walk, which eventually converges to the nearest neighbor. Although very fast, graph-based approaches have suboptimal “locality of reference”, which makes them less suitable for several modern architectures. For instance, this is the case when the algorithm is run on a GPU [JDJ17] or the data is stored in external memory [SWQ14].) This justifies further study of the partition-based methods.
Machine learning techniques are particularly useful for the sketching approach, leading to a vast body of research under the label “learning to hash” [WSSJ14, WLKC16]. In particular, several recent works employed neural networks to obtain high-quality sketches [LLW15, SDSJ19]. The fundamental difference from our work is that sketching is designed to speed up linear scans over the dataset, by reducing the cost of distance evaluation, while indexing is designed for sublinear time searches, by reducing the number of distance evaluations. Note that in principle, one could use sketches to generate space partitions, since a -bit sketch induces a partition of into parts. However, this is a substantially different use of sketches than the one intended in the above mentioned works. Indeed, we observed that partitions induced by high-quality sketching techniques do not perform well compared to, say, quantization-based partitions.
A different application of neural networks related to NNS is to optimize the performance of the nearest neighbor classifier. Given a labeled dataset in a classification setting, the idea is to learn a representation of the dataset – either as sketches [KW17, JZPG17] or as a high-dimensional embedding [ST18] – that would render the nearest neighbor classifier (i.e., labeling each query point with the label of its nearest data point) as accurate as possible. Apart from not producing an indexing method, these works are also different from ours by being inherently supervised, relying on a fully labeled dataset, whereas our approach is unsupervised.
Given a dataset of points, and a number of parts , our goal is to find a “simple” partition of into parts with the following properties:
Balanced: The number of data points in each part is not much larger than .
Locality sensitive: For a typical query point , most of its nearest neighbors belong to the same part of . We assume that queries and data points come from similar distributions.
Simple:
The partition should admit a compact description. For example, we might look for a space partition induced by hyperplanes.
First, suppose that the query is chosen as a uniformly random data point, . Let be the -NN graph of , whose vertices are the data points, and each vertex is connected to nearest neighbors. Then the above problem boils down to partitioning the graph into parts such that each part contains roughly vertices, and the number of edges crossing between different parts is as small as possible (see Figure 1(b)). This balanced graph partitioning problem is extremely well-studied, and there are available combinatorial partitioning solvers that produce very high-quality solutions. In our implementation, we use the open-source solver KaHIP [SS13].
More generally, we need to handle out-of-sample queries, i.e., which are not contained in . Let denote the partition of (equivalently, of the dataset ) found by the graph partitioner. To convert into a solution to our problem, we need to extend it to a “simple” partition of the whole space , that would work well for query points. In order to accomplish this, we train a model that, given a query point , predicts which of the parts of the point belongs to (see Figure 1(c)). We use the dataset as a training set, and the partition as the labels – i.e., each data point is labeled with the ID of the part of containing it. The geometric intuition for this learning step is that – even though the partition is obtained by combinatorial means, and in principle might consist of ill-behaved subsets of – in most practical scenarios, we actually expect it to be close to being induced by a simple partition of the ambient space. For example, if the dataset is fairly well-distributed on the unit sphere, and the number of parts is , a balanced cut of should be close to a hyperplane.
The choice of model to train depends on the level of “simplicity” we wish to impose on the final partition . For instance, if we are interested in a hyperplane partition, we can train a linear model using SVM or regression. In this paper, we instantiate the learning step with both linear models and small-sized neural networks. Here, there is a natural tension between the size of the model we train and the accuracy of the resulting classifier, and hence the quality of the partition we produce. A larger model would yield better NNS accuracy, at the expense of computational efficiency. We discuss this more in Section 3.
Given a query point , the trained model can be used to assign it to a part of , and search for nearest neighbors within the data points in that part. In order to achieve high search accuracy, we actually train the model to predict several parts for a given query point, which are likely to contain nearest neighbors. For neural networks, this can be done naturally by taking several largest outputs of the last layer. By searching through more parts (in the order of preference predicted by the model) we can achieve better accuracy, allowing for a trade-off between computational resources and accuracy.
When the required number of parts is large, in order to improve the efficiency of the resulting partition, it pays off to produce it in a hierarchical manner. Namely, we first find a partition of into parts, then recursively partition each of the parts into parts, and so on, repeating the partitioning for levels (see Figure 2 for the illustration). The total number of parts in the overall partition is . The advantage of such a hierarchical partition is that it is much simpler to navigate than a one-shot partition with parts.
In one instantiation of the supervised learning component, we use neural networks with a small number of layers and constrained hidden dimensions. The exact parameters depend on the size of the training set, and are specified in the next section.
In order to support effective multi-probe querying, we need to infer not just the part in which the query point resides, but rather a distribution over parts that are likely to contain this point and its neighbors. A -probe candidate list is then formed from all data points in the most likely parts.
In order to accomplish this, we use soft labels for data points generated as follows. For and a data point , the soft label is a distribution over the part containing a point chosen uniformly at random among nearest neighbors of (including itself). Now, for a predicted distribution , we seek to minimize the KL divergence between and : .
The purpose of the soft labels is to guide the neural network with information about the ranking of parts for searching nearest neighbors. Optimizing w.r.t. allows the model to predict multiple parts more accurately, which is necessary for achieving high accuracy via multi-probe querying.
is a hyperparameter that needs to be tuned. In practice, accuracy in the objective function increases in the regime when
is noticeably larger than , as more neighbors give the network a more “complete” distribution over parts.For the experimental evaluation, we use three standard ANN benchmarks [ABF17]: SIFT (image descriptors, 1M 128-dimensional points), GloVe (word embeddings [PSM14], approximately 1.2M 100-dimensional points, normalized), and MNIST (images of digits, 60K 784-dimensional points). All three datasets come with query points, which we use for evaluation. We include the results for SIFT and GloVe in the main text, and MNIST in Appendix A.
We mainly investigate the trade-off between the number of candidates generated for a query point, and the -NN accuracy, defined as the fraction of its nearest neighbors that are among those candidates. The number of candidates determines the processing time of an individual query. Over the entire query set, we report both the average as well as the
-th quantile
of the number of candidates. The former measures the throughput^{2}^{2}2Number of queries per second. of the data structure, while the latter measures its latency.^{3}^{3}3Maximum time per query, modulo a small fraction of outliers.
We mostly focus on parameter regimes that lead to -NN accuracy of at least . In all of our experiments, .We evaluate two variants of our method, corresponding to two different choices of the supervised learning component in our framework.
In this variant we use small neural networks. Their exact architecture is detailed in the next section. We compare Neural LSH to partitions obtained by -means clustering. As mentioned in Section 1, this method produces high quality partitions of the dataset that naturally extend to all of , and other existing methods we have tried (such as LSH) did not match its performance. We evaluate partitions into parts and parts. We test both one-level (non-hierarchical) and two-level (hierarchical) partitions. Queries are multi-probe.
This variant uses logistic regression as the supervised learning component and, as a result, produces very simple partitions induced by
hyperplanes. We compare this method with PCA trees [Spr91, KZN08, AAKK14], random projection trees [DS13], and recursive bisections using -means clustering. We build trees of hierarchical bisections of depth up to (thus, the total number of leaves is up to ). The query procedure descends a single root-to-leaf path and returns the candidates in that leaf.Neural LSH uses a fixed neural network architecture for the top-level partition, and a fixed architecture for all second-level partitions. Both architectures consist of several blocks, where each block is a fully-connected layer + batch normalization
[IS15]+ ReLU activations. The final block is followed by a fully-connected layer and a softmax layer. The resulting network predicts a distribution over the parts of the partition. The only difference between the top-level network the second-level network architecture is their number of blocks (
) and the size of their hidden layers (). In the top-level network we use and . In the second-level networks we use and . To reduce overfitting, we use dropout during training. The networks are trained using the Adam optimizer [KB15] for under epochs on both levels. We reduce the learning rate multiplicatively at regular intervals.A hierarchical partition produces a tree in which each node corresponds to a partition into parts. In our experiments, we evaluate and , thus the total number of parts in the two-level experiments are and respectively. In the latter case, each part contains fewer than data points, which is too small for supervised learning without overfitting. Therefore, in the two-level experiment with , we use Neural LSH at the top-level and -means clustering at the bottom level. In the other experiments (two-levels with and one-level with ) we use Neural LSH at all levels.
Note that multiple partitions of can be combined in ways other than the hierarchical approach we evaluate. For example, a common technique called Product Quantization combines multiple invocations of -means in a Cartesian product fashion, over a decomposition of into orthogonal subspaces [JDS11, BL12]. There are various techniques to tune and improve this approach [NF13, GHKS14, WGS17]. Since the focus of our paper is to compare the quality of individual partitions, we use hierarchical partitioning as a baseline approach to combining partitions. Nonetheless, we note that the above Cartesian product approach and related ideas can be readily applied to Neural LSH as well.
We slightly modify the KaHIP partitioner to make it more efficient on the -NN graphs. Namely, we introduce a hard threshold of on the number of iterations for the local search part of the algorithm, which speeds up the partitioning dramatically, while barely affecting the quality of the resulting partitions.
Figure 3 shows the empirical comparison of Neural LSH with -means. The points listed are those that attained an accuracy of at least . We note that the reported setting of two-level partitioning with is the best performing configuration of -means, for both SIFT and GloVe.^{4}^{4}4In terms of the minimum number of candidates that attains accuracy. Thus we evaluate the baseline at its optimal performance.
In all settings considered, Neural LSH yields consistently better partitions than -means. Depending on the setting, -means requires significantly more candidates to achieve the same accuracy:
Up to more for the average number of candidates for GloVe;
Up to more for the -quantiles of candidates for GloVe;
Up to more for the average number of candidates for SIFT;
Up to more for the -quantiles of candidates for SIFT;
Figure 4 lists the largest multiplicative advantage in the number of candidates of Neural LSH compared to -means, for accuracy values of at least . Specifically, for every configuration of -means, we compute the ratio between the number of candidates in that configuration and the number of candidates of Neural LSH in its optimal configuration, among those that attained at least the same accuracy as that -means configuration. The table lists the maximum ratio over all accuracy values of at least .
We also note that in all settings except two-level partitioning with ,^{5}^{5}5As mentioned earlier, in this setting Neural LSH uses -means at the second level, due to the large overall number of parts compared to the size of the datasets. This explains why the gap between the average and the -quantile number of candidates of Neural LSH is larger for this setting. Neural LSH produces partitions for which the -quantiles for the number of candidates are very close to the average number of candidates, which indicates very little variance between query times over different query points. In contrast, the respective gap in the partitions produced by -means is much larger, since unlike Neural LSH, it does not directly favor balanced partitions. This implies that Neural LSH might be particularly suitable for latency-critical NNS applications.
GloVe | SIFT | ||||
Averages | -quantiles | Averages | -quantiles | ||
One level | parts | 1.745 | 2.125 | 1.031 | 1.240 |
parts | 1.491 | 1.752 | 1.047 | 1.348 | |
Two levels | parts | 2.176 | 2.308 | 1.113 | 1.306 |
parts | 1.241 | 1.154 | 1.182 | 1.192 |
The largest model size learned by Neural LSH is equivalent to storing about points for SIFT, or points for GloVe.^{6}^{6}6The difference accounts for the different network architecture used for them, as well as their different dimensionality. This is considerably larger than -means with , which stores at most points. Nonetheless, we believe the larger model size is acceptable for Neural LSH, for the following reasons:
In most of the NNS applications, the bottleneck in the high accuracy regime is the memory accesses needed to retrieve candidates and the further processing (such as distance computations, exact or approximate). The model size is not a hindrance as long as does not exceed certain reasonable limits (e.g., it should fit into a CPU cache). Neural LSH significantly reduces the memory access cost, while increasing the model size by an acceptable amount.
We have observed that the quality of the Neural LSH partitions is not too sensitive to decreasing the sizes the hidden layers. The model sizes we report are, for the sake of concreteness, the largest ones that still lead to improved performance. Larger models do not increase the accuracy, and sometimes decrease it due to overfitting.
Here we compare binary decision trees, where in each tree node a
hyperplane is used to determine which of the two subtrees to descend into. We generate hyperplanes via multiple methods: Regression LSH, cutting the dataset into two equal halves along the top PCA direction [Spr91, KZN08], -means clustering, and random projections of the centered dataset [DS13, KS18]. We build trees of depth up to , which corresponds to hierarchical partitions with the total number of parts up to . We summarize the results for GloVe and SIFT datasets in Figure 5. For random projections, we run each configuration times and average the results.For GloVe, Regression LSH significantly outperforms -means, while for SIFT, Regression LSH essentially matches -means in terms of the average number of candidates, but shows a noticeable advantage in terms of the -percentiles. In both instances, Regression LSH significantly outperforms PCA tree, and all of the above methods dramatically improve upon random projections.
Note however, that random projections have an additional benefit: if one is willing to boost the search accuracy, it is enough to simply repeat the sampling process several times and generate an ensemble of decision trees instead of a single tree. This allows us to make each individual tree relatively deep, which decreases the overall number of candidates, trading space for query time. Other considered approaches (Regression LSH, -means, PCA tree) are inherently deterministic and boosting the accuracy requires more care: for instance, one can use partitioning into blocks in spirit of [JDS11] or see [KS18] for alternative approaches. Since we focus on individual partitions and not ensembles, we leave this issue out of the scope.
In this paper, we presented a new technique for finding partitions of which support high-performance indexing for sublinear-time NNS. It proceeds in two major steps:
We start with combinatorial balanced partitioning of the -NN graph of the dataset;
We extend the resulting partition to the whole ambient space by using supervised classification (such as logistic regression, neural networks, etc.).
Our experiments show that the new approach consistently outperforms quantization-based and tree-based partitions.
We believe that this study is just the first step in exploring the new partitioning approach, and there is a number of exciting open problems we would like to highlight:
Can we jointly optimize a graph partition and a classifier at the same time? By making the two components aware of each other, we expect the quality of the resulting partition of to improve.
Can our approach be extended to learning several high-quality partitions that complement each other? Such an ensemble can potentially be used to trade query time for memory usage [ALRW17].
Can we use machine learning techniques to improve graph-based indexing techniques [MY18] for NNS? (This is in contrast to partition-based indexing, as done in this work).
Our framework is an example of combinatorial tools aiding “continuous” learning techniques. A more open-ended question is whether there are other problems that can benefit from such symbiosis.
Proceedings of the 50th Annual ACM SIGACT Symposium on Theory of Computing
, pages 787–800. ACM, 2018.Cartesian k-means.
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 3017–3024, 2013.Glove: Global vectors for word representation.
InProceedings of the 2014 conference on empirical methods in natural language processing (EMNLP)
, pages 1532–1543, 2014.We include experimental results for the MNIST dataset, where all the experiments are performed exactly in the same way as for SIFT and GloVe. Consistent with the trend we observed for SIFT and GloVe, Neural LSH consistently outperforms -means (see Figure 6) both in terms of average number of candidates and especially in terms of the -th quantiles. We also compare Regression LSH with recursive -means, as well as PCA tree and random projections (see Figure 7), where Regression LSH consistently outperforms the other methods.