Differentiable Approximation Bridges For Training Networks Containing Non-Differentiable Functions

05/09/2019 ∙ by Jason Ramapuram, et al. ∙ Apple Inc. 0

Modern neural network training relies on piece-wise (sub-)differentiable functions in order to use backpropation for efficient calculation of gradients. In this work, we introduce a novel method to allow for non-differentiable functions at intermediary layers of deep neural networks. We do so through the introduction of a differentiable approximation bridge (DAB) neural network which provides smooth approximations to the gradient of the non-differentiable function. We present strong empirical results (performing over 600 experiments) in three different domains: unsupervised (image) representation learning, image classification, and sequence sorting to demonstrate that our proposed method improves state of the art performance. We demonstrate that utilizing non-differentiable functions in unsupervised (image) representation learning improves reconstruction quality and posterior linear separability by 10 also observe an accuracy improvement of 77 25 classification setting with the sort non-linearity. This work enables the usage of functions that were previously not usable in neural networks.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Most state of the art neural networks [14, 29, 10] rely on some variant of Robbins-Monroe [28] based stochastic optimization. The requirement for utilizing this algorithm includes the assumption that the gradients of the functional be Lipschitz continuous. In this work we attempt to study approximate gradient pathways that allow for arbitrary non-linear functions as sub-modules of neural networks. We do so by introducing a smooth neural network approximation (DAB) to the non-differentiable function and utilize its gradients during training time. At inference, we drop the DAB network entirely, thus requiring no extra memory or compute.

2 Related Work

Method / Objective
Supports
Non-Differentiable
Functions
Scales to
Large
Dimensions
Works with
Operators that
Change Dimension
DNI [17] / DPG [16] / DGL [2] Asynchronous network updates. no yes yes
Backprop Alternatives [31, 19, 9, 1, 5, 6, 23] Optimize arbitrary functions. yes no yes
Score Function Estimator [8, 22] Differentiate non-differentiable functions. yes no yes
Straight-Through Estimator [3] Ignore non-differentiable functions. yes yes no
DAB (ours) Differentiate non-differentiable functions. yes yes yes

Traditional Solutions

: Traditional solutions to handling non-differentiable functions in machine learning tend to cluster around using the Score Function Estimator (SFE)

[8, 22] (also known as REINFORCE [35]) or the Straight-Through Estimator (STE) [3]

. While the SFE is an unbiased estimate of the gradients, it generally suffers from high variance

[11] and needs to be augmented with Control Variates [7] that require manual tuning and domain knowledge. The STE on the other hand is a solution that simply copies gradients back, skipping the non-differentiable portion (i.e. treating it as an identity operation). Furthermore, the STE does not allow for operators that change dimension, i.e. , since it is unclear how the gradients of the larger/smaller output would be copied back.

Backpropagation Alternatives

: Machine learning has a rich history of backpropagation alternatives, ranging from Simulated Annealing

[31]

, Particle Swarm Optimization

[19]

, Genetic Algorithms

[9], Evolutionary Strategies [1], and Bayesian approaches such as MCMC based sampling algorithms [5, 6]. These algorithms have generally been shown to not scale to complex, large dimension optimization problems [27] that are embodied in large neural network models. More recent work in the analysis of backpropagation alternatives [23] have demonstrated that it is possible to learn weight updates through the use of random matrices; while no statement is made about training / convergence time.

Asynchronous Neural Network Updates: Recent work such as Decoupled Neural Interfaces (DNI) [17] and Decoupled Parallel Backpropagation (DPG) [16] introduced an auxiliary network to approximate gradients in RNN models. Similar approximation techniques have been introduced [2] (DGL) to allow for greedy layerwise CNN based training. The central objective with these models is to enable asynchronous updates to speed up training time. Our work differs from all these solutions in that our objective is not to improve training speed / parallelism, but to learn a function approximator of a non-differentiable function such that it provides a meaningful training signal for the preceding layers in the network. This approach allows us to utilize complex, non-differentiable functions such as kmeans, sort, signum, etc, as intermediary layers in neural network pipelines.

3 Preliminaries

In their seminal work [28], Robbins and Monroe developed a framework of optimization to solve for the roots of a function , under the assumption of the existence of a unique solution. They characterized the iterative update rule as:

(1)

Given an observable random variable

, parametrized by , the objective is defined as solving for , wherein is some constant. If we assume that the gradient of is -Lipschitz continuous, we can replace with its gradient; this is due to the fact that we can bound the difference between iterative updates of with and without the application of :

