Adversarial Transfer Learning for Punctuation Restoration

04/01/2020 ∙ by Jiangyan Yi, et al. ∙ 0

Previous studies demonstrate that word embeddings and part-of-speech (POS) tags are helpful for punctuation restoration tasks. However, two drawbacks still exist. One is that word embeddings are pre-trained by unidirectional language modeling objectives. Thus the word embeddings only contain left-to-right context information. The other is that POS tags are provided by an external POS tagger. So computation cost will be increased and incorrect predicted tags may affect the performance of restoring punctuation marks during decoding. This paper proposes adversarial transfer learning to address these problems. A pre-trained bidirectional encoder representations from transformers (BERT) model is used to initialize a punctuation model. Thus the transferred model parameters carry both left-to-right and right-to-left representations. Furthermore, adversarial multi-task learning is introduced to learn task invariant knowledge for punctuation prediction. We use an extra POS tagging task to help the training of the punctuation predicting task. Adversarial training is utilized to prevent the shared parameters from containing task specific information. We only use the punctuation predicting task to restore marks during decoding stage. Therefore, it will not need extra computation and not introduce incorrect tags from the POS tagger. Experiments are conducted on IWSLT2011 datasets. The results demonstrate that the punctuation predicting models obtain further performance improvement with task invariant knowledge from the POS tagging task. Our best model outperforms the previous state-of-the-art model trained only with lexical features by up to 9.2 absolute overall F_1-score on test set.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 10

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Generally, the output sequences of automatic speech recognition (ASR) systems don’t contain punctuation marks. Thus it degrades the readability of the generated words and leads to poor user experiences in real-world scenarios

[32]. So it is necessary to restore punctuation marks for speech transcripts.

Many attempts have been made to predict punctuation marks automatically. These approaches can be roughly divided into three categories in terms of applied features: prosody features, lexical features and the combination of the previous two features based methods.

Prosody features based methods are tried by some previous researchers [9, 17]. Christensen et al. [9]

use hidden Markov models to restore punctuation marks using acoustic data. Kim et al.

[17] try to perform punctuation prediction and speech recognition jointly with prosody features. The previous results show that prosody features are useful, but they don’t work well when speakers make pauses in unnatural places.

The combination of prosody and lexical features based methods are proposed to resolve this problem [29, 26]. Che et al. [5]

propose to train deep neural networks (DNN) on parallel lexical and acoustic features. Tilk et al.

[30]

use a long short-term memory (LSTM) based punctuation prediction model trained with text and speech data by two stages. Klejch et al.

[20, 21]

propose a recurrent neural network (RNN) encoder-decoder architecture with an attention layer to restore punctuation marks by fusing lexical and prosody features. However, these models need to utilize the lexical data with the corresponding speech data. So the use of text and speech data is limited. Yi et al.

[35]

propose to train self-attention based models using word and speech embeddings. This method can use any kind of text and speech data. It also obtains obvious performance improvement with pre-trained vectors. However, it still has a limitation to utilize enough information from the text data.

In fact, it is not difficult to obtain a large amount of available text data. Therefore, this paper only focuses on lexical features based methods. A lot of studies have been tried to restore punctuation marks only using text data.

One kind of methods is that punctuation marks are treated as hidden inter-word events [22]. Beeferman et al. [3]

propose to train an n-gram language model (LM) using punctuated text data. The n-gram LM is also used to predict punctuation marks and perform capitalization jointly by Gravano et al.

[14].

The other kind of methods is that predicting punctuation is viewed as a sequence labeling task [32, 40], in which a punctuation mark is assigned to each word. Previous studies [23, 32, 15] show that conditional random fields (CRFs) are better-suited to predict punctuation marks than n-gram LM based methods. Lu et al. [23] try to train CRF based models only with token features. Ueffing et al. [32] propose to combine syntactic features with LM scores, token features and sentence length to predict punctuation marks using CRF models. Part-of-speech (POS) tags of several continuous words are used as the features to train CRF and DNN combined lexical model by Cho et al. [8]. The results show that POS tags are helpful for improving the performance of punctuation prediction tasks. Recently, neural networks based models are used to predict punctuation marks. Unlike the previous lexical features including n-grams, LM statics, token, POS tags and other syntax information etc., the lexical features of the neural networks are pre-trained word vectors. Che et al. [6]

propose to train DNN and convolution neural network (CNN) based models using word embeddings. The results show that the neural network based methods outperform the CRF based method over purely text data. More recently, Tilk et al. use bidirectional recurrent neural network with attention mechanism (T-BRNN)

[31] to improve the performance. Yi et al. [39] propose to use bidirectional LSTM (BLSTM) with a CRF layer (BLSTM-CRF) and an ensemble of models to predict punctuation. Most recently, Kim [18] uses deep recurrent neural networks with layer-wise multi-head attentions for punctuation restoration. The best model in [18] has achieved the state-of-the-art performance with purely lexical features on IWSLT2011 datasets [6]. The overall -score of the model is up to 68.6%.

The aforementioned methods show that any kind of text data can be utilized through pre-trained word vectors. They also demonstrate that POS tags are useful lexical features. However, they still have two limitations. (1) One is the word vectors are trained using left-to-right language modeling objective functions [25, 24]. Thus the word embeddings only have unidirectional knowledge. (2) The other is that an extra POS tagger is needed to provide tags information for the input sequence during predicting stage. So it not only increases computation cost, but also introduces some errors from the POS tagger. Therefore, this paper proposes adversarial transfer learning to alleviate these problems.

Inspired by the promising results of pre-trained bidirectional encoder representations from transformers (BERT) model on many natural language processing (NLP) tasks

[10], this paper tries to transfer model parameters from a pre-trained BERT model to initialize a punctuation prediction model as shown in Fig. 1. The BERT model is trained by fusing context from both left and right directions. Unlike word embeddings, the transferred parameters contain both left-to-right and right-to-left representations.

