DeepAI
Log In Sign Up

A Closer Look at Loss Weighting in Multi-Task Learning

11/20/2021
by   Baijiong Lin, et al.
0

Multi-Task Learning (MTL) has achieved great success in various fields, however, how to balance different tasks to avoid negative effects is still a key problem. To achieve the task balancing, there exist many works to balance task losses or gradients. In this paper, we unify eight representative task balancing methods from the perspective of loss weighting and provide a consistent experimental comparison. Moreover, we surprisingly find that training a MTL model with random weights sampled from a distribution can achieve comparable performance over state-of-the-art baselines. Based on this finding, we propose a simple yet effective weighting strategy called Random Loss Weighting (RLW), which can be implemented in only one additional line of code over existing works. Theoretically, we analyze the convergence of RLW and reveal that RLW has a higher probability to escape local minima than existing models with fixed task weights, resulting in a better generalization ability. Empirically, we extensively evaluate the proposed RLW method on six image datasets and four multilingual tasks from the XTREME benchmark to show the effectiveness of the proposed RLW strategy when compared with state-of-the-art strategies.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

11/22/2022

Mitigating Negative Transfer in Multi-Task Learning with Exponential Moving Average Loss Weighting Strategies

Multi-Task Learning (MTL) is a growing subject of interest in deep learn...
09/03/2020

Multi-Loss Weighting with Coefficient of Variations

Many interesting tasks in machine learning and computer vision are learn...
02/12/2020

A Simple General Approach to Balance Task Difficulty in Multi-Task Learning

In multi-task learning, difficulty levels of different tasks are varying...
01/07/2020

Dynamic Task Weighting Methods for Multi-task Networks in Autonomous Driving Systems

Deep multi-task networks are of particular interest for autonomous drivi...
08/26/2020

HydaLearn: Highly Dynamic Task Weighting for Multi-task Learning with Auxiliary Tasks

Multi-task learning (MTL) can improve performance on a task by sharing r...
06/25/2020

MTAdam: Automatic Balancing of Multiple Training Loss Terms

When training neural models, it is common to combine multiple loss terms...
06/11/2021

Instance-Level Task Parameters: A Robust Multi-task Weighting Framework

Recent works have shown that deep neural networks benefit from multi-tas...

1 Introduction

Multi-Task Learning (MTL) (ZhangY21; Vandenhende21)

aims to jointly train several related tasks to improve their generalization performance by leveraging common knowledge among them. Since this learning paradigm can not only significantly reduce the model size and increase the inference speed but also improve the performance, it has been successfully applied in various fields of deep learning, such as Computer Vision (CV)

(Vandenhende21)

, Natural Language Processing (NLP)

(czy21)

, reinforcement learning

(ZhangY21) and so on. However, when all the tasks are not related enough, which may be reflected via conflicting gradients or dominating gradients (pcgrad), it is more difficult to train a multi-task model than training them separately because some tasks dominantly influence model parameters, leading to unsatisfactory performance for other tasks. This phenomenon is related to the task balancing problem (Vandenhende21) in MTL. Recently, several works focus on tackling this issue from an optimization perspective via balancing task losses or gradients.

In this paper, we investigate eight State-Of-The-Art (SOTA) task balancing approaches and unify them as loss weighting strategies. According to the way of generating loss weights, those methods can be divided into three types, including the solving approach such as directly solving a quadratic optimization problem in a multi-objective formulation as weights (sk18), the calculating approach such as projecting conflict gradients (pcgrad), and the learning approach such as learning weights in a gradient descent manner (chen2018gradnorm). On the other hand, since there are some discrepancies of the implementation details such as using different backbone networks for training or different metrics for the evaluation among those SOTA weighting methods, leading to inconsistent comparisons, we provide a unified testbed on six CV datasets and four multilingual problems from the XTREME benchmark (hu20b) for those SOTA weighting strategies to show a fair comparison.

In addition, inspired by dynamic weighting processes in those SOTA strategies where loss weights vary over training iterations or epochs, we have a sudden whim:

what will happen if a MTL model is trained with random loss weights?

Specifically, in each training iteration, we first sample the loss weights from a distribution with some normalization and then minimize the aggregated loss weighted by the normalized random weights. Surprisingly, this seemingly unreliable method can not only converge but also achieve comparable performance with the SOTA weighting strategies. Based on this observations, we proposed a simple yet effective weighting strategy for MTL, called Random Loss Weighting (RLW). It is very easy to implement RLW by adding only one line of code and this strategy does not incur any additionally computational cost. An implementation example of RLW in PyTorch

