Provably Robust Metric Learning

by   Lu Wang, et al.
Nanjing University, Inc.

Metric learning is an important family of algorithms for classification and similarity search, but the robustness of learned metrics against small adversarial perturbations is less studied. In this paper, we show that existing metric learning algorithms, which focus on boosting the clean accuracy, can result in metrics that are less robust than the Euclidean distance. To overcome this problem, we propose a novel metric learning algorithm to find a Mahalanobis distance that is robust against adversarial perturbations, and the robustness of the resulting model is certifiable. Experimental results show that the proposed metric learning algorithm improves both certified robust errors and empirical robust errors (errors under adversarial attacks). Furthermore, unlike neural network defenses which usually encounter a trade-off between clean and robust errors, our method does not sacrifice clean errors compared with previous metric learning methods. Our code is available at


page 1

page 2

page 3

page 4


Robustness and Generalization for Metric Learning

Metric learning has attracted a lot of interest over the last decade, bu...

Two-Stage Metric Learning

In this paper, we present a novel two-stage metric learning algorithm. W...

SemBleu: A Robust Metric for AMR Parsing Evaluation

Evaluating AMR parsing accuracy involves comparing pairs of AMR graphs. ...

Metric Learning for Adversarial Robustness

Deep networks are well-known to be fragile to adversarial attacks. Using...

A Distributed Approach towards Discriminative Distance Metric Learning

Distance metric learning is successful in discovering intrinsic relation...

A Metric Learning Reality Check

Deep metric learning papers from the past four years have consistently c...

Hyperbolic Vision Transformers: Combining Improvements in Metric Learning

Metric learning aims to learn a highly discriminative model encouraging ...

1 Introduction

Metric learning has been an important family of machine learning algorithms and has achieved successes on several problems, including computer vision 

[kulis2009fast, guillaumin2009you, hua2007discriminant], text analysis [lebanon2006metric], meta learning [vinyals2016matching, snell2017prototypical] and others [slaney2008learning, xiong2006kernel, ye2016what]

. Given a set of training samples, metric learning aims to learn a good distance measurement such that items in the same class are closer to each other in the learned metric space, which is crucial for classification and similarity search. Since this objective is directly related to the assumption of nearest neighbor classifiers, most of the metric learning algorithms can be naturally and successfully combined with

-Nearest Neighbor (-NN) classifiers.

Adversarial robustness of machine learning algorithms has been studied extensively in recent years due to the need of robustness guarantees in real world systems. It has been demonstrated that neural networks can be easily attacked by adversarial perturbations in the input space [szegedy2014intriguing, goodfellow2015explaining, biggio2018wild], and such perturbations can be computed efficiently in both white-box [carlini2017towards, madry2018towards] and black-box settings [chen2017zoo, ilyas2019prior, cheng2020signopt]. Therefore, many defense algorithms have been proposed to improve the robustness of neural networks [kurakin2017adversarial, madry2018towards]. Although these algorithms can successfully defend from standard attacks, it has been shown that many of them are vulnerable under stronger attacks when the attacker knows the defense mechanisms [carlini2017towards]. Therefore, recent research in adversarial defense of neural networks has shifted to the concept of “certified defense”, where the defender needs to provide a certification that no adversarial examples exist within a certain input region [wong2018provable, cohen2019certified, zhang2020towards].

In this paper, we consider the problem of learning a metric that is robust against adversarial input perturbations. It has been shown that nearest neighbor classifiers are not robust [papernot2016transferability, wang2019evaluating, sitawarin2020minimum], where a small and human imperceptible perturbation in the input space can easily fool a -NN classifier, thus it is natural to investigate how to obtain a metric that improves the adversarial robustness. Despite being an important and interesting research question to tackle, to the best of our knowledge this problem has not been studied in the literature. There are several caveats that make this a hard problem: 1) attack and defense algorithms for neural networks often rely on the smoothness of the function, while -NN is a discrete step function where the gradient does not exist. 2) Even evaluating the robustness of -NN (with the Euclidean distance) is harder than neural networks. For instance, [wang2019evaluating] showed that attacks to -NN are time consuming and nontrivial. Furthermore, none of the existing work have considered general Mahalanobis distances. 3) Existing algorithms for evaluating the robustness of -NN, including attack [yang2020robust] and verification [wang2019evaluating], are often non-differentiable, while training a robust metric will require a differentiable measurement of robustness.