Furthermore, motivated by the success of adversarial learning [13] on domain adaptation [12], Chinese word segmentation [7], environment and speaker adaptation [28, 27] and low resource speech recognition [37, 36, 38] tasks, this paper proposes to combine multi-task learning and adversarial training to solve the second limitation as shown in Fig. 2. Multi-task learning [4] a special instance of transfer learning. The conclusions drawn by Caruana in [4] show that multi-task learning is effective for improving the performance of a single task, due to the extra information contained in the training signals for the other related tasks. Therefore, a POS tagging task is used as an auxiliary task to improve the performance of the punctuation prediction task. The model of the punctuation prediction task and the POS tagging task consists of shared and private layers. The shared layers contain task independent information. The task specific features are learned from the private layers of each task. However, the shared layers may learn some unnecessary task specific information. Thus adversarial learning is used to ensure that the shared layers of the model learn more task invariant knowledge. We only use the punctuation prediction task to output punctuation marks during decoding. Thus this method can use the syntactic features from the POS tagging task without increasing any extra computation and introducing unnecessary errors from the POS tagger.

There has been no work, to the best of our knowledge, that combines transfer learning and adversarial strategy to improve the performance of punctuation restoration tasks. The main contributions of this paper are as follows. (1) A pre-trained BERT model is used to transfer bidirectional representations to punctuation prediction models. (2) Adversarial multi-task learning is used to learn task invariant information with an extra POS tagging task for the punctuation prediction task. Experiments are conducted on IWSLT2011 datasets. The results demonstrate that the punctuation predicting models initialized by the pre-trained BERT model obtain significant performance improvement against the models initialized randomly by up to 9.4% absolute overall -score on test set. The results also show that the punctuation predicting models obtain further performance gains with task invariant knowledge from the POS tagging task. Our best model achieves better results than the previous state-of-the-art model trained with purely text data [18] and combination of lexical and acoustic features [35] by up to 9.2 % and 4.9 % absolute overall -score on test set, respectively.

Fig. 1: The architecture of BERT-BLSTM-CRF model. BERT layers are initialized by a pre-trained language representation model. BLSTM-CRF layers are initialized randomly.
Fig. 2:

The architecture of the proposed adversarial BERT-BLSTM-CRF model. The task shared layers are from the pre-trained BERT model, which has a stack of 12 identical layers. The task specific classifiers are used for a punctuation predicting task and a POS tagging task, respectively. Both of them consist of BLSTM-CRF layers. FC denotes the fully connected layer. The gradient reversal layer (GRL) is introduced to ensure the feature distributions over all the tasks are as indistinguishable as possible for the task discriminator. The outputs of the task discriminator are task labels:

PUN and POS. PUN denotes the punctuation predicting task. POS is referred as to the POS tagging task.

The rest of this paper is organized as follows. Section II briefly introduces how to transfer parameters from a pre-trained BERT model to a punctuation prediction model. How to transfer task invariant knowledge from a POS tagging task is presented in Section III. Experiments are described in Section IV. The results are discussed in Section V. The conclusions are drawn in Section VI.

Input words Susan where is the national library
Output labels COMMA O O O O QUESTION
TABLE I: An example of inputs and outputs for the punctuation predicting task.
Input words Oh it is a beautiful morning
Output labels UH PRP VBZ DT JJ NN
TABLE II: An example of inputs and outputs for the POS tagging task.

Ii Transfer parameters from Pre-trained model

Inspired by the state-of-the-art results of pre-trained BERT on many NLP tasks [24], we initialize a punctuation prediction model by the parameters from a pre-trained BERT model. The BERT model is trained with bidirectional context information. Unlike word embeddings, the transferred parameters contain both left-to-right and right-to-left information. The model architecture used to predict punctuation marks is shown in Fig. 1. It consists of BERT and BLSTM-CRF layers. The BERT layers are from a pre-trained BERT model proposed by Devlin et al. [10]. The BLSTM-CRF layers are motivated by the work [16]. Thus the model for punctuation prediction is called BERT-BLSTM-CRF.

In this paper, predicting punctuation is viewed as a sequence labeling task. An example of inputs and outputs for the punctuation predicting task are listed in Table I. The inputs are unpunctuated words, e.g. “Susan where is the national library”. The corresponding outputs are punctuation marks, such as “COMMA O O O O QUESTION”. The details of punctuation marks are described in Section IV-A.

Ii-a BERT layers

BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. The pre-trained BERT model can be used to finetune with just one additional output layer for many tasks, such as named entity recognition and question answering etc. It achieves state-of-the-art results on several NLP tasks

[10]. The architecture of a BERT model is a multi-layer bidirectional Transformer encoder proposed by Vaswani et al. [33].

The encoder consists of a stack of identical layers as shown at the bottom of Fig. 1

. Each layer has two sub-layers. The first is a multi-head self-attention mechanism. The second is a fully connected feed-forward network. A residual connection is employed around each of the two sub-layers, followed by layer normalization.

Positional encodings are utilized to make use of the order of the input or output sequence. The input embeddings are learnt from input tokens similarly to other sequence transduction models. The dimension of the embeddings is denoted by .

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. Instead of performing a single attention function with -dimensional keys, values and queries, Vaswani et al. find it beneficial to linearly project the queries, keys and values times with different, learned linear projections to , and dimensions, respectively. Multi-head attention allows these projected versions of queries, keys and values to perform the attention function in parallel, yielding -dimensional output values. Please see more details in [33, 10].

Ii-B BLSTM-CRF layers

Motivated by the work in [16], a BLSTM layer and a CRF layer are combined to form BLSTM-CRF layers as shown at the top of Fig. 1. The two layers can efficiently use past and future input features via the BLSTM layer and sentence level tag information via the CRF layer. The input features of the BLSTM layer are output encodings of pre-trained BERT layers. The CRF layer is represented by lines which connect consecutive output layers. It has a state transition matrix as parameters.

