Stochastic Optimization of Sorting Networks via Continuous Relaxations

03/21/2019 ∙ by Aditya Grover, et al. ∙ Stanford University 8

Sorting input objects is an important step in many machine learning pipelines. However, the sorting operator is non-differentiable with respect to its inputs, which prohibits end-to-end gradient-based optimization. In this work, we propose NeuralSort, a general-purpose continuous relaxation of the output of the sorting operator from permutation matrices to the set of unimodal row-stochastic matrices, where every row sums to one and has a distinct arg max. This relaxation permits straight-through optimization of any computational graph involve a sorting operation. Further, we use this relaxation to enable gradient-based stochastic optimization over the combinatorially large space of permutations by deriving a reparameterized gradient estimator for the Plackett-Luce family of distributions over permutations. We demonstrate the usefulness of our framework on three tasks that require learning semantic orderings of high-dimensional objects, including a fully differentiable, parameterized extension of the k-nearest neighbors algorithm.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Learning to automatically sort objects is useful in many machine learning applications, such as top- multi-class classification (Berrada et al., 2018), ranking documents for information retrieval (Liu et al., 2009)

, and multi-object target tracking in computer vision 

(Bar-Shalom & Li, 1995)

. Such algorithms typically require learning informative representations of complex, high-dimensional data, such as images, before sorting and subsequent downstream processing. For instance, the

-nearest neighbors image classification algorithm, which orders the neighbors based on distances in the canonical pixel basis, can be highly suboptimal for classification (Weinberger et al., 2006)

. Deep neural networks can instead be used to learn representations, but these representations cannot be optimized end-to-end for a downstream sorting-based objective, since the sorting operator is not differentiable with respect to its input.

In this work, we seek to remedy this shortcoming by proposing NeuralSort, a continuous relaxation to the sorting operator that is differentiable almost everywhere with respect to the inputs. The output of any sorting algorithm can be viewed as a permutation matrix, which is a square matrix with entries in such that every row and every column sums to 1. Instead of a permutation matrix, NeuralSort returns a unimodal

row-stochastic matrix. A unimodal row-stochastic matrix is defined as a square matrix with positive real entries, where each row sums to 1 and has a distinct

. All permutation matrices are unimodal row-stochastic matrices. NeuralSort has a temperature knob that controls the degree of approximation, such that in the limit of zero temperature, we recover a permutation matrix that sorts the inputs. Even for a non-zero temperature, we can efficiently project any unimodal matrix to the desired permutation matrix via a simple row-wise operation. Hence, NeuralSort is also suitable for efficient straight-through gradient optimization (Bengio et al., 2013), which requires “exact” permutation matrices to evaluate learning objectives.

As the second primary contribution, we consider the use of NeuralSort for stochastic optimization over permutations. In many cases, such as latent variable models, the permutations may be latent but directly influence observed behavior, e.g., utility and choice models are often expressed as distributions over permutations which govern the observed decisions of agents (Regenwetter et al., 2006; Chierichetti et al., 2018)

. By learning distributions over unobserved permutations, we can account for the uncertainty in these permutations in a principled manner. However, the challenge with stochastic optimization over discrete distributions lies in gradient estimation with respect to the distribution parameters. Vanilla REINFORCE estimators are impractical for most cases, or necessitate custom control variates for low-variance gradient estimation 

(Glasserman, 2013).

In this regard, we consider the Plackett-Luce (PL) family of distributions over permutations (Plackett, 1975; Luce, 1959). A common modeling choice for ranking models, the PL distribution is parameterized by scores, with its support defined over the symmetric group consisting of permutations. We derive a reparameterizable sampler for stochastic optimization with respect to this distribution, based on Gumbel perturbations to the (log-)scores. However, the reparameterized sampler requires sorting these perturbed scores, and hence the gradients of a downstream learning objective with respect to the scores are not defined. By using NeuralSort instead, we can approximate the objective and obtain well-defined reparameterized gradient estimates for stochastic optimization.

Finally, we apply NeuralSort to tasks that require us to learn semantic orderings of complex, high-dimensional input data. First, we consider sorting images of handwritten digits, where the goal is to learn to sort images by their unobserved labels. Our second task extends the first one to quantile regression, where we want to estimate the median (

-th percentile) of a set of handwritten numbers. In addition to identifying the index of the median image in the sequence, we need to learn to map the inferred median digit to its scalar representation. In the third task, we propose an algorithm that learns a basis representation for the

-nearest neighbors (kNN) classifier in an end-to-end procedure. Because the choice of the

nearest neighbors requires a non-differentiable sorting, we use NeuralSort to obtain an approximate, differentiable surrogate. On all tasks, we observe significant empirical improvements due to NeuralSort over the relevant baselines and competing relaxations to permutation matrices.

2 Preliminaries

An -dimensional permutation is a list of unique indices . Every permutation is associated with a permutation matrix with entries given as:

Let denote the set of all possible permutations in the symmetric group. We define the operator as a mapping of real-valued inputs to a permutation corresponding to a descending ordering of these inputs. E.g.

, if the input vector

, then since the largest element is at the first index, second largest element is at the third index and so on. In case of ties, elements are assigned indices in the order they appear. We can obtain the sorted vector simply via .

2.1 Plackett-Luce distributions