To develop a provably robust metric learning algorithm, we formulate an objective function to learn a Mahalanobis distance, parameterized by a positive semi-definite matrix , that maximizes the minimal adversarial perturbation on each sample. However, computing the minimal adversarial perturbation is intractable for -NN, so to make the problem solvable, we propose an efficient formulation for lower-bounding the minimal adversarial perturbation, and this lower bound can be represented as an explicit function of to enable the gradient computation. We further develop several tricks to improve the efficiency of the overall procedure. Similar to certified defense algorithms in neural networks, the proposed algorithm can provide a certified robustness improvement on the resulting -NN model with the learned metric. Decision boundaries of 1-NN with different Mahalanobis distances for a toy dataset are visualized in Figure 1. It can be observed that the proposed Adversarial Robust Metric Learning (ARML) method can obtain a more robust metric on this example.

We conduct extensive experiments on six real world datasets and show that the proposed algorithm can improve both certified robust errors and the empirical robust errors (errors under adversarial attacks) over existing metric learning algorithms.

(a) Euclidean
(b) NCA [goldberger2004neighbourhood]
(c) ARML (Ours)
Figure 1: Decision boundaries of 1-NN with different Mahalanobis distances.

2 Background

Metric learning for nearest neighbor classifiers

A nearest-neighbor classifier based on a Mahalanobis distance could be characterized by a training dataset and a positive semi-definite matrix. Let be the instance space, the label space where is the number of classes. is the training set with for every . is a positive semi-definite matrix. The Mahalanobis distance for any is defined as


and a Mahalanobis -NN classifier will find the nearest neighbors of the test instance in based on the Mahalanobis distance, and then predicts the label based on majority voting of these neighbors.

Many metric learning approaches aim to learn a good Mahalanobis distance based on training data, including [goldberger2004neighbourhood, davis2007information, weinberger2009distance, jain2010inductive, sugiyama2007dimensionality] (see more discussions in Section 5). However, none of these previous methods are trying to find a metric that is robust to small input perturbations.

Adversarial robustness and minimal adversarial perturbation

There are two important concepts in adversarial robustness: adversarial attack and adversarial verification (or robustness verification). Adversarial attack aims to find a perturbation to change the prediction, and adversarial verification aims to find a radius within which no perturbation could change the prediction. Both of them can be reduced to the problem of finding the minimal adversarial perturbation. For a classifier on an instance , the minimal adversarial perturbation can be defined as


which is the smallest perturbation that could lead to misclassification. Note that if is not correctly classified, the minimal adversarial perturbation is

, i.e., the zero vector. Let

denote the optimal solution and the optimal value. Obviously, is also the solution of the optimal adversarial attack, and is the solution of the optimal adversarial verification. For neural networks, it is often NP-complete to compute (2), so many efficient algorithms have been proposed for attack [goodfellow2015explaining, carlini2017towards, brendel2018decision, cheng2020signopt] and verification [wong2018provable, weng2018towards, mirman2018differentiable], corresponding to computing upper and lower bounds of (2) respectively. However, these methods do not work for discrete models such as nearest neighbor classifiers.

In this paper our algorithm will be based on a novel derivation of a lower bound of the minimal adversarial perturbation for Mahalanobis -NN classifiers. To the best of our knowledge, there has been no previous work tackling this problem. Since the Mahalanobis -NN classifier is parameterized by a positive semi-definite matrix and the training set , we further let and explicitly indicate their dependence on and . In this paper we will consider norm in (2) for simplicity.

Certified and empirical robust error

Let be a lower bound of the norm of the minimal adversarial perturbation , possibly computed by a robustness verification algorithm. For a distribution over , the certified robust error with respect to the radius

is defined as the probability that

is not greater than , namely


Note that in the case and , the certified robust error is reduced to the clean error. In this paper we will investigate how to compute it for Mahalanobis -NN classifiers.

On the other hand, adversarial attack algorithms are trying to find a feasible solution of (2), which will give an upper bound . Based on the upper bound, we can measure the empirical robust error of a model by


