LAW: Learning to Auto Weight

05/27/2019 ∙ by Zhenmao Li, et al. ∙ SenseTime Corporation Beihang University 5

Example weighting algorithm is an effective solution to the training bias problem. However, typical methods are usually limited to human knowledge and require laborious tuning of hyperparameters. In this study, we propose a novel example weighting framework called Learning to Auto Weight (LAW), which can learn weighting policy from data adaptively based on reinforcement learning (RL). To shrink the huge searching space in a complete training process, we divide the training procedure consisting of numerous iterations into a small number of stages, and then search a low-deformational continuous vector as action, which determines the weight of each sample. To make training more efficient, we make an innovative design of the reward to remove randomness during the RL process. Experimental results demonstrate the superiority of weighting policy explored by LAW over standard training pipeline. Especially, compared with baselines, LAW can find a better weighting schedule which achieves higher accuracy in the origin CIFAR dataset, and over 10 dataset with 30



There are no comments yet.


page 1

page 2

page 3

page 4

Code Repositories

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Although the quantity of training samples is critical for current state-of-the-art deep neural networks (DNNs), the quality of data also has significant impacts on exerting the powerful capacity of DNNs on various tasks. For supervised learning, it is a common hypothesis that both training and test examples are drawn i.i.d. from the same distribution

. However, during practical training, this assumption is not always valid, therefore, the training bias problems, mainly including label noise and class imbalance, are encountered frequently.

It is widely known that example111In this paper, “example” and “sample” can be used interchangeably. weighting algorithm is an effective solution to the training bias problem. For label noise problem, previous study in zhang2016understanding establish that DNNs for image classification trained with stochastic gradient methods (SGD) can easily fit a random labelling of the training data. Therefore, Jiang et al. propose MentorNet jiang2017mentornet

, which provides a weighting scheme for StudentNet to focus on the sample whose label is probably correct. For class imbalance problem, Ren et al.

ren2018learning propose a novel meta-learning algorithm that learns to assign weights to training examples based on their gradient directions. Compared to common techniques such as re-sampling he2008learning ; maciejewski2011local and cost-sensitive learning tang2008svms ; lin2017focal , the proposed method can directly learn from the data, rather than requiring prior-knowledge.

As a data selection strategy, example reweighting can be considered as ‘soft’ sampling as well, which can help to find optimal solutions and speed up the training. Previous studies on curriculum learning bengio2009curriculum and self-paced learning kumar2010self show that by feeding a training model with data of proper level of difficulty in different training stages, the training process can achieve significant improvements in generalization. However, most curriculums in literatures are usually predefined for specific tasks with prior knowledge. As far as we know, only Fan et al. DBLP:journals/corr/FanTQBL17 propose a framework to learn a policy to automatically select training data leveraging reinforcement learning (RL).

In this paper, we propose a novel example weighting framework to learn weighting policy from data adaptively based on RL, which we call Learning to Auto Weight (LAW). Although RL has been successfully employed in various tasks such as network architecture design zoph2016neural , augmentation cubuk2018autoaugment , etc., there are some inherently severe problems in standard RL, which make it difficult to learn an effective auto weighting policy directly. The first challenge is the huge searching space caused by numerous iterations during training process. To alleviate this problem, we divide the training process into a small number of stages consisting of successive iterations, so that the time steps for searching policy can be significantly limited to the number of stages. The second one is that data utilization is usually of low efficiency for RL, which cannot make full use of the training set. Hence, we modify the deep deterministic policy gradient (DDPG) lillicrap2015continuous algorithm in the updating part to make the best of the data in replay buffer and extend it to multi-process like other RL methods mnih2016asynchronous . Last but not the least, the data in one batch sampled from the whole dataset are in different combination at every episode, which introduces randomness to the policy searching. The RL method cannot decide what makes the reward higher, i.e., the better action or just some good data points. This can be considered as a kind of credit assignment problem. We solve this problem by designing a novel reward measurement, where two identical networks are trained with the same data in each step, one for searching actions and another for reference. The reward is the difference in accuracy between them.

