Semi-Autoregressive Training Improves Mask-Predict Decoding

01/23/2020 ∙ by Marjan Ghazvininejad, et al. ∙ 3

The recently proposed mask-predict decoding algorithm has narrowed the performance gap between semi-autoregressive machine translation models and the traditional left-to-right approach. We introduce a new training method for conditional masked language models, SMART, which mimics the semi-autoregressive behavior of mask-predict, producing training examples that contain model predictions as part of their inputs. Models trained with SMART produce higher-quality translations when using mask-predict decoding, effectively closing the remaining performance gap with fully autoregressive models.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

While mainstream approaches to machine translation sequentially generate a translation token by token, recent advances in non-autoregressive Gu et al. (2018); Libovický and Helcl (2018); Sun et al. (2019) and semi-autoregressive decoding Lee et al. (2018); Stern et al. (2019); Gu et al. (2019a) have produced increasingly viable alternatives, which can decode substantially faster, with some cost to performance. One such approach, mask-predict Ghazvininejad et al. (2019), repeatedly predicts the entire target sequence in parallel, conditioned on the most confident word predictions from the previous iteration. The underlying model, a conditional masked language model, is trained by masking part of the (gold) target sequence and predicting the missing tokens. During training, all observed (unmasked) tokens come from the ground truth data. However, at inference time, the observed tokens are high-confidence model predictions, creating a discrepancy that can hurt performance in practice.

To remedy this problem, we introduce SMART (Semi-Autoregressive Training), a new training process for conditional masked language models that better matches the semi-autoregressive nature of the mask-predict decoding algorithm. We first create training examples by starting with the gold target sequence and masking a subset of its tokens, just like the original training process. We then use the current model to predict the sequence from the partially-observed input, and mask a different subset of tokens to create the training example’s input. The model is then trained to predict the gold target sequence based on this partially-observed prediction-based input, as well as the source sequence (see Figure 1), allowing it to better correct mistakes made during the early iterations of the mask-predict decoding loop.

SMART improves the performance of mask-predict decoding by 0.5 to 1.0 BLEU, effectively closing the gap with fully autoregressive models. For example, in the WMT’14 EN-DE benchmark, we arrive at a BLEU score of 27.65, just under the 27.75 achieved by a strong autoregressive baseline. This result implies that the of mask-predict decoding is not only a fast alternative to autoregressive beam search, but also an accurate one.

Example Generation Steps
The hotel is an ideal choice for business and leisure trips
is for leisure
The hotel is an choice choice for business and leisure travellers
The is an choice choice for travellers
Final Training Example
Das Hotel ist eine ideale Wahl für Geschäfts- und Urlaubs @@reisen
The is an choice choice for travellers
The hotel is an ideal choice for business and leisure trips
Figure 1: An illustration of how SMART generates new training examples. We start with the gold target sequence and randomly mask some of its tokens, use the partially-observed gold sequence to predict the entire translation , and then mask a random subset of tokens again. The resulting sequence is used as the model’s input during optimization, alongside the source , when training to predict the original gold sequence .

2 Background: Mask-Predict111For further detail, see Ghazvininejad et al. (2019).

Conditional Masked Language Models

A conditional masked language model (CMLM) takes a source sequence and a partially-observed target sequence

as input. It predicts the probabilities of the masked (unobserved) target sequence tokens

, assuming conditional independence between them (given the inputs).

Since each target token is either observed or masked, the predictions are effectively conditioned on the target sequence length as well, which must be predicted separately by the model.

Mask-Predict Decoding

Mask-predict generates the entire target sequence in a preset number of decoding iterations . Given the predicted target sequence length , decoding starts with a fully-masked target sequence.333In practice, the algorithm uses multiple length candidates, decodes each in parallel, and selects the best (highest-probability) result. Considering multiple length candidates is somewhat analogous to beam search in autoregressive decoding. The model then predicts the entire sequence in parallel, setting each token with its most probable assignment ().

For each iteration , the algorithm performs a mask step, in which the tokens with the lowest probabilities are replaced with a special mask token.444The number of masked tokens gradually shrinks with . This is followed by a predict step, where the model predicts the masked tokens while conditioning on the observed high-confidence predictions from the previous iterations.

Non-Autoregressive Training

