Log In Sign Up

Learning Hard Alignments with Variational Inference

by   Dieterich Lawson, et al.

There has recently been significant interest in hard attention models for tasks such as object recognition, visual captioning and speech recognition. Hard attention can offer benefits over soft attention such as decreased computational cost, but training hard attention models can be difficult because of the discrete latent variables they introduce. Previous work used REINFORCE and Q-learning to approach these issues, but those methods can provide high-variance gradient estimates and be slow to train. In this paper, we tackle the problem of learning hard attention for a sequential task using variational inference methods, specifically the recently introduced VIMCO and NVIL. Furthermore, we propose a novel baseline that adapts VIMCO to this setting. We demonstrate our method on a phoneme recognition task in clean and noisy environments and show that our method outperforms REINFORCE, with the difference being greater for a more complicated task.


Latent Alignment and Variational Attention

Neural attention has become central to many state-of-the-art models in n...

Stochastic gradient variational Bayes for gamma approximating distributions

While stochastic variational inference is relatively well known for scal...

Variational Inference with Continuously-Indexed Normalizing Flows

Continuously-indexed flows (CIFs) have recently achieved improvements ov...

Mixture of Inference Networks for VAE-based Audio-visual Speech Enhancement

In this paper, we are interested in unsupervised speech enhancement usin...

Neural Variational Inference For Estimating Uncertainty in Knowledge Graph Embeddings

Recent advances in Neural Variational Inference allowed for a renaissanc...

Embarrassingly Parallel Variational Inference in Nonconjugate Models

We develop a parallel variational inference (VI) procedure for use in da...

Near-Optimal Glimpse Sequences for Improved Hard Attention Neural Network Training

We introduce the use of Bayesian optimal experimental design techniques ...

1 Introduction

Attention models have gained widespread traction from their successful use in tasks such as object recognition, machine translation, speech recognition where they are used to integrate information from different parts of the input before producing outputs. Soft attention does this by weighting and combining all input elements into a context vector while hard attention selects specific inputs and discards others, leading to computational gains and greater interpretability. While soft attention models are differentiable end-to-end and thus easy to train, hard attention models introduce discrete latent variables that often require reinforcement learning style approaches.

Classic reinforcement learning methods such as REINFORCE [1] and Q-learning [2] have been used to train hard attention models, but these methods can provide high-variance gradient estimates, making training slow and providing inferior solutions. An alternative to reinforcement learning is variational inference, which trains a second model, called the approximate posterior, to be close to the true posterior over the latent variables. The approximate posterior uses information about both the input and its labels to produce settings of the latent variables used to train the original model. This can provide lower-variance gradient estimates and better solutions.

In this paper, we leverage recent developments in variational inference to fit hard attention models in a sequential setting. We specialize these method to sequences and develop a model for the approximate posterior. In response to issues applying variational inference techniques to long sequences, we develop new variance control methods. Finally we show experimentally that our approach improves performance and substantially improves training time for speech recognition on the TIMIT dataset as well as a challenging noisy, multi-speaker version of TIMIT that we call Multi-TIMIT.

2 Methods

(a) Model

(b) Approximate Posterior
Figure 3: A diagram of our models. s denote the Bernoulli emission decision variables, s are inputs, s are targets, and s and

s are the hidden states of the recurrent neural networks (RNNs) that parameterize the conditional distributions of the models. Square nodes are deterministic, round nodes are stochastic. A shaded

indicates that the model chose to consume an input and not emit an output while an unshaded mean that the model chose to produce an output and not consume an input. For example, in (a) note that is shaded, so the model did not produce an output on timestep 1 and instead consumes the input on the next timestep. is unshaded, so on the second timestep the model produced output .

2.1 Model

In this paper we use the online sequence-to-sequence model described in [3] to demonstrate our methods. We model where is a sequence of observed target tokens and is a sequence of observed inputs. The Bernoulli latent variables define when the model outputs tokens, i.e. implies the model emitted a token at timestep , and implies the model did not emit a token at timestep . If , the model is forced to dwell on the same input at the next time step, i.e. the observation fed in at timestep is fed in again at timestep when . Let be the number of target tokens, the number of inputs, and the number of steps the model is run for. Our model assumes factorizes as


where is the position in the output at time and

is the input position at time t. Intuitively, this expression is the product over time of the probability assigned to the current ground truth given that the model emitted, multiplied by the probability that the model emitted. When

