1 Introduction
Metric learning has been an important family of machine learning algorithms and has achieved successes on several problems, including computer vision
[kulis2009fast, guillaumin2009you, hua2007discriminant], text analysis [lebanon2006metric], meta learning [vinyals2016matching, snell2017prototypical] and others [slaney2008learning, xiong2006kernel, ye2016what]. Given a set of training samples, metric learning aims to learn a good distance measurement such that items in the same class are closer to each other in the learned metric space, which is crucial for classification and similarity search. Since this objective is directly related to the assumption of nearest neighbor classifiers, most of the metric learning algorithms can be naturally and successfully combined with
Nearest Neighbor (NN) classifiers.Adversarial robustness of machine learning algorithms has been studied extensively in recent years due to the need of robustness guarantees in real world systems. It has been demonstrated that neural networks can be easily attacked by adversarial perturbations in the input space [szegedy2014intriguing, goodfellow2015explaining, biggio2018wild], and such perturbations can be computed efficiently in both whitebox [carlini2017towards, madry2018towards] and blackbox settings [chen2017zoo, ilyas2019prior, cheng2020signopt]. Therefore, many defense algorithms have been proposed to improve the robustness of neural networks [kurakin2017adversarial, madry2018towards]. Although these algorithms can successfully defend from standard attacks, it has been shown that many of them are vulnerable under stronger attacks when the attacker knows the defense mechanisms [carlini2017towards]. Therefore, recent research in adversarial defense of neural networks has shifted to the concept of “certified defense”, where the defender needs to provide a certification that no adversarial examples exist within a certain input region [wong2018provable, cohen2019certified, zhang2020towards].
In this paper, we consider the problem of learning a metric that is robust against adversarial input perturbations. It has been shown that nearest neighbor classifiers are not robust [papernot2016transferability, wang2019evaluating, sitawarin2020minimum], where a small and human imperceptible perturbation in the input space can easily fool a NN classifier, thus it is natural to investigate how to obtain a metric that improves the adversarial robustness. Despite being an important and interesting research question to tackle, to the best of our knowledge this problem has not been studied in the literature. There are several caveats that make this a hard problem: 1) attack and defense algorithms for neural networks often rely on the smoothness of the function, while NN is a discrete step function where the gradient does not exist. 2) Even evaluating the robustness of NN (with the Euclidean distance) is harder than neural networks. For instance, [wang2019evaluating] showed that attacks to NN are time consuming and nontrivial. Furthermore, none of the existing work have considered general Mahalanobis distances. 3) Existing algorithms for evaluating the robustness of NN, including attack [yang2020robust] and verification [wang2019evaluating], are often nondifferentiable, while training a robust metric will require a differentiable measurement of robustness.
To develop a provably robust metric learning algorithm, we formulate an objective function to learn a Mahalanobis distance, parameterized by a positive semidefinite matrix , that maximizes the minimal adversarial perturbation on each sample. However, computing the minimal adversarial perturbation is intractable for NN, so to make the problem solvable, we propose an efficient formulation for lowerbounding the minimal adversarial perturbation, and this lower bound can be represented as an explicit function of to enable the gradient computation. We further develop several tricks to improve the efficiency of the overall procedure. Similar to certified defense algorithms in neural networks, the proposed algorithm can provide a certified robustness improvement on the resulting NN model with the learned metric. Decision boundaries of 1NN with different Mahalanobis distances for a toy dataset are visualized in Figure 1. It can be observed that the proposed Adversarial Robust Metric Learning (ARML) method can obtain a more robust metric on this example.
We conduct extensive experiments on six real world datasets and show that the proposed algorithm can improve both certified robust errors and the empirical robust errors (errors under adversarial attacks) over existing metric learning algorithms.
2 Background
Metric learning for nearest neighbor classifiers
A nearestneighbor classifier based on a Mahalanobis distance could be characterized by a training dataset and a positive semidefinite matrix. Let be the instance space, the label space where is the number of classes. is the training set with for every . is a positive semidefinite matrix. The Mahalanobis distance for any is defined as
(1) 
and a Mahalanobis NN classifier will find the nearest neighbors of the test instance in based on the Mahalanobis distance, and then predicts the label based on majority voting of these neighbors.
Many metric learning approaches aim to learn a good Mahalanobis distance based on training data, including [goldberger2004neighbourhood, davis2007information, weinberger2009distance, jain2010inductive, sugiyama2007dimensionality] (see more discussions in Section 5). However, none of these previous methods are trying to find a metric that is robust to small input perturbations.
Adversarial robustness and minimal adversarial perturbation
There are two important concepts in adversarial robustness: adversarial attack and adversarial verification (or robustness verification). Adversarial attack aims to find a perturbation to change the prediction, and adversarial verification aims to find a radius within which no perturbation could change the prediction. Both of them can be reduced to the problem of finding the minimal adversarial perturbation. For a classifier on an instance , the minimal adversarial perturbation can be defined as
(2) 
which is the smallest perturbation that could lead to misclassification. Note that if is not correctly classified, the minimal adversarial perturbation is
, i.e., the zero vector. Let
denote the optimal solution and the optimal value. Obviously, is also the solution of the optimal adversarial attack, and is the solution of the optimal adversarial verification. For neural networks, it is often NPcomplete to compute (2), so many efficient algorithms have been proposed for attack [goodfellow2015explaining, carlini2017towards, brendel2018decision, cheng2020signopt] and verification [wong2018provable, weng2018towards, mirman2018differentiable], corresponding to computing upper and lower bounds of (2) respectively. However, these methods do not work for discrete models such as nearest neighbor classifiers.In this paper our algorithm will be based on a novel derivation of a lower bound of the minimal adversarial perturbation for Mahalanobis NN classifiers. To the best of our knowledge, there has been no previous work tackling this problem. Since the Mahalanobis NN classifier is parameterized by a positive semidefinite matrix and the training set , we further let and explicitly indicate their dependence on and . In this paper we will consider norm in (2) for simplicity.
Certified and empirical robust error
Let be a lower bound of the norm of the minimal adversarial perturbation , possibly computed by a robustness verification algorithm. For a distribution over , the certified robust error with respect to the radius
is defined as the probability that
is not greater than , namely(3) 
Note that in the case and , the certified robust error is reduced to the clean error. In this paper we will investigate how to compute it for Mahalanobis NN classifiers.
On the other hand, adversarial attack algorithms are trying to find a feasible solution of (2), which will give an upper bound . Based on the upper bound, we can measure the empirical robust error of a model by
(4) 
Since is computed by an attack method, the empirical robust error is also called the attack error. A family of decisionbased attack methods, which view the victim model as a blackbox, can be used to attack Mahalanobis NN classifiers [brendel2018decision, cheng2019query, cheng2020signopt].
3 Adversarially robust metric learning
The objective of adversarially robust metric learning (ARML) is to learn the matrix via the training data such that the resulting Mahalanobis NN classifier has small certified and empirical robust errors.
3.1 Basic formulation
The goal is to learn a positive semidefinite matrix to minimize the certified robust training error. Since the certified robust error defined in (3
) is nonsmooth, we replace the indicator function by a loss function. The resulting objective can be formulated as
(5) 
where is an monotonically nonincreasing function, e.g., the hinge loss , exponential loss , logistic loss , or “negative” loss . We also employ to enforce to be positive semidefinite, and it is possible to derive a lowrank by constraining the shape of . Note that the minimal adversarial perturbation is defined on the training set excluding , since otherwise a 1nearest neighbor classifier with any distance measurement will have 100% accuracy. In this way, we minimize the “leaveoneout” certified robust error. The remaining problem is how to exactly compute or approximate in our training objective.
3.2 Bounding minimal adversarial perturbation for Mahalanobis Nn
For convenience, suppose
is an odd number and denote
. In the binary classification case for simplicity, i.e., , the computation of for Mahalanobis NN could be formulated as(6) 
This minimization formulation enumerates all the size nearest neighbor set containing at most instances in the same class with the test instance, computes the minimum perturbation resulting in each nearest neighbor set, and takes the minimum of them.
Obviously, solving (6) exactly has time complexity growing exponentially with , and furthermore, a numerical solution cannot be incorporated into the training objective (5) since we need to write as a function of for backpropagation. To address these issues, we resort to a lower bound of the optimal value of (6) rather than solving it exactly.
First, we consider a simple triplet problem: given vectors and a positive semidefinite matrix , find the minimum perturbation on such that holds. It could be formulated as the following optimization problem
(7) 
Note that the constraint in (7) can be written as a linear form, so this is a convex quadratic programming problem with a linear constraint. We show that the optimal value of (7) can be expressed in closed form:
(8) 
where denotes . The derivation for the optimal value is deferred to Appendix A. Note that if
is the identity matrix and
strictly holds, then the optimal value is the Euclidean distance from to the bisection between and .For convenience, we define the function as
(9) 
Then we could relax (6) further and have the following theorem:
Theorem 1 (Robustness verification for Mahalanobis Nn).
Given a Mahalanobis NN classifier parameterized by a neighbor parameter , a training dataset and a positive semidefinite matrix , for any instance we have
(10) 
where and select the th maximum and th minimum respectively with .
The proof is deferred to Appendix B. In this way, we only need to compute for each and in order to derive a lower bound of the minimal adversarial perturbation of Mahalanobis NN. It leads to an efficient algorithm to verify the robustness of Mahalanobis NN. The time complexity is and independent of .
In the general multiclass case, the constraint of (6) is the necessary condition for successful attacks, rather than the necessary and sufficient condition. As a result, the optimal value of (6) is a lower bound of the minimal adversarial perturbation. Therefore, Theorem 1 also holds for the multiclass case. Based on this lower bound of , we will derive the proposed ARML algorithm.
3.3 Training algorithm of adversarially robust metric learning
By replacing the in (5) with the lower bound derived in Theorem 1, we get a trainable objective function for adversarially robust metric learning:
(11) 
Although (11) is trainable since is a function of , for large datasets it is timeconsuming to run the inner minmax procedure. Furthermore, since we care about the generalization performance of the learned metric instead of the robust training error, it is unnecessary to compute the exact solution. Therefore, instead of computing the and exactly, we propose to sample positive and negative instances from the neighborhood, which leads to the following formulation:
(12) 
where denotes a sampling procedure for an instance in the same class within ’s neighborhood, and denotes a sampling procedure for an instance in a different class, also within ’s neighborhood, and the distances are measured by the Mahalanobis distance . In our implementation, we sample instances from a fixed number of nearest instances. As a result, the optimization formulation (12) approximately minimizes the certified robust error.
Our adversarially robust metric learning (ARML) algorithm is shown in Algorithm 1. At every iteration, is updated with the gradient, while the calculations of and do not contribute to the gradient for the sake of efficient and stable computation.
3.4 Exact minimal adversarial perturbation of Mahalanobis 1NN
In the special Mahalanobis 1NN case, we will show a method to compute the exact minimal adversarial perturbation in a similar formulation to (6). However, this algorithm can only compute a numerical value of the minimal adversarial perturbation , so it cannot be used in training time. We will use this method to evaluate the robust error for the Mahalanobis 1NN case in the experiments.
Computing the minimal adversarial perturbation for Mahalanobis 1NN classifier can be formulated as the following optimization problem:
(13) 
This is equivalent to considering each in a different class from and computing the minimum perturbation needed for making closer to than all the training instances in the same class with , i.e., s,. It is noteworthy that the constraint of (13) could be equivalently written as
(14) 
which are all affine functions. Therefore, the inner minimization is a convex quadratic programming problem and could be solved in polynomial time [kozlov1980polynomial]. As a result, it leads to a naive polynomialtime algorithm for finding the minimal adversarial perturbation of Mahalanobis 1NN: solve all the inner quadratic programming problems and then select the minimum of them.
Instead, we propose a much more efficient method to solve (13). The main idea is to compute a lower bound for each inner minimization problem first, and with these lower bounds, we could screen most of the inner minimization problems safely without the need of solving them exactly. This method is an extension of [wang2019evaluating], where they only take the Euclidean distance into consideration. See Algorithm 2 in Appendix C for details and this algorithm is used for computing certified robust errors of Mahalanobis 1NN in the experimental section.
4 Experiments
We compare the proposed ARML (Adversarial Robust Metric Learning) method with the following baselines:

Euclidean: uses the Euclidean distance directly without learning any metric;

Neighbourhood components analysis (NCA) [goldberger2004neighbourhood]: maximizes a stochastic variant of the leaveoneout nearest neighbors score on the training set.

Large margin nearest neighbor (LMNN) [weinberger2009distance]: keeps close nearest neighbors from the same class, while keeps instances from different classes separated by a large margin.

Information Theoretic Metric Learning (ITML) [davis2007information]: minimizes the logdeterminant divergence with similarity and dissimilarity constraints.

Local Fisher Discriminant Analysis (LFDA) [sugiyama2007dimensionality]: a modified version of linear discriminant analysis by rewriting scatter matrices in a pairwise manner.
For evaluation, we use six public datasets on which metric learning methods perform favorably in terms of clean errors, including four small or mediumsized datasets [chang2011libsvm]: Splice, Pendigits, Satimage and USPS, and two image datasets MNIST [lecun1998gradient] and FashionMNIST [xiao2017fashion]
, which are wildly used for robust verification for neural networks. For the proposed method, we use the same hyperparameters for all the datasets (see Appendix
D for the dataset statistics, more details of the experimental setting, and hyperparameter sensitivity analysis).MNIST  radius  0.000  0.500  1.000  1.500  2.000  2.500 