The original training process for CMLMs takes the gold target sequence and masks out random tokens, where . The model then predicts only the masked tokens while conditioning on the observed target tokens, which are always correct. Training optimizes the cross-entropy between the predictions and the correct values of the masked tokens. We call this process NART (Non-Autoregressive Training) because it only uses gold data as its inputs, and does not condition on model predictions.

3 Semi-Autoregressive Training

The non-autoregressive training process of Ghazvininejad et al. (2019), NART, creates training examples where all the observed tokens are correct – the right word type in the right position. This assumption does not hold for mask-predict decoding, since the observed tokens (high-confidence predictions from previous iterations) are not always correct. We introduce an improved training process for CMLMs that better reflects the semi-autoregressive nature of mask-predict decoding by creating training examples from predicted target sequences, not gold ones. We name this training procedure SMART (Semi-Autoregressive Training).

Like NART, we start with a gold target sequence and randomly mask tokens, where is sampled uniformly from 1 to (the target’s length). The CMLM then predicts the entire sequence, including the observed tokens, creating a new sequence from the most probable assignments. We repeat the masking process, but with different random values, to construct the final training example; i.e. we sample and randomly mask tokens from to create the partially-observed target sequence . Figure 1 illustrates this process.555We perform a double forward pass only when creating training examples. During inference, each mask-predict iteration includes only a single forward pass in the predict step.

The observed portion of may contain incorrect observations because it is based on predictions (). Therefore, we optimize the cross entropy for predicting all tokens, not only the masked ones. This change allows models trained with SMART to fix incorrect observations during prediction, and can be integrated into the mask-predict algorithm by modifying the predict step: instead of predicting just the masked tokens, predict every target token, and update those tokens whose predictions differ from the input.

Training Mode Decoding WMT’14 WMT’17
Iterations EN-DE DE-EN EN-ZH ZH-EN
NART 1 18.05 21.83 24.23 13.64
SMART 1 18.58 23.77 24.15 13.51
NART 4 25.94 29.90 32.63 21.90
SMART 4 27.03 30.87 33.37 22.61
NART 10 27.03 30.53 33.19 23.21
SMART 10 27.65 31.27 34.06 23.78
Table 1: The performance (test set BLEU) of semi-autoregressive training (SMART), compared to the original non-autoregressive training for CMLMs (NART). All models are decoded with mask-predict.
Model Decoding WMT’14 WMT’17
Iterations EN-DE DE-EN EN-ZH ZH-EN
Autoregressive Transformer with Beam Search 27.61 31.38 34.31 23.65
    + Knowledge Distillation 27.75 31.30 34.38 23.91
SMART CMLM with Mask-Predict 10 27.65 31.27 34.06 23.78
27.64 31.44 34.10 24.12
Table 2: The performance (test set BLEU) of semi-autoregressive training (SMART), compared to the standard (sequential) transformer. Length beam, beam size and length penalty is tuned for each model on validation set.

4 Experiments

We demonstrate, over 4 benchmarks, that replacing the original CMLM training process with SMART produces higher quality translations when decoding with mask-predict. Moreover, we show that our new approach closes the performance gap between semi-autoregressive and fully autoregressive machine translation. Finally, we conduct an ablation study and analyze how SMART balances between easy and hard training examples.

4.1 Setup

We evaluate on two machine translation datasets, in both directions (four benchmarks overall): WMT’14 English-German (4.5M sentence pairs), and WMT’17 English-Chinese (20M sentence pairs). The datasets are tokenized into subword units using BPE Sennrich et al. (2016). We use the same preprocessed data as Vaswani et al. (2017) and Wu et al. (2019) for WMT’14 EN-DE and WMT’17 EN-ZH respectively. We evaluate performance with BLEU Papineni et al. (2002) for all language pairs, except from English to Chinese, where we use SacreBLEU Post (2018).666SacreBLEU hash: BLEU+case.mixed+lang.en-zh +numrefs.1+smooth.exp+test.wmt17+tok.zh+version.1.3.7

We implement our experiments based on the code of mask-predict Ghazvininejad et al. (2019)

, which uses the standard model and optimization hyperparameters for transformers in the base configuration

Vaswani et al. (2017): 512 model dimensions, 2048 hidden dimensions, model averaging, etc. We also follow the standard practice of knowledge distillation Gu et al. (2018); Ghazvininejad et al. (2019); Zhou et al. (2019) in the non-autoregressive machine translation literature, and train both our model and the baselines on translations produced by a large autoregressive transformer model. For autoregressive decoding, we tune the beam size () and length penalty on the development set, and similarly tune the number of length candidates () for mask-predict decoding.