(2)

Given that we can upper bound the normed-parameter difference by the normed-functional gradients and through the assumption of small iterates in parameter space (K < 1), repeated application of this update rule converges to a fixed point. This derives from the Banach Fixed Point Theorem that states:

theoTheorem Given a metric space and a contractive mapping , then admits a unique fixed point .

In the specific case of Equation 2, the metric is the L2-norm. Note, a norm is a more rigid constraint than a metric since norms require translation invariance and the scaling property in addition to all the requirements of a metric.

4 Model

Figure 1: Graphical Model for our proposed framework. represents the non-differentiable function.

A graphical model is listed in Figure 1 and depicts a generic version of our framework. Given some true input data distribution, , and a set of ( in Figure 1) functional approximators, , our learning objective is defined as maximizing the log likelihood, coupled with a new regularizer, , introduced in this work:

(3)
(4)
(5)

Since the latent representations are simple functional transformations, we can represent the distributions 111. (Equation 4), by dirac distributions centered around their functional evaluations: . This allows us to rewrite our objective as shown in Equation 5, where is a problem specific hyper-parameter. A key point is that during the forward pass of the model we use the non-differentiable function, .

4.1 Choice of metric under simplifying assumptions

In this section we analyze the regularizer introduced in Equation 5 in the special case where the non-differentiable function output,

, is a (differentiable) linear transformation of the previous layer coupled with additive Gaussian noise (aleatoric uncertainty):

(6)
(7)

Under these simplifying assumptions our model induces a Gaussian log-likelihood as shown in Equation 7. At this point we can directly maximize the above likelihood using maximum likelihood estimation. Alternatively, if we have apriori knowledge we can introduce it as a prior, , over the weights

, and minimize the negative log-likelihood times the prior to evaluate the posterior, i.e. the MAP estimate. If we make a conjugate prior assumption,

, then:

(8)
(9)
(10)

This analysis leads us to the well known result that a linear transformation with aleatoric Gaussian noise results in a loss proportional to the L2 loss (Equation 10). However, what can we say about the case where is a non-linear, non-differentiable output? In practice we observe that using the L2 loss, coupled with a non-linear neural network transformation,

produces strong results. To understand why, we appeal to the central limit theorem which states that the scaled mean of the random variable converges to a Gaussian distribution as the sample size increases. Furthermore, if we can assume a zero mean, positive variance, and finite absolute third moment, it can be shown that the rate of convergence to a Gaussian distribution is proportional to

, where is the number of samples [4]. We explored alternatives such as the Huber loss [15], cosine loss, L1 loss and cross-entropy loss, but found the L2 loss to consistenty produce strong results and utilize it for all presented experiments.

5 Experiments

We quantify our proposed algorithm on three different benchmarks: sequence sorting, unsupervised representation learning, and image classification. For a full list of hyper-parameters, model specifications and,

example PyTorch

[25] code see the Appendix.

5.1 Sequence Sorting

Figure 2: Left: Dense sorting model with non-differentiable-function. Right: LSTM Model with a non-differentiable function.
Length (T) ELU-Dense Ptr-Net[33]
Read-Process
Write[32]
Signum-RNN (ours) Signum-Dense (ours)
T=5 86.46 4.7% (x5) 90% 94% 99.3 0.09% (x5) 99.3 0.25% (x5)
T=10 0 0% (x5) 28% 57% 92.4 0.36% (x5) 94.2 0.1% (x5)
T=15 0 0% (x5) 4% 10% 87.2 0.3% (x5) 79.8 0.8% (x5)
Table 1: All-or-none sorting test-accuracy (presented as mean std (replication)) for varying length (T) sequences.

In this experiment, we analyze sequence sorting with neural networks. input sequences of length

are generated by sampling a uniform distribution,

. The objective of the model, , is to predict a categorical output distribution, corresponding to the index of the sorted input sequence, . We follow [32] and evalute the all-or-none (called out-of-sequence in [32]) accuracy for all presented models. This metric penalizes an output, , for not predicting the entire sequence in correct order (no partial-credit), .