Given training examples , where is an input sequence,

denotes a corresponding ground-truth label sequence, a negative log-likelihood objective is used as the loss function. Thus the loss function

can be defined as follow.

(1)

where

is the probability of the ground-truth label sequence.

At the training stage, gradient back-propagation is used to minimize the loss function. At the decoding stage, viterbi algorithm [34] is utilized to find the most probable predicted tag sequence.

Iii Transfer task invariant knowledge from POS tagging task

Motivated by the success of adversarial training on many tasks [13, 12, 7, 28, 27, 37, 36, 38] , multi-task learning and adversarial training are combined to learn task invariant information from an extra POS tagging task. The POS tagging task is used as an auxiliary task through multi-task learning to further improve the performance of the punctuation prediction task. An adversarial loss is used to prevent the shared space from containing task specific information.

Iii-a Proposed adversarial BERT-BLSTM-CRF model

The architecture of our proposed model is shown in Fig. 2. The model is called adversarial BERT-BLSTM-CRF model. It consists of task shared layers, two task specific classifiers and an adversarial task discriminator.

The task shared layers are from a pre-trained BERT model, which has a stack of 12 identical layers. The task specific classifiers are used for a punctuation predicting task and a POS tagging task, respectively. Both of them consist of BLSTM-CRF layers.

For the punctuation predicting task, one example of inputs and outputs are listed in Table I. The inputs of the BERT layers are words, e.g. “Susan where is the national library”, while the outputs of the punctuation prediction task are punctuation marks, such as “COMMA O O O O QUESTION”. More details of punctuation marks are presented in Section IV-A.

For the POS tagging task, an example of inputs and outputs are listed in Table II. the inputs of the BERT layers are words, e.g. “Oh it is a beautiful morning”, while the outputs of the POS tagging task are POS tags, e.g. “UH PRP VBZ DT JJ NN”. More details of POS tags are introduced in Section IV-B.

However, the shared layers may learn some unnecessary task specific information. Adversarial strategy is used to prevent the shared parameter from learning task dependent information. This idea is implemented by an adversarial task discriminator.

The adversarial task discriminator is used to recognize the task label of each sequence using the task shared features. The outputs of the shared layers are converted into a fixed-size real-valued vector by a max-pooling layer. The fixed-size vector is the input of the task discriminator through a gradient reversal layer (GRL)

[12, 11]

. The task discriminator is implemented as a fully connected (FC) neural network with a single hidden layer. Rectified linear units (ReLU)

[2]

are used as the activation functions of the hidden layer.

The GRL is introduced to ensure that the feature distributions are as indistinguishable as possible for the task discriminator. Therefore, the adversarial BERT-BLSTM-CRF is to learn a representation that can generalize well from one task to another. They ensure that the internal representation of the shared parameters contains no task discriminative information.

Iii-B Multi-task learning

Dataset #TED Talks #Sentences #Tokens #COMMA #PERIOD #QUESTION #O
Training 1,690 143,991 2,102,417 158,499 132,680 11,311 1,799,927
Development 20 20,635 295,800 22,475 1,8940 1,695 252,690
Test (Ref.) 8 861 12,626 830 808 53 10,935
Test (ASR) 8 852 12,822 798 810 42 11,172
TABLE III: Overall data distributions of IWSLT datasets for the punctuation predicting task.

Multi-task learning is one instance of transfer learning [4]

. The model is trained simultaneously on the training data of multiple tasks. Each task has its own private layers to estimate the posterior probabilities of task specific labels.

For the -th task, given a dataset with training samples , where is the -th training sample, is an input sequence, is the corresponding labels for the input sequence, is the total number of labels. The multi-task model is trained to minimize the negative log-likelihood for all the tasks. So the loss function of the multi-task model is defined as:

(2)

where denotes the index of the -th task, is the total number of the tasks, is computed with a parametric classifier.

In this paper, we only use two tasks. So is equal to 2. The two tasks are a punctuation predicting task and a POS tagging task. The punctuation predicting task is defined as PUN. The POS tagging task is denoted by POS. So the output labels of the task classifier are PUN and POS.

Iii-C Adversarial training

A task discriminator is used to recognize the task label during adversarial training stage. The gradients minimizing task classification errors are passed back with an opposite sign to the shared layers through the GRL. Thus it ensures the feature distributions over all tasks are as indistinguishable as possible for the task discriminator.

Given an additional task label for each training sample , where denotes the task label for each sequence, and is the total number of tasks. The loss function of the task discriminator is formulated as:

(3)

Although the task classifier is optimized to minimize the task classification error, the gradient of the task classifier is negative so that the bottom shared parameters are trained to be task invariant.

Iii-D The final loss function

Adversarial multi-task learning is used to jointly optimize the two loss functions: and . For the standard multi-task learning, the shared representations are optimized in order to minimize the loss of the primary and auxiliary task. Adversarial multi-task learning is different from the standard multi-task learning. For adversarial multi-task learning, the shared parameters are trained to maximize both the classification accuracies of the punctuation predicting task and the POS tagging task, but to minimize the classification accuracies of the task discriminator. However, the adversarial multi-task learning works adversarially to the task discriminator through GRL. It encourages task independent features to emerge in the course of the optimization. So the shared features become punctuation marks and POS tags discriminative but task invariant. The improved task invariance leads to the improved performance of the punctuation prediction task. So the final loss function of adversarial multi-task learning is defined as:

(4)

where is the loss weight,

is gradually increased from 0 to 1 as epoch increases so that the model is stably trained

[12].

There are no parameters associated with the GRL. The GRL acts as an identity transformation during the feed-forward. However, at the back-propagation stage, the GRL takes the gradient from the subsequent level and changes its sign, such as multiplying by -1, before passing it to the preceding layer. That means the gradient is reversed through the GRL by multiplying -. Thus the shared layers can learn task invariant knowledge from the POS tagging task.

