SwitchOut: an Efficient Data Augmentation Algorithm for Neural Machine Translation

08/22/2018 ∙ by Xinyi Wang, et al. ∙ 0

In this work, we examine methods for data augmentation for text-based tasks such as neural machine translation (NMT). We formulate the design of a data augmentation policy with desirable properties as an optimization problem, and derive a generic analytic solution. This solution not only subsumes some existing augmentation schemes, but also leads to an extremely simple data augmentation strategy for NMT: randomly replacing words in both the source sentence and the target sentence with other random words from their corresponding vocabularies. We name this method SwitchOut. Experiments on three translation datasets of different scales show that SwitchOut yields consistent improvements of about 0.5 BLEU, achieving better or comparable performances to strong alternatives such as word dropout (Sennrich et al., 2016a). Code to implement this method is included in the appendix.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction and Related Work

Data augmentation algorithms generate extra data points from the empirically observed training set to train subsequent machine learning algorithms. While these extra data points may be of lower quality than those in the training set, their quantity and diversity have proven to benefit various learning algorithms 

(DeVries and Taylor, 2017; Amodei et al., 2016). In image processing, simple augmentation techniques such as flipping, cropping, or increasing and decreasing the contrast of the image are both widely utilized and highly effective (Huang et al., 2016; Zagoruyko and Komodakis, 2016).

However, it is nontrivial to find simple equivalences for NLP tasks like machine translation, because even slight modifications of sentences can result in significant changes in their semantics, or require corresponding changes in the translations in order to keep the data consistent. In fact, indiscriminate modifications of data in NMT can introduce noise that makes NMT systems brittle (Belinkov and Bisk, 2018).

Due to such difficulties, the literature in data augmentation for NMT is relatively scarce. To our knowledge, data augmentation techniques for NMT fall into two categories. The first category is based on back-translation (Sennrich et al., 2016b; Poncelas et al., 2018), which utilizes monolingual data to augment a parallel training corpus. While effective, back-translation is often vulnerable to errors in initial models, a common problem of self-training algorithms (Chapelle et al., 2009). The second category is based on word replacements. For instance, Fadaee et al. (2017) propose to replace words in the target sentences with rare words in the target vocabulary according to a language model, and then modify the aligned source words accordingly. While this method generates augmented data with relatively high quality, it requires several complicated preprocessing steps, and is only shown to be effective for low-resource datasets. Other generic word replacement methods include word dropout (Sennrich et al., 2016a; Gal and Ghahramani, 2016), which uniformly set some word embeddings to at random, and Reward Augmented Maximum Likelihood (RAML; Norouzi et al. (2016)), whose implementation essentially replaces some words in the target sentences with other words from the target vocabulary.

In this paper, we derive an extremely simple and efficient data augmentation technique for NMT. First, we formulate the design of a data augmentation algorithm as an optimization problem, where we seek the data augmentation policy that maximizes an objective that encourages two desired properties: smoothness and diversity. This optimization problem has a tractable analytic solution, which describes a generic framework of which both word dropout and RAML are instances. Second, we interpret the aforementioned solution and propose a novel method: independently replacing words in both the source sentence and the target sentence by other words uniformly sampled from the source and the target vocabularies, respectively. Experiments show that this method, which we name SwitchOut, consistently improves over strong baselines on datasets of different scales, including the large-scale WMT 15 English-German dataset, and two medium-scale datasets: IWSLT 2016 German-English and IWSLT 2015 English-Vietnamese.

2 Method

2.1 Notations

We use uppercase letters, such as ,

, etc., to denote random variables and lowercase letters such as

, , etc., to denote the corresponding actual values. Additionally, since we will discuss a data augmentation algorithm, we will use a hat to denote augmented variables and their values, e.g. , , , , etc. We will also use boldfaced characters, such as ,

, etc., to denote probability distributions.

2.2 Data Augmentation

