Predictive Sampling with Forecasting Autoregressive Models

02/23/2020 ∙ by Auke J. Wiggers, et al. ∙ 52

Autoregressive models (ARMs) currently hold state-of-the-art performance in likelihood-based modeling of image and audio data. Generally, neural network based ARMs are designed to allow fast inference, but sampling from these models is impractically slow. In this paper, we introduce the predictive sampling algorithm: a procedure that exploits the fast inference property of ARMs in order to speed up sampling, while keeping the model intact. We propose two variations of predictive sampling, namely sampling with ARM fixed-point iteration and learned forecasting modules. Their effectiveness is demonstrated in two settings: i) explicit likelihood modeling on binary MNIST, SVHN and CIFAR10, and ii) discrete latent modeling in an autoencoder trained on SVHN, CIFAR10 and Imagenet32. Empirically, we show considerable improvements over baselines in number of ARM inference calls and sampling speed.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 8

page 12

page 13

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep generative models aim to approximate the joint distribution

of high-dimensional objects, such as images, video and audio. When a model of the distribution is available, it may be used for numerous applications such as anomaly detection, inpainting, super-resolution and denoising. However, modeling high-dimensional objects remains a notoriously challenging task.

Figure 1: Overview of predictive sampling. A sequence-so-far of ARM samples is extended with forecasts and given as input to the ARM. As the ARM has strict triangular dependence, its first output is valid: the conditioning consists only of ARM samples. If the forecast is equal to , the next output is also valid. Note that some connections are omitted for aesthetic purposes.

A powerful class of distribution models called deep autoregressive models (ARMs) (Bengio and Bengio, 2000; Larochelle and Murray, 2011)

decomposes the high-dimensional joint distribution into single-dimensional conditional distributions, using the chain rule from probability theory. Neural network based ARMs currently hold state-of-the-art likelihood in image and audio domains

(van den Oord et al., 2016a, b; Salimans et al., 2017; Chen et al., 2018; Menick and Kalchbrenner, 2018; Child et al., 2019).

A major limitation of ARMs is that the autoregressive computation can be parallelized only in a single direction: either evaluation or sampling. Generally, these models are trained using likelihood evaluation and training is computationally orders of magnitude more expensive than sampling. As such, ARMs are designed to allow for fast evaluation, but sampling from these models is prohibitively slow. In the literature, there are methods that try to accelerate sampling by breaking autoregressive structure a priori, but consequently suffer from a decrease in likelihood performance (Reed et al., 2017). Another approach approximates an autoregressive density model via distillation (van den Oord et al., 2018), but this method provides no guarantees that samples from the distilled model originate from the original model distribution.

(a) Given is a sequence-so-far for a position

, here one-hot encoded. A future sequence

is predicted by forecasting function and is appended to .
(b) An inference pass of the model conditioned on gives the distribution parameters for every time step, which allows sampling , the outputs of the ARM.
(c) For each forecast in that equals the sampled value , the position is incremented, without any additional inference calls. This process repeats until . If at some point the algorithm returns.
(d) Although , the output is a valid sample, as its distribution depends on , which is equal to

. The value is written to the forecast vector,

is incremented by one, and the process repeats from (a).
Figure 2: One iteration of predictive sampling with forecasting autoregressive models.

This paper proposes a new algorithm termed predictive sampling, which 1) accelerates discrete ARM sampling, 2) keeps autoregressive structure intact, and 3) samples from the true model distribution. Predictive sampling forecasts which values are likely to be sampled, and uses the parallel inference property of the ARM to reduce the total number of required ARM forward passes. To forecast future values, we introduce two methods: ARM fixed-point iteration and learned forecasting. These methods rely on two insights: i) the ARM sampling procedure can be reparametrized into a deterministic function and independent noise, and ii) activations of the penultimate layer of the ARM can be utilized for computationally efficient forecasting. We demonstrate a considerable reduction in the number of forward passes, and consequently, sampling time, on binary MNIST, SVHN and CIFAR10. Additionally, we show on the SVHN, CIFAR10 and Imagenet32 datasets that predictive sampling can be used to speed up ancestral sampling from a discrete latent autoencoder, when an ARM is used to model the latent space. For a visual overview of the method, see Figure 1.

2 Methodology

Consider a variable , where is a discrete space, for example for 8 bit images, where is the dimensionality of the data. An autoregressive model views as a sequence of 1-dimensional variables , which suggests the following universal probability model (Bengio and Bengio, 2000; Larochelle and Murray, 2011):

(1)

Universality follows from the chain rule of probability theory, and denotes the values . Samples from the model are typically obtained using ancestral sampling:

(2)

In practice, ARMs can be implemented efficiently using deep neural networks. Let be a strictly autoregressive function such that when , the representation depends only on input values . The parameters for the distribution over are then an autoregressive function of the representation such that depends on . Using this formulation it is possible to parallellize ARM inference, i.e., to obtain a log-likelihood for every variable in parallel. However, in this setting, naïve sampling from an ARM requires forward calls that cannot be parallelized.

2.1 Predictive Sampling

Consider the naïve sampling approach. First, it computes an intermediate ARM representation and distribution parameters , then samples the first value . Only then can the next representation be computed.

