1 Introduction
Autoregressive models are a family of exact likelihoodbased generative models that represent the joint distribution of data as a product of conditionals
. Neural network models in this family have achieved stateoftheart log likelihoods on highdimensional image and video datasets
(van den Oord et al., 2016a; Chen et al., 2018; Menick and Kalchbrenner, 2018; Parmar et al., 2018; Child et al., 2019; Weissenborn et al., 2019; Salimans et al., 2017; Kalchbrenner et al., 2017; Uria et al., 2016; Parikh et al., 2016; Theis and Bethge, 2015; van den Oord et al., 2016b) due to architectural innovations that enable the following capabilities:
Large, high information bandwidth receptive fields for each pixel , capable of expressing longrange dependencies over previous pixels , and

Computationally efficient, vectorizable computation of the log likelihood and its gradient.
Autoregressive model architectures that can read longrange dependencies over large receptive fields are able to express all joint distributions over the data. Meanwhile, architectures that admit fast log likelihood gradient computation are suitable for training using a stochastic gradient method on a maximum likelihood objective—a straightforward, stable training procedure for generative models.
These desiderata make selfattention a compelling building block for autoregressive model architectures. Selfattention is a neural network operation that is able to transform a sequence into a sequence , where each depends on all by way of a single vectorizable computation (Vaswani et al., 2017). Selfattention is remarkably effective at learning longrange dependencies between data dimensions and neural networks that incorporate selfattention in their designs are stateoftheart on many tasks from language modelling and machine translation to image and video modelling (Parmar et al., 2018; Child et al., 2019).
But the power of selfattention comes at the price of computational complexity. The memory and computation it consumes grow quadratically with the sequence length making it prohibitively expensive to directly apply selfattention to long sequences. In the case of autoregressive models of multidimensional tensors such as images or videos, the aim to capture large receptive fields in multiple dimensions further exacerbates the problem as even a modest number of receptive field steps in each dimension can encompass a large total number of locations. Various approaches have been proposed to alleviate this difficulty at the cost of either limiting the receptive field or requiring operations that may not be broadly available on GPUs or TPUs.
We propose the Axial Transformer, a simple yet effective selfattentionbased autoregressive model for data organized as multidimensional tensors. Rather than applying attention to a flattened string of tensor elements, our model instead applies attention along a single axis of the tensor without flattening—we refer to this as “axial attention.” Since the length of any single axis (that is, the height or width of an image) is typically much smaller than the total number of elements, an axial attention operation enjoys a significant saving in computation and memory over standard selfattention: for a dimensional tensor with shape , axial attention saves a factor of resources over standard selfattention.
Our Axial Transformer architecture allows for the majority of the context to be embedded with a high degree of parallelism without introducing conditional independence assumptions among any of the locations, but has an interesting property that it is amenable to a simpletoimplement fast sampling procedure. To sample one row of an image, the Axial Transformer only runs an autoregressive Transformer over that one row only, without reembedding pixels from previous rows. We structure the Axial Transformer, however, so that it always defines a fully expressive joint distribution. No dependencies on previous pixels are ever lost.
We evaluate Axial Transformers on image and video modelling benchmarks. We show that Axial Transformer achieves stateoftheart results on ImageNet32 and on ImageNet64. We also show that, simply by stacking a video along the channel dimension, the Axial Transformer can be directly applied to the channelstacked video without nearly any modification. On the BAIR Robot Pushing benchmark, the Axial Transformer significantly outperforms previous results without using an architecture specially designed for videos. The generated samples on these datasets are of the expected high quality.
Axial Transformers do not require subroutines for GPUs or TPUs that may exhibit unfavorable memory bandwidth and computation tradeoffs. Axial Transformers are simple to implement using efficient operations that are widely available in deep learning frameworks (primarily densedense MatMuls). An open source implementation of our models is available at anonymized URL.
2 Background
To set the stage for our discussion, we first review selfattention and its computational resource requirements in the context of autoregressive modeling. A selfattention layer takes as input a length sequence of dimensional embeddings (a matrix) and produces an output sequence (also a matrix) via:
, , and are parameter matrices responsible for projecting the entries of the sequence into keys, queries, and values, respectively. Each entry of the output sequence is a linear combination of values in weighted by the attention matrix
, which itself is computed from similarities between all pairs of query and key vectors. Both the expressive power and the resource cost of selfattention come from computing
and : it takes time and space to compute the pairwise similarities between and and to compute the linear combination of vectors.This quadratic complexity makes it impractical to apply selfattention to images and videos directly as flattened vectors: a small image has 3072 dimensions. Sequences such as these are too long for selfattention, so attempts to scale selfattention to these modalities generally involve restricting these sequence lengths in a modalityaware manner while attempting to preserve modeling performance.
One strategy is to restrict the conditioning context to a carefully designed small subset of the data dimensions. While this reduces the cost of attention, which is only performed over these small subsets instead of the full data, the model can no longer express all joint distributions over the data. Parmar et al. (2018) propose image models with conditioning context restricted to a small window of the full image, but the implementation requires redundant data copies to extract and process these windows. Weissenborn et al. (2019) similarly scale video autoregressive models by restricting the context, again preventing their model from expressing all joint distributions over pixels. Our models do not restrict context and hence we obtain better log likelihoods, as we will see in section 4.
A different strategy is to stack multiple sparse attention layers, each with restricted context for computational efficiency, but in a manner that overlapping these layers yields a fullcontext model. Child et al. (2019)
propose two sparse attention patterns with this property. However, the architecture they propose that works best for images (the Strided Sparse Transformer) requires custom sparse attention GPU kernels to implement a specific blocksparse variant of matrixmatrixmultiply. The model cannot be easily implemented on other hardware such as TPUs.
See table 1 for a summary of these architecture design tradeoffs. Our goal in this paper is to design attentionbased autoregressive models that attain the best of all worlds. Our Axial Transformer, described in subsequent sections, has a full conditioning context, so its ability to express joint distributions is never limited. The Axial Transformer also does not require any redundant data copies or custom kernels to implement in an efficient way. Indeed, we designed, and will make open source, an efficient implementation that uses only standard operations in deep learning libraries.
Model 