Iii-E Decoding

At the decoding stage, the POS tagging task and the adversarial task classifier are removed, but only the punctuation prediction task is utilized to predict marks. Therefore, the model can utilize task invariant syntactic features from the POS tagging task without increasing more computation and introducing extra errors from the POS tagger. Viterbi algorithm [34] is also used to find the most probable punctuation sequence.

Iv Experiments

A series of experiments are conducted to evaluate the proposed method in this section. Our experiments are conducted on English IWSLT [6] and Penn TreeBank (PTB) datasets111https://catalog.ldc.upenn.edu/LDC99T42. IWSLT datasets are used for punctuation prediction tasks. PTB datasets are utilized for POS tagging tasks. The results are reported on two test sets of IWSLT datasets.

Dataset #Sentences #Tokens
Training 39,831 950,011
Development 1,699 40,068
Test 2,415 56,671
TABLE IV: Overall data distributions of PTB datasets for the POS tagging task.
Test Model Initialization COMMA PERIOD QUESTION Overall
Ref. BERT-CRF Random 61.1 59.6 60.3 72.1 70.7 71.4 71.3 60.2 65.3 68.2 63.5 65.7
BERT-CRF Pre-trained BERT 73.7 69.1 71.3 83.7 78.8 81.2 75.1 70.1 72.5 77.5 72.7 75.0
BERT-BLSTM-CRF Random 61.9 59.9 60.9 72.4 71.1 71.7 71.5 61.0 65.8 68.6 64.0 66.2
BERT-BLSTM-CRF Pre-trained BERT 74.2 69.7 71.9 84.6 79.2 81.8 76.0 70.4 73.1 78.3 73.1 75.6
ASR BERT-CRF Random 56.1 57.1 56.6 69.1 71.1 70.1 64.0 53.6 58.3 63.1 60.6 61.7
BERT-CRF Pre-trained BERT 70.2 67.5 68.8 76.6 77.1 76.8 67.5 65.7 66.6 71.4 70.1 70.8
BERT-BLSTM-CRF Random 56.3 57.4 56.8 69.4 71.2 70.3 64.3 54.1 58.8 63.3 60.9 62.0
BERT-BLSTM-CRF Pre-trained BERT 70.7 68.1 69.4 77.6 77.5 77.5 68.4 66.0 67.2 72.2 70.5 71.4
TABLE V: Transferring parameters from pre-trained BERT to punctuation predicting models. The results of punctuation predicting models in terms of , , on test sets of IWSLT2011 datasets.

Iv-a IWSLT datasets

IWSLT datasets are from TED Talks, which are reorganized for predicting punctuation marks by Che et al. [6]. There are three kinds of datasets: training set, development set and test set.

The training and development sets are provided by the training data of IWSLT2012 machine translation track, which consists of 1,710 TED Talks. Che et al. [6] further split it into training and development sets according to the ID of TED talks. The two test sets are Ref. and ASR, which provided by the test data of IWSLT2011 ASR track. Ref. is from manual transcripts of audio files. ASR is from transcripts of the ASR system. More details of the datasets can be found in [6].

The datasets have four kinds of labels: O, COMMA, PERIOD and QUESTION. O denotes a non-punctuation mark. COMMA denotes the kind of colons or dashes. Exclamation marks or semicolons are denoted by PERIOD. QUESTION is the kind of question marks. Table III describes data statistics of IWSLT datasets.

Iv-B PTB datasets

PTB datasets consist of three annotation schemes: POS tagging, syntactic bracketing, and disfluency annotation. We only use PTB POS tagging datasets in our experiments.

The PTB tagset is based on that of the Brown Corpus, but it differs from it in a number of important ways. The PTB tagset concerns the significance of syntactic context. It encodes a word s syntactic function in its POS tag whenever possible. POS assigns each word with a unique tag that indicates its syntactic role. It contains 36 POS tags222https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html, such as UH, PRP, VBZ, DT, JJ and NN etc. UH means interjection. PRP denotes personal pronoun. VBZ is 3rd person singular present verb. DT means determiner. JJ denotes adjective and NN means singular or mass noun. Table IV describes data statistics of PTB datasets.

Iv-C Metrics

All models are evaluated in terms of precision (), recall (), -score () in our experiments. We focus on the performance of the punctuation marks. So the correctly predicted non-punctuation marks O are ignored. We only evaluate the performance of COMMA, PERIOD and QUESTION on two test sets: Ref. and ASR, respectively. “Overall” denotes the performance of all the three punctuation marks. The results of all experiments are reported on the two test sets of IWSLT datasets: Ref. and ASR. More details of metrics can be found in [6].

Iv-D Experimental Setup

Pre-trained BERT models are released by Google333https://github.com/google-research/bert

, implemented with the TensorFlow toolkit

[1]. The pre-trained models include two kinds of models444https://github.com/google-research/bert#pre-trained-models: BERT-Large and BERT-Base. The size of our experimental data is not large. Therefore, we use the Uncased BERT-Base model to initialize the BERT-BLSTM-CRF models. Uncased means that any case and accent markers are stripped out.

The basic architecture of the BERT-Base model is shown at the bottom of Fig. 1 or Fig. 2. The encoder has a stack of identical layers. The heads of the parallel self-attention are 12. For each of these heads, we set . Because of the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality. In order to use residual connections, we set . The positional encodings have the same dimension as the embeddings layers. So the two can be sumed. The total parameters of the BERT-Base is 110M. Please see [10] for pre-training details of the BERT-Base model.

The configuration of BLSTM-CRF layers is based on the work in [39]

, where there are one BLSTM layer and a CRF layer. The BLSTM layer has peephole connections and a recurrent projection layer. There are two directions in the BLSTM layer: forward and backward. Each direction is a regular LSTM layer. The LSTM layer consists of 240 memory cells and the recurrent projection layer would project the output to 120 dimensions. We initialize the BLSTM layer by the range (-0.02, 0.02) with a uniform distribution. Apart from clipping the activations of memory cells to range [-50, 50], the activations of other units, the weights or the estimated gradients are not limited.

