Traditional machine learning algorithms which optimize the empirical risk often suffer from poor generalization performance under distributional shifts caused by latent heterogeneity or selection biases that widely exist in real-world dataIII and Marcu (2006); Torralba and Efros (2011). How to guarantee a machine learning algorithm with good generalization ability on data drawn out-of-distribution is of paramount significance, especially in high-stake applications such as financial analysis, criminal justice and medical diagnosis, etc.Kukar (2003); Rudin and Ustun (2018), which is known as the out-of-distribution(OOD) generalization problemArjovsky et al. (2019).
To ensure the OOD generalization ability, invariant learning methods assume the existence of the causally invariant correlations and exploit them through given environments, which makes their performances heavily dependent on the quality of environments. Further, the requirements for the environment labels are too strict to meet with, since real-world datasets are frequently assembled by merging data from multiple sources without explicit source labels. Recently, several worksCreager et al. (2020); Liu et al. (2021) to relax such restrictions have been proposed. Creager et al.Creager et al. (2020) directly infer the environments according to a given biased model first and then performs invariant learning. But the two stages cannot be jointly optimized and the quality of inferred environments depends heavily on the pre-provided biased model. Further, for complicated data, using invariant representation for environment inference is harmful, since the environment-specific features are gradually discarded, causing the extinction of latent heterogeneity and rendering data from different latent environments undistinguishable. Liu et al.Liu et al. (2021) design a mechanism where two interactive modules for environment inference and invariant learning respectively can promote each other. However, it can only deal with scenarios where invariant and variant features are decomposed on raw feature level, and will break down when the decomposition can only be performed in representation space(e.g., image data).
This paper focuses on the integration of latent heterogeneity exploration and invariant learning on representation level. In order to incorporate representation learning with theoretical guarantees, we introduce Neural Tangent Kernel(NTKJacot et al. (2018)) into our algorithm. According to NTK theoryJacot et al. (2018)
, training the neural network is equivalent to linear regression using Neural Tangent Features(NTF), which converts non-linear neural networks into linear regression in NTF space and makes the integration possible. Based on this, our Kernelized Heterogeneous Risk Minimization (KerHRM) algorithm is proposed, which synchronously optimizes the latent heterogeneity exploration moduleand invariance learning module in NTF space. Specifically, we propose our novel Invariant Gradient Descent(IGD) for , which performs invariant learning in NTF space and then feeds back to neural networks with appointed invariant gradient direction. For , we construct an orthogonal heterogeneity-aware kernel to capture the environment-specific features and to further accelerate the heterogeneity exploration. Theoretically, we demonstrate our heterogeneity exploration algorithm for with rate-distortion theory and justify the orthogonality property of the built kernel, which jointly can illustrate the mutual promotion between the two modules. Empirically, experiments on both synthetic and real-world data validate the superiority of KerHRM in terms of good out-of-distribution generalization performance.
Following Arjovsky et al. (2019); Chang et al. (2020a), we consider data with different sources data collected from multiple training environments . Here environment labels are unavailable as in most of the real applications.
is a random variable on indices of training environments andis the distribution of data and label in environment . The goal of this work is to find a predictor with good out-of-distribution generalization performance, which is formalized as:
where represents the risk of predictor on environment , and
the loss function. Note thatis the random variable on indices of all possible environments such that . Usually, for all , the data and label distribution can be quite different from that of training environments . Therefore, the problem in equation 1 is referred to as Out-of-Distribution (OOD) Generalization problem Arjovsky et al. (2019). Since it is impossible to characterize the latent environments without any prior knowledge or structural assumptions, the invariance assumption is proposed for invariant learning:
There exists random variable such that the following properties hold:
a. : for all , we have holds.
b. : .
This assumption indicates invariance and sufficiency for predicting the target using , which is known as invariant representations with stable relationships with across . To acquire such , a branch of worksChang et al. (2020b); Koyama and Yamaguchi (2020); Liu et al. (2021) proposes to find the maximal invariant predictor(MIP) of an invariance set, which are defined as follows:
The invariance set with respect to is defined as:
where is the Shannon entropy of a random variable. The corresponding maximal invariant predictor (MIP) of is defined as , where measures Shannon mutual information between two random variables.
However, recent worksChang et al. (2020b); Koyama and Yamaguchi (2020) on finding MIP solutions rely on the availability of data from multiple training environments , which is hard to meet with in practice. Further, their validity is highly determined by the given . Since , the invariance regularized by is often too large and the learned MIP may contain variant components and fails to generalize well. Based on this, Heterogeneous Risk Minimization(HRMLiu et al. (2021)) proposes to generate environments with minimal and to conduct invariant prediction with learned . However, the proposed HRM can only deal with simple scenarios where on raw feature level ( are invariant features and variant ones), and will break down where ( is an unknown transformation function), since the decomposition can only be performed in representation space. In this work, we focus on the integration of latent heterogeneity exploration and invariant learning in general scenarios where invariant features are latent in , which can be easily fulfilled in real applications.
(Problem Setting) Assume that , where satisfies Assumption 2.1, is an unknown transformation function and (following functional representation lemmaEl Gamal and Kim (2011)), given heterogeneous dataset without environment labels, the task is to generate environments with minimal and meanwhile learn invariant models.
Following the analysis in section 2, to generate environments with minimal is equivalent to generate environments with as varying as possible, so as to exclude variant parts from the invariant set .
In spite of such insight, the latent make it impossible to directly generate . In this work, we propose our Kernelized Heterogeneous Risk Minimization (KerHRM) algorithm with two interactive modules, the frontend for heterogeneity exploration and backend for invariant prediction. Specifically, given pooled data, the algorithm starts with the heterogeneity exploration module with a learned heterogeneity-aware kernel to generate . The learned environments are used by to produce invariant direction in Neural Tangent Feature(NTF) space that captures the invariant components , and then is used to guide the gradient descent of neural networks. After that, we update the kernel to orthogonalize with the invariant direction so as to better capture the variant components and realize the mutual promotion between and iteratively. The whole framework is jointly optimized, so that the mutual promotion between heterogeneity exploration and invariant learning can be fully leveraged. For smoothness we begin with the invariant prediction step to illustrate our algorithm, and the flow of whole algorithm is shown in figure 1.
3.1 : Invariant Gradient Descent with (Step 1)
For our invariant learning module , we propose Invariant Gradient Descent (IGD) algorithm. Taking the learned environments as input, our IGD firstly performs invariant learning in Neural Tangent Feature (NTFJacot et al. (2018)) space to obtain the invariant direction , and then guides the whole neural network with to learn the invariant model(neural network)’s parameters .
Neural Tangent Feature Space The NTK theoryJacot et al. (2018) shows that training the neural network is equivalent to linear regression using non-linear NTFs , as in equation 4. For each data point , where is the feature dimension, the corresponding feature is given by , where is the number of neural network’s parameters. Firstly, we would like to dissect the feature components within by decomposing the invariant and variant components hidden in
. Therefore, we propose to perform Singular Value Decomposition (SVD) on the NTF matrix:
Intuitively, in equation 3, each row of represents the -th feature component of and we take such different feature components with the top
largest singular values to represent the data, and the rationality of low rank decomposition is guaranteed theoreticallyUdell and Townsend (2019); Oymak et al. (2019) and empiricallyArora et al. (2019). Since SVD ensures every feature components orthogonal, the neural tangent feature of the -th data point can be decomposed into , where denotes the strength of the -th feature component in the -th data. However, since neural networks have millions of parameters, the high dimension prevents us from learning directly on high dimensional NTFs . Therefore, we rewrite the initial formulation of linear regression into:
where we let which reflects how the model parameter utilizes the feature components. Since is orthogonal, fitting with features is equivalent to fitting using reduced NTFs . In this way, we convert the original high-dimensional regression problem into the low-dimensional one in equation 5, since in wide neural networks, we have .
Invariant Learning with Reduced NTFs We could perform invariant learning on reduced NTFs in linear space. In this work, we adopt the invariant regularizer proposed in Koyama and Yamaguchi (2020) to learn due to its optimality guarantees, and the objective function is:
Guide Neural Network with invariant direction With the learned , it remains to feed back to the neural network’s parameters . Since for neural networks with millions of parameters whose , it is difficult to directly obtain as . Therefore, we design a loss function to approximate the projection (). Note that , we have
Therefore, we can ensure the updated parameters satisfy that is parallel to , which leads to the following loss function:
where is the empirical prediction loss over training data and the second term is to force the invariance property of the neural network.
3.2 Variant Component Decomposition with (Step 2)
The core of our KerHRM is the mutual promotion of the heterogeneity exploration module and the invariant learning module . From our insight, we should leverage the variant components to exploit the latent heterogeneity. Therefore, with the better invariant direction learned by that captures the invariant components in data, it remains to capture better variant components so as to further accelerate the heterogeneity exploration procedure, for which we design a clustering kernel on the reduce NTF space of with the help of learned in section 3.1. Recall the NTF decomposition in equation 3, the initial similarity of two data points and can be decomposed as:
With the invariant direction learned by in iteration , we can wipe out the invariant components used by via
which gives a new heterogeneity-aware kernel that better captures the variant components as .
3.3 : Heterogeneity exploration with (Step 3)
takes one heterogeneous dataset as input, and outputs a learned multi-environment partition for invariant prediction module , and we implement it as a clustering algorithm with kernel regression given the heterogeneity-aware that captures the variant components in data. Following the analysis above, only the variant components should be leveraged to identify the latent heterogeneity, and therefore we use the kernel as well as learned in section 3.2 to capture the different relationship between and , for which we use as the clustering centre. Specifically, we assume the -th cluster centre to be a Gaussian around as:
For the given data points , the empirical distribution can be modeled as . Under this setting, we propose one convex clustering algorithm, which aims at finding a mixture distribution in distribution set defined as:
to fit the empirical data best. Therefore, the original objective function and the simplified one are:
Note that our clustering algorithm differs from others since the cluster centres are learned models parameterized with . As for optimization, we use EM algorithm to optimize the centre parameters and the mixture weights iteratively. Specifically, when optimizing the cluster centre model , we use kernel regression with to avoid computing and allow large . For generating the learned environments , we assign -th point to
-th cluster with probability.
4 Theoretical Analysis
In this section, we provide theoretical justifications of the mutual promotion between and . Since our algorithm does not violate the theoretical analysis in Koyama and Yamaguchi (2020) and Jacot et al. (2018) which proves that better from benefits the MIP learned by , to finish the mutual promotion, we only need to justify that better from benefits the learning of in .
Then similar to Lashkari and Golland (2007), we use the rate-distortion theory to demonstrate why larger between cluster centres benefits our convex clustering as well as the quality of .
(Rate-Distortion) For the proposed convex clustering algorithm, we have:
where is a discrete random variable over the space
is a discrete random variable over the spacewhich denotes the probability of -th data point belonging to -th cluster, are the marginal distribution of random variable respectively, and the Shannon mutual information. Note that the optimal can be obtained by the optimal and therefore we only minimize the r.h.s with respect to .
Actually models the conditional distribution . If in the underlying distribution of the empirical data differs a lot between different clusters, the optimizer will put more efforts in optimizing to avoid inducing larger error, resulting in smaller efforts put on optimization of and a relatively larger . This means data sample points have a larger mutual information with cluster index , thus the clustering is prone to be more accurate.
2. Orthogonality Property: Better for better . Firstly, we prove the orthogonality property between (equation 6) and parameters of clustering centres .
(Orthogonality Property) Denote the data matrix of -th environment and , then for each , we have and , where denotes the column space and the null space.
Theorem 4.2 justifies that the parameter space for clustering model as well as the space of learned variant components is orthogonal to the invariant direction , which indicates that better invariant direction regulates better variant components and therefore better heterogeneity. Taking (1) and (2) together, we conclude that better results() of promotes the latent heterogeneity exploration in because of larger between-cluster distance. Finally, we use a linear but general setting for further clarification.
Assume that data points from environments are generated as follows:
where with equal probability, the coefficient varies across environment , is the invariant feature and following functional representation lemma El Gamal and Kim (2011) is the variant feature with and its relationship with the target relies on the environment-specific .
In example 4, when achieves optimal, we have 10, we have , which directly shows that in the next iteration, uses solely variant components in to learn environments with diverse , which by lemma 4.1 and theorem 4.1 gives the best clustering results.
In this section, we validate the effectiveness of our method on synthetic data and real-world data.
Baselines We compare our proposed KerHRM with the following methods:
Empirical Risk Minimization(ERM):
Distributionally Robust Optimization(DRO Duchi and Namkoong (2018)):
Environment Inference for Invariant Learning(EIIL Creager et al. (2020)):
Heterogeneous Risk Minimization(HRM Liu et al. (2021))
Invariant Risk Minimization(IRM Arjovsky et al. (2019)) with environment labels:
We choose one typical methodDuchi and Namkoong (2018) of DRO as DRO is another main branch of methods for OOD generalization problem of the same setting with us (no environment labels). And HRM and EIIL are another methods for inferring environments for invariant learning without environment labels. We choose IRM as another baseline for its fame in invariant learning, but note that IRM is based on multiple training environments and we provide labels for it, while the others do not need. Further, for ablation study, we run KerHRM for only one iteration without the feedback loop and denote it as Static KerHRM(KerHRM). For all experiments, we use a two-layer MLP with 1024 hidden units.
Evaluation Metrics To evaluate the prediction performance, for task with only one testing environment, we simply use the prediction accuracy of the testing environment. While for tasks with multiple environments, we introduce defined as , defined as
, which are mean and standard deviation error across. And we use the average mean square error for .
5.1 Synthetic Data
Classification with Spurious Correlation
Following Sagawa et al. (2020), we induce the spurious correlation between the label and a spurious attribute . Specifically, each environment is characterized by its bias rate , where the bias rate represents that for data, , and for the other data, . Intuitively, measures the strength and direction of the spurious correlation between the label and spurious attribute , where larger signifies higher spurious correlation between and , and represents the direction of such spurious correlation, since there is no spurious correlation when . We assume , where is the invariant feature generated from label and the variant feature generated from spurious attribute :
is an random orthogonal matrix to scramble the invariant and variant component, which makes it more practical. Typically, we setto let the model more prone to use spurious since is more informative.
In training, we set and generate 2000 data points, where points are from environment with and the other from environment with . For our method, we set the cluster number . In testing, we generate 1000 data points from environment with to induce distributional shifts from training. In this experiments, we vary the bias rate of environment and the scrambled matrix
which can be an orthogonal or identity matrix (as done inArjovsky et al. (2019)), and results after 10 runs are reported in Table 1.
From the results, we have the following observations and analysis: ERM suffers from the distributional shifts between training and testing, which yields the worst performance in testing. DRO can only provide slight resistance to distributional shifts, which we think is due to the over-pessimism problemFrogner et al. (2019b). EIIL achieves the best training performance but also performs poorly in testing. HRM outperforms the above three baselines, but its testing accuracy is just around the random guess(0.50), which is due to the disturbance of the simple raw feature setting in Liu et al. (2021). IRM performs better when the heterogeneity between training environments is large( is small), which verifies our analysis in section 2 that the performance of invariant learning methods highly depends on the quality of the given . Compared to all baselines, our KerHRM performs the best with respect to highest testing accuracy and lowest , showing its superiority to IRM and original HRM.
Further, we also empirically analyze the sensitivity to the choice of cluster number of our KerHRM. We set and test the performance with respectively. Results compared with IRM are shown in Table 2. From the results, we can see that the cluster number of our methods does not need to be the ground truth number(ground truth is 2) and our KerHRM is not sensitive to the choice of cluster number
. Intuitively, we only need the learned environments to reflect the variance of relationships between, but do not require the environments to be ground truth. However, we notice that when is far away from the proper one, the convergence of clustering algorithm is much slower.
Regression with Selection Bias
In this setting, we induce the spurious correlation between the label and spurious attributes through selection bias mechanism, which is similar to that in Kuang et al. (2020). We assume and , where is a non-linear function such that remains invariant across environments while changes arbitrarily. For simplicity, we select data with probability according to a certain variable :
where . Intuitively, eventually controls the strengths and direction of the spurious correlation between and (i.e. if , a data point whose is close to its is more probably to be selected.). The larger value of means the stronger spurious correlation between and , and means positive correlation and vice versa. Therefore, here we use to define different environments.
In training, we generate 1000 points from environment with a predefined and points from with . In testing, to simulate distributional shifts, we generate data points for 6 environments with . We compare our KerHRM with ERM, DRO, EIIL and IRM. We conduct experiments with different settings on and the scrambled matrix .
From the results in Table 3, we have the following analysis: ERM, DRO and EIIL performs poor with respect to high average and stability error, which is similar to that in classification experiments(Table 1). The results of HRM are quite different in two scenarios, where Scenario 1 corresponds to the simple raw feature setting() in Liu et al. (2021) but Scenario 2 violates such simple setting with random orthogonal and greatly harms HRM. Compared to all baselines, our KerHRM achieves lowest average error in 5/6 settings, and its superiority is especially obvious in our more general setting(Scenario 2).
|Scenario 1: Non-Scrambled Setting (, varying )|
|Scenario 2: Scrambled Setting (random orthogonal , varying )|
To further validate our method’s capacity under general settings, we use the colored MNIST dataset, where data
are high-dimensional non-linear transformation from invariant features(digits) and variant features(color ). Following Arjovsky et al. (2019), we build a synthetic binary classification task, where each image is colored either red or green in a way that strongly and spuriously correlates with the class label . Firstly, a binary label is assigned to each images according to its digits: for digits 04 and for digits 59. Secondly, we sample the color id by flipping with probability and therefore forms environments, where for the first training environment, for the second training environments and for the testing environment. Thirdly, we induce noisy labels by randomly flipping the label with probability 0.2.
We randomly sample 2500 images for each environments, and the two training environments are mixed without environment label for ERM, DRO, EIIL, HRM and HRM, while for IRM, the labels are provided. For IRM, we sample 1000 data from the two training environments respectively and select the hyper-parameters which maximize the minimum accuracy of two validation environments. Note that we have no access to the testing environment while training, therefore we cannot resort to testing data to select the best one, which is more reasonable and different from that in Arjovsky et al. (2019). For the others, since we have no access to labels, we simply pool the 2000 data points for validation. The results are shown in Table 4, where Perfect Inv. Model represents the oracle results that can be achieved under this setting. We run each method for 5 times and report the average accuracy, and since the variance of all methods are relatively small, we omit it in the table.
From the results, our KerHRM generalize the HRM to much more complicated data and consistently achieves the best performances. KerHRM even outperforms IRM significantly in an unfair setting where we provide perfect environment labels for IRM, which shows the limitation of manually labeled environments. Further, to best show the mutual promotion between and , we plot the training and testing accuracy as well as the KL-divergence of between the learned over iterations in figure 3. From figure 3, we firstly validate the mutual promotion between and since and testing accuracy escalate synchronously over iterations. Secondly, figure 3 corresponds to our analysis in section 2 that the performance of invariant learning method is highly correlated to the heterogeneity of , which sheds lights to the importance of how to leverage the intrinsic heterogeneity in training data for invariant learning.
5.2 Real-world Data
In this experiment, we test our method on a real-world regression dataset (Kaggle) of house sales prices from King County, USA111https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data, where the target variable is the transaction price of the house and each sample contains 17 predictive variables, such as the built year, number of bedrooms, and square footage of home, etc. Since it is fairly reasonable to assume the relationships between predictive variables and the target vary along the time (for example, the pricing mode may change along the time), there exist distributional shifts in the price-prediction task with respect to the build year of houses. Specifically, the houses in this dataset were built between , and we divide the whole dataset into 6 periods, where each contains a time span of two decades. Notice that the later periods have larger distributional shifts. We train all methods on the first period where and test on the other 5 periods and report the average results over 10 runs in figure 3. For IRM, we further divide the period 1 into two decades for the provided.
Analysis The testing errors of ERM and DRO increase sharply across environments, indicating the existence of the distributional shifts between environments. IRM performs better than ERM and DRO, which shows the usefulness of environment labels for OOD generalization and the possibility of learning invariant predictor from multiple environments. The proposed KerHRM outperforms EIIL and HRM, which validates its superiority of heterogeneity exploration. KerHRM even outperforms IRM, which indicates the limitation of manually labeled environments in invariant learning and the necessity of latent heterogeneity exploration.
Although the proposed KerHRM is a competitive method, it has several limitations. Firstly, since in we take the model parameters as cluster centres, the strict convergence guarantee for our clustering algorithm is quite hard to analyze. And empirically, we find when the pre-defined cluster number is far away from the ground-truth, the convergence of will become quite slow. Further, such restriction also affects the analysis of the mutual promotion between and
, which we can only empirically provide some verification. Besides, although we incorporate Neural Tangent Kernel to deal with data beyond raw feature level, how to deal with more complicated data still remains unsolved. Also, how to incorporate deep learning with the mutual promotion between the two modules needs further investigation, and we left it for future work.
In this paper, we propose the KerHRM algorithm for the OOD generalization problem, which achieves both the latent heterogeneity exploration and invariant prediction. From our theoretical and empirical analysis, we find that the heterogeneity of environments plays a key role in invariant learning, which is consistent with some recent analysisRosenfeld et al. (2021) and opens a new line of research for OOD generalization problem. Our code is available at https://github.com/LJSthu/Kernelized-HRM.
This work was supported National Key R&D Program of China (No. 2018AAA0102004).
Appendix A Appendix
a.1 Experimental Details
In this section, we introduce the experimental details as well as additional results. In all experiments, we take for our KerHIL and select the best one according to the validation results.
Classification with Spurious Correlation
For our synthetic data, we set and to let the model more prone to use spurious since is more informative.
Regression with Selection Bias
In this setting, the correlations among covariates are perturbed through selection bias mechanism. According to assumption 2.1, we assume and is independent from while the covariates in are dependent with each other. We assume and remains invariant across environments while can arbitrarily change.
Therefore, we generate training data points with the help of auxiliary variables as following:
To induce model misspecification, we generate as:
where , and . For our synthetic data, we set , and . As we assume that remains unchanged while can vary across environments, we design a data selection mechanism to induce this kind of distribution shifts. For simplicity, we select data points according to a certain variable set :
where . Given a certain , a data point is selected if and only if (i.e. if , a data point whose is close to its is more probably to be selected.) Intuitively, eventually controls the strengths and direction of the spurious correlation between and (i.e. if , a data point whose is close to its is more probably to be selected.). The larger value of means the stronger spurious correlation between and , and means positive correlation and vice versa. Therefore, here we use to define different environments.
a.2 Proof of Theorems
a.2.1 Proof of Theorem 2.1
First, we would like to prove that a random variable satisfying assumption 2.1 is MIP.
A representation satisfying assumption 2.1 is the maximal invariant predictor.
: To prove . If is not the maximal invariant predictor, assume