FuCiTNet: Improving the generalization of deep learning networks by the fusion of learned class-inherent transformations

05/17/2020 ∙ by Manuel Rey-Area, et al. ∙ 16

It is widely known that very small datasets produce overfitting in Deep Neural Networks (DNNs), i.e., the network becomes highly biased to the data it has been trained on. This issue is often alleviated using transfer learning, regularization techniques and/or data augmentation. This work presents a new approach, independent but complementary to the previous mentioned techniques, for improving the generalization of DNNs on very small datasets in which the involved classes share many visual features. The proposed methodology, called FuCiTNet (Fusion Class inherent Transformations Network), inspired by GANs, creates as many generators as classes in the problem. Each generator, k, learns the transformations that bring the input image into the k-class domain. We introduce a classification loss in the generators to drive the leaning of specific k-class transformations. Our experiments demonstrate that the proposed transformations improve the generalization of the classification model in three diverse datasets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 12

page 13

page 14

page 15

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

It is well-known that building robust and efficient supervised deep learning networks requires large amounts of quality data, especially when the involved classes share many visual features. Recent technological advances in camera sensors and the potentially unlimited data provided by internet have helped gathering large volumes of images imageNet ; coco_dataset ; cifar . However, labelling such amounts of data is still manual and costly. In consequence, a large number of problems has still to deal with very small labeled datasets and hence, the resulting classification models are usually unable to generalise correctly to new unseen examples, this problem is known as overfitting.

In general, the issue of overfitting in supervised networks is addressed either from the model or data point-of-view. In the former, transfer learning or diverse regularization techniques are employed. In the later, increasing the volume of the training set using data augmentation strategies alexnet ; intro_dataAug2 ; intro_dataAug3 are considered. For highly sensitive purposes, data augmentation strategies, if not chosen correctly, can potentially change meaningful information resulting in ill-posed training data. Instead of manually selecting data augmentation techniques, recent works showed that learning these transformations from data can lead to significant improvements in the generalization of the models.

Some approaches train the generator of a GAN (Generative Adversarial Network) to learn the suitable augmenting techniques from scratch learn_dataAug ; intro_learnAug3 while others train the generator to learn finding the optimal set of augmenting techniques from an initial space of data augmentation strategies learn_dataAug_gans . The downside of the aforementioned approaches is that the selected transformations are applied to all samples in a training set without taking into account the particularities of each class.

The present work proposes FuCiTNet approach for learning class-inherent transformations in each class within the dataset under evaluation. FuCiTNet is inspired by GANs, it creates a number, , of generators equal to the number of classes. Each generator, , with , will be entrusted to learn the features of a specific -class space. When a sample is fed into the system, it is broadcasted to every generator producing

transformed images, each of which is fed to the classifier which predicts a label with certain error. The error is transferred back to the entrusted generator (specified by the input’s label ground-truth) indicating the amount of change the class transformation must be altered to meet the classifier requirements. The final prediction of the trained classification model will be calculated based on the fusion of the N different output predictions.

The contributions of this work can be summarized as follows:

  • We propose class-inherent transformation generators for improving the generalization capacity of image classification models, especially appropri- ate for problems in which the involved classes share many visual features. Our approach, FuCiTNet, creates as many generators as the number of classes N. Each generator, , learns the inherent transformations that bring the input image from its space to the k-class space. We introduce a classification loss in the generators to drive the learning of specific k-class transformations.

  • The final prediction of the classification model is calculated as a fusion of the N output scores. The source code of FuCiTNet will be available in Github after acceptation

  • Our experiments demonstrate that class-inherent transformations produce a clearer discrimination for the classifier yielding better generalisation performance in three small datasets from two different fields.

This paper is organized as follows. A summary of the most related works to ours are reviewed in Section 2. A description of FuCiTNet model is provided in Section 3. Experimental framework is provided in Section 4. Results and analysis are given in Section 5 and finally conclusions and future work in Section 6.

2 Related Work