The family of Plackett-Luce distributions over permutations is best described via a generative process: Consider a sequence of items, each associated with a canonical index . A common assumption in ranking models is that the underlying generating process for any observed permutation of items satisfies Luce’s choice axiom (Luce, 1959)

. Mathematically, this axiom defines the ‘choice’ probability of an item with index

as: where is interpreted as the score of item with index . The normalization constant is given by .

If we choose the items one at a time (without replacement) based on these choice probabilities, we obtain a discrete distribution over all possible permutations. This distribution is referred to as the Plackett-Luce (PL) distribution, and its probability mass function for any is given by:


where is the vector of scores parameterizing this distribution (Plackett, 1975).

2.2 Stochastic computation graphs

The abstraction of stochastic computation graphs (SCG) compactly specifies the forward value and the backward gradient computation for computational circuits. An SCG is a directed acyclic graph that consists of three kinds of nodes: input nodes which specify external inputs (including parameters), deterministic nodes which are deterministic functions of their parents, and stochastic nodes which are distributed conditionally on their parents. See Schulman et al. (2015) for a review.

To define gradients of an objective function with respect to any node in the graph, the chain rule necessitates that the gradients with respect to the intermediate nodes are well-defined. This is not the case for the

sort operator. In Section 3, we propose to extend stochastic computation graphs with nodes corresponding to a relaxation of the deterministic sort operator. In Section 4, we further use this relaxation to extend computation graphs to include stochastic nodes corresponding to distributions over permutations. The proofs of all theoretical results in this work are deferred to Appendix B.

3 NeuralSort: The Relaxed Sorting Operator

Figure 1: Stochastic computation graphs with a deterministic node corresponding to the output of a sort operator applied to the scores .

Our goal is to optimize training objectives involving a sort operator with gradient-based methods. Consider the optimization of objectives written in the following form:


Here, denotes a vector of real-valued scores, is the permutation that (deterministically) sorts the scores , and is an arbitrary function of interest assumed to be differentiable w.r.t a set of parameters and . For example, in a ranking application, these scores could correspond to the inferred relevances of webpages and could be a ranking loss. Figure 1 shows the stochastic computation graph corresponding to the objective in Eq. 2. We note that this could represent part of a more complex computation graph, which we skip for ease of presentation while maintaining the generality of the scope of this work.

While the gradient of the above objective w.r.t. is well-defined and can be computed via standard backpropogation, the gradient w.r.t. the scores is not defined since the sort operator is not differentiable w.r.t. . Our solution is to derive a relaxation to the sort operator that leads to a surrogate objective with well-defined gradients. In particular, we seek to use such a relaxation to replace the permutation matrix in Eq. 2 with an approximation such that the surrogate objective is differentiable w.r.t. the scores .

The general recipe to relax non-differentiable operators with discrete codomains is to consider differentiable alternatives that map the input to a larger continuous codomain with desirable properties. For gradient-based optimization, we are interested in two key properties:

  1. The relaxation is continuous everywhere and differentiable (almost-)everywhere with respect to elements in the input domain.

  2. There exists a computationally efficient projection from back to .

Relaxations satisfying the first requirement are amenable to automatic differentiation for optimizing stochastic computational graphs. The second requirement is useful for evaluating metrics and losses that necessarily require a discrete output akin to the one obtained from the original, non-relaxed operator.

E.g., in straight-through gradient estimation (Bengio et al., 2013; Jang et al., 2017), the non-relaxed operator is used for evaluating the learning objective in the forward pass and the relaxed operator is used in the backward pass for gradient estimation.

The canonical example is the loss used for binary classification. While the loss is discontinuous w.r.t. its inputs (real-valued predictions from a model), surrogates such as the logistic and hinge losses are continuous everywhere and differentiable almost-everywhere (property 1), and can give hard binary predictions via thresholding (property 2).

Note: For brevity, we assume that the operator is applied over a set of elements with a unique maximizer and hence, the operator has well-defined semantics. With some additional bookkeeping for resolving ties, the results in this section hold even if the elements to be sorted are not unique. See Appendix C.

Unimodal Row Stochastic Matrices.

The sort operator maps the input vector to a permutation, or equivalently a permutation matrix. Our relaxation to sort is motivated by the geometric structure of permutation matrices. The set of permutation matrices is a subset of doubly-stochastic matrices, i.e., a non-negative matrix such that the every row and column sums to one. If we remove the requirement that every column should sum to one, we obtain a larger set of row stochastic matrices. In this work, we propose a relaxation to sort that maps inputs to an alternate subset of row stochastic matrices, which we refer to as the unimodal row stochastic matrices.

Definition 1 (Unimodal Row Stochastic Matrices).

An matrix is Unimodal Row Stochastic if it satisfies the following conditions:

  1. Non-negativity: .

  2. Row Affinity: .

  3. Argmax Permutation: Let denote an -dimensional vector with entries such that . Then, , i.e., it is a valid permuation.

We denote as the set of unimodal row stochastic matrices.

Figure 2: Center: Venn Diagram relationships between permutation matrices (), doubly-stochastic matrices (), unimodal row stochastic matrices (), and row stochastic matrices (). Left: A doubly-stochastic matrix that is not unimodal. Right: A unimodal matrix that is not doubly-stochastic.

All row stochastic matrices satisfy the first two conditions. The third condition is useful for gradient based optimization involving sorting-based losses. The condition provides a straightforward mechanism for extracting a permutation from a unimodal row stochastic matrix via a row-wise operation. Figure 2 shows the relationships between the different subsets of square matrices.