The development sets are utilized for validation. If only a little improvement between two epochs on the development set has been observed, the training terminates.

Test Model COMMA PERIOD QUESTION Overall
Ref. BERT-BLSTM-CRF (Pre-trained BERT) 74.2 69.7 71.9 84.6 79.2 81.8 76.0 70.4 73.1 78.3 73.1 75.6
+ POS tagging task 75.0 70.3 72.6 85.5 79.7 82.5 76.9 71.3 74.0 79.1 73.8 76.4
+ Task classifer 75.1 70.3 72.6 85.9 79.9 82.8 77.0 71.5 74.1 79.3 73.9 76.5
+ Adversarial 76.2 71.2 73.6 87.3 81.1 84.1 79.1 72.7 75.8 80.9 75.0 77.8
ASR BERT-BLSTM-CRF (Pre-trained BERT) 70.7 68.1 69.4 77.6 77.5 77.5 68.4 66.0 67.2 72.2 70.5 71.4
+ POS tagging task 71.4 68.6 70.0 78.4 77.9 78.1 69.2 66.8 68.0 73.0 71.1 72.0
+ Task classifer 71.5 68.6 70.0 78.8 78.1 78.4 69.3 67.0 68.1 73.2 71.2 72.2
+ Adversarial 72.4 69.3 70.8 80.0 79.1 79.5 71.2 68.0 69.6 74.5 72.1 73.3
TABLE VI: Transferring knowledge from a POS tagging task to a punctuation predicting task. The results of punctuation predicting models in terms of , , on test sets of IWSLT2011 datasets.
Test Model Transferred Info. COMMA PERIOD QUESTION Overall
Ref. CRF Best [32] Lexical & Syntactic - - - - - - - - - 49.8 58.0 53.5
DNN-A [6] Word vectors 48.6 42.4 45.3 59.7 68.3 63.7 - - - 54.8 53.6 54.2
CNN-2A [6] Word vectors 48.1 44.5 46.2 57.6 69.0 62.8 - - - 53.4 55.0 54.2
T-LSTM [30] One-hot vectors 49.6 41.4 45.1 60.2 53.4 56.6 57.1 43.5 49.4 55.0 47.2 50.8
T-BRNN-pre [31] Word vectors 65.5 47.1 54.8 73.3 72.5 72.9 70.7 63.0 66.7 70.0 59.7 64.4
BLSTM-CRF [39] Word vectors 58.9 59.1 59.0 68.9 72.1 70.5 71.8 60.6 65.7 66.5 63.9 65.1
Teacher-Ensemble [39] Word vectors 66.2 59.9 62.9 75.1 73.7 74.4 72.3 63.8 67.8 71.2 65.8 68.4
DRNN-LWMA-pre [18] Word vectors 62.9 60.8 61.9 77.3 73.7 75.5 69.6 69.6 69.6 69.9 67.2 68.6
Self-attention [35] Word & Speech vectors 67.4 61.1 64.1 82.5 77.4 79.9 80.1 70.2 74.8 76.7 69.6 72.9
Our best model BERT & POS task 76.2 71.2 73.6 87.3 81.1 84.1 79.1 72.7 75.8 80.9 75.0 77.8
ASR CRF Best [32] Lexical & Syntactic - - - - - - - - - 47.8 54.8 51.0
DNN-A [6] Word vectors 41.0 40.9 40.9 56.2 64.5 60.1 - - - 49.2 51.6 50.4
CNN-2A [6] Word vectors 37.3 40.5 38.8 54.6 65.5 59.6 - - - 46.4 51.9 49.1
T-LSTM [30] One-hot vectors 41.8 37.8 39.7 56.4 49.3 52.6 55.6 42.9 48.4 49.1 43.6 46.2
T-BRNN-pre [31] Word vectors 59.6 42.9 49.9 70.7 72.0 71.4 60.7 48.6 54.0 66.0 57.3 61.4
BLSTM-CRF [39] Word vectors 55.7 56.8 56.2 68.7 71.5 70.1 63.8 53.4 58.1 62.7 60.6 61.5
Teacher-Ensemble [39] Word vectors 60.6 58.3 59.4 71.7 72.9 72.3 66.2 55.8 60.6 66.2 62.3 64.1
DRNN-LWMA-pre [18] Word vectors - - - - - - - - - - - -
Self-attention [35] Word & Speech vectors 64.0 59.6 61.7 75.5 75.8 75.6 72.6 65.9 69.1 70.7 67.1 68.8
Our best model BERT & POS task 72.4 69.3 70.8 80.0 79.1 79.5 71.2 68.0 69.6 74.5 72.1 73.3
TABLE VII: Compared with other models on IWSLT2011 datasets. The results of punctuation predicting models in terms of , , on test sets .

Iv-E Transferring parameters from pre-trained BERT

In this section, we evaluate the effectiveness of transferring parameters from pre-trained BERT on IWSLT datasets. Two model architectures are designed for punctuation predicting task: BERT-CRF and BERT-BLSTM-CRF. The output labels of the classification layer are three punctuation marks and one non-punctuation mark O.

BERT-CRF: A simple classification layer is added to the BERT layers. The classification layer is a CRF layer.

BERT-BLSTM-CRF: We add two layers at the top of the BERT layers as shown in Fig. 1

. They are BLSTM-CRF layers, which consist of a one-layer BLSTM and a CRF layer. A linear transformation is used to convert the 768-dimensional activations to 120-dimensional BLSTM layer.

In the first group of experiments, all the parameters of BERT-CRF and BERT-BLSTM-CRF are initialized randomly. The models are trained for 6 epochs over the training data. The Adam algorithm [19]

with gradient clipping and warmup is used for optimization. The

