Classification as Decoder: Trading Flexibility for Control in Medical Dialogue

by   Sam Shleifer, et al.

Generative seq2seq dialogue systems are trained to predict the next word in dialogues that have already occurred. They can learn from large unlabeled conversation datasets, build a deeper understanding of conversational context, and generate a wide variety of responses. This flexibility comes at the cost of control, a concerning tradeoff in doctor/patient interactions. Inaccuracies, typos, or undesirable content in the training data will be reproduced by the model at inference time. We trade a small amount of labeling effort and some loss of response variety in exchange for quality control. More specifically, a pretrained language model encodes the conversational context, and we finetune a classification head to map an encoded conversational context to a response class, where each class is a noisily labeled group of interchangeable responses. Experts can update these exemplar responses over time as best practices change without retraining the classifier or invalidating old training data. Expert evaluation of 775 unseen doctor/patient conversations shows that only 12 doctor ended up writing, compared to 18


page 1

page 2

page 3

page 4


Classification As Decoder: Trading Flexibility For Control In Neural Dialogue

Generative seq2seq dialogue systems are trained to predict the next word...

DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation

We present a large, tunable neural conversational response generation mo...

Polite Dialogue Generation Without Parallel Data

Stylistic dialogue response generation, with valuable applications in pe...

Plug-and-Play Conversational Models

There has been considerable progress made towards conversational models ...

Probing the Robustness of Trained Metrics for Conversational Dialogue Systems

This paper introduces an adversarial method to stress-test trained metri...

A Generative Model of Group Conversation

Conversations with non-player characters (NPCs) in games are typically c...

Auto Response Generation in Online Medical Chat Services

Telehealth helps to facilitate access to medical professionals by enabli...

1 Introduction

Modern chatbots are built around two paradigms: the more structured slot-filling paradigm, and the unstructured generative seq2seq paradigm.

The first task-oriented group, exemplified by Budzianowski et al. (2018), tend to solve narrow tasks like restaurant and hotel reservations and require access to a large data structure. This setup is too cumbersome for primary care medical conversations because (a) building the external knowledge base would require the enumeration of the very large symptom, diagnosis and remedy spaces and (b) each module requires separate training data in large volumes. The seq2seq group, which we call generative models (GM) require neither labeling nor structured representations of the dialogue state, but manage to learn strong representations of the conversational context with similar content to a knowledge base, according to (Petroni and others, 2019). They have a key drawback, however: there are no mechanisms to ensure high quality responses. See et al. (2019) find that GM "often repeat or contradict previous statements" and produce generic, boring text, and GM can be attacked to "spew racist output" (Wallace et al., 2019). Even in a cooperative setting, typos, inaccuracies, and other frequent mistakes in the training data will be reproduced by the model at inference time. This drawback is even more important in medical settings, where giving patients bad advice is costly and potentially unsafe.

Our discriminative setup attempts to remedy this issue by restricting the chatbot to a manageable set of high quality “exemplar” responses. We ensure that exemplars are all factual, sensical and grammatical by allowing experts to edit them before or after training. For example, if we were to update a response recommending users sleep 6-8 hours per night to recommending 7-9 hours, we could simply update the message associated with the output class and the discriminative model would immediately generate the new advice in the same context it generated the old advice, without retraining.

We address a key difficulty in this setup – creating non-overlapping response groups that cover a wide range of situations – with weak supervision. A pretrained similarity model merges nearly identical responses into clusters, and a human merges the most frequently occurring of these clusters into larger response classes. This results in a system that leverages novel pretraining techniques to generate useful responses in a wide variety of contexts, while still restricting generations to a fixed set of high quality responses.

2 Related Work

Healthcare dialog models:

Most published dialog models in healthcare generate templated content supported by a knowledge graph (

c.f. Laranjo et al. (2018) for a comprehensive survey). Fitzpatrick et al. (2017)

proposes Woebot, a conversational agent designed to deliver cognitive behavioral therapy in the form of brief conversations with users. Underlying Woebot is a decision tree, where each node has a piece of content to send to the user, and (for some nodes) a proprietary NLP system, to determine which node to send the user to based on their most recent reply.

Minutolo et al. (2017) prototypes a system for turning medical factoid questions into structured queries over a knowledge graph. The system covers a few example medical conditions, and asks the patient for more information until enough slots are filled to execute a valid query. Patient utterances must match a specific set of templates and synomyms to ensure correct queries.

