Aligned Cross Entropy for Non-Autoregressive Machine Translation

04/03/2020 ∙ by Marjan Ghazvininejad, et al. ∙ 9

Non-autoregressive machine translation models significantly speed up decoding by allowing for parallel prediction of the entire target sequence. However, modeling word order is more challenging due to the lack of autoregressive factors in the model. This difficultly is compounded during training with cross entropy loss, which can highly penalize small shifts in word order. In this paper, we propose aligned cross entropy (AXE) as an alternative loss function for training of non-autoregressive models. AXE uses a differentiable dynamic program to assign loss based on the best possible monotonic alignment between target tokens and model predictions. AXE-based training of conditional masked language models (CMLMs) substantially improves performance on major WMT benchmarks, while setting a new state of the art for non-autoregressive models.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Non-autoregressive machine translation models can significantly improve decoding speed by predicting every word in parallel (Gu et al., 2018; Libovický and Helcl, 2018). This advantage comes at a cost to performance since modeling word order is trickier when the model cannot condition on its previous predictions. A range of semi-autoregressive models (Lee et al., 2018; Stern et al., 2019; Gu et al., 2019; Ghazvininejad et al., 2019) have shown there is a speed-accuracy tradeoff that can be optimized with limited forms of autoregression. However, increasing performance of the purely non-autoregressive models without sacrificing decoding speed remains an open challenge. In this paper, we present a new training loss for non-autoregressive machine translation that softens the penalty for word order errors, and significantly improves performance with no modification to the model or to the decoding algorithm.

Target it tastes pretty good though
Model Predictions (Top 5) but it tastes delicious .
however that makes good ,
for this looks tasty so
and for taste fine though
though the feels exquisite !
Figure 1: The model predictions are quite similar to the target, but misaligned by one token. The first and second target tokens (it tastes) are predicted in the second and third positions, respectively, leaving only the predictions in the fourth and fifth positions aligned with the target. The cross-entropy loss will heavily penalize the predictions in the first, second, and third positions.

Existing models (both autoregressive and non-autoregressive) are typically trained with cross entropy loss. Cross entropy is a strict loss function, where a penalty is incurred for every word that is predicted out of position, even for output sequences with small edit distances (see Figure 1). Autoregressive models learn to avoid such penalties, since words are generated conditioned on the sentence prefix. However, non-autoregressive models do not know the exact sentence prefix, and should (intuitively) focus more on root errors (e.g. a missing word) while allowing more partial credit for cascading errors (the right word in the wrong place).

To achieve this more relaxed loss, we introduce aligned cross entropy (AXE), a new objective function that computes the cross entropy loss based on an alignment between the sequence of token labels and the sequence of token distribution predictions. AXE uses dynamic programming to find the monotonic alignment that minimizes the cross entropy loss. It provides non-autoregressive models with a more accurate training signal by ignoring absolute positions and focusing on relative order and lexical matching. We efficiently implement AXE via matrix operations, and use it to train conditional masked language models (CMLM; Ghazvininejad et al., 2019) for machine translation. AXE only slightly increases training time compared to cross entropy, and requires no changes to parallel argmax decoding.

Extensive experiments on machine translation benchmarks demonstrate that AXE substantially boosts the performance of CMLMs, while having the same decoding speed. In WMT’14 EN-DE, training CMLMs with AXE (instead of the regular cross entropy loss) increases performance by 5 BLEU points; we observe similar trends in WMT’16 EN-RO and WMT’17 EN-ZH. Moreover, AXE CMLMs significantly outperform state-of-the-art non-autoregressive models, such as FlowSeq (Ma et al., 2019), as well as the recent CRF-based semi-autoregressive model with bigram LM decoding (Sun et al., 2019). Our detailed analysis suggests that training with AXE makes models more confident in their predictions, thus reducing multimodality, and alleviating a key problem in non-autoregressive machine translation.

2 Aligned Cross Entropy

Let be a target sequence of tokens , and be the model predictions, a sequence of

token probability distributions

. Our goal is to find a monotonic alignment between and that will minimize the cross entropy loss, and thus focus the penalty on lexical errors (predicting the wrong token) rather than positional errors (predicting the right token in the wrong place).