Based on the above analysis, we first divide the whole training process to a small number of stages, which are sets of iteration steps. Then we search a deterministic continuous vector as the action to decide the weights of data points by DDPG. Meanwhile, we train two network at the same time, one for searching the weighting policy and another is a network as a reference model for computing the reward. Experimental results demonstrate the superiority of weighting policy explored by LAW over standard training pipeline. Especially, compared with baselines, LAW can find a better weighting schedule which achieves higher accuracy in the origin CIFAR dataset, and more than higher in accuracy on the contaminated CIFAR dataset with label noises.

Our contributions are listed as follows:

  • we propose a novel example weighting framework called LAW, which can learn weighting policy from data adaptively based on RL. LAW can find good sample weighting schedules that achieve higher accuracy in the original or contaminated CIFAR10 and CIFAR100 datasets, without extra information about the label noises.

  • To shrink the huge searching space in a complete training process, our framework builds on the stage, which is a set of successive iterations, instead of the training step in origin RL.

  • To make the reinforcement learning work efficient, the proposed framework makes full use of the data in the replay buffer and extend the DDPG algorithm to multi-process.

  • To alleviate the credit assignment problem caused by data randomness, we design a novel reward form that removes the randomness and make the reinforcement learning more stable.

2 Related Work

Weighting The practice of weighting each training example has been well investigated in the previous studies. Weighting algorithms mainly solve two kinds of problems: label noise and class imbalance. If models can converge to the optimal solution on the training set with coarse labels, there could be large performance gaps on the test set. This phenomenon has also been explained in zhang2016understanding ; neyshabur2017exploring ; arpit2017closer . Various regularization terms on the example weights have been proposed to prevent overfitting to on corrupted labels ma2017self ; jiang2015self . Recently, Jiang et al. propose MentorNet jiang2017mentornet , which provides a weighting scheme for StudentNet to focus on the sample whose label is probably correct. However, to acquire a proper MentorNet, it is necessary to give extra information such as the correct labels on a dataset during training. On the other hand, class imbalance is usually caused by the cost and difficulty in collecting rarely seen classes. Kahn and Marshall kahn1953methods propose importance sampling which assigns weights to samples to match one distribution to another. Lin et al. lin2017focal propose focal loss to address class imbalance by adding a soft weighting scheme that emphasizes harder examples. Other techniques such as cost-sensitive weighting ting2000comparative ; khan2018cost are also useful for class imbalance problems. Previous methods usually require prior-knowledge to determine a specified weighting mechanism, the performance will deteriorate if we cannot get accurate descriptions of the dataset. To learn from the data, Ren et al. ren2018learning propose a novel meta-learning algorithm that learns to assign weights to training examples based on their gradient directions.

Curriculum Learning Inspired by that humans learn much better when putting the data points in a meaningful order (like from easy level to difficult level), bengio2009curriculum formalizes a training strategy called curriculum learning which promotes learning with examples of increasing difficulty. This idea has been empirically verified and applied in a variety of areas (kumar2010self, ; lee2011learning, ; supancic2013self, ; jiang2014easy, ; graves2017automated, ; turian2010word, ). Self-paced method (kumar2010self, ) defines the curriculum by considering the easy items with small losses in early stages and add items with large losses in the later stages. peng2019accelerating

builds more efficient batch selection method based on typicality sampling, where the typicality is estimated by the density of each sample.

jiang2014self formalizes the curriculum with preference to both easy and diverse samples. alain2015variance

reduces gradient variance by the sampling proposal proportional to the L2-norm of the gradient. Curriculums in the exsiting literature are usually determined by heuristic rules and thus require laborious tuning of hyperparameters. To the best of our knowledge, only Fan et al.

DBLP:journals/corr/FanTQBL17 propose a framework called NDF to learning a policy to automatically select training data based on reinforcement learning. Compared with DBLP:journals/corr/FanTQBL17 , our framework acts as a “soft” sampler, and learns to weight data points adaptively instead of filtering them, which makes the training more smoothing and stable.

3 Preliminaries