(PaszkeGMLBCKLGA19) is shown below.

1outputs = model(inputs)
2loss = criterion(outputs, labels) # [1, task_num] vector
3weight = F.softmax(torch.randn(task_num), dim=-1) # RLW is only this!
4loss = torch.sum(loss*weight)
5optimizer.zero_grad()
6loss.backward()
7optimizer.step()

To show the effectiveness of RLW, we provide both theoretical analyses and empirical evaluations. Firstly, the objective function of RLW can be considered as a doubly stochastic optimization problem when optimizing by stochastic gradient descent or its variants, where the randomness is from both the mini-batch sampling of the data for each task and the random sampling of loss weights. From this perspective, we give a convergence analysis for RLW. Besides, we can show that RLW has a higher probability to escape local minima when compared with fixing loss weights, resulting in a better generalization performance. Empirically, as described before, we compare RLW with SOTA weighting approaches on six CV datasets and four multilingual problems to show its competitive performance.

In summary, the main contributions of this paper are four-fold.

  • We provide a unified testbed on six multi-task computer vision datasets and four multilingual problems from the XTREME benchmark for a fair comparison among eight SOTA weighting methods and the proposed RLW method.

  • We propose a simple yet effective RLW strategy, which we think is an ignored baseline in MTL.

  • We provide the convergence guarantee and effectiveness analysis for RLW.

  • Experiments show that RLW can achieve comparable performance with SOTA weighting methods without bringing any additionally computational cost.

2 Preliminary

Suppose there are tasks and task has its corresponding dataset . An MTL model usually contains two parts of parameters: task-sharing parameters and task-specific parameters . For example, in CV, usually denotes parameters in the feature extractor shared by all tasks and represents the task-specific output module for task . Let

denotes a task-specific loss function for task

. Then the objective function of a MTL model can be formulated as

(1)

where denotes the average loss on for task and are task-specific loss weights with a constraint that for all . When minimizing Eq. (1) by Stochastic Gradient Descent (SGD) or its variants, the task-specific parameters are simply updated based on the corresponding task gradient , while the task-sharing parameters should be updated by all the task losses jointly as

(2)

where is a learning rate. Obviously, for the update of task-sharing parameters , the loss weighting (i.e., in Eq. (1)) influences via the aggregated gradient essentially and the gradient weighting in Eq. (2) during the backward process has the same effect as the loss weighting when they are using the same weights. Therefore, we can ignore the level on which the weights act and focus on the generation of weights. For simplicity, these two types of weighting are all referred to as loss weighting in the following sections.

Apparently, the most simple method for loss weighting is to set a same weight for every tasks, i.e., without loss of generality, for all . This approach is a common baseline in MTL and it is called Equally Weighting (EW) in this paper. To tackle the task balancing problem and improve the performance of MTL model, there are several works to study how to generate appropriate weights. In this paper, we investigate eight SOTA weighting strategies, i.e. Gradient Normalization (GradNorm) (chen2018gradnorm), Uncertainty Weights (UW) (kgc18), MGDA (sk18), Dynamic Weight Average (DWA) (ljd19), Projecting Conflicting Gradient (PCGrad) (pcgrad), Gradient sign Dropout (GradDrop) (ChenNHLKCA20), Impartial Multi-Task Learning (IMTL) (liu2021imtl), and Gradient Vaccine (GradVac) (wang2021gradient).

According to different ways of generating loss weights, we categorize those loss weighting strategies into three types: the learning approach, the solving approach, and the calculating approach. Both GradNorm and UW consider the loss weights in Eq. (1) as learnable parameters and explicitly optimize them by gradient descent. MGDA casts MTL as a multi-objective optimization problem and directly solves the loss weights in Eq. (1) by solving a quadratic programming problem. DWA, PCGrad, GradDrop and GradVac directly compute the weights by combining gradients and/or losses of all the tasks. IMTL is a hybrid strategy, which combines the learning and the calculating approaches. We summarize those strategies from the perspective of loss weighting in Table 5 in Appendix A.

We now unify those eight SOTA methods as loss weighting strategies, i.e., generating loss weights in Eq. (1). Noticeably, almost all the existing strategies except EW need to incur intensive computation to generate loss weights in every iteration, such as solving a quadratic optimization problem in MGDA, and operating on high-dimensional gradients in PCGrad, GradDrop, IMTL, and GradVac. Different from those strategies, the proposed RLW strategy generates loss weights in a sampling way, thus it is as efficient as EW without bringing additionally computational costs.

3 The RLW Method

