The extensive adoption of neural networks and, in general, learning models has been raising concerns regarding our chances, as humans, to explain their behavior. Interpretability would be a highly desirable feature for neural networks, especially in those applications like autonomous driving [grigorescu2020survey], healthcare [miotto2018deep], and finance [sezer2020financial], where safety, life, and security are at stake.
Deep neural networks have achieved superhuman performances in many domains, from computer vision[lecun2015deep, he2016deep]vaswani2017attention, devlin2018bert], and data analysis [sezer2020financial]. However, the achieved performances have come at the expense of model complexity, making it difficult to interpret how neural networks work [linardatos2021explainable]. These neural networks are usually deployed as ”black boxes”, with millions of parameters to be tuned, mostly according to experience and rule of thumb. Interpreting how a trainable parameter in the network setup directly affects the desired output from a given input has nearly zero chances.
According to the literature, interpretability is defined as “the degree to which a human can understand the cause of a decision”[miller2019explanation]
. When a machine learning model reaches high accuracy on a task such as classification and prediction, can we trust the model without understanding why such a decision has been taken? The decision process is complex and we tend to evaluate the performance of a system in solving a given task using metrics computed at the end of the processing chain. While single metrics, such as the classification accuracy, reach super-human results, they provide an incomplete description of the real-world task[doshi2017towards]. As humans, when looking at an object that has eyes and limbs, we can infer via reasoning and intuition that these are elements (parts) that belong to the same entity (whole) [biederman1987recognition], say an animal, and we can explain and motivate why such decision is taken, generally based on past experiences, beliefs and attitude [albarracin2000cognitive]
. Moreover, even in presence of an animal never seen before, we can probably tell from the visual features, our frames of reference[hawkins2021thousand] and our hierarchical organization of objects in the world [miller1995wordnet] whether it is a fish or a mammal. We would like neural networks to display the same behavior, so that objects that are close in the conceptual-semantic and lexical relations are adjacent in the feature space as well (as shown in Fig. 1e
). By doing so, it would be intuitive to identify hierarchical relations between samples and how the model has learned to build a topology describing each sample. Consequently, we can agree on the definition of interpretability in deep learning as the“extraction of relevant knowledge from a machine-learning model concerning relationships either contained in data or learned by the model” [murdoch2019definitions].
In the image classification field, available techniques, such as transformers [vaswani2017attention, dosovitskiy2020image, devlin2018bert], neural fields [mildenhall2020nerf], contrastive learning representation [chen2020simple], distillation [hinton2015distilling] and capsules [sabour2017dynamic], have achieved state-of-the-art performances, introducing a number of novelties, such as powerful attention-based features and per-patch analysis, positional encoding, similarity-based self-supervised pre-training, model compression and deep modeling of part-whole relationships.
Taken as standalone, these methods have contributed to improving the interpretability of networks, while still lacking direct emphasis on either data relationships [vaswani2017attention, dosovitskiy2020image, devlin2018bert, chen2020simple, mildenhall2020nerf] (e.g. conceptual-semantic relationships) or model-learned relationships [sabour2017dynamic, hinton2015distilling] (e.g. part-whole relationships). Retrieving part whole hierarchy is not a new task per se, as it has been exploited in different research areas as scene parsing [bear2020learning, deng2020generative] and multi-level scene decomposition [zhu2007stochastic, hong2021ptr]. Instead of aiming at learning the part-whole hierarchy as the final goal of our architecture, we focus on learning the part-whole representation as a mean to interpret the network behavior at different levels.
In [hinton2021represent], a concept idea on how to represent part-whole hierarchies in neural networks is introduced, which attempts to merge the advantages of the above state-of-the-art frameworks into a single theoretical system (known as GLOM). GLOM aims at mimicking the human ability in learning to parse visual scenes. Inspired by the theoretical concepts described in [hinton2021represent, hawkins2021thousand], we build a working system, called Agglomerator, which achieves part-whole agreement [hinton1990mapping] at different levels of the model (relationships learned by the model) and hierarchical organization of the feature space (relationships contained in data), as shown in Fig. 1.
Our contribution is summarised as follows:
we introduce a novel model, called Agglomerator111The code and the pre-trained models can be found at https://github.com/mmlab-cv/Agglomerator, mimicking the functioning of the cortical columns in the human brain [hawkins2017theory];
we explain how our architecture provides interpretability of relationships learned by the model, specifically part-whole relationships;
we show how our architecture provides interpretability of relationships contained in data, namely the hierarchical organization of the feature space;
we provide results outperforming or on par with current methods on multiple common datasets, such as SmallNORB [lecun2004learning], MNIST [lecun1998gradient], FashionMNIST [xiao2017/online], CIFAR-10 and CIFAR-100 [krizhevsky2009learning], also relying on fewer parameters.
2 Related work
Convolutional Neural Networks (CNNs) [he2016deep, simonyan2014very]
have risen to a prominent role in computer vision when they started to outperform the existing literature in the image classification task of the ImageNet challenge[krizhevsky2012imagenet]. The convolution operator can effectively describe spatially-correlated data resulting in a feature map, while the pooling operation down-samples the obtained feature map by summarizing the presence of certain features in patches of the image. The pooling operation in CNNs has been the subject of criticism since it does not preserve the information related to the part-whole relationship [sitzmann2019scene] between features belonging to the same object [sabour2017dynamic].
Transformers [liu2021swin, dosovitskiy2020image, khan2021transformers] have proven able to outperform CNNs, thanks to their ability to encode powerful features using self-attention and patch-based analysis of images. Multi-headed transformers [devlin2018bert] require the query, key, and value weights to be trained differently for each head, which is more costly than training a CNN. The main advantage compared to CNNs is the ability of the multiple heads to combine information from different locations in the image with fewer losses than the pooling operation [lee2019set]. However, when compared with CNNs, Transformer-like models usually require intensive pre-training on large datasets, to achieve state-of-the-art performances.
Multi Layer Perceptrons (MLPs)
Multi Layer Perceptrons (MLPs)[tolstikhin2021mlp, li2021convmlp] are characterised by fully connected layers, in which each node is connected to every other possible node of the next layer. Even though they are easier to train and have simpler architecture compared to CNNs, the fully connected layers may cause the network to grow too fast in size and number of parameters, not allowing powerful scalability. MLPs have experienced a resurgence, thanks to patch-based approaches [tolstikhin2021mlp, li2021convmlp], that allowed reaching state-of-the-art performances. They can also be seen as 1x1 convolutions [hinton2021represent, tolstikhin2021mlp, li2021convmlp], which do not require the pooling operation.
Capsules networks [sabour2017dynamic, hinton2018matrix, kosiorek2019stacked, ribeiro2020capsule, mukhometzianov2018capsnet, mazzia2021efficient]
try to mimic the way the human brain creates a parse tree of parts and wholes by dynamically allocating groups of neurons (capsules) that can model objects at different levels of the part-whole hierarchy. The routing algorithm determines which capsules are activated to describe an object in the image, with lower-level capsules describing the parts (e.g. eyes and limbs), and higher-level capsules describing wholes (e.g. mammals and fish). While effectively routing information from different locations in the image, activated capsules cannot describe every single possible object in the image, thus limiting their effectiveness on more complex datasets (e.g. ImageNet, CIFAR-100), while achieving state-of-the-art results on simpler ones (e.g. MNIST). While part-whole hierarchies have been investigated in other fields like scene parsing[bear2020learning, deng2020generative] and multi-level scene decomposition [zhu2007stochastic, hong2021ptr], capsule networks aim at building an internal representation of the hierarchy, which allows for better interpretability of the final task (e.g. classification).
There has been a recent push toward the so-called biologically inspired Artificial Intelligence (AI)[hole2021thousand, hawkins2021thousand], which tries to build deep learning networks able to mimic the structure and functions of the human brain. In [hawkins2021thousand], the authors propose a column-like structure, similar to hyper-columns typical of the human neocortex. In [van2021disentangling], the authors build upon cortical columns implemented as separate neural networks called Cortical Column Networks (CCN). Their framework aims at representing part-whole relationships in scenes to learn object-centric representations for classification.
The author in [hinton2021represent] proposes a conceptual framework, called GLOM, based on inter-connected columns, each of which is connected to a patch of the image and is composed of auto-encoders stacked in levels. Weights sharing among MLP-based [li2021convmlp] auto-encoders allows for an easily trainable architecture with fewer weights, while knowledge distillation [hinton2015distilling]
allows for a reduction of the training parameters. The patch-based approach combined with the spatial distribution of columns allows for a sort of positional encoding and viewpoint estimation similarly to what is used inneural fields [mildenhall2020nerf, sitzmann2019scene]
. At training time, the author recommends that GLOM should be trained using a contrastive loss function[chen2020simple]. This procedure, combined with a Transformer-like self-attention [vaswani2017attention]
mechanism on each layer of the columns, aims at reaching a consensus between columns. Routing the information with layer-based attention and stacked autoencoders would theoretically allow GLOM to learn a different level of abstraction of the input at a different location and level in the columns, creating a part-whole structure with a richer representation if compared to capsule networks[sabour2017dynamic].
While GLOM is presented in [hinton2021represent] more as an intuition rather than a proper architecture, in this work we develop its foundational concepts and turn them into a fully working system, with application to image classification.
The framework we propose aims at replicating the column-like pattern, similar to hyper-columns typical of the human visual cortex [hawkins2021thousand]. An overview is shown in Fig. 1.
Agglomerator brings together concepts and building blocks from multiple methods, such as CNNs [li2021convmlp], transformers [vaswani2017attention, dosovitskiy2020image, devlin2018bert], neural fields [mildenhall2020nerf], contrastive learning representation [chen2020simple], distillation [hinton2015distilling], and capsules [sabour2017dynamic]. Here, we introduce the mathematical notation needed to explain the details of the main building blocks of the architecture.
Each input image is transformed into a feature map divided into patches. The -th patch, with is fed to the corresponding column , spatially located at coordinates . The subscript is omitted in the next equations for better readability. As shown in Fig. 2, each column consists of embedding levels connected by a stack of auto-encoders at location at time , as suggested in [hinton2021represent]. The superscript is omitted in the next instances of for better readability. Each level
of the column is an embedding vector representation of size. Levels and represent consecutive levels; represents a part of the whole . We indicate as all the levels in all columns sharing the same value and belonging to the same layer . Being the last layer of our architecture at the last time step , it is represented as .
3.1 Patches embedding
At the embedding stage, as in [li2021convmlp], we apply a convolutional Tokenizer to extract the feature map of each image of size pixels, which provides a richer representation compared to the original image. Following the implementation in [li2021convmlp], the obtained feature map has size where and . We then embed each of the -dimensional embedding vectors into the bottom levels at the corresponding coordinates of the corresponding column . Feeding the -th each patch to a spatially located column resembles the positional encoding of neural fields [mildenhall2020nerf], where each -sized embedding represents at the same time the sample and its relative observation viewpoint. At each time step , we embed each image sample into the first layer of the columns, which is represented as the bottom layer .
Consecutive levels in time and space in a column are connected by an auto-encoder. The auto-encoders are based on an MLP, which allows for model reduction [hinton2015distilling] and faster training time. Each auto-encoder computes the top-down contribution of a level to the value of the level below at the next time step using a top-down decoder. Similarly, each auto-encoder computes the bottom-up contribution of a level to the value of the level above at the next time step using a bottom-up encoder. and
share a similar structure, but for the activation functions, as described in Fig.2(e). The top-down network uses GELU activation functions [hendrycks2016gaussian], while the bottom up network relies on sinusoidal activation functions [sitzmann2020implicit, sopena1999neural, wong2002handwritten]. All the connecting to layer share the same weights. The same is true for the connecting to layer .
The key element of our architecture is how the information is routed to obtain a representation of the input data where the part-whole hierarchies emerge.
Before computing the loss, we need to iteratively propagate each batch through the network, obtaining a deep representation of each image. This procedure, propagation phase, encourages the network to reach consensus between neighbor levels . Ideally, this means that all neighbor levels in the last layer should have similar values, representing the same whole; neighbor levels at bottom layers should instead share the value among smaller groups, each group representing the same part. Group of vectors that ”agree” on a similar value have reached the consensus on the image representation at that level, and they are called islands of agreement [hinton2021represent]. An example of such representation is shown in Fig. 1(d). In capsules-based approaches [sabour2017dynamic], group of neurons are activated to represent the part-whole hierarchy with limited expressive power. Our -dimensional layers provide a richer representation of the same hierarchy.
To obtain such representation, at time step , we randomly initialise all the values and we embed a batch of samples into the bottom layer . Once the values are initialized, we compute the attention . Instead of the self-attention mechanism used in Transformers [vaswani2017attention, dosovitskiy2020image, devlin2018bert], a standard attention weighting is deployed as in [xu2015show]. Each attention weight is computed as
where represents each possible level belonging to the same layer as , is an indicator function which indexes all the neighbors levels of belonging to the same layer and is a parameter that determines the sharpness of the attention.
At each time step , a batch with samples is fed to the bottom layer network as described in Sec. 3.1. We compute the values as
where indicates the arithmetical average, and are trainable weights. For layer , contribution is not included, as does not exist. The propagation phase takes time steps to reach the final representation of each image at each layer .
The training procedure of our architecture is shown in Fig. 3. It is divided in two steps: (i) a pre-training phase using a supervised contrastive loss function [chen2020simple] and (ii) a training phase for the image classification using a Cross-Entropy loss.
We first pre-train our network using an image-based contrastive loss [chen2020simple]. Given a batch with samples, we duplicate each image to obtain pairs of samples , for a total of data points. We then apply data augmentation RandAugment [cubuk2020randaugment] to both . Both samples are fed to the network as described in Sec. 3.1, and we perform the propagation phase in Sec. 3.3 to obtain the representation at the last layer . Then we rearrange the levels to obtain a vector of dimensions , given as input to the contrastive head , as described in Fig. 2. At the output of the contrastive head, each sample is described by a feature vector of dimension . We take all the possible sample pairs from the batch and we compute the contrastive loss defined as:
where indicates the dot product between the normalised version of and , is a indicator function valued if and belong to the same class, and otherwise.
Once the network is pre-trained using the contrastive loss, the weights are frozen. We apply augmentation [cubuk2020randaugment] to each sample in a batch of size , which is then fed to the network for the propagation phase to obtain for each sample the representation . Then, the cross-entropy head is added on top of the contrastive head . A linear layer resizes -dimensional features to dimension , which corresponds to the number of classes to be predicted for each dataset. The new layers are then trained using the cross-entropy function:
where is the label of a sample taken from the batch and is the label to be predicted.
We perform our experiments on the following datasets:
SmallNorb (S-NORB) [lecun2004learning] is a dataset for 3D object recognition from shape. It consists of roughly 200000 images of size pixels of 5 classes of toys.
MNIST [lecun1998gradient] and FashionMNIST [xiao2017/online], consist of 60000 training images and 10000 test images of grayscale handwritten digits and Zalando’s articles of size pixels.
CIFAR-10 and CIFAR-100 [krizhevsky2009learning] both consist of 50000 training images and 10000 test images of size pixels, with 10 and 100 classes, respectively.
Our network is trained in an end-to-end fashion using PyTorch Lightning on a single NVIDIA GeForce RTX 3090. Input images for each dataset are normalized using each standard dataset’s normalization. We train our network on each dataset’s native resolution, except for SmallNorb, which is resized topixels, following the standard procedure as in [ribeiro2020capsule, hinton2018matrix]. The Tokenizer embedding creates patches represented by -dimensional vectors, where and are the pixels dimension of the input image. Thus the corresponding number of columns is for CIFAR-10, CIFAR-100, and SmallNorb, and for MNIST FashionMNIST. During the pre-training, we deploy the following hyper-parameters: epochs, cyclic learning rate [smith2017cyclical] in the range , batch size , levels embedding , number of levels , number of iterations , dropout value , contrastive features dimension , and weight decay . During the training phase, we resume the network training with the same hyper-parameters, being the number of classes corresponding to each dataset.
|ResNet-110||[he2016deep, huang2016deep, assunccao2019denser]||Conv||-||2.10||5.10||6.41*||27.76*||1.7||GPU|
5 Quantitative results
We report the quantitative results for each dataset in Tab. 2. Capsule-based models [hinton2018matrix, sabour2017dynamic, mazzia2021efficient, mukhometzianov2018capsnet, ribeiro2020capsule] can achieve good performances on simple datasets (SmallNorb, MNIST, and FashionMNIST), but they fail to generalize to datasets with a higher number of classes (CIFAR-100). Convolutional-based models [he2016deep, assunccao2019denser, huang2016deep, simonyan2014very]
can generalize to different datasets, at the expense of weak model interpretability, mainly due to the max-pooling operation. Transformer-based[dosovitskiy2020image] and MLP-based methods [li2021convmlp, tolstikhin2021mlp] are able to achieve the best performances on more complex datasets, but they do not provide tests for smaller datasets. However, to achieve such levels of accuracy they rely on long pretraining (thousands of TPU days) on expensive computational architectures, implementing data augmentation on ImageNet [krizhevsky2012imagenet] or the JFT-300M [sun2017revisiting] dataset, not available publicly. As can be seen, our method performs on par with capsule-based methods on simpler datasets, while achieving better generalization on more complex ones. In addition, our method has fewer parameters than most transformer-based and MLP-based methods, and it can be trained in less time on a much smaller architecture.
2D representation of the latent space for multiple methods trained only on the CIFAR-10 dataset obtained using Principal Component Analysis (PCA)[wold1987principal]. The PCA provides a deterministic change of base for the data from a multidimensional space into a 2D space. The legend (f) displays the classes, which are divided between super-classes Vehicles and Animals following the WordNet hierarchy [miller1995wordnet]. The different methods (a,b,c,d,e) are all able to cluster the samples between the two super-classes. However, while (a,b,e) display a latent space where classes are close to each other, the two MLP-based methods (c,d) are able to provide a clearer separation between the super-classes. Both methods show conceptual-semantically close samples on the edge of each superclass, such as airplanes and birds. Inside each superclass, semantically close samples are represented contiguously, such as deers and horses, or cars and trucks. Our method (c) provides better inter-class and intra-class separability. The overlap percentage is reported for each method. The overlap area is the area where a mistake with a higher hierarchical severity [bertinetto2020making] has a higher probability to occur.
Ablation study. We analyze the contribution of the different components of our architecture evaluating their influence on the validation loss after epochs. The considered parameters, in descending order of correlation with the validation loss value are: the embedding dimension , the contrastive feature vector , learning rate, weight decay, dropout, and the number of levels . The results are reported in Fig. 4. We perform different training on CIFAR-10 with different combinations of parameters.
In Tab. 1 we show how our network configuration (I) performs similarly with (II) and (III). Both sinusoidal activations and shared attention in (I) are key to providing interpretable results, allowing islands of agreement to emerge. Simplified versions using only a linear layer instead of column layers (IV), of the contrastive head (V), or of the linear embedding (VI) lead to a decrease in performance.
6 Qualitative results: interpretability
Our method provides interpretability of the relationships learned by the model by explicitly modeling the part-whole hierarchy, and of the relationships contained in data through the hierarchical organization of the feature space.
Island of agreement as a representation of multi-level part-whole hierarchy. During the propagation phase, neighbor levels on the same layer are encouraged to reach a consensus by forming islands of agreement. The islands of agreement represent the part-whole hierarchies at different levels. In Fig. 5, we provide a few examples of the islands of agreement obtained on MNIST and CIFAR-10 trained with levels. Each arrow represents the value of a level at location , reduced from -dimensional to 2D using a linear layer. As for increases, neighbor tend to agree on a common representation of the whole represented in the image sample. At lower levels, smaller islands emerge, each representing a part of the whole. Samples of MNIST present fewer changes in the islands across levels because the data is much simpler, indicating that fewer levels in the hierarchy can be sufficient to obtain similar results. Our Agglomerator is thus able to represent a patch differently at different levels of abstraction. At the same level, spatially adjacent patches take the same value, agreeing on the representation of parts and wholes.
Latent space organization as the representation of conceptual-semantic relationship in data. Recent networks aim at maximizing inter-class distances and minimizing intra-class distances between samples in the latent space. While the accuracy is high, they provide little interpretability in their data representation. As a result, mistakes are less likely to happen, but the mistake severity, defined as the distance between two classes in WordNet lexical hierarchy [miller1995wordnet], does not decrease [bertinetto2020making]. As shown in Fig. 6, our network semantically organizes the input data resembling the human lexical hierarchy.
Our method introduces new types of hyper-parameters in the network structure, such as embedding dimensions, number of levels, and size of patches, which need to be tuned. We believe a better parameters setting can be found for all the datasets, increasing accuracy while still retaining interpretability. Moreover, a higher number of parameters generally causes architectures to be more prone to over-fitting and more difficult to train. To improve the accuracy of our network, we would need a pre-training on large datasets (e.g., on ImageNet), which requires large computational resources to be performed in a reasonable time frame. While hoping that powerful TPU architectures become publicly available in the future, we are currently investigating efficient pre-training strategies for our network.
We presented Agglomerator, a method that makes a step forward towards representing interpretable part-whole hierarchies and conceptual-semantic relationships in neural networks. We believe that interpretable networks are key to the success of artificial intelligence and deep learning. With this work, we intend to promote a preliminary implementation and the corresponding results on the image classification task, and we hope to inspire other researchers to adjust our solution to solve more complex and diverse tasks.