Plenty of reinforcement learning methods have been proposed mnih2015human ; silver2016mastering ; silver2017mastering ; schulman2017proximal ; lillicrap2015continuous ; mnih2016asynchronous . In a standard reinforcement learning setting, an agent interacts with an environment over numbers of discrete time steps. During each interaction, a state from state set is emitted by the environment, then the agent receives the state and selects an action from a set of possible actions A according to a policy , which models a mapping from state to action. After executing the action , the environment emits both the next state and a reward

to the agent. This process continues step by step until reaching a terminal state. Mathematically, this process can be modeled with Markov decision process (MDP) defined by a tuple

containing all the element above, except a transition probability from to and a discount factor . The goal of the agent is to maximize the expected total accumulated reward .

The state-value function is given by which represents the expected return from the state; the action value represents the expected return in state with selected action . Usually, value-based and policy-based reinforcement learning methods optimize goals based on and .

In value-based methods, function approximators like neural networks are utilised to model , such as Q-learning which aims to find the best with corresponding policy . The update function is defined as the following equation:


where is function parameters. At state , the agent takes an action leading to reward and next state . is a possible action in next state. The tuple is called transition. Different versions of Q-learning have similar update function.

Policy-based methods also utilise the neural network as a function approximator, but their target is the policy not value . Basically, they work by computing an estimator of policy gradient and employ stochastic gradient ascent algorithm. The most commonly used gradient estimator has the following form:


where is an estimator of the advantage function defined as .

The value-based method like Q-learning is suited for low-dimensional discrete action, while the police-based method is suited for both discrete and continuous actions. In this work, our objective is to search a low dimension continuous vector as the action, which decides the weight of each item in one batch and the action is deterministic. Therefore, we choose the classical police-based method Deep deterministic policy gradient (DDPG) (lillicrap2015continuous, ).

4 Learning to Auto Weight

For a standard SGD training process, one iteration consists of a forward and backward propagation based on the input mini-batch of samples, and the whole training process usually contains successive iterations before getting the final model. In this process, the order of feeding samples has significant influences on the accuracy. For example, one effective order is from easy level to difficult level for different training stages bengio2009curriculum . Previous works on curriculums are usually determined by heuristic rules, while in this study, we aim to propose a novel weighting framework to learn from data. To this end, we utilize reinforcement learning to learn an automatic example weighting policy for building an implicit order of feeding samples. The agent is the sample weighting model to be learnt while the environment is the network to be trained on some dataset.

Our framework, called LAW, is based on DDPG and illustrated in Figure 1, where the left half side constitutes the Agent and the right side constitutes the Environment. According to the State including current training stage, historical training loss, validating accuracy, the Actor outputs an action to weight each Item, which is encoded by some features including loss, entropy, density, label. After that, we calculate a weighted mean loss based on the weights of items and use it to train the Target Network for policy searching. At the same time, we update the Reference Network using original loss. The Relative Accuracy is the difference between the Target Network and the Reference Network in terms of accuracy, which can be computed as the Reward to update the Critic. It should be noted that the Reference Network and the Target Network are identical except the losses for update.

Figure 1: The framework of LAW.

4.1 Stage-MDP

A complete training process involves thousands of steps which is too complicated to update the policy. To alleviate this problem, we introduce a sparse updating strategy for the agent, called stage-MDP. We devide the training process into several stages , and the agent is updated in each stage, which contains a fixed set of iterations. Since , the complexity can be greatly reduced from to . Meanwhile, the searching space shrinks greatly. In the following, the MDP tuple will be described in detail.

State: The state of the network at step is denoted as , which must be informative enough to generate the best action. Considering the effectiveness and time consumption, we choose the training phase descriptor as the state, which is the combination of current training stage, historical training loss, historical validating accuracy.

Action: For LAW, the action is a continuous parameter matrix to decide the weights of items according to items’ features with dimensions. Thus, effective feature descriptors are needed. The features we used are listed as follows:
Training Loss

: One practicable descriptor is the training loss, which is frequently utilized in curriculum learning, hard example mining and self-paced method. The loss can be used to describe the difficulty level of the samples. Sometimes it would be helpful to drop out the data points with large losses when there are some outliers.


: For classification tasks, the entropy of predicted logits demonstrates the hardness of the input sample. Hard examples tend to have large entropy while samples with a small and stable entropy are probably easy.