We develop two novel models to address the sorting problem: a simple feed-forward neural network (Figure

2-left) and a sequential RNN model (Figure 2-right). The central difference between a traditional model and the ones in Figure 2, is the incorporation of a non-differentiable (hard) function shown in red in both model diagrams. During the foward pass of the model, we directly use the (hard) non-differentiable function’s output for the subsequent layers. The DAB network receives the same input as the non-differentiable function and caches its output. This cached output is used in the added regularizer presented in Section 4.1 in order to allow the DAB to approximate the non-differentiable function ( in Figure 2). During the backward pass (dashed lines), the gradients are routed through the DAB instead of the non-differentiable function. While it is possible to utilize any non-differentiable function, in this experiment we use the following -margin signum function:

(11)

We contrast our models with state of the art for sequence sorting ([33, 32]) and a baseline ELU-Dense multilayer neural network and demonstrate (Table 1) that our model outperforms all baselines (in some cases by over 75%). These gains can be attributed to the choice of (non-differentiable) non-linearty that we use in our model. We believe that the logic of sequence sorting can be simplified using a function that directly allows binning of intermediary model outputs into , which in turn simplifies implementing a swap operation.

5.1.1 Effect of Pondering

Figure 3: Effect of increasing ponder steps for 5 (left) and 10 (right

) sort problems. The mean and standard deviation of the maximum test all-or-none accuracy are reported over 5 runs per ponder length.

The model presented in [32] evaluates the effect of pondering in which they iterate an LSTM with no further inputs. This pondering allows the model to learn to sort its internal representation. Traditional sorting algorithms run operations on the dimensional input sequence. Iterating the LSTM attempts to parallel this. We introduce a similar pondering loop into our model and show the performance benefit in Figure 3; we observe a similar performance gain, but notice that the benefits decrease after five pondering iterations.

5.2 Unsupervised Representations

In this experiment, we study the usefulness of learnt unsupervised representations by traditional latent variable models such as the Variational Autoencoder (VAE)

[20]. Variational Autoencoders, coupled with discrete reparameterization methods [24, 18] enable learning of compact binary latent representations. Given an input random variable , VAEs posit an approximate posterior, , over a latent variable, , and maximize the Evidence Lower BOund (ELBO). We contrast the VAE ELBO with our optimization objective below 222Note that the backward pass for the DAB follows the same logic as presented earlier.:

VAE DAB

We posit that good latent representions should not only be compact (in terms of bits-per-pixel), but also useful as a mechanism to linearly disentangle a complex input space as well as reconstruct the original sample well. Simple, disentangled latent representations are the ultimate goal of unsupervised learning, and we demonstrate the usefulness that non-differentiable functions bring to this goal. We do so through the use of two metrics: the MS-SSIM

[34] and linear classification of posterior samples. The MS-SSIM is a metric typically used in compression related studies and allows us to get a sense of how similar (in structure) the reconstructed sample is to the original. Linear classification of posterior samples provides us with an evaluation of disentangled latent representations: a quintessential feature of a good unsupervised representation. Importantly, we do not specifically train the model to induce better linearly separability as that would necessitate the use of supervision.

In Figures 4 and 5 we contrast our models (dab-) against traditional bernoulli and discrete gumbel-reparameterized models [24, 18]

and a naive downsample, binary-threshold and classify solution (

threshold). We summarize the variants we utilize below:

Functional Form
dab-bernoulli Sample from non-reparameterized distribution: .
dab-binary
dab-signum Equation 11. BPP is scaled by due to trinary representation.
threshold bilinear(x, BPP), threshold(x, ) and linearly classify for the best .

We begin by utilizing the training set of Fashion MNIST, CIFAR10, and ImageNet to train the baseline bernoulli and discrete VAEs as well as the models with the non-differentiable functions (

dab-) presented above. We train five models per level of bpp for FashionMNIST and CIFAR10 and evaluate the MS-SSIM and linear classification accuracy at each point. We repeat the same, but only for bpp=0.00097 for Imagenet due to computational restrictions. The linear classifier is trained on the same training dataset333We use the encoded posterior representation as input to the linear classifier. after the completion of training the main model. We present the mean and standard deviation results in Figures 4 and 5 for all three datasets. We observe that our models perform better in terms of test-reconstruction (MS-SSIM) and also provides a more disentangled latent representation (in terms of linear test accuracy). We observe either dab-signum or dab-binary performing better than all variants across all datasets. Since only the activation is being changed, the benefit can be directly attributed to the use of the non-differentiable functions used as activations.