Our work departs from this stream along multiple dimensions. First, we do not assume access to an external knowledge base. Second, we cover a wider range of medical conditions. Third, our model does not require perfect user input to generate good responses.

Generative Dialog Models: Wolf et al. (2019)

won the 2019 PersonaChat competition with "TransferTransfo", a generative transformer approach. The model starts training with pretrained weights from the GPT2 transformer, then finetunes with the PersonaChat data on a combination of two loss functions: next-utterance classification loss and language modeling (next word prediction) loss. Generation is performed in a typical generative manner: beam search with sampling and a blacklist to prevent copying from old utterances. We compare our architecture to this approach in Section


Discriminative Dialog Models: The closest work to ours is Wan and Chen (2018), AirBNB’s customer service chatbot, which also uses a discriminative approach, but does not attempt to cover the whole response space and differs architecturally. Whereas our approach restricts the output space to 187 responses that attempt to cover the whole output space, the AirBNB system chooses from 71 one sentence investigative questions, each representing a cluster of questions, and leaves other response types, like statements and courtesy questions, to a separate model.

3 Approach

We aim to use the last turns of conversational context to suggest a response for a doctor to send to a patient. Our process involves two stages: (1) create groups of interchangeable doctor utterances, to use as labels for (2) train a classifier to predict a response class given the context that preceded it.

Weak Supervision Procedure for Generating Response Classes We aim to generate response classes, where each response class is a group of interchangeable responses observed in the conversation data, with the following characteristics: (1) Low overlap between classes, (2) sufficient train examples in each class, (3) classes that cover a large number of unique responses. In the Figure below, we show our five stage procedure and briefly detail the steps in the figure’s caption. A more detailed, mathematical explanation of the same procedure is presented in the Appendix.