the model did not emit at time , so there is no probability assigned to the ground truth on that timestep. For brevity, we will use to implicitly mean (i.e., the target at step ). Similarly, we will refer to as and similarly for ranges over time for these variables.

2.2 Learning

To fit the model (1) with maximum likelihood we are concerned with maximizing the probability of the observed variables . However, (1) is written in terms of the unobserved latents, , so we must marginalize over them. We maximize

where is the state at time and the expectations are over . Note that this is a lower bound on the log probability of the observed , so maximizing this bound will hopefully increase the likelihood of the observed data. Differentiating this objective gives


where is the return at timestep , understood intuitively as the log probability the model assigns to observed data for a given series of emission decisions. The first gradient term can be estimated with a single Monte Carlo sample, but the second term exhibits high variance because it involves an unbounded log probability. To reduce variance, [3] subtracts a learned baseline from the return, which does not change the expectation as long as it is independent of .

Performing stochastic gradient ascent with this gradient estimator is the standard REINFORCE algorithm where the reward is the log-likelihood. Unfortunately, this requires sampling from during training, which can lead to gradient estimates with high variance when settings of that assign high likelihood to are rare [4]. Variational inference is a family of techniques that use importance sampling to instead sample from a different model, called the approximate posterior or , which approximates the true posterior over , . We factorize the approximate posterior as


The approximate posterior has access to all past and future and , as well as past , and leverages this information to assign high probability to that produce large values of . Intuitively, in speech recognition, knowing the token the model must emit is helpful in deciding when to emit.

Using and an importance sampling identity we obtain a lower bound on the log-likelihood


where we can simultaneously optimize and the parameters of the model to improve the lower bound. Optimizing this bound via stochastic gradient ascent can be thought of as training with maximum likelihood to reproduce s sampled from . is then updated with REINFORCE-style gradients where the reward is the log-probability assigns to given , similar to (2), see [4] for details. Setting recovers the REINFORCE objective.

2.2.1 Multi-sample Objectives

Both the REINFORCE and the variational inference objectives admit multi-sample versions that give tighter bounds on the log-likelihood [5]. In particular, the multi-sample variational lower bound is


where is the number of samples and denotes the th sample of the latent variables. Setting recovers the multi-sample analogue to REINFORCE.

The gradient of (5) takes a similar form to (2), with one low-variance term and one REINFORCE-style term with high variance, for details see [4]. Similarly to the REINFORCE objective, we can use a baseline to reduce the variance of the gradient as long as it does not depend on . Notably, the baseline for trajectory is allowed to depend on all timesteps of other trajectories, i.e. .

2.3 Variance Reduction

Training these models is challenging due to high variance gradient estimates. We can reduce the variance of the estimators by using information from multiple trajectories to construct baselines. In particular, for REINFORCE, we can write the gradient update as

where is a baseline for sample that is a function of the th trajectory’s state up to time as well as the returns produced by all other trajectories. The goal is to pick a that is a good estimate of the return, and a straightforward choice of is the average return from the other samples

This ignores the fact that , which can make this standard baseline unusable. For example, in our setting different trajectories may have emitted different numbers of tokens on a given timestep, resulting in substantial differences in return between trajectories that do not indicate the relative merit of those trajectories. Ideally, we would average over multiple trajectories starting from , but this is computationally expensive. In [4] the authors propose the following baseline which adds a residual term to address this. Let be the instantaneous reward at timestep , then the baseline at timestep can be written


This baseline results in a learning signal that is the same across all timesteps, potentially increasing variance as all decisions in a trajectory are rewarded or punished together. We will call this the leave-one-out (LOO) baseline because the baseline for a given sample is constructed using an average of the return of the other samples. Note that VIMCO optimizes the multisample variational lower bound in equation (5) with the leave-one-out baseline, and NVIL optimizes the single sample variational lower bound in equation (4) with a baseline that can be learned or computed from averages [6].

As the return strongly depends on the number of emitted tokens at time , we can instead average the return of the other samples from when they have emitted the same number of tokens as sample . In particular, let be the first timestep when sample has emitted the same number of tokens as sample at timestep , then


We call this new baseline the temporal leave-one-out baseline because it takes into account the temporal reward structure of our setting. This baseline can be combined with the parametric baseline, and is applicable to both variational inference and REINFORCE objectives in single- and multi-sample settings. We explore the performance of these baselines empirically in the experiments section.