Since is computed by an attack method, the empirical robust error is also called the attack error. A family of decision-based attack methods, which view the victim model as a black-box, can be used to attack Mahalanobis -NN classifiers [brendel2018decision, cheng2019query, cheng2020signopt].

3 Adversarially robust metric learning

The objective of adversarially robust metric learning (ARML) is to learn the matrix via the training data such that the resulting Mahalanobis -NN classifier has small certified and empirical robust errors.

3.1 Basic formulation

The goal is to learn a positive semi-definite matrix to minimize the certified robust training error. Since the certified robust error defined in (3

) is non-smooth, we replace the indicator function by a loss function. The resulting objective can be formulated as


where is an monotonically non-increasing function, e.g., the hinge loss , exponential loss , logistic loss , or “negative” loss . We also employ to enforce to be positive semi-definite, and it is possible to derive a low-rank by constraining the shape of . Note that the minimal adversarial perturbation is defined on the training set excluding , since otherwise a 1-nearest neighbor classifier with any distance measurement will have 100% accuracy. In this way, we minimize the “leave-one-out” certified robust error. The remaining problem is how to exactly compute or approximate in our training objective.

3.2 Bounding minimal adversarial perturbation for Mahalanobis -Nn

For convenience, suppose

is an odd number and denote

. In the binary classification case for simplicity, i.e., , the computation of for Mahalanobis -NN could be formulated as


This minimization formulation enumerates all the -size nearest neighbor set containing at most instances in the same class with the test instance, computes the minimum perturbation resulting in each -nearest neighbor set, and takes the minimum of them.

Obviously, solving (6) exactly has time complexity growing exponentially with , and furthermore, a numerical solution cannot be incorporated into the training objective (5) since we need to write as a function of for back-propagation. To address these issues, we resort to a lower bound of the optimal value of (6) rather than solving it exactly.

First, we consider a simple triplet problem: given vectors and a positive semi-definite matrix , find the minimum perturbation on such that holds. It could be formulated as the following optimization problem


Note that the constraint in (7) can be written as a linear form, so this is a convex quadratic programming problem with a linear constraint. We show that the optimal value of (7) can be expressed in closed form:


where denotes . The derivation for the optimal value is deferred to Appendix A. Note that if

is the identity matrix and

strictly holds, then the optimal value is the Euclidean distance from to the bisection between and .

For convenience, we define the function as


Then we could relax (6) further and have the following theorem:

Theorem 1 (Robustness verification for Mahalanobis -Nn).

Given a Mahalanobis -NN classifier parameterized by a neighbor parameter , a training dataset and a positive semi-definite matrix , for any instance we have


where and select the -th maximum and -th minimum respectively with .

The proof is deferred to Appendix B. In this way, we only need to compute for each and in order to derive a lower bound of the minimal adversarial perturbation of Mahalanobis -NN. It leads to an efficient algorithm to verify the robustness of Mahalanobis -NN. The time complexity is and independent of .

In the general multi-class case, the constraint of (6) is the necessary condition for successful attacks, rather than the necessary and sufficient condition. As a result, the optimal value of (6) is a lower bound of the minimal adversarial perturbation. Therefore, Theorem 1 also holds for the multi-class case. Based on this lower bound of , we will derive the proposed ARML algorithm.

3.3 Training algorithm of adversarially robust metric learning

By replacing the in (5) with the lower bound derived in Theorem 1, we get a trainable objective function for adversarially robust metric learning:


Although (11) is trainable since is a function of , for large datasets it is time-consuming to run the inner min-max procedure. Furthermore, since we care about the generalization performance of the learned metric instead of the robust training error, it is unnecessary to compute the exact solution. Therefore, instead of computing the and exactly, we propose to sample positive and negative instances from the neighborhood, which leads to the following formulation:


where denotes a sampling procedure for an instance in the same class within ’s neighborhood, and denotes a sampling procedure for an instance in a different class, also within ’s neighborhood, and the distances are measured by the Mahalanobis distance . In our implementation, we sample instances from a fixed number of nearest instances. As a result, the optimization formulation (12) approximately minimizes the certified robust error.

Our adversarially robust metric learning (ARML) algorithm is shown in Algorithm 1. At every iteration, is updated with the gradient, while the calculations of and do not contribute to the gradient for the sake of efficient and stable computation.

Input: Training data

