Log In Sign Up

Learning Modular Structures That Generalize Out-of-Distribution

Out-of-distribution (O.O.D.) generalization remains to be a key challenge for real-world machine learning systems. We describe a method for O.O.D. generalization that, through training, encourages models to only preserve features in the network that are well reused across multiple training domains. Our method combines two complementary neuron-level regularizers with a probabilistic differentiable binary mask over the network, to extract a modular sub-network that achieves better O.O.D. performance than the original network. Preliminary evaluation on two benchmark datasets corroborates the promise of our method.


page 1

page 2


Is a Modular Architecture Enough?

Inspired from human cognition, machine learning systems are gradually re...

Can Subnetwork Structure be the Key to Out-of-Distribution Generalization?

Can models with particular structure avoid being biased towards spurious...

On the Generalization and Adaption Performance of Causal Models

Learning models that offer robust out-of-distribution generalization and...

S2RMs: Spatially Structured Recurrent Modules

Capturing the structure of a data-generating process by means of appropr...

Gradient Matching for Domain Generalization

Machine learning systems typically assume that the distributions of trai...

Distinguishing rule- and exemplar-based generalization in learning systems

Despite the increasing scale of datasets in machine learning, generaliza...


Recent work has uncovered that neural networks that are learned on observational data are often prone to spurious correlations, and rely on shortcuts learned from the training data for solving the task instead of modelling the underlying mechanism

Geirhos_2020. This leads to them failing to transfer to more challenging testing conditions, such as real-world scenarios. Recent works show that modularity is a useful inductive bias that can lead to better systematic generalization goyal2020recurrent; csordas2021are. We seek to understand whether networks can be structurally enforced to prefer modular solutions. In this context, zhang2021subnetwork show that a fully-trained network contains sub-networks that are less susceptible to spurious correlations, and introduce a method to extract the structure from a trained network. We study whether we can obtain such solutions through training itself, by regularizing the network to avoid fitting spurious correlations in the data. We introduce objectives that explicitly incorporate the structure of the network and induce modular structures to be formed at every layer of the network. Our method enforces the network to be a compositional hierarchy of expert modules, promoting the emergence of features that are specialized and reused across multiple training domains in the network. We show that our method boosts the O.O.D. performance of networks across two benchmark datasets.


We first motivate our method. A deep neural network contains several layers of neurons, each serving as a feature for every neuron in the next. Every fundamental sub-function (e.g. a single convolutional filter) in the network is associated with separate neurons that arise out of transformations (e.g. dot product) of the function with the input. Our aim is to discover features that are well used (activated) across multiple training domains, as well as detecting and preventing redundant features from being present in the network. Hence, specialization and reuse are two key principles that underlie our method that we describe below.

First, our objective for specialization regularizes such that every feature in the network should be a different composition of the available sub-features. That is, every feature should fit as few features as necessary, and should differ as much as possible in the features fit, minimizing redundancy. However, directly encouraging this on the weights would unnecessarily constrain the power of the network.

Hence, we use a differentiable probabilistic binary mask

over the network weights, relaxed by the Gumbel-Sigmoid estimator

jang2017categorical. Each value

represents the probability of sampling weight

. During every training iteration, once the mask is sampled, it is binarized as

sigmoid . Once trained, we obtain deterministic masks by binarizing the final values, hence extracting a subnetwork described by the mask.

As part of our specialization objective, we impose the following regularization of the continuous masks of the weights:


where denotes the number of layers, - the number of features in the layer, - the number of outgoing weights from feature .

Note that we minimize the (square of) sum of sampling probabilities of weights outgoing from each feature in the current layer, allowing it to be fit sparsely by only a few required features from above. Consequently, this encourages features in the next layer to fit a minimally overlapping set of features from the current layer, leading to each of the former specializing in their underlying function.

Although this objective would encourage specialization, every feature in the current layer may not be necessary, as extra features may correspond to unnecessary functions. The network must automatically be able to decide how many features to keep. However, constructing an objective that can be used to restrict the number of features in a layer is non-trivial, since in the worst case, every feature may be necessary depending on the task and model capacity at hand.

Here, we hypothesise that the necessary features are those that are reused across domains by multiple specialist functions above. Consequently, we regularize to preserve only those features that have a large number of outgoing weights sampled with high-probability, discarding features that are not well reused. We enforce this through the following objective:


This term is inspired from that of group lasso regularization KimTree; applying this term can effectively zero out the masks of all the outgoing weights of some features. Unlike that of group lasso that regularizes the weights and can have overlapping groups, we apply it on the masks and do not have any overlapping groups.

Our final objective is, therefore,


is the loss function used for the task,

being a general-purpose regularizer (eg. ), and & , the weights of each of our regularization terms.

Training with the regularized differentiable mask on data consisting of multiple training domains conditionally activates only those weights shared across multiple domains. Consequently, the sub-network contains features that are invariant to the domain, and hence aids O.O.D. generalization.

Preliminary Results

We present preliminary results of our method on two benchmark O.O.D. generalization datasets - Colored MNIST (C-MNIST) and Rotated MNIST (R-MNIST). Each dataset is artifically biased in such a way that in the training dataset, a certain degree of correlation is induced between spurious variables and the class label. In the test dataset, the correlation is reversed. The goal of O.O.D. generalization is to encourage the model to fit the invariant features, ignoring other correlated variables, training and validating only on in-distribution data.

Model Method C-MNIST R-MNIST
CNN ERM 35.23 96.5
ERM + modReg 38.20 96.7
IRM 67.69 97.3
IRM + modReg 71.88 98.1
MLP ERM 34.27 94.45
ERM + modReg 36.91 95.43
IRM 72.58 97.4
IRM + modReg 75.59 97.9
Table 1: Results of the proposed method on multiple architectures, across datasets.

Our method is versatile, and can be used on top of any algorithm. Here, we apply our method on top of empirical risk minimization(ERM), the standard approach to machine learning problems, and invariant risk minimization (IRM) arjovsky2020invariant, a method that estimates invariant, causal predictors from multiple training environments.

Preliminary results shown in table 1 verify the effectiveness of our method. Our method gives consistent gains across the two datasets and architectures considered. In particular, our gains give considerable boosts in the heavily biased C-MNIST dataset, and also improves performance in the R-MNIST dataset in which existing methods have reached their potential.

Future Work

We plan to scale up our method and test its effectiveness on larger datasets. Further, we also plan to take our method forward and evaluate on larger architectures such as ResNets, and on top of other existing O.O.D. generalization methods.