In the setting of predictive sampling, suppose now that we can obtain a forecast , which is equal to with high probability. In this case can be computed in parallel, where is valid (i.e., it is equal to ) if equals . We proceed as before and sample . If the sampled is indeed equal to the forecast , the representation is valid, and we can immediately sample without additional calls to . More generally, for a sequence of correct forecasts, re-computations of are saved. A general description of these steps can be found in Algorithm 1, where denotes vector concatenation and denotes a function that outputs forecasts starting at location . A corresponding visualization is shown in Figure 2.

  Input: ,
  Output:
  let ,
  while  do
           // Forecast
     
     
     // While forecast is correct, output is valid
     while  and   do
        
     end while
     if  then
              // Overwrite the input vector
        
     end if
  end while
Algorithm 1 Predictive Sampling

2.2 Forecasting

The forecasting function can be formalized as follows. Specifically, consider the vector in Algorithm 1, which contains valid samples until position , i.e., the variables are valid samples from the ARM. Let denote a forecast for the variable . A forecasting function aims to infer the most likely future sequence starting from position given all available information thus far:

(3)

Using this notion, we can define the predictive sampling algorithm, given in Algorithm 1. It uses the forecasting function to compute a future sequence , and combines known values with this sequence to form the ARM input. Utilizing this input the ARM can compute an output sequence for future time steps . Note that the first ARM output is a sample from the true model distribution, as the conditioning does not contain forecasts. For each consecutive step where forecast is equal to the ARM output , the subsequent output is valid as well. When the forecasting function does not agree with the ARM, we write the last valid output to the input vector , and proceed to the next iteration of predictive sampling. This process is repeated until all variables have been sampled.

Isolating stochasticity via reparametrization

The sampling step introduces unpredictability for each dimension, which may fundamentally limit the number of subsequent variables that can be predicted correctly. For example, even if every forecasted variable has a chance of being correct, the expected length of a correct sequence will only be .

To solve this issue, we reparametrize the sampling procedure from the ARM using a deterministic function and a stochastic noise variable . A sample can equivalently be computed using the deterministic function conditioned on random noise :

(4)

Such a reparametrization always exists for discrete distributions, and as a consequence the sampling procedure from the ARM has become deterministic, conditioned on random noise . This is an important insight, because the reparametrization resolves the aforementioned fundamental limit of predicting a stochastic sequence. For instance, consider an ARM that models categorical distributions over using log-probabilities . One method to reparametrize categorical distributions is the Gumbel-Max trick (Gumbel, 1954), which has recently become popular in machine learning (Maddison et al., 2014, 2017; Jang et al., 2017; Kool et al., 2019). By sampling standard Gumbel noise the categorical sample can be computed using:

(5)

where represents a category and is the log probability of dimension for class .

Shared Representation

In theory, the future sequence can be predicted perfectly using and , as it turns the ARM into a deterministic function . In practice however, the amount of computation that is required to predict the future perfectly may exceed the computational cost of the ARM. The ARM representation from the previous iteration of predictive sampling is therefore shared with the forecasting function for efficiency:

(6)

When forecasts starting from location are required, the variables are already valid model outputs. In the previous iteration of predictive sampling, input was valid, and therefore the representation is valid as well. Although it is possible to obtain an unconditional forecast for , we use a zero vector as initial forecast.

Note that theoretically, conditioning on does not change anything, as the data processing inequality states that no post-processing function can increase the information content. However, from a practical perspective, is a convenient representation that summarizes the input in a way that is easy to reuse.

2.3 ARM Fixed-Point Iteration

The first method we introduce for forecasting is ARM Fixed-Point Iteration (FPI), which utilizes the ARM itself as forecasting function. Specifically, a forecast at step is obtained using the ARM reparametrization , where noise is isolated:

(7)

Note that is a concatenation of the valid samples thus far and the forecasts from the previous iteration of predictive sampling (as in Algorithm 1). In other words, current forecasts are obtained using ARM inputs that may turn out to be invalid. Nevertheless, the method is compelling because it is computationally inexpensive and requires no additional learned components: it simply reuses the ARM output.

Interestingly, the combination of forecasting with Equation 7 and Algorithm 1 is equivalent to a reformulation as a fixed-point iteration using the function defined over all indices of :

(8)

where denotes the iteration number of predictive sampling. We show this reformulation in Algorithm 2. This equivalence follows because ARM outputs are fixed if their conditioning consists of samples that are valid, i.e., for the outputs equal the inputs . The future outputs are automatically used as forecasts. The algorithm is guaranteed to converge in steps because the system has strictly triangular dependence, and may converge much faster if variables do not depend strongly on adjacent previous variables.

  Input: ,
  Output:
  let ,
  repeat
     
     
  until 
Algorithm 2 ARM Fixed-Point iteration

2.4 Learned forecasting

ARM fixed-point iteration makes use of the fact that the ARM outputs distributions for every location . However, many output distributions are conditioned on forecasts from the previous iteration of predictive sampling, and these may turn out to be incorrect. For example, if in the first iteration of the algorithm we find that forecast does not match , the procedure will still use the sampled as input in the second iteration. In turn, this may result in an incorrect forecast . In the worst case, this leads to cascading errors, and ARM inference calls are required.

To address this problem, we introduce learned forecasting, an addition to ARM fixed-point iteration. We construct forecasting modules: small neural networks that are trained to match the distribution . As the forecasting networks are learned, they can only utilize information that will be available during sampling. For that reason, they are conditioned on the available valid information, and .

In particular, a forecasting module at timestep is designed to output a distribution that will be trained to match the ARM distribution at that location with the important difference that is conditioned only on (and also ), whereas the ARM is conditioned on . In our experiments we minimize the distance between corresponding distributions and using the KL divergence:

(9)

with respect to the forecasting module for each future step . The gradient path from the model is detached in this divergence.

After training, forecasts can be obtained via the forecasting distributions and reparametrization noise. For example, when and are categorical distributions:

(10)

where is the log-probability that according to the forecasting distribution. In practice, a sequence of forecasts is obtained by concatenating forecasting modules , where and is the window in which we forecast future values.

We find that explicitly conditioning on in combination with does not result in a noticeable effect on performance for the forecasting module capacity we consider. Instead it suffices to solely condition on . The representation is shared and trained jointly for both the ARM and the forecasting modules, but the forecasting objective is down-weighed with a factor of so that the final log-likelihood performance is not affected. While it is possible to train forecasting modules on samples from the model distribution, we only train on samples from the data distribution as the sampling process is relatively slow.

3 Related work

Neural network based likelihood methods in generative modelling can broadly be divided into VAEs (Kingma and Welling, 2014; Rezende et al., 2014), Flow based models (Dinh et al., 2017) and autoregressive models (Bengio and Bengio, 2000; Larochelle and Murray, 2011). VAEs and Flows are attractive when fast sampling is important, as they can be constructed without autoregressive components that need inverses. However, in terms of likelihood performance, ARMs currently outperform VAEs and Flows and hold state-of-the-art in image and audio domains (van den Oord et al., 2016b, a; Salimans et al., 2017; Chen et al., 2018; Child et al., 2019).

One of the earliest neural network architectures for autoregressive probability estimation of image data is NADE

(Larochelle and Murray, 2011). This model employs a causal structure, i.e., nodes of the network are connected in such a way that layer output only depends on a set of inputs . Numerous follow up works by Germain et al. (2015); van den Oord et al. (2016b); Akoury and Nguyen (2017); Salimans et al. (2017); Menick and Kalchbrenner (2018); Sadeghi et al. (2019) improve on this idea, and increase likelihood performance by refining training objectives and improving network architectures.

There are various approaches that aim to capture the performance of the ARM while keeping sampling time low. The autoregressive dependencies can be broken between some of the dimensions, which allows some parts of the sampling to run in parallel, but comes at the cost of decreased likelihood performance (Reed et al., 2017). It is possible to train a student network using distillation (van den Oord et al., 2018), but in this case samples from the student network will not come from the (teacher) model distribution.

An alternative method that does preserve the model structure relies on caching of layer activations to avoid duplicate computation (Ramachandran et al., 2017). To determine when to read from and write to the cache, the algorithm must be specified beforehand in accordance with the architecture of the ARM. In addition, activations of the network need to be stored and thus have a larger memory cost when sampling. In contrast, our method does not require knowledge of model-specific details and does not require additional memory to store activations.

4 Experiments

Batch size 1 Batch size 32
ARM calls Time (s) Speedup ARM calls Time (s) Speedup
MNIST (1 bit) Baseline 100.0% 0.0 16.6 0.1 100.0% 0.0 24.1   0.4
Forecast zeros 14.5% 5.0 2.4 0.8 25.0% 0.1 7.6   1.0
Predict last 7.8% 1.5 1.5 0.4 10.0% 0.6 3.8   0.3
Fixed-point iteration 3.3% 0.9 0.6 0.1 27.6 5.2% 0.4 2.8   0.2 8.6
+ Forecasting () 3.3% 0.6 0.7 0.2 4.3% 0.3 2.8   0.5 8.6
SVHN (8 bit) Baseline 100.0% 0.0 145.7 0.8 100.0% 0.0 1174   5.7
Fixed-point iteration 22.0% 1.2 32.2 1.7 4.5 28.0% 1.8 327 19.9 3.6
+ Forecasting () 36.9% 2.7 57.3 4.2 46.5% 1.9 547 22.4
CIFAR10 (5 bit) Baseline 100.0% 0.0 148.2 0.5 100.0% 0.0 1114   3.6
Fixed-point iteration 15.6% 2.1 23.3 3.0 6.4 16.7% 0.4 239   0.6 4.7
+ Forecasting () 23.2% 2.9 35.6 4.3 27.5% 1.1 311 10.4
CIFAR10 (8 bit) Baseline 100.0% 0.0 145.7 0.8 100.0% 0.0 1174   5.7
Fixed-point iteration 22.0% 2.0 32.0 2.9 4.6 25.9% 1.1 305 11.4 3.8
+ Forecasting () 43.1% 5.5 65.1 8.2 50.9% 1.8 597 21.6
+ Forecasting () 59.8% 2.9 94.5 4.4 67.2% 0.6 842   6.2
Table 1:

Performance of predictive sampling for ARMs trained on explicit likelihood modeling tasks, in terms of percentage of forward passes with respect to the original sampling procedure, and total time to sample. All reported times are based on own implementation. Reported means and (Bessel-corrected) standard deviations are based on sampling of 10 batches with random seeds

.

Predictive sampling is evaluated in two settings: An ARM is trained on images, we refer to this task as explicit likelihood modeling. Secondly, an ARM is trained on the discrete latent space of on autoencoder.

The used datasets are Binary MNIST (Larochelle and Murray, 2011), SVHN (Netzer et al., 2011), CIFAR10 (Krizhevsky et al., 2009), and ImageNet32 (van den Oord et al., 2016b). We use the standard test split as test data, except for Imagenet32, for which no test split is available and we use the validation split as test data. As validation data, we use the last 5000 images of the train split for MNIST and CIFAR10, we randomly select 8527 images from the train split for SVHN, and we randomly select images from the train split for Imagenet32. For all datasets, the remainder of the train split is used as training data.