We define an alignment to be a function that maps target positions to prediction positions, i.e. . We further assume that this alignment is monotonic, i.e. iff . Given a specific alignment , we define a conditional loss as:

  Input: tokens , predictions
  for  to  do
  end for
  for  to  do
  end for
  for  to  do
     for  to  do
     end for
  end for
Algorithm 1 Aligned Cross Entropy

The first term of this loss function is an aligned cross entropy between and , and the second term is a penalty for unaligned predictions. Epsilon () is a special “blank” token in our vocabulary that appears in the probability distributions, but that does not appear in the final output string.

Now, the final loss is the minimum over all possible monotonic alignments of the conditional loss:


Finding the optimal monotonic alignment between two sequences is a well studied problem. For instance, dynamic time warping (DTW) (Sakoe and Chiba, 1978) is a well-known algorithm for finding the optimal alignment between two different time series. Here we have extended the idea to compute the optimal alignment between a sequence of target tokens and a sequence of prediction probability distributions. We use a simple dynamic program to find the optimal alignment while calculating the AXE loss.

Align Aligns the current target with the current prediction , updating along the diagonal.
Skip Prediction Skips the current prediction by predicting an empty token (), updating along the axis. This operation is akin to inserting an empty token to the target sequence at the -th position.
Skip Target Skips the current target by predicting it without incrementing the prediction iterator , updating along the axis. This operation is akin to duplicating the prediction

. The hyperparameter

controls how expensive this operation is; high values of will discourage alignments that skip too many target tokens.
Table 1: The three local update operators in AXE’s dynamic program.
Target it tastes pretty good though
Alignment 2 3 3 4 5
Model Predictions (Top 5) but it tastes delicious
however makes good .
that looks tasty ,
for this taste fine so
and for feels exquisite though
Figure 2: An example illustrating how AXE aligns model predictions with the target sequence. The operations that created the optimal alignment in this example are: (1) skip prediction (cost: ), (2) align to (cost: ), (3) align to (cost: ), (4) skip target (cost: ), (5) align to (cost: ), (6) align to (cost: ).

Dynamic Programming

Given a sequence of target tokens and a sequence of predictions we propose a method to find the score of the optimal alignment between any prefix of these two sequences and , for any and . The score of the optimal alignment for the full sequences is obtained at and .

We start by defining a matrix of by dimensions, respectively corresponding to and , where represents the minimum loss value for aligning to as defined in Equation 2. We initialize to be and then proceed to fill the matrix by taking the local minimum at each cell from three possible operators: Align, Skip Prediction , and Skip Target. Table 1 describes each operation and its update formula. Once the matrix is full, the cell will contain the cross entropy loss of the optimal alignment. Algorithm 1 lays out a straightforward implementation of AXE’s dynamic program.

According to Equation 2, the optimal alignment can be many-to-one, where multiple target positions can be mapped to a single prediction. This would be computed by aligning the first mapped token and skipping the rest of target tokens. To discourage skipping too many target tokens, we penalize skip target operators separately with a parameter as described in Table 1. Setting will result in the loss function defined in Equation 2, but as we show in our ablation study (Section 4.3), higher values yield better performance in practice.

Efficient Implementation

The implementation in Algorithm 1 has time complexity. However, multiple updates of the matrix

can be parallelized on GPUs and other tensor-processing architectures. Rather than iterating over each cell, we iterate over each

anti-diagonal, computing all the values along the anti-diagonal in parallel. In other words, we first compute the values of , followed by , etc. Since the number of anti-diagonals is , we arrive at a time complexity of . Since is typically on the same order of magnitude as , the linear cost of computing AXE during training becomes negligible compared to forward and backward passes through the model.111Batch implementation of this algorithm is straightforward. By doing so, we are able to achieve training times similar to (about 1.2 times slower than) training with cross entropy loss.


Figure 2 depicts an example application of AXE. We see that the predictions are generally good, but start with a shift with respect to the target. This misalignment would cause the regular cross entropy loss to severely penalize the first three predictions, even though and are correct when aligned with and . AXE, on the other hand, finds an alignment between the target and the predictions, which allows it to focus the penalty on the redundant prediction in and the missing token , i.e. the root errors.