Euclidean  0.033  0.112  0.274  0.521  0.788  0.945  
NCA  0.025  0.140  0.452  0.839  0.977  1.000  
LMNN  0.032  0.641  0.999  1.000  1.000  1.000  
ITML  0.073  0.571  0.928  1.000  1.000  1.000  
LFDA  0.152  1.000  1.000  1.000  1.000  1.000  
ARML (Ours)  0.024  0.089  0.222  0.455  0.757  0.924  
FashionMNIST  radius  0.000  0.500  1.000  1.500  2.000  2.500 
Euclidean  0.145  0.381  0.606  0.790  0.879  0.943  
NCA  0.116  0.538  0.834  0.950  0.998  1.000  
LMNN  0.142  0.756  0.991  1.000  1.000  1.000  
ITML  0.163  0.672  0.929  0.998  1.000  1.000  
LFDA  0.211  1.000  1.000  1.000  1.000  1.000  
ARML (Ours)  0.127  0.348  0.568  0.763  0.859  0.928  
Splice  radius  0.000  0.100  0.200  0.300  0.400  0.500 
Euclidean  0.320  0.513  0.677  0.800  0.854  0.880  
NCA  0.130  0.252  0.404  0.584  0.733  0.836  
LMNN  0.190  0.345  0.533  0.697  0.814  0.874  
ITML  0.306  0.488  0.679  0.809  0.862  0.882  
LFDA  0.264  0.434  0.605  0.760  0.845  0.872  
ARML (Ours)  0.130  0.233  0.370  0.526  0.652  0.758  
Pendigits  radius  0.000  0.100  0.200  0.300  0.400  0.500 
Euclidean  0.032  0.119  0.347  0.606  0.829  0.969  
NCA  0.034  0.202  0.586  0.911  0.997  1.000  
LMNN  0.029  0.183  0.570  0.912  0.995  0.999  
ITML  0.049  0.308  0.794  0.991  1.000  1.000  
LFDA  0.042  0.236  0.603  0.912  0.998  1.000  
ARML (Ours)  0.028  0.115  0.344  0.598  0.823  0.967  
Satimage  radius  0.000  0.150  0.300  0.450  0.600  0.750 
Euclidean  0.108  0.642  0.864  0.905  0.928  0.951  
NCA  0.103  0.710  0.885  0.915  0.940  0.963  
LMNN  0.092  0.665  0.871  0.912  0.944  0.969  
ITML  0.127  0.807  0.979  1.000  1.000  1.000  
LFDA  0.125  0.836  0.919  0.956  0.992  1.000  
ARML (Ours)  0.095  0.605  0.839  0.899  0.920  0.946  
USPS  radius  0.000  0.500  1.000  1.500  2.000  2.500 
Euclidean  0.045  0.224  0.585  0.864  0.970  0.999  
NCA  0.056  0.384  0.888  0.987  1.000  1.000  
LMNN  0.046  0.825  1.000  1.000  1.000  1.000  
ITML  0.060  0.720  0.999  1.000  1.000  1.000  
LFDA  0.098  1.000  1.000  1.000  1.000  1.000  
ARML (Ours)  0.043  0.204  0.565  0.857  0.970  0.999 
Certified robust errors  Empirical robust errors  

