coral-pytorch
CORAL and CORN implementations for ordinal regression with deep neural networks.
view repo
In recent times, deep neural networks achieved outstanding predictive performance on various classification and pattern recognition tasks. However, many real-world prediction problems have ordinal response variables, and this ordering information is ignored by conventional classification losses such as the multi-category cross-entropy. Ordinal regression methods for deep neural networks address this. One such method is the CORAL method, which is based on an earlier binary label extension framework and achieves rank consistency among its output layer tasks by imposing a weight-sharing constraint. However, while earlier experiments showed that CORAL's rank consistency is beneficial for performance, the weight-sharing constraint could severely restrict the expressiveness of a deep neural network. In this paper, we propose an alternative method for rank-consistent ordinal regression that does not require a weight-sharing constraint in a neural network's fully connected output layer. We achieve this rank consistency by a novel training scheme using conditional training sets to obtain the unconditional rank probabilities through applying the chain rule for conditional probability distributions. Experiments on various datasets demonstrate the efficacy of the proposed method to utilize the ordinal target information, and the absence of the weight-sharing restriction improves the performance substantially compared to the CORAL reference approach.
READ FULL TEXT VIEW PDFCORAL and CORN implementations for ordinal regression with deep neural networks.
Many real-world prediction tasks involve ordinal target labels. Popular examples of such ordinal tasks are customer ratings (e.g., a product rating system from 1 to 5 stars) and medical diagnoses (e.g., disease severity labels such as none, mild, moderate, and severe). While we can apply conventional classification losses, such as the multi-category cross-entropy, to such problems, they are suboptimal since they ignore the intrinsic order among the ordinal targets. For example, for a patient with severe disease status, predicting none and moderate would incur the same loss even though the difference between none and severe is more significant than the difference between moderate and severe. Moreover, unlike in metric regression, we cannot quantify the distance between the ordinal ranks. For instance, the difference between a disease status of none and mild cannot be quantitatively compared to the difference between mild and moderate. Hence, ordinal regression (also called ordinal classification or ranking learning) can be considered as an intermediate problem between classification and regression.
Among the most common machine learning-based approaches to ordinal regression is Li and Lin’s extended binary classification framework [li2007ordinal] that was adopted for deep neural networks by Niu et al. [niu2016ordinal]. In this work, we solve the rank inconsistency problem (Fig. 1) of this ordinal regression framework without imposing constraints that could limit the expressiveness of the neural network and without substantially increasing the computational complexity.
The contributions of our paper are as follows:
A new rank-consistent ordinal regression framework, CORN (Conditional Ordinal Regression for Neural Networks), based on the chain rule for conditional probability distributions;
Rank consistency guarantees without imposing the weight-sharing constraint used in the CORAL reference framework [cao2020rank];
Experiments with different neural network architectures and datasets showing that CORN’s removal of the weight-sharing constraint improves the predictive performance compared to the more restrictive reference framework.
Ordinal regression is a classic problem in statistics, going back to early proportional hazards and proportional odds models
[mccullagh1980regression]. To take advantage of well-studied and well-tuned binary classifiers, the machine learning field developed ordinal regression methods based on extending the rank prediction to multiple binary label classification subtasks
[li2007ordinal]. This approach relies on three steps: (1) extending rank labels to binary vectors, (2) training binary classifiers on the extended labels, and (3) computing the predicted rank label from the binary classifiers. Modified versions of this approach have been proposed in connection with perceptrons
[crammer2002pranking][shashua2003ranking; rajaram2003classification; chu2005new]. In 2007, Li and Lin presented a reduction framework unifying these extended binary classification approaches [li2007ordinal].In 2016, Niu et al. adapted Li and Lin’s extended binary classification framework to train deep neural networks for ordinal regression [niu2016ordinal]; we refer to this method as OR-NN. Across different image datasets, OR-NN was able to outperform other reference methods. However, Niu et al. pointed out that ORD-NN suffers from rank inconsistencies among the binary tasks and that addressing this limitation might raise the training complexity substantially. Cao et al. [cao2020rank] recently addressed this rank inconsistency limitation via the CORAL method. To avoid increasing the training complexity, CORAL achieves rank consistency by imposing a weight-sharing constraint in the last layer, such that the binary classifiers only differ in their bias units. However, while CORAL outperformed the OR-NN method across several face image datasets for age prediction, the weight-sharing constraint may impose a severe limitation in terms of functions that the neural network can approximate. In this paper, we investigate an alternative approach to guarantee rank consistency without increasing the training complexity and restricting the neural network’s expressiveness and capacity.
Several deep neural networks for ordinal regression do not build on the extended binary classification framework. These methods include Zhu et al.’s [zhu2021convolutional]
convolutional ordinal regression forest for image data, which combines a convolutional neural network with differentiable decision trees. Diaz and Marathe
[diaz2019soft]proposed a soft ordinal label representation obtained from a softmax layer, which can be used for scenarios where interclass distances are known. Another method that does not rely on the extended binary classification framework is Suarez et al.’s distance metric learning algorithm
[suarez2021ordinal]. Petersen et al. [Petersen2021-diffsort] developed a method based on differentiable sorting networks based on pairwise swapping operations with relaxed sorting operations, which can be used for ranking where the relative ordering is known but the absolute target values are unknown.This paper focuses on addressing the rank inconsistency on OR-NN without imposing the weight-sharing of CORAL, which is why the study of the methods mentioned above is outside the scope of this paper.
This section describes the details of our CORN method, which addresses the rank inconsistency in Niu et al.’s OR-NN [niu2016ordinal] without requiring CORAL’s [cao2020rank] weight-sharing constraint.
Let
denote a dataset for supervised learning consisting of
training examples, where denotes the inputs of the -th training example and its corresponding class label. In an ordinal regression context, we refer to as the rank, where with rank order . The objective of an ordinal regression model is then to find a mappingthat minimizes a loss function
.With CORAL, Cao et al. [cao2020rank] proposed a deep neural network for ordinal regression that addressed the rank inconsistency of Niu et al.’s OR-NN [niu2016ordinal], and experiments showed that addressing rank consistency had a positive effect on predictive performance.
Both CORAL and OR-NN built on an extended binary classification framework [li2007ordinal], where the rank labels are recast into a set of binary tasks, such that indicates whether exceeds rank . The label predictions are then obtained via , where is the rank index, which is computed as
(1) |
Here, is the probability prediction of the -th binary classifier in the output layer, and is an indicator function that returns if the inner condition is true and otherwise.
The CORAL method ensures that the predictions are rank-monotonic, that is, , which provides rank consistency to the ordinal regression model. While the rank label calculation via Eq. 1 does not strictly require consistency among the task predictions, , it is intuitive to see why rank consistency can be theoretically beneficial and can lead to more interpretable results via the binary subtasks. While CORAL provides this rank consistency, CORAL’s limitation is a weight-sharing constraint in the output layer. Consequently, all binary classification tasks use the same weight parameters and only differ in their bias units, which may limit the flexibility and expressiveness of an ordinal regression neural network based on CORAL.
The proposed CORN model is a neural network for ordinal regression that exhibits rank consistency without any weight-sharing constraint in the output layer (Fig. 2). Instead, CORN uses a new training procedure with conditional training subsets that ensures rank consistency through applying the chain rule of probability.
Given a training set , CORN applies a label extension to the rank labels similar to CORAL, such that the resulting binary label indicates whether exceeds rank . Similar to CORAL, CORN also uses learning tasks associated with ranks in the output layer as illustrated in Fig. 2.
However, in contrast to CORAL, CORN estimates a series of conditional probabilities using conditional training subsets (described in Section
3.4) such that the output of the th binary task represents the conditional probability^{1}^{1}1When , represents the initial unconditional probability .(2) |
where the events are nested: .
The transformed, unconditional probabilities can then be computed by applying the chain rule for probabilities to the model outputs:
(3) |
Since , we have
(4) |
which guarantees rank consistency among the binary tasks.
Our model aims to estimate and the conditional probabilities . Estimating is a classic binary classification task under the extended binary classification framework with the binary labels . To estimate the conditional probabilities such as , we focus only on the subset of the training data where
. As a result, when we minimize the binary cross-entropy loss on these conditional subsets, for each binary task, the estimated output probability has a proper conditional probability interpretation
^{2}^{2}2When training a neural network using backpropagation, instead of minimizing the
loss functions corresponding to the conditional probabilities on each conditional subset separately, we can minimize their sum, as shown in the loss function we propose in Section 3.5, to optimize the binary tasks simultaneously..In order to model the conditional probabilities in Eq. 3, we construct conditional training subsets for training, which are used in the loss function (Section 3.5) that is minimized via backpropagation. The conditional training subsets are obtained from the original training set as follows:
where , and denotes the size of . Note that the labels are subject to the binary label extension as described in Section 3.3. Each conditional training subset is used for training the conditional probability prediction for .
Let denote the predicted value of the -th node in the output layer of the network (Fig. 2), and let denote the size of the -th conditional training set. To train a CORN neural network using backpropagation, we minimize the following loss function:
(5) |
We note that in , represents the -th training example in . To simplify the notation, we omit an additional index to distinguish between in different conditional training sets.
To improve the numerical stability of the loss gradients during training, we implement the following alternative formulation of the loss, where
are the net inputs of the last layer (aka logits), as shown in Fig.
2, and :(6) |
A derivation showing that the two loss equations are equivalent and a PyTorch implementation are included in the Supplementary Material. In addition, the Supplementary Material includes a visual illustration of the loss computation based on the conditional training subsets.
To obtain the rank index of the -th training example, and any new data record during inference, we threshold the predicted probabilities corresponding to the binary tasks and sum the binary labels as follows:
where the predicted rank is .
The MORPH-2 dataset^{3}^{3}3https://www.faceaginggroup.com/morph/ [ricanek2006morph] contains 55,608 face images, which were processed as described in [cao2020rank]: facial landmark detection [sagonas2016300] was used to compute the average eye location, which was then used by the EyepadAlign function in MLxtend v0.14 [raschka2018mlxtend] to align the face images. The original MORPH-2 dataset contains age labels in the range of 16-70 years. In this study, we use a balanced version of the MORPH-2 dataset containing 20,625 face images with 33 evenly distributed age labels within the range of 16-48 years.
The Asian Face Database (AFAD)^{4}^{4}4https://github.com/afad-dataset/tarball [niu2016ordinal] contains 165,501 faces in the age range of 15-40 years. No additional preprocessing was applied to this dataset since the faces were already centered. In this study, we use a balanced version of the AFAD dataest with 13 age labels in the age range of 18-30 years.
The Image Aesthetic Dataset (AES)^{5}^{5}5http://www.di.unito.it/~schifane/dataset/beauty-icwsm15/ [schifanella2015image] used in this study contains 13,868 images, each with a list of beauty scores ranging from 1 to 5. To create ordinal regression labels, we replaced the beauty score list of each image with its average score rounded to the nearest integer in the range 1-5. Compared to the other image datasets MORPH-2 and AFAD, the size of the AES dataset was relatively small, and we did not attempt to create a class-balanced version of this dataset for this study. Moreover, measures such as dropping scarce classes are challenging as the aesthetic standards are subtle and complicated. Removing any class might cause the models to be unable to learn the underlying standards. Therefore, even though a balanced dataset might be preferred for general method comparisons, we used the original dataset given these feasibility considerations.
The Fireman Dataset (Fireman)^{6}^{6}6https://github.com/gagolews/ordinal_regression_data is a tabular dataset that contains 40,768 instances, 10 numeric features, and an ordinal response variable with 16 categories. We created a balanced version of this dataset consisting of 2,543 instances per class and 40,688 from the 16 ordinal classes in total.
Each dataset was randomly divided into 75% training data, 5% validation data, and 20% test data. We share the partitions for all datasets, along with all preprocessing code used in this paper, in the code repository (see Section 4.4).
For method comparisons on the image datasets (MORPH-2, AFAD, and AES), we used ResNet-34 [he2016deep] as the backbone architecture since it is an established architecture that is known to achieve good performance on a variety of image classification datasets.
For the tabular Fireman dataset, we used a simple multilayer perceptron architecture (MLP) with leaky ReLU
[maas2013rectifier]activation functions (negative slope 0.01) and BatchNorm. Since the MLP architectures were prone to overfitting, a dropout layer with drop probability 0.2 was added after the leaky ReLU activations in each hidden layer. In addition, we used the AdamW [loshchilov2017decoupled]optimizer with a weight decay rate of 0.2. The number of hidden layers (one or two) and the number of units per hidden layer were determined by hyperparameter tuning (see Section
4.3 for more details).In this paper, we compare the performance of a neural network trained via the rank-consistent CORN approach to both Niu et al.’s [niu2016ordinal] OR-NN method (no rank consistency) and CORAL (rank consistency by using identical weight parameters for all nodes in the output layer). In addition, we implement neural network classifiers trained with standard multicategory cross-entropy loss as a baseline, which we refer to as CE-NN. While all methods (CE-NN, OR-NN, CORAL, and CORN) use different loss functions during training, it is worth emphasizing that they can share similar backbone architectures and only require small changes in the output layer. For instance, to implement a neural network for ordinal regression using the proposed CORN method, we replaced the network’s output layer with the corresponding binary conditional probability task layer.
All model evaluations and comparisons are based on the mean absolute error (MAE) and root mean squared error (RMSE), which are defined as follows:
where is the ground truth rank of the -th test example and is the predicted rank, respectively.
For each method, we used the validation set to determine the best hyperparameter setting and the best training epoch. Then, using the best hyperparameter setting for each method, we repeated the model training five times using different random seeds (0, 1, 2, 3, and 4) for the random weight initialization and dataset shuffling.
Across the different image datasets using the ResNet-34 architecture, we found the best hyperparameter setting for CORN was a batch size of 16 and a learning rate of while the best setting for all other methods was a batch size of 256 and a learning rate of .
On the tabular Fireman dataset, the hyperparameter tuning of the MLPs included the learning rate, batch size, number of hidden layers, and number of hidden units per layer. All MLPs performed best with two instead of one hidden layer. The CORAL and CE-NN combination was 300 units in the first hidden layer and 200 units in the second. CORN and OR-NN performed best with 300 units in each of the two hidden layers. The best learning rate for CORN was 0.001 and 0.0005 for all other methods. CORN and OR-NN performed best with a batch size of 128, while CORAL and CE-NN performed best when the batch size was set to 64.
All models were trained for 200 epochs with stochastic gradient descent via adaptive moment estimation
[kingma2015adam] with the default decay rates. After the training, the model corresponding to the training epoch with the lowest validation set RMSE was chosen as the best model to be evaluated on the final test dataset. The complete training logs for all methods are provided in the repository. (Section 4.4).All neural networks were implemented in PyTorch 1.8 [paszke2019pytorch] and trained on NVIDIA GeForce RTX 2080Ti graphics cards. We make all source code used for the experiments available^{7}^{7}7https://github.com/Raschka-research-group/corn-ordinal-neuralnet and provide a user-friendly implementation of CORN in the coral-pytorch Python package^{8}^{8}8https://github.com/Raschka-research-group/coral-pytorch.
To compare deep neural networks trained with our proposed CORN method to CORAL, Niu et al.’s OR-NN [niu2016ordinal], and the baseline cross-entropy loss (CE-NN), we conducted a series of experiments on three image datasets and one tabular dataset. As detailed in Section 4.2, the experiments on the image datasets were based on the ResNet-34 architecture, and we used a multilayer perceptron for the tabular dataset.
Method | Seed | MORPH-2 | AFAD | AES | FIREMAN | ||||
---|---|---|---|---|---|---|---|---|---|
MAE | RMSE | MAE | RMSE | MAE | RMSE | MAE | RMSE | ||
CE-NN | 0 | 3.81 | 5.19 | 3.31 | 4.27 | 0.43 | 0.68 | 0.80 | 1.14 |
1 | 3.60 | 4.8 | 3.28 | 4.19 | 0.43 | 0.69 | 0.80 | 1.14 | |
2 | 3.61 | 4.84 | 3.32 | 4.22 | 0.45 | 0.71 | 0.79 | 1.13 | |
3 | 3.85 | 5.21 | 3.24 | 4.15 | 0.43 | 0.70 | 0.80 | 1.16 | |
4 | 3.80 | 5.14 | 3.24 | 4.13 | 0.42 | 0.68 | 0.80 | 1.15 | |
AVGSD | 3.73 0.12 | 5.04 0.20 | 3.28 0.04 | 4.19 0.06 | 0.43 0.01 | 0.69 0.01 | 0.80 0.01 | 1.14 0.01 | |
OR-NN [niu2016ordinal] | 0 | 3.21 | 4.25 | 2.81 | 3.45 | 0.44 | 0.70 | 0.75 | 1.07 |
1 | 3.16 | 4.25 | 2.87 | 3.54 | 0.43 | 0.69 | 0.76 | 1.08 | |
2 | 3.16 | 4.31 | 2.82 | 3.46 | 0.43 | 0.69 | 0.77 | 1.10 | |
3 | 2.98 | 4.05 | 2.89 | 3.49 | 0.44 | 0.70 | 0.76 | 1.08 | |
4 | 3.13 | 4.27 | 2.86 | 3.45 | 0.43 | 0.69 | 0.74 | 1.07 | |
AVGSD | 3.13 0.09 | 4.23 0.10 | 2.85 0.03 | 3.48 0.04 | 0.43 0.01 | 0.69 0.01 | 0.76 0.01 | 1.08 0.01 | |
CORAL [cao2020rank] | 0 | 2.94 | 3.98 | 2.95 | 3.60 | 0.47 | 0.72 | 0.82 | 1.14 |
1 | 2.97 | 4.03 | 2.99 | 3.69 | 0.47 | 0.72 | 0.83 | 1.16 | |
2 | 3.01 | 3.98 | 2.98 | 3.70 | 0.48 | 0.73 | 0.81 | 1.13 | |
3 | 2.98 | 4.01 | 3.00 | 3.78 | 0.44 | 0.70 | 0.82 | 1.16 | |
4 | 3.03 | 4.06 | 3.04 | 3.75 | 0.46 | 0.72 | 0.82 | 1.15 | |
AVGSD | 2.99 0.04 | 4.01 0.03 | 2.99 0.03 | 3.70 0.07 | 0.46 0.02 | 0.72 0.01 | 0.82 0.01 | 1.15 0.01 | |
CORN (ours) | 0 | 2.98 | 4 | 2.80 | 3.45 | 0.41 | 0.67 | 0.75 | 1.07 |
1 | 2.99 | 4.01 | 2.81 | 3.44 | 0.44 | 0.69 | 0.76 | 1.08 | |
2 | 2.97 | 3.97 | 2.84 | 3.48 | 0.42 | 0.68 | 0.77 | 1.10 | |
3 | 3.00 | 4.06 | 2.80 | 3.48 | 0.43 | 0.69 | 0.76 | 1.08 | |
4 | 2.95 | 3.92 | 2.79 | 3.45 | 0.43 | 0.69 | 0.74 | 1.07 | |
AVGSD | 2.98 0.02 | 3.99 0.05 | 2.81 0.02 | 3.46 0.02 | 0.43 0.01 | 0.68 0.01 | 0.76 0.01 | 1.08 0.01 |
As the results in Table 1 show, CORN outperforms all other methods on the three image datasets, MORPH-2, AFAD, and AES. We repeated the experiments on different random seeds for model weight initialization and data shuffling, which ensures that the results are not coincidental.
It is worth noting that even though CORAL’s rank consistency was found to be beneficial for model performance [cao2020rank], it performs noticeably worse than OR-NN on AES and the balanced AFAD dataset. This might likely be due to CORAL’s weight-sharing constraint in the output layer, which could affect the expressiveness of the neural networks and thus limit the complexity of what it can learn. In contrast the CORN method, which is also rank-consistent, performs better than OR-NN on MORPH and AFAD.
For all methods, the overall performances on the AES dataset are within a similar range. One possible explanation is that the AES dataset contains a wide variety of photos depicting people, nature, various objects, and architecture that were ranked based on a subjective aesthetic score between 1 and 5. In contrast, MORPH-2 and AFAD depict face images with a relatively consistent image location. The much larger variety and complexity of images in the AES dataset, and the complicated nature of judging the aesthetics of an image, might explain why all methods show similar performances on the AES dataset. The reason behind the low MAE and RMSE values on AES are owed to the relatively small number of categories (5) compared to MORPH-2 (33) and AFAD (13).
We found that OR-NN and CORN have identical performances on the tabular Fireman dataset (Table 1), outperforming both the CE-NN and CORAL in both test MAE and test RMSE. Here, similar to AES, the performances are relatively close, and the 16-category prediction task is relatively easy for a fully connected neural network regardless of the loss function.
In this paper, we developed the rank-consistent CORN framework for ordinal regression via conditional training datasets. We used CORN to train convolutional and fully connected neural architectures on ordinal response variables. Our experimental results showed that the CORN method improved the predictive performance compared to the rank-consistent reference framework CORAL. While our experiments focused on image and tabular datasets, the generality of our CORN method allows it to be readily applied to other types of datasets to solve ordinal regression problems with various neural network structures.
This research was supported by the Office of the Vice Chancellor for Research and Graduate Education at the University of Wisconsin-Madison with funding from the Wisconsin Alumni Research Foundation.
We can convert the CORN loss function,
(7) |
into an alternative version
(8) |
where are the net inputs of the last layer (aka logits) and , since
This allows us to use the logsigmoid(z)
function that is implemented in deep learning libraries such as PyTorch as opposed to using log(1-sigmoid(z))
; the former yields numerically more stable gradients during backpropagation. A PyTorch implementation of the CORN loss function is shown in Fig. S1.
Comments
There are no comments yet.