Fast-R2D2: A Pretrained Recursive Neural Network based on Pruned CKY for Grammar Induction and Text Representation

by   Xiang Hu, et al.
Shandong University
Gerard de Melo

Recently CKY-based models show great potential in unsupervised grammar induction thanks to their human-like encoding paradigm, which runs recursively and hierarchically, but requires O(n^3) time-complexity. Recursive Transformer based on Differentiable Trees (R2D2) makes it possible to scale to large language model pre-training even with complex tree encoder by introducing a heuristic pruning method. However, the rule-based pruning approach suffers from local optimum and slow inference issues. In this paper, we fix those issues in a unified method. We propose to use a top-down parser as a model-based pruning method, which also enables parallel encoding during inference. Typically, our parser casts parsing as a split point scoring task, which first scores all split points for a given sentence, and then recursively splits a span into two by picking a split point with the highest score in the current span. The reverse order of the splits is considered as the order of pruning in R2D2 encoder. Beside the bi-directional language model loss, we also optimize the parser by minimizing the KL distance between tree probabilities from parser and R2D2. Our experiments show that our Fast-R2D2 improves performance significantly in grammar induction and achieves competitive results in downstream classification tasks.


page 1

page 2

page 3

page 4


Dependency Grammar Induction with a Neural Variational Transition-based Parser

Dependency grammar induction is the task of learning dependency syntax w...

A Span-based Linearization for Constituent Trees

We propose a novel linearization of a constituent tree, together with a ...

Pika parsing: parsing in reverse solves the left recursion and error recovery problems

A recursive descent parser is built from a set of mutually-recursive fun...

R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling

Human language understanding operates at multiple levels of granularity ...

A Fast Unified Model for Parsing and Sentence Understanding

Tree-structured neural networks exploit valuable syntactic parse informa...

Unsupervised Parsing via Constituency Tests

We propose a method for unsupervised parsing based on the linguistic not...

An Empirical Study of Compound PCFGs

Compound probabilistic context-free grammars (C-PCFGs) have recently est...

Code Repositories


The official code for paper "R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling".

view repo

1 Introduction

Compositional, hierarchical and recursive processing are widely believed to be essential traits of human language across diverse linguistic theories DBLP:journals/tit/Chomsky56; chomsky2014aspects. Chart-based parsing models DBLP:journals/corr/MaillardCY17; kim-etal-2019-compound; dblp:conf/naacl/drozdovvyim19; hu-etal-2021-r2d2 have made promising progress in both grammar induction and hierarchical encodings in recent years. The differential CKY encoding architecture of DBLP:journals/corr/MaillardCY17 simulates the hierarchical and recursive process explicitly by introducing an energy function to combine all possible derivations when constructing each cell representation. However, this entails a cubic time complexity, which makes it impossible to scale to large language model training like BERT devlin2018. Analogously, the cubic memory cost also limits the tree encoder’s ability to draw on huge parameter models as a backbone.

hu-etal-2021-r2d2 introduced a heuristic pruning method, successfully reducing the time complexity to a linear number of compositions. Their experiments show that the chart-based model has great potential for grammar induction and representation learning when applying a sophisticated tree encoder such as Transformers with large corpus pretraining, leading to a Recursive Transformer based on Differentiable Trees, or R2D2 for short. However, their heuristic pruning approach is a rule-based algorithm that only considers certain composition probabilities. Thus, trees constructed in this way are not guaranteed to be globally optimal. Moreover, as each step during pruning is based on previous decisions, the entire encoding process is sequential and thus slow in the inference stage.

In this work, we resolve these issues by proposing a unified method with a new global pruning strategy based on a light-weight and fast top-down parser. We cast parsing as split point scoring, where we first encode the input sentence with a bi-directional LSTM, and score all split points in parallel. Specifically, for a given sentence, the parser first scores each split point between words in parallel by looking at its left and right contexts, and then recursively splits a span (starting with the whole sentence) into two sub-spans by picking a split point with the highest score among the current split candidates. During training, we incorporate sampling in the recursive splitting process, where, in each step, we sample a split point with respect to the score distribution in the current span and simplify the process as a sorting problem. Thus, the reverse order of the sorted split points can serve as the merge order to guide the pruning of the CKY encoder. As the gradient of the pretrained component cannot be back-propagated to the parser, inspired by URNNG dblp:conf/naacl/kimrykdm19, we optimize the parser by sampling trees over the CKY chart table generated by R2D2. Additionally, the pretrained tree encoder can compose sequences recursively in parallel according to the trees generated by the parser, which makes Fast-R2D2 a Recursive Neural Network DBLP:journals/ai/Pollack90; DBLP:conf/emnlp/SocherPWCMNP13 variant.