We facilitate our discussion with a probabilistic framework that motivates data augmentation algorithms. With , being the sequences of words in the source and target languages (e.g. in machine translation), the canonical MLE framework maximizes the objective

Here is the empirical distribution over all training data pairs and is a parameterized distribution that we aim to learn, e.g.

 a neural network. A potential weakness of MLE is the mismatch between

and the true data distribution . Specifically, is usually a bootstrap distribution defined only on the observed training pairs, while has a much larger support, i.e. the entire space of valid pairs. This issue can be dramatic when the empirical observations are insufficient to cover the data space.

In practice, data augmentation is often used to remedy this support discrepancy by supplying additional training pairs. Formally, let be the augmented distribution defined on a larger support than the empirical distribution . Then, MLE training with data augmentation maximizes

In this work, we focus on a specific family of , which depends on the empirical observations by

This particular choice follows the intuition that an augmented pair that diverges too far from any observed data is more likely to be invalid and thus harmful for training. The reason will be more evident later.

2.3 Diverse and Smooth Augmentation

Certainly, not all are equally good, and the more similar is to , the more desirable will be. Unfortunately, we only have access to limited observations captured by . Hence, in order to use to bridge the gap between and , it is necessary to utilize some assumptions about . Here, we exploit two highly generic assumptions, namely:

  • Diversity: has a wider support set, which includes samples that are more diverse than those in the empirical observation set.

  • Smoothness: is smooth, and similar pairs will have similar probabilities.

To formalize both assumptions, let be a similarity function that measures how similar an augmented pair is to an observed data pair . Then, an ideal augmentation policy should have two properties. First, based on the smoothness assumption, if an augmented pair is more similar to an empirical pair , it is more likely that is sampled under the true data distribution , and thus should assign a significant amount of probability mass to . Second, to quantify the diversity assumption, we propose that the entropy should be large, so that the support of is larger than the support of and thus is closer to the support . Combining these assumptions implies that should maximize the objective


where controls the strength of the diversity objective. The first term in (1) instantiates the smoothness assumption, which encourages to draw samples that are similar to . Meanwhile, the second term in (1) encourages more diverse samples from . Together, the objective extends the information in the “pivotal” empirical sample to a diverse set of similar cases. This echoes our particular parameterization of in Section 2.2.

The objective in (1) is the canonical maximum entropy problem that one often encounters in deriving a max-ent model (Berger et al., 1996), which has the analytic solution:


Note that (2) is a fairly generic solution which is agnostic to the choice of the similarity measure . Obviously, not all similarity measures are equally good. Next, we will show that some existing algorithms can be seen as specific instantiations under our framework. Moreover, this leads us to propose a novel and effective data augmentation algorithm.

2.4 Existing and New Algorithms

Word Dropout.

In the context of machine translation, Sennrich et al. (2016a) propose to randomly choose some words in the source and/or target sentence, and set their embeddings to vectors. Intuitively, it regards every new data pair generated by this procedure as similar enough and then includes them in the augmented training set. Formally, word dropout can be seen as an instantiation of our framework with a particular similarity function (see Appendix A.1).


From the perspective of reinforcement learning,

Norouzi et al. (2016) propose to train the model distribution to match a target distribution proportional to an exponentiated reward. Despite the difference in motivation, it can be shown (c.f. Appendix A.2) that RAML can be viewed as an instantiation of our generic framework, where the similarity measure is if and otherwise. Here, is a task-specific reward function which measures the similarity between and . Intuitively, this means that RAML only exploits the smoothness property on the target side while keeping the source side intact.


After reviewing the two existing augmentation schemes, there are two immediate insights. Firstly, augmentation should not be restricted to only the source side or the target side. Secondly, being able to incorporate prior knowledge, such as the task-specific reward function in RAML, can lead to a better similarity measure.

Motivated by these observations, we propose to perform augmentation in both source and target domains. For simplicity, we separately measure the similarity between the pair and the pair and then sum them together, i.e.


