The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".
Recently CKY-based models show great potential in unsupervised grammar induction thanks to their human-like encoding paradigm, which runs recursively and hierarchically, but requires O(n^3) time-complexity. Recursive Transformer based on Differentiable Trees (R2D2) makes it possible to scale to large language model pre-training even with complex tree encoder by introducing a heuristic pruning method. However, the rule-based pruning approach suffers from local optimum and slow inference issues. In this paper, we fix those issues in a unified method. We propose to use a top-down parser as a model-based pruning method, which also enables parallel encoding during inference. Typically, our parser casts parsing as a split point scoring task, which first scores all split points for a given sentence, and then recursively splits a span into two by picking a split point with the highest score in the current span. The reverse order of the splits is considered as the order of pruning in R2D2 encoder. Beside the bi-directional language model loss, we also optimize the parser by minimizing the KL distance between tree probabilities from parser and R2D2. Our experiments show that our Fast-R2D2 improves performance significantly in grammar induction and achieves competitive results in downstream classification tasks.READ FULL TEXT VIEW PDF
The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".
Compositional, hierarchical and recursive processing are widely believed to be essential traits of human language across diverse linguistic theories DBLP:journals/tit/Chomsky56; chomsky2014aspects. Chart-based parsing models DBLP:journals/corr/MaillardCY17; kim-etal-2019-compound; dblp:conf/naacl/drozdovvyim19; hu-etal-2021-r2d2 have made promising progress in both grammar induction and hierarchical encodings in recent years. The differential CKY encoding architecture of DBLP:journals/corr/MaillardCY17 simulates the hierarchical and recursive process explicitly by introducing an energy function to combine all possible derivations when constructing each cell representation. However, this entails a cubic time complexity, which makes it impossible to scale to large language model training like BERT devlin2018. Analogously, the cubic memory cost also limits the tree encoder’s ability to draw on huge parameter models as a backbone.
hu-etal-2021-r2d2 introduced a heuristic pruning method, successfully reducing the time complexity to a linear number of compositions. Their experiments show that the chart-based model has great potential for grammar induction and representation learning when applying a sophisticated tree encoder such as Transformers with large corpus pretraining, leading to a Recursive Transformer based on Differentiable Trees, or R2D2 for short. However, their heuristic pruning approach is a rule-based algorithm that only considers certain composition probabilities. Thus, trees constructed in this way are not guaranteed to be globally optimal. Moreover, as each step during pruning is based on previous decisions, the entire encoding process is sequential and thus slow in the inference stage.
In this work, we resolve these issues by proposing a unified method with a new global pruning strategy based on a light-weight and fast top-down parser. We cast parsing as split point scoring, where we first encode the input sentence with a bi-directional LSTM, and score all split points in parallel. Specifically, for a given sentence, the parser first scores each split point between words in parallel by looking at its left and right contexts, and then recursively splits a span (starting with the whole sentence) into two sub-spans by picking a split point with the highest score among the current split candidates. During training, we incorporate sampling in the recursive splitting process, where, in each step, we sample a split point with respect to the score distribution in the current span and simplify the process as a sorting problem. Thus, the reverse order of the sorted split points can serve as the merge order to guide the pruning of the CKY encoder. As the gradient of the pretrained component cannot be back-propagated to the parser, inspired by URNNG dblp:conf/naacl/kimrykdm19, we optimize the parser by sampling trees over the CKY chart table generated by R2D2. Additionally, the pretrained tree encoder can compose sequences recursively in parallel according to the trees generated by the parser, which makes Fast-R2D2 a Recursive Neural Network DBLP:journals/ai/Pollack90; DBLP:conf/emnlp/SocherPWCMNP13 variant.
In this paper, we make the following main contributions:
We propose a model-based pruning method based on a top-down parser and corresponding unsupervised training objective. Experiments show that our parser outperforms models custom-tailored for grammar induction.
By encoding in parallel all trees generated by the top-down parser, Fast-R2D2 significantly improves the inference speed 30 to 50 fold compared to R2D2.
We pre-train Fast-R2D2 on a large corpus and evaluate it on downstream tasks. The experiments demonstrate that a pretrained recursive model based on an unsupervised parser significantly outperforms pretrained sequential Transformers DBLP:conf/nips/VaswaniSPUJGKP17 with the same parameter size in single sentence classification tasks.
R2D2 follows the work of DBLP:journals/corr/MaillardCY17 in defining a CKY-style 10.5555/1097042; kasami1966efficient; younger1967recognition encoder. Informally, given a sentence with words or word-pieces, R2D2 defines a chart table as shown in Figure 1. In such a chart, each cell is a tuple , where
is a vector representation,is the probability of a single composition step, and is the probability of the subtree for the span over the sub-string . At the lowest level, the table has terminal nodes with initialized as the embeddings of input tokens , while and are set to one. When , the representation is a weighted sum of intermediate combinations , defined as:
is a split point from to , is a -layer Transformer encoder. and denote the single step combination probability and the subtree probability, respectively, at split point , and are the concatenation of all or values, and Gumbel is the Straight-Through Gumbel-Softmax operation of DBLP:conf/iclr/JangGP17 with temperature set to one. As Gumbel picks the optimal splitting point at each cell in practice, it is straightforward to recover the complete derivation tree from the root node in a top-down manner recursively.
As shown in Figure LABEL:fig:pruning, R2D2 starts to prune if all cells beneath height have been encoded. The heuristic rules work as follows:
Recover the maximum sub-tree for each cell at the -th level, and collect all cells at the nd level that appear in any sub-tree;
Rank candidates in Step 1 by the composition probability , and pick the cell with highest score as a non-splittable span (e.g., );
Remove any invalid cells that would break the now non-splittable span from Step 2, e.g., the dark cells in (c), and reorganize the chart table much like in the Tetris game as in (d);
Encode the blank cells at the -th level, e.g., the cell highlighted with stripes in (d), and go back to Step 1 until the root cell has been encoded.
To learn meaningful structures without gold trees, hu-etal-2021-r2d2 propose a self-supervised pretraining objective. Similar to the bidirectional masked language model task, R2D2 reconstructs a given token based on its context representation and
. The probability of each token is estimated by the tree encoder defined in R2D2. The final objective is:
We propose a top-down parser to evaluate scores for all split points in a sentence and generate a merge order according to the scores.
Given a sentence , there are split points between words. We define a top-down parser by giving confidence scores to all split points as follows:
As shown in Figure 3, first, a bi-directional LSTM encodes the sentence, and then, for each split point, a MLP over the concatenation of the left and right context representations yields the final split scores. Formally, we have:
Here, is the embedding of the input sentence , while and denote the forward and reverse representation, respectively. is the score of the -th split point, whose left and right context representations are and . Given scores of , one can easily recover the binary tree shown in Figure 3: We recursively split a span (starting with the entire sentence) into two sub-spans by picking the split point with the highest score in the current span. Taking the sentence in Figure 3 (a) as an example, we split the overall sentence at split point in the first step, which leads to two sub-trees over and . Then we split at and at . We can continue this procedure until the complete tree has been recovered.
In the training stage, we perform sampling over the computed scores in order to increase the robustness and exploration of our model. We denote as the list of split points at time in ascending order, which is in the first step. Then a particular split point is selected from by sampling based on the probabilities estimated by stacking of split points scores. The sampled together form the final split point sequence . At each time step, we remove from when is selected, then sample the next split point until the set of remaining split points is empty. Formally, we have:
As the Gumbel-Max trick GUMBEL; DBLP:conf/nips/MaddisonTM14 provides a simple and efficient way to draw samples from a categorical distribution with class probabilities, we can obtain via the Gumbel-Max trick as:
where is the Gumbel noise for the -th split point. Therefore, the aforementioned process is equivalent to sorting the original sequence of split points scores with added Gumbel noise. Figure 3 (b) shows the sampled tree with respect to the split point scores. The split point sequence can hence be obtained simply as:
Here, sorts the array in descending order and returns the indices of the original array. The sampled is in Figure 3 (b).
As word-pieces wu2016google
and Byte-Pair Encoding (BPE) are commonly used in pretrained language models, it is straightforward to incorporate multiple word-piece constraints into the top-down parser to reduce word-level parsing errors. We denote asa list of span constraints composed of beginning and end positions of non-splittable spans, defined as . For each in , there should be a sub-tree for a span covering the sub-string . This goal can be achieved by ensuring that the scores for splits covered by are lower than those of others by adjusting them following Algorithm 1.
We denote as the reverse order of the split point sequence and then treat as a bottom-up merge order inferred by the top-down parser based on the global context. Subsequently, we simply replace Algorithm 2 in hu-etal-2021-r2d2 by our Algorithm 2. As shown in Figure LABEL:fig:pruning, we still retain the threshold and the pruning logic of R2D2, but we select cells to merge according to instead of following heuristic rules. Specifically, given a shrinking chart table, we select the next merge index among the second row by popping and modifying in Algorithm 2.
Take the example in Figure 3 (b) for instance: starts with . Then we merge the first cell in the second row in Figure LABEL:fig:pruning (b), and obtain a new . In the next round, we treat the 4th cell covering as a non-splittable cell in (e), and becomes .
We denote the tree probabilities estimated by the top-down parser and R2D2 as , , respectively. The difficulty here is that could be optimized by the objective defined in Equation 6, but there is no gradient feedback for . To make learnable, an intuitive solution is to fit to by minimizing their Kullback–Leibler distance. While the tree probabilities of both distributions are discrete and not exhaustive, inspired by URNNG dblp:conf/naacl/kimrykdm19, a Monte Carlo estimate for the gradient with respect to can be defined as:
with samples , …, from . Algorithm 3 shows the complete sampling process from . Specifically, we sample split points recursively as in previous work DBLP:journals/corr/cmp-lg-9805007; DBLP:conf/emnlp/FinkelMN06; dblp:conf/naacl/kimrykdm19 with respect to the intermediate tree probabilities calculated during hierarchical encoding.
A sequence of split points and corresponding spans is returned by the sampler. For the -th sample , we denote as the probability of taking as split from span at the -th step. Formally, we have:
where and denote the start and end of the corresponding span. Please note here that the are not adjusted by span constraints.
In this paper, we mainly focus on classification tasks as downstream tasks. We consider the root representation as representing the entire sentence. As we have two models pre-trained in our framework – R2D2 encoder and a top-down parser – we have two ways of generating the representations:
Run forced encoding over the binary tree from the top-down parser with R2D2 encoder;
Use the binary tree to guide the pruning of the R2D2 encoder, and take the root representation .
It is obvious that the first approach is much faster than the latter one, as the R2D2 encoder only runs times in forced encoding, and can run in parallel layer by layer, e.g., we may run compositions at , , and in parallel in Figure 3 (b). We explore both of these approaches in our experiments.
As suggested in previous work radford2018improving; howard-ruder-2018-universal; gururangan-etal-2020-dont, given a pretrained model, continued pretraining on an in-domain corpus with the same pretraining objective can yield a better generalization ability. Thus, we simply add together our pretraining objectives in all downstream tasks. At the same time, as the downstream task may guide R2D2 to more reasonable tree structures, we still maintain the KL loss to enable the parser to continuously update. For the two inference methods, we uniformly select the root representation as the representation for a given sentence followed by an MLP and estimate the cross-entropy loss, denoted as and , respectively. The final loss is:
For comparison, we include six recent strong models for unsupervised parsing with available open source implementations: structFormerDBLP:conf/acl/ShenTZBMC20
, Ordered NeuronsDBLP:conf/iclr/ShenTSC19, URNNG dblp:conf/naacl/kimrykdm19, DIORA dblp:conf/naacl/drozdovvyim19, C-PCFG kim-etal-2019-compound, and R2D2 hu-etal-2021-r2d2
. Following htut-etal-2018-grammar, we train all systems on a training set consisting only of raw text, and evaluate and report the results on an annotated test set. As an evaluation metric, we adopt sentence-level unlabeled
computed using the script from kim-etal-2019-compound. We compare against the non-binarized gold trees per convention. The best checkpoint for each system is picked based on scores on the validation set.
As our model is a pretrained model based on word-pieces, for a fair comparison, we test all models with two types of input: word level (W) and word-piece level (WP). Fast-R2D2 is pretrained with span constraints for the word level but without span constraints for the word-piece level. To support word-piece level evaluation, we convert gold trees to word-piece level trees by simply breaking each terminal node into a non-terminal node with its word-pieces as terminals, e.g., (NN discrepancy) into (NN (WP disc) (WP ##re) (WP ##pan) (WP ##cy)).
EFLOPS DBLP:conf/hpca/DongCZYWFZLSPGJ20 is a highly scalable distributed training system designed by Alibaba. With its optimized hardware architecture and co-designed supporting software tools, including ACCL DBLP:journals/micro/DongWFCPTLLRGGL21 and KSpeed (the high-speed data-loading service), it could easily be extended to 10K nodes (GPUs) with linear scalability.
The tree encoder of our model uses 4-layer Transformers with 768-dimensional embeddings, 3,072-dimensional hidden layer representations, and 12 attention heads. The top-down parser of our model uses a 4-layer bidirectional LSTM with 128-dimensional embeddings and 256-dimensional hidden layer. The sampling number is set to be 256. Training is conducted using Adam optimization with weight decay with a learning rate of for the tree encoder and for the top-down parser. The batch size is set to 96 for =
, though we also limit the maximum total length for each batch, such that excess sentences are moved to the next batch. The limit is set to 1,536. It takes about 120 hours for 60 epochs of training with= on 8 A100 GPUs.
Fast-R2D2 is pretrained on a mixed corpus of WikiText103 DBLP:conf/iclr/MerityX0S17 and the training set of Penn Treebank (PTB) marcus-etal-1993-building. WikiText103 is split at the sentence level, and sentences longer than 200 after tokenization are discarded (about 0.04‰ of the original data). The total number of sentences is 4,089,500, and the average sentence length is 26.97. For Chinese, we use a subset of Chinese Wikipedia for pretraining, specifically the first 10,000,000 sentences shorter than 150 characters.
We test our approach on PTB with the standard splits (2–21 for training, 22 for validation, 23 for test) and the same preprocessing as in recent work kim-etal-2019-compound, where we discard punctuation and lower-case all tokens. To explore the universality of the model across languages, we also run experiments on Chinese Penn Treebank (CTB) 8 ctb8, on which we also remove punctuation.
Note that in all settings, the training is conducted entirely on raw unannotated text.
|Left Branching (W)||8.15||11.28|
|Right Branching (W)||39.62||27.53|
|Random Trees (W)||17.76||20.17|
Table 1 shows results of all systems with words (W) and word-pieces (WP) as input on the WSJ and CTB test sets. When we evaluate all systems on word-level golden trees, our Fast-R2D2 performs substantially better than R2D2 across both datasets. As mentioned in Section 3.3, Fast-R2D2 has two options to get the final tree and representation for a given input sentence: We denote as Fast-R2D2 the system that runs forced encoding over the output tree of the top-down parser, and as Fast-R2D2 the method of using the tree to guide the pruning of the R2D2 encoder and selecting the best tree using the chart table. The results suggest that Fast-R2D2 outperforms Fast-R2D2 on both the WSJ and CTB test sets. More interestingly, Fast-R2D2, a model without human-designed grammar constraints at the word-piece level granularity, outperforms the model specifically designed for grammar induction at a word-level granularity. If all systems take word-pieces as input, and are measured against word-piece level golden trees (the bottom five results), our Fast-R2D2 obtains state-of-the-art results on both WSJ and CTB.
Following dblp:conf/naacl/kimrykdm19 and drozdov-etal-2020-unsupervised, we also compute the recall of constituents when evaluating on word-piece level golden trees. Besides standard constituents, we also compare the recall of word-piece chunks and proper noun chunks. Proper noun chunks are extracted by finding adjacent unary nodes with the same parent and tag NNP. Table 2 reports the recall scores for constituents and words on the WSJ and CTB test sets. Compared with the R2D2 baseline, our Fast-R2D2 performs slightly worse for small semantic units, but significantly better over larger semantic units (such as VP and SBAR) on the WSJ test set. On the CTB test set, our Fast-R2D2 outperforms R2D2 on all constituents.
From Tables 1 and 2, we conclude that Fast-R2D2 overall obtains better results than R2D2 on CTB, while faring slightly worse than R2D2 only for small semantic units on WSJ. We conjecture that this difference stems from differences in tokenization between Chinese and English. Chinese is a character-based language without complex morphology, where collocations of characters are consistent with the language, making it easier for the top-down parser to learn them well. In contrast, word-pieces for English are built based on statistics, and individual word-pieces are not necessarily natural semantic units. Thus, there may not be sufficient semantic self-consistency, such that it is harder for a top-down parser with a small amount of parameters to fit it well.
We next consider the effectiveness of Fast-R2D2 in downstream tasks. This experiment is not intended to advance the state-of-the-art on the GLUE benchmark but rather to assess to what extent our approach performs respectably against the dominating inductive bias as in conventional sequential Transformers.
We fine-tune pretrained models on several datasets, including SST-2, CoLA, QQP, and MNLI from the GLUE benchmarkwang2018glue. As sequential Transformers with their dominating inductive bias remain the norm for numerous NLP tasks, we mainly compare Fast-R2D2 with Bert devlin2018 as a representative pretrained model based on a sequential Transformer. In order to compare the two forms of inductive bias fairly and efficiently, we pretrain Bert models with 4-layers and 12-layers as well as our Fast-R2D2 from scratch on the WikiText103 corpus following Section 4.1.1. For simplicity, Fast-R2D2 is fine-tuned without span constraints. Following the common settings, we add an MLP layer over the root representation of the R2D2 encoder for single sentence classification. For cross-sentence tasks such as QQP and MNLI, we feed the root representations of the two sentences into the pretrained tree encoder of R2D2 as left and right inputs, and also add a new task id as another input term to the R2D2 encoder. Then we feed the hidden output of the new task id into another MLP layer to predict the final label. We train all systems on four data sets for 10 epochs with a learning rate of , batch size , and a maximum input length . We validate each model in each epoch, and report the best results on development sets.
|Model||Para.||Single sent.||Cross sent.|
Table 3 shows the corresponding scores on SST-2, CoLA, QQPl, and MNLI. In terms of the parameter size, our Fast-R2D2 model has 52M and 10M parameters for the R2D2 encoder and top-down parser, respectively. It is clear that 12-layer Bert is significantly better than 4-layer Bert. Fast-R2D2 uses the output tree from the top-down parser, while Fast-R2D2 uses the best tree inferred by the R2D2 encoder. Similar to the results for unsupervised parsing, Fast-R2D2 in classification tasks again outperforms Fast-R2D2. We hypothesize that trees generated by the top-down parser without Gumbel noise are more stable and reasonable. Fast-R2D2 significantly outperforms 4-layer Bert and achieves competitive results compared to 12-layer Bert in single sentence classification tasks like SST-2 and CoLA, but still performs significantly worse in the cross-sentence tasks. We believe this is an expected result, as there is no cross-attention mechanism in the inductive bias of Fast-R2D2. However, the performance of Fast-R2D2 on classification tasks shows that the inductive bias of R2D2 has higher parameter utilization than sequentially applied Transformers. Importantly, we demonstrate that a Recursive Neural Network variant with an unsupervised parser can achieve comparable results to pretrained sequential Transformers even with fewer parameters and interpretable intermediate results, Hence, our Fast-R2D2 framework provides an alternative choice for NLP tasks.
As DBLP:conf/icml/ChowdhuryC21 argued, parallelism is also an important feature for a RvNN. As the tree encoder of Fast-R2D2 is based on Transformers, we mainly compare the time cost of sequential Transformers and Fast-R2D2 in forced encoding on various sequence length ranges. We randomly select 1,000 sentences for each range from WikiText103 and report the average time consumption on single A100 GPU. Bert is based on the open source Transformers library222https://github.com/huggingface/transformers and R2D2 is based on the official code333https://github.com/alipay/StructuredLM_RTDT/tree/r2d2 in hu-etal-2021-r2d2.
|Model||Sequence Length Ranges|
Table 4 shows the inference time in seconds for different systems to process 1,000 sentences with a batch size of 50. Running R2D2 is time-consuming since the heuristic pruning method involves huge memory exchanges between GPU and CPU. In Fast-R2D2, we alleviate this problem by using model-guided pruning to accelerate the chart table processing, in conjunction with a code implementation in CUDA, Fast-R2D2 reduces the inference time significantly. Fast-R2D2 further improves the inference speed by running forced encoding in parallel over the binary tree generated by the parser, which is about 30-50 times faster than R2D2 in various ranges. Although there is still a gap in speed compared to sequential Transformers, Fast-R2D2 is sufficiently fast for most NLP tasks while producing interpretable intermediate representations.
Many attempts have been done in grammar induction and hierarchical encoding. DBLP:conf/conll/Clark01 and DBLP:conf/acl/KleinM02 provided some of the first successful statistical approaches to grammar induction. There have been multiple reccent papers that focus on structure induction based on language modeling objectives DBLP:conf/nips/ShenTHLSC19; DBLP:conf/iclr/ShenTSC19; DBLP:conf/acl/ShenTZBMC20; kim-etal-2019-compound
. DBLP:journals/ai/Pollack90 propose to use RvNN as a recursive architecture to encode text hierarchically, and DBLP:conf/emnlp/SocherPWCMNP13 show the effectiveness of RvNNs with gold trees for sentiment analysis. In this work, we focus on models that are capable of learning meaningful structures in an unsupervised way and encoding text over the induced tree hierarchically.
In the line of work on learning a sentence representation with structures, DBLP:conf/iclr/YogatamaBDGL17 jointly train their shift-reduce parser and sentence embedding components without gold trees. As their parser is not differentiable, they have to resort to reinforcement training, resulting in increased variance, which may easily collapse to trivial left/right branching trees. Gumbel-Tree-LSTMsDBLP:conf/aaai/ChoiYL18 construct trees by recursively selecting two terminal nodes to merge and learning the composition probability via downstream tasks. CRvNN DBLP:conf/icml/ChowdhuryC21 makes the whole process end-to-end differential and parallel by introducing a continuous relaxation. The work of URNNG dblp:conf/naacl/kimrykdm19 applies variational inference over latent trees to perform unsupervised optimization of the RNNG dyer-etal-2016-recurrent
, an RNN model that estimates a joint distribution over sentences and trees based on shift-reduce operations. However, it is hard to induce them when trained from scratch. DBLP:journals/corr/MaillardCY17 propose an alternative approach, based on a differential CKY encoding. The algorithm is differential by using a soft-gating approach, which approximates discrete candidate selection by a probabilistic mixture of the constituents available in a given cell of the chart. While their work relies on annotated downstream tasks to learn structures, dblp:conf/naacl/drozdovvyim19 propose a novel auto-encoder-like pretraining objective based on the inside–outside algorithmBaker1979TrainableGF; DBLP:conf/icgi/Casacuberta94. As mentioned above, CKY-based models have cubic time complexity. hu-etal-2021-r2d2 propose a pruned differential CKY encoding architecture with a simple pretraining objective related to bi-directional language modeling. They reduce the time complexity to a linear number of composition steps and make it possible to apply sophisticated tree encoders and to scale to large corpus pretraining.
In this paper, we have presented Fast-R2D2, which improves the performance and inference speed of R2D2 by introducing a fast top-down parser to guide the pruning of R2D2 encoder. Pretrained on the same corpus, Fast-R2D2 significantly outperforms sequential Transformers with a similar scale of parameters on classification tasks. Experimental results show that Fast-R2D2 is a promising and feasible way to learn hierarchical text representations, which is different from layer stacking models and can also generate interpretable intermediate representations. As future work, we are investigating leveraging the intermediate representations in additional downstream tasks.
We would like to thank the Aliyun EFLOPS team for their substantial support in designing and providing a cutting-edge training platform to facilitate fast experimentation in this work.