StructuredLM_RTDT
The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".
view repo
Recently CKYbased models show great potential in unsupervised grammar induction thanks to their humanlike encoding paradigm, which runs recursively and hierarchically, but requires O(n^3) timecomplexity. Recursive Transformer based on Differentiable Trees (R2D2) makes it possible to scale to large language model pretraining even with complex tree encoder by introducing a heuristic pruning method. However, the rulebased pruning approach suffers from local optimum and slow inference issues. In this paper, we fix those issues in a unified method. We propose to use a topdown parser as a modelbased pruning method, which also enables parallel encoding during inference. Typically, our parser casts parsing as a split point scoring task, which first scores all split points for a given sentence, and then recursively splits a span into two by picking a split point with the highest score in the current span. The reverse order of the splits is considered as the order of pruning in R2D2 encoder. Beside the bidirectional language model loss, we also optimize the parser by minimizing the KL distance between tree probabilities from parser and R2D2. Our experiments show that our FastR2D2 improves performance significantly in grammar induction and achieves competitive results in downstream classification tasks.
READ FULL TEXT VIEW PDFThe official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".
Compositional, hierarchical and recursive processing are widely believed to be essential traits of human language across diverse linguistic theories DBLP:journals/tit/Chomsky56; chomsky2014aspects. Chartbased parsing models DBLP:journals/corr/MaillardCY17; kimetal2019compound; dblp:conf/naacl/drozdovvyim19; huetal2021r2d2 have made promising progress in both grammar induction and hierarchical encodings in recent years. The differential CKY encoding architecture of DBLP:journals/corr/MaillardCY17 simulates the hierarchical and recursive process explicitly by introducing an energy function to combine all possible derivations when constructing each cell representation. However, this entails a cubic time complexity, which makes it impossible to scale to large language model training like BERT devlin2018. Analogously, the cubic memory cost also limits the tree encoder’s ability to draw on huge parameter models as a backbone.
huetal2021r2d2 introduced a heuristic pruning method, successfully reducing the time complexity to a linear number of compositions. Their experiments show that the chartbased model has great potential for grammar induction and representation learning when applying a sophisticated tree encoder such as Transformers with large corpus pretraining, leading to a Recursive Transformer based on Differentiable Trees, or R2D2 for short. However, their heuristic pruning approach is a rulebased algorithm that only considers certain composition probabilities. Thus, trees constructed in this way are not guaranteed to be globally optimal. Moreover, as each step during pruning is based on previous decisions, the entire encoding process is sequential and thus slow in the inference stage.
In this work, we resolve these issues by proposing a unified method with a new global pruning strategy based on a lightweight and fast topdown parser. We cast parsing as split point scoring, where we first encode the input sentence with a bidirectional LSTM, and score all split points in parallel. Specifically, for a given sentence, the parser first scores each split point between words in parallel by looking at its left and right contexts, and then recursively splits a span (starting with the whole sentence) into two subspans by picking a split point with the highest score among the current split candidates. During training, we incorporate sampling in the recursive splitting process, where, in each step, we sample a split point with respect to the score distribution in the current span and simplify the process as a sorting problem. Thus, the reverse order of the sorted split points can serve as the merge order to guide the pruning of the CKY encoder. As the gradient of the pretrained component cannot be backpropagated to the parser, inspired by URNNG dblp:conf/naacl/kimrykdm19, we optimize the parser by sampling trees over the CKY chart table generated by R2D2. Additionally, the pretrained tree encoder can compose sequences recursively in parallel according to the trees generated by the parser, which makes FastR2D2 a Recursive Neural Network DBLP:journals/ai/Pollack90; DBLP:conf/emnlp/SocherPWCMNP13 variant.
In this paper, we make the following main contributions:
We propose a modelbased pruning method based on a topdown parser and corresponding unsupervised training objective. Experiments show that our parser outperforms models customtailored for grammar induction.
By encoding in parallel all trees generated by the topdown parser, FastR2D2 significantly improves the inference speed 30 to 50 fold compared to R2D2.
We pretrain FastR2D2 on a large corpus and evaluate it on downstream tasks. The experiments demonstrate that a pretrained recursive model based on an unsupervised parser significantly outperforms pretrained sequential Transformers DBLP:conf/nips/VaswaniSPUJGKP17 with the same parameter size in single sentence classification tasks.
R2D2 follows the work of DBLP:journals/corr/MaillardCY17 in defining a CKYstyle 10.5555/1097042; kasami1966efficient; younger1967recognition encoder. Informally, given a sentence with words or wordpieces, R2D2 defines a chart table as shown in Figure 1. In such a chart, each cell is a tuple , where
is a vector representation,
is the probability of a single composition step, and is the probability of the subtree for the span over the substring . At the lowest level, the table has terminal nodes with initialized as the embeddings of input tokens , while and are set to one. When , the representation is a weighted sum of intermediate combinations , defined as:(1)  
(2)  
(3)  
(4)  
(5) 
is a split point from to , is a layer Transformer encoder. and denote the single step combination probability and the subtree probability, respectively, at split point , and are the concatenation of all or values, and Gumbel is the StraightThrough GumbelSoftmax operation of DBLP:conf/iclr/JangGP17 with temperature set to one. As Gumbel picks the optimal splitting point at each cell in practice, it is straightforward to recover the complete derivation tree from the root node in a topdown manner recursively.
As shown in Figure LABEL:fig:pruning, R2D2 starts to prune if all cells beneath height have been encoded. The heuristic rules work as follows:
Recover the maximum subtree for each cell at the th level, and collect all cells at the nd level that appear in any subtree;
Rank candidates in Step 1 by the composition probability , and pick the cell with highest score as a nonsplittable span (e.g., );
Remove any invalid cells that would break the now nonsplittable span from Step 2, e.g., the dark cells in (c), and reorganize the chart table much like in the Tetris game as in (d);
Encode the blank cells at the th level, e.g., the cell highlighted with stripes in (d), and go back to Step 1 until the root cell has been encoded.
To learn meaningful structures without gold trees, huetal2021r2d2 propose a selfsupervised pretraining objective. Similar to the bidirectional masked language model task, R2D2 reconstructs a given token based on its context representation and
. The probability of each token is estimated by the tree encoder defined in R2D2. The final objective is:
(6) 
We propose a topdown parser to evaluate scores for all split points in a sentence and generate a merge order according to the scores.
Given a sentence , there are split points between words. We define a topdown parser by giving confidence scores to all split points as follows:
(7) 
As shown in Figure 3, first, a bidirectional LSTM encodes the sentence, and then, for each split point, a MLP over the concatenation of the left and right context representations yields the final split scores. Formally, we have:
(8)  
(9) 
Here, is the embedding of the input sentence , while and denote the forward and reverse representation, respectively. is the score of the th split point, whose left and right context representations are and . Given scores of , one can easily recover the binary tree shown in Figure 3: We recursively split a span (starting with the entire sentence) into two subspans by picking the split point with the highest score in the current span. Taking the sentence in Figure 3 (a) as an example, we split the overall sentence at split point in the first step, which leads to two subtrees over and . Then we split at and at . We can continue this procedure until the complete tree has been recovered.
In the training stage, we perform sampling over the computed scores in order to increase the robustness and exploration of our model. We denote as the list of split points at time in ascending order, which is in the first step. Then a particular split point is selected from by sampling based on the probabilities estimated by stacking of split points scores. The sampled together form the final split point sequence . At each time step, we remove from when is selected, then sample the next split point until the set of remaining split points is empty. Formally, we have:
(10)  
(11) 
As the GumbelMax trick GUMBEL; DBLP:conf/nips/MaddisonTM14 provides a simple and efficient way to draw samples from a categorical distribution with class probabilities, we can obtain via the GumbelMax trick as:
(12) 
where is the Gumbel noise for the th split point. Therefore, the aforementioned process is equivalent to sorting the original sequence of split points scores with added Gumbel noise. Figure 3 (b) shows the sampled tree with respect to the split point scores. The split point sequence can hence be obtained simply as:
(13) 
Here, sorts the array in descending order and returns the indices of the original array. The sampled is in Figure 3 (b).
As wordpieces wu2016google
and BytePair Encoding (BPE) are commonly used in pretrained language models, it is straightforward to incorporate multiple wordpiece constraints into the topdown parser to reduce wordlevel parsing errors. We denote as
a list of span constraints composed of beginning and end positions of nonsplittable spans, defined as . For each in , there should be a subtree for a span covering the substring . This goal can be achieved by ensuring that the scores for splits covered by are lower than those of others by adjusting them following Algorithm 1.We denote as the reverse order of the split point sequence and then treat as a bottomup merge order inferred by the topdown parser based on the global context. Subsequently, we simply replace Algorithm 2 in huetal2021r2d2 by our Algorithm 2. As shown in Figure LABEL:fig:pruning, we still retain the threshold and the pruning logic of R2D2, but we select cells to merge according to instead of following heuristic rules. Specifically, given a shrinking chart table, we select the next merge index among the second row by popping and modifying in Algorithm 2.
Take the example in Figure 3 (b) for instance: starts with . Then we merge the first cell in the second row in Figure LABEL:fig:pruning (b), and obtain a new . In the next round, we treat the 4th cell covering as a nonsplittable cell in (e), and becomes .
We denote the tree probabilities estimated by the topdown parser and R2D2 as , , respectively. The difficulty here is that could be optimized by the objective defined in Equation 6, but there is no gradient feedback for . To make learnable, an intuitive solution is to fit to by minimizing their Kullback–Leibler distance. While the tree probabilities of both distributions are discrete and not exhaustive, inspired by URNNG dblp:conf/naacl/kimrykdm19, a Monte Carlo estimate for the gradient with respect to can be defined as:
(14)  
with samples , …, from . Algorithm 3 shows the complete sampling process from . Specifically, we sample split points recursively as in previous work DBLP:journals/corr/cmplg9805007; DBLP:conf/emnlp/FinkelMN06; dblp:conf/naacl/kimrykdm19 with respect to the intermediate tree probabilities calculated during hierarchical encoding.
A sequence of split points and corresponding spans is returned by the sampler. For the th sample , we denote as the probability of taking as split from span at the th step. Formally, we have:
(15)  
where and denote the start and end of the corresponding span. Please note here that the are not adjusted by span constraints.
In this paper, we mainly focus on classification tasks as downstream tasks. We consider the root representation as representing the entire sentence. As we have two models pretrained in our framework – R2D2 encoder and a topdown parser – we have two ways of generating the representations:
Run forced encoding over the binary tree from the topdown parser with R2D2 encoder;
Use the binary tree to guide the pruning of the R2D2 encoder, and take the root representation .
It is obvious that the first approach is much faster than the latter one, as the R2D2 encoder only runs times in forced encoding, and can run in parallel layer by layer, e.g., we may run compositions at , , and in parallel in Figure 3 (b). We explore both of these approaches in our experiments.
As suggested in previous work radford2018improving; howardruder2018universal; gururanganetal2020dont, given a pretrained model, continued pretraining on an indomain corpus with the same pretraining objective can yield a better generalization ability. Thus, we simply add together our pretraining objectives in all downstream tasks. At the same time, as the downstream task may guide R2D2 to more reasonable tree structures, we still maintain the KL loss to enable the parser to continuously update. For the two inference methods, we uniformly select the root representation as the representation for a given sentence followed by an MLP and estimate the crossentropy loss, denoted as and , respectively. The final loss is:
(16) 
For comparison, we include six recent strong models for unsupervised parsing with available open source implementations: structFormer
DBLP:conf/acl/ShenTZBMC20, Ordered Neurons
DBLP:conf/iclr/ShenTSC19, URNNG dblp:conf/naacl/kimrykdm19, DIORA dblp:conf/naacl/drozdovvyim19, CPCFG kimetal2019compound, and R2D2 huetal2021r2d2. Following htutetal2018grammar, we train all systems on a training set consisting only of raw text, and evaluate and report the results on an annotated test set. As an evaluation metric, we adopt sentencelevel unlabeled
computed using the script from kimetal2019compound. We compare against the nonbinarized gold trees per convention. The best checkpoint for each system is picked based on scores on the validation set.
As our model is a pretrained model based on wordpieces, for a fair comparison, we test all models with two types of input: word level (W) and wordpiece level (WP). FastR2D2 is pretrained with span constraints for the word level but without span constraints for the wordpiece level. To support wordpiece level evaluation, we convert gold trees to wordpiece level trees by simply breaking each terminal node into a nonterminal node with its wordpieces as terminals, e.g., (NN discrepancy) into (NN (WP disc) (WP ##re) (WP ##pan) (WP ##cy)).
EFLOPS DBLP:conf/hpca/DongCZYWFZLSPGJ20 is a highly scalable distributed training system designed by Alibaba. With its optimized hardware architecture and codesigned supporting software tools, including ACCL DBLP:journals/micro/DongWFCPTLLRGGL21 and KSpeed (the highspeed dataloading service), it could easily be extended to 10K nodes (GPUs) with linear scalability.
The tree encoder of our model uses 4layer Transformers with 768dimensional embeddings, 3,072dimensional hidden layer representations, and 12 attention heads. The topdown parser of our model uses a 4layer bidirectional LSTM with 128dimensional embeddings and 256dimensional hidden layer. The sampling number is set to be 256. Training is conducted using Adam optimization with weight decay with a learning rate of for the tree encoder and for the topdown parser. The batch size is set to 96 for =
, though we also limit the maximum total length for each batch, such that excess sentences are moved to the next batch. The limit is set to 1,536. It takes about 120 hours for 60 epochs of training with
= on 8 A100 GPUs.FastR2D2 is pretrained on a mixed corpus of WikiText103 DBLP:conf/iclr/MerityX0S17 and the training set of Penn Treebank (PTB) marcusetal1993building. WikiText103 is split at the sentence level, and sentences longer than 200 after tokenization are discarded (about 0.04‰ of the original data). The total number of sentences is 4,089,500, and the average sentence length is 26.97. For Chinese, we use a subset of Chinese Wikipedia for pretraining, specifically the first 10,000,000 sentences shorter than 150 characters.
We test our approach on PTB with the standard splits (2–21 for training, 22 for validation, 23 for test) and the same preprocessing as in recent work kimetal2019compound, where we discard punctuation and lowercase all tokens. To explore the universality of the model across languages, we also run experiments on Chinese Penn Treebank (CTB) 8 ctb8, on which we also remove punctuation.
Note that in all settings, the training is conducted entirely on raw unannotated text.
WSJ  CTB  
Model  cplx  
Left Branching (W)  8.15  11.28  
Right Branching (W)  39.62  27.53  
Random Trees (W)  17.76  20.17  
ONLSTM (W)  47.7  24.73  
DIORA (W)  51.42    
URNNG (W)  52.4    
StructFormer (WP)  54.0    
CPCFG (W)  55.2  49.95  
R2D2 (WP)  48.11  44.85  
FastR2D2 (WP)  56.24  51.04  
DIORA (WP)  43.94    
CPCFG (WP)  49.76  60.34  
R2D2 (WP)  52.28  63.94  
FastR2D2 (WP)  50.20  67.79  
FastR2D2 (WP)  53.88  67.74 
Table 1 shows results of all systems with words (W) and wordpieces (WP) as input on the WSJ and CTB test sets. When we evaluate all systems on wordlevel golden trees, our FastR2D2 performs substantially better than R2D2 across both datasets. As mentioned in Section 3.3, FastR2D2 has two options to get the final tree and representation for a given input sentence: We denote as FastR2D2 the system that runs forced encoding over the output tree of the topdown parser, and as FastR2D2 the method of using the tree to guide the pruning of the R2D2 encoder and selecting the best tree using the chart table. The results suggest that FastR2D2 outperforms FastR2D2 on both the WSJ and CTB test sets. More interestingly, FastR2D2, a model without humandesigned grammar constraints at the wordpiece level granularity, outperforms the model specifically designed for grammar induction at a wordlevel granularity. If all systems take wordpieces as input, and are measured against wordpiece level golden trees (the bottom five results), our FastR2D2 obtains stateoftheart results on both WSJ and CTB.
(WP)  WD  NNP  VP  SBAR  
WSJ 
DIORA  94.63  77.83  17.30  22.16 
CPCFG  87.35  66.44  23.63  40.40  
R2D2  99.76  86.76  24.74  39.81  
FastR2D2  97.67  83.44  63.80  65.68  
CTB 
CPCFG  89.34  46.74  39.53   
R2D2  97.16  67.19  37.90    
FastR2D2  97.80  68.57  46.59   
Following dblp:conf/naacl/kimrykdm19 and drozdovetal2020unsupervised, we also compute the recall of constituents when evaluating on wordpiece level golden trees. Besides standard constituents, we also compare the recall of wordpiece chunks and proper noun chunks. Proper noun chunks are extracted by finding adjacent unary nodes with the same parent and tag NNP. Table 2 reports the recall scores for constituents and words on the WSJ and CTB test sets. Compared with the R2D2 baseline, our FastR2D2 performs slightly worse for small semantic units, but significantly better over larger semantic units (such as VP and SBAR) on the WSJ test set. On the CTB test set, our FastR2D2 outperforms R2D2 on all constituents.
From Tables 1 and 2, we conclude that FastR2D2 overall obtains better results than R2D2 on CTB, while faring slightly worse than R2D2 only for small semantic units on WSJ. We conjecture that this difference stems from differences in tokenization between Chinese and English. Chinese is a characterbased language without complex morphology, where collocations of characters are consistent with the language, making it easier for the topdown parser to learn them well. In contrast, wordpieces for English are built based on statistics, and individual wordpieces are not necessarily natural semantic units. Thus, there may not be sufficient semantic selfconsistency, such that it is harder for a topdown parser with a small amount of parameters to fit it well.
We next consider the effectiveness of FastR2D2 in downstream tasks. This experiment is not intended to advance the stateoftheart on the GLUE benchmark but rather to assess to what extent our approach performs respectably against the dominating inductive bias as in conventional sequential Transformers.
We finetune pretrained models on several datasets, including SST2, CoLA, QQP, and MNLI from the GLUE benchmark
wang2018glue. As sequential Transformers with their dominating inductive bias remain the norm for numerous NLP tasks, we mainly compare FastR2D2 with Bert devlin2018 as a representative pretrained model based on a sequential Transformer. In order to compare the two forms of inductive bias fairly and efficiently, we pretrain Bert models with 4layers and 12layers as well as our FastR2D2 from scratch on the WikiText103 corpus following Section 4.1.1. For simplicity, FastR2D2 is finetuned without span constraints. Following the common settings, we add an MLP layer over the root representation of the R2D2 encoder for single sentence classification. For crosssentence tasks such as QQP and MNLI, we feed the root representations of the two sentences into the pretrained tree encoder of R2D2 as left and right inputs, and also add a new task id as another input term to the R2D2 encoder. Then we feed the hidden output of the new task id into another MLP layer to predict the final label. We train all systems on four data sets for 10 epochs with a learning rate of , batch size , and a maximum input length . We validate each model in each epoch, and report the best results on development sets.Model  Para.  Single sent.  Cross sent.  





Bert (4L)  52M  84.98  17.07  84.01  73.73/74.63  
Bert (12L)  116M  90.25  40.72  87.13  80.00/80.41  
FastR2D2 

87.50  8.67  83.97  69.53/69.50  
FastR2D2  88.30  10.14  84.07  69.36/69.11  
FastR2D2  90.25  38.45  84.35  69.36/68.80  
FastR2D2  90.71  40.11  84.32  69.64/69.57 
Table 3 shows the corresponding scores on SST2, CoLA, QQPl, and MNLI. In terms of the parameter size, our FastR2D2 model has 52M and 10M parameters for the R2D2 encoder and topdown parser, respectively. It is clear that 12layer Bert is significantly better than 4layer Bert. FastR2D2 uses the output tree from the topdown parser, while FastR2D2 uses the best tree inferred by the R2D2 encoder. Similar to the results for unsupervised parsing, FastR2D2 in classification tasks again outperforms FastR2D2. We hypothesize that trees generated by the topdown parser without Gumbel noise are more stable and reasonable. FastR2D2 significantly outperforms 4layer Bert and achieves competitive results compared to 12layer Bert in single sentence classification tasks like SST2 and CoLA, but still performs significantly worse in the crosssentence tasks. We believe this is an expected result, as there is no crossattention mechanism in the inductive bias of FastR2D2. However, the performance of FastR2D2 on classification tasks shows that the inductive bias of R2D2 has higher parameter utilization than sequentially applied Transformers. Importantly, we demonstrate that a Recursive Neural Network variant with an unsupervised parser can achieve comparable results to pretrained sequential Transformers even with fewer parameters and interpretable intermediate results, Hence, our FastR2D2 framework provides an alternative choice for NLP tasks.
As DBLP:conf/icml/ChowdhuryC21 argued, parallelism is also an important feature for a RvNN. As the tree encoder of FastR2D2 is based on Transformers, we mainly compare the time cost of sequential Transformers and FastR2D2 in forced encoding on various sequence length ranges. We randomly select 1,000 sentences for each range from WikiText103 and report the average time consumption on single A100 GPU. Bert is based on the open source Transformers library^{2}^{2}2https://github.com/huggingface/transformers and R2D2 is based on the official code^{3}^{3}3https://github.com/alipay/StructuredLM_RTDT/tree/r2d2 in huetal2021r2d2.
Model  Sequence Length Ranges  
050  50100  100200  200500  
Bert (12L)  1.36  1.46  1.62  2.38 
R2D2  38.06  173.74  555.95   
FastR2D2  4.67  14.91  39.73  150.26 
FastR2D2*  1.83  4.16  7.88  14.84 
Table 4 shows the inference time in seconds for different systems to process 1,000 sentences with a batch size of 50. Running R2D2 is timeconsuming since the heuristic pruning method involves huge memory exchanges between GPU and CPU. In FastR2D2, we alleviate this problem by using modelguided pruning to accelerate the chart table processing, in conjunction with a code implementation in CUDA, FastR2D2 reduces the inference time significantly. FastR2D2 further improves the inference speed by running forced encoding in parallel over the binary tree generated by the parser, which is about 3050 times faster than R2D2 in various ranges. Although there is still a gap in speed compared to sequential Transformers, FastR2D2 is sufficiently fast for most NLP tasks while producing interpretable intermediate representations.
Many attempts have been done in grammar induction and hierarchical encoding. DBLP:conf/conll/Clark01 and DBLP:conf/acl/KleinM02 provided some of the first successful statistical approaches to grammar induction. There have been multiple reccent papers that focus on structure induction based on language modeling objectives DBLP:conf/nips/ShenTHLSC19; DBLP:conf/iclr/ShenTSC19; DBLP:conf/acl/ShenTZBMC20; kimetal2019compound
. DBLP:journals/ai/Pollack90 propose to use RvNN as a recursive architecture to encode text hierarchically, and DBLP:conf/emnlp/SocherPWCMNP13 show the effectiveness of RvNNs with gold trees for sentiment analysis. In this work, we focus on models that are capable of learning meaningful structures in an unsupervised way and encoding text over the induced tree hierarchically.
In the line of work on learning a sentence representation with structures, DBLP:conf/iclr/YogatamaBDGL17 jointly train their shiftreduce parser and sentence embedding components without gold trees. As their parser is not differentiable, they have to resort to reinforcement training, resulting in increased variance, which may easily collapse to trivial left/right branching trees. GumbelTreeLSTMs
DBLP:conf/aaai/ChoiYL18 construct trees by recursively selecting two terminal nodes to merge and learning the composition probability via downstream tasks. CRvNN DBLP:conf/icml/ChowdhuryC21 makes the whole process endtoend differential and parallel by introducing a continuous relaxation. The work of URNNG dblp:conf/naacl/kimrykdm19 applies variational inference over latent trees to perform unsupervised optimization of the RNNG dyeretal2016recurrent, an RNN model that estimates a joint distribution over sentences and trees based on shiftreduce operations. However, it is hard to induce them when trained from scratch. DBLP:journals/corr/MaillardCY17 propose an alternative approach, based on a differential CKY encoding. The algorithm is differential by using a softgating approach, which approximates discrete candidate selection by a probabilistic mixture of the constituents available in a given cell of the chart. While their work relies on annotated downstream tasks to learn structures, dblp:conf/naacl/drozdovvyim19 propose a novel autoencoderlike pretraining objective based on the inside–outside algorithm
Baker1979TrainableGF; DBLP:conf/icgi/Casacuberta94. As mentioned above, CKYbased models have cubic time complexity. huetal2021r2d2 propose a pruned differential CKY encoding architecture with a simple pretraining objective related to bidirectional language modeling. They reduce the time complexity to a linear number of composition steps and make it possible to apply sophisticated tree encoders and to scale to large corpus pretraining.In this paper, we have presented FastR2D2, which improves the performance and inference speed of R2D2 by introducing a fast topdown parser to guide the pruning of R2D2 encoder. Pretrained on the same corpus, FastR2D2 significantly outperforms sequential Transformers with a similar scale of parameters on classification tasks. Experimental results show that FastR2D2 is a promising and feasible way to learn hierarchical text representations, which is different from layer stacking models and can also generate interpretable intermediate representations. As future work, we are investigating leveraging the intermediate representations in additional downstream tasks.
We would like to thank the Aliyun EFLOPS team for their substantial support in designing and providing a cuttingedge training platform to facilitate fast experimentation in this work.