Figure 1: The leftmost steps (1) and (2) are designed to generate candidate pairs of responses in a semantic neighborhood, in order to avoid evaluating the similarity of O( responses. We use BERT from Devlin et al. (2018)

to evaluate similarity between pairs of responses in step 3. Step 4 runs Agglomerative Clustering the distance between two responses is their predicted probability of dissimilarity or 1 if they were not generated as a candidate pair. Merging requires complete linkage, which means that two clusters are only merged if all the responses in both clusters are >= 75% similar to all responses in the other cluster. Step 5 is manual, and somewhat more involved. First create a dataset containing (centroid text, # occurences of cluster constituents), sorted by # occurrences, where the centroid is the most frequently occurring response in the cluster. For each row in the dataset, the labeler decides whether the cluster centroid text belongs in any existing response class. If centroid response could be used interchangeably with centroids in an existing group (most of the time): add the cluster to the existing response class. Otherwise, create a new group with a memorable name.

Conversation Context Response Class: Classification Training with ULMFit We train our discriminative response suggestion model to classify conversational context to one of the 187 response classes. (Context, ResponseClass) pairs are only included in the labeled training data if the true response is a member of one of the response groups created in the step above. We follow Howard and Ruder (2018)’s ULMFit approach with a few modifications, most notably adding Label Smoothing to the loss function. The appendix details all modifications, and diagrams the training and inference pipelines.

4 Experiments

Data For language model finetuning, we use 300,000 doctor/patient interactions containing 1.8 million rounds of Doctor/patient exchanges, collected through a web and mobile application for primary care consultations. We use the most recent 100,000 interactions, which contain 500,000 rounds as input to the Response Class Generation process, which yields 72,981 (context, response) pairs for classification training. The number of turns per conversation (mean 10.8, std: 7.85) and length of each turn(mean: 20.4, std: 21.8) varies widely.

Clustering Statistics: Preprocessing and filtering yielded 60,000 responses. Step 2 yielded 1 million candidate pairs for evaluation. Step 4 yielded 40,000 response clusters with many overlapping groups; the largest cluster is only 10 distinct responses. In step 5, one labeler created 187 groups from the 3,000 most frequently occurring clusters in 3 hours. This leaves 90% of all responses in the data unlabeled. One advantageous property of our approach is that the human merging step need not be fully completed. In other approaches, like Wan and Chen (2018)’s, where there are fewer, larger automatically generated clusters and a human iterates through them and removes heterogeneneous constituents, every response must be considered. This would have taken us 40 hours, if we extrapolate linearly.

We hypothesize that fully automating the clustering process is difficult because the pretrained sentence encoders used in our candidate generation step are misaligned with our merge criteria, which is more permissive than pure semantic similarity. For example, none of our pretrained sentence encoders produce ("You’re welcome. Hoping for the best.", "Take care, my pleasure.") as a candidate pair.

4.1 Evaluation criteria

Expert evaluations for end-to-end comparisons. To compare discriminative and generative approaches, we construct a test set constructed of (conversation, response) pairs that are held out from training and validation data. Roughly 91% of test data responses are unlabeled. We call a response unlabeled if it is not an exact duplicate of responses in our 187 response class clusters.

Given the low correlation of automated metrics such as BLEU score to human judgment of response quality reported in Liu et al. (2016), a group of medical doctors evaluated the quality of generated responses on the test data. For a given conversational context, evaluators compared the doctor response observed in the data to a model’s suggested response. Evaluators reported whether a model’s response is either (a) equivalent to the observed response, (b) different but higher quality, (c) different but equal quality, or (d) different but lower quality. For example, “Hoping for the best, take care.” and "Take care!!!!” would be marked equivalent.

Accuracy on unseen labeled data is used to compare different classifiers on the same dataset.

4.2 Results

We find that on 775 test set conversations, the discriminative model compares favorably to the generative model, generating responses evaluated as worse than those observed in the data only 12% of the time, compared to 18% for the generative model.

Generative Discriminative
a. Equivalent to Dr. 56% 71%
b. Different, higher quality 1% 6%
c. Different, equal quality 25% 11%
d. Different, lower quality 18% 12%
Architecture 4 Turn Accuracy 8 Turn Accuracy Encoder Finetune Time Train Time
ULMFit 56.70% 57.00% 12h 40 mins
QRNN 49.30% 49.20% 0 2h
Hierarchical ULMFit 53.80% 54.90% 12h 18h
Hierarchical QRNN 47.80% 49.40% 0 6h
Transformer 56.64% 56.82% 12h 6h
Table 1: Architecture comparison: all experiments were completed on a single V100 GPU. follows Wolf et al. (2019) and requires 10x slower inference than ULMFit. follows Serban et al. (2015) and Wan and Chen (2018). Details and discussion of the tradeoffs of different approaches can be found in the appendix.

How much history is useful? We find, somewhat counterintuitively, that the ULMFit classifier does not benefit at all from using more than the last 6 turns of conversation history. A table showing the accuracy using different amounts of history can be found in the appendix.
Well calibrated probabilities Since the discriminative model is only generated on (context, response) pairs from a fixed bank of responses, it will occasionally see context that does not match any of the responses it is trained on. In these circumstances, it should not suggest a reply to the doctor. Figure 3 shows that if we restrict our evaluations to the 50% of situations where it is the most confident (as measured by the maximum predicted probability), the rate of bad suggested responses falls from 11% to below 2%.
Comparing different labeling procedures We compare our 187 response group approach described in Section 3 with two other approaches: one using full automation with KMeans (897 clusters) and and another uses the full procedure with only 20 minutes of manual labeling (40 clusters). These approaches both generate roughly 35% bad responses, according to expert evaluations, compared to 11% for the 187 class approach that requires 3 hours of labeling.

5 Conclusion

In this work, we propose a classification model that leverages advances in pretraining techniques to generate useful responses in a wide variety of contexts while restricting generations to a fixed, easy to update set of high quality responses, thereby trading flexbility for control. We find that making this tradeoff also helps the average suggested response quality.

The key difficulty in this approach, and opportunity for future work is the grouping of response classes. We also intend to test whether the control for flexibility tradeoff provides similar quality improvements in other conversational domains.


  • P. Budzianowski, T. Wen, B. Tseng, I. Casanueva, U. Stefan, R. Osman, and M. Gašić (2018) MultiWOZ - a large-scale multi-domain wizard-of-oz dataset for task-oriented dialogue modelling. In

    Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)

    Cited by: §1.
  • A. Conneau, D. Kiela, H. Schwenk, L. Barrault, and A. Bordes (2017) Supervised learning of universal sentence representations from natural language inference data. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, Copenhagen, Denmark, pp. 670–680. External Links: Link Cited by: item 1b.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: Figure 1, item 1c.
  • K. K. Fitzpatrick, A. Darcy, and M. Vierhile (2017) Delivering cognitive behavior therapy to young adults with symptoms of depression and anxiety using a fully automated conversational agent (woebot): a randomized controlled trial. JMIR mental health 4 (2), pp. e19. Cited by: §2.
  • J. Howard and S. Ruder (2018) Universal language model fine-tuning for text classification. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), Vol. 1, pp. 328–339. Cited by: §3.
  • L. Laranjo, A. G. Dunn, H. L. Tong, A. B. Kocaballi, J. Chen, R. Bashir, D. Surian, B. Gallego, F. Magrabi, A. Y. S. Lau, and E. Coiera (2018) Conversational agents in healthcare: a systematic review. Journal of the American Medical Informatics Association 25 (9), pp. 1248–1258. External Links: ISSN 1527-974X, Document, Link, Cited by: §2.
  • C. Liu, R. Lowe, I. V. Serban, M. Noseworthy, L. Charlin, and J. Pineau (2016)

    How NOT to evaluate your dialogue system: an empirical study of unsupervised evaluation metrics for dialogue response generation

    CoRR abs/1603.08023. External Links: Link, 1603.08023 Cited by: §4.1.
  • S. Merity, C. Xiong, J. Bradbury, and R. Socher (2016) Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843. Cited by: §6.2.
  • A. Minutolo, M. Esposito, and G. De Pietro (2017) A conversational chatbot based on kowledge-graphs for factoid medical questions.. In SoMeT, pp. 139–152. Cited by: §2.
  • J. Pennington, R. Socher, and C. Manning (2014)

    Glove: global vectors for word representation

    In Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP), pp. 1532–1543. Cited by: item 1b.
  • G. Pereyra, G. Tucker, J. Chorowski, L. Kaiser, and G. E. Hinton (2017)

    Regularizing neural networks by penalizing confident output distributions

    CoRR abs/1701.06548. External Links: Link, 1701.06548 Cited by: §6.2.
  • F. Petroni et al. (2019) Language models as knowledge bases?. Cited by: §1.
  • A. See, S. Roller, D. Kiela, and J. Weston (2019) What makes a good conversation? how controllable attributes affect human judgments. CoRR abs/1902.08654. External Links: Link, 1902.08654 Cited by: §1.
  • I. V. Serban, A. Sordoni, Y. Bengio, A. C. Courville, and J. Pineau (2015) Hierarchical neural network generative models for movie dialogues. CoRR abs/1507.04808. External Links: Link, 1507.04808 Cited by: Table 1.
  • E. Wallace, S. Feng, N. Kandpal, M. Gardner, and S. Singh (2019) Universal adversarial triggers for attacking and analyzing nlp. External Links: 1908.07125 Cited by: §1.
  • M. Wan and X. Chen (2018) Beyond "how may I help you?": assisting customer service agents with proactive responses. CoRR abs/1811.10686. External Links: Link, 1811.10686 Cited by: §2, Table 1, §4, §6.2.
  • A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. Bowman (2018) GLUE: a multi-task benchmark and analysis platform for natural language understanding. Proceedings of the 2018 EMNLP Workshop BlackboxNLP: Analyzing and Interpreting Neural Networks for NLP. External Links: Link, Document Cited by: item 1c.
  • T. Wolf, V. Sanh, J. Chaumond, and C. Delangue (2019)

    TransferTransfo: A transfer learning approach for neural network based conversational agents

    CoRR abs/1901.08149. External Links: Link, 1901.08149 Cited by: §2, Table 1, §6.2.

