Many algorithms in machine translation use an encoder-decoder model, where input sentences are encoded into feature vectors. Translated sentences are generated word by word from the decoder. For each word output, there is a conditional distribution over the word given the input sentence and previous output words. These models are often trained by conditioning on the true output so far in an attempt to maximize the probability of the correct next word. For example, if we have outputwords and an input sentence of length , our distribution over the next output is . This can often choose words that are good in the immediate future rather than optimal sentences in the long run.
Despite this limitation, the current state-of-the-art in neural machine translation such as Edunov et al. (2018) only uses a policy network. In the past, algorithms in machine translation Bahdanau et al. (2017); He et al. (2017)
as well as image captioningRen et al. (2017) managed to improve upon policy-only methods by adding a value network. This implies that the current state-of-the-art models can likely be improved by the addition of a value network, which can be trained jointly with the policy. The policy network is our conditional language model, such as a transformer Vaswani et al. (2017) that outputs a distribution of next words given what we have seen at each time step. The value network predicts the expected reward, the BLEU score, that we would obtain given the current output if we continue following the policy to completion of the sentence. One benefit of jointly training the policy and value networks is that it helps guide the policy network to learn to optimize for longer term rewards such as the final BLEU score. Furthermore, recent papers Bahdanau et al. (2017); Ren et al. (2017) have used actor-critic methods (see Appendix 8.1) to jointly improve the policy and value networks.
to converge on a quality model. This is likely due to the high variance gradient estimates that are used in these long sequence prediction tasks such as machine translation where the target sentence can be over 60 tokens.
We improve on several existing reinforcement learning methods in neural machine translationBahdanau et al. (2017); Wu et al. (2018) by using Monte Carlo Tree Search (MCTS) in a way similar to AlphaZero Silver et al. (2017) so that our model has more stable updates. MCTS in AlphaZero achieved excellent performance in Go where long sequences of moves are predicted, which showed it’s potential as a method to produce more accurate updates to the model than actor-critic methods such as those used in Ren et al. (2017). We conduct several experiments to analyze where these benefits are derived from as well as discuss limitations and possible improvements to our current model.
2 Related Work
Previous works have leveraged a value network to complement the policy network in machine translation Bahdanau et al. (2017); He et al. (2017), image captioning Ren et al. (2017) and playing Go Silver et al. (2017). Ren et al. (2017) designs a model for image captioning and trains a policy and value network in a supervised manner using embedded rewards. The model is then updated through an actor-critic reinforcement learning method similar to A2C (see Appendix 8.1). The authors found that the global guidance introduced by the value network greatly improves performance over just using a policy network. Another paper which uses an actor-critic method in machine translation is Bahdanau et al. (2017) which trains their policy and value networks by reinforcement learning but allows the value network to also take the true output as its input. This method helps to improve the policy by allowing the policy to directly optimize for BLEU score. He et al. (2017) trains a value network to assist the policy in the decoding stage of machine translation but does not do any joint training of the value and policy networks.
As a significant milestone in reinforcement learning, paper Silver et al. (2017) uses self-play and MCTS to train a policy and value network with a shared body, which led to efficient optimization for move predictions in the game of Go. The MCTS algorithm was shown to be a powerful policy improvement and policy evaluation method.
One of the big differences of our method compared to previous works is the addition of a value network with reinforcement learning methods to jointly update the policy and value networks for neural machine translation. The only papers we found that did this were in Bahdanau et al. (2017), and an image captioning paper Ren et al. (2017) where the authors used an actor-critic methods similar to A2C. In contrast, we used MCTS which we found to have some advantages over using the method from Ren et al. (2017) in neural machine translation which we expand upon in section 5.2. In Silver et al. (2017), MCTS is applied to the game of Go instead of machine translation, which we are interested in. One of our contributions is showing that MCTS can be used successfully in this domain by outperforming a model trained using the actor-critic method from Ren et al. (2017). Another contribution is our investigation into where the benefits of the MCTS are derived from (section 5.2) which will help guide future research.
4 Formal Description
Tree node attributes Each node contains a dictionary called ’Edges’ which maps an action to an edge which has attributes, : number of times the edge has been traversed, : sum of values backed up to this edge over all simulations, , which is prior probability of taking this edge which is assigned by the network during the expansion of this node. Each node also has a ’children’ dictionary which maps an action to a child tree node. The node also contains the translation from the root to this node. The attributes and of each edge are initialized as 0.
Key differences from AlphaZero Firstly, there aren’t two players so each simulation is only from one perspective. Secondly, at each node we only store the top 50 edges ranked by their prior probability from our network. This means that when we get our visitation frequencies among those edges, we don’t want to update our policy network directly on the probabilities proportional to the visitation frequencies. This is because these probabilities add to 1 but during the simulations we never gave the algorithm a possibility of visiting the other 6450 branches. To compensate for this, we multiply each probability in ’simVisitationProbs’ by the sum of the prior probabilities of the 50 edges leaving that node and then only calculate the cross entropy loss on these 50 actions we used. We presume this does not drastically change the probabilities of the actions that we never gave the tree search the chance to visit. The lines of the pseudo-code that correspond to these differences are highlighted in yellow.
In this section, we conduct experiments to gauge the effectiveness of using MCTS as a policy improvement operator in neural machine translation as well as to explore where the benefits of the MCTS stem from.
5.1 Dataset, Implementation, and Results
Dataset The dataset we use for our experiments is the IWSLT14 German to English translation dataset which had training size of 149184, validation size of 6784, and test size of 6400 sentences. Our evaluation metric was BLEU calculated by SacreBLEU Post (2018), which is one of the main metrics used to compare machine translation models. The tokenization is done by FairSeq which has a script for preprocessing this dataset.
Implementation We used the transformer architecture for the policy network from Vaswani et al. (2017) except reduced the size to allow for faster computation and training. We used layer sizes as follows: word embedding dimension is 128, number of attention heads is 8, number of encoder and decoder layers is 4, the dimension of the feedforward network is 512. The value network uses the same architecture except that the output layer mapped to size 1 instead of the length of the target vocabulary. The supervised training of the policy was done using cross entropy loss with teacher forcing, and the Adam optimizer with learning rate schedule from Vaswani et al. (2017). This model is referred to below as ’Supervised Policy’. The initial training of the value network was done in the same manner as in Ren et al. (2017) which we describe in the appendix. We used a dropout of 0.2 in the encoder and decoder layers of our models during training. As one of our baselines, we used the ’Supervised Policy’ updated with policy gradients estimated through REINFORCE as done in Wu et al. (2018) which achieved SOTA performance in 2018. We refer to this model as ’Policy+RL’. We then jointly trained the supervised policy and value network using an algorithm similar to A2C as done in Ren et al. (2017)
. We refer to this model as ’Actor-Critic’. Further details on the training of ’Policy+RL’ and ’Actor-Critic’ are in appendix 8.1. Finally we also jointly trained the policy and value network using MCTS as described in algorithms 1-3. Further specifics of the implementation are in Appendix 8.2. The best hyperparameters we found were, and 100 rollouts per action. Note that because of computational constraints the number of rollouts per action wasn’t increased past 100. We would expect improved performance if this number was increased as it would lead to lower variance updates to the model.
Results We compare the results of the models below. For each model, after training we take the policy from that model and run it on the test set using greedy decoding to get the SacreBLEU Post (2018) score. This is because of the large amount of time it would take to run our MCTS on the entire test set.
An observation from our experiments is that we weren’t actually able to improve the policy using the actor-critic method in Ren et al. (2017). Several tweaks had to be made to the actor-critic algorithms in Bahdanau et al. (2017); Ren et al. (2017) such as having target networks updated only so often, as well as curriculum training to allow the algorithm to converge on a quality model. We didn’t implement these additions which could be the cause of our actor-critic model not improving past the ’Policy+RL’ model.
5.2 Where do the benefits of MCTS come from?
We now explore where the benefits of MCTS over these baseline models come from. There are two clear areas where the benefits of our methods could be derived from that we will explore.
Is it just because we’re doing so many simulations at each step that we could get the actor-critic model to do similarly by increasing batch size? In other words, are we getting a lower variance gradient estimate purely from increased computation?
To perform a fair comparison, we increased the batch size for the policy gradient method until the run time to process the batch was equivalent to the time per batch for our MCTS, which was variable but around 22 minutes. For MCTS this corresponds to 256 sentences and for the policy gradient method this corresponds to 27200 sentences. This is a rough comparison as we could speed up the MCTS significantly as mentioned in the next section. Even when increasing the batch size to 27200 sentences, the validation BLEU score for ’Policy+RL’ was 27.79. This BLEU score is much lower than the BLEU of 28.37 obtained from training the policy using the MCTS method. This shows that the advantage of the MCTS over the ’Policy + RL’ model likely does not only come from the increased computation.
Is it the addition of the value network that is improving the performance over the ’Policy + RL’ model?
To test this, we adjusted the MCTS algorithm to run without a value network, where we only backed up values from states which corresponded to a translation ending in an token. The value backed up was the BLEU score of the translation. It was only during these backups that the visitation frequencies of the edges on the path were incremented. This was fair since we did the same thing during training when a value network is present. Results of the policy obtained from using this modified algorithm during training are in Appendix 8.3, where it achieves better performance than ‘Policy + RL’ (BLEU of 28.26 vs 27.78) and similar performance to the version of MCTS that uses a value network (BLEU of 28.26 vs 28.37). We hypothesize that the reason for this improvement is that the MCTS algorithm only updates one state at a time and the visitation frequencies that are used to update the policy were guided by both accurate Q-values as well as our previous prior probabilities. This leads to low variance gradients, whereas our policy gradients estimated with REINFORCE are trying to update the probabilities of many states using only a sparse reward for the final state. These gradient estimates have high variance since many of the sentences contain more than 40 tokens, which makes it hard for the algorithm to determine the effect each word in the sequence had on the reward.
6 Limitations and Future Improvement
Computational Cost Running this algorithm takes a lot of computational resources and time. For example, translations for 16 sentences of length 19 took approximately 2.5 minutes when only using 100 simulations per step. This means it would be quite time consuming to train a policy and value network from scratch with our implementation. Instead, what we found is that using a pre-trained policy and value network, we can quickly improve these networks within a couple of gradient steps, which makes it computationally feasible. The main change that we think would improve the speed of our algorithm is to use more efficient inter-process communication. Currently, the majority of the time is spent on processes waiting for results from the model. For example, each time a node in a tree needs to be expanded, the process in charge of that tree sends the current state to a process which runs the neural network with that state and then returns the results. Currently the time it takes for the model to run is 10 times less than the amount of time that the process, which sends a request to the model, waits for results.
Search Diversity In the case of machine translation, there were 6565 tokens we could output at any one step. This large action space would cause the memory storage of each node to be very large (each array would be size 6565) so we only stored a small subset of these actions as children of a tree node. In our case, since the model was initially given pre-trained policy and value networks, the policy had already decently narrowed down what the good actions were at each step by the probability mass that it assigned to each word. This allowed for us to only use the top 50 children of each state based on the prior policy probabilities. This is not a perfect solution as at times the true word will be outside the top 50 predicted by the policy but our algorithm relies on this not happening too often. Our current implementation would make it impossible to train from scratch without some modifications since we are throwing out the majority of the possibilities at each step. Even with modifications, it would be difficult to train from scratch because of the massive action space which would be hard to fully explore. It might be possible to utilize a computational oracle constructed from the ground truth to better select the candidate words at each tree expansion.
In the training phase, we have access to the ground truth translation tokens at each time step. The gold translations are not utilized until we simulate an entire sentence to compute BLEU scores. This can cause undesirable paths to be taken during the MCTS while the model is training. There is potential to utilize imitation learning techniques that interpolates between pure supervision and reinforcement learning to better guide MCTS training of the policy along simulation paths.
In this project, we presented a modification of the MCTS algorithm from AlphaZero Silver et al. (2017) to jointly train policy and value networks in neural machine translation. We compared the performance of this algorithm with a model which was updated using estimated policy gradients as well as a model that updates value and policy networks with an actor-critic method similar to A2C. Our results on the IWSLT14 dataset showed performance increases in test set BLEU scores over the other two methods. This shows that many of the current SOTA models such as Edunov et al. (2018) which purely use a policy could likely be improved through updating the policy using our MCTS algorithm. This could either be done by using our modified MCTS which only uses a policy or our original MCTS in which case a value network must be pre-trained for that policy. This algorithm looks to act as a promising alternative to other actor-critic methods such as those used in Bahdanau et al. (2017); He et al. (2017). Lastly, our experiments were confined to the domain of language translation but it is possible to apply this method to other sequence prediction tasks such as in image captioning which is a closely related to the problem of neural machine translation.
- An actor-critic algorithm for sequence prediction. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, Cited by: §1, §1, §1, §2, §2, §5.1, §7.
- Understanding back-translation at scale. CoRR abs/1808.09381. External Links: Cited by: §1, §7.
- Decoding with value networks for neural machine translation. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017, Long Beach, CA, USA, pp. 178–187. External Links: Cited by: §1, §2, §7.
- A call for clarity in reporting BLEU scores. In Proceedings of the Third Conference on Machine Translation: Research Papers, Belgium, Brussels, pp. 186–191. External Links: Cited by: §5.1, §5.1.
- Deep reinforcement learning-based image captioning with embedding reward. In , pp. 1151–1159. External Links: Cited by: §1, §1, §1, §2, §2, §5.1, §5.1, §8.1, §8.2.
- Mastering the game of go without human knowledge. Nature 550, pp. 354–. External Links: Cited by: Neural Machine Translation with Monte-Carlo Tree Search, §1, §2, §2, §2, Figure 1, §7, §8.3.
- Attention is all you need. CoRR abs/1706.03762. External Links: Cited by: §1, §5.1.
- Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning 8 (3), pp. 229–256. External Links: Cited by: §8.1.
A study of reinforcement learning for neural machine translation.
Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, Brussels, Belgium, October 31 - November 4, 2018, pp. 3612–3621. External Links: Cited by: §1, §5.1.
In the model above called ’Policy + RL’, the policy is updated using estimates of the policy gradient which are obtained through REINFORCE Williams (1992). To implement this, we first simulate a translation to the input sentence of length which gives us states during the translation and final BLEU score b. Then, we suppose parameters of the policy are . The gradient estimate for the policy is then estimated as
In the model ’Actor-Critic’, the policy and value network are jointly updated as in Ren et al. (2017). The algorithm works as follows: 1. We simulate a translation to the input sentence of length T which gives us states during the translation and final Bleu score . 2. Suppose parameters of the policy are and parameters of the value network are . The gradient estimate for the policy is then estimated as: and the gradient estimate for the value network is: .
8.2 Additional Implementation Details
Initial Value Network Training To initially train the value network, we used the same training scheme as in Ren et al. (2017) where for each input sentence, a translation is simulated using the trained policy (and greedy decoding) and a BLEU score is produced. Then, an index is uniformly picked within this translation and we use as input to the value network decoder, where is minimized. Not every prefix of x is used because there is a strong correlation between successive prefixes (the state corresponding to that prefix). So we do not want to update the model using similar states successively as this could lead to over-fitting.
We processed batches of 64 sentences at a time and created one process per sentence in our batch so that the tree searches for each sentence could run in parallel. There is one main process that controls the model. Each time a roll-out in one of the trees reaches a leaf node and needs to expand a state, it sends the state tensor to the main process. Thus, the neural network is able to run with a batch of state tensors and then sends the expansion results back to each of these processes. The inter-process communication took the majority of the time, as we were able to write the code for the rollouts in optimized Cython which sped up the algorithm drastically.
Now from each of these sentences, training data were gathered (input sentence, current translation, , BLEU) as described in the algorithm section as well as in Figure 1. To update the model we would first process 4 batches (256 sentences), and then perform 8 iterations of randomly drawing 256 samples from the dataset collected during the processing of these batches, run the model and get the loss as described in Algorithm 3, back propagate, then we do a gradient step.
8.3 Is a Shared Architecture Necessary? Is the Value Network Necessary?
We did several experiments to see if a shared architecture for the policy and value network would lead to performance benefits like it did in Silver et al. (2017). To do this, we used the policy encoder as the encoder for both the policy and value networks. We hypothesized that the encoder’s job for both the value and policy networks are similar in nature which means that the shared architecture can help performance. The encoder learns to create embedding of the input sentences which can be useful for learning what the long term value of choosing the next word is as well as what the probability of choosing that next word should be. To test this theory, we did several experiments where everything was kept constant except whether the encoder was shared or not.
Another interesting question we investigated was if the value network was even necessary. We mention how we modified the MCTS algorithm to only use a policy in section 6. This model is ’No Value’ in the table below.
To investigate both of these questions, we do several experiments where the hyper parameter is manipulated while temperature , number of rollouts per action, Adam learning rate . The runs took close to 5 hours per set of hyper parameters so we only tested several. Here are the validation set BLEU scores of our experiments:
|Joint Network||Disjoint Network||No Value|
The results for shared vs disjoint architectures look very similar and will require more extensive experiments to determine if there is a true difference. The results from running the MCTS algorithm with only a policy look promising as well and look to have just slightly worse performance than when we include a value network. These results show that there may be some benefit in including the value network but would also require more extensive experimentation to draw firm conclusions.