Improving the performance of supervised deep neural networks in image classification is still ongoing research. To reach high accuracies the model needs to generalise robustly to unseen cases to eventually avoid overfitting. This is addressed using several approaches.

The most popular approach is data augmentation. It was firstly introduced by Y. LeCun et al. first_data_aug by randomly applying these transformations to the training dataset: shearing, random cropping and rotations. With the revolution of CNNs alexnet , novel transformations appeared, such as horizontal flipping, changes in intensity, zooming, noise injection etc. With data augmentation, most CNN-based classifiers reduce overfitting and increase their classification accuracy.

Dropout, a well known regularization technique, is also used for improving generalization. It was first introduced by Srivastava et al. dropout . The key idea is to randomly drop units (along with their connections) from the neural network throughout the training process. This prevents units from co-adapting too much to the data. It significantly reduces overfitting and gives major improvements over other regularization methods.

The emergence of GANs gans has led to promising results in image classification. They were useful for generating synthetic samples to increase small datasets where the number of samples per class was low eventually introducing variability for the generalisation of the classification models. The downside of GANs is that its latent space converges if there exists a fair good amount of images to train. Generating synthetic samples in small and very small dataset is still an open issue.

For small and very small datasets, the problem of overfitting is even greater. Inspired by GANs, the authors in learn_dataAug designed an augmentation network that learns the set of augmentations that best improve the classifier performance. The classifier tells the generator network which configuration of image transformations prefers when distinguishing samples from different classes. In the same direction, the authors in learn_dataAug_rl ; learn_dataAug_gans

address a similar problem using Reinforcement Learning and adversarial strategy respectively. The former chooses from a list of potential transformations which one is the best suited through augmentation policies. The latter combined the list of transformations to synthesize a total image transformation. The present work is different to all the previously cited works in that it proposes a new approach for learning transformations inherent to each particular class based on the classifier requirements eventually forcing the classes to be as distinguishable as possible from each other.

3 FuCiTNet approach

Inspired by GANs, FuCiTNet learns class-inherent transformations. Our aim is to build a generator that improves the discrimination capacity between the different classes. GANs use the so called adversarial loss which optimizes a min-max problem. The generator, , tries to minimize the following function while the discriminator, , tries to maximize it:

(1)

Where

is the discriminator’s estimate of the probability that real data instance

is real. is the expected value over all real data instances. is the generator’s output given a noise . is the discriminator’s estimate of the probability that a fake instance is real. is the expected value over all random inputs to the generator. The formula derives from the cross-entropy between the real and generated distributions.

In this work, the adversarial concept is slightly changed. Instead of using a discriminator that focuses on maximizing the inter-class probability of the image belonging to the data or latent distribution, we use a classifier that minimizes the loss of of belonging to a particular dataset class. As the images produced by the generator do not need to be similar to its ground truth, we restrict the generator latent space in a different way by replacing with the classification loss of , which will hallucinate samples with specific enhanced features to meet the classifier requirements improving the classification accuracy.

Instead of building one single generator , we create an array of generators, with . Where is equal the number of object classes. Each generator is in charge of learning the inherent features of that specific k-class until becomes part of the k-class space. In other words, each generator will learn the transformations that map the input image from its own i-class domain to the k-class domain, with . The flowchart diagram of FuCiTNet is depicted in Figure 1.

Figure 1: Flowchart of FuCiTNet during training. The input image is broadcasted to every generator. Each generator produces a transformed image . The classifier computes the cross entropy loss which is transferred back to the generator commissioned to enhance features of the class given by the input’s groundtruth label .

Figure 2: Flowchart of FuCiTNet in inference.

The architecture of our generators , with , consists of 5 identical residual blocks residual_nets . Each block has two convolutional layers with

kernels and 64 feature maps followed by batch-normalization layers

batch_norm and ParametricReLU param_relu

as activation function. The last residual block is followed by a final convolutional layer which reduces the output image channels to 3 to match the input’s dimensions.