Our relaxation to the sort operator is based on a standard identity for evaluating the sum of the largest elements in any input vector.

Lemma 2.

[Lemma 1 in Ogryczak & Tamir (2003)] For an input vector that is sorted as , we have the sum of the -largest elements given as:


The identity in Lemma 2 outputs the sum of the top- elements. The -th largest element itself can be recovered by taking the difference of the sum of top- elements and the top- elements.

Corollary 3.

Let be a real-valued vector of length . Let denote the matrix of absolute pairwise differences of the elements of such that . The permutation matrix corresponding to is given by:


where denotes the column vector of all ones.

E.g., if we set then the non-zero entry in the -th row corresponds to the element with the minimum sum of (absolute) distance to the other elements. As desired, this corresponds to the median element. The relaxation requires operations to compute , as opposed to the overall complexity for the best known sorting algorithms. In practice however, it is highly parallelizable and can be implemented efficiently on GPU hardware.

The operator is non-differentiable which prohibits the direct use of Corollary 3 for gradient computation. Instead, we propose to replace the operator with to obtain a continuous relaxation . In particular, the -th row of is given by:


where is a temperature parameter. Our relaxation is continuous everywhere and differentiable almost everywhere with respect to the elements of . Furthermore, we have the following result.

Theorem 4.

Let denote the continuous relaxation to the permutation matrix for an arbitrary input vector s and temperature defined in Eq. 5. Then, we have:

  1. Unimodality: , is a unimodal row stochastic matrix. Further, let denote the permutation obtained by applying row-wise to . Then, .

  2. Limiting behavior: If we assume that the entries of are drawn independently from a distribution that is absolutely continuous w.r.t. the Lebesgue measure in , then the following convergence holds almost surely:


Unimodality allows for efficient projection of the relaxed permutation matrix to the hard matrix via a row-wise , e.g., for straight-through gradients. For analyzing limiting behavior, independent draws ensure that the elements of are distinct almost surely. The temperature controls the degree of smoothness of our approximation. At one extreme, the approximation becomes tighter as the temperature is reduced. In practice however, the trade-off is in the variance of these estimates, which is typically lower for larger temperatures.

4 Stochastic Optimization over Permutations

(a) Stochastic

(b) Reparameterized Stochastic
Figure 3: Stochastic computation graphs with stochastic nodes corresponding to permutations. Squares denote deterministic nodes and circles denote stochastic nodes.

In many scenarios, we would like the ability to express our uncertainty in inferring a permutation e.g.

, latent variable models with latent nodes corresponding to permutations. Random variables that assume values corresponding to permutations can be represented via stochastic nodes in the stochastic computation graph. For optimizing the parameters of such a graph, consider the following class of objectives:


where and denote sets of parameters, is the permutation matrix corresponding to the permutation , is a parameterized distribution over the elements of the symmetric group , and is an arbitrary function of interest assumed to be differentiable in and . The SCG is shown in Figure 2(a). In contrast to the SCG considered in the previous section (Figure 1), here we are dealing with a distribution over permutations as opposed to a single (deterministically computed) one.

While such objectives are typically intractable to evaluate exactly since they require summing over a combinatorially large set, we can obtain unbiased estimates efficiently via Monte Carlo. Monte Carlo estimates of gradients w.r.t.

can be derived simply via linearity of expectation. However, the gradient estimates w.r.t. cannot be obtained directly since the sampling distribution depends on . The REINFORCE gradient estimator (Glynn, 1990; Williams, 1992; Fu, 2006) uses the fact that to derive the following Monte Carlo gradient estimates:


4.1 Reparameterized gradient estimators for PL distributions

REINFORCE gradient estimators typically suffer from high variance (Schulman et al., 2015; Glasserman, 2013). Reparameterized samplers provide an alternate gradient estimator by expressing samples from a distribution as a deterministic function of its parameters and a fixed source of randomness (Kingma & Welling, 2014; Rezende et al., 2014; Titsias & Lázaro-Gredilla, 2014). Since the randomness is from a fixed distribution, Monte Carlo gradient estimates can be derived by pushing the gradient operator inside the expectation (via linearity). In this section, we will derive a reparameterized sampler and gradient estimator for the Plackett-Luce (PL) family of distributions.

Let the score for an item be an unobserved random variable drawn from some underlying score distribution (Thurstone, 1927). Now for each item, we draw a score from its corresponding score distribution. Next, we generate a permutation by applying the deterministic sort operator to these randomly sampled scores. Interestingly, prior work has shown that the resulting distribution over permutations corresponds to a PL distribution if and only if the scores are sampled independently from Gumbel distributions with identical scales.

Proposition 5.

[adapted from Yellott Jr (1977)] Let be a vector of scores for the items. For each item , sample independently with zero mean and a fixed scale . Let denote the vector of Gumbel perturbed log-scores with entries such that . Then:


For ease of presentation, we assume in the rest of this work. Proposition 5 provides a method for sampling from PL distributions with parameters by adding Gumbel perturbations to the log-scores and applying the sort operator to the perturbed log-scores. This procedure can be seen as a reparameterization trick that expresses a sample from the PL distribution as a deterministic function of the scores and a fixed source of randomness (Figure 2(b)). Letting denote the vector of i.i.d. Gumbel perturbations, we can express the objective in Eq. 7 as:


While the reparameterized sampler removes the dependence of the expectation on the parameters , it introduces a sort operator in the computation graph such that the overall objective is non-differentiable in . In order to obtain a differentiable surrogate, we approximate the objective based on the NeuralSort relaxation to the sort operator:


Accordingly, we get the following reparameterized gradient estimates for the approximation:


which can be estimated efficiently via Monte Carlo because the expectation is with respect to a distribution that does not depend on .

5 Discussion and Related Work

The problem of learning to rank documents based on relevance has been studied extensively in the context of information retrieval. In particular, listwise approaches learn functions that map objects to scores. Much of this work concerns the PL distribution: the RankNet algorithm (Burges et al., 2005) can be interpreted as maximizing the PL likelihood of pairwise comparisons between items, while the ListMLE ranking algorithm in Xia et al. (2008) extends this with a loss that maximizes the PL likelihood of ground-truth permutations directly. The differentiable pairwise approaches to ranking, such as Rigutini et al. (2011), learn to approximate the comparator between pairs of objects. Our work considers a generalized setting where sorting based operators can be inserted anywhere in computation graphs to extend traditional pipelines e.g., kNN.

Prior works have proposed relaxations of permutation matrices to the Birkhoff polytope, which is defined as the convex hull of the set of permutation matrices a.k.a. the set of doubly-stochastic matrices. A doubly-stochastic matrix is a permutation matrix iff it is orthogonal and continuous relaxations based on these matrices have been used previously for solving NP-complete problems such as seriation and graph matching (Fogel et al., 2013; Fiori et al., 2013; Lim & Wright, 2014). Adams & Zemel (2011) proposed the use of the Sinkhorn operator to map any square matrix to the Birkhoff polytope. They interpret the resulting doubly-stochastic matrix as the marginals of a distribution over permutations. Mena et al. (2018) propose an alternate method where the square matrix defines a latent distribution over the doubly-stochastic matrices themselves. These distributions can be sampled from by adding elementwise Gumbel perturbations. Linderman et al. (2018) propose a rounding procedure that uses the Sinkhorn operator to directly sample matrices near the Birkhoff polytope. Unlike Mena et al. (2018), the resulting distribution over matrices has a tractable density. In practice, however, the approach of Mena et al. (2018) performs better and will be the main baseline we will be comparing against in our experiments in Section 6.

As discussed in Section 3, NeuralSort maps permutation matrices to the set of unimodal row-stochastic matrices. For the stochastic setting, the PL distribution permits efficient sampling, exact and tractable density estimation, making it an attractive choice for several applications, e.g., variational inference over latent permutations. Our reparameterizable sampler, while also making use of the Gumbel distribution, is based on a result unique to the PL distribution (Proposition 5).

The use of the Gumbel distribution for defining continuous relaxations to discrete distributions was first proposed concurrently by Jang et al. (2017) and Maddison et al. (2017)

for categorical variables, referred to as Gumbel-Softmax. The number of possible permutations grow factorially with the dimension, and thus any distribution over

-dimensional permutations can be equivalently seen as a distribution over categories. Gumbel-softmax does not scale to a combinatorially large number of categories (Kim et al., 2016; Mussmann et al., 2017), necessitating the use of alternate relaxations, such as the one considered in this work.

6 Experiments

We refer to the two approaches proposed in Sections 3, 4

as Deterministic NeuralSort and Stochastic NeuralSort, respectively. For additional hyperparameter details and analysis, see Appendix 


6.1 Sorting handwritten numbers



Task 1: Sorting Loss()


Task 2: Median Regression Loss()

Figure 4: Sorting and quantile regression. The model is trained to sort sequences of large-MNIST images (Task 1) and regress the median value (Task 2). In the above example, the ground-truth permutation that sorts the input sequence from largest to smallest is , 9803 being the largest and 1270 the smallest. Blue illustrates the true median image with ground-truth sorted index and value .

Dataset. We first create the large-MNIST dataset, which extends the MNIST dataset of handwritten digits. The dataset consists of multi-digit images, each a concatenation of 4 randomly selected individual images from MNIST, e.g., is one such image in this dataset. Each image is associated with a real-valued label, which corresponds to its concatenated MNIST labels, e.g., the label of is . Using the large-MNIST dataset, we finally create a dataset of sequences. Every sequence is this dataset consists of randomly sampled large-MNIST images.

Setup. Given a dataset of sequences of large-MNIST images, our goal is to learn to predict the permutation that sorts the labels of the sequence of images, given a training set of ground-truth permutations. Figure 4 (Task 1) illustrates this task on an example sequence of large-MNIST images. This task is a challenging extension of the one considered by Mena et al. (2018) in sorting scalars, since it involves learning the semantics of high-dimensional objects prior to sorting. A good model needs to learn to dissect the individual digits in an image, rank these digits, and finally, compose such rankings based on the digit positions within an image. The available supervision, in the form of the ground-truth permutation, is very weak compared to a classification setting that gives direct access to the image labels.

Baselines. All baselines use a CNN that is shared across all images in a sequence to map each large-MNIST image to a feature space. The vanilla row-stochastic (RS) baseline concatenates the CNN representations for

images into a single vector that is fed into a multilayer perceptron that outputs

multiclass predictions of the image probabilities for each rank. The Sinkhorn and Gumbel-Sinkhorn baselines, as discussed in Section 5, use the Sinkhorn operator to map the stacked CNN representations for the objects into a doubly-stochastic matrix. For all methods, we minimized the cross-entropy loss between the predicted matrix and the ground-truth permutation matrix.