In this paper, we make the following main contributions:

  1. We propose a model-based pruning method based on a top-down parser and corresponding unsupervised training objective. Experiments show that our parser outperforms models custom-tailored for grammar induction.

  2. By encoding in parallel all trees generated by the top-down parser, Fast-R2D2 significantly improves the inference speed 30 to 50 fold compared to R2D2.

  3. We pre-train Fast-R2D2 on a large corpus and evaluate it on downstream tasks. The experiments demonstrate that a pretrained recursive model based on an unsupervised parser significantly outperforms pretrained sequential Transformers DBLP:conf/nips/VaswaniSPUJGKP17 with the same parameter size in single sentence classification tasks.

2 Background

2.1 R2D2 Architecture

Differentiable Trees.

R2D2 follows the work of DBLP:journals/corr/MaillardCY17 in defining a CKY-style 10.5555/1097042; kasami1966efficient; younger1967recognition encoder. Informally, given a sentence with words or word-pieces, R2D2 defines a chart table as shown in Figure 1. In such a chart, each cell is a tuple , where

is a vector representation,

is the probability of a single composition step, and is the probability of the subtree for the span over the sub-string . At the lowest level, the table has terminal nodes with initialized as the embeddings of input tokens , while and are set to one. When , the representation is a weighted sum of intermediate combinations , defined as:


is a split point from to , is a -layer Transformer encoder. and denote the single step combination probability and the subtree probability, respectively, at split point , and are the concatenation of all or values, and Gumbel is the Straight-Through Gumbel-Softmax operation of DBLP:conf/iclr/JangGP17 with temperature set to one. As Gumbel picks the optimal splitting point at each cell in practice, it is straightforward to recover the complete derivation tree from the root node in a top-down manner recursively.

Figure 1: Chart data structure. There are two alternative ways of generating : combining either (, ) or (, ).
Heuristic pruning.

As shown in Figure LABEL:fig:pruning, R2D2 starts to prune if all cells beneath height have been encoded. The heuristic rules work as follows:

  1. Recover the maximum sub-tree for each cell at the -th level, and collect all cells at the nd level that appear in any sub-tree;

  2. Rank candidates in Step 1 by the composition probability , and pick the cell with highest score as a non-splittable span (e.g., );

  3. Remove any invalid cells that would break the now non-splittable span from Step 2, e.g., the dark cells in (c), and reorganize the chart table much like in the Tetris game as in (d);

  4. Encode the blank cells at the -th level, e.g., the cell highlighted with stripes in (d), and go back to Step 1 until the root cell has been encoded.


To learn meaningful structures without gold trees, hu-etal-2021-r2d2 propose a self-supervised pretraining objective. Similar to the bidirectional masked language model task, R2D2 reconstructs a given token based on its context representation and

. The probability of each token is estimated by the tree encoder defined in R2D2. The final objective is:


3 Methodology

3.1 Global Pruning Strategy

We propose a top-down parser to evaluate scores for all split points in a sentence and generate a merge order according to the scores.

Figure 3: (a) A parsed tree by sorting split scores (). (b) A sampled tree by adding Gumbel noise ( in dark vertical bars).
Top-down parser.

Given a sentence , there are split points between words. We define a top-down parser by giving confidence scores to all split points as follows:


As shown in Figure 3, first, a bi-directional LSTM encodes the sentence, and then, for each split point, a MLP over the concatenation of the left and right context representations yields the final split scores. Formally, we have:


Here, is the embedding of the input sentence , while and denote the forward and reverse representation, respectively. is the score of the -th split point, whose left and right context representations are and . Given scores of , one can easily recover the binary tree shown in Figure 3: We recursively split a span (starting with the entire sentence) into two sub-spans by picking the split point with the highest score in the current span. Taking the sentence in Figure 3 (a) as an example, we split the overall sentence at split point in the first step, which leads to two sub-trees over and . Then we split at and at . We can continue this procedure until the complete tree has been recovered.

