In both supervised and unsupervised image generation tasks, one tries to balance a set of criteria, some overlapping and some presenting natural tradeoffs. For example, in image analogy tasks, multiple conditioned generation criteria, some based on comparing the input to the output using an MSE term and some on perceptual losses, may be combined with one or more adversarial losses. If the mapping is two-sided, one also employs a circularity loss. If the image is of a composite nature, mask-based losses are added, and so on.
It has been demonstrated repeatedly that adding losses can be beneficial. In many cases, formulating the desiderata as loss terms is well paved. However, as more loss terms are added, the space of possible balancing weights increases exponentially, and more resources need to be allocated to identify good configurations that would justify the added terms. Another common challenge is that at different stages of the training process, the optimal balance may change mescheder2017numerics. This is especially true for adversarial losses, which often become unstable. It is, therefore, necessary to have the balancing terms update dynamically during training.
In this work, we introduce Multi-Term Adam (MTAdam), an optimization algorithm for multi-term loss functions. MTAdam extends Adamkingma2014Adam and allows an effective training of an unweighted multi-term loss objective. Thus, MTAdam can streamline the computationally demanding task of hyperparameter search, required for effectively weighting multi-term loss objectives.
At every training iteration, a dynamic weight is assigned to each of the loss terms based on the magnitude of the gradient that each term entails. The weights are assigned in a way that balances between these gradients and equates their magnitude.
This, however, would be an ineffective balancing method, without three crucial components: (i) the balancing needs to occur independently for each layer of the neural network, since the relative contributions of the losses vary, depending on the layer, (ii) the balancing needs to be anchored by a dominant loss term in order to allow a natural progression of the effective learning rate, i.e., one cannot normalize such that the magnitude becomes a constant, and (iii) the update step needs to take into account the maximal variance among all losses, to support sufficient explorations in places of the parameter space, in which one of the losses becomes more sensitive.
The main focus of our experiments is in the field of conditional adversarial image generation. This domain is known to require multiple loss terms. Moreover, the quality of the generated images is evaluated using acceptable success metrics that are not directly optimized by any of the loss terms, such as FID fid. Our results show that MTAdam is able to recover from unbalanced starting points, in which the weight parameters are set inappropriately, while Adam and other baseline methods cannot.
2 Related Work
SGD with momentum (Nesterov) rumelhart1986learning
is an optimization algorithm that extends SGD, suggesting to update the network’s parameters by a moving average of the gradients, rather than the gradients at each step. Root Mean Square Propagation (RMSProp)rmsprop extends SGD by dividing the learning rate during the backward step, by the moving average of the second moment of the gradients of each parameter. Adam kingma2014Adam combines both principles and employs the first and second moments of the gradients of each learned parameter and applies them during the backward step. Adam has become a dominant optimizer, that is applied across many applications, and, in particular, it is the de-facto standard in the field of adversarial training. It is known for improving convergence to work well with the default values of its own hyperparameters ( and ).
Multiple methods have been suggested for selecting hyperparameters srinivas2009gaussian; wang2016optimization; bergstra2011algorithms; feurer2014using. Hyperband li2017hyperband performs hyperparameter search as an infinite-armed bandit problem utilizing a predefined amount of resources while searching for the best configuration of hyperparameters that maximizes the given success criterion. It can provide an order-of-magnitude speedup compared to Bayesian optimization bergstra2011algorithms; snoek2012practical; hutter2011sequential. We utilize Hyperband as a baseline that can directly optimize the FID score.
Image generation techniques utilize multiple loss terms, including pixel-wise norms, perceptual losses johnson2016perceptual, and one or more adversarial loss terms NIPS2014_5423. The need for a careful configuration of the weight parameters hinders research in this field and increases its costs.
In our experiments, we employ three different image generation methods, which vary in the type of supervision employed or the task being solved: (i) pix2pixisola2017image generates an image in domain based on an input image in domain , after observing matching pairs during training. (ii) CyclgeGAN zhu2017unpaired performs the same task while training in an unsupervised manner on unmatched images from the two domains. (iii) SRGAN ledig2017photo generates high-resolution images from low-resolution ones and is trained in a supervised way.
Multi-task learning shares some similarities with multi-term learning. In multi-task learning, one learns a few tasks concurrently, each associated with one or more loss terms. The most common approaches utilize hard parameter sharing caruana1997multitask; kendall2018multi; he2017mask; long2015learning; ren2015faster, in which a single model, with multiple task-specific heads, is trained. This approach is extremely effective when the tasks are highly related, as in MaskRCNN he2017mask. In the soft parameter sharing approach, each task has its own model, and a regularization term encourages the parameters of the different models to be similar yang2016trace; duong2015low. While MTAdam can be applied to multi-task learning (of both types) without substantial modifications, we leave the empirical validation of this ability for future investigations.
The Adam algorithm optimized one stochastic objective function over the set of parameters , where is an index of the current mini-batch of samples. In contrast, MTAdam optimizes a set of such terms . While Adam’s task is to minimize the expected value w.r.t. the parameters , MTAdam minimizes a weighted average of the terms. The weights of these mixtures are all positive, but otherwise unknown. The guiding principle for the determination of the weights at each iteration is that the moving average of the magnitude of the gradient of each term is equal across terms. This magnitude is evaluated and balanced at every layer of the neural network.
In Adam, two moments are continuously updated, using a moving average scheme: is the first moment of the gradient and
is the second moment. Both are vectors of the same size of. The moving averages are computed using the mixing coefficients and for the two moments.
MTAdam records such moments for each term separately. In addition, it uses a mixing coefficient in order to maintain the moving average of the gradient magnitude per each layer , which is denoted by .
Adam borrows from the SGD with momentum method (Nesterov) and updates the vector of parameters based on the weighted first moment of the gradient. In MTAdam, the first moment is computed based on a weighted gradient, in which the parameters of each layer for every term are weighted such that their magnitude is normalized by the factor . This way, across all layers, and at every time point, the terms contribute equally to the gradient step.
The Adam optimization algorithm is depicted in the left side of Alg. 1 and MTAdam on the right. In line 1, MTAdam algorithm initializes pairs of first and second moment vectors. This is similar to Adam, except for initializing a pair of moments for each loss term. In line 2, and different from Adam, MTAdam initializes first moments for the magnitude of the gradients, per layer. In line 3, both MTAdam and Adam iterate over the stochastic mini-batches, performing training steps.
In line 4, MTAdam iterates over the loss terms. For each loss term, MTAdam calculates its gradients over each one of the network layers (line 5), in an analogous manner to the way Adam computes the (single) gradients vector . In lines 6-8, MTAdam iterates over the layers, updates the moving average of the magnitude for each layer and loss term, and normalizes the gradients of the current layer and loss term, by multiplying with . This multiplication normalizes the magnitude of the current gradients of layer and loss term using the moving average . This normalization leads to all gradient magnitudes to be similar to that of the first loss term.
In line 5, we calculate the gradients of each specific loss term w.r.t , across all layers. For a layer index , the gradient is denoted by . We denoted by the concatenation of all per-layer gradients.
Lines 6-8 of MTAdam do not have an Adam analog. The operation performed normalizes the magnitude of the gradients of all loss terms, to match the magnitude of the first loss term. This assigns a unique role to the first term, as the primary loss to which all other losses are compared. By linking the magnitude to that of a concrete loss, and not to a static value (e.g., normalizing to have a unit norm), we maintain the relationship between the training progression and the learning rate.
The normalization iterates over the loss terms, and for each gradient, the first moment of the magnitude of the gradient is updated. The gradient magnitude is then normalized by that of the first loss term.
Then, in lines 9-12, MTAdam updates the first and second moments for each parameter and each loss term and computes their bias correction. This is similar to Adam, except that the moments are calculated separately for each loss term. In lines 13-15, MTAdam iterates over the loss terms and calculates the steps from each term. The steps are summed over , and the result is assigned .
In Adam, the update size is normalized by the second moment. In MTAdam, we divided by the maximal second moment among all loss terms. This division allows MTAdam to make smaller gradient steps, when a lower certainty is introduced by at least one of the loss terms. The motivation for this is that even if one of the losses is in a high-sensitivity region, where small updates create rapid changes to this term, then the step, regardless of the term which led to it, should be small. The importance of this maximization is demonstrated in the ablation study in Sec. 4.3.
Memory and Run Time Analysis Adam utilizes a pair of 1st and 2nd moments for each learned parameter. Given a network with learned parameters, it has a memory complexity of . MTAdam utilizes different pairs of 1st and 2nd moments for each parameter. In addition, MTAdam employs first moments magnitude for each layer. These two extensions bring the memory complexity to . In MTAdam, the run time complexity also depends on the number of loss terms . The dependence on the number of layers in Alg. 1 can be absorbed in .
We compare the results of MTAdam with five baselines: (1-3) Adam, RMSProp, and SGD with momentum, applied with unbalanced weightings. (4) Hyperband li2017hyperband applied to perform a hyper-parameter search for the lambdas between the loss terms, utilizing the FID metric. (5) Adam optimizer applied with a balanced weighting.
We note that baseline (4) benefits from running training multiple times and that baseline (5) employs the weights proposed for each method after a development process that is likely to have included a hyper-parameter search, in which multiple runs were evaluated by the developers of each method. For each optimization method, we employ the default parameters in pytorch. For the Adam experiments, we set the hyperparametersand to 0.9 and 0.999, respectively. For MTAdam, and are configured with the same values as Adam and .
4.1 MNIST classification
In order to turn MNIST to an unbalanced multi-term experiment, we compute the loss for each of the digits separately, creating ten loss terms, each weighted by a random weight from the uniform distribution between 1 and 1000. The test set is unweighted, which causes classes that are associated with lower weights to suffer from underfitting.
The official of the pytorch MNIST example is used: two convolutional layers followed by two fully connoted layers. The experiment is repeated 100 times, and for the sake of saving computations, hyperband is not tested. The results in Tab. 1, show a clear advantage for MTAdam over the other unbalanced alternatives.
|85.4 1.23||87.2 0.97||88.8 0.84||97.9 0.07||98.3 0.05|
4.2 Image synthesis
We demonstrate the ability of MTAdam to effectively converge when applied with unbalanced multi-term loss objectives. To this end, we compare the performance of MTAdam with other optimizers, evaluated on three methods, pix2pix isola2017image, CycleGan zhu2017unpaired and SRGAN ledig2017photo.
We used the learning rate as found in the implementation of each method isola2017image; zhu2017unpaired; ledig2017photo. Performance is evaluated using various metrics: L1, L2, PSNR, NMSE (normalized MSE), FID fid, and SSIM wang2004image. In each case, following the metrics used in the original work, with the addition of FID.
(left) L1 and (right) FID per epoch on the validation set of Facade for three pix2pix variants. Balanced-Adam and Unbalanced-Adam employ Adam, with a balanced and unbalanced weighting, respectively. Unbalanced-MTAdam utilizes MTAdam.
Pix2pix Experiments The objective function of the pix2pix generator has dual-terms:
Where is the GAN loss of the generator, is the pixel loss, and and are set to 100 and 1, respectively. In our study, we unbalance pix2pix models by setting to 100 (which implies a 1:1 ratio between the two loss terms).
Two datasets are used: matching facade images and their semantic labels tylevcek2013spatial and aerial photographs and matching maps isola2017image. Performance is reported on a holdout test set of each benchmark.
Fig. 1, depicts the test-performance of multiple models per each training epoch. The experiment’s name contains ‘unbalanced’ for the case of equal loss terms, and ‘balanced’ when using the prescribed values. As can be seen, unbalanced-MTAdam yields a similar convergence as the balanced-Adam. Specifically, the MTAdam experiment converges substantially better than the unbalanced Adam experiment, leading to improved L1 and FID scores.
In Tab. 2 MTAdam is compared with all baselines, applied for training pix2pix models. The top four models, in each section, utilize an identical unbalanced setting, each applied with a different optimizer. The Hyperband experiment utilizes the FID metric to perform a hyperparameter search on . The Hyperband experiments incorporate 40 trials (i.e. 40 training processes, which gives it a great advantage), each trial randomly sampled a different lambda from the range . We report Hyperband performance by utilizing the trial that is associated with the chosen lambda. In the aerial photograph experiment, the Hyperband failed to choose a value that is close to the original value of 1. In the facade experiment, Hyperband sampled at least one lambda value between 0.5 to 10, yet the retrieved best model utilizes a higher lambda value of 162.89, since this value showed a preferable FID value on the validation set. In the maps dataset, a value of 59.61 was selected.
The results of the table clearly show the advantage of MTAdam over all baseline methods. In addition, it also shows a slight improvement in performance in comparison to the usage of Adam on the prescribed weights. Fig. 2 presents two representative samples from the facades test set. Pix2pix-Unbalanced introduces visual artifacts and suffers from mode collapse (it generates the same corrupted patch in the top right corner of many images). Pix2pix-Unbalanced-MTAdam yields higher-quality images, similar to those of the original Pix2Pix, using the prescribed weights.
The CycleGAN objective function is composed of six loss terms:
Where the , , terms are the GAN loss, cycle GAN loss and identity loss, for each one of the sides (A or B). , and are set to 1, 10 and 0.5, respectively. In our experiments, we unbalance CycleGAN, by setting to 1000, leaving and unchanged.
Fig. 3 compares the convergence of three CycleGAN models: unbalanced MTAdam, unbalanced Adam, and balanced Adam. All models are applied on the horse2zebra dataset deng2009imagenet. As can be seen, MTAdam exhibits a competitive convergence to the Adam experiment applied with balanced weighting, which is much better than the performance of Adam on the unbalanced weights. Tab. 3 presents the performance of MTAdam applied on CycleGAN, compared to all five baselines, and evaluated on two datasets. MTAdam with an unbalanced initialization yields competitive performance to the Adam method, which uses the prescribed hyperparams.
Fig. 4 exhibits representative images from the CycleGAN models, showing that MTAdam, even when applied to a loss with unbalanced weights, matches the results of Adams on the prescribed weights.
In the Super Resolution GAN (SRGAN)ledig2017photo, the objective function is:
for which the total loss is a combination of a GAN loss, perceptual loss and MSE. The , and are set to 1, 0.001, 0.006. We unbalance SRGAN experiments by setting both and to 1 (i.e. applying all three terms with the same coefficient).
We evaluate SRGAN-Unbalanced-MTAdam on two test sets, Set14 set14 and BSD100BSD100. The results, listed in Tab. 5 demonstrate that MTAdam can effectively recover from unbalanced weights, while the other optimization methods suffer a degradation in performance.
|pix2pix facade||CycleGAN horse2zebra|
|(ii)||Removing L6, changing to||46.8||3850||0.258||133.5||70.1||75.2||69.5||145.7|
|(iii)||No scaling by in L8||46.7||3862||0.242||136.3||72.3||82.4||70.6||167.2|
|(iv)||Line 14: Scaling like Adam||48.2||3956||0.245||171.5||73.5||210.5||71.1||190.3|
|(v)||Line 14: scaling by mean instead of max||48.3||4020||0.249||156.8||73.7||208.4||71.6||173.5|
4.3 Ablation Study
Tab. 5 presents an ablation study for unbalanced pix2pix on the facade images and for unbalanced CycleGAN on the zebra2horse dataset. The following variants are considered (the descriptions refer to Alg. 1): (i) treating all layers as one layer and eliminating lines 6-8 altogether. (ii) training all layers at one, and performing the normalization in line 8 once for the entire gradient, i.e., still normalizing by the magnitude of the gradient of the first term. (iii) scaling the gradients of each layer and each term in line 8 by but not by . (iv) replacing the term in line 14 with , in an analog way to line 14 in Adam. (v) replacing the same term with the mean .
The results, shown in Tab. 5, indicate that it is crucial to employ a per layer analysis, in the way it is done in MTAdam, that normalizing by the magnitude of the gradient of an anchor term is highly beneficial, and that the maximal variance is a better alternative to alternative scaling terms.
MTAdam is shown to be a widely applicable optimizer, which can dynamically balance multiple loss terms in an effective way. While tested on image generation tasks, it is a general algorithm, which can find its usage in other types of tasks that require the optimization of multiple terms, such as domain adaptation and some forms of self-supervised learning. Our code can be found here222https://github.com/ItzikMalkiel/MTAdam. MTAdam is implemented as a generic pytorch optimizer and applying it is almost as simple as applying Adam.
This project has received funding from the European Research Council (ERC) under the European Unions Horizon 2020 research and innovation program (grant ERC CoG 725974).
Appendix A Image synthesis samples
Fig. 5 presents a few representative samples from the CycleGANzhu2017unpaired experiments, employed to transfer horse images to zebras. As can be seen, the unbalanced Adam training completely fails to generate zebras images and collapses to the identity mapping. This can be attributed to the domination of cyclic loss in the unbalanced settings, which dictates the convergence, leaveing the other loss terms ineffective. On the other hand, the unbalanced MTAdam model was able to successfully generate zebra images, in a quality that is similar or better to the balanced adam model (which uses the prescribed weights). In particular, in the first row, we can see that CycleGAN-Unbalanced-MTAdam was able to outperform the CycleGAN-Balanced-Adam experiment, as the latter fails to generate a zebra image for this particular sample, while MTAdam was able to generate a fairly good quality image.
Fig. 6 depicts a few samples from the same models described above, this time employed to generate horse images from zebra images. As can be seen, the unbalanced Adam fails again to generate horse images, collapses to the identity mapping, and this time also introduces visual artifacts in a few images (see the green artifact in the bottom row). In addition, MTAdam yields images of the same quality as the balanced Adam experiment.
Fig. 7 presents visual results for the SRGAN ledig2017photo experiments. The images were taken from Set14 set14. All models were trained to generate high-resolution images from low-resolution images, with a factor of 4x upscaling. As can be seen, the SRGAN model that employs unbalanced weights and Adam optimizer yields images with visual artifacts, and low fidelity. This can be attributed to the domination of the GAN and perceptual loss terms over the pixel-wise term. In contrast, the SRGAN model that employs our MTAdam optimizer with unbalanced weights yields images of the same quality as the SRGAN that utilizes the prescribed weights ledig2017photo and Adam.