Results. Following Mena et al. (2018), our evaluation metric is the the proportion of correctly predicted permutations on a test set of sequences. Additionally, we evaluate the proportion of individual elements ranked correctly. Table 1 demonstrates that the approaches based on the proposed sorting relaxation significantly outperform the baseline approaches for all considered. The performance of the deterministic and stochastic variants are comparable. The vanilla RS baseline performs well in ranking individual elements, but is not good at recovering the overall square matrix.

We believe the poor performance of the Sinkhorn baselines is partly because these methods were designed and evaluated for matchings. Like the output of sort, matchings can also be represented as permutation matrices. However, distributions over matchings need not satisfy Luce’s choice axiom or imply a total ordering, which could explain the poor performance on the tasks considered.

Vanilla RS 0.467 (0.801) 0.093 (0.603) 0.009 (0.492) 0. (0.113) 0. (0.067)
Sinkhorn 0.462 (0.561) 0.038 (0.293) 0.001 (0.197) 0. (0.143) 0. (0.078)
Gumbel-Sinkhorn 0.484 (0.575) 0.033 (0.295) 0.001 (0.189) 0. (0.146) 0. (0.078)
Deterministic NeuralSort 0.930 (0.951) 0.837 (0.927) 0.738 (0.909) 0.649 (0.896) 0.386 (0.857)
Stochastic NeuralSort 0.927 (0.950) 0.835 (0.926) 0.741 (0.909) 0.646 (0.895) 0.418 (0.862)
Table 1: Average sorting accuracy on the test set. First value is proportion of permutations correctly identified; value in parentheses is the proportion of individual element ranks correctly identified.

6.2 Quantile regression

Setup. In this experiment, we extend the sorting task to regression. Again, each sequence contains large-MNIST images, and the regression target for each sequence is the -th quantile (i.e., the median) of the labels of the images in the sequence. Figure 4 (Task 2) illustrates this task on an example sequence of large-MNIST images, where the goal is to output the third largest label. The design of this task highlights two key challenges since it explicitly requires learning both a suitable representation for sorting high-dimensional inputs and a secondary function that approximates the label itself (regression). Again, the supervision available in the form of the label of only a single image at an arbitrary and unknown location in the sequence is weak.

Baselines. In addition to Sinkhorn and Gumbel-Sinkhorn, we design two more baselines. The Constant baseline always returns the median of the full range of possible outputs, ignoring the input sequence. This corresponds to since we are sampling large-MNIST images uniformly in the range of four-digit numbers. The vanilla neural net (NN) baseline directly maps the input sequence of images to a real-valued prediction for the median.

Constant (Simulated) 356.79 (0.00) 227.31 (0.00) 146.94 ( 0.00)
Vanilla NN 1004.70 (0.85) 699.15 (0.82) 562.97 (0.79)
Sinkhorn 343.60 (0.25) 231.87 (0.19) 156.27 (0.04)
Gumbel-Sinkhorn 344.28 (0.25) 232.56 (0.23) 157.34 (0.06)
Deterministic NeuralSort 45.50 (0.95) 34.98 (0.94) 34.78 (0.92)
Stochastic NeuralSort 33.80 (0.94) 31.43 (0.93) 29.34 (0.90)
Table 2: Test mean squared error () and values (in parenthesis) for quantile regression.

Results. Our evaluation metric is the mean squared error (MSE) and on a test set of sequences. Results for images are shown in Table 2. The Vanilla NN baseline while incurring a large MSE, is competitive on the metric. The other baselines give comparable performance on the MSE metric. The proposed NeuralSort approaches outperform the competing methods on both the metrics considered. The stochastic NeuralSort approach is the consistent best performer on MSE, while the deterministic NeuralSort is slightly better on the metric.

6.3 End-to-end, differentiable -Nearest Neighbors



Figure 5: Differentiable kNN. The model is trained such that the representations for the training points that have the same label as are closer to (included in top-) than others.

Setup. In this experiment, we design a fully differentiable, end-to-end -nearest neighbors (kNN) classifier. Unlike a standard kNN classifier which computes distances between points in a predefined space, we learn a representation of the data points before evaluating the -nearest neighbors.

We are given access to a dataset of () pairs of standard input data and their class labels respectively. The differentiable kNN algorithm consists of two hyperparameters: the number of training neighbors , the number of top candidates , and the sorting temperature . Every sequence of items here consists of a query point and a randomly sampled subset of candidate nearest neighbors from the training set, say

. In principle, we could use the entire training set (excluding the query point) as candidate points, but this can hurt the learning both computationally and statistically. The query points are randomly sampled from the train/validation/test sets as appropriate but the nearest neighbors are always sampled from the training set. The loss function optimizes for a representation space

(e.g., CNN) such that the top- candidate points with the minimum Euclidean distance to the query point in the representation space have the same label as the query point. Note that at test time, once the representation space is learned, we can use the entire training set as the set of candidate points, akin to a standard kNN classifier. Figure 5 illustrates the proposed algorithm.

Formally, for any datapoint , let denote a permutation of the candidate points. The uniformly-weighted kNN loss, denoted as , can be written as follows:


where are the labels for the candidate points. Note that when is an exact permutation matrix (i.e., temperature ), this expression is exactly the negative of the fraction of nearest neighbors that have the same label as .