Density: The density of the feature reveals how informative or typical the sample is, which can be used to measure the importance of samples. Obviously, the sample with large density should be paid more attention to. The density of features could be calculated and saved beforehand by Gaussian kernel density algorithm peng2019accelerating . To make it simple, we calculate samples’ similarity matrix using samples’ logits () defined as in one batch, and for each sample, we average its similarities with other data points to approximate the density.
Label: We can also use the label on account of that the label information would help to remove some bias in the dataset like class imbalance.

0:  Training data , number of episode , batch size , number of steps in one episode , number of steps in one stage
  Randomly initialize actor and critic network with weights and Copy the weights to the target network and : , Initialize replay buffer
  for  do
     Random initialize the training network with weight
     Copy the weights to the reference model : Initialize a random process for action exploration
     for  do
        if  then
           Collect state and select action
        end if
        Train the reference model one stepCollect features for data points in one batch, and compute the loss weights according to the Equation 3Weight the item’s loss and train the target model
        if  then
           Compute reward by Equation 4Store the tuple in RUpdate actor and critic network
        end if
     end for
  end for
  The target actor network
Algorithm 1 LAW-DDPG algorithm for learning to auto weight

We normalize the features to make the reinforcement learning more stable. Once the features are extracted, the weights of items are defined as:


where is the action selected by the policy according to the state and f denotes the feature vector. Specifically, the agent (auto weight model) receives the state in each stage, and selects a vector as the action to decide each sample’s weight.

Reward: The reward is the main signal for guiding reinforcement learning. Considering that different data points in each step from the whole dataset introduce inevitable randomness, it’s hard to decide what makes the reward higher. Thus, the reward designed must remove the uncertain so that reinforcement learning methods can apply credits to those better actions. In each episode, we train two networks, one for searching policy called target network and the other for reference called reference network. The reward is the accuracy of the target network subtracting the accuracy of the reference network defined as:


4.2 Learning Algorithm

DDPG is utilized to learn the automatic weighting policy. Since the time step is based on the stage, we modify parts of updating actor and critic network. Instead of sampling a transition in one time step, we use all the data in replay buffer to update the actor and critic network for numbers of epochs, so that the networks take full advantage of the cached data to learn useful knowledge. On the other hand, we extend the algorithm to multi-process with the purpose of obtaining enough samples in a short time. We only collect the state and the action at the step which is the start of a new stage. The algorithm details are list in the Algorithm 

1. The algorithm to update actor and critic network is illustrated in Algorithm 2.

There are some other tricks that help to make the reinforcement learning work. We add regulation in the reward to limit the norm of action; the reward is not very important in early stages, thus we add weights to reward in different stages as:


where is current epoch, is total number of epochs of training process and is scale adjustment rate.

  Replay buffer , actor and critic network, number of epochs , policy batch size Process the data of replay buffer to generate transition data
  for  do
     Shuffle data of transition
     for each transitions data  do
        Set Update critic by minimizing the MSE loss: Update the actor one step as in  lillicrap2015continuous
     end for
  end for
Algorithm 2 LAW-DDPG: update actor and critic network

5 Experiments

5.1 Experiments setup

We demonstrate the effectiveness of LAW on image classification dataset CIFAR10 and CIFAR100 krizhevsky2009learning , which consist of 50,000 training and 10,000 validation color images with the size of 32

32, respectively. All the experiments are implemented on the platform of PyTorch 

paszke2017automatic using Nvidia Tesla V100 with batch size 128. The actor and critic are modeled using MLP with 4 layers. The target and reference networks are modeled with VGG19 simonyan2014very , a well-known CNN model. We train the actor and critic for 500 episodes by Adam kingma2014adam with learning rates and respectively. For each episode, the target and reference networks are optimized by SGD with momentum 0.9 and weight decay 0.0001, while the learning rate is started from 0.01 and divided by 10 at the 11k iteration. The terminal state in one episode is at the iteration of 15k. Moreover, during the first 4k iterations, the actor and critic is not trained and the VGG19 model is optimized without any weighting from LAW. We call this period warm-up. After the warm-up, actor and critic are optimized every 500 iterations (500 iterations constitute one stage) and used to search the proper policy to focus on important samples by weighting. To acquire reliable reward, we draw 50% samples from the test set randomly as validation dataset, while performance of the actor is tested on the other half of samples. Our codes will be released on line.

