Self-balanced Learning For Domain Generalization

08/31/2021 ∙ by Jin Kim, et al. ∙ Yonsei University 0

Domain generalization aims to learn a prediction model on multi-domain source data such that the model can generalize to a target domain with unknown statistics. Most existing approaches have been developed under the assumption that the source data is well-balanced in terms of both domain and class. However, real-world training data collected with different composition biases often exhibits severe distribution gaps for domain and class, leading to substantial performance degradation. In this paper, we propose a self-balanced domain generalization framework that adaptively learns the weights of losses to alleviate the bias caused by different distributions of the multi-domain source data. The self-balanced scheme is based on an auxiliary reweighting network that iteratively updates the weight of loss conditioned on the domain and class information by leveraging balanced meta data. Experimental results demonstrate the effectiveness of our method overwhelming state-of-the-art works for domain generalization.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Despite the impressive advent of deep learning technologies, the domain shift is still a huge obstacle for deploying them in numerous computer vision tasks. As most approaches are trained under the assumption that training and test data are sampled from the same distribution, they often fail to infer accurate predictions for unseen test data sampled from out-of-distribution 

[1, 2, 3, 4, 5].

To overcome this issue, the domain generalization (DG) approaches have attempted to design a generalizable model to minimize the domain gap between train and test domains and perform evenly well under multiple data distributions. As part of this effort, the data augmentation methods have been proposed to synthesize the source domains to a wider span of the data space. Volpi et al[6] adversarially generated fictitious samples by defining target distributions within a certain Wasserstein distance. Recently, Yue et al[7] employed style transfer [8] to synthesize new training samples, and Carlucci et al[9] proposed to solve a jigsaw puzzle by shuffling image patches to improve the generalization performance of image classification. Although these approaches have demonstrated promising results with data augmentation, it is fairly hard to generate new training samples that can fully cover the real-world distribution, often leading to overfitting issues.

(a) Target domain: Caltech
(b) Target domain: LabelMe
Figure 1:

Left: the classification result of each class on the VLCS dataset. Right: the number of samples in source domains. The variances of samples with respect to classes and domains,

and , are reported, where a large variance value implies severe imbalance of training data.
Figure 2: An overall methodological flow of SBDG. In step 1, the parameter of the task network is updated to with a mini-batch from the imbalanced set. The parameter of the reweighting network is updated to with a mini-batch from the balanced set in step 2. As a final step, is optimized to with the imbalanced set and the adaptive weight derived from .

Learning domain invariant feature representations has become an attractive alternative to make the model robust to the unseen target data [10, 11, 12, 13, 14]. As a pioneering work, maximum mean discrepancy (MMD) constraint [15]

was applied to adversarial autoencoders to measure the alignment of the distribution among multiple source domains 

[10]. Some approaches have improved generalization ability through a contrastive loss to embed training samples nearby latent space [14] or an episodic training procedure to simulate domain shift in training phase [11]. The model-agnostic meta-learning approaches [16, 17, 18] have been introduced to learn a way that improves the domain generality of a base learner by finding a route to optimization. While they have achieved remarkable advances in domain generalization, biased training data distributions in terms of domains and classes behave as the main obstacle impeding the aforementioned approaches from achieving higher accuracy. For example, we depict the classification performance on the different target domain using the representative meta-learning based DG method [18] and our proposed method in Fig. 1. The variances for the number of data with respect to domains and classes are also reported. The accuracy has been degraded significantly from the imbalanced source domains, when evaluating the model on the L dataset in Fig. 1(b).

In this paper, we propose a novel self-balanced domain generalization framework, termed SBDG, that acts as a more effective alternative for the domain/class imbalance problem by explicitly weighting a training loss. To this end, we extend the sample reweighting scheme [19], which was originally proposed for solving the class imbalance issue, to prioritize the minority domain with relatively higher training losses. Specifically, we formulate an auxiliary reweighting network that can be integrated with the domain generalization methods through the adaptive balancing of the training loss to fully leverage the information of a given domain. Furthermore, model-agnostic meta-learning [16]

is employed to train the reweighting network guided by an unbiased meta-dataset that is uniformly distributed in terms of domains and classes as shown in Fig.

2. Experimental results show that our framework helps to prevent the performance degeneration by imbalanced training samples and achieves state-of-the-art performance compared to prior works.

2 Proposed Method

2.1 Problem Statement

The objective of domain generalization is to learn a system to perform well in an unseen domain using a set of observable source domains , where is the number of source domains. The source domain for class contains pairs of input images and class labels , where largely varies. The drawback of most existing methods is to give equal weight to sample in training without consideration of the imbalance.