Using Eq. 13, the training objectives for Deterministic and Stochastic NeuralSort are given as:

Deterministic: (14)
Stochastic: (15)
Algorithm MNIST Fashion-MNIST CIFAR-10
kNN 97.2% 85.8% 35.4%
kNN+PCA 97.6% 85.9% 40.9%
kNN+AE 97.6% 87.5% 44.2%
kNN + Deterministic NeuralSort 99.5% 93.5% 90.7%
kNN + Stochastic NeuralSort 99.4% 93.4% 89.5%
CNN (w/o kNN) 99.4% 93.4% 95.1%
Table 3: Average test kNN classification accuracies from neighbors for best value of .

Datasets. We consider three benchmark datasetes: MNIST dataset of handwritten digits, Fashion-MNIST dataset of fashion apparel, and the CIFAR-10 dataset of natural images (no data augmentation) with the canonical splits for training and testing.

Baselines. We consider kNN baselines that operate in three standard representation spaces: the canonical pixel basis, the basis specified by the top principal components (PCA), an autonencoder (AE). Additionally, we experimented with nearest neighbors and across two distance metrics: uniform weighting of all -nearest neighbors and weighting nearest neighbors by the inverse of their distance. For completeness, we trained a CNN with the same architecture as the one used for NeuralSort (except the final layer) using the cross-entropy loss.

Results. We report the classification accuracies on the standard test sets in Table 3

. On both datasets, the differentiable kNN classifier outperforms all the baseline kNN variants including the convolutional autoencoder approach. The performance is much closer to the accuracy of a standard CNN.

7 Conclusion

In this paper, we proposed NeuralSort, a continuous relaxation of the sorting operator to the set of unimodal row-stochastic matrices. Our relaxation facilitates gradient estimation on any computation graph involving a sort operator. Further, we derived a reparameterized gradient estimator for the Plackett-Luce distribution for efficient stochastic optimization over permutations. On three illustrative tasks including a fully differentiable -nearest neighbors, our proposed relaxations outperform prior work in end-to-end learning of semantic orderings of high-dimensional objects.

In the future, we would like to explore alternate relaxations to sorting as well as applications that extend widely-used algorithms such as beam search (Goyal et al., 2018)

. Both deterministic and stochastic NeuralSort are easy to implement. We provide reference implementations in Tensorflow 

(Abadi et al., 2016)

and PyTorch 

(Paszke et al., 2017) in Appendix A. The full codebase for this work is open-sourced at