, number of epochs

Output: Positive semi-definite matrix .
1 Initialize and as identity matrices ;
2 for   do
3       Update with the gradient ;
4       Update with the constraint ;
6 end for
Algorithm 1 Adversarially robust metric learning (ARML)

3.4 Exact minimal adversarial perturbation of Mahalanobis 1-NN

In the special Mahalanobis 1-NN case, we will show a method to compute the exact minimal adversarial perturbation in a similar formulation to (6). However, this algorithm can only compute a numerical value of the minimal adversarial perturbation , so it cannot be used in training time. We will use this method to evaluate the robust error for the Mahalanobis 1-NN case in the experiments.

Computing the minimal adversarial perturbation for Mahalanobis 1-NN classifier can be formulated as the following optimization problem:


This is equivalent to considering each in a different class from and computing the minimum perturbation needed for making closer to than all the training instances in the same class with , i.e., s,. It is noteworthy that the constraint of (13) could be equivalently written as


which are all affine functions. Therefore, the inner minimization is a convex quadratic programming problem and could be solved in polynomial time [kozlov1980polynomial]. As a result, it leads to a naive polynomial-time algorithm for finding the minimal adversarial perturbation of Mahalanobis 1-NN: solve all the inner quadratic programming problems and then select the minimum of them.

Instead, we propose a much more efficient method to solve (13). The main idea is to compute a lower bound for each inner minimization problem first, and with these lower bounds, we could screen most of the inner minimization problems safely without the need of solving them exactly. This method is an extension of [wang2019evaluating], where they only take the Euclidean distance into consideration. See Algorithm 2 in Appendix C for details and this algorithm is used for computing certified robust errors of Mahalanobis 1-NN in the experimental section.

4 Experiments

We compare the proposed ARML (Adversarial Robust Metric Learning) method with the following baselines:

  • Euclidean: uses the Euclidean distance directly without learning any metric;

  • Neighbourhood components analysis (NCA) [goldberger2004neighbourhood]: maximizes a stochastic variant of the leave-one-out nearest neighbors score on the training set.

  • Large margin nearest neighbor (LMNN) [weinberger2009distance]: keeps close nearest neighbors from the same class, while keeps instances from different classes separated by a large margin.

  • Information Theoretic Metric Learning (ITML) [davis2007information]: minimizes the log-determinant divergence with similarity and dissimilarity constraints.

  • Local Fisher Discriminant Analysis (LFDA) [sugiyama2007dimensionality]: a modified version of linear discriminant analysis by rewriting scatter matrices in a pairwise manner.

For evaluation, we use six public datasets on which metric learning methods perform favorably in terms of clean errors, including four small or medium-sized datasets [chang2011libsvm]: Splice, Pendigits, Satimage and USPS, and two image datasets MNIST [lecun1998gradient] and Fashion-MNIST [xiao2017fashion]

, which are wildly used for robust verification for neural networks. For the proposed method, we use the same hyperparameters for all the datasets (see Appendix 

D for the dataset statistics, more details of the experimental setting, and hyperparameter sensitivity analysis).