is set to 4,000. The is set to 32, which means that each batch contains 32 sentences. The rate of dropout is set to 0.1. The initial learning rate is 5e-4. The learning rate is varied over the course of training, according to the formula in [33].

In the second group of experiments, the parameters of BERT layers of BERT-CRF and BERT-BLSTM-CRF are first initialized with the parameters of the pre-trained BERT model. Then all of the parameters are jointly fine-tuned using training data for the punctuation predicting task. The models are fine-tuned for 3 epochs over the training data. The is set to 32. The rate of dropout is set to 0.1. we select the best fine-tuning learning rate of 5e-5 on the development set.

The results of the models are reported in Table V. The results show that the BERT-BLSTM-CRF models obtain better results than the BERT-CRF models accordingly. The BERT-CRF models with random initialization achieve the worst results among our models on both two test sets. But the BERT-BLSTM-CRF models with transferred parameters from pre-trained BERT model obtain the best performance on both two test sets: Ref. and ASR.

The results also show that the models initialized from pre-trained BERT model outperform the models with random initialization significantly. The BERT-CRF model initialized from pre-trained BERT model obtains better performance than that with random initialization by 9.3% and 9.1% absolute overall -score on Ref. and ASR test sets, respectively. The BERT-BLSTM-CRF model initialized from pre-trained BERT model outperforms that initialized randomly by 9.4% and 9.4% absolute overall -score on Ref. and ASR test sets, respectively.

In the rest of our experiments, we use BLSTM-CRF layers as task specific classification layers on top of the BERT layers for both punctuation predicting task and POS tagging task. In addition, the BERT layers are initialized by the parameters of the pre-trained BERT model.

Iv-F Transferring knowledge from POS tagging task

A series of experiments are performed to evaluate the performance of the punctuation predicting task trained with the knowledge transferred from a POS tagging task. The POS tagging task is trained on PTB POS tagging datasets. We fine-tune the models for 3 epochs with a batch size of 32.

In the first group of experiments, we train the punctuation restoring task only with the help of the POS tagging task via multi-task learning. The architecture of the POS tagging specific layers is identical to that of the punctuation restoring specific layers. It consists of one-layer BLSTM and a CRF layer. The two tasks share pre-trained BERT layers. The output of the BERT layers are the input of the two tasks. A linear transformation is used to convert the 768-dimensional output to 120-dimensional input for the two task specific BLSTM layers. The output labels of the POS tagging task are 36 POS tags.

In the second group of experiments, we add a task discriminator to the aforementioned multi-task model. The task discriminator has one max-pooling layer, a FC layer and a softmax layer. The ReLU activation functions are used in the 1024-dimensional FC layer. The task classifier has two output labels:

PUN and POS. PUN denotes the punctuation predicting task. POS is referred as to the POS tagging task. During training, we select a task from {PUN, POS} at each iteration. Then, we use a batch of training samples from the given task to update the parameters. The Adam algorithm [19] with gradient clipping and warmup is used for to optimize the loss function. The punctuation predicting task and POS tagging task may have different convergence rate. So we repeat the above iterations until early stopping according to the punctuation predicting performance.

In the third group of experiments, we further add a GRL in the above-mentioned task discriminator. The GRL is after the max-pooling layer and before the FC layer as shown in Fig. 2. The GRL has no parameters. The dropout rate is fixed at 0.1. The loss weight is initiated at 0 and is gradually changed to 1 with the following formula [12]:

(5)

where is the training progress linearly changing from 0 to 1, is set to 10 in all experiments.

This strategy allows the task classifier to be less sensitive to noisy signal at the early stages of the training procedure. Note that the is used only for updating the shared BERT layers. However, for updating the task classification component, we use a fixed , to ensure that the latter trains as fast as the two task specific classifiers [12]. The Adam algorithm is also used for to optimize the final loss function.

The results of the above three groups of experiments are listed in Table VI. The results show that with the help of the POS tagging task via multi-task learning, the punctuation restoring task obtains performance improvement both on Ref. and ASR test sets. The results also demonstrate that when the multi-task model with an additional task classifier, the performance of the punctuation predicting models improve moderately on two test sets. However, the punctuation predicting models achieve further obvious improvements when the multi-task model with an extra adversarial task discriminator on Ref. and ASR test sets, respectively.

Iv-G Compared with other methods

We also compare our proposed models with previous models on IWSLT2011 datasets. The previous results are listed in Table VII. ”Transferred Info.” in Table VII denotes the type of transferred information from text data.

CRF best is the best model proposed by Ueffing et al. [32]. DNN-A and CNN-2A are the best models from [6], in which Che et al. half the value of softmax output for class O . T-LSTM represents the first stage model from [30] that Tilk et al. train on the English IWSLT2011 dataset. T-BRNN-pre

is the best attention model proposed by Tilk et al.

[31]. BLSTM-CRF denotes the best single model trained in [39]. Teacher-Ensemble is the best ensemble model proposed by Yi et al. [39]. DRNN-LWMA-pre represents the best multi-head attention based model from [18]. Self-attention [36] achieves the state-of-the-art results.

CRF best, DNN-A, CNN-2A, T-LSTM, T-BRNN-pre, BLSTM-CRF, Teacher-Ensemble and DRNN-LWMA-pre models in Table VII are trained only with text data. Whereas Self-attention model is trained using both lexical and prosody features. Our models are trained only using text data.

The results show that our best models with purely lexical features outperform all the previous state-of-the-art models. When compared with the best model in [18], the overall -score of our best model improves absolutely by 9.2% on Ref. test set. When compared with the lexical and prosody model: Self-attention [36], the overall -score of our best model also improves absolutely by 4.9% and 4.5% on Ref. and ASR test set, respectively.

V Discussions

The above experimental results show that the proposed adversarial transfer learning is effective. Some interesting observations are made as follows.