where and are domain specific similarity functions and , are hyper-parameters that absorb the temperature parameter . This allows us to factor into:


In addition, notice that this factored formulation allows and to be sampled independently.

Sampling Procedure.

To complete our method, we still need to define and , and then design a practical sampling scheme from each factor in (4). Though non-trivial, both problems have been (partially) encountered in RAML (Norouzi et al., 2016; Ma et al., 2017). For simplicity, we follow previous work to use the negative Hamming distance for both and . For a more parallelized implementation, we sample an augmented sentence from a true sentence as follows:

  1. [leftmargin=*]

  2. Sample by .

  3. For each , with probability , we can replace by a uniform .

This procedure guarantees that any two sentences and with the same Hamming distance to

have the same probability, but slightly changes the relative odds of sentences with different Hamming distances to

from the true distribution by negative Hamming distance, and thus is an approximation of the actual distribution. However, this efficient sampling procedure is much easier to implement while achieving good performance.

Algorithm 1

illustrates this sampling procedure, which can be applied independently and in parallel for each batch of source sentences and target sentences. Additionally, we open source our implementation in TensorFlow and in PyTorch (respectively in Appendix 

A.5 and A.6).

Input :  : a sentence represented by vocab integral ids, : the temperature, : the vocabulary
Output :  : a sentence with words replaced
1 Function HammingDistanceSample(, , ):
2        Let be the partition function. Let for . Sample . In parallel, do: 
3               Sample . if  then
4                      .
5              else
6                      .
7               end if
10       return
Algorithm 1 Sampling with SwitchOut.

3 Experiments


We benchmark SwitchOut on three translation tasks of different scales: 1) IWSLT 2015 English-Vietnamese (en-vi); 2) IWSLT 2016 German-English (de-en); and 3) WMT 2015 English-German (en-de). All translations are word-based. These tasks and pre-processing steps are standard, used in several previous works. Detailed statistics and pre-processing schemes are in Appendix A.3.

Models and Experimental Procedures.

Our translation model, i.e. 

, is a Transformer network 

(Vaswani et al., 2017). For each dataset, we first train a standard Transformer model without SwitchOut and tune the hyper-parameters on the dev set to achieve competitive results. (w.r.t. Luong and Manning (2015); Gu et al. (2018); Vaswani et al. (2017)). Then, fixing all hyper-parameters, and fixing , we tune the rate, which controls how far we are willing to let deviate from . Our hyper-parameters are listed in Appendix A.4.


While the Transformer network without SwitchOut is already a strong baseline, we also compare SwitchOut against two other baselines that further use existing varieties of data augmentation: 1) word dropout on the source side with the dropping probability of ; and 2) RAML on the target side, as in Section 2.4. Additionally, on the en-de task, we compare SwitchOut against back-translation (Sennrich et al., 2016b).

SwitchOut vs. Word Dropout and RAML.

We report the BLEU scores of SwitchOut, word dropout, and RAML on the test sets of the tasks in Table 1

. To account for variance, we run each experiment multiple times and report the median BLEU. Specifically, each experiment without SwitchOut is run for

times, while each experiment with SwitchOut is run for times due to its inherently higher variance. We also conduct pairwise statistical significance tests using paired bootstrap (Clark et al., 2011), and record the results in Table 1. For 4 of the 6 settings, SwitchOut delivers significant improvements over the best baseline without SwitchOut. For the remaining two settings, the differences are not statistically significant. The gains in BLEU with SwitchOut over the best baseline on WMT 15 en-de are all significant (). Notably, SwitchOut on the source demonstrates as large gains as these obtained by RAML on the target side, and SwitchOut delivers further improvements when combined with RAML.

Method en-de de-en en-vi
Transformer 21.73 29.81 27.97
+WordDropout 20.63 29.97 28.56
+SwitchOut 22.78 29.94 28.67
+RAML 22.83 30.66 28.88
+RAML +WordDropout 20.69 30.79 28.86
+RAML +SwitchOut 23.13 30.98 29.09
Table 1: Test BLEU scores of SwitchOut and other baselines (median of multiple runs). Results marked with are statistically significant compared to the best result without SwitchOut. For example, for en-de results in the first column, +SwitchOut has significant gain over Transformer; +RAML +SwitchOut has significant gain over +RAML.