In this section, we introduce the proposed RLW method. The RLW method is a simple loss weighting strategy and it considers the loss weights

as random variables. Formally, the objective function of the RLW method is formulated as

(3)

where denotes the expectation and where we omit the task-specific parameters in Eq. (3) for brevity. To guarantee loss weights in to be non-negative, we can first sample from any distribution and then normalize into via a mapping , where is a normalization function for example softmax function and denotes a convex hull in , i.e. means and for all . Note that in most cases is different from .

In Eq. (3), is usually too complex to compute its expectation , thus a stochastic approximation scheme is adopted to minimize Eq. (3). When the mini-batch SGD (bottou1991stochastic) or its variants is used to minimize Eq. (3) as most deep learning models did, Eq. (3) can be viewed as a doubly stochastic optimization problem, where the randomness is from both the mini-batch data sampling for each task and the randomly sampling of the loss weights. In the following, we show that the approximated gradient

is an unbiased estimation of the true gradient of

, where denotes a mini-batch data sampled from all the tasks. Specifically, as is a mini-batch data sampled from to calculate the stochastic gradient to approximate the full gradient for task , we have . Therefore, when we further randomly sample a weight vector , we have

which verifies that

is an unbiased estimation.

In practice, it is very easy to implement the RLW method without modifying network architecture or bringing additionally computational costs. Specifically, in each iteration, we first sample from and normalize it to obtain via appropriate normalization function , and then minimize the aggregated loss weighted by . The entire algorithm of RLW (i.e., minimizing Eq. (3)) via SGD is shown in Algorithm 1. Apparently, the only difference between the proposed RLW strategy and the widely used EW strategy is Line 7 in Algorithm 1 and it is very easy to implement with only one line of code.

1:numbers of iterations , numbers of tasks , learning rate , dataset , weight distribution
2:Randomly initialized ;
3:for  to  do
4:     for  to  do
5:          Sample a mini-batch data from ;
6:          Compute loss ;
7:     end for
8:     Sample weights from and normalize it into via ; RLW is only this
9:     ;
10:end for
Algorithm 1 Optimization Algorithm for RLW by SGD

In this paper, we use six different distributions for

in the proposed RLW method, including uniform distribution between

and (denoted by Uniform

), standard normal distribution (denoted by

Normal), Dirichlet distribution with (denoted by Dirichlet

), Bernoulli distribution with probability

(denoted by Bernoulli), Bernoulli distribution with probability and a constraint (denoted by constrained Bernoulli

), and normal distribution with a random mean and a random variance both sampling from a uniform distribution

for each task (denoted by random Normal). We set as a function of if is the Bernoulli distribution or the constrained Bernoulli distribution and a softmax function for the other types of distribution. When sampling from the first five types of distribution, is simply proportional to , thus it is fair to compare with the EW strategy. When is a random Normal distribution, it means each is sampled from a normal distribution with random mean and variance, thus it is intractable to compute the expectation for and combining with such distribution can further show the effectiveness of RLW.

When sampling from a Bernoulli distribution, the weights for all tasks are either or , i.e., for all . In this way, just a subset of tasks contributes to updating the task-sharing parameters . This manner can be viewed as the mini-batch sampling on the task level. If considering an additional constraint that , it implies only one task is involved in the update of the task-sharing parameters in each iteration. Although there are some works (DongWHYW15; LiuGHDDW15; SogaardG16; SubramanianTBP18; SanhWR19; LiuHCG19) adopting this strategy to train a MTL model, it is a special case in the proposed RLW strategy and beyond existing works, we also provide theoretical analyses to show the effectiveness of the proposed RLW method.

4 Analysis

As the optimization procedure of the RLW method can be viewed as the doubly stochastic optimization, this strategy increases the randomness compared with the fixed loss weights methods optimizing via SGD (denoted by FW), where EW is a special case. In this section, we focus on analyzing how the extra randomness from the loss weights sampling affects the convergence and effectiveness of RLW compared with FW.

In the case of without misunderstanding, we simply use instead of to denote the loss function of task for brevity in this section and Appendix B. For the ease of the analysis, we need to make the following assumption.

Assumption 1.

The loss function of task is -Lipschitz continuous w.r.t , and that . Loss weights in satisfy .

In the following theorem, we analyze the convergence property of Algorithm 1 for RLW.

Theorem 1.

Suppose the loss function of task is -strongly convex. We define and denote by the solution in the -th iteration. When , the step size or equivalently the learning rate in SGD, satisfies , where , under Assumption 1 we have

(4)