The punctuation predicting models obtain significant performance improvement via transferred parameters from pre-trained BERT model. The parameters transferred from the pre-trained BERT model are used to initialize the punctuation predicting models. It is helpful for at least three reasons. One reason is that the punctuation predicting model has parameters for feature types observed in the a large amount of external unlabeled text data as well as punctuated text data. Thus it has better feature coverage. The second reason is that the training objective is non-convex. So this initialization can be benefited in avoiding bad local optima. The third reason is that pre-trained BERT model is a deep bidirectional language model trained on both left and right context. Thus the punctuation predicting model can use left-to-right and right-to-left representations transferred from the pre-trained BERT model. The bidirectional knowledge is useful for predicting punctuation marks.

The punctuation predicting task benefits from a POS tagging task. The syntactic features convey useful information if the input text is formal and well-structured. POS tagging corpus encode a word s syntactic function in its POS tag whenever possible. POS assigns each word with a unique tag that indicates its syntactic role. So the punctuation predicting task can learn helpful syntactic knowledge from the POS tagging task.

The punctuation predicting models gain further obvious performance improvement with task invariant knowledge. Although the punctuation restoring task obtains performance improvement with the help of the POS tagging task via multi-task learning, the punctuation predicting models achieve further obvious improvements when the multi-task model trained with an extra adversarial task discriminator. The main possible reason is that the shared layers of the multi-task model may learn some unnecessary task specific features. However, the adversarial loss makes the shared layers to prevent from learning the task dependent information. So the punctuation predicting models with an adversarial task classifier can learn more task invariant features. The transferred task invariant knowledge are helpful for improving the performance of the punctuation predicting model.

In summary, all the punctuation predicting models benefit from both better feature coverage and better initialization, as well as syntactic knowledge via transfer learning. Moreover, the adversarial strategy forces the shared layers to prevent from containing task dependent information. The punctuation predicting models benefit from the task invariant features by adversarial transfer learning.

Vi Conclusion

This paper proposes adversarial transfer learning to improve the performance of punctuation predicting tasks. Bidirectional representations are transferred from a pre-trained BERT model to punctuation prediction models. Furthermore, task invariant knowledge is learnt for the punctuation prediction task with an auxiliary POS tagging task via adversarial multi-task learning. Experiments are conducted on IWSLT2011 datasets. The results demonstrate that the punctuation predicting models with transferred parameters from pre-trained BERT model outperform the models with random initialization significantly. The results also show that the punctuation predicting models with task invariant knowledge obtain further performance improvement. Our best model outperforms the previous state-of-the-art models. Future work includes applying the proposed method to other speech signal processing tasks.

Acknowledgments

This work is supported by the National Key Research & Development Plan of China (No. 2017YFC0820602) and the National Natural Science Foundation of China (NSFC) (No.61425017, No. 61773379, No. 61603390, No. 61771472, No. 61901473), and Inria-CAS Joint Research Project (No. 173211KYSB20190049).