MNIST -radius 0.000 0.500 1.000 1.500 2.000 2.500
Euclidean 0.033 0.112 0.274 0.521 0.788 0.945
NCA 0.025 0.140 0.452 0.839 0.977 1.000
LMNN 0.032 0.641 0.999 1.000 1.000 1.000
ITML 0.073 0.571 0.928 1.000 1.000 1.000
LFDA 0.152 1.000 1.000 1.000 1.000 1.000
ARML (Ours) 0.024 0.089 0.222 0.455 0.757 0.924
Fashion-MNIST -radius 0.000 0.500 1.000 1.500 2.000 2.500
Euclidean 0.145 0.381 0.606 0.790 0.879 0.943
NCA 0.116 0.538 0.834 0.950 0.998 1.000
LMNN 0.142 0.756 0.991 1.000 1.000 1.000
ITML 0.163 0.672 0.929 0.998 1.000 1.000
LFDA 0.211 1.000 1.000 1.000 1.000 1.000
ARML (Ours) 0.127 0.348 0.568 0.763 0.859 0.928
Splice -radius 0.000 0.100 0.200 0.300 0.400 0.500
Euclidean 0.320 0.513 0.677 0.800 0.854 0.880
NCA 0.130 0.252 0.404 0.584 0.733 0.836
LMNN 0.190 0.345 0.533 0.697 0.814 0.874
ITML 0.306 0.488 0.679 0.809 0.862 0.882
LFDA 0.264 0.434 0.605 0.760 0.845 0.872
ARML (Ours) 0.130 0.233 0.370 0.526 0.652 0.758
Pendigits -radius 0.000 0.100 0.200 0.300 0.400 0.500
Euclidean 0.032 0.119 0.347 0.606 0.829 0.969
NCA 0.034 0.202 0.586 0.911 0.997 1.000
LMNN 0.029 0.183 0.570 0.912 0.995 0.999
ITML 0.049 0.308 0.794 0.991 1.000 1.000
LFDA 0.042 0.236 0.603 0.912 0.998 1.000
ARML (Ours) 0.028 0.115 0.344 0.598 0.823 0.967
Satimage -radius 0.000 0.150 0.300 0.450 0.600 0.750
Euclidean 0.108 0.642 0.864 0.905 0.928 0.951
NCA 0.103 0.710 0.885 0.915 0.940 0.963
LMNN 0.092 0.665 0.871 0.912 0.944 0.969
ITML 0.127 0.807 0.979 1.000 1.000 1.000
LFDA 0.125 0.836 0.919 0.956 0.992 1.000
ARML (Ours) 0.095 0.605 0.839 0.899 0.920 0.946
USPS -radius 0.000 0.500 1.000 1.500 2.000 2.500
Euclidean 0.045 0.224 0.585 0.864 0.970 0.999
NCA 0.056 0.384 0.888 0.987 1.000 1.000
LMNN 0.046 0.825 1.000 1.000 1.000 1.000
ITML 0.060 0.720 0.999 1.000 1.000 1.000
LFDA 0.098 1.000 1.000 1.000 1.000 1.000
ARML (Ours) 0.043 0.204 0.565 0.857 0.970 0.999
Table 1: Certified robust errors of Mahalanobis 1-NN. The best (minimum) certified robust errors among all methods are in bold. Note that the certified robust errors of 1-NN are also the optimal attack errors.
Certified robust errors Empirical robust errors
MNIST -radius 0.000 0.500 1.000 1.500 2.000 2.500 0.000 0.500 1.000 1.500 2.000 2.500
Euclidean 0.038 0.134 0.360 0.618 0.814 0.975 0.031 0.063 0.104 0.155 0.204 0.262
NCA 0.030 0.175 0.528 0.870 0.986 1.000 0.027 0.063 0.120 0.216 0.330 0.535
LMNN 0.040 0.669 1.000 1.000 1.000 1.000 0.036 0.121 0.336 0.775 0.972 1.000
ITML 0.106 0.731 0.943 1.000 1.000 1.000 0.084 0.218 0.355 0.510 0.669 0.844
LFDA 0.237 1.000 1.000 1.000 1.000 1.000 0.215 1.000 1.000 1.000 1.000 1.000
ARML (Ours) 0.034 0.101 0.276 0.537 0.760 0.951 0.032 0.055 0.077 0.109 0.160 0.213
Fashion-MNIST -radius 0.000 0.500 1.000 1.500 2.000 2.500 0.000 0.500 1.000 1.500 2.000 2.500
Euclidean 0.160 0.420 0.650 0.800 0.895 0.946 0.143 0.227 0.298 0.360 0.420 0.489
NCA 0.144 0.557 0.832 0.946 1.000 1.000 0.121 0.232 0.343 0.483 0.624 0.780
LMNN 0.158 0.792 0.991 1.000 1.000 1.000 0.140 0.364 0.572 0.846 0.983 0.999
ITML 0.236 0.784 0.949 1.000 1.000 1.000 0.209 0.460 0.692 0.892 0.978 1.000
LFDA 0.291 1.000 1.000 1.000 1.000 1.000 0.263 0.870 0.951 0.975 0.988 0.995
ARML (Ours) 0.152 0.371 0.589 0.755 0.856 0.924 0.134 0.202 0.274 0.344 0.403 0.487
Splice -radius 0.000 0.100 0.200 0.300 0.400 0.500 0.000 0.100 0.200 0.300 0.400 0.500
Euclidean 0.333 0.558 0.826 0.965 0.988 0.996 0.306 0.431 0.526 0.608 0.676 0.743
NCA 0.103 0.209 0.415 0.659 0.824 0.921 0.103 0.173 0.274 0.414 0.570 0.684
LMNN 0.149 0.332 0.630 0.851 0.969 0.994 0.149 0.241 0.357 0.492 0.621 0.722
ITML 0.279 0.571 0.843 0.974 0.995 0.997 0.279 0.423 0.525 0.603 0.675 0.751
LFDA 0.242 0.471 0.705 0.906 0.987 0.997 0.242 0.371 0.466 0.553 0.637 0.737
ARML (Ours) 0.128 0.221 0.345 0.509 0.666 0.819 0.128 0.196 0.273 0.380 0.497 0.639
Pendigits -radius 0.000 0.100 0.200 0.300 0.400 0.500 0.000 0.100 0.200 0.300 0.400 0.500
Euclidean 0.039 0.126 0.316 0.577 0.784 0.937 0.036 0.085 0.155 0.248 0.371 0.528
NCA 0.038 0.196 0.607 0.884 0.997 1.000 0.038 0.103 0.246 0.428 0.637 0.804
LMNN 0.034 0.180 0.568 0.898 0.993 0.999 0.030 0.096 0.246 0.462 0.681 0.862
ITML 0.060 0.334 0.773 0.987 1.000 1.000 0.060 0.149 0.343 0.616 0.814 0.926
LFDA 0.047 0.228 0.595 0.904 1.000 1.000 0.043 0.104 0.248 0.490 0.705 0.842
ARML (Ours) 0.035 0.114 0.308 0.568 0.780 0.937 0.034 0.078 0.138 0.235 0.368 0.516
Satimage -radius 0.000 0.150 0.300 0.450 0.600 0.750 0.000 0.150 0.300 0.450 0.600 0.750
Euclidean 0.101 0.579 0.842 0.899 0.927 0.948 0.091 0.237 0.482 0.682 0.816 0.897
NCA 0.117 0.670 0.886 0.915 0.936 0.961 0.101 0.297 0.564 0.746 0.876 0.931
LMNN 0.105 0.613 0.855 0.914 0.944 0.961 0.090 0.269 0.548 0.737 0.855 0.910
ITML 0.130 0.768 0.959 1.000 1.000 1.000 0.109 0.411 0.757 0.939 0.990 1.000
LFDA 0.128 0.779 0.904 0.958 0.995 1.000 0.112 0.389 0.673 0.860 0.950 0.986
ARML (Ours) 0.103 0.540 0.824 0.898 0.920 0.943 0.092 0.228 0.464 0.668 0.817 0.896
USPS -radius 0.000 0.500 1.000 1.500 2.000 2.500 0.000 0.500 1.000 1.500 2.000 2.500
Euclidean 0.063 0.239 0.586 0.888 0.977 1.000 0.058 0.125 0.211 0.365 0.612 0.751
NCA 0.072 0.367 0.903 0.986 1.000 1.000 0.063 0.158 0.365 0.686 0.899 0.980
LMNN 0.062 0.856 1.000 1.000 1.000 1.000 0.055 0.359 0.890 0.999 1.000 1.000
ITML 0.082 0.696 0.999 1.000 1.000 1.000 0.072 0.273 0.708 0.987 1.000 1.000
LFDA 0.134 1.000 1.000 1.000 1.000 1.000 0.118 0.996 1.000 1.000 1.000 1.000
ARML (Ours) 0.057 0.203 0.527 0.867 0.971 0.997 0.053 0.118 0.209 0.344 0.572 0.785
Table 2: Certified robust errors (left) and empirical robust errors (right) of Mahalanobis -NN. The best (minimum) robust errors among all methods are in bold.