The ARM architecture is based on (Salimans et al., 2017), with the fully autoregressive categorical output distribution of (van den Oord et al., 2016b). The categorical output distribution allows us to scale to an arbitrary number of channels without substantial changes to the network architecture. The autoregressive order is a raster-scan order, and in each spatial location an output channel is dependent on all preceding input channels.

All experiments were performed using PyTorch version 1.1.0

(Paszke et al., 2019)

. Training took place on Nvidia Tesla V100 GPUs. To obtain sampling times, measurements were taken on a single Nvidia GTX 1080Ti GPU, with Nvidia driver 410.104, CUDA 10.0, and cuDNN v7.5.1. For a full list of hyperparameters, data preprocessing steps and machine specifications, see Appendix 

A.

4.1 Predictive sampling of image data

Setting

In this section the performance of predictive sampling for explicit likelihood modelling tasks is tested on binary MNIST, SVHN and CIFAR10. We use the same architecture for all datasets but binary MNIST, for which we decrease the number of layers and filters to prevent overfitting. Each ARM is optimized using the log-likelihood objective and performance is reported in bits per dimension (bpd), which is the negative log-likelihood in base two divided by the number of dimensions. After 200000 training iterations, the test set performance of the ARMs is bpd on binary MNIST, bpd on SVHN, bpd on CIFAR10 5-bit and bpd on CIFAR10 8-bit. Further details on the architecture and optimization procedure are described in Appendix A.

For the forecasting modules, we choose a lightweight network architecture that forecasts future timesteps. A triangular convolution is applied to

, the hidden representation of the ARM. This is followed by a

convolution with a number of output channels equal to the number of timesteps to forecast multiplied by the number of input categories. The number of forecasting modules is for binary MNIST and or for other datasets (the exact number is specified in brackets in the results). Forecasts for all remaining future timesteps are taken from the ARM output, as this does not require additional computation.

(a) Samples from the model distribution .
(b) Forecasting mistakes by the forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 3: Samples from the 1-bit ARM. Forecasting mistakes are shown in red.

Performance

Sampling performance for ARMs is presented in Table 1. For each dataset, we list the percentage of ARM calls with respect to the default sampling procedure, as well as the total runtime during sampling of a batch. Results are reported for batch sizes 1 and 32. In this implementation, the slowest image determines the number of ARM inference passes. We leave the implementation of a scheduling system to future work, which would allow sampling at an average rate equal to the batch size 1 setting.

Fixed-point iteration and learned forecasting greatly outperform the standard baseline on all datasets. To put the improvements in perspective, we introduce two additional baselines for binary MNIST: forecast zeros and predict last. The first baseline simply forecasts for all future timesteps , and the second baseline repeats the last observed value . On binary MNIST, both fixed-point iteration and learned forecasting outperform these baselines.

Batch size 1 Batch size 32
ARM calls Time (s) Speedup ARM calls Time (s) Speedup
SVHN Baseline 100.0% 0.0 12.1 0.0 1.0 100.0% 0.0 12.6 0.2 1.0
Fixed-point iteration 15.0% 2.6 1.9 0.3 6.4 20.3% 1.2 3.1 0.2 4.1
+ Forecasting () 16.9% 2.8 2.2 0.3 5.5 24.9% 2.9 3.8 0.4 3.3
CIFAR10 Baseline 100.0% 0.0 12.1 0.0 1.0 100.0% 0.0 12.7 0.1 1.0
Fixed-point iteration 17.6% 2.9 2.2 0.4 5.5 24.3% 2.0 3.6 0.3 3.6
+ Forecasting () 19.7% 3.7 2.6 0.4 4.6 26.4% 1.6 4.0 0.2 3.2
ImageNet32 Baseline 100.0% 0.0 12.1 0.0 1.0 100.0% 0.0 12.9 0.0 1.0
Fixed-point iteration 13.8% 3.1 1.8 0.3 6.7 20.9% 2.6 3.1 0.3 4.2
+ Forecasting () 14.2% 2.0 1.9 0.4 6.4 23.0% 2.3 3.5 0.4 3.7
Table 2: Performance of predictive sampling for ARMs trained on the latent space of an autoencoder, in terms of percentage of forward passes with respect to the original sampling procedure, and total time to sample. All reported times are based on own implementation. Reported means and (Bessel-corrected) standard deviations are based on sampling of 10 batches with random seeds .
(a) Samples from the model distribution .
(b) Forecasting mistakes by the forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 4: Samples from the 5-bit ARM. The shade of red indicates the number of forecasting mistakes for that location.

Comparing the sampling speed for 5-bit and 8-bit CIFAR, we observe that when data has a lower-bit depth, it is generally easier to predict future variables. This can likely be attributed to the lower number of categories. Typically SVHN is considered to be an easier dataset to model than CIFAR10, a claim which is also supported by the negative log-likelihood of 1.81 for SVHN versus 3.05 for CIFAR10. Interestingly, we find that SVHN is not necessarily easier in the case of predictive sampling. Comparing the ARM calls for SVHN and CIFAR10 when using fixed-point iteration, both models require approximately 22% of the ARM calls. This suggests that the performance of predictive sampling depends mostly on the number of categories and less on the modeling difficulty of the data.