Our method improves generalization performance by estimating adaptive learning weights between dominant and minor samples in each training iteration. As shown in Fig.

2, the proposed networks are composed of two parts; (1) the task network to perform a target task and (2) the auxiliary reweighting network to balance the task loss of each sample. The task network

predicts the probability of the image class from an input image

with a parameter . The auxiliary reweighting network

takes a loss and the conditional domain vector

as inputs and outputs an adaptive weight for sample with a network parameter . Although the image classification is employed as the target task in this work, our framework can also be extended in numerous computer vision tasks such as object detection [20] and semantic segmentation [21]. We remain this as future work.

2.2 Training

To jointly train the reweighting and task networks, we first divide the source domains into balanced meta dataset and imbalanced set :

(1)

The images in the balanced meta dataset are uniformly distributed in terms of domains and classes. The overall framework is illustrated in Fig. 2. We take two state-of-the-arts (MLDG [18] or RSC [1]) as the task network for the image classification, but any kind of DG models can also be adopted. The auxiliary reweighting network

is composed of 2 fully connected layers with the ReLU activation function, and the sigmoid output layer is adopted to ensure

. Next, we describe the three training steps in detail.

Step 1. In the first step, we update the current parameter of the task network with the imbalanced set. The image class is predicted by with softmax function as follows:

(2)

where is the probability of image classes, is a training image sampled from imbalanced set , and is the number of class. The task loss is cross-entropy loss between and . While the task loss can be defined in a different form depending on

, we just use the cross-entropy loss function for simplification. Different from 

[19] that uses only loss of the task network, the reweighting network takes the task loss and the one-hot domain conditional vector to ensure that recognizes the characteristics of the domain. Thus, the weight for loss is estimated as follows:

(3)

where represents a concatenation operation. The current parameter of the task network is optimized by minimizing the following weighted loss:

(4)

where is the number of samples in a mini-batch.

Step 2. The parameter of the reweighting network is optimized to be guided by the parameter on the balanced meta dataset by optimizing the following equation:

(5)

where is a cross entropy loss, and is a training image sampled from the balanced meta dataset , is the estimated class vector with the parameter , and is the size of mini-batch.

Hyperparameters: Mini-batch size ,; max iteration ; step size ,
Input: Source training domains ; imbalanced training set ; balanced training set
Output: Task parameter

1:procedure Training(,)
2:     Initialize parameters and
3:     for  to  do
4:         
5:         
6:         
7:         
8:         
9:         
10:         
11:         
12:         
13:     end for
14:end procedure
Algorithm 1 Self-Balanced Domain Generalization

Step 3. At the last step, the updated parameter is used to produce the adaptive weight and the parameter is optimized as follows:

(6)

where and is derived from the task network with the parameter . The flow of SBDG algorithm is summarized in Alg. 1 and the convergence of this second gradient update procedure can be mathematically proven in a similar way to [19].

3 Experiments

(a)
(b)
(c)
(d)
Figure 3: The accuracy (red line) and the adaptive weight (blue line) are depicted in the top row, and the number of images is shown in the bottom row for each class in three source domains on VLCS [22] dataset. The target domain is set to (a) V, (b) L, (c) C, and (d) S, respectively. We measure the accuracy before training step 1 and the weight at training step 3 (see Sec. 2.2).

3.1 Experimental Settings

Dataset. We evaluate our method on VLCS dataset [22], which is the commonly used domain generalization benchmark for image classification. It contains images from four different datasets (domains): VOC2007 (V), LabelMe (L), Caltech (C), and SUN09 (S). Each domain includes five classes: bird, car, chair, dog, and person. Following the standard protocol of [23], the VLCS dataset is randomly divided for each domain into 70% training and 30% test set.

Implementation details. We picked 12 samples from each domain-class pair set to construct the balanced meta dataset by leveraging random-over-sampling from 30% of the training set. The mini-batch size of the reweighting network was 9 for each domain, totaling 27 for three source domains. The learning rate of reweighting network is set as . Following the previous works [9, 1], we perform random cropping, horizontal flipping, and RGB to greyscale converting.

To evaluate the influence of our framework, two baseline methods were adopted as the task network ; (1) MLDG [18] and (2) RSC [1]. MLDG [18] is the well-known meta-learning based DG algorithm that meta-learns how to generalize across domains. RSC [1]

