Fast Nearest-Neighbor Classification using RNN in Domains with Large Number of Classes

12/11/2017 ∙ by Gautam Singh, et al. ∙ ibm 0

In scenarios involving text classification where the number of classes is large (in multiples of 10000s) and training samples for each class are few and often verbose, nearest neighbor methods are effective but very slow in computing a similarity score with training samples of every class. On the other hand, machine learning models are fast at runtime but training them adequately is not feasible using few available training samples per class. In this paper, we propose a hybrid approach that cascades 1) a fast but less-accurate recurrent neural network (RNN) model and 2) a slow but more-accurate nearest-neighbor model using bag of syntactic features. Using the cascaded approach, our experiments, performed on data set from IT support services where customer complaint text needs to be classified to return top-N possible error codes, show that the query-time of the slow system is reduced to 1/6^th while its accuracy is being improved. Our approach outperforms an LSH-based baseline for query-time reduction. We also derive a lower bound on the accuracy of the cascaded model in terms of the accuracies of the individual models. In any two-stage approach, choosing the right number of candidates to pass on to the second stage is crucial. We prove a result that aids in choosing this cutoff number for the cascaded system.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In the spectrum of text classification tasks, number of class labels could be two (binary), more than two (multi-class) or a very large number (more than 10000). In industry, text classification tasks with large number of classes arise naturally. In the domain of IT customer support, a user complaint text is classified to return top- most likely error codes (from potentially 10000s of options) that product could be having. Another example is from the domain of health insurance where patients inquire whether their insurance covers a certain diagnosis or treatment. Such patient queries need to be classified into top- appropriate medical codes to look up against a database and serve an automated response to the patient.

In the given setting, enough training samples are not available to adequately train an effective ML-based model [13]. For dealing with this challenge, one work [6] proposes a hierarchical classification where a hierarchy among class labels is known before-hand. Example, an item is classified into top-level categories (“computer” or “sports”) and then further classified into sub-categories (“computer/hardware”, “computer/software” etc.). In another approach [13], the hierarchy among classes is not known. Instead, from the “flat” class labels, a hierarchy is constructed through repeated clustering of the classes.

In this paper, we adopt a different approach. As the number of class labels grows, the task of text classification starts to increasingly resemble the task of document retrieval (or search). Our approach makes use of this observation. Retrieval methods using sophisticated features are effective but very slow at prediction time. ML models on the other hand are fast but imprecise in the given setting. A common approach in retrieval domain uses two-stages 1) filtering stage, a fast, imprecise and inexpensive stage that generates candidate documents and 2) ranking stage, a sophisticated retrieval module that uses complex features (phrase-level or syntactic-level) to re-rank the candidate documents. The two stage retrieval approach mitigates the trade-off between speed and accuracy. By analogy, in this paper, we use statistical, ML based model as the first stage (i.e. candidate generation). This stage is fast but has low accuracy. Next, we use expensive syntactic NLP features and similarity scoring on the candidate classes in the second-stage to generate final top- predicted classes. This stage is slow but more accurate.

A number of ML-based models exist for text classification such as regression models [11], Bayesian models [8] and emerging deep neural networks [9, 7, 4, 17, 18]. On the other hand, many approaches use syntactic NLP-based features for text classification based on similarity of nearest neighbor [12, 15]. Another approach uses word2vec to incorporate word similarity into nearest-neighbor-based text classification task [16]. For candidate generation, hashing has been a well-known technique. Hashing techniques can either be data-agnostic (such as locality sensitive hashing [1, 3]) or data-dependent such as learning to hash [14]. Candidate generation is also classified as conjunctive if the candidates returned contain all the terms in query and disjunctive if the candidates contain at least one term from the query [5, 2].

The main contributions of this paper are as follows:

  1. We propose a hybrid model for text classification that cascades a fast but less-accurate recurrent neural network model and a slow but more-accurate retrieval model which uses bag of syntactic features. We experimentally show that the query time of the slow-retrieval model is reduced to after cascading while improving upon its accuracy. (Section 4)

  2. We prove a meaningful lower bound on the accuracy of cascaded model in terms of the accuracies of the individual models. The result is generic and can be applied on any cascaded retrieval model. (Section 4.1)

  3. Choosing the number of candidate classes to pass on to the second stage in any cascaded model is crucial to the performance. If is too small, the accuracy of the second stage suffers. If is too large, the speed of the model suffers. To choose this, past works typically perform grid search or test on values at regular intervals within a desirable range. In this paper, we prove that in order to choose the best , we need to test the accuracy of the cascaded model only on few special values of rather than all possible values within a desirable range. The result is generic and can be applied to any cascaded retrieval model. (Section 4.2)

  4. We show that our cascaded model outperforms a baseline for speeding-up the slow retrieval model using locality sensitive hashing (LSH). (Section 5)