Figure 4: We sweep a range of bits-per-pixel (BPP) values for FashionMNIST and CIFAR10, performing 5 experiments at each BPP level per model type. Left: Test Multi-Scale Structural Similarity (MS-SSIM) [34]. Middle: Purely unsupervised linear posterior test-classification accuracy. Right: Test input images and their reconstructions at BPP=0.1.
Figure 5: Five trials (each) of ImageNet using only BPP=0.00097 due to computational restrictions. Left: Test MS-SSIM [34]. Right: Purely unsupervised linear posterior test-classification accuracy; Images are compressed from to 786 bits (496 for dab-signum since ) and yield a 40x improvement over random guessing (0.001).

5.3 Image Classification

CIFAR10
Test-Accuracy
Mean +/- Std Functional Form
Baseline 92.87% 0.06% Identity(x)
Signum 91.95% 0.07% Equation 11
Sort 92.93% 0.1% sort-row(x) sort-col(x)
Topk 92.21% 0.14% (sort-row(x) sort-col(x))[0:k]
K-Means 91.97% 0.16% kmeans(x, k=10)
Table 2: CIFAR10 test-accuracy over five trials for each row. is a concatenation.

In this experiment we evaluate how well our model performs in classifying images of CIFAR10 using a Resnet18 model tailored to operate on images. We evaluate a variety of non-differentiable functions and present their test accuracy and standard deviation in Table 2. We observe that utilizing a Sort as the final activation in the Resnet18 model improves upon the vanilla model (Baseline) by 0.1%. While these results are statistically significant, the difference seems rather small. In contrast, when we used the same non-differentiable function in a simpler model for the same problem, we observed a larger difference (10%) between the test-accuracies. We attribute this to the regularization effect induced by the choice of non-differentiable activation.

5.3.1 Ablation / Case Studies

Figure 6: Left: Signum non-differentiable function evaluated at different sections of a Resnet18 model. Middle: Earth mover distance between input layer to non-differentiable function and output of non-differentiable function. Right: CIFAR10 test accuracy for DAB vs. Straight-Through-Estimator using Sort-1D.

Layer Placement: In order to validate where to place the non-differentiable function within the Resnet18 architecture, we perform an ablation study wherein we train each model 5 times (Figure 6-left). Since the Resnet18 model has four residual blocks, we place the non-differentiable function at the output of each block. We observe that the network remains stable throughout training when placing the non-differentiable function at the fourth layer and use this for all experiments presented in Table 2. We posit that this is due to the fact that networks typically learn low level Haar like filters at initial layers and enacting a complex, non-differentiable function at an initial layer destroys the coherence during the learning process.

Conditioning of Preceding Layer: We utilize the sort non-differentiable function shown in Table 2 to explore the effect of the regularizer introduced in Equation 5. We calculate the empirical earth mover distance between the input layer to the non-differentiable function ( in Figure 1) and its output ( in Figure 1). We repeat the experiment five times and report the mean and standard deviation in Figure 6-middle. We observe that the regularizer conditions the input layer into being more ameanable to sorting, as demonstrated by the decrease in the test EMD over time.

Contrasting with STE: We evaluate the test-accuracy of the Straight-Through-Estimator (STE) in contrast to DAB. The STE was originally utilized to bypass differentiating through a simple argmax operator [3], however, here we analyze how well it performs when handling a complex operand such as sorting. Since the STE cannot operate over transformations that vary dimensionality, we use a simplified version of the sort operator from the previous experiment. Instead of sorting the rows and columns as in Table 2, we simply flatten the feature map and run a single sort operation. This allows us to utilize the STE in this scenario. We observe in Figure 6-right that DAB clearly outperforms the STE.

6 Discussion

Extensive research in machine learning has focused on discovering new (sub-)differentiable non-linearities to use within neural networks [13, 21, 26]. In this work, we demonstrate a novel method to allow for the incorporation of generic, non-differentiable functions within neural networks and empirically demonstrate their benefit through a variety of experiments using a handful of non-differentiable operators such as kmeans, sort and signum. Rather than manually deriving sub-differentiable solutions (eg: [12]), using the Straight-Through-Estimator (eg: [30]) or relying on REINFORCE, we directly use a neural network to learn a smooth approximation to the non-differentiable function. This work opens up the use of much more complex non-differentiable operators within neural network pipelines.