3 Training Non-Autoregressive Models

We use AXE to train conditional masked language models (CMLMs) for non-autoregressive machine translation (Ghazvininejad et al., 2019).222While in this work we apply AXE to CMLMs, the loss function can be used to train other models as well. We leave further investigation of this direction to future work.

3.1 Conditional Masked Language Models

A conditional masked language model takes a source sequence and a partially-observed target sequence as input, and predicts the probabilities of the masked (unobserved) target sequence tokens . The underlying architecture is an encoder-decoder transformer (Vaswani et al., 2017).

In the original paper, CMLMs are used for machine translation where a random subset of tokens are masked at training time. However, at inference all target tokens are masked () and the length of

(the number of masked tokens) is unknown. To estimate the length of

, an auxiliary task is introduced to predict the target length based on the source sequence .333See (Ghazvininejad et al., 2019) for further detail.

3.2 Adapting CMLMs to AXE

In our case, the model can also produce blank tokens (), which effectively shorten the predicted sequence’s length. To account for potentially skipped tokens during inference, we multiply the predicted length by a hyperparameter (which is tuned on the validation set) before applying argmax decoding.

3.3 Adapting the Training Objectives to AXE

Since this work focuses on the purely non-autoregressive setting, the entire target sequence will be masked at inference time (). The same does not have to hold for training; we can utilize partially observed sequences in order to provide the learner with easier and more focused training examples. We experiment with three variations:

Unobserved Input, Predict All

All the tokens in the target sequence are masked, and the model is expected to predict all of them. This is a direct replication of the task at inference time. While AXE allows for the number of masked tokens to be different from the length of the gold target sequence , we found that setting produced better models in preliminary experiments.

Partially-Observed Input, Predict All

As in the original CMLM training process, a random subset of the target sequence is masked before being passed onto the model as input.444The number of masked input tokens is distributed uniformly between and . We then apply AXE on the entire sequence, regardless of which tokens were observed. When training on partially-observed inputs, we always set to avoid further alterations of the gold target sequence beyond masking.

Partially-Observed Input, Predict Masks

The straightforward application of AXE to CMLM training (which ignores whether each token was masked or observed) works well in practice. However, we can also allow AXE to skip the observed tokens when computing cross entropy, and focus the training signal on the actual task. We do so by setting for every observed token ; i.e. if the -th token is observed and is aligned with the prediction corresponding to the same position (), there is no penalty. Our ablation studies show that this modification provides a modest but consistent boost in performance (see Section 4.3). As a result, we use this setting for training our model.

Model WMT’14 WMT’16 WMT’17
Cross Entropy CMLM (Ghazvininejad et al., 2019) 18.05 21.83 27.32 28.20 24.23 13.64
AXE CMLM (Ours) 23.53 27.90 30.75 31.54 30.88 19.79
Table 2: The performance (test set BLEU) of AXE CMLM compared to cross entropy CMLM on all of our benchmarks. Both models are purely non-autoregressive, using a single forward pass during argmax decoding.
Model Decoding WMT’14 WMT’16
Iterations EN-DE DE-EN EN-RO RO-EN
Transformer Base 27.61 31.38 34.28 33.99
    + Knowledge Distillation 27.75 31.30 — — — —
Iterative Refinement (Lee et al., 2018) 1 13.91 16.77 24.45 25.73
CTC Loss (Libovický and Helcl, 2018) 1 17.68 19.80 19.93 24.71
NAT w/ Fertility (Gu et al., 2018) 1 17.69 21.47 27.29 29.06
Cross Entropy CMLM (Ghazvininejad et al., 2019) 1 18.05 21.83 27.32 28.20
Auxiliary Regularization (Wang et al., 2019) 1 20.65 24.77 — — — —
Bag-of-ngrams Loss (Shao et al., 2019) 1 20.90 24.61 28.31 29.29
Hint-based Training (Li et al., 2019) 1 21.11 25.24 — — — —
FlowSeq (Ma et al., 2019) 1 21.45 26.16 29.34 30.44
Bigram CRF (Sun et al., 2019) 1 23.44 27.22 — — — —
AXE CMLM (Ours) 1 23.53 27.90 30.75 31.54
Table 3: The performance (test set BLEU) of CMLMs trained with AXE, compared to other non-autoregressive methods. The standard (autoregressive) transformer results are also reported for reference.