is the current state-of-the-art DG algorithm that discards a few percent of high gradient features at each epoch and trains the model with the remaining information. To train with MLDG 

[18], outer and inner update’s gradients are added equally. The reweighted loss was used both in the inner and outer update of MLDG [18]. The batch size is 128 for each source domain and the learning rate is set as . The remaining hyper-parameters for RSC were used the same as the original method [1]

. We used AlexNet pre-trained on ImageNet as a backbone network for each baseline method. We selected the best model via leave-one-domain-out validation model 

[24].

 

Method Caltech101 LabelMe SUN09 VOC2007 Avg.
(0.016) (0.169) (0.160) (0.150)
Deep-All [9] 96.25 59.72 64.51 70.58 72.76
CIDDG [25] 88.83 63.06 62.10 64.38 69.59
Undo-Bias [26] 93.63 63.49 61.32 69.99 72.11
MMD-AAE [10] 94.40 62.60 64.40 67.70 72.28
Epi-FCR [11] 94.10 64.30 65.90 67.10 72.90
JiGen [9] 96.93 60.90 64.30 70.62 73.19
MASF [24] 94.78 64.90 67.64 69.14 74.11
MLDG [18] 94.40 61.30 65.90 67.70 72.30
RSC [1] 97.61 61.86 68.32 73.93 75.43
Ours with MLDG 95.05 63.97 66.08 67.45 73.14
Ours with RSC 96.47 65.85 68.68 72.59 75.90

 

Table 1: Performance comparison on VLCS dataset with state-of-the-art methods. denotes the combination of SBDG with MLDG [18] as the task network, and denotes that the task network is RSC [1]. denotes the normalized variance of source domains corresponding to the target domain.

3.2 Results

In Table 1, we compared our model with several state-of-the-art models on VLCS dataset [22] in terms of classification accuracy on different target domains. The variation of source domains () is computed for each target domain, e.g. using ‘L’, ‘S, and ‘V’ source domains for ‘C’ target domain. While most algorithms work reasonably well for ‘C’ domain, they perform poorly when ‘L’, ‘S’, and ‘V’ are target domains. The reason is that when the target domain is the ‘C’, is lower than in other cases as shown in Table 1. Our method shows the significant improvement of the accuracy when is high (i.e. target domain is ‘L’ and ‘S’). We achieve 75.90% accuracy on average when the task network is RSC [1]. Especially, the performance gain is 3.99% in the ‘L’ domain and 0.36% in the ‘S’ domain over RSC [1]. A similar performance gain is also observed when MLDG [18] is employed as the task network.

3.3 Analysis

Table 2 shows the effectiveness of the domain condition vector used in the auxiliary reweighting network. The domain information assigned to the reweighting network creates more appropriate weights. The conditional domain vector makes the reweighting network to be aware that there are other types of imbalances depending on the source domains. In Fig. 3, we investigated the adaptive weight (blue line) and the accuracy (red line) for each class corresponding to the source domains at the same iteration (top row) and the number of training data for each class and domain (bottom row). Note that, the accuracy is measured before training step 1 and the adaptive weight is measured at training step 3. The results show that the domain-class imbalances degrade the performance in some classes and domains (green box in Fig. 3). The weight learned from the proposed reweighting network is inversely proportional to the accuracy, indicating that it emphasizes the loss to learn in a balanced manner. This demonstrates that the proposed method copes well with the domain-class imbalance problem.

 

Domain vector Caltech101 LabelMe SUN09 VOC2007 Avg.
95.97 62.31 67.58 71.34 74.30
96.47 65.85 68.68 72.59 75.90

 

Table 2: Ablation study for the effect of the conditional domain vector on VLCS dataset. We use RSC [1] as the task network.

4 Conclusion

In this work, we introduce a novel self-balanced domain generalization method to deal with the imbalanced distribution in terms of domains and classes. we predict adaptive weights of losses through the auxiliary reweighting network in the training phase to effectively balance the impact of training samples. Our method outperforms prior approaches especially when the source domains are largely imbalanced. Furthermore, we hope that this method can offer a new research direction by addressing the bias problem in the domain adaptation.

Acknowledgement

This research was supported by the Yonsei University Research Fund of 2021 (2021-22-0001).