3 Related Work

In this section we first highlight the relationship between our model and other models for attention. Tang et. al. [7] proposed visual attention within the context of generative models, while Mnih et. al. [8] proposed using recurrent models of visual attention for discriminative tasks. Subsequently, visual attention was used in an image captioning model [9]

. These forms of attention use discrete variables for attention location. Recently, ‘soft-attention’ models were proposed for neural machine translation and speech recognition

[10, 11]. Unlike the earlier mentioned, hard-attention models, these models pay attention to the entire input and compute features by blending spatial features with an attention vector that is normalized over the entire input. Our paper is most similar to the hard attention models in that features at discrete locations are used to compute predictions. However it is different from the above models in the training method: While the hard attention models use REINFORCE for training, we follow variational techniques. We are also different from the above models in the specific application – attention in our models is over temporal locations only, rather than visual and temporal locations. As a result, we additionally propose the temporal leave-one-out baseline.

Because the attention model we use is hard-attention, the model we use has parallels to prior work on online sequence-to-sequence models [12, 3]. The neural transducer model [12] can use either hard attention, or a combination of hard attention with local soft attention. However it explicitly splits the input sequences into chunks, and it is trained with an approximate maximum likelihood procedure that is similar to a policy search. The model of Luo et. al. [3] is most similar to our model. Both models use the same architecture; however, while they use REINFORCE for training, we explore VIMCO for training the attention model. We also propose the novel temporal LOO baseline. A similar model with REINFORCE has also been used for training an online translation model [13]

and for training Neural Turing Machines

[14]. Our work would be equally valid for these domains, which we leave for future work.

There has also been work using reweighted wake sleep to train sequential models. In [15], Ba et. al. optimize a variational lower bound with the prior instead of using a variational posterior. In this work, we refer to this as REINFORCE to distinguish it from variational inference with an inference network. In [16] the authors revisit this topic, using reweighted wake sleep to train similar models. Their algorithm makes use of an inference network but does not optimize a variational lower bound. Instead they optimize separate objectives for the model and the inference network that produce a biased estimate of the gradient of the log marginal likelihood.

Figure 4: Test set phoneme error rate (PER) curves for models trained with REINFORCE, NVIL, and VIMCO on the TIMIT dataset (left), the Multi-TIMIT mixing proportion dataset (middle), and sample emission decisions for different methods on a TIMIT utterance (right). We evaluated three independent trials for each method. VIMCO converged more quickly than REINFORCE on both datasets. Furthermore, the performance gap between REINFORCE and VIMCO increases with Multi-TIMIT. We hypothesize that because Multi-TIMIT is a more challenging task, having a strong approximation to the posterior lets the model draw attention to the correct positions. NVIL performed well on TIMIT, but struggled with the more challenging Multi-TIMIT (note that only a single trial performs reasonably).

4 Experiments

For our experiments we used the standard TIMIT phoneme recognition task. The TIMIT dataset has 3696 training utterances, 400 validation utterances, and 182 test utterances. The audio waveforms were processed into frames of log mel filterbank spectrograms every 25ms with a stride of 10ms. Each frame had 40 mel frequency channels and one energy channel; deltas and accelerations of the features were append to each frame. As a result each frame was a 123 dimensional input. The targets for each utterance were the sequence of phonemes. We used the 61 phoneme labels provided with TIMIT for training and decoding. To compute the phone error rate (PER) we collapsed the 61 phonemes to 39 as is standard on this task


To model we used a 2-layer LSTM with 256 units in each layer. For the variational posterior we first processed the inputs with a 4-layer bidirectional LSTM and then fed the final layer’s hidden state into a 2-layer unidirectional LSTM along with the current target and the previous emission decision . Each layer had 256 units. Note that in this case the approximate posterior does not have access to at timestep — in practice we found giving access to far in the future did not improve performance.

We regularized the models with variational noise [18] and performed a grid search over the values

for the standard deviation of the noise. We also used L2 regularization and grid searched over the values

for the weight of the regularization.