where . Then for any positive , can be achieved after iterations with , where .

Theorem 1 shows that the RLW method with the fixed step size has a linear convergence up to a radius around the optimal solution, which is similar to FW according to the property of the standard SGD method (moulines2011non; NeedellSW16). Although RLW has a larger than FW, i.e., , possibly requiring more iterations for RLW method to reach the same accuracy with FW, our empirical experiments in Appendix C.1 show that this does not cause much impact in practice.

We next analyze the effectiveness of the RLW method from the perspective of stochastic optimization. It is observed that the SGD method can escape sharp local minima and converge to a better solution than Gradient Descent (GD) techniques under various settings with the help of noisy gradients (hardt2016train; kleinberg2018alternative). Inspired by those works, we first provide the following Theorem 2 and then leverage this theorem to show that the extra randomness in the RLW method can help RLW to better escape sharp local minima and achieve a better generalization performance than FW.

For the ease of presentation, we introduction some notations. Here we consider the update step of these stochastic methods as , where is a noise with and . Here denotes the intensity of the noise. For the analysis, we construct an intermediate sequence . Then we get . Therefore, the sequence can be regarded as an approximation of using GD to minimize the function .

Theorem 2.

Suppose is -Lipschitz continuous and . If the loss function of task is -one point strongly convex w.r.t a local minimum after convolved with noise , i.e., , then under Assumption 1 we have with probability at least after iterations with , where , , , and .

Firstly, Theorem 2 only requires that is -one point strongly convex w.r.t after convolved with noise

, which is much weak than the convexity assumption and can hold for deep neural networks. Moreover, this theorem implies that for both RLW and FW methods, their solutions have a high probability to get close to a local minimum

depending on the noise . Note that by adding extra noise, the sharp local minimum will disappear and only the flat local minimum with a large diameter will still exist (kleinberg2018alternative). On the other hand, those flat local minima may satisfy one point strongly convexity assumption in Theorem 2, thus the diameter of the converged flat local minimum is affected by the noise intensity. Due to the extra randomness from loss weights sampling, the RLW method can provide a stronger noise (i.e. a larger ) than FW (referred to Appendix B.3). Hence RLW can better escape sharp local minima and converge to a flatter local minimum than FW, resulting in better generalization performance.

5 Experiments

In this section, we empirically evaluate the proposed RLW method by conducting experiments on six computer vision datasets (i.e., NYUv2, CityScapes, CelebA, PASCAL-Context, Office-31, and Office-Home) and four multilingual problems from the XTREME benchmark

(hu20b). Due to page limit, experimental results of the CityScapes, CelebA, Office-31, Office-Home datasets and two multilingual problems are put in Appendix C.

5.1 Datasets

The NYUv2 dataset (silberman2012indoor)

is an indoor scene understanding dataset, which consists of video sequences recorded by the RGB and Depth cameras in the Microsoft Kinect. It contains 795 and 654 images with ground-truths for training and validation, respectively. This dataset includes three tasks: 13-class semantic segmentation, depth estimation, and surface normal prediction.

The PASCAL-Context dataset (MottaghiCLCLFUY14) is an annotation extension of the PASCAL VOC 2010 challenge. It contains 10,103 images, which are divided into two parts: 4,998 for training and 5,105 for validation. We consider two tasks with annotations in this dataset: 21-class semantic segmentation and 7-class human part segmentation. By following (ManinisRK19), we generate two additional tasks, including the saliency estimation and surface normal estimation tasks where their ground-truth labels are computed by the label distillation using pretrained state-of-the-art models (BansalCRGR17; ChenZPSA18).

The XTREME benchmark (hu20b)

is a large-scale multilingual multi-task benchmark for cross-lingual generalization evaluation, which covers fifty languages and contains nine tasks. We conduct experiments on four tasks containing Named Entity Recognition (NER), Part-Of-Speech (POS) tagging, Natural Language Inference (NIL), and Paraphrase Identification (PI) from this benchmark. On each task, we construct a multilingual problem by choosing the four languages with the largest number of data. The more details are provided in Appendix 

C.6.

5.2 Implementation Details

The network architecture used adopt the hard-parameter sharing pattern (Caruana93), which shares bottom layers of the network for all tasks and uses separate top layers for each task. Other architectures with more parameter sharing manners are provided in Section 5.7. The implementation details on each dataset are introduced in the following.

For the NYUv2 dataset, the DeepLabV3+ architecture (ChenZPSA18)

is used. Specifically, a ResNet-50 network pretrained on the ImageNet dataset with dilated convolutions