Furthermore, while forecasting seems to work well for binary MNIST, the results do not transfer to the more complicated datasets. For CIFAR10, we observe that increasing the number of forecasting modules decreases performance. Note also that for binary MNIST the runtime overhead of the forecasting modules negates the effect of the reduced number of ARM inference.

To aid quantitative analysis, model samples and corresponding forecasting mistakes are visualized in Figure 3 and 4. In these figures, red pixels highlight in which locations in the image the forecast was incorrect, for both forecasting modules and fixed-point iteration. As color images consist of three channels, mistakes are visualized using , or red depending on the number of channels that were predicted correctly. For binary MNIST samples (Figure 3) it is noticeable that mistaken forecasts do not necessarily lie on the edge of the digits. Instead, we observe that the forecasting methods are able to correctly transition over edges at certain areas, which indicates that more sophisticated patterns are used than, for example, simply repeating the last observed value. For more complicated 5-bit CIFAR data (Figure 4) there are generally more mistakes in the top row and on the right side of the images. An explanation for this may be that the ARM dependency structure is from left to right, and top to bottom. The left-most pixels are conditioned more strongly on pixels directly above, and these are generally further away in the sequence. Hence, even if the last pixel of the preceding row contains a wrong value, pixels in the left-most column can be predicted with high accuracy.

4.2 Predictive sampling of latent variables

Setting

In this section we explore autoencoders with a probabilistic latent space (Theis et al., 2017; Ballé et al., 2017). Typically these methods weigh a distortion component and a rate component :

(11)

where is an encoder, is a decoder and is a tunable parameter. In our experiments we use the Mean Squared Error (MSE) as distortion metric and set . Following (van den Oord et al., 2017; Habibian et al., 2019; Razavi et al., 2019) we model the latent distribution using an ARM. The (deterministic) encoder has an architecture consisting of two

convolutional layers, two strided convolutions and two residual blocks, following PyTorch BasicBlock implementation

(He et al., 2016). The decoder mirrors this architecture with two residual blocks, two transposed convolutions and two standard convolutional layers. The latent space is quantized using an argmax of a softmax, where the gradient is obtained using the straight-through estimator. We use a latent space of channels, with height and width equal to , and categories per latent variable. Further details on the architecture are given in Appendix A.

Following van den Oord et al. (2017), we separate the training of autoencoder and ARM. We first train the discrete autoencoder for iterations, then freeze its weights, and train an ARM on the latents generated by the encoder for another iterations. We find that this scheme results in more stability than joint training. The obtained MSE is for Imagenet32, for CIFAR10, and for SVHN. The obtained bits per image dimension are for Imagenet32, for CIFAR10, and for SVHN (To obtain the bits per latent dimension, multiply these by the dimensionality reduction factor ). Note that the prior likelihood depends on the latent variables produced by the encoder, and cannot be compared directly with results from explicit likelihood modeling.

Performance

The sampling performance for the discrete-latent autoencoder is presented in Table 2. Again, predictive sampling with fixed-point iteration and learned forecasting modules both outperform the baseline, and fixed-point iteration outperforms learned forecasting across all three datasets.

Samples and predictive sampling mistakes of forecasting methods are depicted in Figure 5 for an autoencoder trained on CIFAR10 (8 bit). Samples are generated in the latent representation and subsequently is visualized. In addition, the latent representation is visualized on a scale from black to red, where the amount of red indicates the number of mistakes at that location, averaged over the channel dimension. The latent representation has an resolution and is resized to match the images.

(a) Decoded samples , where .
(b) Forecasting mistakes by fixed-point iteration.
(c) Forecasting mistakes by learned forecasting modules.
Figure 5: Samples from the VAE and corresponding forecasting mistakes for a latent space.

Finally, the convergence behavior of fixed-point iteration is visualized in Figure 6. In this figure, the color indicates the iteration of sampling from which the variable remained the same, i.e., the iteration at which that variable converged. For example, because there is strict triangular dependence and the top-left variable in the first channel is at the beginning of the sequence, this variable will converge at step one. The converging iterations are averaged over channels and a batch of images. The right image of Figure 6 shows the baseline, where the total number of iterations is exactly equal to the number of dimensions. The left image shows the convergence of the ARM fixed-point iteration procedure, which needs iterations on average for this batch of data. We observe that pixels on the left of the image tend to converge earlier than those on the right. This matches the conditioning structure of the ARM, where values in the left-most column depend strongly on pixel values directly above, and right-most variables also depend on pixels to their left.

4.3 Ablations

We perform ablations on 8 bit CIFAR10 data to show the effect of the isolation of stochasticity via reparametrization, and the sharing of the ARM representation. First, to quantify the effect of reparametrization, the sampling procedure is run again for an ARM without learned forecasting modules. As forecast, the most likely value according the forecasting distribution is used. For categorical distributions, this is done by removing the term from Equation 10. In addition, we show the importance of sharing the ARM representation by training forecasting modules conditioned only on and reparametrization noise , i.e., Equation 6 where is removed. Results are shown in Table 3. These experiments indicate that the both reparametrization and the shared representation improve performance considerably, with reparametrization having the biggest effect.

Figure 6: Comparison of convergence for fixed-point iteration (left) and the baseline (right) for a latent space of an autoencoder trained on CIFAR10. Each spatial location shows the iteration at which the final value was determined, averaged over all latent channels and over 32 samples . Note that a log-scale colormap is used to emphasize differences for low values.
CIFAR10
ARM calls Time (s)
Fixed-point iteration 25.9% 1.1 305 11.4
  without reparametrization 97.2% 0.4 1122   6.4