Tree sampling.

In the training stage, we perform sampling over the computed scores in order to increase the robustness and exploration of our model. We denote as the list of split points at time in ascending order, which is in the first step. Then a particular split point is selected from by sampling based on the probabilities estimated by stacking of split points scores. The sampled together form the final split point sequence . At each time step, we remove from when is selected, then sample the next split point until the set of remaining split points is empty. Formally, we have:


As the Gumbel-Max trick GUMBEL; DBLP:conf/nips/MaddisonTM14 provides a simple and efficient way to draw samples from a categorical distribution with class probabilities, we can obtain via the Gumbel-Max trick as:


where is the Gumbel noise for the -th split point. Therefore, the aforementioned process is equivalent to sorting the original sequence of split points scores with added Gumbel noise. Figure 3 (b) shows the sampled tree with respect to the split point scores. The split point sequence can hence be obtained simply as:


Here, sorts the array in descending order and returns the indices of the original array. The sampled is in Figure 3 (b).

Span Constraints.

As word-pieces wu2016google

and Byte-Pair Encoding (BPE) are commonly used in pretrained language models, it is straightforward to incorporate multiple word-piece constraints into the top-down parser to reduce word-level parsing errors. We denote as

a list of span constraints composed of beginning and end positions of non-splittable spans, defined as . For each in , there should be a sub-tree for a span covering the sub-string . This goal can be achieved by ensuring that the scores for splits covered by are lower than those of others by adjusting them following Algorithm 1.

1: scores for all split points.
2: non-splittable spans.
3: a constant, positive value.
4:function Adjust-scores(v, )
6:     for  do
7:         for  to  do
9: Scores adjusted will be smaller than scores of span boundaries.               
Algorithm 1 Adjust split scores according to
Model-based Pruning.

We denote as the reverse order of the split point sequence and then treat as a bottom-up merge order inferred by the top-down parser based on the global context. Subsequently, we simply replace Algorithm 2 in hu-etal-2021-r2d2 by our Algorithm 2. As shown in Figure LABEL:fig:pruning, we still retain the threshold and the pruning logic of R2D2, but we select cells to merge according to instead of following heuristic rules. Specifically, given a shrinking chart table, we select the next merge index among the second row by popping and modifying in Algorithm 2.

1:function Next-Index()
2:      Index
3:     for  to  do
4:         if  then Merging at left
5:               Shift left               
6:     return
Algorithm 2 Next merge index in the second row

Take the example in Figure 3 (b) for instance: starts with . Then we merge the first cell in the second row in Figure LABEL:fig:pruning (b), and obtain a new . In the next round, we treat the 4th cell covering as a non-splittable cell in (e), and becomes .

3.2 Optimization

We denote the tree probabilities estimated by the top-down parser and R2D2 as , , respectively. The difficulty here is that could be optimized by the objective defined in Equation 6, but there is no gradient feedback for . To make learnable, an intuitive solution is to fit to by minimizing their Kullback–Leibler distance. While the tree probabilities of both distributions are discrete and not exhaustive, inspired by URNNG dblp:conf/naacl/kimrykdm19, a Monte Carlo estimate for the gradient with respect to can be defined as:


with samples , …, from . Algorithm 3 shows the complete sampling process from . Specifically, we sample split points recursively as in previous work DBLP:journals/corr/cmp-lg-9805007; DBLP:conf/emnlp/FinkelMN06; dblp:conf/naacl/kimrykdm19 with respect to the intermediate tree probabilities calculated during hierarchical encoding.

1:function Sample() Root cell
4:     while  is not empty do
6:         , Start/end indices
7:          splits at most
8:         for  to  do
9:               Using Equation 2          
12:          Sampled point
13:         if  then Add left child
15:         if  +  then Add right child
17:     return
Algorithm 3 Top-down tree sampling for R2D2

A sequence of split points and corresponding spans is returned by the sampler. For the -th sample , we denote as the probability of taking as split from span at the -th step. Formally, we have:


where and denote the start and end of the corresponding span. Please note here that the are not adjusted by span constraints.