This research was supported by NSF (#1651565, #1522054, #1733686), ONR, AFOSR (FA9550-19-1-0024), and FLI. AG is supported by MSR Ph.D. fellowship and Stanford Data Science scholarship. We are thankful to Jordan Alexander, Kristy Choi, Adithya Ganesh, Karan Goel, Neal Jean, Daniel Levy, Jiaming Song, Yang Song, Serena Yeung, and Hugh Zhang for helpful comments.


  • Abadi et al. (2016) Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: a system for large-scale machine learning. In Operating Systems Design and Implementation, 2016.
  • Adams & Zemel (2011) Ryan Prescott Adams and Richard S Zemel. Ranking via Sinkhorn propagation. arXiv preprint arXiv:1106.1925, 2011.
  • Balog et al. (2017) Matej Balog, Nilesh Tripuraneni, Zoubin Ghahramani, and Adrian Weller. Lost relatives of the Gumbel trick. In International Conference on Machine Learning, 2017.
  • Bar-Shalom & Li (1995) Yaakov Bar-Shalom and Xiao-Rong Li. Multitarget-multisensor tracking: principles and techniques. 1995.
  • Bengio et al. (2013) Yoshua Bengio, Nicholas Léonard, and Aaron Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • Berrada et al. (2018) Leonard Berrada, Andrew Zisserman, and M Pawan Kumar. Smooth loss functions for deep top-k classification. In International Conference on Learning Representations, 2018.
  • Burges et al. (2005) Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. Learning to rank using gradient descent. In International Conference on Machine learning, 2005.
  • Chierichetti et al. (2018) Flavio Chierichetti, Ravi Kumar, and Andrew Tomkins. Discrete choice, permutations, and reconstruction. In Symposium on Discrete Algorithms, 2018.
  • Fiori et al. (2013) Marcelo Fiori, Pablo Sprechmann, Joshua Vogelstein, Pablo Musé, and Guillermo Sapiro. Robust multimodal graph matching: Sparse coding meets graph matching. In Advances in Neural Information Processing Systems, 2013.
  • Fogel et al. (2013) Fajwel Fogel, Rodolphe Jenatton, Francis Bach, and Alexandre d’Aspremont. Convex relaxations for permutation problems. In Advances in Neural Information Processing Systems, 2013.
  • Fu (2006) Michael C Fu. Gradient estimation. Handbooks in operations research and management science, 13:575–616, 2006.
  • Gao & Pavel (2017) Bolin Gao and Lacra Pavel. On the properties of the softmax function with application in game theory and reinforcement learning. arXiv preprint arXiv:1704.00805, 2017.
  • Glasserman (2013) Paul Glasserman. Monte Carlo methods in financial engineering, volume 53. Springer Science & Business Media, 2013.
  • Glynn (1990) Peter W Glynn. Likelihood ratio gradient estimation for stochastic systems. Communications of the ACM, 33(10):75–84, 1990.
  • Goyal et al. (2018) Kartik Goyal, Graham Neubig, Chris Dyer, and Taylor Berg-Kirkpatrick. A continuous relaxation of beam search for end-to-end training of neural sequence models. In

    AAAI Conference on Artificial Intelligence

    , 2018.
  • Jang et al. (2017) Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with Gumbel-softmax. In International Conference on Learning Representations, 2017.
  • Kim et al. (2016) Carolyn Kim, Ashish Sabharwal, and Stefano Ermon.

    Exact sampling with integer linear programs and random perturbations.

    In AAAI Conference on Artificial Intelligence, 2016.
  • Kingma & Welling (2014) Diederik P Kingma and Max Welling. Auto-encoding variational Bayes. In International Conference on Learning Representations, 2014.
  • Lim & Wright (2014) Cong Han Lim and Stephen J Wright. Sorting network relaxations for vector permutation problems. arXiv preprint arXiv:1407.6609, 2014.
  • Linderman et al. (2018) Scott W Linderman, Gonzalo E Mena, Hal Cooper, Liam Paninski, and John P Cunningham. Reparameterizing the birkhoff polytope for variational permutation inference. In International Conference on Artificial Intelligence and Statistics, 2018.
  • Liu et al. (2009) Tie-Yan Liu et al. Learning to rank for information retrieval. Foundations and Trends® in Information Retrieval, 3(3):225–331, 2009.
  • Luce (1959) R Duncan Luce. Individual choice behavior: A theoretical analysis. Courier Corporation, 1959.
  • Maddison et al. (2017) Chris J Maddison, Andriy Mnih, and Yee Whye Teh.

    The concrete distribution: A continuous relaxation of discrete random variables.

    In International Conference on Learning Representations, 2017.
  • Mena et al. (2018) Gonzalo Mena, David Belanger, Scott Linderman, and Jasper Snoek. Learning latent permutations with gumbel-sinkhorn networks. In International Conference on Learning Representations, 2018.
  • Mussmann et al. (2017) Stephen Mussmann, Daniel Levy, and Stefano Ermon. Fast amortized inference and learning in log-linear models with randomly perturbed nearest neighbor search. In Uncertainty in Artificial Intelligence, 2017.
  • Ogryczak & Tamir (2003) Wlodzimierz Ogryczak and Arie Tamir. Minimizing the sum of the largest functions in linear time. Information Processing Letters, 85(3):117–122, 2003.
  • Paszke et al. (2017) Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. 2017.
  • Plackett (1975) Robin L Plackett. The analysis of permutations. Applied Statistics, pp. 193–202, 1975.
  • Regenwetter et al. (2006) Michel Regenwetter, Bernard Grofman, Ilia Tsetlin, and AAJ Marley. Behavioral social choice: probabilistic models, statistical inference, and applications. Cambridge University Press, 2006.
  • Rezende et al. (2014) Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra.

    Stochastic backpropagation and approximate inference in deep generative models.

    In International Conference on Machine Learning, 2014.
  • Rigutini et al. (2011) Leonardo Rigutini, Tiziano Papini, Marco Maggini, and Franco Scarselli. Sortnet: Learning to rank by a neural preference function. IEEE transactions on neural networks, 22(9):1368–1380, 2011.
  • Schulman et al. (2015) John Schulman, Nicolas Heess, Theophane Weber, and Pieter Abbeel. Gradient estimation using stochastic computation graphs. In Advances in Neural Information Processing Systems, 2015.
  • Thurstone (1927) Louis L Thurstone. A law of comparative judgment. Psychological review, 34(4):273, 1927.
  • Titsias & Lázaro-Gredilla (2014) Michalis Titsias and Miguel Lázaro-Gredilla. Doubly stochastic variational Bayes for non-conjugate inference. In International Conference on Machine Learning, 2014.
  • Weinberger et al. (2006) Kilian Q Weinberger, John Blitzer, and Lawrence K Saul. Distance metric learning for large margin nearest neighbor classification. In Advances in Neural Information Processing Systems, 2006.
  • Williams (1992) Ronald J Williams.

    Simple statistical gradient-following algorithms for connectionist reinforcement learning.

    Machine learning, 8(3-4):229–256, 1992.
  • Xia et al. (2008) Fen Xia, Tie-Yan Liu, Jue Wang, Wensheng Zhang, and Hang Li. Listwise approach to learning to rank: theory and algorithm. In International Conference on Machine Learning, 2008.
  • Yellott Jr (1977) John I Yellott Jr.

    The relationship between luce’s choice axiom, thurstone’s theory of comparative judgment, and the double exponential distribution.

    Journal of Mathematical Psychology, 15(2):109–144, 1977.


Appendix A Sorting Operator

a.1 Tensorflow

Sorting Relaxation for Deterministic NeuralSort:

import tensorflow as tf
def deterministic_NeuralSort(s, tau):
  s: input elements to be sorted. Shape: batch_size x n x 1
  tau: temperature for relaxation. Scalar.
  n = tf.shape(s)[1]
  one = tf.ones((n, 1), dtype = tf.float32)
  A_s = tf.abs(s - tf.transpose(s, perm=[0, 2, 1]))
  B = tf.matmul(A_s, tf.matmul(one, tf.transpose(one)))
  scaling = tf.cast(n + 1 - 2 * (tf.range(n) + 1), dtype = tf.float32)
  C = tf.matmul(s, tf.expand_dims(scaling, 0))
  P_max = tf.transpose(C-B, perm=[0, 2, 1])
  P_hat = tf.nn.softmax(P_max / tau, -1)
  return P_hat

Reparameterized Sampler for Stochastic NeuralSort:

def sample_gumbel(samples_shape, eps = 1e-10):
    U = tf.random_uniform(samples_shape, minval=0, maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)
def stochastic_NeuralSort(s, n_samples, tau):
  s: parameters of the PL distribution. Shape: batch_size x n x 1.
  n_samples: number of samples from the PL distribution. Scalar.
  tau: temperature for the relaxation. Scalar.
  batch_size = tf.shape(s)[0]
  n = tf.shape(s)[1]
  log_s_perturb = s + sample_gumbel([n_samples, batch_size, n, 1])
  log_s_perturb = tf.reshape(log_s_perturb, [n_samples * batch_size, n, 1])
  P_hat = deterministic_NeuralSort(log_s_perturb, tau)
  P_hat = tf.reshape(P_hat, [n_samples, batch_size, n, n])
  return P_hat

a.2 PyTorch

Sorting Relaxation for Deterministic NeuralSort:

import torch
def deterministic_NeuralSort(s, tau):
  s: input elements to be sorted. Shape: batch_size x n x 1
  tau: temperature for relaxation. Scalar.
  n = s.size()[1]
  one = torch.ones((n, 1), dtype = torch.float32)
  A_s = torch.abs(s - s.permute(0, 2, 1))
  B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))
  scaling = (n + 1 - 2 * (torch.arange(n) + 1)).type(torch.float32)
  C = torch.matmul(s, scaling.unsqueeze(0))
  P_max = (C-B).permute(0, 2, 1)
  sm = torch.nn.Softmax(-1)
  P_hat = sm(P_max / tau)
  return P_hat