MNIST  radius  0.000  0.500  1.000  1.500  2.000  2.500  0.000  0.500  1.000  1.500  2.000  2.500 
Euclidean  0.038  0.134  0.360  0.618  0.814  0.975  0.031  0.063  0.104  0.155  0.204  0.262  
NCA  0.030  0.175  0.528  0.870  0.986  1.000  0.027  0.063  0.120  0.216  0.330  0.535  
LMNN  0.040  0.669  1.000  1.000  1.000  1.000  0.036  0.121  0.336  0.775  0.972  1.000  
ITML  0.106  0.731  0.943  1.000  1.000  1.000  0.084  0.218  0.355  0.510  0.669  0.844  
LFDA  0.237  1.000  1.000  1.000  1.000  1.000  0.215  1.000  1.000  1.000  1.000  1.000  
ARML (Ours)  0.034  0.101  0.276  0.537  0.760  0.951  0.032  0.055  0.077  0.109  0.160  0.213  
FashionMNIST  radius  0.000  0.500  1.000  1.500  2.000  2.500  0.000  0.500  1.000  1.500  2.000  2.500 
Euclidean  0.160  0.420  0.650  0.800  0.895  0.946  0.143  0.227  0.298  0.360  0.420  0.489  
NCA  0.144  0.557  0.832  0.946  1.000  1.000  0.121  0.232  0.343  0.483  0.624  0.780  
LMNN  0.158  0.792  0.991  1.000  1.000  1.000  0.140  0.364  0.572  0.846  0.983  0.999  
ITML  0.236  0.784  0.949  1.000  1.000  1.000  0.209  0.460  0.692  0.892  0.978  1.000  
LFDA  0.291  1.000  1.000  1.000  1.000  1.000  0.263  0.870  0.951  0.975  0.988  0.995  
ARML (Ours)  0.152  0.371  0.589  0.755  0.856  0.924  0.134  0.202  0.274  0.344  0.403  0.487  
Splice  radius  0.000  0.100  0.200  0.300  0.400  0.500  0.000  0.100  0.200  0.300  0.400  0.500 
Euclidean  0.333  0.558  0.826  0.965  0.988  0.996  0.306  0.431  0.526  0.608  0.676  0.743  
NCA  0.103  0.209  0.415  0.659  0.824  0.921  0.103  0.173  0.274  0.414  0.570  0.684  
LMNN  0.149  0.332  0.630  0.851  0.969  0.994  0.149  0.241  0.357  0.492  0.621  0.722  
ITML  0.279  0.571  0.843  0.974  0.995  0.997  0.279  0.423  0.525  0.603  0.675  0.751  
LFDA  0.242  0.471  0.705  0.906  0.987  0.997  0.242  0.371  0.466  0.553  0.637  0.737  
ARML (Ours)  0.128  0.221  0.345  0.509  0.666  0.819  0.128  0.196  0.273  0.380  0.497  0.639  
Pendigits  radius  0.000  0.100  0.200  0.300  0.400  0.500  0.000  0.100  0.200  0.300  0.400  0.500 
Euclidean  0.039  0.126  0.316  0.577  0.784  0.937  0.036  0.085  0.155  0.248  0.371  0.528  
NCA  0.038  0.196  0.607  0.884  0.997  1.000  0.038  0.103  0.246  0.428  0.637  0.804  
LMNN  0.034  0.180  0.568  0.898  0.993  0.999  0.030  0.096  0.246  0.462  0.681  0.862  
ITML  0.060  0.334  0.773  0.987  1.000  1.000  0.060  0.149  0.343  0.616  0.814  0.926  
LFDA  0.047  0.228  0.595  0.904  1.000  1.000  0.043  0.104  0.248  0.490  0.705  0.842  
ARML (Ours)  0.035  0.114  0.308  0.568  0.780  0.937  0.034  0.078  0.138  0.235  0.368  0.516  
Satimage  radius  0.000  0.150  0.300  0.450  0.600  0.750  0.000  0.150  0.300  0.450  0.600  0.750 
Euclidean  0.101  0.579  0.842  0.899  0.927  0.948  0.091  0.237  0.482  0.682  0.816  0.897  
NCA  0.117  0.670  0.886  0.915  0.936  0.961  0.101  0.297  0.564  0.746  0.876  0.931  
LMNN  0.105  0.613  0.855  0.914  0.944  0.961  0.090  0.269  0.548  0.737  0.855  0.910  
ITML  0.130  0.768  0.959  1.000  1.000  1.000  0.109  0.411  0.757  0.939  0.990  1.000  
LFDA  0.128  0.779  0.904  0.958  0.995  1.000  0.112  0.389  0.673  0.860  0.950  0.986  
ARML (Ours)  0.103  0.540  0.824  0.898  0.920  0.943  0.092  0.228  0.464  0.668  0.817  0.896  
USPS  radius  0.000  0.500  1.000  1.500  2.000  2.500  0.000  0.500  1.000  1.500  2.000  2.500 
Euclidean  0.063  0.239  0.586  0.888  0.977  1.000  0.058  0.125  0.211  0.365  0.612  0.751  
NCA  0.072  0.367  0.903  0.986  1.000  1.000  0.063  0.158  0.365  0.686  0.899  0.980  
LMNN  0.062  0.856  1.000  1.000  1.000  1.000  0.055  0.359  0.890  0.999  1.000  1.000  
ITML  0.082  0.696  0.999  1.000  1.000  1.000  0.072  0.273  0.708  0.987  1.000  1.000  
LFDA  0.134  1.000  1.000  1.000  1.000  1.000  0.118  0.996  1.000  1.000  1.000  1.000  
ARML (Ours)  0.057  0.203  0.527  0.867  0.971  0.997  0.053  0.118  0.209  0.344  0.572  0.785 
Mahalanobis 1NN
Certified robust errors of Mahalanobis 1NN with respect to different perturbations are shown in Table 1. Note that for 1NN, the proposed algorithm in Algorithm 2, which solves (13), can compute the exact minimal adversarial perturbation for each instance, so the values we get in Table 1 are both certified robust errors and empirical robust errors (attack errors). Also, note that when radius, the resulting certified robust error is equivalent to the clean error on the unperturbed test set.
We have three main observations from the experimental results. First, although NCA and LMNN achieve better clean errors (at the radius 0) than Euclidean in most datasets, they are less robust to adversarial perturbations than Euclidean (except the Splice dataset, on which Euclidean performs overly poorly in terms of clean errors). Both NCA and LMNN suffer from the tradeoff between the clean error and the certified robust error. Second, ARML performs competitively with NCA and LMNN in terms of clean errors (achieves the best on 4/6 of the datasets). Third and the most importantly, ARML is much more robust than all the other methods in terms of certified robust errors for nearly all perturbation radii.
Mahalanobis Nn
For NN models, it is intractable to compute the exact minimal adversarial perturbation, so we report both certified robust errors and empirical robust errors (attack errors). We set for all the experiments. The certified robust error can be computed by Theorem 1, which works for any Mahalanobis metric. On the other hand, we also conduct adversarial attack to these models to derive the empirical robust error — the lower bounds of the certified robust errors — via a hardlabel blackbox attack method, the Boundary Attack [brendel2018decision]. Different from the NN case, since both attack and robustness verification are not optimal, there will be a gap between the two numbers. These results are shown in Table 2.
The three observations of Mahalanobis 1NN also hold for the NN: NCA and LMNN have improved clean errors (empirical robust errors at the radius 0) but this often comes with degraded robust errors compared with the Euclidean distance, while ARML achieves good robust errors as well as clean errors. The results suggest that ARML is more robust both provably (in terms of the certified robust error) and empirically (in terms of the empirical robust error).
5 Related work
Metric Learning
Metric learning aims to learn a new distance using supervision concerning the learned distance [kulis2013metric]. In this paper, we mainly focus on the linear metric learning: the learned distance is the squared Euclidean distance after applying the transformation globally, i.e., the Mahalanobis distance [goldberger2004neighbourhood, davis2007information, weinberger2009distance, jain2010inductive, sugiyama2007dimensionality]. There are also nonlinear models for metric learning, such as kernelized metric learning [kulis2006learning, chatpatanasiri2010a], local metric learning [frome2007learning, weinberger2008fast] and deep metric learning [chopra2005learning, schroff2015facenet]. Robustness verification for nonlinear metric learning and learning a provably robust nonlinear metric would be an interesting future work.
Adversarial robustness of neural networks
Empirical defense aims to learn a classifier which is robust to some adversarial attacks [kurakin2017adversarial, madry2018towards], but has no guarantee for the robustness to other stronger (or unknown) adversarial attacks [carlini2017towards, athalye2018obfuscated]. In contrast, certified defense provides a guarantee that no adversarial examples exist within a certain input region [wong2018provable, cohen2019certified, zhang2020towards]. The basic idea of these certified defense methods is to minimize the certified robust training error. However, all these methods for neural networks rely on the assumption of smoothness of the classifier, and hence could not be applied to the nearest neighbor classifiers.
Adversarial robustness of nearest neighbor classifiers
Most works about adversarial robustness of NN focus on adversarial attack. Some papers propose to attack a differentiable substitute of NN [papernot2016transferability, sitawarin2020minimum]
, and others formalize the attack as a list of quadratic programming problems or linear programming problems
[wang2019evaluating, yang2020robust]. As far as we know, there is only one paper considering adversarial verification for NN, but they only consider the Euclidean distance, and no certified defense method is proposed [wang2019evaluating]. In contrast, we propose the first adversarial verification method and the first certified defense (or provably robust learning) for Mahalanobis NN.6 Conclusion
We propose a novel metric learning method named ARML to obtain a robust Mahalanobis distance that can be robust to adversarial input perturbations. Experiments show that the proposed method can improve both clean errors and robust errors compared with existing metric learning algorithms.
References
Appendix A Optimal value of triplet problem
The triplet problem is formalized as below:
(15) 
It is equivalent to the optimization
(16) 
where we have
(17) 
(18) 
The dual function is
(19)  
(20) 
where holds for . Then the dual problem is
(21) 
The optimal point is
(22) 
and the optimal value is
(23) 
By the Slater’s condition, if holds, we have the strong duality. Therefore, the optimal value of (15) is
(24) 
In fact, it is easy to verify that even if obtains, the optimal value also holds.
Appendix B Proof of Theorem 1
Appendix C Details of computing exact minimal adversarial perturbation of Mahalanobis 1NN
The overall algorithm is displayed in Algorithm 2. We denote as the optimal value of the inner minimization problem with respect to , and denote as its lower bound. We first sort the subproblems according to the ascending order of for . For every subproblem, we compute the lower bound of its optimal value. If the optimal value is too large, we just screen the subproblem safely without solving it exactly.
c.1 Greedy coordinate ascent (descent)
For the subproblem we have to solve exactly, we employ the greedy coordinate ascent method. Note that the inner minimization problem of (13) is a convex quadratic programming problem. We solve the problem by dealing with its dual formulation. The greedy coordinate ascent method is used because the optimal dual variables are very sparse. The algorithm is shown in Algorithm 3. At every iteration, only one dual variable is updated.
c.2 Lower bound of inner minimization problem
The following theorem is dependent on the solution of the triplet problem.
Theorem 2.
The optimal value of the inner minimization of (13) with respect to is lower bounded as
(32) 
Proof.
Relaxing the constraint of (13) by means of replacing the universal quantifier, we know is lower bounded by the optimal value of the following optimization problem
(33)  
s.t.  (34) 
Obviously, the optimal value of the inner problem is . ∎
In this way, we could derive a lower bound of the optimal value in closed form.
Appendix D Experimental details
Datasets
Dataset statistics and test errors (on all test instances) of Euclidean NN are shown in Table 3. All training data are used to learn metrics, and 1,000 instances are randomly sampled to compute certified robust errors.
# features  # classes  # train  # test  1NN test error  11NN test error  