(YuKF17) is used as a shared encoder among tasks and the Atrous Spatial Pyramid Pooling (ASPP) (ChenZPSA18) module is used as task-specific head for each task. Input images are resized to . The Adam optimizer (kingma2014adam) with the learning rate as and the weight decay as is used for training and the batch size is set to 8. We use the cross-entropy loss, loss and cosine loss as the loss function of the semantic segmentation, depth estimation and surface normal prediction tasks, respectively. For the PASCAL-Context dataset, the network architecture is similar to that on NYUv2 dataset with a shallower ResNet-18 network used as the shared encoder due to constraints of computing resources. All input images are resized to . The Adam optimizer with both the learning rate and weight decay as is applied for training and the batch size is set to 12. The cross-entropy loss is used for two segmentation tasks and saliency estimation task, while the normal estimation task uses the loss. For each multilingual problem in the XTREME benchmark, a pretrained multilingual BERT (mBERT) model (DevlinCLT19)

implemented via the open source transformers library

(WolfDSCDMCRLFDS20) is used as the shared encoder among languages and a fully connected layer is used as the language-specific output layer for each language. The Adam optimizer with the learning rate as and the weight decay as is used for training and the batch size is set to 32. The cross-entropy loss is used for the four multilingual problems.

5.3 Evaluation Metric

To measure the performance of MTL models in a scalar metric, for homogeneous MTL problems (e.g., the Office-31 dataset) which contain tasks of the same type such as the classification task, we directly average the performance metrics among tasks. For heterogeneous MTL problems (e.g., the NYUv2 dataset) that contain tasks of different types, by following (ManinisRK19; Vandenhende21), we compute the average of the relative improvement over the EW method on each metric of each task as

where denotes the number of metrics in task , denotes the performance of a task balancing strategy for the th metric in task , is defined similarly for the EW strategy, and is set to if a higher value indicates better performance for the th metric in task and otherwise .

5.4 Results on the NYUv2 Dataset

The results on the NYUv2 validation dataset are shown in Table 1. It is noticeable that the proposed RLW strategy can achieve comparable performance with SOTA baseline methods. Firstly, RLW with six weight distributions can always outperform EW, which implies that training in a doubly stochastic manner can have a better generalization ability. Secondly, RLW can achieve a balanced improvement on all tasks. That is, RLW has comparable or better performance on each metric in each task when compared with EW, resulting in a large . Different from RLW, many baseline methods achieve unbalanced performance on all the tasks. For example, IMTL significantly outperforms other methods on the normal prediction task but has unsatisfactory performance on the other two tasks. Hence, this can be one advantage of the proposed RLW strategy since MTL aims to improve the generalization performance of each task as much as possible. Thirdly, RLW with some distributions (i.e., “constrained Bernoulli” and “Normal”) can improve the generalization performance by more than 1%, which is significantly better than baseline methods. Even, RLW with the constrained Bernoulli distribution can entirely dominate not only the EW method but also some baseline methods such as GradNorm, UW, DWA, and GradDrop, on each task, which demonstrates the effectiveness of the RLW method.

Moreover, we compare the average time of training one epoch for each loss weighting strategy with the same batch size (i.e., 8) on a single NVIDIA GeForce RTX 3090 GPU. The relative training speed of each method over the EW method is reported as in Table 1. Noticeably, the proposed RLW strategy is as computationally efficient as EW, while some baseline methods are compute-intensive. For example, PCGrad and GradVac take about twice the time of EW for each epoch because of computing the gradients of parameters. MGDA spends a lot of time to solve a complex quadratic programming problem. Furthermore, of those baseline methods will increase as the network becomes deeper, while the proposed RLW strategy is architecture-agnostic and always as efficient as EW.

By combining the above analysis, we think that the proposed RLW method is an effective and efficient loss weighting strategy for MTL.