2 Nearest-Neighbor Model using Bag of Syntactic Features

In this section, we describe the aforementioned slow nearest-neighbor based model. In this technique, we first perform dependency parsing on the text. A dependency parser, takes a sentence as input and returns a tree


where is a set of all nodes or words in sentence and is a set of all edges or 3-tuples in the tree.


Any directed edge represents some grammatical relation between the connected words and . These relations might have labels such as nsubj, dobj, advmod,…etc. and these represent some grammatical function fulfilled by the connected word pair.

Next, we take word-pairs using each edge in the dependency tree and concatenate their word vectors

[10] to get a bag of syntactic feature vectors. This is shown in Algorithm 1

. In the algorithm, notice that weights are assigned to the words during concatenation. These weights are based on heuristics and higher weight is given to nouns, adjectives and verbs than other parts of speech.

1:procedure GenerateBagOfSyntacticFeatures()
4:     for  do
10:               return
Algorithm 1 Generate Bag of Syntactic Features

2.0.1 Computing similarity of text with a class in training set

Given a text query , we next compute its similarity with a particular class in training set. Let the set of texts in the training set corresponding to class be called . Let the denote the set of bags of syntactic features corresponding to each text in . Let denote bag of syntactic features for the query text.


In the above similarity metric, we compute the cosine of feature vectors of the query text and texts corresponding to class in the training set. The similarity of the best matching text is taken as the similarity score for class . Next, the highest scoring classes for the given query are returned.

3 Recurrent Models for Text Classification

For text classification using recurrent models, text is converted into a sequence of word-vectors and given as input to the model. In recurrent models, the words in text may be processed from left to right. In each iteration, previous hidden state and a word are processed to return a new hidden state. In this paper, we experiment with two kinds of recurrent models 1) GRU [4] and 2) LSTM [7]. We describe below the details only for the GRU model.

GRU model is parametric and defined by 6 matrices , , , ,, and output matrix . The recurrence equations are given below.


Initialize the hidden state as a zero vector.


For iteration, , compute the following

where is the number of words in text and

refers to the sigmoid function.

Termination and Computing Output Probability Distribution

The latest hidden state

is subjected to a softmax layer to generate an output probability distribution

. We return the classes corresponding to top- probability values in .

4 Cascaded Model for Fast and Accurate Retrieval

Retrieval model using bag of syntactic features is an example of nearest-neighbor classification. For a given query , this demands that the similarity score be computed with every sample in the training set. On the contrary, if we filter a few candidate classes using the first stage of cascading, the slowness of the retrieval model is overcome. We denote the recurrent machine learning model as and the slow nearest-neighbor classifier as .


The correct class to which query belongs is denoted by . We denote the set of candidate classes returned by the first stage by and number of such candidates by . We use denote the set of classes returned by the first stage. We denote the number of classes to be returned by the second stage as . Therefore . denotes the set of classes returned by the second stage after inspecting the set of classes returned by the first stage. We use to denote the set of classes returned by the second stage if it were to inspect all classes in the training set without any cascading. We define an empirical accuracy metric over a validation set containing user-queries as follows.


The numerator is the number of queries with correct classes in the top- suggestions returned by text classifier . The denominator is the total number of queries.

Before describing the proofs, we define two empirical quantities and

associated to the cascaded model which are easy to estimate as follows using a validation set.


It is easy to compute as follows. is analogously computed.

  1. Run both and on the validation set and store the match scores for each class.

  2. For each , find the number of classes which are present both in top- for and top- for . Also find the number of classes which are present in top- for .

  3. Find the ratio of the above two numbers for each .

In this paper, we assume that empirical estimates of probability values using the validation set are good approximations of their actual values.

4.1 Lower bound on accuracy of cascaded model

The idea is to show that the accuracy of the cascaded model is lower bounded by accuracy of the slow-model times . This is given in following theorem.

Theorem 4.1

For a cascaded model consisting of stages and ,


In order to prove the above, we go through the following lemma.

Lemma 1

Let by any query such that , then