Our classifier is a ResNet-18 which consists of an initial convolutional layer with kernels and 64 feature maps followed by a max pool layer. Then, 4 blocks of two convolutional layers with kernels with 64, 128, 256 and 512 feature maps respectively followed by a average pooling and a fully connected layer which outputs a

element vector. ReLU is used as the activation function

relu .

The classifier uses a standard cross entropy loss (). Given a batch size of input images with their respective ground truth labels, :

(2)

The true label distribution for the image is depicted by and indicates the predicted label distribution for the transformed image .

All generators use the same loss inspired by SRGAN . For the generator with the respective loss is indicated by Formula (3). It consists of a multiple term loss constituted by a pixel-wise MSE term, a perception MSE term and the classifier loss. Adding this classifier loss to the generators is one of the main novelty of our method. The classification loss is added to the generator loss with a weighted factor indicating how much the generator must change its outcome to suit the classifier. It is worth noting that, for each input image, the classification loss is only transferred to a particular generator indicated by the input’s ground-truth label.

(3)

The similarity term () keeps the generator from changing the image too much while the classifier loss drives the output away from the input and close to the k-class feature space.

The pixel-wise MSE is obtained by performing a regular L2-norm between each pixel in the input image and the generated image.

(4)

The perceptual MSE assesses similarity between and by feeding each of them into a pretrained VGG-16 vgg . The euclidean distance between the VGG-16 feature maps () defines the perceptual loss.

(5)

The cross entropy loss coming from the classifier is added to a particular generator loss in the generator array specified by the input’s ground-truth label with a weighted factor controlling the impact of the classifier within the generator:

(6)

where is the batch size, the number of classes and is the element in the true label distribution of the image belonging to class . Likewise, depicts the classifier predictions for the transformed image with of belonging to the class .

In this manner, each generator in the array is able to learn class-inherent features from a specific class eventually, accentuating the differentiation among classes.

In inference time the general flowchart of the system is modified as shown in Figure 2. The final class prediction is given by concatenating each output distribution for every

in the logit domain, taking the arg max and computing the modulo with the amount of classes

as indicated in Formula (7)

(7)

4 Experimental framework

To assess the capacity of FuCiTNet in increasing the discrimination among classes, we created three datasets from two different fields using object-classes that are frequently confused by the best performing classification model. In particular, from Tiny ImageNet

le2015tiny , we created two datasets made of two and three classes respectively, cat-vs-dog and cat-vs-dog-vs-goldfish. From NWPU-RESISC45 remote sensing dataset made of aerial ortho-images Cheng_2017 , we created church-vs-palace dataset which represent one of the most similar/confusing pair of classes in this dataset. To further increase the complexity of this dataset, we downsampled the church and palace images to pixels. A brief description of the three datasets is provided in Table 1.

Dataset # classes # pixels/image Train Test Total
cat-vs-dog two 800 200 1000
cat-vs-dog-vs-goldfish three 1200 300 1500
church-vs-palace two 1120 280 1400
Table 1: Description of the three evaluated datasets together with the used data distribution 80%-20%.

The generator networks were initialized randomly while the classifier, ResNet-18, was initialized using the pretrained weights on ImageNet imageNet . As optimization algorithm, we used Adam adam in both, generator and classifier with . The adopted learning rates for generators and classifier are different, for the generators, we used a value of and

for ResNet-18. The reason why generators have a lower learning rate is because they need to be subtle when generating the transformed image and avoid the classifier from falling behind on capturing the rate of change in appearance. For the ResNet-18 we used learning rate decay of 0.1 each 5 epochs. We have also used weight decay of

to avoid overfitting. We applied early stopping monitoring based on the validation loss with a patience of 10 epochs.

The weighted factor in the generator loss in Eq. (3

) is a hyperparameter to be tuned. We evaluated a space of 13 values: [1, 0.5, 0.1, 0.075, 0.05, 0.025, 0.01, 0.0075, 0.005, 0.0025, 0.001, 0.00075, 0.0005]. We assessed the effect from a high contribution in the loss towards a softer impact. For each value of