References

  • [1] T. Asselmeyer, W. Ebeling, and H. Rosé. Evolutionary strategies of optimization. Physical Review E, 56(1):1171, 1997.
  • [2] E. Belilovsky, M. Eickenberg, and E. Oyallon. Decoupled greedy learning of cnns. arXiv preprint arXiv:1901.08164, 2019.
  • [3] Y. Bengio, N. Léonard, and A. Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • [4] A. C. Berry. The accuracy of the gaussian approximation to the sum of independent variates. Transactions of the american mathematical society, 49(1):122–136, 1941.
  • [5] A. E. Gelfand and A. F. Smith. Sampling-based approaches to calculating marginal densities. Journal of the American statistical association, 85(410):398–409, 1990.
  • [6] W. R. Gilks, S. Richardson, and D. Spiegelhalter. Markov chain Monte Carlo in practice. Chapman and Hall/CRC, 1995.
  • [7] P. Glasserman. Monte Carlo methods in financial engineering, volume 53. Springer Science & Business Media, 2013.
  • [8] P. W. Glynn. Likelihood ratio gradient estimation for stochastic systems. Communications of the ACM, 33(10):75–84, 1990.
  • [9] D. E. Goldberg and J. H. Holland. Genetic algorithms and machine learning. Machine learning, 3(2):95–99, 1988.
  • [10] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio. Generative adversarial nets. In Advances in neural information processing systems, pages 2672–2680, 2014.
  • [11] W. Grathwohl, D. Choi, Y. Wu, G. Roeder, and D. Duvenaud. Backpropagation through the void: Optimizing control variates for black-box gradient estimation. ICLR, 2018.
  • [12] E. Grefenstette, K. M. Hermann, M. Suleyman, and P. Blunsom. Learning to transduce with unbounded memory. In Advances in neural information processing systems, pages 1828–1836, 2015.
  • [13] R. H. Hahnloser, R. Sarpeshkar, M. A. Mahowald, R. J. Douglas, and H. S. Seung. Digital selection and analogue amplification coexist in a cortex-inspired silicon circuit. Nature, 405(6789):947, 2000.
  • [14] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In

    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

    , pages 770–778, 2016.
  • [15] P. J. Huber. Robust estimation of a location parameter. In Breakthroughs in statistics, pages 492–518. Springer, 1992.
  • [16] Z. Huo, B. Gu, and H. Huang. Training neural networks using features replay. In Advances in Neural Information Processing Systems, pages 6660–6669, 2018.
  • [17] M. Jaderberg, W. M. Czarnecki, S. Osindero, O. Vinyals, A. Graves, D. Silver, and K. Kavukcuoglu. Decoupled neural interfaces using synthetic gradients. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1627–1635. JMLR. org, 2017.
  • [18] E. Jang, S. Gu, and B. Poole. Categorical reparameterization with gumbel-softmax. ICLR, 2017.
  • [19] J. Kennedy. Particle swarm optimization. Encyclopedia of machine learning, pages 760–766, 2010.
  • [20] D. P. Kingma and M. Welling. Auto-encoding variational bayes. ICLR, 2014.
  • [21] G. Klambauer, T. Unterthiner, A. Mayr, and S. Hochreiter. Self-normalizing neural networks. In Advances in neural information processing systems, pages 971–980, 2017.
  • [22] J. P. Kleijnen and R. Y. Rubinstein. Optimization and sensitivity analysis of computer simulation models by the score function method. European Journal of Operational Research, 88(3):413–427, 1996.
  • [23] T. P. Lillicrap, D. Cownden, D. B. Tweed, and C. J. Akerman.

    Random synaptic feedback weights support error backpropagation for deep learning.

    Nature communications, 7:13276, 2016.
  • [24] C. J. Maddison, A. Mnih, and Y. W. Teh.

    The concrete distribution: A continuous relaxation of discrete random variables.

    ICLR, 2017.
  • [25] A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and A. Lerer. Automatic differentiation in pytorch. In NIPS-W, 2017.
  • [26] P. Ramachandran, B. Zoph, and Q. V. Le. Searching for activation functions. arXiv preprint arXiv:1710.05941, 2017.
  • [27] L. M. Rios and N. V. Sahinidis. Derivative-free optimization: a review of algorithms and comparison of software implementations. Journal of Global Optimization, 56(3):1247–1293, 2013.
  • [28] H. Robbins and S. Monro. A stochastic approximation method. The annals of mathematical statistics, pages 400–407, 1951.
  • [29] C. Szegedy, S. Ioffe, V. Vanhoucke, and A. Alemi. Inception-v4, inception-resnet and the impact of residual connections on learning. arXiv preprint arXiv:1602.07261, 2016.
  • [30] A. van den Oord, O. Vinyals, et al. Neural discrete representation learning. In Advances in Neural Information Processing Systems, pages 6306–6315, 2017.
  • [31] P. J. Van Laarhoven and E. H. Aarts. Simulated annealing. In Simulated annealing: Theory and applications, pages 7–15. Springer, 1987.
  • [32] O. Vinyals, S. Bengio, and M. Kudlur. Order matters: Sequence to sequence for sets. ICLR, 2016.
  • [33] O. Vinyals, M. Fortunato, and N. Jaitly. Pointer networks. In Advances in Neural Information Processing Systems, pages 2692–2700, 2015.
  • [34] Z. Wang, E. P. Simoncelli, and A. C. Bovik. Multiscale structural similarity for image quality assessment. In The Thrity-Seventh Asilomar Conference on Signals, Systems & Computers, 2003, volume 2, pages 1398–1402. Ieee, 2003.
  • [35] R. J. Williams.

    Simple statistical gradient-following algorithms for connectionist reinforcement learning.

    Machine learning, 8(3-4):229–256, 1992.