Weighting Segmentation Depth Surface Normal
Strategy mIoU Pix Acc Abs Err Rel Err Angle Distance Within
Mean Median 11.25 22.5 30
EW 53.91 75.56 0.3840 0.1567 23.6338 17.2451 34.94 60.65 71.81
GradNorm 53.81 75.35 0.3863 0.1556 23.6106 17.2565 34.98 60.58 71.76
UW 53.15 75.41 0.3817 0.1576 23.6487 17.2040 34.98 60.71 71.80
MGDA 53.66 75.37 0.3864 0.1610 23.4757 16.9912 35.44 61.17 72.16
DWA 53.33 75.42 0.3834 0.1556 23.5806 17.1242 35.18 60.88 71.91
PCGrad 53.34 75.43 0.3857 0.1600 23.2293 16.6966 36.09 61.80 72.66
GradDrop 53.80 75.56 0.3857 0.1587 23.8726 17.1406 35.10 60.72 71.60
IMTL 52.90 74.88 0.3883 0.1632 23.0534 16.5304 36.30 62.20 73.08
GradVac 53.52 75.43 0.3840 0.1559 23.2892 16.8601 35.67 61.53 72.46
RLW (Uniform) 54.09 75.78 0.3826 0.1563 23.6272 17.2711 34.73 60.67 71.87
RLW (Normal) 54.19 75.98 0.3789 0.1570 23.1984 16.7944 35.71 61.74 72.77
RLW (Dirichlet) 53.54 75.45 0.3834 0.1547 23.6392 17.0715 35.28 60.92 71.88
RLW (Bernoulli) 53.72 75.62 0.3850 0.1610 23.1413 16.6591 36.08 61.98 72.86
RLW (constrained Bernoulli) 54.32 75.78 0.3779 0.1533 23.2101 16.9354 35.41 61.44 72.58
RLW (random Normal) 54.08 75.77 0.3815 0.1581 23.5598 16.9577 35.53 61.20 72.13
Table 1: Performance on the NYUv2 validation dataset with three tasks: 13-class semantic segmentation, depth estimation, and surface normal prediction. The best results for each task on each measure are highlighted in bold. () indicates that the higher (lower) the result, the better the performance.

5.5 Results on the PASCAL-Context Dataset

The results on the PASCAL-Context validation dataset are shown in Table 2. The empirical observations are similar to those on the NYUv2 dataset. Specifically, RLW with different distributions can outperform EW, which means RLW has a better generalization ability. Most baseline methods have unsatisfactory performance on this dataset and the best baseline, i.e., MGDA, achieves the largest of 0.18%. Thus, RLW outperforms many baseline methods. Moreover, RLW with the constrained Bernoulli distribution achieves the highest improvement of 0.46%. Furthermore, RLW does achieve more balanced improvement than some baseline methods. For example, GradNorm performs not so good on the saliency estimation task, leading to the lowest . Although IMTL significantly improves the performance of the surface normal estimation task, it performs unsatisfactorily on the other tasks especially the semantic segmentation task. Moreover, the relative training speed of each method over the EW method is similar to the NYUv2 dataset and hence we omit it in Table 2.

Weighting SS HPS Saliency Surface Normal
Strategy mIoU mIoU mIoU maxF Angle Distance Within
Mean RMSE 11.25 22.5 30
EW 64.52 58.71 64.31 77.11 17.6444 26.1634 42.24 75.93 86.97
GradNorm 64.13 58.49 61.64 72.46 18.0455 26.4642 41.03 74.90 86.26
UW 63.72 59.13 64.47 77.26 17.4962 26.0463 42.64 76.48 87.35
MGDA 63.34 58.86 64.79 77.54 17.3070 25.8584 43.30 77.14 87.79
DWA 64.32 58.61 64.30 77.15 17.4065 25.9242 42.72 76.71 87.59
PCGrad 63.58 58.68 63.79 76.71 17.2376 25.8572 43.69 77.11 87.70
GradDrop 64.04 59.36 62.31 73.10 17.3246 25.8966 43.29 76.92 87.63
IMTL 62.67 58.35 62.92 73.21 16.8026 25.4852 45.02 78.49 88.65
GradVac 62.99 58.63 64.30 77.15 17.1852 25.7621 43.55 77.41 88.00
RLW (Uniform) 63.52 59.03 63.75 76.71 17.0261 25.6528 44.21 77.77 88.24
RLW (Normal) 64.14 58.43 64.05 76.86 17.1794 25.7734 43.69 77.31 87.91
RLW (Dirichlet) 63.97 58.88 64.30 77.19 17.2147 25.8093 43.67 77.33 87.89
RLW (Bernoulli) 64.34 58.35 64.28 77.03 17.3379 25.9016 43.18 76.93 87.66
RLW (constrained Bernoulli) 65.07 58.52 64.19 76.96 17.3377 25.9005 43.30 77.04 87.69
RLW (random Normal) 64.09 59.15 63.84 76.76 17.3860 25.9472 42.98 76.73 87.54
Table 2: Performance on the PASCAL-Context validation dataset with four tasks: 21-class semantic segmentation (abbreviated as SS), 7-class human parts segmentation (abbreviated as HPS), saliency estimation, and surface normal prediction. The best results for each task on each measure are highlighted in bold. () indicates that the higher (lower) the result, the better the performance.