the system was trained throughout 100 epochs with a batch size of 32. We alternate updates between the generator and classifier, i.e., for each batch we update first the classifier then we update the generator. We evaluate FuCiTNet on the validation set after each epoch. The chosen weights for both networks are the ones which minimize the classification loss in this set.

For a fair comparison, we compare the results of FuCiTNet with the two most related approaches  learn_dataAug and learn_dataAug_gans . As learn_dataAug_gans provides the source code, we analyze and show the results on the three considered datasets, cat-vs-dog, cat-vs-dog-vs-goldfish and church-vs-palace, following the same experimental protocol we used in the rest of experiments. However, as learn_dataAug does not provide the source code and considered only binary classes, we included only the results reported in the paper using the same experimental protocol on dog-vs-cat dataset. The results of this approach on cat-vs-dog-vs-goldfish and church-vs-palace are not available to us. Both approaches from learn_dataAug and learn_dataAug_gans used ResNet CNN architecture. In addition, we also compare our results to the results of the best classification model obtained based on the same network architecture and by manually selecting the set of optimizations that reaches the highest performance.

In all the experiment, we used 3-fold cross validation following a 80:20 hold out data distribution as depicted in Table 1

. All the experiments were executed on a NVIDIA Titan Xp. All the implementations were performed using PyTorch DL framework

pytorch .

5 Results and analysis

This section presents, compares and analyzes the quantitative and qualitative results of FuCiTNet with the state-of-the-art methods, learn_dataAug and learn_dataAug_gans , and with the best classification model.

Setup Accuracy Cat mean Dog mean
confidence confidence
None 0.872
FT 0.865
Data aug 0.892 1.673 2.520
Data aug, FT 0.888
learn_dataAug 0.770
learn_dataAug_gans 0.800
FuCiTNet, None, 0.870
FuCiTNet, FT, 0.880
FuCiTNet, Data aug, 0.883
FuCiTNet, Data aug, FT, 0.912 2.734 2.101
Table 2: Performance of ResNet-18 classification model without (row: 1 to 4) and with (row: 7 to 10) FuCiTNet, using different configurations, on cat-vs-dog dataset. Data aug consists of random horizontal flipping, random rotation and random affine. The accuracy of the state-of-the-art approaches, learn_dataAug and learn_dataAug_gans , is shown in row 5 and 6 respectively.
Predicted/Actual Cat Dog
Cat 86 14
Dog 11 89
(a) Confusion matrix for the setup: Data aug
Predicted/Actual Cat Dog
Cat 97 3
Dog 10 90
(b) Confusion matrix for the setup: FuCiTNet, Data aug, FT,
Table 5: Cat-vs-dog confusion matrices for fold #1 in 3FCV for the best reference model(a) and FuCiTNet (b)
Setup Accuracy Church mean Palace mean
confidence confidence
None 0.774
No data aug, FT 0.769
Data aug 0.783 0.769 1.452
Data aug, FT 0.779
learn_dataAug_gans 0.779
FuCiTNet, None, 0.751
FuCiTNet, FT, 0.767
FuCiTNet, Data aug, 0.777
FuCiTNet, Data aug, FT, 0.795 1.868 1.580
Table 6: Performance of ResNet-18 based classification model without (row: 1 to 4) and with (row: 6 to 9) FuCiTNet, using different configurations, on the church-vs-palace dataset. Data augmentation consists of random horizontal flipping, random rotation and random affine. The accuracy of the state-of-the-art approach learn_dataAug_gans is shown in row 5.

max width=0.475 Predicted/Actual Church Palace Church 114 26 Palace 30 110

(a) Confusion matrix for the setup: Data aug

max width=0.475 Predicted/Actual Church Palace Church 114 26 Palace 27 113

