Recent advances in deep learning have led to major progress in many domains, with neural models sometimes achieving or even surpassing human performance(wang2019superglue). However, these methods often struggle in out-of-distribution (ood
) settings where train and test examples are drawn from different distributions. In particular, unlike humans, conventional sequence-to-sequence (seq2seq) models, widely used in natural language processing (NLP), fail to generalizesystematically (bahdanau2018systematic; lake2018generalization; loula2018rearranging), i.e., correctly interpret sentences representing novel combinations of concepts seen in training. Our goal is to provide a mechanism for encouraging systematic generalization in seq2seq models.
To get an intuition about our method, consider the semantic parsing task shown in Figure 1. A learner needs to map a natural language (NL) utterance to a program which can then be executed on a knowledge base. To process the test utterance, the learner needs to first decompose it into two segments previously observed in training (shown in green and blue), and then combine their corresponding program fragments to create a new program. Current seq2seq models fail in this systematic generalization setting finegan-dollak-etal-2018-improving; keysers2019measuring. In contrast, traditional grammar formalisms decompose correspondences between utterances and programs into compositional mappings of substructures (steedman2000syntactic), enabling grammar-based parsers to recombine rules acquired during training, as needed for systematic generalization. Grammars have proven essential in statistical semantic parsing in the pre-neural era zettlemoyer2012learning; wong-mooney-2006-learning, and have gained renewed interest now as a means of achieving systematic generalization herzig2020span; shaw2020compositional. However, grammars are hard to create and maintain (e.g., requiring grammar engineering or grammar induction stages) and do not scale well to NLP problems beyond semantic parsing (e.g., machine translation). In this work, we argue that the key property of grammar-based models, giving rise to their improved ood performance, is that a grammar implicitly encodes alignments between input and output segments. For example, in Figure 1, the expected segment-level alignments are ‘the length len’ and ‘the longest river longest(river(all))’. Instead of developing a full-fledged grammar-based method, we directly model segment-level alignments as structured latent variables. The resulting alignment-driven seq2seq model remains end-to-end differentiable, and, in principle, applicable to any sequence transduction problem.
Modeling segment-level alignments requires simultaneously inducing a segmentation of input and output sequences and discovering correspondences between the input and output segments. While segment-level alignments have been previously incorporated in neural models (yu2016online; wang2017sequence), to maintain tractability, these approaches support only monotonic alignments. The monotonicity assumption is reasonable for certain tasks (e.g., summarization), but it is generally overly restrictive (e.g., consider semantic parsing and machine translation). To relax this assumption, we complement monotonic alignments with an extra reordering step. That is, we first permute the source sequence so that segments within the reordered sequence can be aligned monotonically to segments of the target sequence. Coupling latent permutations with monotonic alignments dramatically increases the space of admissible segment alignments.
The space of general permutations is exceedingly large, so, to allow for efficient training, we restrict ourselves to separable permutations (bose1998pattern). We model separable permutations as hierarchical reordering of segments using permutation trees. This hierarchical way of modeling permutations reflects the hierarchical nature of language and hence is arguably more appropriate than ‘flat’ alternatives mena2018learning. Interestingly, a recent study (steedman2020formal) demonstrated that separable permutations are sufficient for capturing the variability of permutations in linguistic constructions across natural languages, providing further motivation for our modeling choice.
Simply marginalizing over all possible separable permutations remains intractable. Instead, inspired by recent work on modeling latent discrete structures (corro-titov-2019-learning; fu2020latent), we introduce a continuous relaxation of the reordering problem. The key ingredient of the relaxation is marginal inference. In this work, we propose an efficient dynamic programming algorithm to perform exact marginal inference with separable permutations, resulting in an effective differentiable neural module producing relaxed separable permutations. By plugging this module into an existing module supporting monotonic segment alignments yu2016online, we obtain an end-to-end differentiable seq2seq model, supporting non-monotonic segment-level alignments. In summary, our contributions are:
A general seq2seq model for NLP tasks that accounts for latent non-monotonic segment-level alignments.
A novel and efficient algorithm for exact marginal inference with separable permutations, allowing for end-to-end training using a continuous relaxation.111Our code and data will be available upon acceptance.
Experiments on synthetic problems and NLP tasks (semantic parsing and machine translation) showing that modeling segment alignments is beneficial for systematic generalization.
2 Background and Related Work
2.1 Systematic Generalization
Human learners exhibit systematic generalization, which refers to their ability to generalize from training data to novel situations. This is possible due to compositionality of natural languages - to a large degree, sentences are built using an inventory of primitive concepts and finite structure-building mechanisms (chomsky_aspects_1965). For example, if one understands ‘John loves the girl’, they should also understand ‘The girl loves John’ (fodor1988connectionism). This is done by ‘knowing’ the meaning of individual words and the grammatical principle of subject-verb-object composition. As pointed out by goodwin2020probing, systematicity entails that primitive units have consistent meaning across different contexts. In contrast, in seq2seq models, the representations of a word are highly influenced by context (see experiments in lake2018generalization). This is also consistent with the observation that seq2seq models tend to memorize large chunks rather than discover underlying compositional principles (hupkes_compositionality_2019). The memorization of large sequences lets the model fit the training distribution but harms out-of-distribution generalization.
2.2 Discrete Alignments as Conditional Computation Graphs
Latent discrete structures enable the incorporation of inductive biases into neural models and have been beneficial for a range of problems. For example, input-dependent module layouts andreas2016neural or graphs norcliffe2018learning have been explored in visual question answering. There is also a large body of work on inducing task-specific discrete representations (usually trees) for NL sentences yogatama2016learning; niculae2018sparsemap; havrylov-etal-2019-cooperative; corro-titov-2019-learning
. The trees are induced simultaneously with learning a model performing a computation relying on the tree (typically a recursive neural networksocher2011semi), while optimizing a task-specific loss. Given the role the structures play in these approaches – i.e., defining the computation flow – we can think of the structures as conditional computation graphs.
In this work, we induce discrete alignments as conditional computation graphs to guide seq2seq models. Given a source sequence with tokens and a target sequence with tokens, we optimize the following objective:
where Encode is a function that embeds into with being the hidden size, is the alignment matrix between input and output tokens. In this framework, alignments are separately predicted by to guide the computation that maps to . The parameters of both model components ( and ) are disjoint.
Standard encoder-decoder models (bahdanau2014neural) rely on continuous attention weights rather than discrete alignments, i.e., for each target token . Discrete alignments have been considered in previous work (xu2015show; deng2018latent), in this case is a sequence of
categorical random variables (‘hard attention’). Though discrete, the hard attention only considers word-level alignments, i.e., assumes that each target token is aligned with a single source token. This is a limiting assumption; for example, in traditional statistical machine translation, word-based models (e.g.,(brown1993mathematics)) are known to achieve dramatically weaker results than phrase-based models (e.g., (koehn-etal-2007-moses)). In this work, we aim to bring the power of phrase-level (aka segment-level) alignments to neural seq2seq models.
3 Latent Segment Alignments via Separable Permutations
Our method integrates a layer of segment-level alignments with a seq2seq model. The architecture of our model is shown in Figure 2. Central to this model is the alignment network, which decomposes the alignment problem into two stages: (i) input reordering and (ii) monotonic alignment between the reordered sequence and the output. Formally, we decompose the alignment matrix from Eq 1 into two parts:
where is a permutation matrix, and represents monotonic alignments. With this decomposition, we can rewrite the objective in Eq 1 as follows:
where denotes the reordered representation. With a slight abuse of notation, now denotes the parameters of the model generating permutations, and denotes the parameters used to produce monotonic alignments. Given the permutation matrix , the second expectation , which we denote as , can be handled by existing methods, such as SSNT (yu2016online) and SWAN (wang2017sequence). In the rest of the paper, we choose SSNT as the module for handling monotonic alignment.222 In our initial experiments, we found that SWAN works as well as SSNT but is considerably slower. We can rewrite the objective we optimize in the following compact form:
Algorithmically, both in and , segments are considered as basic elements being manipulated: (i) SSNT considers all possible monotonic segment-to-segment alignments (refer to yu2016online); (ii) our reordering module, presented in subsequent sections, creates the final permutation matrix by hierarchically reordering input segments. Modeling segments provides a strong inductive bias, reflecting the intuition that sequence transduction in NLP can be largely accomplished by manipulations at the level of segments. In contrast, in conventional seq2seq methods there is no explicit notion of segments.
3.1 Structured Latent Reordering by Permutation Trees
Inspired by steedman2020formal, we restrict word reorderings to separable permutations. Formally, separable permutations are defined in terms of permutation trees (aka separating trees (bose1998pattern)), i.e., if a permutation can be represented by a permutation tree, it is separable. A permutation tree over a permutation of a sequence is a binary tree in which each node represents the ordering of a segment ; the children exhaustively split their parent into sub-segments and . Each node has a binary label that decides whether the segment of the left child precedes that of the right child. Bracketing transduction grammar (BTG, wu1997stochastic), which is proposed in the context of machine translation, is the corresponding context-free grammar that has one non-terminal () and three anchored rules:
where is the anchored non-terminal covering the segment from to (excluding ). The first two rules decide whether to keep or invert two segments when constructing a larger segment; the last rule states that every word in an utterance is associated with a non-terminal . An example is shown in Figure 3. This hierarchical approach to generating separable permutations reflects the compositional nature of language, and, thus, appears more appealing than using ‘flat’ alternatives mena2018learning. Moreover, with BTGs, we can incorporate segment-level features to model separable permutations, and design tractable algorithms.
By assigning a score to each anchored rule using segment-level features, we obtain a distribution over all possible derivations, and use it to compute the objective in Eq 4.
where is a score function assigning a (non-negative) weight to an anchored rule , is the partition function, which can be computed using the inside algorithm, is the permutation matrix corresponds to the derivation . BTG, along with the weight assigned for each rule, is a weighted context-free grammar (WCFG). In this WCFG, the weight is only normalized at the derivation level. As we will see in Algorithm 1, we are interested in normalizing the weight of production rules and convert the WCFG to an equivalent PCFG following smith2007weighted
, so that the probability of a derivation can be computed as follows:
where is the weight of the production rule under the transformed PCFG. The details of the conversion are provided in the Appendix.
3.2 Soft Reordering: Computing Marginal Permutations
The first strategy is to use the deterministic expectation of permutations to softly reorder a sentence, analogous to the way standard attention approximates categorical random variables. Specifically, we use the following approximation:
where is the marginal permutation matrix, and it can be treated as structured attention (kim2017structured). Methods for performing marginal inference of anchored rules, i.e., computing the marginal distribution of production rules, are well-known in NLP (jurafsky2000speech). However, we are interested in the marginal permutation matrix (or equivalently the expectation of the matrix components) as the matrix is the data structure that is ultimately used in our model. As a key contribution of this work, we propose an efficient algorithm to exactly compute the marginal permutation matrix using dynamic programming.
In order to compute the marginal permutation matrix we need to marginalize over the exponentially many derivations of permutation trees. We propose to map a derivation of BTG into its corresponding permutation matrix in a recursive manner. Specifically, we first associate word with an identity permutation matrix ; then we associate Straight and Inverted rules with direct
and skewsums of permutation matrices, respectively:
For example, the permutation matrix of the derivation tree shown in Figure 3 can be obtained by:
Intuitively, the permutation matrix of long segments can be constructed by composing permutation matrices of short segments. Motivated by this, we propose a dynamic programming algorithm, which takes advantage of the observation that we can reuse the permutation matrices of short segments when computing permutation matrices of long segments, as shown in Algorithm 1. While the above equation is defined over discrete permutation matrices encoding a single derivation, the algorithm applies recursive rules to expected permutation matrices. Central to the algorithm is the following recursion:
where is the expected permutation matrix for the segment from to , is the probability of employing the production rule , defined in Eq 6. Overall, Algorithm 1 is a bottom-up method that constructs expected permutation matrices incrementally in Step 13 and 14, while relying on the probability of the associated production rule. We prove the correctness of this algorithm by induction in the Appendix.
3.3 Hard Reordering: Gumbel-Permutation by Differentiable Sampling
During inference, for efficiency, it is convenient to rely on the most probable derivation and its corresponding most probable :
The use of discrete permutations during inference and soft reorderings during training leads to a training-inference gap which may be problematic. Inspired by recent work (Gumbel-Softmax, jang2016categorical; maddison2016concrete) that relaxes the sampling procedure of a categorical distribution using the Gumbel-Max trick, we use a differentiable procedure to sample a from . Recall that is converted to a PCFG, and sampling from this PCFG can be done by sampling a production rule for each segment and then composing all the production rules together. Specifically, a production rule for a given segment is sampled using the Straight-Through Gumbel-Softmax, denoted with in Step 8 of Algorithm 1.
We propose two efficient algorithms for computing marginals and obtaining samples of separable permutations with their distribution parameterized via BTG. In both algorithms, PCFG plays an important role of decomposing a global problem into sub-problems, which explains why we convert into a PCFG in Eq 6. Relying on the proposed algorithms, we present two relaxations of the discrete permutations that let us induce latent reorderings with end-to-end training. We refer to the resulting system as ReMoto, short for a seq2seq model with Reordered-then-Monotone alignments. Soft-ReMoto and Hard-ReMoto denote the versions which use soft marginal permutations and hard Gumbel permutations, respectively.
Reordering in Previous Work
In traditional statistical machine translation (SMT), reorderings are typically handled by a distortion model (e.g., al2006distortion) in a pipeline manner. neubig-etal-2012-inducing and nakagawa-2015-efficient also use BTG for modeling reorderings, but they assume the alignments have been produced on a preprocessing step, using an alignment tool och-ney-2003-systematic. Relying on these alignments, they induce reorderings. Inversely, we rely on latent reordering to induce the underlying alignments.
Reordering modules have been previously used in neural models, and can be assigned to the following two categories. First, reordering components (huang2017towards; chen2019neural)
were proposed for neural machine translation. However, they are not structured or sufficiently constrained in the sense that they may produce invalid reorderings (e.g., a word is likely to be moved to more than one new position). In contrast, our module is a principled way of dealing with latent reorderings. Second, the generic permutations (i.e., one-to-one matchings), used previously in NLP (e.g.,lyu-titov-2018-amr), do not suit our needs as they are defined in terms of tokens, rather than segments. They also require extra approximations to enable differentiable learning mena2018learning. We will compare with both classes of methods in our experiments.
First, we consider two diagnostic tasks where we can test the neural reordering module on its own. Then we further assess our general seq2seq model ReMoto on two real-world NLP tasks.
4.1 Diagnostic Tasks
|- with shared parameters||100.0||40.9||100.0||100.0|
We design a task of converting an arithmetic expression in infix format to the one in postfix format. An example is shown in Table 1. We create a synthetic dataset by sampling data from a PCFG.333 The dataset is provided in the supplementary material, and it will be released. In order to generalize, a system needs to learn how to manipulate internal sub-structures (i.e., segments) while respecting well-formedness constraints. This task can be solved by the shunting-yard algorithm but we are interested to see if neural networks can solve it and generalize ood by learning from raw infix-postfix pairs. For standard splits (IID), we randomly sample 20k infix-postfix pairs whose nesting depth is set to be between 1 and 6; 10k, 5k, 5k of these pairs are used as train, dev and test sets, respectively. To test systematic generalization, we create a Length split (Len) where training and dev examples remain the same as IID splits, but test examples have a nesting depth of 7. In this way, we test whether a system can generalize to unseen longer input.
We use the SCAN dataset lake2018generalization, which consists of simple English commands coupled with sequences of discrete actions. Here we use the semantic parsing version, SCAN-SP herzig2020span, where the goal is to predict programs corresponding to the action sequences. An example is shown in Table 1. As in these experiments our goal is to test the reordering component alone, we remove parentheses and commas in programs. For example, the program after (twice (jump), thrice(walk (around, left))) is converted to a sequence: after twice jump thrice walk around left. In this way, the resulting parentheses-free sequence can be viewed as a reordered sequence of the NL utterance ‘jump twice after walk around left thrice’.444The grammar of the programs is known so we can reconstruct the original program from the intermediate parentheses-free sequences using the grammar. Apart from the standard split (IID, aka simple split (lake2018generalization)), we create a Length split (Len) where the training set contains NL utterances with maximum length 5, while utterances in the dev and test sets have minimum length of 6.555 Since we use the program form, the original length split (lake2018generalization), which is based on the length of action sequence, is not very suitable in our experiments.
Baselines and Results
In both diagnostic tasks, we use ReMoto with a trivial monotonic alignment matrix
(an identity matrix) in Eq3. Essentially, ReMoto becomes a sequence tagging model. We consider three baselines: (1) vanilla Seq2Seq models with Luong attention luong2015effective; (2) an LSTM-based tagging model which learn the reordering implicitly, and can be viewed as a version ReMoto with a trivial and ; (3) Sinkhorn Attention that replaces the permutation matrix of Soft-ReMoto in Eq 4 by Gumbel-Sinkhorn networks mena2018learning.
We report results by averaging over three runs in Table 2. In both datasets, almost all methods achieve perfect accuracy in IID splits. However, baseline systems cannot generalize well to the challenging LEN splits. In contrast, our methods, both Soft-ReMoto and Hard-ReMoto perform very well on Len splits, surpassing the best baseline system by large margins (). The results indicate that ReMoto, particularly its neural reordering module, has the right inductive bias to learn reorderings. We also test a variant Soft-ReMoto where parameters with shared input embeddings. This variant does not generalize well to the LEN split on the arithmetic task, showing that it is beneficial to split models of the ‘syntax’ (i.e., alignment) and ‘semantics’, confirming what has been previously observed havrylov-etal-2019-cooperative; russin2019compositional.
4.2 Semantic Parsing
Our second experiment is on semantic parsing where ReMoto models the latent alignment between NL utterances and their corresponding programs. We use GeoQuery dataset zelle1996learning which contains 880 utterance-programs pairs. The programs are in variable-free form kate2005learning; an example is shown in Table 1. Similarly to SCAN-SP, we transform the programs into parentheses-free form which have better structural correspondence with utterances.666 Again, we can reconstruct the original programs based on the grammar. An example of such parentheses-free form is shown in Figure 2. Apart from the standard version, we also experiment with the Chinese and German versions of GeoQuery (jones-etal-2012-semantic; susanto2017semantic). Since different languages exhibit divergent word orders steedman2020formal, the results in the multilingual setting will tell us if our model can deal with this variability.
In addition to standard IID splits, we create a Len split where the training examples have parentheses-free programs with a maximum length 4; the dev and test examples have programs with a minimum length 5. We also experiment with the Temp split herzig2020span where training and test examples have programs with disjoint templates.
Baselines and Results
Apart from conventional seq2seq models, for comparison, we also implemented the syntactic attention russin2019compositional. Our model ReMoto is similar in spirit to the syntactic attention, ‘syntax’ in their model (i.e., alignment) and ‘semantics’ (i.e., producing the representation relying on the alignment) are separately modeled. In contrast to our structured mechanism for modeling alignments, their syntactic attention still relies on the conventional attention mechanism. Results are shown in Table 3.
For the challenging Temp and Len splits, our best performing model Hard-ReMoto achieves consistently stronger performance than seq2seq and syntactic attention.
Thus, our model bridges the gap between conventional seq2seq models
and specialized state-of-the-art grammar-based models shaw2020compositional; herzig2020span.777
NQG (shaw2020compositional) achieves 35.0% in the English LEN, and
SBSP (herzig2020span) (without lexicon) achieves 65.9% in the English
(without lexicon) achieves 65.9% in the EnglishTEMP in execution accuracy. Both models are augmented with pre-trained representations (BERT).
4.3 Machine Translation
Our final experiment is on small-scale machine translation tasks, where ReMoto models the latent alignments between parallel sentences from two different languages. To probe systematic generalization, we also create a LEN split for each language pair in addition to the standard IID splits.
We use the small en-ja dataset extracted from TANKA Corpus. The original split (IID) has 50k/500/500 examples for train/dev/test with lengths 4-16 words.888https://github.com/odashi/small_parallel_enja We create a LEN split where the English sentences of training examples have a maximum length 12 whereas the English sentences in dev/test have a minimum length 13. The LEN split has 50k/538/538 examples for train/dev/test, respectively.
We extract a subset from FBIS corpus (LDC2003E14) by filtering English sentences with length 4-30. We randomly shuffle the resulting data to obtain an IID split which has 141k/3k/3k examples for train/dev/test, respectively. In addition, we create a LEN split where English sentences of training examples have a maximum length 29 whereas the English sentences of dev/test examples have a length 30. The LEN split has 140k/4k/4k examples as train/dev/test sets respectively.
Baselines and Results
In addition to the conventional seq2seq, we compare with the original SSNT model which only accounts for monotonic alignments. We also implemented a variant that combines SSNT with the local reordering module huang2017towards as our baseline to show the advantage of our structured ordering module.
Results are shown in Table 4. Our model, especially Hard-ReMoto, consistently outperforms other baselines on both splits. In EN-JA translation, the advantage of our best-performance Hard-ReMoto is slightly more pronounced in the LEN split than in the IID split. In ZH-EN translation, while SSNT and its variant do not outperform seq2seq in the LEN split, ReMoto can still achieve better results than seq2seq. These results show that our model is better than its alternatives at generalizing to longer sentences for machine translation.
Latent alignments, apart from promoting systematic generalization, also lead to better interpretability as discrete alignments reveal the internal process for generating output. For example, in Table 5, we show a few examples from our model. Each output segment is associated with an underlying rationale, i.e. a segment of the reordered input.
5 Conclusion and Future Work
In this work, we propose a new general seq2seq model that accounts for latent segment-level alignments. Central to this model is a novel structured reordering module which is coupled with existing modules to handle non-monotonic segment alignments. We model reorderings as separable permutations and propose an efficient dynamic programming algorithm to perform marginal inference and sampling. It allows latent reorderings to be induced with end-to-end training. Empirical results on both synthetic and real-world datasets show that our model can achieve better systematic generalization than conventional seq2seq models.
The strong inductive bias introduced by modeling alignments in this work could be potentially beneficial in weakly-supervised settings, such as weakly-supervised semantic parsing, where conventional seq2seq models usually do not perform well.