SwitchOut vs. Back Translation.

Traditionally, data-augmentation is viewed as a method to enlarge the training datasets (Krizhevsky et al., 2012; Szegedy et al., 2014). In the context of neural MT, Sennrich et al. (2016b) propose to use artificial data generated from a weak back-translation model, effectively utilizing monolingual data to enlarge the bilingual training datasets. In connection, we compare SwitchOut against back translation. We only compare SwitchOut against back translation on the en-de task, where the amount of bilingual training data is already sufficiently large222We add the extra monolingual data from http://data.statmt.org/rsennrich/wmt16_backtranslations/en-de/. The BLEU scores with back-translation are reported in Table 2. These results provide two insights. First, the gain delivered by back translation is less significant than the gain delivered by SwitchOut. Second, SwitchOut and back translation are not mutually exclusive, as one can additionally apply SwitchOut on the additional data obtained from back translation to further improve BLEU scores.

Method en-de
Transformer 21.73
+SwitchOut 22.78
+BT 21.82
+BT +RAML 21.53
+BT +SwitchOut 22.93
+BT +RAML +SwitchOut 23.76
Table 2: Test BLEU scores of back translation (BT) compared to and combined with SwitchOut (median of runs).

Effects of and .

We empirically study the effect of these temperature parameters. During the tuning process, we translate the dev set of the tasks and report the BLEU scores in Figure 1. We observe that when fixing , the best performance is always achieved with a non-zero .

Figure 1: Dev BLEU scores with different and . Top left: WMT 15 en-de. Top right: IWSLT 16 de-en. Bottom: IWSLT 15 en-vi.

Where does SwitchOut Help the Most?

Intuitively, because SwitchOut is expanding the support of the training distribution, we would expect that it would help the most on test sentences that are far from those in the training set and would thus benefit most from this expanded support. To test this hypothesis, for each test sentence we find its most similar training sample (i.e. nearest neighbor), then bucket the instances by the distance to their nearest neighbor and measure the gain in BLEU afforded by SwitchOut for each bucket. Specifically, we use (negative) word error rate (WER) as the similarity measure, and plot the bucket-by-bucket performance gain for each group in Figure 2. As we can see, SwitchOut improves increasingly more as the WER increases, indicating that SwitchOut is indeed helping on examples that are far from the sentences that the model sees during training. This is the desirable effect of data augmentation techniques.

Figure 2: Gains in BLEU of RAML+SwitchOut over RAML. -axis is ordered by the WER between a test sentence and its nearest neighbor in the training set. Left: IWSLT 16 de-en. Right: IWSLT 15 en-vi.

4 Conclusion

In this paper, we propose a method to design data augmentation algorithms by solving an optimization problem. These solutions subsume a few existing augmentation schemes and inspire a novel augmentation method, SwitchOut. SwitchOut delivers improvements over translation tasks at different scales. Additionally, SwitchOut is efficient and easy to implement, and thus has the potential for wide application.


We thank Quoc Le, Minh-Thang Luong, Qizhe Xie, and the anonymous EMNLP reviewers, for their suggestions to improve the paper.