(b) Confusion matrix for the setup: FuCiTNet, Data aug, FT,
Table 9: Church-vs-palace confusion matrices for fold #1 in 3FCV for the best reference model(a) and FuCiTNet (b)
Setup Accuracy Cat mean Dog mean Goldfish mean
confidence confidence confidence
None 0.898
FT 0.909
Data aug 0.915 7.796 4.778 9.617
Data aug, FT 0.911
learn_dataAug_gans 0.790
FuCiTNet, None, 0.887
FuCiTNet, FT, 0.902
FuCiTNet, Data aug, 0.900
FuCiTNet, Data aug, FT, 0.920 3.484 3.248 6.613
Table 10: Performance of ResNet-18 based classification model without (row: 1 to 4) and with (row: 6 to 9) FuCiTNet, using different configurations, on cat-vs-dog-vs-goldfish dataset. Data augmentation consists of random horizontal flipping, random rotation and random affine. The accuracy of the state-of-the-art approach learn_dataAug_gans is shown in row 5.

max width=0.475 Predicted/Actual Cat Dog Goldfish Cat 94 6 0 Dog 18 80 2 Goldfish 3 2 95

(a) Confusion matrix for the setup: Data aug

max width=0.475 Predicted/Actual Cat Dog Goldfish Cat 94 6 0 Dog 9 85 6 Goldfish 1 1 98

(b) Confusion matrix for the setup: FuCiTNet, Data aug, FT,
Table 13: Cat-vs-dog-vs-goldfish confusion matrices for fold #2 in 3FCV for the best reference model(a) and FuCiTNet(b)

Quantitative results

The performance, in terms of accuracy, of the best classification model with and without applying FuCiTNet, on the three datasets is presented in Tables 2, 6 and 10. Several configurations were analyzed, ’None’ indicates that the network was initialized with ImageNet weights and retrained on the dataset. ’Data aug’ indicates that the set of the optimal data augmentation techniques was applied. ’FT’, for Fine Tuning, indicates that only the last fully connected layer was retrained on the dataset and the remaining layers were frozen. If ’FT’ is not indicated, it means that the whole layers of the network were re-trained on the dataset. To compare FuCiTNet with the state-of-the-art, we also include the accuracy of the method proposed in

learn_dataAug on cat-vs-dog dataset only (see Table 2) and the method proposed in learn_dataAug_gans on the three datasets (see Tables 2, 6 and 10 ).

When applying FuCiTNet to the test images, the classification model reaches , and higher accuracy than the best classification model on cat-dog, church-vs-palace and cat-vs-dog-vs-goldfish respectively. The number of TP and TN also improves in all the three datasets as shown in the confusions matrices in Tables 5, 9 and 13.

With respect to the state-of-the-art, FuCiTNet provides 14%, 2% and 16.4% better accuracy than the approach proposed in learn_dataAug_gans on cat-dog, church-vs-palace and cat-vs-dog-vs-goldfish respectively and 18.44% better accuracy than the approach proposed by learn_dataAug on cat-vs-dog.

These results can explain that FuCiTNet makes the object-class in the input images more distinguishable to the model. Our transformation approach can be considered as an incremental optimization to the rest of optimizations space, since on all the datasets, the best results were obtained by combining FuCiTNet with fine-tuning and data-augmentation.

We have also analyzed the mean confidence of each class in the three problems such that:

where is the number of correctly classified test images, from the test split, with ground truth class. The mean confidence of the best reference model and FuCiTNet are shown in the third column of Tables 2, 6 and 10. As we can observe from these results FuCiTNet clearly improve the mean confidence of the model in all the classes in cat-vs-dog and church-vs-palace datasets. Although FuCiTNet have lowered the mean confidence in the cat-vs-dog-vs-goldfish dataset, the accuracy, number of true positives and true negatives have improved.

FuCiTNet provides several advantages over the method proposed in learn_dataAug and learn_dataAug_gans

. It does not require neither paired datasets of high and low resolution input images nor super-resolution network. Unlike

learn_dataAug , in which the transformations were learnt regardless the class of the sample, our results demonstrate that exploiting class-inherent features improves significantly the borders between visually similar classes.