Method PER
REINFORCE with leave-one-out (LOO) baseline 20.5
NVIL with LOO baseline 21.1
VIMCO with LOO baseline 20.0
REINFORCE with temporal LOO baseline 20.0
NVIL with temporal LOO baseline 21.4
VIMCO with temporal LOO baseline 20.0
Online Alignment RNN (stacked LSTM) [3] 21.5
Neural Transducer with unsupervised alignments [12] 20.8
Online Alignment RNN (grid LSTM) [3] 20.5
Monotonic Alignment Decoder [19] 20.4
Neural Transducer with supervised alignments [12] 19.8
Connectionist Temporal Classification [20] 19.6
Table 1: PER results on TIMIT test set for various models. This shows that REINFORCE performs comparably to the variational inference methods and that our novel baselines improve training for REINFORCE. It also shows that our baselines improve performance over [3] which uses the same model with parametric baselines. Each number is the average of three runs. Our methods are above the horizontal line, while methods from the literature are listed below it.
Method Mixing Proportion
0.50 0.25 0.1
Connectionist Temporal Classification 43.8 33.3 27.3
RNN Transducer 48.9 32.2 25.7
REINFORCE with LOO baseline 42.9 32.5 25.9
NVIL with LOO baseline 70.1 71.8 55.2
VIMCO with LOO baseline 41.7 30.7 25.4
REINFORCE with temporal LOO baseline 43.5 31.6 25.6
NVIL with temporal LOO baseline 74.3 71.9 54.9
VIMCO with temporal LOO baseline 41.7 30.75 25.2
Table 2: PER results on Multi-TIMIT for various algorithms. It can be seen that for this task VIMCO outperforms REINFORCE, and both VIMCO and REINFORCE outperforms RNN trained with Connectionist Temporal Classification significantly. The benefit of VIMCO increases as the second speaker’s volume increases.

4.1 Multi-TIMIT

We generated a multi-speaker dataset by mixing male and female voices from TIMIT. Each utterance in the original TIMIT dataset was paired with an utterance from the opposite gender. The waveform of both utterances was first scaled to lie within the same range, and then the scale of the second utterance was reduced to a smaller volume before mixing the two utterances. We used three different scales for the second utterance: 50%, 25%, and 10%. The new raw utterances were processed in the same manner as the original TIMIT utterances, resulting in a 123 dimensional input per frame. The transcript of the speaker 1 was used as the ground truth transcript for this new utterance. Multi-TIMIT has the same number of train, dev, and test utterances as the original TIMIT, as well as the same target phonemes.

We trained models with the same configuration described above on the different mixing scales, and also trained 2-layer unidirectional LSTM models with Connectionist Temporal Classification for comparison. The results are shown in Table 2.

5 Results

Figure 4 shows a plot of the training curves for the different methods of training and the different datasets. The variational methods (VIMCO and NVIL) require many fewer training steps compared to REINFORCE on both datasets. All methods used the same batch size and number of samples, so training steps are comparable. NVIL performs well enough on a simple task like TIMIT, but struggles with Multi-TIMIT. It can be seen that the gap between REINFORCE and VIMCO increases on Multi-TIMIT (also see table 2).

The right panel of Figure 4 shows that REINFORCE attempts to wait to emit outputs until more information has come in, compared to VIMCO. This is presumably because it requires more information during learning. VIMCO, on the other hand, leverages the variational posterior which can access future and find the optimal place to emit.

In our experiments the difference between the performance of VIMCO and REINFORCE was larger for the more complicated task of Multi-TIMIT than for the simpler task of TIMIT. This can be explained by considering the samples that the models learn from. In the simpler problem of single speaker TIMIT, Monte-Carlo samples generated by REINFORCE have very high likelihood under – there are only a small number of samples that explain the entire probability mass, and these are sampled easily by a left to right ancestral pass (in time) of the model. These are very similar to the samples generated by the approximate posterior from VIMCO. As a result both methods perform approximately the same. In the case of Multi-TIMIT, however, in the ancestral pass the probabilities for individual emissions are much lower. Thus the likelihood is less ’peaked’, and a large diversity of samples is chosen, leading to higher variance and poor learning. VIMCO, on the other hand does not face this problem because it samples from the approximate posterior, which is close to the true posterior and so very peaked around the ‘correct’ samples of experience.

6 Conclusion

In this paper we have showed how we can adapt VIMCO to perform hard attention for the case of temporal problems and introduce a new variance-reducing baseline. Our method outperforms other methods of training online sequence to sequence models, and the improvements are greater for more difficult problems such as noisy mixed speech. In the future we will apply these techniques to other challenging domains, such as visual attention.