4.2 Results

We first compare SMART to the original CMLM training process (NART). Table 1 shows that SMART typically produces better models, with an average gain of 0.71 BLEU. Even with a single decoding iteration (the purely non-autoregressive scenario), SMART produces better models in WMT’14 and falls short of the baseline by a slim margin in WMT’17 (0.08 and 0.13 BLEU).777We show the NART numbers reported by Ghazvininejad et al. (2019), where length candidates were used. For fair comparison, we decoded the NART models while tuning the number of length candidates on the development set, but observed only minor deviations from the setting.

We also compare between SMART-trained CMLMs with mask-predict decoding and autoregressive transformers with beam search. Table 2

shows that a constant number of decoding steps (10) brings our semi-autoregressive approach very close to the autoregressive baseline. With the exception of English to Chinese, the performance differences are within the typical random seed variance. Increasing the number of mask-predict iterations to

yields even more balanced results; in two of the four benchmarks, the small performance margins are actually in favor of our semi-autoregressive approach.

4.3 Ablation Study

We consider several variations of our proposed method to quantify the effect of each component. To prevent overfitting, we evaluate on the development set using length candidates.

Repredicting All Tokens

Besides SMART, we also augment the mask-predict algorithm to predict all tokens – not only the masked ones – during the predict step (Section 3). Table 3 compares this new version of mask-predict to the original. We find that predicting all tokens increases performance by 0.40 BLEU on average when using 4 decoding iterations. With 10 decoding iterations, the gains shrink to around 0.08, but are still consistently positive.

Predicted Tokens Decoding WMT’14
Iterations EN-DE DE-EN
Masked Tokens 4 25.18 29.61
All Tokens 4 25.61 29.98
Masked Tokens 10 26.06 30.29
All Tokens 10 26.14 30.37
Table 3: The performance (development set BLEU) of SMART-trained models with two flavors of mask-predict: predicting only masked tokens (original version), and predicting all tokens at each iteration.
Forward Passes Decoding Iterations
1 4 10
2 24.24 29.98 30.37
3 23.74 29.89 30.28
4 23.81 39.66 30.10
Table 4: Increasing the number of forward passes used to produce each training example in SMART can negatively effect the resulting model. Performance measured on WMT’14 DE-EN (development set BLEU).
Cross-Entropy Loss Decoding Iterations
1 4 10
1st Pass + 2nd Pass 23.89 29.78 30.05
Only 2nd Pass 24.24 29.98 30.37
Table 5: Using gradients from the first forward pass in SMART can negatively effect the resulting model. Performance measured on WMT’14 DE-EN (development set BLEU).
Gold Mask Ratio Decoding Iterations
() 1 4 10
    0% 23.05 29.30 29.69
  25% 23.19 29.41 29.84
  50% 23.04 29.99 30.15
  75% 23.40 29.87 30.36
100% 16.78 18.44 18.62
Uniform 24.24 29.98 30.37
Table 6: The effect of the gold masking ratio (as a proxy of training example difficulty) on performance, measured on WMT’14 DE-EN (development set BLEU).

Multi-Iteration SMART

Lee et al. (2018) also proposed a semi-autoregressive training regime, in which the training process imitated the iterative refinement decoding algorithm. They use four decoding iterations during training, while accumulating the gradients from every model invocation. We try to apply the same ideas to SMART, but find that they do not improve our method.

We first consider creating our training examples by performing multiple mask-predict iterations during training, instead of just two. Table 4 shows that training on examples created by three or four forward passes of the model yields slightly (but consistently) worse results.

We also experiment with applying the cross-entropy loss after each forward pass (instead of just the last one). Table 5 reveals that using these gradients produces slightly weaker models, suggesting that using only the examples produced by the latter forward pass provides the model with a better training signal.

Difficulty Analysis

SMART produces training examples from model predictions conditioned on partially-observed gold data (). Intuitively, the amount of masked gold data will affect the difficulty of said example. When 0% of the gold tokens are masked, the model will likely just copy its input (), and produce easier training examples, effectively reducing SMART to NART. When 100% of the gold tokens are masked, the training example will be entirely prediction-based, posing a significantly harder challenge for the model.

To explore the effect of training example difficulty on performance, we replace the uniformly distributed number of masks