Proof (of Lemma 1)

Since , hence the first stage removes only some incorrect classes and not the correct class . Now since the correct class is in top- for without any cascading, hence, after cascading using , the candidate classes that inspects contain fewer incorrect classes. Thus, introduction of cascading either improves or maintains the rank of the correct class returned in top-. This gives us the above lemma. ∎

Proof (of Theorem 4.1)

Using Lemma 1,


This implies that,


4.2 Picking best

Given a cascaded model, choosing the number of candidates to pass on to the second stage is crucial. If is too small, then accuracy suffers as it becomes more likely that the correct class has not passed the first stage. If is too large, then the query time suffers. The first stage also acts as an elimination round and large dilutes this elimination process by crowding out the correct class.

Text classification models used in each stage are typically complex. Studying their combined behavior in a cascaded setting may not be straightforward. Thus, choosing is a challenge. Typically, the only reliable way to do this is to run the cascaded model on all possible values of and pick a which produces the highest accuracy on a validation set within a desirable range of . This process might be time-consuming as the slow model (as a part of cascaded model) needs to be re-run for every being checked. In the following theorem, we show that not all values of need to be checked. Given that has same value for two distinct values of , the theorem shows that choosing the smaller value of offers at least as much accuracy as choosing the larger one. This implies that we need to check only those values of where changes value.

Theorem 4.2

Let be any query. For such that , if then


In other words, for a given value of , the accuracy is maximized when


For proof of above theorem, we go through the following lemma.

Lemma 2

such that ,

Proof (of Lemma 2)

If the correct class is returned in top- by the first stage for a given query, then for , the correct class is also a part of top- classes returned by the first stage. ∎

Proof (of Theorem 4.2)

We start from the condition given in the theorem i.e., and using Equation 6, we get


Using above Equation 14 and Lemma 2, we get the equality of the set of queries for whom the correct class have passed through the first stage.


Now consider the set of queries which are correctly classified in top- by the cascaded model using ,


Now, when is reduced to , we know on one hand that the number of classes passing to the second stage is smaller i.e., . On the other hand, we know from set equivalence in Equation 15 that the exact same queries contain their correct classes in the candidate classes being passed on. These two observations imply that only incorrect classes have been removed in the first stage while going from to . This reduction in number of classes being passed on can either improve or keep same the rank of the correct class returned in the top- by the second stage in the cascaded setting. Therefore,


This implies that


4.3 Baseline for Query Time Improvement

This section describes the LSH-based baseline for candidate generation. In the training set, for every syntactic feature vector , the bit of the hash code is given as


where are randomly picked. We create a hash-table whose indices are hash-codes of syntactic features in the training set and values are the sets of corresponding class labels. We create a similar hash-table whose values are texts corresponding to the hash-codes instead of class labels. For candidate generation, we use two implementations of the conjunctive approach 1) Class-based where returned candidate classes contain to all hash-codes computed from the query text. 2) Text-based where returned candidate classes have at least one text that contains all hash codes computed from the query text.

5 Experiments and Inferences

In this section, we describe the experiments which demonstrate performances of our proposed techniques and verify the bounds.

Data Set

Two kinds of documents from the domain of IT support are used to generate data set for our experiments 1) product reference documents and 2) past problem requests. From 300MB of product reference documents, we extracted a total of 55K distinct error codes and a total of 15K distinct error code text descriptions. We combined the error codes corresponding to each of 15K distinct error code descriptions to reduce data sparsity per class and to get 15K error-code classes. From the past problem requests, we extracted 40K problems with known error-code classes. Out of these, 90% are used for training while remaining is set aside for validation and testing. Notice that the mean number of texts corresponding to each error code class is approximately 2-3, which is too few for adequate training of statistical ML-based models.