3.3 Downstream tasks


In this paper, we mainly focus on classification tasks as downstream tasks. We consider the root representation as representing the entire sentence. As we have two models pre-trained in our framework – R2D2 encoder and a top-down parser – we have two ways of generating the representations:

  • Run forced encoding over the binary tree from the top-down parser with R2D2 encoder;

  • Use the binary tree to guide the pruning of the R2D2 encoder, and take the root representation .

It is obvious that the first approach is much faster than the latter one, as the R2D2 encoder only runs times in forced encoding, and can run in parallel layer by layer, e.g., we may run compositions at , , and in parallel in Figure 3 (b). We explore both of these approaches in our experiments.

Training objectives.

As suggested in previous work radford2018improving; howard-ruder-2018-universal; gururangan-etal-2020-dont, given a pretrained model, continued pretraining on an in-domain corpus with the same pretraining objective can yield a better generalization ability. Thus, we simply add together our pretraining objectives in all downstream tasks. At the same time, as the downstream task may guide R2D2 to more reasonable tree structures, we still maintain the KL loss to enable the parser to continuously update. For the two inference methods, we uniformly select the root representation as the representation for a given sentence followed by an MLP and estimate the cross-entropy loss, denoted as and , respectively. The final loss is:


4 Experiments

4.1 Unsupervised Grammar Induction

4.1.1 Setup

Baselines and Evaluation.

For comparison, we include six recent strong models for unsupervised parsing with available open source implementations: structFormer


, Ordered Neurons 

DBLP:conf/iclr/ShenTSC19, URNNG dblp:conf/naacl/kimrykdm19, DIORA dblp:conf/naacl/drozdovvyim19, C-PCFG kim-etal-2019-compound, and R2D2 hu-etal-2021-r2d2

. Following htut-etal-2018-grammar, we train all systems on a training set consisting only of raw text, and evaluate and report the results on an annotated test set. As an evaluation metric, we adopt sentence-level unlabeled

computed using the script from kim-etal-2019-compound. We compare against the non-binarized gold trees per convention. The best checkpoint for each system is picked based on scores on the validation set.