References

  • [1] Zeyi Huang, Haohan Wang, Eric P. Xing, and Dong Huang, “Self-challenging improves cross-domain generalization,” in ECCV, 2020.
  • [2] Antonio Torralba and Alexei A Efros, “Unbiased look at dataset bias,” in CVPR, 2011.
  • [3] Yaroslav Ganin and Victor Lempitsky,

    “Unsupervised domain adaptation by backpropagation,”

    in ICML, 2015.
  • [4] Shai Ben-David, John Blitzer, Koby Crammer, and Fernando Pereira, “Analysis of representations for domain adaptation,” in NeurIPS, 2007.
  • [5] Baochen Sun, Jiashi Feng, and Kate Saenko, “Return of frustratingly easy domain adaptation,” in AAAI, 2016.
  • [6] Riccardo Volpi, Hongseok Namkoong, Ozan Sener, John C Duchi, Vittorio Murino, and Silvio Savarese, “Generalizing to unseen domains via adversarial data augmentation,” in NeurIPS, 2018.
  • [7] Xiangyu Yue, Yang Zhang, Sicheng Zhao, Alberto Sangiovanni-Vincentelli, Kurt Keutzer, and Boqing Gong, “Domain randomization and pyramid consistency: Simulation-to-real generalization without accessing target domain data,” in ICCV, 2019.
  • [8] Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A Efros,

    “Unpaired image-to-image translation using cycle-consistent adversarial networks,”

    in ICCV, 2017.
  • [9] Fabio M Carlucci, Antonio D’Innocente, Silvia Bucci, Barbara Caputo, and Tatiana Tommasi, “Domain generalization by solving jigsaw puzzles,” in CVPR, 2019.
  • [10] Haoliang Li, Sinno Jialin Pan, Shiqi Wang, and Alex C Kot, “Domain generalization with adversarial feature learning,” in CVPR, 2018.
  • [11] Da Li, Jianshu Zhang, Yongxin Yang, Cong Liu, Yi-Zhe Song, and Timothy M Hospedales, “Episodic training for domain generalization,” in ICCV, 2019.
  • [12] Seonguk Seo, Yumin Suh, Dongwan Kim, Jongwoo Han, and Bohyung Han, “Learning to optimize domain specific normalization for domain generalization,” ECCV, 2020.
  • [13] Toshihiko Matsuura and Tatsuya Harada, “Domain generalization using a mixture of multiple latent domains.,” in AAAI, 2020.
  • [14] Saeid Motiian, Marco Piccirilli, Donald A Adjeroh, and Gianfranco Doretto, “Unified deep supervised domain adaptation and generalization,” in ICCV, 2017.
  • [15] Yujia Li, Kevin Swersky, and Rich Zemel,

    “Generative moment matching networks,”

    in ICML, 2015.
  • [16] Chelsea Finn, Pieter Abbeel, and Sergey Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” in ICML, 2017.
  • [17] Yogesh Balaji, Swami Sankaranarayanan, and Rama Chellappa, “Metareg: Towards domain generalization using meta-regularization,” in NeurIPS, 2018.
  • [18] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy Hospedales, “Learning to generalize: Meta-learning for domain generalization,” in AAAI, 2018.
  • [19] Jun Shu, Qi Xie, Lixuan Yi, Qian Zhao, Sanping Zhou, Zongben Xu, and Deyu Meng, “Meta-weight-net: Learning an explicit mapping for sample weighting,” in NeurIPS, 2019.
  • [20] Chuang Gan, Tianbao Yang, and Boqing Gong, “Learning attributes equals multi-source domain generalization,” in CVPR, 2016.
  • [21] Sicheng Zhao, Bo Li, Xiangyu Yue, Yang Gu, Pengfei Xu, Runbo Hu, Hua Chai, and Kurt Keutzer, “Multi-source domain adaptation for semantic segmentation,” in NeurIPS, 2019.
  • [22] Chen Fang, Ye Xu, and Daniel N Rockmore, “Unbiased metric learning: On the utilization of multiple datasets and web images for softening bias,” in ICCV, 2013.
  • [23] Muhammad Ghifary, W Bastiaan Kleijn, Mengjie Zhang, and David Balduzzi, “Domain generalization for object recognition with multi-task autoencoders,” in ICCV, 2015.
  • [24] Qi Dou, Daniel Coelho de Castro, Konstantinos Kamnitsas, and Ben Glocker, “Domain generalization via model-agnostic learning of semantic features,” in NeurIPS, 2019.
  • [25] Ya Li, Xinmei Tian, Mingming Gong, Yajing Liu, Tongliang Liu, Kun Zhang, and Dacheng Tao, “Deep domain generalization via conditional invariant adversarial networks,” in ECCV, 2018.
  • [26] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales, “Deeper, broader and artier domain generalization,” in ICCV, 2017.