Analysis methods which allow us to better understand the representations and functioning of neural models of language are increasingly needed as deep learning becomes the dominant approach to natural language processing. A popular technique for analyzing neural representations involves predicting information of interest from the activation patterns, typically using a simple predictive model such as a linear classifier or regressor. If the model is able to predict this information with high accuracy, the inference is that the neural representation encodes it. We refer to these asdiagnostic models.
One important limitation of this method of analysis is that it is only easily applicable to relatively simple types of target information, which are amenable to be predicted via linear regression or classification. Should we wish to decode activation patterns into a structured target such as a syntax tree, we would need to resort to complex structure prediction algorithms, running the risk that the analytic method becomes no simpler than the actual neural model.
Here we introduce an alternative approach based on correlating neural representations of sentences and structured symbolic representations commonly used in linguistics. Crucially, the correlation is in similarity space rather than in the original representation space, removing most constraints on the types of representations we can use. Our approach is an extension of the Representational Similarity Analysis (RSA) method, initially introduced by kriegeskorte2008representational in the context of understanding neural activation patterns in human brains.
In this work we propose to specifically apply RSA to neural representations of strings from a language on one side, and to structured symbolic representations of these strings on the other side. To capture the similarities between these symbolic representations, we use a tree kernel, a metric to compute the proportion of common substructures between trees. This approach enables straightforward comparison of neural and symbolic-linguistic representations. Furthermore, we introduce RSA, a similarity-based analytic method which combines features of RSA and of diagnostic models.
We validate both techniques on neural models which process a synthetic language for arithmetic expressions with a simple syntax and semantics and show that they behave as expected in this controlled setting. We further apply our techniques to two neural models trained on English text, Infersent (DBLP:journals/corr/BojanowskiGJM16) and BERT (DBLP:journals/corr/abs-1810-04805), and show that both models encode a substantial amount of syntactic information compared to random models and a simple bag-of-words representations; we also show that according to our metrics syntax is most salient in the intermediate layers of BERT.
2 Related work
2.1 Analytic methods
Dominance of deep learning models in NLP has brought an increasing interest in techniques to analyze these models and gain insight into how they encode linguistic information. For an overview of analysis techniques, see belinkov2018analysis. The most widespread family of techniques are diagnostic models
, which use the internal activations of neural networks trained on a particular task as input to another predictive model. The success of such a predictive model is then interpreted as evidence that the predicted information has been encoded by the original neural model. The approach has also been calledauxiliary task adi2016fine, decoding alishahi2017encoding, diagnostic classifier hupkes2018visualisation or probing P18-1198.
Diagnostic models have used a range of predictive tasks, but since their main purpose is to help us better understand the dynamics of a complex model, they themselves need to be kept simple and interpretable. This means that the predicted information in these techniques is typically limited to simple class labels or values, as opposed to symbolic, structured representations such as syntactic trees which are of interest to linguists. In order to work around this limitation tenney2018you present a method for probing complex structures via a formulation named edge probing, where classifiers are trained to predict various lexical, syntactic and semantic relations between representation of word spans within a sentence.
Another important consideration when analyzing neural encodings is the fact that a randomly initialized network will often show non-random activation patterns. The reasons for this depends on each particular case, but may involve the dynamics of the network itself as well as features of the input data. For a discussion of this issue in the context of diagnostic models see zhang2018language.
Alternative approaches have been proposed to analyzing neural models of language. For example, saphra2018understanding train parallel recurrent models to perform POS, semantic and topic tagging, and measure the correlation between the neural representaions of the original model and the trained tagger.
Others modify the neural architecture itself to make it more interpretable: W18-5403 adapt layerwise relevance propagation bach2015pixel to Kernel-based Deep Architectures croce2017deep
in order to retrieve examples which motivate model decisions. A vector representation for a given structured symbolic input is built based on kernel evaluations between the input and a subset of training examples known as landmarks, and the network decision is then traced back to the landmarks which had most influence on it. In our work we also use kernels between symbolic structures, but rather than building a particular interpretable model we propose a general analytical framework.
2.2 Representation Similarity Analysis
kriegeskorte2008representational present RSA as a variant of pattern-information analysis, to be applied for understanding neural activation patterns in human brains, for example syntactic computations (tyler2013syntactic) or sensory cortical processing (yamins2016using). The core idea is to find connections between data from neuroimaging, behavioral experiments and computational modeling by correlating representations of stimuli in each of these representation spaces via their pairwise (dis)similarities.
2.3 Tree kernels
For extending RSA to a structured representation space, we need a metric for measuring (dis)similarity between two structured representations. Kernels provide a suitable framework for this purpose: collins2002convolution introduce convolutional kernels for syntactic parse trees as a metric which quantifies similarity between trees as the number of overlapping tree fragments between them, and introduce a polynomial time algorithm to compute these kernels; moschitti2006making propose an efficient algorithm for computing tree kernels in linear average running time.
2.4 Synthetic languages
When developing techniques for analyzing neural network models of language, several studies have used synthetic data from artificial languages. Using synthetic language has the advantage that its structure is well-understood and the complexity of the language and the statistical characteristics of the generated data can be carefully controlled. The tradition goes back to the first generation of connectionist models of language elman1990finding; hochreiter1997long. More recently, W18-5414 and W18-5425 both use context-free grammars to generate data, and train RNN-based models to identify matching numbers of opening and closing brackets (so called Dyck languages). The task can be learned, but W18-5414 report that the models fail to generalize to longer sentences. W18-5456 also show that with extensive training and the appropriate curriculum, LSTMs trained on synthetic language can learn compositional interpretation rules.
Nested arithmetic languages are also appealing choices since they have an unambiguous hierarchical structure and a clear compositional semantic interpretation (i.e. the value of the arithmetic expression). hupkes2018visualisation train RNNs to calculate the value of such expressions and show that they perform and generalize well to unseen strings. They apply diagnostic classifiers to analyze the strategy employed by the RNN model.
3 Similarity-based analytical methods
RSA finds connections between data from two different representation spaces. Specifically for each representation type we compute a matrix of similarities between pairs of stimuli. Pairs of these matrices are then subject to second-order analysis by extracting their upper triangulars and computing a correlation coefficient between them.
Thus for a set of objects , given a similarity function for a representation , the function which computes the representational similarity matrix is defined as:
and the RSA score between representations and for data is the correlation (such as Pearson’s correlation coefficient ) between the upper triangulars and , excluding the diagonals.
We apply RSA to neural representations of strings from a language on one side, and to structured symbolic representations of these strings on the other side. The structural properties are captured by defining appropriate similarity functions for these symbolic representations; we use tree kernels for this purpose.
A tree kernel measures the similarity between a pair of tree structures by computing the number of tree fragments they share. collins2002convolution introduce an algorithm for efficiently computing this quantity; see the supplementary material for details. A tree fragment in their formulation is a set of connected nodes subject to the constraint that only complete production rules are included. Following collins2002convolution we work with normalized tree kernels: given a function which computes the raw count of tree fragments in common between trees and , the normalized tree kernel is defined as:
Basic RSA measures correlation between similarities in two different representations globally, i.e. how close they are in their totality. In contrast, diagnostic models answer a more specific question: to what extent a particular type of information can be extracted from a given representation. For example, while for a particular neural encoding of sentences it may be possible to predict the length of the sentence with high accuracy, the RSA between this representation and the strings represented only by their length may be relatively small in magnitude, since the neural representation may be encoding many other aspects of the input in addition to its length.
We introduce RSA, a method which shares features of both classic RSA as well as the diagnostic model approach. Like RSA it is based on two similarity functions and specific to two different representations and . But rather than computing the square matrices and for a set of objects , we sample a reference set of objects to act as anchor points, and then embed the objects of interest in the representation space via the representational similarity function defined as:
Likewise for representation , we calculate for the same set of objects . The rows of the two resulting matrices contain two different views of the objects of interest, where the dimensions of each view indicate the degree of similarity for a particular reference anchor point. We can now fit a multivariate linear regression model to map between the two views:
where is the source and is the target view, and is the mean squared error. The success of this model can be seen as an indication of how predictable representation is from representation . Specifically, we use a cross-validated Pearson’s correlation between predicted and true targets for an -penalized model.
4 Synthetic language
Evaluation of analysis methods for neural network models is an open problem. One frequently resorts to largely qualitative evaluation: checking whether the conclusions reached via a particular approach have face validity and match pre-existing intuitions. However pre-existing intuitions are often not reliable when it comes to complex neural models applied to also very complex natural language data. It is helpful to simplify one part of the overall system and apply the analytic technique of interest on the neural model which processes a simple and well-understood synthetic language. As our first case study, we use a simple language of arithmetic expressions. Here we first describe the language and its syntax and semantics, and then introduce neural recurrent models which process these expressions.
4.1 Arithmetic expressions
Our language consists of expressions which encode addition and subtraction modulo 10. Consider the example expression ((6+2)-(3+7)). In order to evaluate the whole expression, each parenthesized sub-expression is evaluated modulo 10: in this case the left sub-expression evaluates to 8, the right one to 0 and the whole expression to 8. Table 1 gives the context-free grammar which generates this language, and the rules for semantic evaluation. Figure 1 shows the syntax tree for the example expression according to this grammar. This language lacks of ambiguity, has a small vocabulary (14 symbols) and simple semantics, while at the same time requiring processing of hierarchical structure to evaluate its expressions.
In order to recursively generate expressions in we need to pick which rule of to expand. We use the recursive function Generate defined in Algorithm 1
. The function receives two input parameters: the branching probabilityand the decay factor . In the recursive call to Generate in lines 4 and 5 the probability is divided by the decay factor. Larger values of lead to the generation of smaller expressions. Within the branching path in line 6 the operator (grammar rules 1–2) is selected uniformly at random, and likewise in the non-branching path in line 9 the digit (rules 3–13) is sampled uniformly.
4.2 Neural models of arithmetic expressions
We define three recurrent models which process the arithmetic expressions from language . Each of them is trained to predict a different target, related either to the syntax of the language or to its semantics. We use these models as a testbed for validating our analytical approaches. All these models share the same recurrent encoder architecture, based on LSTM (hochreiter1997long).
The encoder consists of a trainable embedding lookup table for the input symbols, and a single-layer LSTM. The state of the hidden layer of the LSTM at the last step in the sequence is used as a representation of the input expression.
This model consists of the encoder as described above, which passes its representation of the input to a multi-layer perceptron component with a single output neuron. It is trained to predict the value of the input expression, with mean squared error as the loss function. In order to perform this task we would expect that the model needs to encode the hierarchical structure of the expression to some extent while also encoding the result of actually carrying out the operations of semantic evaluation.
This model is similar to semantic evaluation but is trained to predict the depth of the syntax tree corresponding to the expression instead of its value. We expect this model to need to encode a fair amount of hierarchical information, but it can completely ignore the semantics of the language, including the identity of the digit symbols.
This model uses the encoder to create a representation of the input expression, which it then decodes in its prefix form. For example, the expression ((6+2)-(3+7)) is converted to (-(+62)(+37)). The decoder is an LSTM trained as a conditional language model, i.e. its initial hidden state is the output of the encoder and its input at each step is the embedding of previous output symbol. The loss function is categorical cross-entropy. We would expect this model to encode the hierarchical structure in some form as well as the identity of the digit symbols, but it can ignore the compositional semantics of the language.
4.3 Reference representations
We use RSA to correlate the neural encoders from Section 4.2 with reference syntactic and semantic information about the arithmetic expressions. For the neural representations we use cosine distance as the dissimilarity metric. The reference representations and their associated dissimilarity metrics are described below.
This is simply the value to which each expression evaluates, also used as the target of the semantic evaluation model. As a measure of dissimilarity we use the absolute difference between values, which ranges from 0 to 9.
This is the depth of the syntax tree for each expression, also used as the target of the tree depth model. We use the absolute difference as the dissimilarity measure. The dissimilarity is minimum 0 and has no upper bound, but in our data the typical maximum value is around 7.
This is an estimate of similarity between two syntax trees based on the number of tree fragments they share, as described in Section3 (see supplementary material for a full example). The normalized tree kernel metric ranges between 0 and 1, which we convert to dissimilarity by subtracting it from 1.
The semantic value and tree depth correlates are easy to investigate with a variety of analytic methods including diagnostic models; we include them in our experiments as a point of comparison. We use the tree kernel representation to evaluate structured RSA for a simple synthetic language.
4.4 Experimental settings
We implement the neural models in PyTorch 1.0.0. We use the following model architecture: encoder embedding layer size 64, encoder LSTM size 128, for the regression models, MLP hidden layer size 256, for the sequence-to-sequence model the decoder hyper-parameters are the same as the encoder. The symbols are predicted via a linear projection layer from hidden state, followed by a softmax. Training proceeds following a curriculum: we first train on 100,000 batches of size 32 of random expressions sampled with decay, followed by 200,000 batches with and finally 400,000 batches with . We optimize with Adam with learning rate . We report results on expressions sampled with . See supplementary material for the distribution of expression sizes for these values of .
We report all results for two conditions: randomly initialized and trained in order to quantify the effect of learning on the activation patterns as a difference between the score for the trained model and the random model. The trained model is chosen by saving model weights during training every 10,000 batches and selecting the weights with the smallest loss on 1,000 held-out validation expressions. Results are reported on separate test data consisting of 2,000 expressions and 200 reference expressions for RSA embedding.
Table 2 shows the results of our experiments, where each row shows a different encoder type and each column a different target task. We will discuss these results in detail below.
Semantic value and tree depth
As a first sanity check, we would like to see whether the RSA techniques show the same pattern captured by the diagnostic models. As expected, both diagnostic and RSA scores are the highest when the objective function used to train the encoder and the analytical reference representations match: for example, the semantic evaluation encoder scores high on the semantic value reference, both for the diagnostic model and the RSA. Furthermore, the scores for the value and depth reference representation according to the diagnostic model and according to RSA are in agreement. The scores according to RSA in some cases show a different picture. This is expected, as RSA answers a substantially different question than the other two approaches: it looks at how the whole representations match in their similarity structure, whereas both the diagnostic model and RSA focus on the part of the representation that encodes the target information the strongest.
We can use both RSA and RSA for exploring whether the hidden activations encode any structural representation of syntax: this is evident in the scores yielded by the TK reference representations. As expected, the highest scores for both methods are gained when using Infix-to-prefix encodings, the task that relies the most on the hierarchical structure of an input string. RSA yields the second-highest score for Tree depth encodings, which also depend on aspects of tree structure. What is unexpected is the results for the random encoder, which we turn to next.
The non-random nature of the activation patterns of randomly initialized models (e.g., zhang2018language) is also strongly in evidence in our results. For example the random encoder has quite a high score for diagnostic regression on tree depth. Even more striking is the fact that the random encoder has substantial negative RSA score for the Tree Kernel: thus, expression pairs more similar according to the Tree Kernel are less similar according to the random encoder, and vice-versa.
When applying RSA we can inspect the full correlation pattern via a scatter-plot of the dissimilarities in the reference and encoder representations. Figure 2 shows the data for the random encoder and the Tree Kernel representations. As can be seen, the negative correlation for the random encoder is due to the fact that according to the Tree Kernel, expression pairs tend to have high dissimilarities, while according to the random encoder’s activations they tend to have overall low dissimilarities. For the trained Infix-to-prefix encoder the dissimilarities are clearly positively correlated with the TK dissimilarities.
Thus the raw correlation value for the trained encoder is a biased estimate of the effect of learning, as learning has to overcome the initially substantial negative correlation: a better estimate is the difference between scores for the learned and random model. It is worth noting that the same approach would be less informative for the diagnostic model approach or for RSA. For a regression model the correlation scores will be positive, and when taking the difference between learned and random scores, they may cancel out, even though a particular information may be predictable from the random activations in a completely different way than from the learned activations. This is what we see for the RSA scores for random vs. infix-to-prefix encoder: the scores partially cancel out, and given the pattern in Figure 2 it is clear that subtracting them is misleading. It is thus a good idea to complement the RSA score with the plain RSA correlation score in order to obtain a full picture of how learning affects the neural representations.
Overall, these results show that RSA can be used to answer the same sort of questions as the diagnostic model. It has the added advantage of being also easily applicable to structured symbolic representations, while the RSA scores and the full RSA correlation pattern provides a complementary source of insight into neural representations. Encouraged by these findings, we next apply both RSA and RSA to representations of natural language sentences.
5 Natural language
Here we use our proposed RSA-based techniques to compare tree-structure representations of natural language sentences with their neural representations captured by sentence embeddings. Such embeddings are often provided by NLP systems trained on unlabeled text, using variants of a language modeling objective (e.g. peters2018deep), next and previous sentence prediction (kiros2015skip; logeswaran2018efficient), or discourse based objectives (nie2017dissent; jernite2017discourse). Alternatively they can be either fully trained or fine-tuned on annotated data using a task such as natural language inference (conneau-EtAl:2017:EMNLP2017). In our experiments we use one of each type of encoders.
Bag of words
As a baseline we use a classic bag of words model where a sentence is represented by a vector of word counts. We do not exclude any words and use raw, unweighted word counts.
This is the supervised model described in conneau-EtAl:2017:EMNLP2017 based on a bidirectional LSTM trained on natural language inference. We use the infersent2 model with pre-trained fastText (DBLP:journals/corr/BojanowskiGJM16) word embeddings.111Available at https://github.com/facebookresearch/InferSent. We also test a randomly initialized version of this model, including random word embeddings.
This is an unsupervised model based on the Transformer architecture (vaswani2017attention) trained on a cloze-task and next-sentence prediction (DBLP:journals/corr/abs-1810-04805). We use the Pytorch version of the large 24-layer model (bert-large-uncased).222Available at https://github.com/huggingface/pytorch-pretrained-BERT. We also test a randomly initialized version of this model.
5.2 Experimental settings
We use a sample of data from the English Web Treebank (EWT) (ewt) which contains a mix of English weblogs, newsgroups, email, reviews and question-answers manually annotated for syntactic constituency structure. We use the 2,002 sentences corresponding to the development section of the EWT Universal Dependencies (silveira14gold), plus 200 sentences from the training section as reference sentences when fitting RSA.
We compute Tree Kernel scores on the original constituency trees without any preprocessing, for the values of . Lower values of this parameter discount larger tree fragments in the computation of the kernel; the value 1 does not do any discounting. See the supplementary material for details, and Figure 3 for the illustration of the effect.
For the BERT embeddings we use the vector associated with the first token (CLS
) for a given layer. For Infersent, we use the default max-pooled representation.
When fitting RSA we use L2-penalized multivariate linear regression. We report the results for the value of the , for , with the highest -fold cross-validated Pearson’s between target and predicted similarity-embedded vectors.
Table 3 shows the results of applying RSA and RSA on five different sentence encoders, using the Tree Kernel reference. Results are reported using two different values for the Tree Kernel parameter . For BERT, we report the scores for the topmost layer as well as for the layer which maximizes the given score.
As can be seen, with , all the encoders show a substantial RSA correlation with the parse trees. The highest scores are achieved by the trained Infersent and BERT, but even Bag of Words and untrained versions of Infersent and BERT show a sizeable correlation with syntactic trees according to both RSA and RSA.
When structure matching is strict (), only trained BERT and Infersent capture syntactic information according to RSA; however, RSA still shows moderate correlation for BoW and the untrained versions of BERT and Infersent. Thus RSA is less sensitive to the value of than RSA since changing it from to does not alter results in a qualitative sense.
Figure 4 shows how RSA and RSA scores change when correlating Tree Kernel estimates with embeddings from different layers of BERT. For trained models, scores peak between layers 15–23 (depending on metric and ) and decline thereafter, which indicates that the final layers are increasingly dedicated to encoding aspects of sentences other than pure syntax.
We present two RSA-based methods for correlating neural and syntactic representations of language, using tree kernels as a measure of similarity between syntactic trees. Our results on arithmetic expressions confirm that both versions of structured RSA capture correlations between different representation spaces, while providing complementary insights. We apply the same techniques to English sentence embeddings, and show where and to what extent each representation encodes syntactic information. The proposed methods are general and applicable not just to constituency trees, but given a similarity metric, to any symbolic representation of linguistic structures including dependency trees or Abstract Meaning Representations. We plan to explore these options in future work. Upon publication we will release a toolkit with an implementation of our methods.
Size distribution of arithmetic expressions
Tree Kernel algorithm
Following collins2002convolution, we calculate the tree kernel between two trees and as:
where and are the complete sets of tree fragments in and , respectively, and the function is calculated as:
Here is the number of children of a given (sub)tree, and is its i child; is the production of node , and is true if is a preterminal node. Figure 6 shows the complete set of tree fragments which the tree kernel implicitly computes for the expression (3+7).