1 Introduction
Recent advances in deep learning have led to major progress in many domains, with neural models sometimes achieving or even surpassing human performance
(wang2019superglue). However, these methods often struggle in outofdistribution (ood) settings where train and test examples are drawn from different distributions. In particular, unlike humans, conventional sequencetosequence (seq2seq) models, widely used in natural language processing (NLP), fail to generalize
systematically (bahdanau2018systematic; lake2018generalization; loula2018rearranging), i.e., correctly interpret sentences representing novel combinations of concepts seen in training. Our goal is to provide a mechanism for encouraging systematic generalization in seq2seq models.To get an intuition about our method, consider the semantic parsing task shown in Figure 1. A learner needs to map a natural language (NL) utterance to a program which can then be executed on a knowledge base. To process the test utterance, the learner needs to first decompose it into two segments previously observed in training (shown in green and blue), and then combine their corresponding program fragments to create a new program. Current seq2seq models fail in this systematic generalization setting finegandollaketal2018improving; keysers2019measuring. In contrast, traditional grammar formalisms decompose correspondences between utterances and programs into compositional mappings of substructures (steedman2000syntactic), enabling grammarbased parsers to recombine rules acquired during training, as needed for systematic generalization. Grammars have proven essential in statistical semantic parsing in the preneural era zettlemoyer2012learning; wongmooney2006learning, and have gained renewed interest now as a means of achieving systematic generalization herzig2020span; shaw2020compositional. However, grammars are hard to create and maintain (e.g., requiring grammar engineering or grammar induction stages) and do not scale well to NLP problems beyond semantic parsing (e.g., machine translation). In this work, we argue that the key property of grammarbased models, giving rise to their improved ood performance, is that a grammar implicitly encodes alignments between input and output segments. For example, in Figure 1, the expected segmentlevel alignments are ‘the length len’ and ‘the longest river longest(river(all))’. Instead of developing a fullfledged grammarbased method, we directly model segmentlevel alignments as structured latent variables. The resulting alignmentdriven seq2seq model remains endtoend differentiable, and, in principle, applicable to any sequence transduction problem.
Modeling segmentlevel alignments requires simultaneously inducing a segmentation of input and output sequences and discovering correspondences between the input and output segments. While segmentlevel alignments have been previously incorporated in neural models (yu2016online; wang2017sequence), to maintain tractability, these approaches support only monotonic alignments. The monotonicity assumption is reasonable for certain tasks (e.g., summarization), but it is generally overly restrictive (e.g., consider semantic parsing and machine translation). To relax this assumption, we complement monotonic alignments with an extra reordering step. That is, we first permute the source sequence so that segments within the reordered sequence can be aligned monotonically to segments of the target sequence. Coupling latent permutations with monotonic alignments dramatically increases the space of admissible segment alignments.
The space of general permutations is exceedingly large, so, to allow for efficient training, we restrict ourselves to separable permutations (bose1998pattern). We model separable permutations as hierarchical reordering of segments using permutation trees. This hierarchical way of modeling permutations reflects the hierarchical nature of language and hence is arguably more appropriate than ‘flat’ alternatives mena2018learning. Interestingly, a recent study (steedman2020formal) demonstrated that separable permutations are sufficient for capturing the variability of permutations in linguistic constructions across natural languages, providing further motivation for our modeling choice.
Simply marginalizing over all possible separable permutations remains intractable. Instead, inspired by recent work on modeling latent discrete structures (corrotitov2019learning; fu2020latent), we introduce a continuous relaxation of the reordering problem. The key ingredient of the relaxation is marginal inference. In this work, we propose an efficient dynamic programming algorithm to perform exact marginal inference with separable permutations, resulting in an effective differentiable neural module producing relaxed separable permutations. By plugging this module into an existing module supporting monotonic segment alignments yu2016online, we obtain an endtoend differentiable seq2seq model, supporting nonmonotonic segmentlevel alignments. In summary, our contributions are:

[topsep=0pt,itemsep=2pt]

A general seq2seq model for NLP tasks that accounts for latent nonmonotonic segmentlevel alignments.

A novel and efficient algorithm for exact marginal inference with separable permutations, allowing for endtoend training using a continuous relaxation.^{1}^{1}1Our code and data will be available upon acceptance.