with different fixed ratios. Table 6 shows training with harder examples (50% to 75% gold mask ratio) improves performance, but that training with inputs that are not based on “a grain of truth” (100% gold mask ratio) is not conducive to a successful learning process. By sampling from a uniform distribution, SMART provides training examples from a broad spectrum of difficulties.

5 Related Work

SMART was inspired by the iterative refinement model of Lee et al. (2018), who also used a semi-autoregressive training method. While Lee et al. seed their model inputs with artificial noise during training, the only source of noise in SMART is the model predictions.

Other semi-autoregressive models have also been able to close the performance gap with beam search decoded autoregressive models. Shu et al. (2019) demonstrate how a latent-variable approach can outperform the autoregressive baseline on Japanese to English translation, but still observe a significant performance gap on WMT’14 EN-DE. Others have introduced insertion operators Stern et al. (2019); for example, the Levenshtein transformer Gu et al. (2019b) allows for both insertions and deletions, achieving equal-quality translations with a smaller number of decoding iterations. SMART achieves a similar result with a simple approach that requires neither latent variables nor insertions.

6 Conclusion

We introduced SMART (Semi-Autoregressive Training), a new training process for conditional masked language models that better matches the semi-autoregressive nature of the mask-predict decoding algorithm. SMART training produces models that are competitive with mainstream autoregressive models in terms of performance, while retaining the benefits of fast parallel decoding.


  • M. Ghazvininejad, O. Levy, Y. Liu, and L. Zettlemoyer (2019) Mask-predict: parallel decoding of conditional masked language models. In Proc. of EMNLP-IJCNLP, External Links: Link Cited by: §1, §3, §4.1, footnote 1, footnote 7.
  • J. Gu, J. Bradbury, C. Xiong, V. O. Li, and R. Socher (2018)

    Non-autoregressive neural machine translation

    In Proc. of ICLR, Cited by: §1, §4.1.
  • J. Gu, Q. Liu, and K. Cho (2019a) Insertion-based decoding with automatically inferred generation order. arXiv preprint arXiv:1902.01370. Cited by: §1.
  • J. Gu, C. Wang, and J. Zhao (2019b) Levenshtein transformer. In Proc. of NeurIPS, External Links: Link Cited by: §5.
  • J. D. Lee, E. Mansimov, and K. Cho (2018) Deterministic non-autoregressive neural sequence modeling by iterative refinement. In Proc. of EMNLP, External Links: Link Cited by: §1, §4.3, §5.
  • J. Libovický and J. Helcl (2018) End-to-end non-autoregressive neural machine translation with connectionist temporal classification. In Proc. of EMNLP, External Links: Link Cited by: §1.
  • K. Papineni, S. Roukos, T. Ward, and W. Zhu (2002) Bleu: a method for automatic evaluation of machine translation. In Proceedings of the 40th Annual Meeting of the Association for Computational Linguistics, External Links: Link Cited by: §4.1.
  • M. Post (2018) A call for clarity in reporting BLEU scores. In Proceedings of the Third Conference on Machine Translation: Research Papers, External Links: Link Cited by: §4.1.
  • R. Sennrich, B. Haddow, and A. Birch (2016) Neural machine translation of rare words with subword units. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), External Links: Link Cited by: §4.1.
  • R. Shu, J. Lee, H. Nakayama, and K. Cho (2019) Latent-variable non-autoregressive neural machine translation with deterministic inference using a delta posterior. arXiv preprint arXiv:1908.07181. Cited by: §5.
  • M. Stern, W. Chan, J. R. Kiros, and J. Uszkoreit (2019) Insertion transformer: flexible sequence generation via insertion operations. In Proc. of ICML, External Links: Link Cited by: §1, §5.
  • Z. Sun, Z. Li, H. Wang, D. He, Z. Lin, and Z. Deng (2019) Fast structured decoding for sequence models. In Advances in Neural Information Processing Systems, pp. 3011–3020. Cited by: §1.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems, pp. 5998–6008. Cited by: §4.1, §4.1.
  • F. Wu, A. Fan, A. Baevski, Y. N. Dauphin, and M. Auli (2019) Pay less attention with lightweight and dynamic convolutions. International Conference on Learning Representations. Cited by: §4.1.
  • C. Zhou, G. Neubig, and J. Gu (2019) Understanding knowledge distillation in non-autoregressive machine translation. arXiv preprint arXiv:1911.02727. Cited by: §4.1.