Mahalanobis 1-NN

Certified robust errors of Mahalanobis 1-NN with respect to different perturbations are shown in Table 1. Note that for 1-NN, the proposed algorithm in Algorithm 2, which solves (13), can compute the exact minimal adversarial perturbation for each instance, so the values we get in Table 1 are both certified robust errors and empirical robust errors (attack errors). Also, note that when radius, the resulting certified robust error is equivalent to the clean error on the unperturbed test set.

We have three main observations from the experimental results. First, although NCA and LMNN achieve better clean errors (at the radius 0) than Euclidean in most datasets, they are less robust to adversarial perturbations than Euclidean (except the Splice dataset, on which Euclidean performs overly poorly in terms of clean errors). Both NCA and LMNN suffer from the trade-off between the clean error and the certified robust error. Second, ARML performs competitively with NCA and LMNN in terms of clean errors (achieves the best on 4/6 of the datasets). Third and the most importantly, ARML is much more robust than all the other methods in terms of certified robust errors for nearly all perturbation radii.

Mahalanobis -Nn

For -NN models, it is intractable to compute the exact minimal adversarial perturbation, so we report both certified robust errors and empirical robust errors (attack errors). We set for all the experiments. The certified robust error can be computed by Theorem 1, which works for any Mahalanobis metric. On the other hand, we also conduct adversarial attack to these models to derive the empirical robust error — the lower bounds of the certified robust errors — via a hard-label black-box attack method, the Boundary Attack [brendel2018decision]. Different from the -NN case, since both attack and robustness verification are not optimal, there will be a gap between the two numbers. These results are shown in Table 2.