6 Appendix

6.1 Detailed Response Class Generation Procedure

  1. Automatically Cluster Similar Responses

    1. We lower case, remove patient and doctor identifying information, and remove punctuation from all responses seen in the data. We consider only the preprocessed responses that occur more than once, to make subsequent steps computationally cheaper.

    2. Estimating the similarity of every pair of responses is an O( operation and most pairs are likely to have negligible similarity. Therefore, we restrict computing similarities of to only responses that are within a semantic neighborhood. More specifically, we encode each response as a vector using three pretrained sentence encoders: InferSent [Conneau et al., 2017], the finetuned AWD-LSTM language model, the average Glove [Pennington et al., 2014] word vector for the response, and the TFIDF weighted average of the Glove vectors. For each encoder, we take the 10 nearest neighbors for each response.

    3. For each candidate pair, we run a supervised similarity model, BERT [Devlin et al., 2018] pretrained on Quora Question Pairs [Wang et al., 2018], to predict the probability that each response pairs’ members are semantically similar. We store the dissimilarity of each pair in a sparse distance matrix, with a distance of 1 (the maximum) if two responses were not blocked together.

    4. The last step is Agglomerative Clustering, on , where the distance between two responses is their predicted probability of dissimilarity or 1 if they were not generated as a candidate pair. Merging requires complete linkage, which means that two clusters are only merged if all the responses in both clusters are >= 75% similar to all responses in the other cluster.

  2. Manually Merge Clusters into Response Classes

    1. Create dataset containing (centroid text, # occurences of cluster constituents), sorted by # occurrences, where the centroid is the most frequently occurring response in the cluster.

    2. For each row in the dataset, the labeler decides whether the cluster centroid text belongs in any existing response class.

      1. If centroid belongs in an existing group (most of the time): add the cluster to the existing response class.

      2. Otherwise: create a new group with a memorable name, e.g “Greet + Pain Scale Question”.

    We merge all responses that have the same impact on the user, and could therefore be used interchangeably. For example, even though “How long have you had the symptoms?” and “When did the symptoms start?” do not mean the same thing, they are both members of the same response class.

6.2 ULMFit Modifications

Like the original work, we start with an AWD-LSTM language model pretrained on the wiki103 dataset [Merity et al., 2016], finetune the language model on our interaction history, and attach a classifier head to predict the response class given the concat pooled representation of final hidden state of the language model.

To accommodate larger batch size than the original work, which we found to help performance, we truncate context sequences to the last 304 tokens before passing them through the language model. This allows us to train with batches of 512 examples, and adjust the learning rate commensurately.

To encode information about speaker changes, we insert two special tokens: one that indicates the beginning of the user’s turn and one that indicates the beginning of the doctor’s turn.

Finally, we add Label smoothing [Pereyra et al., 2017] with

to the cross entropy loss function. Label smoothing smooths one-hot encoded classification labels towards

, and reduces the impact of mislabeled examples on classification training.

Figure 2: Inference and training procedures, starting from a conversational context (left).

Classification architecture comparison: To facilitate comparison with the hierarchical encoding paradigm used by Wan and Chen [2018], we tested two different architectures: hierarchical ULMFit (pretrained) and hierarchical QRNN111All QRNN based experiments use random initialization, and 3 layers with hidden size 64.

(trained from scratch). In both settings, the higher level context RNN was a randomly initialized QRNN. We found that non-hierarchical ULMFit significantly outperformed its hierarchical counterpart while hierarchical and flat QRNN performed comparably. We attribute part of this discrepancy with previous work to the large variance in the length of each turn in our data. Turns vary from 2 to 304 tokens, after truncation, requiring models that consume 3D hierarchical encodings to consume large amounts of padding and smaller batch sizes. Hierarchical ULMFit on 8 turns could only be trained with batch size 32, while the non-hierarchical one fits 512 examples in each batch. To compare with

Wolf et al. [2019], we finetune a pretrained double headed transformer on our conversation data, discard the language modeling and multiple choice heads, and attach a one layer classification head that is trained until convergence. As shown in Table 1, this results in similar accuracy to the ULMFit architecture but is much more computationally expensive (10x train time, 20x slower inference).

6.3 How much history is useful?

Max Turns of History 1 2 3 4 5 6 7 All
Accuracy 44.5% 53.3% 55.3% 56.7% 56.3% 57.7% 57.4% 57.0%
Table 2: One turn is all messages sent consecutively by one conversation participant. Observations are truncated to the most recent turns.

6.4 Comparing Different Labeling Procedures

# Classes Train Examples Bad Responses Unique per 100 responses
40 19,300 38% 17
187 72,981 11% 28
879 86,941 34% 49
Table 3: Generated with process described in Section 2, including manual merge step. Generated with KMeans and no manual merging or review. Bad responses percentage is calculated on 100 test set examples using the the manual evaluation process outlined above. Unique per 100 responses measures how many unique responses are generated per 100 conversation contexts, and is computed on 1000 test set suggestions.

6.5 Opting out at different thresholds

Figure 3: The rate of bad suggested responses falls if we "opt-out", and don’t suggest any response when the model’s predicted probability is low. "Opt Out Frequency" measures how often the model chooses not to suggest a response, while "Usable Suggestion Rate" measures how often the suggested response is not worse than the doctor response observed in the data.