Experiments on synthetic problems and NLP tasks (semantic parsing and machine translation) showing that modeling segment alignments is beneficial for systematic generalization.
2 Background and Related Work
2.1 Systematic Generalization
Human learners exhibit systematic generalization, which refers to their ability to generalize from training data to novel situations. This is possible due to compositionality of natural languages  to a large degree, sentences are built using an inventory of primitive concepts and finite structurebuilding mechanisms (chomsky_aspects_1965). For example, if one understands ‘John loves the girl’, they should also understand ‘The girl loves John’ (fodor1988connectionism). This is done by ‘knowing’ the meaning of individual words and the grammatical principle of subjectverbobject composition. As pointed out by goodwin2020probing, systematicity entails that primitive units have consistent meaning across different contexts. In contrast, in seq2seq models, the representations of a word are highly influenced by context (see experiments in lake2018generalization). This is also consistent with the observation that seq2seq models tend to memorize large chunks rather than discover underlying compositional principles (hupkes_compositionality_2019). The memorization of large sequences lets the model fit the training distribution but harms outofdistribution generalization.
2.2 Discrete Alignments as Conditional Computation Graphs
Latent discrete structures enable the incorporation of inductive biases into neural models and have been beneficial for a range of problems. For example, inputdependent module layouts andreas2016neural or graphs norcliffe2018learning have been explored in visual question answering. There is also a large body of work on inducing taskspecific discrete representations (usually trees) for NL sentences yogatama2016learning; niculae2018sparsemap; havrylovetal2019cooperative; corrotitov2019learning
. The trees are induced simultaneously with learning a model performing a computation relying on the tree (typically a recursive neural network
socher2011semi), while optimizing a taskspecific loss. Given the role the structures play in these approaches – i.e., defining the computation flow – we can think of the structures as conditional computation graphs.In this work, we induce discrete alignments as conditional computation graphs to guide seq2seq models. Given a source sequence with tokens and a target sequence with tokens, we optimize the following objective:
(1) 
where Encode is a function that embeds into with being the hidden size, is the alignment matrix between input and output tokens. In this framework, alignments are separately predicted by to guide the computation that maps to . The parameters of both model components ( and ) are disjoint.
Standard encoderdecoder models (bahdanau2014neural) rely on continuous attention weights rather than discrete alignments, i.e., for each target token . Discrete alignments have been considered in previous work (xu2015show; deng2018latent), in this case is a sequence of
categorical random variables (‘hard attention’). Though discrete, the hard attention only considers wordlevel alignments, i.e., assumes that each target token is aligned with a single source token. This is a limiting assumption; for example, in traditional statistical machine translation, wordbased models (e.g.,
(brown1993mathematics)) are known to achieve dramatically weaker results than phrasebased models (e.g., (koehnetal2007moses)). In this work, we aim to bring the power of phraselevel (aka segmentlevel) alignments to neural seq2seq models.3 Latent Segment Alignments via Separable Permutations
Our method integrates a layer of segmentlevel alignments with a seq2seq model. The architecture of our model is shown in Figure 2. Central to this model is the alignment network, which decomposes the alignment problem into two stages: (i) input reordering and (ii) monotonic alignment between the reordered sequence and the output. Formally, we decompose the alignment matrix from Eq 1 into two parts:
(2) 
where is a permutation matrix, and represents monotonic alignments. With this decomposition, we can rewrite the objective in Eq 1 as follows:
(3) 
where denotes the reordered representation. With a slight abuse of notation, now denotes the parameters of the model generating permutations, and denotes the parameters used to produce monotonic alignments. Given the permutation matrix , the second expectation , which we denote as , can be handled by existing methods, such as SSNT (yu2016online) and SWAN (wang2017sequence). In the rest of the paper, we choose SSNT as the module for handling monotonic alignment.^{2}^{2}2 In our initial experiments, we found that SWAN works as well as SSNT but is considerably slower. We can rewrite the objective we optimize in the following compact form:
(4) 
Algorithmically, both in and , segments are considered as basic elements being manipulated: (i) SSNT considers all possible monotonic segmenttosegment alignments (refer to yu2016online); (ii) our reordering module, presented in subsequent sections, creates the final permutation matrix by hierarchically reordering input segments. Modeling segments provides a strong inductive bias, reflecting the intuition that sequence transduction in NLP can be largely accomplished by manipulations at the level of segments. In contrast, in conventional seq2seq methods there is no explicit notion of segments.
3.1 Structured Latent Reordering by Permutation Trees
Inspired by steedman2020formal, we restrict word reorderings to separable permutations. Formally, separable permutations are defined in terms of permutation trees (aka separating trees (bose1998pattern)), i.e., if a permutation can be represented by a permutation tree, it is separable. A permutation tree over a permutation of a sequence is a binary tree in which each node represents the ordering of a segment ; the children exhaustively split their parent into subsegments and . Each node has a binary label that decides whether the segment of the left child precedes that of the right child. Bracketing transduction grammar (BTG, wu1997stochastic), which is proposed in the context of machine translation, is the corresponding contextfree grammar that has one nonterminal () and three anchored rules:

[noitemsep]

:

:

:
where is the anchored nonterminal covering the segment from to (excluding ). The first two rules decide whether to keep or invert two segments when constructing a larger segment; the last rule states that every word in an utterance is associated with a nonterminal . An example is shown in Figure 3. This hierarchical approach to generating separable permutations reflects the compositional nature of language, and, thus, appears more appealing than using ‘flat’ alternatives mena2018learning. Moreover, with BTGs, we can incorporate segmentlevel features to model separable permutations, and design tractable algorithms.
By assigning a score to each anchored rule using segmentlevel features, we obtain a distribution over all possible derivations, and use it to compute the objective in Eq 4.
(5) 
where is a score function assigning a (nonnegative) weight to an anchored rule , is the partition function, which can be computed using the inside algorithm, is the permutation matrix corresponds to the derivation . BTG, along with the weight assigned for each rule, is a weighted contextfree grammar (WCFG). In this WCFG, the weight is only normalized at the derivation level. As we will see in Algorithm 1, we are interested in normalizing the weight of production rules and convert the WCFG to an equivalent PCFG following smith2007weighted
, so that the probability of a derivation can be computed as follows:
(6) 
where is the weight of the production rule under the transformed PCFG. The details of the conversion are provided in the Appendix.
The challenge with optimizing the objective in Eq 5
is that the search space of possible derivations is exponential, making the estimation of the gradients with respect to parameters of the reordering component (
) nontrivial. We now present two differentiable surrogates we use.3.2 Soft Reordering: Computing Marginal Permutations
The first strategy is to use the deterministic expectation of permutations to softly reorder a sentence, analogous to the way standard attention approximates categorical random variables. Specifically, we use the following approximation:
where is the marginal permutation matrix, and it can be treated as structured attention (kim2017structured). Methods for performing marginal inference of anchored rules, i.e., computing the marginal distribution of production rules, are wellknown in NLP (jurafsky2000speech). However, we are interested in the marginal permutation matrix (or equivalently the expectation of the matrix components) as the matrix is the data structure that is ultimately used in our model. As a key contribution of this work, we propose an efficient algorithm to exactly compute the marginal permutation matrix using dynamic programming.
In order to compute the marginal permutation matrix we need to marginalize over the exponentially many derivations of permutation trees. We propose to map a derivation of BTG into its corresponding permutation matrix in a recursive manner. Specifically, we first associate word with an identity permutation matrix ; then we associate Straight and Inverted rules with direct
and skew
sums of permutation matrices, respectively:For example, the permutation matrix of the derivation tree shown in Figure 3 can be obtained by:
(7) 
Intuitively, the permutation matrix of long segments can be constructed by composing permutation matrices of short segments. Motivated by this, we propose a dynamic programming algorithm, which takes advantage of the observation that we can reuse the permutation matrices of short segments when computing permutation matrices of long segments, as shown in Algorithm 1. While the above equation is defined over discrete permutation matrices encoding a single derivation, the algorithm applies recursive rules to expected permutation matrices. Central to the algorithm is the following recursion:
(8) 
where is the expected permutation matrix for the segment from to , is the probability of employing the production rule , defined in Eq 6. Overall, Algorithm 1 is a bottomup method that constructs expected permutation matrices incrementally in Step 13 and 14, while relying on the probability of the associated production rule. We prove the correctness of this algorithm by induction in the Appendix.
3.3 Hard Reordering: GumbelPermutation by Differentiable Sampling
During inference, for efficiency, it is convenient to rely on the most probable derivation and its corresponding most probable :
(9) 
The use of discrete permutations during inference and soft reorderings during training leads to a traininginference gap which may be problematic. Inspired by recent work (GumbelSoftmax, jang2016categorical; maddison2016concrete) that relaxes the sampling procedure of a categorical distribution using the GumbelMax trick, we use a differentiable procedure to sample a from . Recall that is converted to a PCFG, and sampling from this PCFG can be done by sampling a production rule for each segment and then composing all the production rules together. Specifically, a production rule for a given segment is sampled using the StraightThrough GumbelSoftmax, denoted with in Step 8 of Algorithm 1.
Summary
We propose two efficient algorithms for computing marginals and obtaining samples of separable permutations with their distribution parameterized via BTG. In both algorithms, PCFG plays an important role of decomposing a global problem into subproblems, which explains why we convert into a PCFG in Eq 6. Relying on the proposed algorithms, we present two relaxations of the discrete permutations that let us induce latent reorderings with endtoend training. We refer to the resulting system as ReMoto, short for a seq2seq model with ReorderedthenMonotone alignments. SoftReMoto and HardReMoto denote the versions which use soft marginal permutations and hard Gumbel permutations, respectively.
Reordering in Previous Work
In traditional statistical machine translation (SMT), reorderings are typically handled by a distortion model (e.g., al2006distortion) in a pipeline manner. neubigetal2012inducing and nakagawa2015efficient also use BTG for modeling reorderings, but they assume the alignments have been produced on a preprocessing step, using an alignment tool ochney2003systematic. Relying on these alignments, they induce reorderings. Inversely, we rely on latent reordering to induce the underlying alignments.
Reordering modules have been previously used in neural models, and can be assigned to the following two categories. First, reordering components (huang2017towards; chen2019neural)
were proposed for neural machine translation. However, they are not structured or sufficiently constrained in the sense that they may produce invalid reorderings (e.g., a word is likely to be moved to more than one new position). In contrast, our module is a principled way of dealing with latent reorderings. Second, the generic permutations (i.e., onetoone matchings), used previously in NLP (e.g.,
lyutitov2018amr), do not suit our needs as they are defined in terms of tokens, rather than segments. They also require extra approximations to enable differentiable learning mena2018learning. We will compare with both classes of methods in our experiments.4 Experiments
First, we consider two diagnostic tasks where we can test the neural reordering module on its own. Then we further assess our general seq2seq model ReMoto on two realworld NLP tasks.
4.1 Diagnostic Tasks
0.9 Dataset Input Output Arithmetic SCANSP jump twice after walk around left thrice after (twice (jump), thrice(walk (around, left))) GeoQuery how many states do not have rivers ? count(exclude(state(all), loc_1(river(all))))
Arithmetic  SCANSP  
Model  IID  Len  IID  Len 
Seq2Seq  100.0  0.0  100.0  13.9 
LSTMbased Tagging  100.0  20.6  100.0  57.7 
SinkhornAttention Tagging  99.5  8.8  100.0  48.2 
SoftReMoto  100.0  86.9  100.0  100.0 
 with shared parameters  100.0  40.9  100.0  100.0 