The three observations of Mahalanobis 1-NN also hold for the -NN: NCA and LMNN have improved clean errors (empirical robust errors at the radius 0) but this often comes with degraded robust errors compared with the Euclidean distance, while ARML achieves good robust errors as well as clean errors. The results suggest that ARML is more robust both provably (in terms of the certified robust error) and empirically (in terms of the empirical robust error).

5 Related work

Metric Learning

Metric learning aims to learn a new distance using supervision concerning the learned distance [kulis2013metric]. In this paper, we mainly focus on the linear metric learning: the learned distance is the squared Euclidean distance after applying the transformation globally, i.e., the Mahalanobis distance [goldberger2004neighbourhood, davis2007information, weinberger2009distance, jain2010inductive, sugiyama2007dimensionality]. There are also nonlinear models for metric learning, such as kernelized metric learning [kulis2006learning, chatpatanasiri2010a], local metric learning [frome2007learning, weinberger2008fast] and deep metric learning [chopra2005learning, schroff2015facenet]. Robustness verification for nonlinear metric learning and learning a provably robust non-linear metric would be an interesting future work.

Adversarial robustness of neural networks

Empirical defense aims to learn a classifier which is robust to some adversarial attacks [kurakin2017adversarial, madry2018towards], but has no guarantee for the robustness to other stronger (or unknown) adversarial attacks [carlini2017towards, athalye2018obfuscated]. In contrast, certified defense provides a guarantee that no adversarial examples exist within a certain input region [wong2018provable, cohen2019certified, zhang2020towards]. The basic idea of these certified defense methods is to minimize the certified robust training error. However, all these methods for neural networks rely on the assumption of smoothness of the classifier, and hence could not be applied to the nearest neighbor classifiers.

Adversarial robustness of nearest neighbor classifiers

Most works about adversarial robustness of -NN focus on adversarial attack. Some papers propose to attack a differentiable substitute of -NN [papernot2016transferability, sitawarin2020minimum]

, and others formalize the attack as a list of quadratic programming problems or linear programming problems 

[wang2019evaluating, yang2020robust]. As far as we know, there is only one paper considering adversarial verification for -NN, but they only consider the Euclidean distance, and no certified defense method is proposed [wang2019evaluating]. In contrast, we propose the first adversarial verification method and the first certified defense (or provably robust learning) for Mahalanobis -NN.

6 Conclusion

We propose a novel metric learning method named ARML to obtain a robust Mahalanobis distance that can be robust to adversarial input perturbations. Experiments show that the proposed method can improve both clean errors and robust errors compared with existing metric learning algorithms.


Appendix A Optimal value of triplet problem

The triplet problem is formalized as below:


It is equivalent to the optimization


where we have


The dual function is


where holds for . Then the dual problem is


The optimal point is


and the optimal value is


By the Slater’s condition, if holds, we have the strong duality. Therefore, the optimal value of (15) is


In fact, it is easy to verify that even if obtains, the optimal value also holds.

Appendix B Proof of Theorem 1


Let denote the optimal value of the inner minimization problem of (6). By relaxing the constraint via replacing the universal quantifier, we have