Learned forecasting 50.9% 1.8 597 21.6
  without representation sharing 67.1% 3.3 802 19.5
Table 3: Ablation showing the effect of reparametrization and representation sharing for CIFAR10. Means and (Bessel-corrected) standard deviations are based on 10 sampled batches of size 32, with random seeds .

5 Conclusion

We introduce predictive sampling, an algorithm speeds up sampling for autoregressive models (ARMs), while keeping the model intact. The algorithm aims to forecast likely future values and exploits the parallel inference property of neural network based ARMs. We propose two variations to obtain forecasts, namely ARM fixed-point iteration and learned forecasting modules. In both cases, the sampling procedure is reduced to a deterministic function by a reparametrization. We train ARMs on image data and on the latent space of a discrete autoencoder, and show in both settings that predictive sampling provides a considerable increase in sampling speed. ARM fixed-point iteration, a method that requires no training, obtains the best performance overall.

References

  • N. Akoury and A. Nguyen (2017) Spatial pixelcnn: generating images from patches. arXiv preprint arXiv:1712.00714. Cited by: §3.
  • J. Ballé, V. Laparra, and E. P. Simoncelli (2017) End-to-end optimized image compression. In 5th International Conference on Learning Representations, ICLR, Cited by: §4.2.
  • S. Bengio and Y. Bengio (2000)

    Taking on the curse of dimensionality in joint distributions using neural networks

    .
    IEEE Trans. Neural Netw. Learning Syst. 11 (3), pp. 550–557. Cited by: §1, §2, §3.
  • X. Chen, N. Mishra, M. Rohaninejad, and P. Abbeel (2018) PixelSNAIL: an improved autoregressive generative model. In Proceedings of the 35th International Conference on Machine Learning, ICML, pp. 863–871. Cited by: §1, §3.
  • R. Child, S. Gray, A. Radford, and I. Sutskever (2019) Generating long sequences with sparse transformers. CoRR abs/1904.10509. Cited by: §1, §3.
  • L. Dinh, J. Sohl-Dickstein, and S. Bengio (2017) Density estimation using real NVP. In 5th International Conference on Learning Representations, ICLR, Cited by: §3.
  • M. Germain, K. Gregor, I. Murray, and H. Larochelle (2015) Made: masked autoencoder for distribution estimation. In International Conference on Machine Learning, pp. 881–889. Cited by: §3.
  • E. J. Gumbel (1954) Statistical theory of extreme values and some practical applications. NBS Applied Mathematics Series 33. Cited by: §2.2.
  • A. Habibian, T. v. Rozendaal, J. M. Tomczak, and T. S. Cohen (2019) Video compression with rate-distortion autoencoders. In

    The IEEE International Conference on Computer Vision (ICCV)

    ,
    Cited by: §4.2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    pp. 770–778. Cited by: §A.3, §4.2.
  • E. Jang, S. Gu, and B. Poole (2017) Categorical reparameterization with gumbel-softmax. In 5th International Conference on Learning Representations, ICLR, Cited by: §2.2.
  • D. P. Kingma and M. Welling (2014) Auto-encoding variational bayes. In 2nd International Conference on Learning Representations, ICLR, Cited by: §3.
  • W. Kool, H. van Hoof, and M. Welling (2019) Stochastic beams and where to find them: the gumbel-top-k trick for sampling sequences without replacement. In Proceedings of the 36th International Conference on Machine Learning, ICML, pp. 3499–3508. Cited by: Appendix B, §2.2.
  • A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Cited by: Appendix C, §4.
  • H. Larochelle and I. Murray (2011) The neural autoregressive distribution estimator. In

    Proceedings of the 14th International Conference on Artificial Intelligence and Statistics (AISTATS)

    ,
    pp. 29–37. Cited by: Appendix C, §1, §2, §3, §3, §4.
  • C. J. Maddison, D. Tarlow, and T. Minka (2014) A* sampling. In Advances in Neural Information Processing Systems, pp. 3086–3094. Cited by: Appendix B, §2.2.
  • C. J. Maddison, A. Mnih, and Y. W. Teh (2017)

    The concrete distribution: A continuous relaxation of discrete random variables

    .
    In 5th International Conference on Learning Representations, ICLR, Cited by: §2.2.
  • J. Menick and N. Kalchbrenner (2018) Generating high fidelity images with subscale pixel networks and multidimensional upscaling. arXiv preprint arXiv:1812.01608. Cited by: §1, §3.
  • Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng (2011) Reading digits in natural images with unsupervised feature learning. . Cited by: Appendix C, §4.
  • A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Kopf, E. Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner, L. Fang, J. Bai, and S. Chintala (2019)

    PyTorch: an imperative style, high-performance deep learning library

    .
    In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett (Eds.), pp. 8024–8035. Cited by: §A.4, §4.
  • P. Ramachandran, T. L. Paine, P. Khorrami, M. Babaeizadeh, S. Chang, Y. Zhang, M. A. Hasegawa-Johnson, R. H. Campbell, and T. S. Huang (2017) Fast generation for convolutional autoregressive models. In 5th International Conference on Learning Representations, ICLR, Cited by: §3.
  • A. Razavi, A. van den Oord, and O. Vinyals (2019) Generating diverse high-fidelity images with vq-vae-2. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox, and R. Garnett (Eds.), pp. 14837–14847. Cited by: §4.2.
  • S. E. Reed, A. van den Oord, N. Kalchbrenner, S. G. Colmenarejo, Z. Wang, Y. Chen, D. Belov, and N. de Freitas (2017) Parallel multiscale autoregressive density estimation. In Proceedings of the 34th International Conference on Machine Learning, ICML, pp. 2912–2921. Cited by: §1, §3.
  • D. J. Rezende, S. Mohamed, and D. Wierstra (2014)

    Stochastic backpropagation and approximate inference in deep generative models

    .
    arXiv preprint arXiv:1401.4082. Cited by: §3.
  • H. Sadeghi, E. Andriyash, W. Vinci, L. Buffoni, and M. H. Amin (2019) PixelVAE++: improved pixelVAE with discrete prior. arXiv preprint arXiv:1908.09948. Cited by: §3.
  • T. Salimans, A. Karpathy, X. Chen, and D. P. Kingma (2017) Pixelcnn++: improving the pixelcnn with discretized logistic mixture likelihood and other modifications. arXiv preprint arXiv:1701.05517. Cited by: §A.1, §A.1, §1, §3, §3, §4.
  • L. Theis, W. Shi, A. Cunningham, and F. Huszár (2017) Lossy image compression with compressive autoencoders. In 5th International Conference on Learning Representations, ICLR, Cited by: §4.2.
  • A. van den Oord, S. Dieleman, H. Zen, K. Simonyan, O. Vinyals, A. Graves, N. Kalchbrenner, A. Senior, and K. Kavukcuoglu (2016a) Wavenet: a generative model for raw audio. arXiv preprint arXiv:1609.03499. Cited by: §1, §3.
  • A. van den Oord, N. Kalchbrenner, and K. Kavukcuoglu (2016b)

    Pixel recurrent neural networks

    .
    arXiv preprint arXiv:1601.06759. Cited by: §A.1, §A.1, Appendix C, §1, §3, §3, §4, §4.
  • A. van den Oord, Y. Li, I. Babuschkin, K. Simonyan, O. Vinyals, K. Kavukcuoglu, G. van den Driessche, E. Lockhart, L. C. Cobo, F. Stimberg, N. Casagrande, D. Grewe, S. Noury, S. Dieleman, E. Elsen, N. Kalchbrenner, H. Zen, A. Graves, H. King, T. Walters, D. Belov, and D. Hassabis (2018) Parallel wavenet: fast high-fidelity speech synthesis. In Proceedings of the 35th International Conference on Machine Learning, ICML, pp. 3915–3923. Cited by: §1, §3.
  • A. van den Oord, O. Vinyals, and K. Kavukcuoglu (2017) Neural discrete representation learning. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems NeurIPS, pp. 6306–6315. Cited by: §4.2, §4.2.