Transformer (Vaswani et al., 2017)  yes  no  yes  no  
Image Transformer (Parmar et al., 2018)  no  yes  yes  no  
Block Transformer (Weissenborn et al., 2019)  no  yes  yes  no  
Strided Sparse Transformer (Child et al., 2019)  yes  yes  no  no  
Axial Transformer (ours)  yes  yes  yes  yes 
3 Axial Transformers
We now describe Axial Transformers, our selfattentionbased autoregressive models for highdimensional data tensors. We describe its basic building block in section 3.1 and then we complete the description into a full autoregressive model in section 3.2.
3.1 Axial attention
We first introduce our basic building block for developing selfattentionbased autoregressive models for highdimensional data tensors.
The proposed approach does not change the original shape of the multidimensional data tensor and performs a masked or unmasked attention over a single axis of the tensor at a time. We call this operation axial attention, denoted by . It performs attention over axis of the tensor , mixing information along axis while keeping information along other axes independent. It is straightforward to implement: axial attention over axis can be implemented by transposing all axes except to the batch axis, calling standard attention as a subroutine, then undoing the transpose (an alternative is to use the einsum operation available in most deep learning libraries).
When the data is an image, we call column attention, as it mixes information within columns while keeping separate columns independent. We call row attention for analogous reasons. Axial attention on a square image of size performs attention on sequences of length —this is a total of computation—an savings in computation over standard selfattention. In general, for a dimensional tensor with , axial attention saves computation over standard attention. Of course, a single layer of axial attention along some axis does not have the full receptive field since it covers a single axis, but we will see in section 3.2 that stacking two axial attention layers allows the model to obtain a global receptive field.
It will be important for us to also define to be the causally masked variant of : component of the result of along axis depends on only components of along axis . The receptive fields of these attention patterns, both unmasked and masked, are illustrated in fig. 2. We will use these masked blocks to build our autoregressive model in section 3.2.
Axial attention can be used within standard Transformer layers in a straightforward manner to produce Axial Transformer layers. The basic building blocks are the same as those found in the standard Transformer architecture:

: layer normalization (Ba et al., 2016), and

: a dense layer operating over the last axis of the input . The letter denotes the dimension of the output activations. If the input has shape , then this operation is identical to a convolution, and the output has shape .
We use these to define ResNet axial attention blocks operating on tensors of dimensional embeddings (Vaswani et al., 2017; Child et al., 2019):
is chosen to be some constant factor larger than , from 1 to 4 (Vaswani et al., 2017). We also define a using in place of .
Operations similar to unmasked axial attention have been proposed in other contexts in computer vision
(Huang et al., 2019). Our focus in forthcoming sections is the use of masked axial attention and its utility in autoregressive image modeling, which is not explored in these works.3.2 Axial Transformers
We now describe Axial Transformers, our axial attentionbased autoregressive models for images and videos. We will use the axial attention operations described in section 3.1 as building blocks in a multilayer autoregressive model of the form following the raster scan ordering of pixels. We will accomplish this by building an autoregressive model over rows (section 3.2.1), then conditioning each row on previous rows (section 3.2.1), then further conditioning on previous channels and frames (section 3.2.2). Decomposing the model in this manner also leads to a simple fast and partly parallel sampling procedure (section 3.2.1).
3.2.1 A model for singlechannel images
We begin with an autoregressive model for a singlechannel image with shape , with each pixel taking an integer value in representing its intensity. As is standard practice with Transformers, pixel intensities are first embedded into a tensor of dimensional embeddings, which we call . The architecture’s responsibility is to transform into a
tensor of logits suitable for classification or sampling. These logits must depend only on previous pixels in the input
along the raster scan ordering to ensure that the architecture defines a valid autoregressive model.Inner Decoder: a rowwise model
Our idea is to begin with masked row attention layers to create a “rowwise” model:
Here, is the number of masked row attention blocks applied to . PositionEmbeddings is a tensor of position embeddings that inform the attention layers of the position. For parameter efficiency we use “additively factorized” position embeddings, meaning that we parameterize them as a broadcasted sum of embeddings for rows and embeddings for columns.
The operation ShiftRight shifts the input right by one pixel, which has the effect of shifting the receptive field left by one pixel. This ensures that the masked row attention layers exclude the current pixel from their receptive field, which is crucial for architecture to define a correct autoregressive model.
As this model employs row attention only, it enjoys the computational efficiency benefits described in section 3.1. However, it clearly does not define a fullcontext model because each location in the output does not depend on input pixels in previous rows. If we were to use the resulting as logits for pixel intensity prediction, we would obtain a set of independent autoregressive models for each row , not a single autoregressive model with full context. We address this issue next.
Outer Decoder: capturing the rows above
Each pixel in the aforementioned model already depends on previous pixels in its own row . We just need to make it depend on all previous rows too. So, we insert unmasked row and masked column layers in the beginning of the model as follows (newly inserted operations are underlined):
The tensor represents context captured above the current pixel. It is computed by unmasked row and masked column attention layers, repeated to a total of layers to increase model capacity, which make cover the receptive field at all rows above and including the current pixel. The ShiftDown operation shifts down one pixel, which shifts its receptive field up one pixel. Thus we have a context which captures all pixels above while excluding the current row, which we add to as input to the masked row layers. We have thus converted the rowwise model into a fully expressive autoregressive model that captures not only pixels in the current row but also those above.
Following standard practice, we pass the final through layer normalization and a final dense layer to produce logits with shape . The logits at each location depend on all previous pixel locations in the raster scan ordering.
SemiParallel Sampling
Naive implementations of sampling from sequential models are notoriously slow because they require reevaluating the entire network to sample each location. In the case of our model for a square image, each network evaluation takes time, so sampling the whole image would take , which is far too large.
Fortunately, our architecture is amenable to a particularly simple implementation of a faster sampling that is able to compute large sections of the model in parallel (see Figure 1). Pseudocode is as follows:

For each row :

Compute the upper context including information about all using the upper layers

For each column :

Sample conditioned on and prior elements of row i ().


Because the rowwise layers are independent over rows (they depend on other rows only through the upper context, as explained in section 3.2.1), sampling one row can be accomplished by evaluating the rowwise layers for that one row only, completely ignoring other rows. Thus, in one row of pixels, each pixel can be sampled in , so all pixels can be sampled in . Before each of the rows can be sampled, the upper context must be computed in , for a total of over the course of all rows. Thus we arrive at in total, which is faster than the naive implementation. To our knowledge, sampling speedups of this type are not possible with contemporary work on scaling Transformers to images and videos (Child et al., 2019; Weissenborn et al., 2019).
Model  ImageNet 32x32  ImageNet 64x64 
Multiscale PixelCNN (Reed et al., 2017)  3.95  3.70 
PixelCNN/RNN (van den Oord et al., 2016a)  3.86  3.63 
Gated PixelCNN (van den Oord et al., 2016b)  3.83  3.57 
PixelSNAIL (Chen et al., 2018)  3.80  3.52 
SPN (Menick and Kalchbrenner, 2018)  3.79  3.52 
Image Transformer (Parmar et al., 2018)  3.77  – 
Strided Sparse Transformer (Child et al., 2019)  –  3.44 
Axial Transformer + LSTM inner decoder  3.77  3.46 
Axial Transformer  3.76 (3.758)  3.44 (3.439) 
Model  bits/dim next 15 frames 
VideoFlow (Kumar et al., 2019)  1.87 
Video Transformer (Weissenborn et al., 2019)  1.35 
Axial Transformer (ours)  1.29 
3.2.2 Channel Encoder for MultiChannel Images and Videos
We have just described an architecture for a singlechannel image of shape . Here, we show how to extend the architecture to multichannel images or videos of shape (here is either the number of channels in a multichannel image, or the product of the number of channels and timesteps in a video). One way to model such data of shape is to simply stack the channels on top of each other into a singlechannel image of shape or . This is simple to implement, but does increase the sequence length for column attention or row attention, which can be undesirable for large . We instead opt to model one channel at a time as a singlechannel image, but now conditioned on previous channels using an extra set of unmasked row and unmasked column attention layers. This means that we have a model of the form , where previous channels are processed into a tensor of context information, which is then added into the first encoding blocks of the model in section 3.2.1 (Figure 3).
We do not share any parameters among any of these layers. At training time, we train on a random channel slice of each image: we process the previous slices using these unmasked attention layers to produce a context tensor, and maximize the likelihood of the randomly chosen slice conditioned on this context. This amounts to training on an unbiased estimate of log likelihood for the whole data tensor. See
fig. 1 for an illustration of this complete model.4 Experiments
We benchmarked our models on standard datasets for generative image and video models: downsampled ImageNet (van den Oord et al., 2016a) and BAIR Robot Pushing (Ebert et al., 2017)
. All Axial Transformers have 8 total layers in the encoder, 8 layers in the outer decoder and 4 layers in the inner decoder. We use a hidden size of 2048 neurons throughout and for all setups and 16 heads with 128 neurons each for the attention component. We train for approximately 200k steps on ImageNet32 and ImageNet64 and for 200k steps on BAIR Robot Pushing. Our models can overfit on ImageNet32, but on the other datasets the models keep on gradually improving with more steps. See
table 2 and table 3 for our results.4.1 Ablation study
To push the limits of the semiparallel sampling by making the inner decoder as small as possible, we train an Axial Transformer with the inner decoder replaced by a single LSTM layer of 2048 units. This slows down training time by about 20% on ImageNet32 and about 80% on ImageNet64 when maintaining the number of steps and all else fixed. We find that the Axial Transformer + LSTM inner decoder performs rather well on the ImageNet32 and ImageNet64 benchmarks (table 2), thereby also showing the effectiveness of the remaining parts of the Axial Transformer that capture the context of the rows above. We also find however that the full four layers of the inner decoder of the Axial Transformer provide an additional boost in performance as well as significantly faster training. The Axial Transformer + LSTM inner decoder has the advantage of requiring only a couple of matrixvector products to compute the layers at each autogressive step, comparing favourably with about the 12 matrixvector products required by the Axial Transformer, but the slower training time would make the LSTM inner decoder quickly impractical for larger tensors.
4.2 Samples
In fig. 4 and fig. 5, we show samples from our and ImageNet models. The samples are globally coherent and show visibly recognizable scenes, meaning that our Axial Transformer architecture successfully captures longrange dependencies across thousands of data dimensions in these image datasets. The samples also don’t show any architecturecorrelated artefacts. In addition, in fig. 6 we show samples from the BAIR Robotic Pushing dataset. The first frame is each row is given by the dataset and the rest are continuation. We note the high quality exactness of details and the very large diversity (at temperature 1.0).
5 Conclusion
We proposed the Axial Transformer, an selfattentionbased autoregressive model for data organized as high dimensional tensors. It is based on axial attention, a simple generalization of selfattention that scales better with the dimension of input data, achieving a savings in computation and memory for a dimensional input tensor with elements. Axial attention is easy to implement and does not require custom kernels to run efficiently on modern accelerators. Axial Transformers use axial selfattention layers and a shift operation to naturally and efficiently build full receptive fields of multidimensional tensors. Our model matches or outperforms the stateoftheart on ImageNet32 and ImageNet64 image benchmarks and sets a significant new stateoftheart on the BAIR Robot Pushing video benchmark.
References
 Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: 1st item.