4 Experiments

We evaluate CMLMs trained with AXE on 6 standard machine translation benchmarks, and demonstrate that AXE significantly improves performance over cross entropy trained CMLMs and over recently-proposed non-autoregressive models as well.

4.1 Setup

Translation Benchmarks

We evaluate our method on both directions of three standard machine translation datasets with various training data sizes: WMT’14 English-German (4.5M sentence pairs), WMT’16 English-Romanian (610k pairs), and WMT’17 English-Chinese (20M pairs). The datasets are tokenized into subword units using BPE (Sennrich et al., 2016).555We run joint BPE for all language pairs except English-Chinese. We use the same data and preprocessing as Vaswani et al. (2017), Lee et al. (2018), and Wu et al. (2019) for WMT’14 EN-DE, WMT’16 EN-RO, and WMT’17 EN-ZH respectively. We evaluate performance with BLEU (Papineni et al., 2002) for all language pairs, except for translating English to Chinese where we use SacreBLEU (Post, 2018).666SacreBLEU hash: BLEU+case.mixed+lang.en-zh +numrefs.1+smooth.exp+test.wmt17+tok.zh+version.1.3.7


We generally follow the transformer base hyperparameters (Vaswani et al., 2017): 6 layers for the encoder and decoder, 8 attention heads per layer, 512 model dimensions, and 2048 hidden dimensions. We follow the weight initialization schema from BERT (Devlin et al., 2018), and sample weights from , set biases to zero, and set layer normalization parameters to and . For regularization, we set dropout to , and use weight decay and label smoothing with . We train batches of 128k tokens using Adam (Kingma and Ba, 2015) with and . The learning rate warms up to

within 10k steps, and then decays with the inverse square-root schedule. We train all models for 300k steps. We measure the validation loss at the end of each epoch, and average the 5 best checkpoints based on their validation loss to create the final model. We train all models with mixed precision floating point arithmetic on 16 Nvidia V100 GPUs. For autoregressive decoding, we use a beam size of