5.6 Results on the XTREME benchmark

Weighting POS (F1 Score) PI (Accuracy)
Strategy en zh te vi Avg en zh de es Avg
EW 95.02 88.89 91.16 87.11 90.55 94.09 84.59 89.44 90.24 89.59
GradNorm 95.01 88.91 91.88 87.06 90.71 94.39 85.94 90.99 91.44 90.69
UW 94.89 88.77 90.96 87.12 90.44 93.74 85.44 90.24 91.29 90.18
MGDA 95.08 88.97 92.35 87.12 90.88 94.64 84.99 89.84 90.89 90.09
DWA 95.02 89.03 91.87 87.27 90.80 94.69 84.99 89.49 91.44 90.15
PCGrad 94.85 88.42 90.72 86.71 90.18 94.19 85.49 89.09 91.24 90.00
GradDrop 95.08 89.06 90.65 87.17 90.49 94.29 84.44 89.69 90.94 89.84
IMTL 94.87 88.80 92.18 86.72 90.65 94.54 84.79 90.14 90.99 90.12
GradVac 94.87 88.41 90.62 86.47 90.09 94.29 84.94 89.19 90.89 89.83
RLW (Uniform) 95.06 89.00 92.31 86.93 90.83 94.69 85.79 90.29 91.94 90.68
RLW (Normal) 95.01 88.87 92.86 86.85 90.90 94.39 85.34 90.04 91.84 90.40
RLW (Dirichlet) 95.16 88.96 91.64 87.24 90.75 94.24 84.39 89.99 90.99 89.90
RLW (Bernoulli) 95.13 89.10 91.13 87.03 90.60 95.09 85.89 90.24 91.99 90.80
RLW (constrained Bernoulli) 94.98 89.05 92.33 86.87 90.81 94.69 85.49 90.19 90.99 90.34
RLW (random Normal) 95.07 89.00 91.10 87.25 90.60 94.79 84.94 89.54 90.99 90.07
Table 3: Performance on two multilingual problems, i.e. POS and PI from the XTREME benchmark. The best results for each language are highlighted in bold.

We study four multilingual problems from the XTREME benchmark (hu20b) and show experimental results of POS and PI in Table 3. Due to the page limit, the results of NLI and NER are placed in Table 12 in the Appendix. Different from heterogeneous MTL problems on the NYUv2 and PASCAL-Context datasets, in these multilingual problems, each language has its own input data, which is usually called homogeneous MTL problems (ZhangY21). According to the results in Tables 3, RLW with diverse distributions still outperforms EW in all the two multilingual problems, which further shows the effectiveness of RLW on different type of MTL problem. Besides, RLW achieves comparable and even better performance than those baseline methods. For example, on the POS multilingual problem, RLW has the highest average F1 score and it achieves the best average accuracy on the PI problem.

5.7 RLW with Different Architectures

The proposed RLW strategy can be seamlessly combined with other MTL network architectures without increasing additional computation cost. To see this, we combine the RLW strategy with three SOTA MTL architectures, i.e., cross-stitch network (MisraSGH16), Multi-Task Attention Network (MTAN) (ljd19), and CNN with Neural Discriminative Dimensionality Reduction layer (NDDR-CNN) (gao2019nddr) and evaluate all the methods on the NYUv2 dataset. Experimental results when using the MTAN architecture are shown in Table 4. Due to page limit, the results for the cross-stitch and NDDR-CNN networks are put in Tables 6 and 7 in the Appendix, respectively.

According to the results, we have some observations. Firstly, compared with the performance when using the DMTL architecture (i.e., the results on Table 1), the performance of each loss weighting strategy with the deeper MTAN architecture is improved, especially on the surface normal estimation task, which is due to the larger capacity of the MTAN. Secondly, the RLW strategy with different distributions can outperform EW, which indicates the effectiveness of RLW with more advanced and deeper architectures. Thirdly, compared with baseline methods, the proposed RLW method can achieve competitive performance. For example, RLW with the random normal distribution can improve over EW by 0.76% and is among the top-3 methods.

