DisCo
DisCo Transformer for Non-autoregressive MT
view repo
Most machine translation systems generate text autoregressively, by sequentially predicting tokens from left to right. We, instead, use a masked language modeling objective to train a model to predict any subset of the target words, conditioned on both the input text and a partially masked target translation. This approach allows for efficient iterative decoding, where we first predict all of the target words non-autoregressively, and then repeatedly mask out and regenerate the subset of words that the model is least confident about. By applying this strategy for a constant number of iterations, our model improves state-of-the-art performance levels for constant-time translation models by over 3 BLEU on average. It is also able to reach 92-95 performance of a typical left-to-right transformer model, while decoding significantly faster.
READ FULL TEXT VIEW PDFDisCo Transformer for Non-autoregressive MT
EfficientNet-Transformer model for convert image to UTF8 text
Most machine translation systems use linear time decoding strategies where words are predicted one-by-one. In this paper, we present a model and a parallel decoding algorithm which, for a relatively small sacrifice in performance, can be used to generate translations in a constant number of decoding iterations.
We introduce conditional masked language models (CMLMs), which are encoder-decoder architectures trained with a masked language model objective Devlin et al. (2018); Lample and Conneau (2019). This change allows the model to lean to predict, in parallel, any arbitrary subset of masked words in the target translation. We use transformer CMLMs, where the decoder’s self attention Vaswani et al. (2017) can attend to the entire sequence (left and right context) to predict each masked word. We train with a simple masking scheme where the number of masked target tokens is distributed uniformly, presenting the model with both easy (single mask) and difficult (completely masked) examples. Unlike recently proposed insertion models Gu et al. (2019); Stern et al. (2019), which treat each token as a separate training instance, CMLMs can train from the entire sequence in parallel, resulting in much faster training.
We also introduce a new decoding algorithm, mask-predict, which uses the order-agnostic nature of CMLMs to support highly parallel, constant-time decoding. Mask-predict repeatedly masks out and re-predicts the subset of words in the current translation that the model is least confident about, in contrast to recent constant-time translation approaches that repeatedly predict the entire sequence Lee et al. (2018). Decoding starts with a completely masked target text, to predict all of the words in parallel, and ends after a constant number of mask-predict cycles. This overall strategy allows the model to repeatedly reconsider word choices within a rich bi-directional context and, as we will show, produce high-quality translations in just a few cycles.
Experiments on benchmark machine translation datasets show the strengths of mask-predict decoding for transformer CMLMs. With just 4 iterations, BLEU scores already surpass the performance of the best non-autoregressive models. With 10 iterations, the approach outperforms the current state-of-the-art constant-time model
Lee et al. (2018) by gaps of 4-5 BLEU points on the WMT’14 English-German translation benchmark, and up to 3 BLEU points on WMT’16 English-Romanian. When compared to standard autoregressive transformer models, CMLMs with mask-predict offer a trade-off between speed and performance, trading a 5-10% reduction in translation quality for a 200-300% speed-up.A conditional masked language model (CMLM) predicts a set of target tokens given a source text and part of the target text . It makes the strong assumption that the tokens are conditionally independent of each other (given and
), and predicts the individual probabilities
for each . Since the number of tokens in is given in advance, the model is also implicitly conditioning on the length of the target sequence .We adopt the standard encoder-decoder transformer for machine translation Vaswani et al. (2017): a source-language encoder that does self-attention, and a target-language decoder that has one set of attention heads over the encoder’s output and another set for the target language (self-attention). In terms of parameters, our architecture is identical to the standard one. We deviate from the standard decoder by removing the self-attention mask that prevents left-to-right decoders from attending on future tokens. In other words, our decoder is bi-directional, in the sense that it can use both left and right contexts to predict each token.
During training, we randomly select among the target tokens.^{1}^{1}1
We first sample the number of masked tokens from a uniform distribution between one and the sequence’s length, and then randomly choose that number of tokens.
Following Devlin et al. (2018), we replace the inputs of the tokens with a special MASK token.We optimize the CMLM for cross-entropy loss over every token in . This can be done in parallel, since the model assumes that the tokens in are conditionally independent of each other. While the architecture can technically make predictions over all target-language tokens (including ), we only compute the loss for the tokens in .
In traditional left-to-right machine translation, where the target sequence is predicted token by token, it is natural to determine the length of the sequence dynamically by simply predicting a special “end of sentence” token. However, for CMLMs to predict the entire sequence in parallel, they must know its length in advance. This problem was recognized by prior work in non-autoregressive translation, where the length is predicted with a fertility model Gu et al. (2018)
or by pooling the encoder’s outputs into a length classifier
Lee et al. (2018).We follow Devlin et al. (2018) and add a special LENGTH token to the encoder, akin to the CLS token in BERT. The model is trained to predict the length of the target sequence as the LENGTH token’s output, similar to predicting another token from a different vocabulary, and its loss is added to the cross-entropy loss from the target sequence.
We introduce the mask-predict algorithm, which does constant-time decoding when run for a fixed number of mask-predict cycles. At each iteration, the algorithm selects a subset of tokens to mask, and then predicts them (in parallel) using an underlying CMLM. Masking the tokens where the model has doubts while conditioning on previous high-confidence predictions lets the model re-predict the more challenging cases, but with more information. At the same time, the ability to make large parallel changes at each step allows mask-predict to converge on a high quality output sequence in sub-linear time.
Der Abzug der französischen Kampftruppen wurde am 20. November abgeschlossen . | |
---|---|
The departure of the French combat completed completed on 20 November . | |
The departure of French combat troops was completed on 20 November . | |
The withdrawal of French combat troops was completed on November 20th . |
Given the target sequence’s length (see Section 3.3), we define two variables: the target sequence and the probability of each token . The algorithm runs for a predetermined number of iterations , which is either a constant or a simple function of . At each iteration, we perform a mask operation, followed by predict.
For the first iteration (), we mask all the tokens. For later iterations, we mask the tokens with the lowest probability scores:
The number of masked tokens is a function of the iteration ; specifically, we use linear decay , where is the total number of iterations. For example, if , we will mask 90% of the tokens at , 80% at , and so forth.
After masking, the CMLM predicts the masked tokens , conditioned on the source text and the unmasked target tokens . We select the prediction with the highest probability for each masked token and update its probability score accordingly:
The values and the probabilities of unmasked tokens remain unchanged:
We tried updating or decaying these probabilities in preliminary experiments, but found that this heuristic works well despite the fact that some probabilities are stale.
Figure 1 illustrates how mask-predict can generate a good translation in just three iterations.
In the first iteration (), the entire target sequence is masked ( and ), and is thus generated by the CMLM in a purely non-autoregressive process:
This produces an ungrammatical translation with repetitions (“completed completed”), which is typical of non-autoregressive machine translations due to the multi-modality problem Gu et al. (2018).
In the second iteration (), we select 8 of the 12 tokens generated in the previous step; these token were predicted with the lowest probabilities at . We mask them and repredict with the CMLM, while conditioning on the 4 unmasked tokens . This results in a more grammatical and accurate translation. Our analysis shows that this second iteration often removes most repetitions, perhaps because conditioning on even a small portion of the target sequence is enough to collapse the multi-modal target distribution into a single output (Section 5.1).
In the last iteration (), we select the 4 of the 12 tokens that had the lowest probabilities. Two of those tokens were predicted at the first step (), and not repredicted at the second step (). It is quite common for earlier predictions to be masked at later iterations because they were predicted with less information and thus tend to have lower probabilities. Now that the model is conditioning on 8 tokens, it is able to produce an more fluent translation; “withdrawal” is a better fit for describing troop movement, and “November 20th” is a more common date format in English.
During decoding, we first compute the CMLM’s encoder, and then use the encoding of the special LENGTH token to predict a distribution over the target sequence’s length (see Section 2.3). One simple solution is to select the length with the highest probability . However, since much of the CMLM’s computation can be batched, we select the top length candidates and decode the same example with different lengths in parallel.^{2}^{2}2This idea is somewhat analogous to beam search, hence the choice of for beam size. We then select the sequence with the highest average log-probability as our result:
Our analysis reveals that translating multiple candidate sequences of different lengths significantly increases the performance (see Section 5.3).
We evaluate CMLMs with mask-predict decoding on standard machine translation benchmarks. We find that our approach significantly outperforms prior constant-time machine translation methods and even approaches the performance of standard autoregressive models (Section 4.2), while decoding significantly faster (Section 4.3).
We evaluate on two standard datasets, WMT’14 EN-DE (4.5M sentence pairs) and WMT’16 EN-RO (610k pairs), in both directions. The datasets are tokenized into subword units using BPE Sennrich et al. (2016). We use the same preprocessed data as Vaswani et al. (2017) for WMT’14 EN-DE, and use the data from Lee et al. (2018) for WMT’16 EN-RO.
We follow most of the standard hyperparameters for transformers in the base configuration
Vaswani et al. (2017): 6 layers per stack, 8 attention heads per layer, 512 model dimensions, 2048 hidden dimensions. We also experiment with 512 hidden dimensions, for comparison with previous constant-time translation models Gu et al. (2018); Lee et al. (2018). We follow the weight initialization scheme from BERT Devlin et al. (2018), which samples weights from , initializes biases to zero, and sets layer normalization parameters to . For regularization, we use dropout, weight decay, and smoothed cross validation loss with . We train batches of 128k tokens using Adam Kingma and Ba (2015) with and . The learning rate warms up to a peak ofwithin 10,000 steps, and then decays with the inverse square-root schedule. We trained the WMT’14 EN-DE models for 300k steps, and the WMT’16 EN-RO for 100k steps. We measured the validation loss at the end of each epoch and averaged the 5 best checkpoints to create the final model. During decoding, we use a beam size of
.Following previous work on non-autoregressive and insertion-based machine translation Gu et al. (2018); Lee et al. (2018); Stern et al. (2019), we train CMLMs on translations produced by a standard left-to-right transformer model (large for EN-DE, base for EN-RO). We analyze the impact of distillation in Section 5.4.
Model | Dimensions | Iterations | WMT’14 | WMT’16 | ||
---|---|---|---|---|---|---|
(Model/Hidden) | EN-DE | DE-EN | EN-RO | RO-EN | ||
NAT w/ Fertility Gu et al. (2018) | 512/512 | 1 | 19.17 | 23.20 | 29.79 | 31.44 |
CTC Loss Libovický and Helcl (2018) | 512/4096 | 1 | 17.68 | 19.80 | 19.93 | 24.71 |
Iterative Refinement Lee et al. (2018) | 512/512 | 1 | 13.91 | 16.77 | 24.45 | 25.73 |
512/512 | 10 | 21.61 | 25.48 | 29.32 | 30.19 | |
(Dynamic #Iterations) | 512/512 | ? | 21.54 | 25.43 | 29.66 | 30.30 |
Small CMLM with Mask-Predict | 512/512 | 1 | 15.06 | 19.26 | 20.12 | 20.36 |
512/512 | 4 | 24.17 | 28.55 | 30.00 | 30.43 | |
512/512 | 10 | 25.51 | 29.47 | 31.65 | 32.27 | |
Base CMLM with Mask-Predict | 512/2048 | 1 | 18.12 | 22.66 | 23.65 | 22.78 |
512/2048 | 4 | 26.08 | 30.11 | 31.78 | 31.76 | |
512/2048 | 10 | 26.92 | 30.86 | 32.53 | 33.06 | |
Base Transformer Vaswani et al. (2017) | 512/2048 | 27.30 | — — | — — | — — | |
Base Transformer (Our Implementation) | 512/2048 | 27.87 | 31.20 | 34.01 | 33.79 | |
Large Transformer Vaswani et al. (2017) | 1024/4096 | 28.40 | — — | — — | — — | |
Large Transformer (Our Implementation) | 1024/4096 | 28.65 | 31.87 | — — | — — |
We compare our approach to three other constant-time translation methods: the fertility-based sequence-to-sequence model of Gu et al. (2018), the CTC-loss transformer of Libovický and Helcl (2018), and the iterative refinement approach of Lee et al. (2018). The first two methods are purely non-autoregressive, while the iterative refinement approach is only non-autoregressive in the first decoding iteration, similar to our approach.
Table 1 shows that among the constant-time methods, our approach yields the highest BLEU scores by a considerable margin. When controlling for the number of parameters (i.e. considering only the smaller CMLM configuration), CMLMs score roughly 4 BLEU points higher than the previous state of the art on WMT’14 EN-DE, in both directions. Another striking result is that a CMLM with only 4 mask-predict iterations yields higher scores than 10 iterations of the iterative refinement model; in fact, only 3 mask-predict iterations are necessary for achieving a new state of the art on both directions of WMT’14 EN-DE (not shown).
The translations produced by CMLMs with mask-predict also score competitively when compared to strong transformer-based autoregressive models. In all 4 benchmarks, our base CMLM reaches within 0.34–1.48 BLEU points from a well-tuned base transformer, a relative decrease of less than 5% in translation quality. In many scenarios, this is an acceptable price to pay for a significant speedup from constant-time decoding.
Because CMLMs can predict the entire sequence in parallel, mask-predict can translate an entire sequence in a constant number of decoding iterations. Does this appealing theoretical property translate into a wall-time speed-up in practice? By comparing the actual decoding times, we show that, indeed, our method translates much faster than standard linear-time transformers.
As the baseline system, we use the base transformer with beam search () to translate WMT’14 EN-DE. For CMLMs, we vary the number of mask-predict iterations () and length candidates (). We use a decoding batch of 10 sentences for both models. For each decoding run, we measure the performance (BLEU) and wall time (seconds) from when the model and data have been loaded until the last example has been translated. We calculate the relative translation quality (CMLM BLEU / baseline BLEU) and relative decoding speed-up (CMLM time / baseline time) to assess the speed-performance trade-off.
The implementation of both the baseline transformer and our CMLM is based on fairseq implementation Gehring et al. (2017), which efficiently decodes left-to-right transformers by caching the state. Caching reduces the baseline’s decoding speed from 210 seconds to 134; CMLMs do not use cached decoding. All experiments used exactly the same machine and GPU.
Figure 2 shows the speed-performance trade-off. We see that mask-predict is versatile; on one hand, we can translate almost 5 times faster than the baseline (at a significant cost to quality), or alternatively retain 97% of the quality (BLEU 27.00) while gaining a speed-up of 150%. More balanced configurations can be found for runs with 90-95% relative translation quality, which yield a 2-3 time speed-up. In particular, we note that running mask-predict for five iterations () with only one length candidate () translates 3 times faster than the baseline, while keeping translation quality at 92.6%.
To complement the quantitative results in Section 4, we present qualitative analysis that provides some intuition as to why our approach works and where future work could potentially improve it.
Various non-autoregressive translation models, including our own CMLM, make the strong assumption that the individual token predictions are conditionally independent of each other. Such a model might consider two or more possible translations, A and B, but because there is no coordination mechanism between the token predictions, it could predict one token from A and another token from B. This problem, known as the multi-modality problem Gu et al. (2018), often manifest as token repetitions in the output when the model has multiple hypotheses that predict the same word with high confidence, but at different positions.
We hypothesize that multiple mask-predict iterations alleviate the multi-modality problem by allowing the model to condition on parts of the input, thus collapsing the multi-modal distribution into a sharper uni-modal distribution. To test our hypothesis, we measure the percentage of repetitive tokens produced by each iteration of mask-predict as a proxy metric for multi-modality.
Table 2 shows that, indeed, the proportion of repetitive tokens drops drastically during the first 2-3 iterations. This finding suggests that the first few iterations are critical for converging into a uni-modal distribution. The decrease in repetitions also correlates with the steep rise in translation quality (BLEU), supporting the conjecture of Gu et al. (2018) that multi-modality is a major roadblock for purely non-autoregressive machine translation.
Iterations | WMT’14 EN-DE | WMT’16 EN-RO | ||
---|---|---|---|---|
BLEU | Reps | BLEU | Reps | |
18.12 | 16.89% | 23.65 | 13.42% | |
23.12 | 5.44% | 29.28 | 4.03% | |
25.21 | 2.00% | 31.08 | 1.77% | |
26.08 | 1.00% | 31.78 | 1.02% | |
26.30 | 0.65% | 31.90 | 0.72% |
A potential concern with using a constant amount of decoding iterations is that it may be effective for short sequences (where the number of iterations is closer to the output’s length ), but insufficient for longer sequences. To determine whether this is the case, we use compare-mt Neubig et al. (2019) to bucket the evaluation data by target sentence length and compute the performance with different values of .
Table 3 shows that increasing the number of decoding iterations () appears to mainly improve the performance on longer sequences. Having said that, the performance differences across length buckets are not very large, and it seems that even 4 mask-predict iterations are enough to produce decent translations for long sequences ().
22.3 | 21.1 | 21.7 | |
26.3 | 26.7 | 26.5 | |
26.4 | 27.2 | 27.7 | |
25.9 | 27.0 | 27.0 | |
25.8 | 27.2 | 28.2 |
Traditional autoregressive models can dynamically decide the length of the target sequence by generating a special END token when they are done, but that is not true for models that decode multiple tokens in parallel, such as CMLMs. To address this problem, our model predicts the length of the target sequence (Section 2.3) and decodes multiple length candidates in parallel (Section 3.3). We compare our model’s performance with a varying number of length candidates to its performance when conditioned on the reference (gold) target length in order to determine how accurate it is at predicting the correct length and assess the relative contribution of decoding with multiple length candidates.
Table 4 shows that having multiple candidates can increase performance, even beyond conditioning on the gold length. In fact, peak performance is achieved when roughly half of the examples do not contain the correct length as a candidate.
Surprisingly, adding too many candidates can even degrade performance. We suspect that because CMLMs are implicitly conditioned on the target length, producing a translation that is too short (i.e. high precision, low recall) will have a high average log probability. In preliminary experiments, we tried to address this issue by weighting the different candidates according to the model’s length prediction, but this approach gave too much weight to the top candidate and resulted in lower performance.
Length | WMT’14 EN-DE | WMT’16 EN-RO | ||
---|---|---|---|---|
Candidates | BLEU | LP | BLEU | LP |
26.55 | 16.0% | 31.58 | 13.2% | |
26.87 | 31.8% | 32.34 | 27.4% | |
27.04 | 44.1% | 32.49 | 39.6% | |
27.01 | 53.8% | 32.63 | 50.1% | |
26.92 | 62.5% | 32.53 | 57.9% | |
26.97 | 70.3% | 32.57 | 65.1% | |
26.74 | 76.1% | 32.49 | 71.1% | |
26.64 | 81.0% | 32.48 | 75.8% | |
26.40 | 84.9% | 32.41 | 79.8% | |
Gold | 26.78 | — | 31.91 | — |
Previous work on non-autoregressive and insertion-based machine translation reported that it was necessary to train their models on text generated by an autoregressive teacher model, a process known as distillation. To determine CMLM’s dependence on this process, we train a models on both raw and distilled data, and compare their performance.
Table 5 shows that in the majority of cases, training with model distillation substantially outperforms training on raw data. The gaps are especially large when decoding with a single iteration (purely non-autoregressive). Overall, it appears as though CMLMs are heavily dependent on model distillation.
On the English-Romanian benchmark, the differences are smaller, and after 10 iterations the raw-data can perform comparably with the distilled model; when translating in the opposite direction (RO-EN), the raw-data model even outperforms the distilled model (not shown). A possible explanation is that our teacher model was weaker for this dataset due to insufficient hyperparameter tuning. Alternatively, it could also be the case that the English-German dataset is much noisier than the English-Romanian one, and that the teacher model essentially cleans the training data. Unfortunately, we do not have enough evidence to support or refute either hypothesis at this time.
Iterations | WMT’14 EN-DE | WMT’16 EN-RO | ||
---|---|---|---|---|
Raw | Dist | Raw | Dist | |
10.88 | 18.12 | 20.24 | 23.65 | |
22.06 | 26.08 | 30.89 | 31.78 | |
24.65 | 26.92 | 32.53 | 32.42 |
Recent work by Lample and Conneau (2019) shows that training a masked language model on sentence-pair translation data, as a pre-training step, can improve performance on cross-lingual tasks, including autoregressive machine translation. Our training scheme builds on their work, with the following differences: we use separate model parameters for source and target texts (encoder and decoder), and we also use a different masking scheme.^{3}^{3}3Specifically, we mask a varying percentage of tokens, only from the target, and do not replace input tokens with noise. Most importantly, the goal of our work is different; we do not use CMLMs for pre-training, but to directly generate text with mask-predict decoding.
One such approach for generating text from a masked language models casts BERT Devlin et al. (2018), a non-conditional masked language model, as a Markov random field Wang and Cho (2019). By masking a sequence of length and then iteratively sampling a single token at each time from the model (either sequentially or in arbitrary order), one can produce grammatical examples. This sampling process requires forward passes of the model, while mask-predict decoding can produce text in constant time.
There have been several advances in constant-time machine translation by training non-autoregressive models. Gu et al. (2018) introduce a transformer-based approach with explicit word fertility, and identify the multi-modality problem. Libovický and Helcl (2018) approach the multi-modality problem by collapsing repetitions with the Connectionist Temporal Classification training objective Graves et al. (2006). Perhaps most similar to our work is the iterative refinement approach of Lee et al. (2018)
, in which the model corrects the original non-autoregressive prediction by passing it multiple times through a denoising autoencoder. A major difference is that
Lee et al. (2018) train their noisy autoencoder to deal with corrupt inputs by applying stochastic corruption heuristics on the training data, while we simply mask a random number of input tokens. We also show that our approach outperforms all of these models by wide margins.Finally, recent work has developed insertion-based transformers for arbitrary, but fixed, word order generation Gu et al. (2019); Stern et al. (2019). While they do not provide constant-time decoding, Stern et al. (2019) show strong results in logarithmic time. Both models treat each token insertion as a separate training example, which cannot be computed in parallel with every other insertion in the same sequence. This makes training significantly more expensive that standard transformers (which use causal attention masking) and our CMLMs (which can predict all of the masked tokens in parallel).
This work introduces conditional masked language models and a novel mask-predict decoding algorithm that leverages their parallelism to generate text in constant time. We show that, in the context of machine translation, our approach substantially outperforms previous constant-time methods, and can approach the performance of linear-time autoregressive models while decoding much faster. While there are still open problems, such as the need to condition on the target’s length and the dependence on knowledge distillation, our results provide a significant step forward in non-autoregressive machine translation. In a broader sense, this paper shows that masked language models are useful not only for representing text, but also for generating text efficiently.
We thank Abdelrahman Mohamed for sharing his expertise on non-autoregressive models, and our colleagues at FAIR for valuable feedback.
Connectionist temporal classification: Labelling unsegmented sequence data with recurrent neural networks.
InProceedings of the 23rd international conference on Machine learning
, pages 369–376. ACM.Non-autoregressive neural machine translation.
In ICLR.Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing
, pages 1173–1182, Brussels, Belgium. Association for Computational Linguistics.