(a)
(b)
(c)
(d)
(e)
(f)
Figure 9: Left) Original sample of dog class; middle) Dog transformation to original sample; right) Cat transformation to original sample. The confidence score of the model for (dog class, cat class) are indicated below each image. The score for the left image is obtained using the best reference model whereas the ones for the middle and right images are obtained with FuCiTNet.
(a)
(b)
(c)
(d)
(e)
(f)
Figure 16: left) Original sample of cat class; middle) Dog transformation to original sample; right) Cat transformation to original sample. The confidence score of the model for (dog class, cat class) are indicated below each image. The score for the left image is obtained using the best reference model whereas the ones for the middle and right images are obtained with FuCiTNet.
(a)
(b)
(c)
(d)
(e)
(f)
Figure 23: left) Original sample of Palace class; middle) Church transformation to original sample; right) Palace transformation to original sample. The confidence score of the model for (church class, palace class) are indicated below each image. The score for the left image is obtained using the best reference model whereas the ones for the middle and right images are obtained with FuCiTNet.
(a)
(b)
(c)
(d)
(e)
(f)
Figure 30: left) Original sample of Church class; middle) Church transformation to original sample; right) Palace transformation to original sample. The confidence score of the model for (church class, palace class) are indicated below each image. The score for the left image is obtained using the best reference model whereas the ones for the middle and right images are obtained with FuCiTNet.
(a)
(b)
(c)
(d)
(e)
(f)
(g)
(h)
Figure 39: left) Original sample of goldfish class; middle left) Cat transformation to original sample; middle right) Dog transformation to original sample; right) Goldfish transformation to original sample. The confidence score of the model for (dog class, cat class, goldfish class) are indicated below each image. The score for the left image is obtained using the best reference model whereas the others are obtained with FuCiTNet.

Qualitative results

Figures 9, 16, 23, 30 and 39 show visually the transformations applied by FuCiTNet to different original input images from different classes. As it can be observed from these images, the transformations affect:

  • the contour or border of the object class in the transformed images, as it can be seen in Figures 9(c) and (f) and Figures 16(c) and (f).

  • the pixels that constitute the body of the object-class, as it is the case in Figure 9(f) and Figures 23(b) and (c), Figures 30(b), (c) and (e), Figures 39(c), (d), (f), (g) and (h).

  • or background as it is the case in Figures 9(c) and (f), Figures 16(b), (c), (e) and (f).

In some images in the church-vs-palace testset, the model does not add any transformation. This indicates that the classifier does not need any additional transformation to differentiate the object-class in those specific images. In general, in this dataset, most of the palaces have either a dome or a rectangular roof with an inner courtyard but generally the churches do not include any of these features. In the cases where a church has an inner courtyard, the model needs to add a transformation so that the classifier can distinguish it better from a church as it can be seen in Figure 30. Likewise, when the palace has a very distinguishable shape from a church, the model does not add any transformation as it can be seen in Figure 23(d), (e) and (f).

This effect does not occur in the cat-vs-dog and cat-vs-dog-vs-fish test sets, the background in these images is so diverse, it includes, people, furniture, occlusion etc., that the model always need some more transformation to differentiate better between the involved classes.

6 Conclusions and future work

Our aim in this work was reducing over-fitting in very small datasets in which the involved object-classes share many visual features. We presented FuCiTNet model in which a novel array of generators learn independently class-inherent transformations of each class of the problem. We introduced a classification loss in the generators to drive the learning of specific k-class transformations. The learnt transformations are then applied to the input test images to help the classifier distinguishing better between different classes.

Our experiments demonstrated that FuCiTNet increases the classification generalization capability on three small datasets with very similar classes. With the benchmark datasets we demonstrate that FuCiTNet behaves robustly in diverse-nature data and handles properly different view dimensions (zenital and frontal). We conclude that our method yields strong gains as an incremental optimization technique additional to the standards when searching for a better model performance.

As future work, we are planning to explore and adapt FuCiTNet to small medical datasets.

References