MNIST  784  10  60,000  10,000  0.031  0.033 
FashionMNIST  784  10  60,000  10,000  0.150  0.150 
Splice  60  2  1,000  2,175  0.295  0.291 
Pendigits  16  10  7,494  3,498  0.023  0.027 
Satimage  36  6  4,435  2,000  0.112  0.106 
USPS  256  10  7,291  2,007  0.049  0.060 
Hyperparameters
Hyperparameters of our ARML algorithm are fixed across all datasets. Specifically, the size of the neighborhood where and sample random instances is 10. In other words, at every iteration, we sample one instance from the nearest 10 instances in the same class with the test instance, and sample one instance from the nearest 10 instances in the different classes from the test instance. We employ the Adam algorithm [kingma2015adam] to update parameters with gradients and the parameters is in the default setting (learning rate: 0.001, betas: ). The number of epochs is 1,000. The loss function is the negative loss.
Hyperparameter sensitivity
We investigate the sensitivity of the size of neighborhood used for and . We plot the robust error curves against the radius for different neighborhood sizes in Figure 2 and Figure 3. It suggests that ARML is not very sensitive to this hyperparameter in terms of certified and empirical robust errors if it is not too small.
Implementations of NCA, LMNN, ITML and LFDA
We use the implementations of the metriclearn library [vazelhes2019metric] for NCA, LMNN, ITML and LFDA. Similar to ARML, hyperparameters are fixed across all datasets and are in the default setting. In particular, the maximum numbers of iterations for NCA, LMNN and ITML are 1,000, 100 and 1,000 respectively.