As our model is a pretrained model based on word-pieces, for a fair comparison, we test all models with two types of input: word level (W) and word-piece level (WP). Fast-R2D2 is pretrained with span constraints for the word level but without span constraints for the word-piece level. To support word-piece level evaluation, we convert gold trees to word-piece level trees by simply breaking each terminal node into a non-terminal node with its word-pieces as terminals, e.g., (NN discrepancy) into (NN (WP disc) (WP ##re) (WP ##pan) (WP ##cy)).


EFLOPS DBLP:conf/hpca/DongCZYWFZLSPGJ20 is a highly scalable distributed training system designed by Alibaba. With its optimized hardware architecture and co-designed supporting software tools, including ACCL DBLP:journals/micro/DongWFCPTLLRGGL21 and KSpeed (the high-speed data-loading service), it could easily be extended to 10K nodes (GPUs) with linear scalability.


The tree encoder of our model uses 4-layer Transformers with 768-dimensional embeddings, 3,072-dimensional hidden layer representations, and 12 attention heads. The top-down parser of our model uses a 4-layer bidirectional LSTM with 128-dimensional embeddings and 256-dimensional hidden layer. The sampling number is set to be 256. Training is conducted using Adam optimization with weight decay with a learning rate of for the tree encoder and for the top-down parser. The batch size is set to 96 for =

, though we also limit the maximum total length for each batch, such that excess sentences are moved to the next batch. The limit is set to 1,536. It takes about 120 hours for 60 epochs of training with

= on 8 A100 GPUs.


Fast-R2D2 is pretrained on a mixed corpus of WikiText103 DBLP:conf/iclr/MerityX0S17 and the training set of Penn Treebank (PTB) marcus-etal-1993-building. WikiText103 is split at the sentence level, and sentences longer than 200 after tokenization are discarded (about 0.04‰ of the original data). The total number of sentences is 4,089,500, and the average sentence length is 26.97. For Chinese, we use a subset of Chinese Wikipedia for pretraining, specifically the first 10,000,000 sentences shorter than 150 characters.

We test our approach on PTB with the standard splits (2–21 for training, 22 for validation, 23 for test) and the same preprocessing as in recent work kim-etal-2019-compound, where we discard punctuation and lower-case all tokens. To explore the universality of the model across languages, we also run experiments on Chinese Penn Treebank (CTB) 8 ctb8, on which we also remove punctuation.

Note that in all settings, the training is conducted entirely on raw unannotated text.

4.1.2 Results and Discussion

Model cplx
Left Branching (W) 8.15 11.28
Right Branching (W) 39.62 27.53
Random Trees (W) 17.76 20.17
ON-LSTM (W) 47.7 24.73
DIORA (W) 51.42 -
URNNG (W) 52.4 -
StructFormer (WP) 54.0 -
C-PCFG (W) 55.2 49.95
R2D2 (WP) 48.11 44.85
Fast-R2D2 (WP) 56.24 51.04
DIORA (WP) 43.94 -
C-PCFG (WP) 49.76 60.34
R2D2 (WP) 52.28 63.94
Fast-R2D2 (WP) 50.20 67.79
Fast-R2D2 (WP) 53.88 67.74
Table 1: Unsupervised parsing results with words (W) or word-pieces (WP) as input. Values with are taken from kim-etal-2019-compound. Values with are taken from DBLP:conf/acl/ShenTZBMC20. The bottom five systems are all pre-trained or trained at the word-piece level without span constraints, and are measured against word-piece level golden trees.

Table 1 shows results of all systems with words (W) and word-pieces (WP) as input on the WSJ and CTB test sets. When we evaluate all systems on word-level golden trees, our Fast-R2D2 performs substantially better than R2D2 across both datasets. As mentioned in Section 3.3, Fast-R2D2 has two options to get the final tree and representation for a given input sentence: We denote as Fast-R2D2 the system that runs forced encoding over the output tree of the top-down parser, and as Fast-R2D2 the method of using the tree to guide the pruning of the R2D2 encoder and selecting the best tree using the chart table. The results suggest that Fast-R2D2 outperforms Fast-R2D2 on both the WSJ and CTB test sets. More interestingly, Fast-R2D2, a model without human-designed grammar constraints at the word-piece level granularity, outperforms the model specifically designed for grammar induction at a word-level granularity. If all systems take word-pieces as input, and are measured against word-piece level golden trees (the bottom five results), our Fast-R2D2 obtains state-of-the-art results on both WSJ and CTB.



DIORA 94.63 77.83 17.30 22.16
C-PCFG 87.35 66.44 23.63 40.40
R2D2 99.76 86.76 24.74 39.81
Fast-R2D2 97.67 83.44 63.80 65.68


C-PCFG 89.34 46.74 39.53 -
R2D2 97.16 67.19 37.90 -
Fast-R2D2 97.80 68.57 46.59 -
Table 2: Recall of constituents and words at word-piece level. WD means word.

Following dblp:conf/naacl/kimrykdm19 and drozdov-etal-2020-unsupervised, we also compute the recall of constituents when evaluating on word-piece level golden trees. Besides standard constituents, we also compare the recall of word-piece chunks and proper noun chunks. Proper noun chunks are extracted by finding adjacent unary nodes with the same parent and tag NNP. Table 2 reports the recall scores for constituents and words on the WSJ and CTB test sets. Compared with the R2D2 baseline, our Fast-R2D2 performs slightly worse for small semantic units, but significantly better over larger semantic units (such as VP and SBAR) on the WSJ test set. On the CTB test set, our Fast-R2D2 outperforms R2D2 on all constituents.

From Tables 1 and 2, we conclude that Fast-R2D2 overall obtains better results than R2D2 on CTB, while faring slightly worse than R2D2 only for small semantic units on WSJ. We conjecture that this difference stems from differences in tokenization between Chinese and English. Chinese is a character-based language without complex morphology, where collocations of characters are consistent with the language, making it easier for the top-down parser to learn them well. In contrast, word-pieces for English are built based on statistics, and individual word-pieces are not necessarily natural semantic units. Thus, there may not be sufficient semantic self-consistency, such that it is harder for a top-down parser with a small amount of parameters to fit it well.

4.2 Downstream Tasks

We next consider the effectiveness of Fast-R2D2 in downstream tasks. This experiment is not intended to advance the state-of-the-art on the GLUE benchmark but rather to assess to what extent our approach performs respectably against the dominating inductive bias as in conventional sequential Transformers.

4.2.1 Setup

Data and Baseline

We fine-tune pretrained models on several datasets, including SST-2, CoLA, QQP, and MNLI from the GLUE benchmark 

wang2018glue. As sequential Transformers with their dominating inductive bias remain the norm for numerous NLP tasks, we mainly compare Fast-R2D2 with Bert devlin2018 as a representative pretrained model based on a sequential Transformer. In order to compare the two forms of inductive bias fairly and efficiently, we pretrain Bert models with 4-layers and 12-layers as well as our Fast-R2D2 from scratch on the WikiText103 corpus following Section 4.1.1. For simplicity, Fast-R2D2 is fine-tuned without span constraints. Following the common settings, we add an MLP layer over the root representation of the R2D2 encoder for single sentence classification. For cross-sentence tasks such as QQP and MNLI, we feed the root representations of the two sentences into the pretrained tree encoder of R2D2 as left and right inputs, and also add a new task id as another input term to the R2D2 encoder. Then we feed the hidden output of the new task id into another MLP layer to predict the final label. We train all systems on four data sets for 10 epochs with a learning rate of , batch size , and a maximum input length . We validate each model in each epoch, and report the best results on development sets.

Model Para. Single sent. Cross sent.
Bert (4L) 52M 84.98 17.07 84.01 73.73/74.63
Bert (12L) 116M 90.25 40.72 87.13 80.00/80.41
87.50 8.67 83.97 69.53/69.50
Fast-R2D2 88.30 10.14 84.07 69.36/69.11
Fast-R2D2 90.25 38.45 84.35 69.36/68.80
Fast-R2D2 90.71 40.11 84.32 69.64/69.57
Table 3: All systems are pretrained from scratch on WikiText103. Para. describes the number of parameters for each model. Fast-R2D2 contains the R2D2 encoder and top-down parser, two components with 52M and 10M parameters, respectively. Mcc. stands for Matthew’s correlation coefficient. Fast-R2D2 with are models fine-tuned without for an ablation study.

4.2.2 Results and Discussion

Table 3 shows the corresponding scores on SST-2, CoLA, QQPl, and MNLI. In terms of the parameter size, our Fast-R2D2 model has 52M and 10M parameters for the R2D2 encoder and top-down parser, respectively. It is clear that 12-layer Bert is significantly better than 4-layer Bert. Fast-R2D2 uses the output tree from the top-down parser, while Fast-R2D2 uses the best tree inferred by the R2D2 encoder. Similar to the results for unsupervised parsing, Fast-R2D2 in classification tasks again outperforms Fast-R2D2. We hypothesize that trees generated by the top-down parser without Gumbel noise are more stable and reasonable. Fast-R2D2 significantly outperforms 4-layer Bert and achieves competitive results compared to 12-layer Bert in single sentence classification tasks like SST-2 and CoLA, but still performs significantly worse in the cross-sentence tasks. We believe this is an expected result, as there is no cross-attention mechanism in the inductive bias of Fast-R2D2. However, the performance of Fast-R2D2 on classification tasks shows that the inductive bias of R2D2 has higher parameter utilization than sequentially applied Transformers. Importantly, we demonstrate that a Recursive Neural Network variant with an unsupervised parser can achieve comparable results to pretrained sequential Transformers even with fewer parameters and interpretable intermediate results, Hence, our Fast-R2D2 framework provides an alternative choice for NLP tasks.

4.3 Speed Test

As DBLP:conf/icml/ChowdhuryC21 argued, parallelism is also an important feature for a RvNN. As the tree encoder of Fast-R2D2 is based on Transformers, we mainly compare the time cost of sequential Transformers and Fast-R2D2 in forced encoding on various sequence length ranges. We randomly select 1,000 sentences for each range from WikiText103 and report the average time consumption on single A100 GPU. Bert is based on the open source Transformers library222 and R2D2 is based on the official code333 in hu-etal-2021-r2d2.

Model Sequence Length Ranges
0-50 50-100 100-200 200-500
Bert (12L) 1.36 1.46 1.62 2.38
R2D2 38.06 173.74 555.95 -
Fast-R2D2 4.67 14.91 39.73 150.26
Fast-R2D2* 1.83 4.16 7.88 14.84
Table 4: Inference time in seconds for various systems to process 1,000 sentences with a batch size of 50.

Table 4 shows the inference time in seconds for different systems to process 1,000 sentences with a batch size of 50. Running R2D2 is time-consuming since the heuristic pruning method involves huge memory exchanges between GPU and CPU. In Fast-R2D2, we alleviate this problem by using model-guided pruning to accelerate the chart table processing, in conjunction with a code implementation in CUDA, Fast-R2D2 reduces the inference time significantly. Fast-R2D2 further improves the inference speed by running forced encoding in parallel over the binary tree generated by the parser, which is about 30-50 times faster than R2D2 in various ranges. Although there is still a gap in speed compared to sequential Transformers, Fast-R2D2 is sufficiently fast for most NLP tasks while producing interpretable intermediate representations.

5 Related Work

Many attempts have been done in grammar induction and hierarchical encoding. DBLP:conf/conll/Clark01 and DBLP:conf/acl/KleinM02 provided some of the first successful statistical approaches to grammar induction. There have been multiple reccent papers that focus on structure induction based on language modeling objectives DBLP:conf/nips/ShenTHLSC19; DBLP:conf/iclr/ShenTSC19; DBLP:conf/acl/ShenTZBMC20; kim-etal-2019-compound

. DBLP:journals/ai/Pollack90 propose to use RvNN as a recursive architecture to encode text hierarchically, and DBLP:conf/emnlp/SocherPWCMNP13 show the effectiveness of RvNNs with gold trees for sentiment analysis. In this work, we focus on models that are capable of learning meaningful structures in an unsupervised way and encoding text over the induced tree hierarchically.

In the line of work on learning a sentence representation with structures, DBLP:conf/iclr/YogatamaBDGL17 jointly train their shift-reduce parser and sentence embedding components without gold trees. As their parser is not differentiable, they have to resort to reinforcement training, resulting in increased variance, which may easily collapse to trivial left/right branching trees. Gumbel-Tree-LSTMs 

DBLP:conf/aaai/ChoiYL18 construct trees by recursively selecting two terminal nodes to merge and learning the composition probability via downstream tasks. CRvNN DBLP:conf/icml/ChowdhuryC21 makes the whole process end-to-end differential and parallel by introducing a continuous relaxation. The work of URNNG dblp:conf/naacl/kimrykdm19 applies variational inference over latent trees to perform unsupervised optimization of the RNNG dyer-etal-2016-recurrent

, an RNN model that estimates a joint distribution over sentences and trees based on shift-reduce operations. However, it is hard to induce them when trained from scratch. DBLP:journals/corr/MaillardCY17 propose an alternative approach, based on a differential CKY encoding. The algorithm is differential by using a soft-gating approach, which approximates discrete candidate selection by a probabilistic mixture of the constituents available in a given cell of the chart. While their work relies on annotated downstream tasks to learn structures, dblp:conf/naacl/drozdovvyim19 propose a novel auto-encoder-like pretraining objective based on the inside–outside algorithm 

Baker1979TrainableGF; DBLP:conf/icgi/Casacuberta94. As mentioned above, CKY-based models have cubic time complexity. hu-etal-2021-r2d2 propose a pruned differential CKY encoding architecture with a simple pretraining objective related to bi-directional language modeling. They reduce the time complexity to a linear number of composition steps and make it possible to apply sophisticated tree encoders and to scale to large corpus pretraining.

6 Conclusion

In this paper, we have presented Fast-R2D2, which improves the performance and inference speed of R2D2 by introducing a fast top-down parser to guide the pruning of R2D2 encoder. Pretrained on the same corpus, Fast-R2D2 significantly outperforms sequential Transformers with a similar scale of parameters on classification tasks. Experimental results show that Fast-R2D2 is a promising and feasible way to learn hierarchical text representations, which is different from layer stacking models and can also generate interpretable intermediate representations. As future work, we are investigating leveraging the intermediate representations in additional downstream tasks.

7 Acknowledgement

We would like to thank the Aliyun EFLOPS team for their substantial support in designing and providing a cutting-edge training platform to facilitate fast experimentation in this work.