User Query Top- Error Description Suggestions
getting media err detected on device system lic detected a program exception, a problem occurred during the ipl of a partition , partition firmware detected a data storage error , tape unit command timeout, interface error, tape unit detected a read or write error on tape medium, tape unit is not responding, an open port was detected on port 0 , contact was lost with the device indicated , destroy ipl task
hmc appears to be down licensed internal code failure on the hardware management console hmc , system lic detected a program exception , service processor was reset due to kernel panic , the communication link between the service processor and the hardware management console hmc failed , platform lic detected an error, power supply failure, processor 1 pgood fault pluggable , system power interface firmware spif terminated the system because it detected a power fault , detected ac loss, a problem occurred during the ipl of a partition , platform lic failure
failed power supply a fatal error occurred on power supply 1, power supply failure, the power supply fan on un e1 failed , detected ac loss, a fatal error occurred on power supply 2, power supply non power fault ps1 , the power supply fan on un e2 failed the power supply should be replaced as soon as possible , a non fatal error occurred on power supply 1, the power supply fan on un e2 experienced a short stoppage
Table 1: Examples of user-queries and error-code class descriptions returned by our models with highlighted correct response
Figure 1: Plot showing dependence of and on for the cascaded model on the described data set.
values when crossing changes value GRU 0.933 45 45,55,91,179,210 LSTM 0.955 54 54,82,141,337,452
Table 2: Computing relevant values for cascaded models
My tape drive has been giving error once every week Need to have adapter replaced Flashing power button and warning light
Table 3: Heat map showing word weights assigned by GRU model to user queries
Accuracy Time Taken (in s) Mean Min Max LSTM 54 63.33% 9.09 0.73 49.98 82 64.76% 12.23 1.12 63.60 141 65.23% 14.51 1.23 75.00 337 64.29% 23.35 2.09 114.4 452 62.86% 23.46 2.35 108.2 GRU 45 60.95% 9.85 1.22 50.23 179 63.33% 18.93 2.25 89.97 210 64.29% 20.10 2.31 93.18 400 63.81% 25.89 2.51 117.1 500 63.33% 29.15 2.66 131.6
Table 4: Accuracies and CPU times of cascaded model comprising of syntactic-bigram vector model followed by for varying for suggestion of top-10 error-code classes
Model Accuracy Time Taken (in s) Mean Min Max LSTM 61.43% 0.013 0.004 0.112 GRU 60.00% 0.014 0.005 0.062 sn-Vectors 64.29% 84.60 12.54 294.92 sn-Bigrams 63.33% 41.42 8.97 133.51 BOW 43.33% 12.80 7.32 28.19
Table 5: Accuracies and CPU times of various models for suggestion of top-10 error-code classes
LSH Accuracy Time Taken (in s)
Version Mean Min Max
Cluster 5 62.38% 35.1 0.001 142.2
based 10 60.47% 10.3 0.001 27.12
15 57.14% 5.82 0.001 19.39
20 56.66% 5.79 0.001 19.75
Text 1 64.28% 46.1 0.001 194.2
based 3 61.90% 13.5 0.001 41.23
5 55.23% 1.31 0.001 11.30
Table 6: Baseline accuracies (in top-10) and CPU times of LSH-based implementations (for reducing query time of bag of syntactic features technique) for varying number of permutations .

In Figure 1, we show the plots of and for cascaded models having as the GRU model and the LSTM model. Notice that to guarantee the usefulness of the cascaded model, the accuracy of the cascaded model should be at least as much as the less accurate model. The smallest value of that achieves this can be found by using the lower bound in Theorem 4.1. Thus,




Thus, as shown in Table 3, for GRU model, or . Similarly, for LSTM model, or . These thresholds on are shown in Figure 1 on the plots.

In Table 5, we show the accuracies and CPU times of the cascaded model for varying for as the GRU and the LSTM model. Notice that according to Theorem 4.2, we only need to check the accuracies for where (shown in Figure 1) changes value. Therefore Table 5 shows accuracies for some of those values.

On comparing the accuracy of CPU times of the cascaded model (in Table 5) and the bag of syntactic features model (depicted in Table 5 as sn-Vectors), we see that cascading reduces query time to using LSTM and using GRU model. Cascading using LSTM model also improves the accuracy.

Table 5 shows accuracy and CPU times of other uncascaded models such as 1) fast, machine learning based LSTM and GRU models, 2) bag of syntactic bigrams which uses exact string match for finding similarities between syntactic-bigrams after lemmatizing the words and 3) the bag of words model.

In Table 6, we show the results of the two versions of the LSH based baseline. The accuracy and CPU time are shown for varying number of permutations (number of bits in the hash code). Increasing the number of permutations leads to fewer nearest-neighbor candidates which decreases the accuracy and improves query time. Comparing results in Table 5 and 6, we infer that our proposed cascaded model outperforms the described baseline.

6 Conclusion and Future Work