Substitute it in (6) and then we have


Appendix C Details of computing exact minimal adversarial perturbation of Mahalanobis 1-NN

The overall algorithm is displayed in Algorithm 2. We denote as the optimal value of the inner minimization problem with respect to , and denote as its lower bound. We first sort the subproblems according to the ascending order of for . For every subproblem, we compute the lower bound of its optimal value. If the optimal value is too large, we just screen the subproblem safely without solving it exactly.

Input: Test instance , dataset .
Output: Perturbation norm .
1 Initialize ;
2 Sort by the ascending order of ;
3 for  according to the ascending order do
4       Compute a lower bound of the inner minimization corresponding to ;
5       if  then
6             Solve the inner minimization problem exactly via the greedy coordinate ascent method and derive the optimal value ;
7             if  then
9             end if
11       end if
13 end for
Algorithm 2 Computing the minimal adversarial perturbation for Mahalanobis 1-NN

c.1 Greedy coordinate ascent (descent)

For the subproblem we have to solve exactly, we employ the greedy coordinate ascent method. Note that the inner minimization problem of (13) is a convex quadratic programming problem. We solve the problem by dealing with its dual formulation. The greedy coordinate ascent method is used because the optimal dual variables are very sparse. The algorithm is shown in Algorithm 3. At every iteration, only one dual variable is updated.

Input: , , , .
1 , ;
2 for  to  do
3       , ;
       // choose a coordinate
4       if  then
5             break;
7       end if
       // update the solution
       // update the gradient
9 end for
Output: .
Algorithm 3 Greedy coordinate descent for QP:

c.2 Lower bound of inner minimization problem

The following theorem is dependent on the solution of the triplet problem.

Theorem 2.

The optimal value of the inner minimization of (13) with respect to is lower bounded as


Relaxing the constraint of (13) by means of replacing the universal quantifier, we know is lower bounded by the optimal value of the following optimization problem

s.t. (34)

Obviously, the optimal value of the inner problem is . ∎

In this way, we could derive a lower bound of the optimal value in closed form.

Appendix D Experimental details


Dataset statistics and test errors (on all test instances) of Euclidean -NN are shown in Table 3. All training data are used to learn metrics, and 1,000 instances are randomly sampled to compute certified robust errors.

# features # classes # train # test 1-NN test error 11-NN test error
MNIST 784 10 60,000 10,000 0.031 0.033
Fashion-MNIST 784 10 60,000 10,000 0.150 0.150
Splice 60 2 1,000 2,175 0.295 0.291
Pendigits 16 10 7,494 3,498 0.023 0.027
Satimage 36 6 4,435 2,000 0.112 0.106
USPS 256 10 7,291 2,007 0.049 0.060
Table 3: Dataset statisitcs


Hyperparameters of our ARML algorithm are fixed across all datasets. Specifically, the size of the neighborhood where and sample random instances is 10. In other words, at every iteration, we sample one instance from the nearest 10 instances in the same class with the test instance, and sample one instance from the nearest 10 instances in the different classes from the test instance. We employ the Adam algorithm [kingma2015adam] to update parameters with gradients and the parameters is in the default setting (learning rate: 0.001, betas: ). The number of epochs is 1,000. The loss function is the negative loss.

Hyperparameter sensitivity

We investigate the sensitivity of the size of neighborhood used for and . We plot the robust error curves against the radius for different neighborhood sizes in Figure 2 and Figure 3. It suggests that ARML is not very sensitive to this hyperparameter in terms of certified and empirical robust errors if it is not too small.

(a) 1-NN certified robust error
(b) -NN certified robust error
(c) -NN empirical robust error
Figure 2: Sensitivity to neighborhood size on Splice
(a) 1-NN certified robust error
(b) -NN certified robust error
(c) -NN empirical robust error
Figure 3: Sensitivity to neighborhood size on Satimage

Implementations of NCA, LMNN, ITML and LFDA

We use the implementations of the metric-learn library [vazelhes2019metric] for NCA, LMNN, ITML and LFDA. Similar to ARML, hyperparameters are fixed across all datasets and are in the default setting. In particular, the maximum numbers of iterations for NCA, LMNN and ITML are 1,000, 100 and 1,000 respectively.