This material is based upon work supported in part by the Defense Advanced Research Projects Agency Information Innovation Office (I2O) Low Resource Languages for Emergent Incidents (LORELEI) program under Contract No. HR0011-15-C0114. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation here on.


  • Amodei et al. (2016) Dario Amodei, Sundaram Ananthanarayanan, Rishita Anubhai, and more authors. 2016. Deep speech 2: End-to-end speech recognition in english and mandarin. In ICML.
  • Belinkov and Bisk (2018) Yonatan Belinkov and Yonatan Bisk. 2018. Synthetic and natural noise both break neural machine translation. In ICLR.
  • Berger et al. (1996) Adam L Berger, Vincent J Della Pietra, and Stephen A Della Pietra. 1996.

    A maximum entropy approach to natural language processing.

    Computational linguistics, 22(1):39–71.
  • Chapelle et al. (2009) Olivier Chapelle, Bernhard Scholkopf, and Alexander Zien. 2009. Semi-supervised learning (chapelle, o. et al., eds.; 2006)[book reviews]. IEEE Transactions on Neural Networks, 20(3):542–542.
  • Clark et al. (2011) Jonathan Clark, Chris Dyer, Alon Lavie, and Noah Smith. 2011. Better hypothesis testing for statistical machine translation: Controlling for optimizer instability. In ACL.
  • DeVries and Taylor (2017) Terrance DeVries and Graham W. Taylor. 2017. Improved regularization of convolutional neural networks with cutout. Arxiv, 1708.04552.
  • Fadaee et al. (2017) Marzieh Fadaee, Arianna Bisazza, and Christof Monz. 2017. Data augmentation for low-resource neural machine translation. In ACL.
  • Gal and Ghahramani (2016) Yarin Gal and Zoubin Ghahramani. 2016.

    A theoretically grounded application of dropout in recurrent neural networks.

    In NIPS.
  • Gu et al. (2018) Jiatao Gu, James Bradbury, Caiming Xiong, Victor O.K. Li, and Richard Socher. 2018. Non-autoregressive neural machine translation. In ICLR.
  • Huang et al. (2016) Gao Huang, Zhuang Liu, Laurens van der Maaten, and Kilian Q. Weinberger. 2016. Densely connected convolutional networks. In CVPR.
  • Kingma and Ba (2015) Diederik P. Kingma and Jimmy Lei Ba. 2015. Adam: A method for stochastic optimization. In ICLR.
  • Krizhevsky et al. (2012) Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. 2012. Imagenet classification with deep convolutional neural networks. In NIPS.
  • Luong and Manning (2015) Minh-Thang Luong and Christopher D. Manning. 2015. Stanford neural machine translation systems for spoken language domain. In IWLST.
  • Luong et al. (2015) Minh-Thang Luong, Hieu Pham, and Christopher D. Manning. 2015. Effective approaches to attention-based neural machine translation. In EMNLP.
  • Ma et al. (2017) Xuezhe Ma, Pengcheng Yin, Jingzhou Liu, Graham Neubig, and Eduard Hovy. 2017. Softmax q-distribution estimation for structured prediction: A theoretical interpretation for raml. Arxiv, 1705.07136.
  • Norouzi et al. (2016) Mohammad Norouzi, Samy Bengio, Zhifeng Chen, Navdeep Jaitly, Mike Schuster, Yonghui Wu, and Dale Schuurmans. 2016. Reward augmented maximum likelihood for neural structured prediction. In NIPS.
  • Poncelas et al. (2018) Alberto Poncelas, Dimitar Shterionov, Andy Way, Gideon Maillette de Buy Wenniger, and Peyman Passban. 2018. Investigating backtranslation in neural machine translation. Arxiv, 1804.06189.
  • Ranzato et al. (2016) Marc’Aurelio Ranzato, Sumit Chopra, Michael Auli, and Wojciech Zaremba. 2016. Sequence level training with recurrent neural networks. In ICLR.
  • Sennrich et al. (2016a) Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016a. Edinburgh neural machine translation systems for wmt 16. In WMT.
  • Sennrich et al. (2016b) Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016b. Improving neural machine translation models with monolingual data. In ACL.
  • Szegedy et al. (2014) Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. 2014. Going deeper with convolutions. In CVPR.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS.
  • Zagoruyko and Komodakis (2016) Sergey Zagoruyko and Nikos Komodakis. 2016. Wide residual networks. In BMVC.

Appendix A Appendix

a.1 Word Dropout as a Special Case