Weighting Segmentation Depth Surface Normal
Strategy mIoU Pix Acc Abs Err Rel Err Angle Distance Within
Mean Median 11.25 22.5 30
EW 53.77 75.79 0.3789 0.1546 22.8344 16.6021 36.37 62.45 73.32
GradNorm 54.75 75.89 0.3797 0.1537 22.6295 16.2829 37.09 63.12 73.87
UW 54.80 75.97 0.3767 0.1536 22.6744 16.3050 36.90 63.14 73.87
MGDA 53.94 75.97 0.3788 0.1573 22.6157 16.1395 37.24 63.38 73.94
DWA 53.71 75.72 0.3801 0.1539 23.1560 16.8354 35.95 61.85 72.79
PCGrad 53.83 75.70 0.3823 0.1568 22.9481 16.3528 37.05 62.76 73.37
GradDrop 53.83 75.85 0.3754 0.1530 22.7846 16.4156 36.77 62.74 73.51
IMTL 53.39 75.20 0.3807 0.1549 22.2571 15.8336 38.04 64.11 74.59
GradVac 54.52 75.68 0.3755 0.1546 22.9389 16.5692 36.48 62.33 73.20
RLW (Uniform) 54.27 75.51 0.3820 0.1548 22.9640 16.4375 36.89 62.65 73.29
RLW (Normal) 53.70 75.62 0.3791 0.1551 22.8395 16.3328 37.05 62.82 73.44
RLW (Dirichlet) 53.36 75.08 0.3778 0.1514 22.8803 16.3579 36.99 62.80 73.40
RLW (Bernoulli) 53.46 75.74 0.3820 0.1517 22.5642 16.2013 37.21 63.29 73.96
RLW (constrained Bernoulli) 54.11 75.47 0.3821 0.1558 22.7969 16.2204 36.93 62.90 73.56
RLW (random Normal) 54.10 75.77 0.3802 0.1554 22.4400 16.0336 37.74 63.68 74.17
Table 4: Performance under the MTAN architecture on the NYUv2 validation dataset with three tasks: 13-class semantic segmentation, depth estimation, and surface normal prediction. The best results for each task on each measure are highlighted in bold. () indicates that the higher (lower) the result, the better the performance.

6 Conclusions

In this paper, we have unified eight state-of-the-art task balancing methods from the loss weighting perspective. Based on randomly sampling task weights from distributions, we propose a simple RLW strategy that can achieve comparable performance with state-of-the-art baselines. We analyze the convergence property of the proposed RLW method and the double stochasticity that can help escape sharp local minima. Finally, we provide a consistent and comparative comparison to show the effectiveness of the proposed RLW method on six computer vision datasets and four multilingual tasks from the XTREME benchmark. In our future studies, we will apply the proposed RLW method to more fields such as reinforcement learning.

References

Appendix

Appendix A Summary of loss weighting strategies

In this section, we summarize the eight SOTA methods introduced in Section 2 from a perspective of loss weighting. We define and , where denotes the gradient of with respect to . Let denote the non-negative subspace in the -dimensional space , denote ,

denote an identity matrix,

denote a diagonal matrix with its principal diagonal ,

be the ReLU operation,

represent the sign function, and denote the indicator function. The summary of different methods is in Table 5.

Approach Strategy Weight (-th iteration) Conv. Not Grad.
EW
Learning GradNorm
, and is pre-defined
UW
IMTL-L
Solving MGDA
Calculating DWA
PCGrad
GradDrop
is pre-defined,
IMTL-G
GradVac
is pre-defined, and is initialized as
Sampling RLW (ours) and is a normalization function
Table 5: A summary of SOTA weighting strategies from a perspective of loss weighing. means whether a convergence analysis (abbreviated as Conv.) is provided in the original paper. denotes that the corresponding weighting strategy needs not to compute gradients (abbreviated as Not Grad.) for generating loss weights .

Appendix B Proof of Section 4

b.1 Proof of Theorem 1

Since is -strongly convex w.r.t , for , for any two points and in , we have

(5)

Since , we have , where . Then for any , is -strongly convex.

With notations in Theorem 1, we have

Note that and

where the first inequality is due to the Cauchy-Schwarz inequality and the third inequality is due to . Then, by defining , we obtain

(6)

If , we recursively apply the inequality (B.1) over the first iterations and we can obtain

Thus the inequality (4) holds if .

According to inequality (B.1), the minimal value of a quadratic function is achieved at . By setting , we have

Then if , we have Therefore,

b.2 Proof of Theorem 2

Since and , we have

Since the loss function of task is -one point strongly convex w.r.t a given point after convolved with noise , similar to inequality (B.1), we have

where . Since is -Lipschitz continuous, for any two points and in , we have

(7)

Note that , where . Therefore, is -Lipschitz continuous. Then we can get

where the second inequality is due to the convexity assumption and , the third and forth inequalities are due to the Lipschitz continuity. We set and . If , we have , then we get