Appendix

Appendix A Architecture and hyperparameters

a.1 Autoregressive model architecture

We base our PixelCNN implementation on a Pytorch implementation of PixelCNN++ (Salimans et al., 2017) by GitHub user pclucas14 (https://github.com/pclucas14/pixel-cnn-pp, commit 16c8b2f). We make the following modifications.

Instead of the discretized mixture of logistics loss as described in (Salimans et al., 2017), we utilize the categorical distributions as described in (van den Oord et al., 2016b), which allows us to model distributions with full autoregressive dependence. This is particularly useful when training the PixelCNN for discrete-latent autoencoders, as the number of input channels can be altered without substantial changes to the implementation. We model dependencies between channels as in the PixelCNN architecture, by masking convolutions so that the causal structure is preserved. That is, the output corresponding to input is conditioned on all previous rows , on all previous columns of the same row and all previous channels of the same spatial location .

The original implementation normalizes the input data to a range between and . Instead, we follow (van den Oord et al., 2016b) and use a one-hot encoding for inputs. Additionally, we do not use weight normalization.

a.2 Forecasting module architecture

The forecasting module used in this work consists of a single strictly triangular convolution followed by a convolution, where the number of output channels is equal to the number of data channels multiplied by the number of categories. The masked convolution is applied to the last activation of the up-left stack of the PixelCNN, . We set the number of channels for this layer to 162 for image space experiments, and 160 for latent space experiments.

We experimented with variations of forecasting modules that use (one-hot) or truncated Gumbel noise, obtained as described in Section B, as additional inputs. For the forecasting module capacity we considered, this did not lead to improved sampling performance.

a.3 Default hyperparameters

Explicit likelihood modeling

Hyperparameter settings for the PixelCNNs trained on image data are given in Table 4. We use the same parameters across all datasets to maintain consistency, and did not alter the architecture for likelihood performance. The exception is binary MNIST, where we observed strong overfit if the size was not changed.

Hyperparameter Binary MNIST Default
Learning rate 0.0002 0.0002
Learning rate decay 0.999995 0.999995
Batch size 64 64
Max iterations 200000 200000
Weight decay 1e-6 1e-6
Optimizer Adam Adam
Number of gated resnets 2 5
Filters per layer 60 162
Dropout rate 0.5 0.5
Nonlinearity concat_elu concat_elu
Forecasting modules 20 1
Forecasting filters 60 162
Forecasting loss weight 0.01 0.01
Table 4: Hyperparameters for the trained PixelCNN models.

Latent space modeling

For the latent space experiments, we use an encoder and decoder with bottleneck structure, and a PixelCNN to model the resulting latent space. The width of the encoder and decoder, i.e.

, the parameter that controls the number of channels at every layer, is 512 for all experiments. The used loss function is Mean Squared Error, and the input data is normalized to the range

.

The encoder consists of the following layers. First, two

convolutional layers with padding 1 and half width. Then, one strided

convolution of half width with padding 1 and stride 2, followed by a similar layer of full width. We then apply two residual blocks (PyTorch BasicBlock implementation (He et al., 2016)). Finally, a convolution layer maps to the desired number of latent channels.

The decoder architecture mirrors the encoder architecture. First, a convolution layer maps from the (one-hot) latents to the desired width. Two residual blocks are applied, followed by a full width transpose convolution and a half width transpose convolution, both having the same parameters as their counterparts in the encoder. Lastly, two convolution layers of half width are applied, where the last layer has three output channels.

The latent space is quantized by taking the argmax over a softmax, and one-hot encoding the resulting latent variable. As quantization is non-differentiable, the gradient is obtained using a straight-through estimator, i.e., the softmax gradient is used in the backward pass. We use a latent space of channels, with height and width equal to , and categories per latent variable.

Optimization parameters and the parameters of the PixelCNN that is used to model the latent space are kept the same as in the explicit likelihood setting, see Table 4. We do not train the autoencoder and ARM jointly. Instead, we train an autoencoder for 50000 iterations, then freeze the autoencoder weights and train an ARM on the latent space for an additional 200000 iterations.

a.4 Infrastructure

Software used includes Pytorch (Paszke et al., 2019) version 1.1.0, CUDA 10.0, cuDNN 7.5.1. All sampling time measurements were obtained on a single Nvidia 1080Ti GPU using CUDA events, and we only compute runtime after calling torch.cuda.synchronize. Training was performed on Nvidia TeslaV100 GPUs, with the same software stack as the evaluation system.

Appendix B Posterior Reparametrization Noise

To condition forecasting modules on reparametrization noise when training on the data distribution, sample noise pairs () are needed. In principle, these can be created by sampling and computing the corresponding using the ARM, i.e., by computing the autoregressive inverse. However, this process may be slow, and does not allow for joint training of the ARM and forecasting module. Alternatively, one can use the assumption that the model distribution will sufficiently approximate . In this case, () pairs can be sampled using the data distribution , and the posterior of the noise :

(12)

where is a Dirac delta peak on the output of the reparametrization, and denotes the posterior of the noise given a sample :

(13)

In the case of the Gumbel-Max reparametrization, the posterior Gumbel noise can be computed straightforwardly by using the notion that the maximum and the location of the maximum are independent (Maddison et al., 2014; Kool et al., 2019). First, we sample from the Gumbel distribution for the arg max locations, i.e., the locations that resulted in the sample :

(14)

Subsequently, the remaining values can be sampled using truncated Gumbel distributions () (Maddison et al., 2014; Kool et al., 2019). The truncation point is located at the maximum value :

(15)

Here,

denotes the logit from the model distribution

of dimension for category .

To summarize, a sample () is created by first sampling from data, and then sampling using the Gumbel and Truncated Gumbel distributions as described above. This technique allows simultaneous training of the ARM and forecasting module conditioned on Gumbel noise without the need to create an additional sampling dataset.

Appendix C Generated samples

We show 16 samples for each of the models trained with forecasting modules, as well as forecasting mistakes. To find forecasting mistakes made by ARM fixed-point iteration, we simply disable the forecasting modules during sampling. All samples were generated using the same random seed (10) and were not cherry-picked for perceptual quality or sampling performance.

(a) Samples from the model distribution .
(b) Forecasting mistakes by the forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 7: Samples from the 1-bit ARM and forecasting mistakes.
(a) Samples from the model distribution .
(b) Forecasting mistakes by the forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 8: Samples from the 8-bit ARM and forecasting mistakes.

The used datasets are Binary MNIST (Larochelle and Murray, 2011), SVHN (Netzer et al., 2011), CIFAR10 (Krizhevsky et al., 2009), and ImageNet32 (van den Oord et al., 2016b). Samples generated by the model trained on binary MNIST are shown in Figure 7, Figure 8 shows SVHN 8-bit samples, and Figures 9 and 10 show samples generated by the ARM trained on CIFAR10 for 5-bit and 8-bit data, repsectively.

(a) Samples from the model distribution .
(b) Forecasting mistakes by the forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 9: Samples from the 5-bit ARM and forecasting mistakes.
(a) Samples from the model distribution .
(b) Forecasting mistakes by the forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 10: Samples from the 8-bit ARM and forecasting mistakes.

Samples from the VAE are generated by decoding a sample from the ARM trained on the latent space. That is, we first generate a latent variable from the trained ARM . This sample is then decoded to image-space using the decoder as . We show samples for SVHN in Figure 11, for CIFAR10 in Figure 12, and for Imagenet32 in Figure 13.

(a) Decoded samples , where .
(b) Forecasting mistakes by learned forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 11: VAE samples, and forecasting mistakes in latent space.
(a) Decoded samples , where .
(b) Forecasting mistakes by learned forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 12: VAE samples, and forecasting mistakes in latent space.
(a) Decoded samples , where .
(b) Forecasting mistakes by learned forecasting modules.
(c) Forecasting mistakes by fixed-point iteration.
Figure 13: VAE samples, and forecasting mistakes in latent space.