PixelSNAIL: an improved autoregressive generative model.
In
International Conference on Machine Learning
, pp. 863–871. Cited by: §1, Table 2.  Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509. Cited by: §1, §1, Table 1, §2, §3.1, §3.2.1, Table 2.
 Selfsupervised visual planning with temporal skip connections. In Conference on Robot Learning, pp. 344–356. Cited by: Table 3, §4.
 Ccnet: crisscross attention for semantic segmentation. In Proceedings of the IEEE International Conference on Computer Vision, pp. 603–612. Cited by: §3.1.
 Video pixel networks. In International Conference on Machine Learning, pp. 1771–1779. Cited by: §1.
 VideoFlow: a flowbased generative model for video. arXiv preprint arXiv:1903.01434. Cited by: Table 3.
 Generating high fidelity images with subscale pixel networks and multidimensional upscaling. arXiv preprint arXiv:1812.01608. Cited by: §1, Table 2.

A decomposable attention model for natural language inference
. arXiv preprint arXiv:1606.01933. Cited by: §1.  Image transformer. In International Conference on Machine Learning, pp. 4052–4061. Cited by: §1, §1, Table 1, §2, Table 2.
 Parallel multiscale autoregressive density estimation. In Proceedings of the 34th International Conference on Machine LearningVolume 70, pp. 2912–2921. Cited by: Table 2.
 PixelCNN++: improving the PixelCNN with discretized logistic mixture likelihood and other modifications. In International Conference on Learning Representations (ICLR), Cited by: §1.
 Generative image modeling using spatial lstms. In Advances in Neural Information Processing Systems, pp. 1927–1935. Cited by: §1.
 Neural autoregressive distribution estimation. The Journal of Machine Learning Research 17 (1), pp. 7184–7220. Cited by: §1.
 . International Conference on Machine Learning (ICML). Cited by: §1, Table 2, §4.
 Conditional image generation with PixelCNN decoders. arXiv preprint arXiv:1606.05328. Cited by: §1, Table 2.
 Attention is all you need. In Advances in Neural Information Processing Systems, pp. 5998–6008. Cited by: §1, Table 1, §3.1, §3.1.
 Scaling autoregressive video models. arXiv preprint arXiv:1906.02634. Cited by: §1, Table 1, §2, §3.2.1, Table 3.