Evaluating the fairness of fine-tuning strategies in self-supervised learning

10/01/2021 ∙ by Jason Ramapuram, et al. ∙ Apple Inc. 0

In this work we examine how fine-tuning impacts the fairness of contrastive Self-Supervised Learning (SSL) models. Our findings indicate that Batch Normalization (BN) statistics play a crucial role, and that updating only the BN statistics of a pre-trained SSL backbone improves its downstream fairness (36 supervised learning, while taking 4.4x less time to train and requiring only 0.35 supervised learning, we find that updating BN statistics and training residual skip connections (12.3 fine-tuned model, while taking 1.33x less time to train.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

ssl is an effective pre-training strategy in the image (Chen et al., 2020; Grill et al., 2020; Caron et al., 2021, 2020; Zbontar et al., 2021; Bardes et al., 2021), language (Devlin et al., 2019), video (Alayrac et al., 2020) and audio (Deng et al., 2009) domains. These large scale ssl models are trained without the use of (potentially) biased human annotations, and attain better than supervised performance when fine-tuned on small sample supervised datasets. The performance guarantees for many of these large scale ssl models is strongly coupled with the use of bn (Chen et al., 2020; Zbontar et al., 2021; Caron et al., 2020; Alayrac et al., 2020; Fetterman and Albrecht, 2020). bn (Ioffe and Szegedy, 2015) tends to favor subgroups of the dataset which contain more samples, negatively impacting downstream model performance for under-represented subpopulations.

To understand how ssl model fairness is impacted by fine-tuning, we evaluate a number of tuning strategies. We find that the treatment of bn statistics is a dominant factor for determining downstream fairness. When tuning a linear task head, freezing bn statistics and backbone parameters reduces performance by up to 36% in the worst subgroup fairness metric, whereas allowing bn statistics to update reduces the performance gap against a fully fine-tuned model, while taking 4.4 less time to train and updating only 0.35% of the total model parameters.

2 Results

Our baseline ssl model uses the SimCLR framework and optimization procedure (Chen et al., 2020; Goyal et al., 2017), pre-trained (no labels) on the Celeb-A train split (162,770 samples). We then attach a linear head to the backbone and evaluate five scenarios inspired by analysis in supervised learning (Frankle et al., 2020): fully fine-tuned (Full FT); frozen backbone, updating residual skip connections and bn stats (bn Stats+Skip); frozen backbone, updating bn affine parameters and bn stats (bn Stats+Affine); frozen backbone, updating bn stats (bn Stats); and fully frozen backbone (Frozen). Updates are done using supervised information from the Celeb-A train split.

Training Procedure

bald

double chin

chubby

wearing necktie

wearing necklace

no beard

straight hair

big lips

wavy hair

male

wearing lipstick

all

.02 .05 .06 .07 .12 .17 .21 .24 .32 .42 .47
Gap SSL (Frozen)
SSL (BN Stats)
SSL (BN Stats+Affine)
SSL (BN Stats+Skip)
SSL (Full FT)
Supervised
Worst SSL (Frozen)
SSL (BN Stats)
SSL (BN Stats+Affine)
SSL (BN Stats+Skip)
SSL (Full FT)
Supervised
Table 1: wise gap and worst scores: with under-representation statistic:

. Attributes presented are uniformly distributed across

, displaying balanced and imbalanced model behaviour. We note that the gap (worst) statistic is smaller (larger/worse) for large . All = .

figurec

Figure 1: Top: Total number of parameters and buffers updated per model. Bottom: Distribution of over 1560 combinations on the Celeb-A test split. Individual attribute thresholds are calibrated on the Celeb-A train split with no conditioning. Higher distribution indicates a fairer model in the absolute sense. SSL (Full FT), SSL (bn Stats+Skip) and Supervised perform similarly across the board, with the fully fine-tuned SSL model being marginally better. SSL (frozen) drastically underperforms compared to these four models, however, the performance gap can be closed by simply allowing the running statistics to be updated, yielding SSL (bn stats).

We evaluate Celeb-A test split (19,962 samples) using the 40-dimensional binary attribute prediction task. We baseline our ssl model against a strong supervised learning model, which uses the same ResNet50 (He et al., 2016) backbone. To choose hyper-parameters, we perform a random search (twenty trials) across optimizers (Huo et al., 2021; Kingma and Ba, 2015), learning rates and schedulers (Goyal et al., 2017; Smith and Topin, 2017)

, weight decay, training epochs, and linear warmup intervals. Equivalent compute budget is used for the SSL fine-tuning and supervised models, and we provide results for the best performing model from each search.

Quantifying fairness is challenging due to its multifaceted nature (Garg et al., 2020), with some facets mutually incompatible (Friedler et al., 2016). In this work, taking the score as a performance measure, a fair model maximizes the for the worst treated subgroup () and minimizes performance differences across subgroups (). Concretely, let be the set of all 40 Celeb-A categories, denote the score achieved on task for the subpopulation with as true111For example: is the score for blurry images when predicting wearing hat. , and equivalently for populations with as false. We define

(1)

Model performance across the 1560 combinations222We omit on-diagonal () terms to ensure all metric components are well-defined. is summarized in Figure 1 and Table 1.

3 Conclusion

Models that produce fair representation vectors can directly improve the fairness of any downstream task that uses them. These models have the ability to affect fairness at a large scale, through the use of developer APIs. In this work, we quantify the the effect that various fine-tuning strategies play in downstream fairness, and observe the crucial role played by bn statistics. We demonstrate that

only updating bn statistics minimizes the gap between an end-to-end trained model and a frozen SSL model, improving worst case subgroup fairness by 36% and taking 4.4 less time to train.

References