7 Appendix

7.1 Simple Pytorch Implementation

We provide an example of the base class for any hard function along with an example of the -margin signum operand (Equation 11) below. The BaseHardFn

accepts the input tensor

x along with the DAB approximation (soft_y). Coupling this with the DAB loss (Equation 4.1) provides a basic interface for using DABs with any model.

    @staticmethod
    def forward(ctx, x, soft_y, hard_fn, *args):
        """ Runs the hard function for forward, cache the output and returns.
            All hard functions should inherit from this, it implements the autograd override.
        :param ctx: pytorch context, automatically passed in.
        :param x: input tensor.
        :param soft_y: forward pass output 

(logits)

 of DAB approximator network.
        :param hard_fn: to be passed in from derived class.
        :param args: list of args to pass to hard function.
        :returns: hard_fn(tensor), backward pass using DAB.
        :rtype: torch.Tensor
        """
        hard = hard_fn(x, *args)
        saveable_args = list([a for a in args if isinstance(a, torch.Tensor)])
        ctx.save_for_backward(x, soft_y, *saveable_args)
        return hard
    @staticmethod
    def _hard_fn(x, *args):
        raise NotImplementedError("implement _hard_fn in derived class")
    @staticmethod
    def backward(ctx, grad_out):
        """ Returns DAB derivative.
        :param ctx: pytorch context, automatically passed in.
        :param grad_out: grads coming into layer
        :returns: dab_grad(tensor)
        :rtype: torch.Tensor
        """
        x, soft_y, *args = ctx.saved_tensors
        with torch.enable_grad():
            grad = torch.autograd.grad(outputs=soft_y, inputs=x,
                                       grad_outputs=grad_out,
                                       retain_graph=True)
        return grad[0], None, None, None
class SignumWithMargin(BaseHardFn):
    @staticmethod
    def _hard_fn(x, *args):
        """ x[x < -eps] = -1
            x[x > +eps] = 1
            else x = 0
        :param x: input tensor
        :param args: list of args with 0th element being eps
        :returns: signum(tensor)
        :rtype: torch.Tensor
        """
        eps = args[0] if len(args) > 0 else 0.5
        sig = torch.zeros_like(x)
        sig[x < -eps] = -1
        sig[x > eps] = 1
        return sig
    @staticmethod
    def forward(ctx, x, soft_y, *args):
        return BaseHardFn.forward(ctx, x, soft_y, SignumWithMargin._hard_fn, *args)

7.2 Model Hyper-Parameters

FashionMNIST CIFAR10 ImageNet Sorting Classification
Optimizer Adam RMSProp RMSProp Adam Adam
LR 1e-3 1e-4 1e-4 1e-4 1e-4
Batch-Size 128 128 192 1024 128
Activation ELU ReLU ELU Tanh ELU
Normalization Batchnorm
Batchnorm-Conv,
None-Dense
Batchnorm-Conv,
None-Dense
None Batchnorm
Layer-Type Similar to U-Net
Coord-Conv encoder,
Dense decoder
Resnet18 encoder,
Dense decoder
LSTM (gradclip 5) + Dense(256) CifarResnet18
DAB- 10 70 2 10 10