5.2 Evaluation of image classification

We conduct experiments on CIFAR10 and CIFAR100 image classification tasks. During training, we define the reward as top-1 accuracy gap between target and reference networks. As shown in Fig. 2, we plot tendency of reward in three episodes. It can be seen that the accuracy gaps become positive in later period. Besides, the tendency is consistent with the Equation 5, which indicates that we pay more attention to the rewards in later stages.

(a) CIFAR10
(b) CIFAR100
Figure 2: The rewards during training in three episodes.

The curves of loss gap are shown in Figure 3. The loss gap is defined as the mean of the items’ losses between the target network and reference network. For CIFAR10 in Figure 2(a), loss values of the target network in all iterations are higher than that of the reference network. It indicates that the LAW can help the target network to focus on samples with larger classification error. Along with the training progress, loss gaps gradually narrow as losses of both models converge towards zero. The curves of CIFAR100 in Figure 2(b) indicates similar results.

(a) CIFAR10
(b) CIFAR100
Figure 3: The loss gap between target and reference model in three episodes.

Finally, to evaluate the superiority of LAW, we adopt the optimized actor to automatically weight the loss of samples. Following the learned mechanism, the classification accuracy of target network is facilitated. The promotion of top-1 accuracy on CIFAR is plotted in Figure 4. In the early stages, the curves jitter near 0, which shows that target network and reference network are roughly the same. However, in the later period, LAW improves the performance of target model efficiently.

(a) CIFAR10
(b) CIFAR100
Figure 4: The relative test accuracy on CIFAR10 and CIFAR100 datasets. The y-axis is the accuracy gap between target and reference networks while the x-axis is iteration steps. The red line corresponds to clean CIFAR dataset and the green one corresponds to CIFAR dataset with label noise (noisy-CIFAR).
(a) CIFAR10
(b) CIFAR100
Figure 5: The loss gap between target and reference model in three episodes on noisy-CIFAR.

5.3 Effects on noisy labels compatibility

To confirm the generalization of LAW on noisy labels, we construct noisy-CIFAR datasets by randomly selecting 30% samples and changing their labels as stochastic categories. Figure 5 shows the loss gaps in noisy-CIFAR, under the same training settings in Section 5.2. On the contrary of Figure 4(a), the loss gaps are all under zero, which demonstrates that the learnt policy from LAW can distinguish those data points with corrupted label and reduce the weights of them. Figure 4 illustrates the superiority of LAW. Clearly, for both CIFAR10 and CIFAR100, the final accuracy of target model is significantly higher than that of the reference model. Table 1 shows the detailed performance of the proposed method. Compared with baselines, LAW can find a better weighting schedule, which achieves over higher in accuracy on both noisy-CIFAR dataset.

Method CIFAR10 CIFAR100
without noise with noise imbalance without noise with noise
reference model 80.42 70.70 75.99 54.03 31.76
target model(LAW) 81.45 80.87 77.63 55.70 43.31
Table 1: The top-1 accuracy of reference model and target model on CIFAR dataset.

5.4 Effects on imbalance data

We find LAW can also overcome the data imbalance problem. We make an imbalance dataset by discarding of samples with label of and in CIFAR10, while keeping the others the same as before. From Figure 6, it can be seen that the policy explored by LAW can deal with this problem well. We also plot the weights mean in data points in one batch in Figure 5(b), where the weight mean corresponds to the data points of label and (the red line) is much higher than the others. It indicates that the policy searched by LAW can increase the weights of samples with rarely seen classes. Detailed performance can be found in Table 1.

(a) test relative acc1
(b) weight mean
Figure 6: The relative test accuracy and reward on the on imbalance CIFAR10.

6 Conclusion

In this paper, we propose a novel example weighting framework called LAW, which can learn weighting policy from data adaptively based on reinforcement learning. Experimental results demonstrate the superiority of weighting policy explored by LAW over standard training pipeline. It is our future work to conduct experiments with more kinds of network architectures on various datasets.