Here, we derive word dropout as an instance of our framework. First, let us introduce a new token, , into both the source vocabulary and the target vocabulary. has the embedding of a all- vector and is never trained. For a sequence of words in a vocabulary with , we define the neighborhood to be:

In other words, consists of and all the sentences obtained by replacing a few words in by . Clearly, all augmented sentences that are sampled from using word dropout fall into .

In (4), the augmentation policy was decomposed into two independent terms, one of which samples the augmented source sentence and the other samples the augmented target sentence

Word dropout is an instance of this decomposition, where takes the same form with , given by:


where . To see this is indeed the case, let be the Hamming distance for and set , then we have:


which is precisely the probability of dropping out words in , where each word is dropped with the distribution .

The difference between word dropout and SwitchOut comes in the fact that is much smaller than the support of that SwitchOut can sample from, which is where is the vocabulary. Word dropout concentrates all augmentation probability mass into while SwitchOut spreads the mass into a larger support, leading to a larger entropy. Meanwhile, both word dropout and SwitchOut are exponentially less likely to diverge a way from , ensuring the smoothness desiderata of a good data augmentation policy, as we discussed in Section 2.3.

a.2 RAML as a Special Case

Here, we present a detailed description of how RAML is a special case of our proposed framework. For each empirical observation ,  RAML defines a reward aware target distribution for the model distribution to match. Concretely, the target distribution in RAML has the form

where is the task reward function. With this definition, RAML amounts to minimizing the expected KL divergence between and , i.e.

where is the marginalized target distribution, i.e. . Now, notice that is a member of the augmentation distribution family in consideration (c.f. Section 2.2). Specifically, it is equivalent to a data augmentation distribution where


The last equality reveals an immediate connection between RAML and our proposed framework. In summary, RAML can be seen as a special case of our data augmentation framework, where the similarity function is defined by (7). Practically, this means RAML only consider pairs with source sentences from the empirical set for data augmentation.

a.3 Datasets Descriptions

vocab (K) #sents
src tgt train dev test
en-vi 17.2 7.7 133.3K 1.6K 1.3K
de-en 32.0 22.8 153.3K 7.0K 6.8K
en-de 50.0 50.0 4.5M 2.7K 2.2K
Table 3: Statistics of the datasets.

Table 3 summarizes the statistics of the datasets in our experiments. The WMT 15 en-de dataset is one order of magnitude larger than the IWSLT 16 de-en dataset and the IWSLT 15 en-vi dataset. For the en-vi task, we use the data pre-processed by Luong and Manning (2015). For the en-de task, we use the data pre-processed by Luong et al. (2015), with newstest2014 for validation and newstest2015 for testing. For the de-en task, we use the data pre-processed by Ranzato et al. (2016).

a.4 Hyper-parameters

Task , init clip
Table 4: Hyper-parameters for our experiments.

The hyper-parameters used in our experiments are in Table 4. All models are initialized uniformly at random in the range as reported in Table 4. All models are trained with Adam (Kingma and Ba, 2015). Gradients are clipped at the threshold as specified in Table 4. For the WMT en-de task, we use the legacy learning rate schedule as specified by Vaswani et al. (2017). For the de-en task and the en-vi task, the learning rate is initially , and is decreased by a factor of for every steps, starting at step . All models are trained for 100,000 steps, during which one checkpoint is saved for each steps and the final evaluation is performed on the checkpoint with lowest perplexity on the dev set.

Multiple GPUs are used for each experiment. For the de-en and the en-vi experiments, if we use GPUs, where , then we only perform updates to the models’ parameters. We find that this is sufficient to make the models converge.

a.5 Source Code for Sampling in TensorFlow