Reparamterized Sampler for Stochastic NeuralSort:

def sample_gumbel(samples_shape, eps = 1e-10):
    U = torch.rand(samples_shape)
    return -torch.log(-torch.log(U + eps) + eps)
def stochastic_NeuralSort(s, n_samples, tau):
  s: parameters of the PL distribution. Shape: batch_size x n x 1.
  n_samples: number of samples from the PL distribution. Scalar.
  tau: temperature for the relaxation. Scalar.
  batch_size = s.size()[0]
  n = s.size()[1]
  log_s_perturb = torch.log(s) + sample_gumbel([n_samples, batch_size, n, 1])
  log_s_perturb = log_s_perturb.view(n_samples * batch_size, n, 1)
  P_hat = deterministic_NeuralSort(log_s_perturb, tau)
  P_hat = P_hat.view(n_samples, batch_size, n, n)
  return P_hat

Appendix B Proofs of Theoretical Results

b.1 Lemma 2


For any value of , the following inequalities hold:

Furthermore, for :

This finishes the proof. ∎

b.2 Corollary 3


We first consider at exactly what values of the sum in Lemma 2 is minimized. For simplicity we will only prove the case where all values of are distinct.

The equality holds only when . By Lemma 2, these values of also minimize the RHS of the equality.

Symmetrically, if one considers the score vector , then is minimized at .

Replacing by and using the definition of implies that is minimized at .

It follows that:

Thus, if , then . This finishes the proof.

b.3 Theorem 4

We prove the two properties in the statement of the theorem independently:

  1. Unimodality


    By definition of the softmax function, the entries ef are positive and sum to 1. To show that satisfies the argmax permutation property, . Formally, for any given row , we construct the argmax permutation vector as:

    where the square notation denotes the index of the -th largest element. The first step follows from the fact that the softmax function is monotonically increasing and hence, it preserves the argmax. The second equality directly follows from Corollary 3. By definition, , finishing the proof. ∎

  2. Limiting behavior


    As shown in Gao & Pavel (2017), the softmax function may be equivalently defined as . In particular, . The distributional assumptions ensure that the elements of are distinct a.s., so plugging in completes the proof. ∎

b.4 Proposition 5

This result follows from an earlier result by Yellott Jr (1977). We give the proof sketch below and refer the reader to Yellott Jr (1977) for more details.


Consider random variables such that .

We may prove by induction a generalization of the memoryless property:

If we assume as inductive hypothesis that , we complete the induction as:

It follows from a familiar property of argmin of exponential distributions that:

and by another induction, we have .

Finally, following the argument of Balog et al. (2017), we apply the strictly decreasing function to this identity, which from the definition of the Gumbel distribution implies:

Appendix C Arg Max semantics for Tied Max Elements

While applying the operator to a vector with duplicate entries attaining the value, we need to define the operator semantics for to handle ties in the context of the proposed relaxation.

Definition 6.

For any vector with ties, let denote the operator that returns the set of all indices containing the element. We define the of the -th in a matrix recursively:

  1. If there exists an index that is a member of and has not been assigned as an of any row , then the is the smallest such index.

  2. Otherwise, the is the smallest index that is a member of the .

This function is efficiently computable with additional bookkeeping.

Lemma 7.

For an input vector with the sort permutation matrix given as , we have = if and only if there exists a row such that for all .


From Eq. 5, we have the -th row of given as:

. Therefore, we have the equations:

for some fixed normalization constant . As the function is invertible, both directions of the lemma follow immediately. ∎

Lemma 8.