(Vaswani et al., 2017) and tune the length penalty on the validation set. Similarly we use length candidates for CMLM models, tune the length multiplier (,777Our preliminary analysis shows that AXE selects Skip Prediction in of the time, roughly suggesting that five to ten percent of generated tokens are epsilons. Hence, we search the same range for the length multiplier. and the target skipping penalty () on the validation set.

Knowledge Distillation

Similar to previous work on non-autoregressive translation (Gu et al., 2018; Lee et al., 2018; Ghazvininejad et al., 2019; Stern et al., 2019), we use sequence-level knowledge distillation (Kim and Rush, 2016) by training CMLMs on translations generated by a standard left-to-right transformer model (transformer large for WMT’14 EN-DE and WMT’17 EN-ZH, transformer base for WMT’16 EN-RO). We report the performance of standard autoregressive base transformers trained on distilled data for WMT’14 EN-DE and WMT’17 EN-ZH.

Model Decoding WMT’14 WMT’16
Iterations EN-DE DE-EN EN-RO RO-EN
Knowledge Distillation
AXE CMLM (Ours) 1 23.53 27.90 30.75 31.54
Raw Data
Cross Entropy CMLM (Ghazvininejad et al., 2019) 1 10.64 — — 21.22 — —
CTC Loss (Libovický and Helcl, 2018) 1 17.68 19.80 19.93 24.71
FlowSeq (Ma et al., 2019) 1 18.55 23.36 29.26 30.16
AXE CMLM (Ours) 1 20.40 24.90 30.47 31.42
Table 4: The performance (test set BLEU) of AXE CMLM, compared to other non-autoregressive methods on raw data. The result of AXE CMLM trained with distillation is also reported as a reference.

4.2 Main Results

AXE vs Cross Entropy

We first compare the performance of AXE-trained CMLMs to that of CMLMs trained with the original cross entropy loss. Table 2 shows that training with AXE substantially increases the performance CMLMs across all benchmarks. On average, we gain 5.2 BLEU by replacing cross entropy with AXE, with gains of up to 6.65 BLEU in WMT’17 EN-ZH.

State of the Art

We compare the performance of CMLMs with AXE against nine strong baseline models: the fertility-based sequence-to-sequence model (Gu et al., 2018), transformers trained with CTC loss (Libovický and Helcl, 2018), the iterative refinement approach (Lee et al., 2018), transformers trained with auxiliary regularization (Wang et al., 2019), CMLMs trained with (regular) cross entropy loss (Ghazvininejad et al., 2019), Flowseq: a latent variable model based on generative flow (Ma et al., 2019), hint-based training (Li et al., 2019), bag-of-ngrams training (Shao et al., 2019), and the CRF-based semi-autoregressive model (Sun et al., 2019). All of these models except the last one are purely non-autoregressive, while the CRF-based model uses bigram statistics during decoding, which deviates from the purely non-autoregressive setting.888CMLMs (Ghazvininejad et al., 2019) and the iterative refinement method (Lee et al., 2018) are presented as semi-autoregressive models that run in multiple decoding iterations. However, the first decoding iteration of these models is purely non-autoregressive, which is what we use as our baselines.

Table 3 shows that our system yields the highest BLEU scores of all non-autoregressive models. AXE-trained CMLMs outperform the best purely non-autoregressive model (FlowSeq) on both directions of WMT’14 EN-DE and WMT’16 EN-RO by 1.6 BLEU on average. Moreover, our approach achieves higher BLEU scores than the semi-autoregressive CRF decoder across all available benchmarks.

Raw Data

Finally, we compare the performance of AXE to other methods that train on raw data without knowledge distillation. Table 4 shows that AXE CMLMs still significantly outperform other non-autoregressive models in the raw data scenario. In addition, comparing raw data to knowledge distillation training follows previously-published results that demonstrate the importance of knowledge distillation for non-autoregressive approaches  (Gu et al., 2018; Ghazvininejad et al., 2019; Zhou et al., 2019), although the gap is much smaller for WMT’16 EN-RO.

4.3 Ablation Study

In this section, we consider several variations of our proposed method to investigate the effect of each component. We test the performance of AXE CMLMs with these variations on the WMT’14 DE-EN and EN-DE datasets. To prevent overfitting, we evaluate on the validation set using length candidates.

Training Objective WMT’14
Input Tokens Loss Function EN-DE DE-EN
Unobserved All Tokens 21.97 26.32
Partially-Observed All Tokens 22.80 27.59
Partially-Observed Only Masks 23.13 28.01
Table 5: The effect of different training objectives on performance, measured on WMT’14 DE-EN and EN-DE (validation set BLEU).

Different Training Objectives Table 5 shows the effects of different training objectives (Section 3.3), in which all or part of the target tokens are masked and the loss function is calculated on all tokens or masked tokens only. We find that simulating the inference scenario, where all tokens are unobserved, is actually less effective than revealing a subset of the target tokens as input during training. We speculate that partially-observed inputs add easier examples to the training set, allowing for better optimization as in curriculum learning (Bengio et al., 2009). We also see that including only the masked tokens in the loss function gives us a modest but consistent boost in performance, possibly because the training signal is focused on the actual task.

Skip Target WMT’14 EN-DE WMT’14 DE-EN
Penalty BLEU Skip Target BLEU Skip Target
1 22.60 17.57% 26.84 16.87%
2 23.01 10.91% 27.77 10.53%
3 22.85 9.56% 27.87 9.04%
4 22.90 8.14% 28.01 7.83%
5 23.13 7.40% 27.79 6.95%
Table 6: The effect of changing the skip target penalty coefficient on performance (BLEU) and the percentage of target words that were skipped, using the validation sets of WMT’14 DE-EN and EN-DE.

Skip Target Penalty The hyperparameter acts as a coefficient for the penalty associated with skipping a target token (see Table 1 for a definition). We experiment with different values of , and report our findings in Table 6. We observe that tuning can significantly improve performance with respect to the default of . As intended, high values of discourage alignments that skip too many target tokens.

Length Multiplier The length multiplier inflates the length predicted by a CMLM to account for extra blank tokens () that the model could potentially generate (see Section 3.2 for more detail). Table 7 compares the effect of different length multiplier values. Using the best length multiplier increases the performance by 0.53 BLEU on average for WMT’14 EN-DE and WMT’16 EN-RO.

Length Multiplier WMT’14 WMT’16
22.96 27.50 30.43 32.22
23.06 27.56 30.43 32.25
23.09 27.70 30.66 32.50
23.11 27.81 30.75 32.69
23.13 27.85 30.88 32.83
23.13 27.93 30.88 32.94
23.06 28.01 31.01 32.84
23.09 27.93 31.10 32.61
23.06 27.71 31.14 32.45
23.07 27.68 31.06 32.14
22.92 27.49 30.85 32.01
Table 7: The effect of tuning the length multiplier on performance (BLEU), using the validation set.

5 Analysis

We provide a qualitative analysis to provide some insight where AXE improves over cross entropy, and potential directions for future research on non-autoregressive generation.

AXE Handles Long Sequences Better

We first measure performance of cross entropy versus AXE-trained CMLMs for different sequence lengths. We use compare-mt (Neubig et al., 2019) to split the test sets of WMT’14 EN-DE and DE-EN into different buckets based on target sequence length and calculate BLEU for each bucket. Table 8 shows that the performance of models trained with cross entropy drops drastically as the sequence length increases, while the performance of AXE-trained models remains relatively stable. One explanation for this result is that the longer the sequence, the more likely we are to observe misalignments between the model’s predictions and the target; AXE realigns these cases, providing the model with a cleaner signal for modeling long sequences.

Cross Entropy AXE
WMT’14 EN-DE 18.75 20.48
21.69 23.92
18.64 24.21
15.37 22.65
14.04 23.04
11.62 23.43
DE-EN 22.57 24.39
25.28 27.86
22.43 28.78
19.03 27.18
16.16 27.55
12.23 27.64
Table 8: The performance (test set BLEU) of cross entropy CMLM and AXE CMLM on WMT’14 EN-DE and DE-EN, bucketed by target sequence length ().
(a) Short sequences (less than 10 tokens).
(b) Long sequences (more than 30 tokens).
Figure 3: The average prediction probability assigned to a token as a function of its relative distance from where it was generated in the sequence. Each plot shows the average probabilities for CMLMs trained with cross entropy (dashed blue line) and AXE (solid red line).

AXE Increases Position Confidence

We also study how confident each model is about the position of each generated token. Ideally, we would like each predicted token to have a high probability at the position in which it was predicted and a very low probability in the neighboring positions. After applying argmax decoding, we compute the probability assigned to each generated token in all positions of the sequence and average these probabilities based on the relative distance (positive or negative) to the generated position. Figure 3 plots these averaged probabilities for both short ( tokens) and long ( tokens) target sequences.

Both models are rather confident in their predictions for short sequences (Figure 3(a)): the probability has a high peak at the generated position and drops rapidly as we move further away. However, for longer sentences (Figure 3(b)), we observe that the plot for cross entropy has lost its sharpness. Specifically, the immediate neighbors of the prediction position () receive about probability on average, almost a third of the peak probability. Meanwhile, the probabilities predicted by the AXE-trained model are significantly sharper, assigning negligible probabilities to the generated token in neighboring positions when compared to the center.

On way to explain this result is that cross entropy training encourages predictions to have some probability mass of their neighbors, in order to “hedge their bets” in case the predictions are misaligned with the target. Since AXE finds the best alignment before computing the actual loss, spreading the probability mass of a token among its neighbors is no longer necessary.

AXE Reduces Multimodality

We further argue that AXE reduces the multimodality problem in non-autoregressive machine translation (Gu et al., 2018). Due to minimal coordination between predictions in many non-autoregressive models, a model might consider many possible translations at the same time. In this situation, the model might merge two or more different translations and generate an inconsistent output that is typically characterized by token repetitions. We therefore use the frequency of repeated tokens as a proxy for measuring multimodality in a model.

Table 9 shows the repetition rate for cross entropy and AXE-trained CMLMs. Replacing cross entropy with AXE drastically reduces multimodality, decreasing the number of repetitions by a multiplicative factor of 12.

Model WMT’14
Cross Entropy CMLM 16.72% 12.31%
AXE CMLM   1.41%   1.03%
Table 9: The percentage of repeated tokens on the test sets of WMT’14 EN-DE and DE-EN.

6 Related Work

Advances in neural machine translation techniques in recent years has brought an increasing interest in breaking the autoregressive generation bottleneck in translation models.

Semi-autoregressive models introduce partial parallelism into the decoding process. Some of these techniques include iterative refinement of translations based on previous predictions (Lee et al., 2018; Ghazvininejad et al., 2019, 2020; Gu et al., 2019; Kasai et al., 2020) and combining a lighter autoregressive decoder with a non-autoregressive one (Sun et al., 2019).

Building a fully non-autoregrssive machine translation model is a much more challenging task. One branch of prior work approaches this problem by modeling with latent variables. Gu et al. (2018) introduces word fertility as a latent variable to model the number of generated tokens per each source word. Ma et al. (2019) uses generative flow to model complex distribution of latent variables for parallel decoding of target. Shu et al. (2019) proposes a latent-variable non-autoregressive model with continuous latent variables and a deterministic inference procedure.

There is also work that develops other alternative loss functions for non-autoregressive machine translation. Libovický and Helcl (2018) use the Connectionist Temporal Classification training objective, a loss function from the speech recognition literature that is designed to eliminating repetitions. Li et al. (2019) uses the learning signal provided by hidden states and attention distributions of an autoregressive teacher. Yang et al. (2019)

improves the decoder hidden representations by adding the reconstruction error of source sentence from these representations as an auxiliary regularization term to the loss function. Finally,

Shao et al. (2019) introduce the bag-of-ngrams training objective to encourage the model to capture target-side sequential dependencies.

7 Conclusion

We introduced Aligned Cross Entropy (AXE) as an alternative loss function for training non-autoregressive models. AXE focuses on relative order and lexical matching instead of relying on absolute positions. We showed that, in the context of machine translation, a conditional masked language model (CMLM) trained with AXE significantly outperforms cross entropy trained models, setting a new state-of-the-art for non-autoregressive models.


We thank Abdelrahman Mohamed for sharing his expertise on non-autoregressive models, and our colleagues at FAIR for valuable feedback.


  • Y. Bengio, J. Louradour, R. Collobert, and J. Weston (2009) Curriculum learning. In

    Proceedings of the 26th Annual International Conference on Machine Learning

    pp. 41–48. Cited by: §4.3.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) BERT: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §4.1.
  • M. Ghazvininejad, O. Levy, Y. Liu, and L. Zettlemoyer (2019) Mask-predict: parallel decoding of conditional masked language models. In Proc. of EMNLP-IJCNLP, External Links: Link Cited by: §1, §1, Table 2, Table 3, §3, §4.1, §4.2, §4.2, Table 4, §6, footnote 3, footnote 8.
  • M. Ghazvininejad, O. Levy, and L. Zettlemoyer (2020) Semi-autoregressive training improves mask-predict decoding. arXiv preprint arXiv:2001.08785. Cited by: §6.
  • J. Gu, J. Bradbury, C. Xiong, V. O. Li, and R. Socher (2018) Non-autoregressive neural machine translation. In Proc. of ICLR, Cited by: §1, Table 3, §4.1, §4.2, §4.2, §5, §6.
  • J. Gu, C. Wang, and J. Zhao (2019) Levenshtein transformer. In Proc. of NeurIPS, External Links: Link Cited by: §1, §6.
  • J. Kasai, J. Cross, M. Ghazvininejad, and J. Gu (2020) Parallel machine translation with disentangled context transformer. arXiv preprint arXiv:2001.05136. Cited by: §6.
  • Y. Kim and A. M. Rush (2016) Sequence-level knowledge distillation. In Proc. of EMNLP, External Links: Link Cited by: §4.1.
  • D. P. Kingma and J. Ba (2015) Adam: a method for stochastic optimization. In International Conference for Learning Representations, Cited by: §4.1.
  • J. D. Lee, E. Mansimov, and K. Cho (2018) Deterministic non-autoregressive neural sequence modeling by iterative refinement. In Proc. of EMNLP, External Links: Link Cited by: §1, Table 3, §4.1, §4.1, §4.2, §6, footnote 8.
  • Z. Li, Z. Lin, D. He, F. Tian, T. Qin, L. Wang, and T. Liu (2019) Hint-based training for non-autoregressive machine translation. arXiv preprint arXiv:1909.06708. Cited by: Table 3, §4.2, §6.
  • J. Libovický and J. Helcl (2018) End-to-end non-autoregressive neural machine translation with connectionist temporal classification. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, Brussels, Belgium, pp. 3016–3021. External Links: Link Cited by: §1, Table 3, §4.2, Table 4, §6.
  • X. Ma, C. Zhou, X. Li, G. Neubig, and E. Hovy (2019) FlowSeq: non-autoregressive conditional sequence generation with generative flow. arXiv preprint arXiv:1909.02480. Cited by: §1, Table 3, §4.2, Table 4, §6.
  • G. Neubig, Z. Dou, J. Hu, P. Michel, D. Pruthi, and X. Wang (2019) Compare-mt: a tool for holistic comparison of language generation systems. In Meeting of the North American Chapter of the Association for Computational Linguistics (NAACL) Demo Track, Minneapolis, USA. External Links: Link Cited by: §5.
  • K. Papineni, S. Roukos, T. Ward, and W. Zhu (2002) Bleu: a method for automatic evaluation of machine translation. In Proceedings of the 40th Annual Meeting of the Association for Computational Linguistics, External Links: Link Cited by: §4.1.
  • M. Post (2018) A call for clarity in reporting BLEU scores. In Proceedings of the Third Conference on Machine Translation: Research Papers, External Links: Link Cited by: §4.1.
  • H. Sakoe and S. Chiba (1978) Dynamic programming algorithm optimization for spoken word recognition. IEEE transactions on acoustics, speech, and signal processing 26 (1), pp. 43–49. Cited by: §2.
  • R. Sennrich, B. Haddow, and A. Birch (2016) Neural machine translation of rare words with subword units. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), External Links: Link Cited by: §4.1.
  • C. Shao, J. Zhang, Y. Feng, F. Meng, and J. Zhou (2019) Minimizing the bag-of-ngrams difference for non-autoregressive neural machine translation. arXiv preprint arXiv:1911.09320. Cited by: Table 3, §4.2, §6.
  • R. Shu, J. Lee, H. Nakayama, and K. Cho (2019) Latent-variable non-autoregressive neural machine translation with deterministic inference using a delta posterior. arXiv preprint arXiv:1908.07181. Cited by: §6.
  • M. Stern, W. Chan, J. R. Kiros, and J. Uszkoreit (2019) Insertion transformer: flexible sequence generation via insertion operations. In Proc. of ICML, External Links: Link Cited by: §1, §4.1.
  • Z. Sun, Z. Li, H. Wang, D. He, Z. Lin, and Z. Deng (2019) Fast structured decoding for sequence models. In Advances in Neural Information Processing Systems, pp. 3011–3020. Cited by: §1, Table 3, §4.2, §6.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems, pp. 5998–6008. Cited by: §3.1, §4.1, §4.1.
  • Y. Wang, F. Tian, D. He, T. Qin, C. Zhai, and T. Liu (2019) Non-autoregressive machine translation with auxiliary regularization. In Proc. of AAAI, External Links: Link Cited by: Table 3, §4.2.
  • F. Wu, A. Fan, A. Baevski, Y. N. Dauphin, and M. Auli (2019) Pay less attention with lightweight and dynamic convolutions. International Conference on Learning Representations. Cited by: §4.1.
  • B. Yang, F. Liu, and Y. Zou (2019) Non-autoregressive video captioning with iterative refinement. arXiv preprint arXiv:1911.12018. Cited by: §6.
  • C. Zhou, G. Neubig, and J. Gu (2019) Understanding knowledge distillation in non-autoregressive machine translation. arXiv preprint arXiv:1911.02727. Cited by: §4.2.