We proposed a cascaded model using fast RNN-based text classifiers and slow nearest-neighbor based model relying on sophisticated NLP features. We successfully resolved challenges posed by large number of classes, very few training samples per class and slowness of nearest-neighbor approach. We derived a generic lower bound on the accuracy of a 2-stage cascaded model in terms of accuracies of individual stages. We proved a result that eases the effort involved in finding the appropriate number of candidates to pass on to the second stage. We outperformed an LSH-based baseline for query time reduction.

Some problems that need further work naturally emerge. One is investigating insights when 2-stage cascading is extended to multi-stage. Another is exploring other machine learning models operate in a cascaded setting.


  • [1] Andoni, A., Indyk, P.: Near-optimal hashing algorithms for approximate nearest neighbor in high dimensions. In: Foundations of Computer Science, 2006. FOCS’06. 47th Annual IEEE Symposium on. pp. 459–468. IEEE (2006)
  • [2] Asadi, N., Lin, J.: Effectiveness/efficiency tradeoffs for candidate generation in multi-stage retrieval architectures. In: Proceedings of the 36th international ACM SIGIR conference on Research and development in information retrieval. pp. 997–1000. ACM (2013)
  • [3] Broder, A.Z., Charikar, M., Frieze, A.M., Mitzenmacher, M.: Min-wise independent permutations. Journal of Computer and System Sciences 60(3), 630–659 (2000)
  • [4] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., Bengio, Y.: Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078 (2014)
  • [5] Clarke, C.L., Culpepper, J.S., Moffat, A.: Assessing efficiency–effectiveness tradeoffs in multi-stage retrieval systems without using relevance judgments. Information Retrieval Journal 19(4), 351–377 (2016)
  • [6] Dumais, S., Chen, H.: Hierarchical classification of web content. In: Proceedings of the 23rd annual international ACM SIGIR conference on Research and development in information retrieval. pp. 256–263. ACM (2000)
  • [7]

    Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural computation 9(8), 1735–1780 (1997)

  • [8]

    McCallum, A., Nigam, K., et al.: A comparison of event models for naive bayes text classification. In: AAAI-98 workshop on learning for text categorization. vol. 752, pp. 41–48. Madison, WI (1998)

  • [9] Mikolov, T., Karafiát, M., Burget, L., Cernockỳ, J., Khudanpur, S.: Recurrent neural network based language model. In: Interspeech. vol. 2, p. 3 (2010)
  • [10]

    Mikolov, T., Sutskever, I., Chen, K., Corrado, G.S., Dean, J.: Distributed representations of words and phrases and their compositionality. In: Advances in neural information processing systems. pp. 3111–3119 (2013)

  • [11] Schütze, H., Hull, D.A., Pedersen, J.O.: A comparison of classifiers and document representations for the routing problem. In: Proceedings of the 18th annual international ACM SIGIR conference on Research and development in information retrieval. pp. 229–237. ACM (1995)
  • [12]

    Sidorov, G., Velasquez, F., Stamatatos, E., Gelbukh, A., Chanona-Hernández, L.: Syntactic dependency-based n-grams as classification features. In: Mexican International Conference on Artificial Intelligence. pp. 1–11. Springer (2012)

  • [13] Tsoumakas, G., Katakis, I., Vlahavas, I.: Effective and efficient multilabel classification in domains with large number of labels. In: Proc. ECML/PKDD 2008 Workshop on Mining Multidimensional Data (MMD’08). pp. 30–44 (2008)
  • [14] Wang, J., Zhang, T., Sebe, N., Shen, H.T., et al.: A survey on learning to hash. IEEE Transactions on Pattern Analysis and Machine Intelligence (2017)
  • [15] Wang, S., Manning, C.D.: Baselines and bigrams: Simple, good sentiment and topic classification. In: Proceedings of the 50th Annual Meeting of the Association for Computational Linguistics: Short Papers-Volume 2. pp. 90–94. Association for Computational Linguistics (2012)
  • [16] Ye, X., Shen, H., Ma, X., Bunescu, R., Liu, C.: From word embeddings to document similarities for improved information retrieval in software engineering. In: Proceedings of the 38th International Conference on Software Engineering. pp. 404–415. ACM (2016)
  • [17] Zhang, X., Zhao, J., LeCun, Y.: Character-level convolutional networks for text classification. In: Advances in neural information processing systems. pp. 649–657 (2015)
  • [18] Zhou, P., Qi, Z., Zheng, S., Xu, J., Bao, H., Xu, B.: Text classification improved by integrating bidirectional lstm with two-dimensional max pooling. arXiv preprint arXiv:1611.06639 (2016)