1def hamming_distance_sample(sents, tau, bos_id, eos_id, pad_id, vocab_size):
2  """Sample a batch of corrupted examples from sents.
4  Args:
5    sents: Tensor [batch_size, n_steps]. The input sentences.
6    tau: temperature.
7    vocab_size: to create valid samples.
9  Returns:
10    sents: Tensor [batch_size, n_steps]. The corrupted sentences.
11  """
13  # mask
14  mask = [
15      tf.equal(sents, bos_id),
16      tf.equal(sents, eos_id),
17      tf.equal(sents, pad_id),
18  ]
19  mask = tf.stack(mask, axis=0)
20  mask = tf.reduce_any(mask, axis=0)
22  # first, sample the number of words to corrupt for each sentence
23  batch_size, n_steps = tf.unstack(tf.shape(sents))
24  logits = -tf.range(tf.to_float(n_steps), dtype=tf.float32) * tau
25  logits = tf.expand_dims(logits, axis=0)
26  logits = tf.tile(logits, [batch_size, 1])
27  logits = tf.where(mask,
28                    x=tf.fill([batch_size, n_steps], -float("inf")), y=logits)
30  # sample the number of words to corrupt at each sentence
31  num_words = tf.multinomial(logits, num_samples=1)
32  num_words = tf.reshape(num_words, [batch_size])
33  num_words = tf.to_float(num_words)
35  # <bos> and <eos> should never be replaced!
36  lengths = tf.reduce_sum(1.0 - tf.to_float(mask), axis=1)
38  # sample corrupted positions
39  probs = num_words / lengths
40  probs = tf.expand_dims(probs, axis=1)
41  probs = tf.tile(probs, [1, n_steps])
42  probs = tf.where(mask, x=tf.zeros_like(probs), y=probs)
43  bernoulli = tf.distributions.Bernoulli(probs=probs, dtype=tf.int32)
45  pos = bernoulli.sample()
46  pos = tf.cast(pos, tf.bool)
48  # sample the corrupted values
49  val = tf.random_uniform(
50      [batch_size, n_steps], minval=1, maxval=vocab_size, dtype=tf.int32)
51  val = tf.where(pos, x=val, y=tf.zeros_like(val))
52  sents = tf.mod(sents + val, vocab_size)
54  return sents
Hamming distance sampling in TensorFlow

a.6 Source Code for Sampling in PyTorch

1  """
2  Sample a batch of corrupted examples from sents.
4  Args:
5      sents: Tensor [batch_size, n_steps]. The input sentences.
6      tau: Temperature.
7      vocab_size: to create valid samples.
8  Returns:
9      sampled_sents: Tensor [batch_size, n_steps]. The corrupted sentences.
10  """
12  mask = torch.eq(sents, bos_id) | torch.eq(sents, eos_id) | torch.eq(sents, pad_id)
13  lengths = mask.float().sum(dim=1)
14  batch_size, n_steps = sents.size()
15  # first, sample the number of words to corrupt for each sentence
16  logits = torch.arange(n_steps)
17  logits = logits.mul_(-1).unsqueeze(0).expand_as(
18    sents).contiguous().masked_fill_(mask, -float("inf"))
19  logits = Variable(logits)
20  probs = torch.nn.functional.softmax(logits.mul_(tau), dim=1)
21  num_words = torch.distributions.Categorical(probs).sample()
23  # sample the corrupted positions.
24  corrupt_pos = num_words.data.float().div_(lengths).unsqueeze(
25    1).expand_as(sents).contiguous().masked_fill_(mask, 0)
26  corrupt_pos = torch.bernoulli(corrupt_pos, out=corrupt_pos).byte()
27  total_words = int(corrupt_pos.sum())
28  # sample the corrupted values, which will be added to sents
29  corrupt_val = torch.LongTensor(total_words)
30  corrupt_val = corrupt_val.random_(1, vocab_size)
31  corrupts = torch.zeros(batch_size, n_steps).long()
32  corrupts = corrupts.masked_scatter_(corrupt_pos, corrupt_val)
33  sampled_sents = sents.add(Variable(corrupts)).remainder_(vocab_size)
35  return sampled_sents
Hamming distance sampling in Pytorch