References

  • [1] M. Abadi, P. Barham, J. Chen, Z. Chen, A. Davis, J. Dean, M. Devin, S. Ghemawat, G. Irving, and M. Isard (2016)

    TensorFlow: a system for large-scale machine learning

    .
    In 12th USENIX Symposium on Operating Systems Design and Implementation, pp. 265–283. Cited by: §IV-D.
  • [2] M. Andrew, H. Lempitsky, and Ng. Lempitsky (2013) Rectifier nonlinearities improve neural network acoustic models. In Proceedings of International Conference on International Conference on Machine Learning, Cited by: §III-A.
  • [3] D. Beeferman, A. Berger, and J. Lafferty (1998) Cyberpunc: a lightweight punctuation annotation system for speech. In ICASSP, pp. 689–692 vol.2. Cited by: §I.
  • [4] R. Caruana (1997) Multitask learning. Machine Learning 28 (1), pp. 41–75. Cited by: §I, §III-B.
  • [5] X. Che, S. Luo, H. Yang, and C. Meinel (2016) Sentence boundary detection based on parallel lexical and acoustic models. In INTERSPEECH, pp. 2528–2532. Cited by: §I.
  • [6] X. Che, C. Wang, H. Yang, and C. Meinel (2016) Punctuation prediction for unsegmented transcript based on word vector. In LREC, pp. 654–658. Cited by: §I, §IV-A, §IV-A, §IV-C, §IV-G, TABLE VII, §IV.
  • [7] X. Chen, Z. Shi, X. Qiu, and X. Huang (2017) Adversarial multi-criteria learning for chinese word segmentation. In ACL, pp. 1193–1203. Cited by: §I, §III.
  • [8] E. Cho, K. Kilgour, N. J., and W. A. (2015) Combination of nn and crf models for joint detection of punctuation and disfluencies. In INTERSPEECH, pp. 3650–3654. Cited by: §I.
  • [9] H. Christensen, Y. Gotoh, and S. Renals (2001) Punctuation annotation using statistical prosody models. Proc Isca Workshop on Prosody in Speech Recognition and Understanding, pp. 35–40. Cited by: §I.
  • [10] J. Devlin, M. Chang, K. Lee, and K. Toutanova BERT: pre-training of deep bidirectional transformers for language understanding. In arXiv preprint arXiv:1603.00786, Cited by: §I, §II-A, §II-A, §II, §IV-D.
  • [11] Y. Ganin and V. Lempitsky (2015)

    Unsupervised domain adaptation by backpropagation

    .
    In International Conference on International Conference on Machine Learning, pp. 1180–1189. Cited by: §III-A.
  • [12] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, M. Marchand, and V. Lempitsky (2016) Domain-adversarial training of neural networks. Journal of Machine Learning Research 17 (1), pp. 2096–2030. Cited by: §I, §III-A, §III-D, §III, §IV-F, §IV-F.
  • [13] I. J. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio (2014) Generative adversarial networks. Advances in Neural Information Processing Systems 3, pp. 2672–2680. Cited by: §I, §III.
  • [14] A. Gravano, M. Jansche, and M. Bacchiani (2009) Restoring punctuation and capitalization in transcribed speech. In ICASSP, pp. 4741–4744. Cited by: §I.
  • [15] M. Hasan (2015) Noise-matched training of crf based sentence end detection models. In INTERSPEECH, pp. 349–353. Cited by: §I.
  • [16] Z. Huang, W. Xu, and Y. K. (2015) Bidirectional lstm-crf models for sequence tagging. Computer Science. Cited by: §II-B, §II.
  • [17] J. Kim and P. C. Woodland (2003) A combined punctuation generation and speech recognition system and its performance enhancement using prosody. Speech Communication 41 (4), pp. 563–577. Cited by: §I.
  • [18] S. Kim (2019) Deep recurrent neural networks with layer-wise multi-head attentions for punctuation restoration. In ICASSP, pp. 7280–7284. Cited by: §I, §I, §IV-G, §IV-G, TABLE VII.
  • [19] D. Kingma and B. J. (2015) Adam: a method for stochastic optimization. In ICLR, Cited by: §IV-E, §IV-F.
  • [20] O. Klejch, P. Bell, and S. Renals (2016) Punctuated transcription of multi-genre broadcasts using acoustic and lexical approaches. In Spoken Language Technology Workshop, pp. 433–440. Cited by: §I.
  • [21] O. Klejch, P. Bell, and S. Renals (2017) Sequence-to-sequence models for punctuated transcription combining lexical and acoustic features. In ICASSP, pp. 5700–5704. Cited by: §I.
  • [22] E. Liu, A. Stolcke, D. Hillard, M. Ostendorf, and M. Harper (2006) Enriching speech recognition with automatic detection of sentence boundaries and disfluencies. IEEE Trans Audio Speech Language Process 14 (5), pp. 1526–1540. Cited by: §I.
  • [23] W. Lu and H.T. Ng (2010) Better punctuation prediction with dynamic conditional random fields.. In EMNLP, pp. 177–186. Cited by: §I.
  • [24] T. Mikolov, I. Sutskever, K. Chen, G. Corrado, and J. Dean (2013) Distributed representations of words and phrases and their compositionality. In Proceedings of the Twenty-Second Annual Conference on Neural Information Processing Systems, Cited by: §I.
  • [25] A. Mnih and G. E. Hinton (2008) A scalable hierarchical distributed language model. In Proceedings of the Twenty-Second Annual Conference on Neural Information Processing Systems, Cited by: §I.
  • [26] A. Nanchen and P. N. Garner (2019) Empirical evaluation and combination of punctuation prediction models applied to broadcast news. In ICASSP, pp. 7275–7279. Cited by: §I.
  • [27] G. Saon, G. Kurata, T. Sercu, and et al. (2017) English conversational telephone speech recognition by humans and machines. In INTERSPEECH, pp. 132–136. Cited by: §I, §III.
  • [28] Y. Shinohara (2016) Adversarial multi-task learning of deep neural networks for robust speech recognition. In INTERSPEECH, pp. 2369–2372. Cited by: §I, §III.
  • [29] G. Szasza´k and M. Tundik (2019) Leveraging a character, word and prosody triplet for an asr error robust and agglutination friendly punctuation approach. In INTERSPEECH, pp. 2988–2992. Cited by: §I.
  • [30] O. Tilk and T. Alumae (2015) LSTM for punctuation restoration in speech transcripts. In INTERSPEECH, pp. 683–687. Cited by: §I, §IV-G, TABLE VII.
  • [31] O. Tilk and T. Alumae (2016) Bidirectional recurrent neural network with attention mechanism for punctuation restoration. In INTERSPEECH, pp. 3047–3051. Cited by: §I, §IV-G, TABLE VII.
  • [32] N. Ueffing, M. Bisani, and P. Vozila (2013) Improved models for automatic punctuation prediction for spoken and written text. In INTERSPEECH, pp. 3097–3101. Cited by: §I, §I, §IV-G, TABLE VII.
  • [33] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NIPS, Cited by: §II-A, §II-A, §IV-E.
  • [34] (2010) Viterbi algorithm. In Encyclopedia of Machine Learning, C. Sammut and G. I. Webb (Eds.), pp. 1025–1025. External Links: ISBN 978-0-387-30164-8, Document, Link Cited by: §II-B, §III-E.
  • [35] J. Yi, J. Tao, Y. Bai, and Y. Li (2019) Self-attention based model for punctuation prediction using word and speech embeddings. In ICASSP, pp. 7270–7274. Cited by: §I, §I, TABLE VII.
  • [36] J. Yi, J. Tao, and Y. Bai (2019) Language-invariant bottleneck features from adversarial end-to-end acoustic models for low resource speech recognition. In ICASSP, pp. 6071–6075. Cited by: §I, §III, §IV-G, §IV-G.
  • [37] J. Yi, J. Tao, Z. Wen, and Y. Bai (2018) Adversarial multilingual training for low-resource speech recognition. In ICASSP, pp. 4899–4903. Cited by: §I, §III.
  • [38] J. Yi, J. Tao, Z. Wen, and Y. Bai (2019) Language-adversarial transfer learning for low-resource speech recognition. IEEE/ACM Transactions on Audio, Speech, and Language Processing 27 (3), pp. 621–630. Cited by: §I, §III.
  • [39] J. Yi, J. Tao, Z. Wen, and Y. Li (2017) Distilling knowledge from an ensemble of models for punctuation prediction. In INTERSPEECH, pp. 2779–2783. Cited by: §I, §IV-D, §IV-G, TABLE VII.
  • [40] P. Żelasko, P. Szymański, J. Mizgajski, A. Szymczak, Y. Carmiel, and N. Dehak (2018) Punctuation prediction model for conversational speech. In INTERSPEECH, pp. 2633–2637. Cited by: §I.