HardReMoto  100.0  83.3  100.0  100.0 
Arithmetic
We design a task of converting an arithmetic expression in infix format to the one in postfix format. An example is shown in Table 1. We create a synthetic dataset by sampling data from a PCFG.^{3}^{3}3 The dataset is provided in the supplementary material, and it will be released. In order to generalize, a system needs to learn how to manipulate internal substructures (i.e., segments) while respecting wellformedness constraints. This task can be solved by the shuntingyard algorithm but we are interested to see if neural networks can solve it and generalize ood by learning from raw infixpostfix pairs. For standard splits (IID), we randomly sample 20k infixpostfix pairs whose nesting depth is set to be between 1 and 6; 10k, 5k, 5k of these pairs are used as train, dev and test sets, respectively. To test systematic generalization, we create a Length split (Len) where training and dev examples remain the same as IID splits, but test examples have a nesting depth of 7. In this way, we test whether a system can generalize to unseen longer input.
ScanSp
We use the SCAN dataset lake2018generalization, which consists of simple English commands coupled with sequences of discrete actions. Here we use the semantic parsing version, SCANSP herzig2020span, where the goal is to predict programs corresponding to the action sequences. An example is shown in Table 1. As in these experiments our goal is to test the reordering component alone, we remove parentheses and commas in programs. For example, the program after (twice (jump), thrice(walk (around, left))) is converted to a sequence: after twice jump thrice walk around left. In this way, the resulting parenthesesfree sequence can be viewed as a reordered sequence of the NL utterance ‘jump twice after walk around left thrice’.^{4}^{4}4The grammar of the programs is known so we can reconstruct the original program from the intermediate parenthesesfree sequences using the grammar. Apart from the standard split (IID, aka simple split (lake2018generalization)), we create a Length split (Len) where the training set contains NL utterances with maximum length 5, while utterances in the dev and test sets have minimum length of 6.^{5}^{5}5 Since we use the program form, the original length split (lake2018generalization), which is based on the length of action sequence, is not very suitable in our experiments.
Baselines and Results
In both diagnostic tasks, we use ReMoto with a trivial monotonic alignment matrix
(an identity matrix) in Eq
3. Essentially, ReMoto becomes a sequence tagging model. We consider three baselines: (1) vanilla Seq2Seq models with Luong attention luong2015effective; (2) an LSTMbased tagging model which learn the reordering implicitly, and can be viewed as a version ReMoto with a trivial and ; (3) Sinkhorn Attention that replaces the permutation matrix of SoftReMoto in Eq 4 by GumbelSinkhorn networks mena2018learning.We report results by averaging over three runs in Table 2. In both datasets, almost all methods achieve perfect accuracy in IID splits. However, baseline systems cannot generalize well to the challenging LEN splits. In contrast, our methods, both SoftReMoto and HardReMoto perform very well on Len splits, surpassing the best baseline system by large margins (). The results indicate that ReMoto, particularly its neural reordering module, has the right inductive bias to learn reorderings. We also test a variant SoftReMoto where parameters with shared input embeddings. This variant does not generalize well to the LEN split on the arithmetic task, showing that it is beneficial to split models of the ‘syntax’ (i.e., alignment) and ‘semantics’, confirming what has been previously observed havrylovetal2019cooperative; russin2019compositional.
4.2 Semantic Parsing
Our second experiment is on semantic parsing where ReMoto models the latent alignment between NL utterances and their corresponding programs. We use GeoQuery dataset zelle1996learning which contains 880 utteranceprograms pairs. The programs are in variablefree form kate2005learning; an example is shown in Table 1. Similarly to SCANSP, we transform the programs into parenthesesfree form which have better structural correspondence with utterances.^{6}^{6}6 Again, we can reconstruct the original programs based on the grammar. An example of such parenthesesfree form is shown in Figure 2. Apart from the standard version, we also experiment with the Chinese and German versions of GeoQuery (jonesetal2012semantic; susanto2017semantic). Since different languages exhibit divergent word orders steedman2020formal, the results in the multilingual setting will tell us if our model can deal with this variability.
In addition to standard IID splits, we create a Len split where the training examples have parenthesesfree programs with a maximum length 4; the dev and test examples have programs with a minimum length 5. We also experiment with the Temp split herzig2020span where training and test examples have programs with disjoint templates.
1.0 EN ZH DE Model IID Temp Len IID Temp Len IID Temp Len Seq2Seq 75.7 38.8 21.8 72.5 25.4 19.8 56.1 18.8 15.2 Syntactic Attention (russin2019compositional) 74.3 39.1 18.3 70.2 27.9 18.7 54.3 19.3 14.2 SoftReMoto 74.5 39.3 19.8 73.4 30.3 17.3 55.8 19.5 13.4 HardReMoto 75.2 43.2 23.2 74.3 45.7 22.3 55.6 22.3 16.6
Baselines and Results
Apart from conventional seq2seq models, for comparison, we also implemented the syntactic attention russin2019compositional. Our model ReMoto is similar in spirit to the syntactic attention, ‘syntax’ in their model (i.e., alignment) and ‘semantics’ (i.e., producing the representation relying on the alignment) are separately modeled. In contrast to our structured mechanism for modeling alignments, their syntactic attention still relies on the conventional attention mechanism. Results are shown in Table 3.
For the challenging Temp and Len splits, our best performing model HardReMoto achieves consistently stronger performance than seq2seq and syntactic attention. Thus, our model bridges the gap between conventional seq2seq models and specialized stateoftheart grammarbased models shaw2020compositional; herzig2020span.^{7}^{7}7 NQG (shaw2020compositional) achieves 35.0% in the English LEN, and SBSP (herzig2020span)
(without lexicon) achieves 65.9% in the English
TEMP in execution accuracy. Both models are augmented with pretrained representations (BERT).4.3 Machine Translation
Our final experiment is on smallscale machine translation tasks, where ReMoto models the latent alignments between parallel sentences from two different languages. To probe systematic generalization, we also create a LEN split for each language pair in addition to the standard IID splits.
EnglishJapanese
We use the small enja dataset extracted from TANKA Corpus. The original split (IID) has 50k/500/500 examples for train/dev/test with lengths 416 words.^{8}^{8}8https://github.com/odashi/small_parallel_enja We create a LEN split where the English sentences of training examples have a maximum length 12 whereas the English sentences in dev/test have a minimum length 13. The LEN split has 50k/538/538 examples for train/dev/test, respectively.
ChineseEnglish
We extract a subset from FBIS corpus (LDC2003E14) by filtering English sentences with length 430. We randomly shuffle the resulting data to obtain an IID split which has 141k/3k/3k examples for train/dev/test, respectively. In addition, we create a LEN split where English sentences of training examples have a maximum length 29 whereas the English sentences of dev/test examples have a length 30. The LEN split has 140k/4k/4k examples as train/dev/test sets respectively.
1.0 ENJA ZHEN IID Len IID Len Seq2Seq 35.6 25.3 21.4 18.1 SSNT (yu2016online) 36.3 26.5 20.5 17.3 Local Reordering (huang2017towards) 36.0 27.1 21.8 17.8 SoftReMoto 36.6 27.5 22.3 19.2 HardReMoto 37.4 28.7 22.6 19.5
Baselines and Results
In addition to the conventional seq2seq, we compare with the original SSNT model which only accounts for monotonic alignments. We also implemented a variant that combines SSNT with the local reordering module huang2017towards as our baseline to show the advantage of our structured ordering module.
Results are shown in Table 4. Our model, especially HardReMoto, consistently outperforms other baselines on both splits. In ENJA translation, the advantage of our bestperformance HardReMoto is slightly more pronounced in the LEN split than in the IID split. In ZHEN translation, while SSNT and its variant do not outperform seq2seq in the LEN split, ReMoto can still achieve better results than seq2seq. These results show that our model is better than its alternatives at generalizing to longer sentences for machine translation.
0.91 original input: 在 美国 哪些 州 与 最长 的 河流 接壤 reordered input: 州 接壤 最长 的 河流 与 哪些 美国 在 prediction: state next_to_2 longest river loc_2 countryid_ENTITY ground truth: state next_to_2 longest river loc_2 countryid_ENTITY original input: according to the newspaper , there was a big fire last night reordered input: according to the newspaper , night last big fire a there was prediction: 新聞 に よ れ ば 、 昨夜 大 火事 が あ っ た ground truth: 新聞 に よ る と 昨夜 大 火事 が あ っ た
Interpretability
Latent alignments, apart from promoting systematic generalization, also lead to better interpretability as discrete alignments reveal the internal process for generating output. For example, in Table 5, we show a few examples from our model. Each output segment is associated with an underlying rationale, i.e. a segment of the reordered input.
5 Conclusion and Future Work
In this work, we propose a new general seq2seq model that accounts for latent segmentlevel alignments. Central to this model is a novel structured reordering module which is coupled with existing modules to handle nonmonotonic segment alignments. We model reorderings as separable permutations and propose an efficient dynamic programming algorithm to perform marginal inference and sampling. It allows latent reorderings to be induced with endtoend training. Empirical results on both synthetic and realworld datasets show that our model can achieve better systematic generalization than conventional seq2seq models.
The strong inductive bias introduced by modeling alignments in this work could be potentially beneficial in weaklysupervised settings, such as weaklysupervised semantic parsing, where conventional seq2seq models usually do not perform well.
Comments
There are no comments yet.