In this paper, we propose a method, which allows us to alleviate or completely avoid the notorious problem of numerical instability and stiffness of the adjoint method for training neural ODE. On the backward pass, we propose to use the machinery of smooth function interpolation to restore the trajectory obtained during the forward integration. We show the viability of our approach, both in theory and practice.READ FULL TEXT VIEW PDF
In this study, we propose a novel method to train neural ordinary differential equations (neural ODEs)(Chen et al., 2018)
. This method performs stable and memory-efficient backpropagation through the solution of initial value problems (IVP). Throughout the study, we use the term neural ODEs for all neural networks with so-called ODE blocks. Each ODE block is a continuous analog of a residual neural network(He et al., 2016). Indeed, such networks can be considered as Euler discretization of ordinary differential equations. Neural ODEs were motivated by the relationship between residual neural networks and dynamical systems (Lu et al., 2017; Chang et al., 2017; Ruthotto and Haber, 2018)
. Such models have already been successfully applied to various machine learning problems, including classification, generative modeling, and time series prediction(Chen et al., 2018; Grathwohl et al., 2018; Rubanova et al., 2019).
In general, the core of ODE block is the initial value problem (IVP)
which depends on parameter and considers the input features as the initial condition. The forward pass in neural ODEs is performed by standard numerical ODE solvers. In most cases, adaptive solvers (Dormand and Prince, 1980)
are preferable, because they adjust both the number and the size of steps on every iteration automatically. During the backward pass in neural ODEs, the loss function, which depends on the solution of the IVP, should be differentiated with respect to parameter . A direct application of backpropagation to ODE solvers leads to memory issues, because in this case, for every intermediate time step , the output must be stored as a part of computational graph. Moreover, using adaptive methods becomes risky, since we are not able to explicitly control the number of steps and, as a result, the amount of allocated memory (Eriksson et al., 2004).
However, the adjoint method (Pontryagin et al., 1961; Giles and Pierce, 2000; Marchuk, 2013) performs backpropagation through the initial value problem with a small memory footprint. In this method, backpropagation through the IVP (1) is performed by solving another IVP, which we call the backward IVP. The backward IVP consists of several groups of equations, which are solved backward-in-time: from to . Therefore, the adjoint method can be viewed as a “continuous backpropagation” (see Section 2 for details).
The main property of the adjoint method is that after the forward pass we save and do not store other intermediate values for other . Instead, in the backward pass, we recompute in required . Thus, the following IVP
is a part of the backward IVP.
Gholami et al. (2019) showed that this approach might often lead to catastrophic numerical instabilities and introduced a method, called ANODE, to mitigate this problem. It suggested to make checkpoints at selected time points during the forward pass. After that, ANODE performs forward pass on the smaller time intervals once again to store intermediate values , , for backward pass. During the backward pass, instead of solving (2) ANODE uses stored . Thus, to perform backward pass, ANODE requires additional evaluations of the right-hand side in (2). Also, this method requires more memory to store intermediate activations . Moreover, storing intermediate activations in fixed points prevents the usage of the adaptive ODE solvers in gradient computations. Therefore, ANODE is more numerically stable than the adjoint method but suffers from the following drawback. In the backward pass, one needs to solve as many IVPs as the number of stored checkpoints.
To avoid this drawback, we propose interpolated adjoint method (IAM), which is described in detail in Section 2. This method is based on a smooth function interpolation technique to approximate and get rid of IVP (2). Under mild conditions on the right-hand side , function as a solution of IVP. We use barycentric Lagrange interpolation (BLI) on a Chebyshev grid (Berrut and Trefethen, 2004) to approximate the solution of the backward IVP. If the approximation is accurate enough, we can replace the solution of IVP (2) with the approximation of . The main requirement for our method to work correctly and efficiently is that can be approximated by BLI with sufficient accuracy. This can only be verified experimentally. However, the accuracy of such approximation is inherently related to the smoothness of the solution, which is also one of the natural motivations behind using neural ODEs.
Our main contribution reads as follows.
We approximate activations with the barycentric Lagrange interpolation on a Chebyshev grid instead of solving IVP (2). Thus, we eliminate equations and reduce the dimension of the backward IVP. As a result, the training becomes faster.
We study the numerical stability of the proposed IAM and give the error bounds for the adjoint variable.
We compare our approach with the adjoint method and ANODE in terms of forward/backward pass time, loss, and the number of adaptive solver steps on classification and density estimation problems.
Most deep learning problems reduce to a minimization of the loss functionwith respect to parameter by gradient-based methods. To train neural ODEs, one has to compute gradients of the loss function , which depends on the solution111For simplicity, we assume that only affects the loss function. of the initial value problem (1). To compute , we can use the adjoint method. To use this method, one has to solve IVP (1) and save during the forward pass. The core idea of the adjoint method is to solve the following IVP backward in time from to
where is the adjoint variable. We call the system of (3), (4) and (5) the backward IVP. Since the adjoint method does not store intermediate activations during the forward pass, IVP (3) is used to restore these activations at all required points during the backward pass. It is assumed that solutions of (1) and (3) coincide. Nevertheless, these solutions may differ a lot (Gholami et al., 2019, see Fig. 7). Moreover, even if they coincide, it may take many ODE solver steps to solve (3). Since IVPs (3), (4) and (5) are managed simultaneously, small time steps to solve (3) can slow down the entire adjoint method.
In the proposed interpolated adjoint method (IAM), we suggest to eliminate (3) from the backward IVP. Instead of IVP (3), the IAM approximates activations obtained in the forward pass through the barycentric Lagrange interpolation (BLI) on a Chebyshev grid (Berrut and Trefethen, 2004). The detailed description of the BLI is presented below. We urge readers not to confuse the Lagrange interpolation, which is mostly of theoretical interest, with the BLI, that is widely used in practice for polynomial interpolation (Higham, 2004).
Before interpolating a vector functionon a segment with the BLI, one should construct a Chebyshev grid (Tyrtyshnikov, 2012)
and compute weights , depending only on grid . Function is approximated with
and are stored during the forward pass. The computational complexity of this approximation as well as the memory usage is . Since we approximate , only (4) and (5) have to be solved during the backward pass. Thus, the dimension of the backward IVP is reduced by the size of .
Another method that solves the instability issue of the adjoint method is ANODE, presented in (Gholami et al., 2019). ANODE stores checkpoints in the forward pass, and integrates IVP (2) at each subinterval using an ODE solver. In this case, the stability issues are resolved, but memory issues arise. Thus, it is not possible to train neural ODEs with a large batch size. The proposed method aims to resolve both stability and memory issues. In addition, in contrast to ANODE, our method allows using adaptive solvers in the backward pass.
In this section, we study the stability (Higham, 2002) of the adjoint variable . In other words, we aim to estimate how small perturbations of affect the corresponding adjoint variable . A well-known stability measure is the norm of the Jacobian matrix. It is often bounded in the spectral normalization technique (Miyato et al., 2018). However, in the theory of ordinary differential equations, much more tight estimate can be obtained by utilizing so-called logarithmic norm (Söderlind, 2006)
where is an arbitrary square matrix. Note that logarithmic norm can be negative. If to choose , then
is the maximum eigenvalue of a matrix. In this case, logarithmic norm can be computed by the power method (Mises and Pollaczek-Geiringer, 1929).
The following lemma helps to estimate the error of the adjoint variable approximation, which is computed with the interpolated (7) instead of the exact .
(Söderlind and Mattheij (1985), see Lemma 1)
Let be a solution of the following non-autonomous linear system
Then , where the scalar function satisfies the IVP
The solution of IVP (11) looks as follows
If the logarithmic norm is negative, the upper bound on does not grow exponentially with .
Let and be the exact and the perturbed solutions of IVP (1). Suppose and are the solutions of IVP (4), which uses and , respectively. Since lemma 3.1 considers forward-in-time ODEs, we change variables as and get
Let us denote partial derivatives in the right-hand sides of (13) by and , respectively, and the error by . Thus, we obtain the following IVP on error
Lemma 3.1 gives the following upper bound on the norm of perturbed adjoint variable
The norm of the error is bounded above as
The error norm naturally depends on the smoothness of the Jacobian matrix . It also depends on the logarithmic norm : the smaller the logarithmic norm, the smaller the error norm. If , then as .
In this section, we illustrate the performance of the IAM on density estimation and classification problems. The single additional hyperparameter in the IAM is the number of Chebyshev grid points. It is reported separately for each experiment, e.g., IAM(10) means that we use IAM with ten nodes. Our method is implemented on top of torchdiffeq222https://github.com/rtqichen/torchdiffeq/ package. Initially, DOPRI5 ODE solver from torchdiffeq provided steps that violated the considered time interval ; therefore, we have modified its implementation. The default ODE solver in our experiments is DOPRI5.
We trace the loss and the number of the right-hand side function evaluations in the forward pass and in the backward pass, respectively. By construction of the IAM, we expect that to achieve the same loss value by performing less number of function evaluations on the backward pass and approximately the same number of function evaluations on the forward pass. The number of function evaluations is the primary metrics of computational complexity. In addition, we report time measurements in supplementary materials. The values of hyperparameters are taken from the repositories mentioned above. All experiments were conducted on a single NVIDIA Tesla V100 GPU with 16Gb of memory.
The problem of density estimation is to reconstruct the probability density function using a set of given data points. We consider FFJORD333https://github.com/rtqichen/ffjord method (Grathwohl et al., 2018) that exploits neural ODE approach to density estimation problem. To perform backpropagation, we test the IAM and the adjoint method on the miniboone dataset. This dataset is used in experiments since it is large enough and FFJORD with the compared methods converges in a reasonable time. Neural ODEs for density estimation problem typically requires a single ODE block. In this setting, ANODE is equal to a vanilla backpropagation with the entire computational graph stored in memory. Therefore, we do not consider ANODE in the context of the density estimation problem. Figure 1 shows the number of function evaluations on the backward pass (Figure 0(a)), the number of function evaluations on the forward pass (Figure 0(c)) and train loss (Figure 0(b)) versus the number of iterations. These plots illustrate that for density estimation problem the IAM requires less number of the right-hand side function evaluations during the backward pass preserving the same train loss value as the adjoint method. We also provide dependence of validation loss on the elapsed running time for the IAM and the adjoint method in supplementary materials (see Figure 4(c)). A more detailed comparison of the IAM and the adjoint method in terms of running time of forward and backward passes is presented in supplementary materials (see Figure 6).
We test the proposed method on classification problem with CIFAR10 dataset. Since ANODE proposed by Gholami et al. (2019) is superior to the adjoint method for the classification problems, we compare the IAM with ANODE. In contrast to ANODE, the IAM can manage larger number of points in grid. For a model with a single convolutional layer, a single ODEBlock and a linear layer, IAM(16) obtained test accuracy with batch size equal to 512 and tolerance 1e-3. Such batch size limits the number of checkpoints in ANODE to 4. More checkpoints did not fit to the memory. ANODE with 4 checkpoints and the same hyperparameters gives similar test accuracy . Comparison of peak memory consumption for different number of checkpoints and batch sizes is given in Table 1.
|ODE solver||Batch size||Peak memory, Gb|
In supplementary materials we provide comparison of asymptotic memory consumption and a number of the right-hand side evaluations in ANODE and the IAM (see Table 2). We consider two settings: the single ODE block and sequentially stacked ODE blocks. The presented analysis proves the gain in memory consumption obtained with the IAM. To demonstrate this gain in practice, we plot peak memory for the different numbers of checkpoints and for different models (see Figure 4). This plot illustrates that the IAM can be used with a larger number of checkpoints (which are similar to points in Chebyshev grid) than ANODE. Therefore, an adaptive ODE solver (like DOPRI5) can solve IVP in the backward pass more precisely.
emphasize the importance of using adaptive solvers and introduce a procedure to learn their tolerance with reinforcement learning. A simple procedure to improve the expressive power of neural ODEs by adding dummy variables is proposed byDupont et al. (2019). Zhang et al. (2019) propose to learn time-dependent parameters with other IVP. In addition, the stiffness of ODE is an important concept (Söderlind et al., 2015) for stable and fast training of neural ODEs, since it affects the number of steps in ODE solver during both backward and forward passes (Wanner and Hairer, 1996).
We have developed the interpolated adjoint method (IAM) to improve the adjoint method for the training of neural ordinary differential equations models. The IAM is based on barycentric Lagrange interpolation on a Chebyshev grid. Our method resolves both stability and memory issues arising during the backward pass of neural ODEs. We have compared our approach with the existing methods for classification and density estimation problems. In particular, we have shown that the IAM requires less number of right-hand side function evaluations per epoch than the standard adjoint method for FFJORD training. At the same time, in classification problems the IAM also demonstrates the same accuracy and requires a smaller memory footprint than ANODE.
Sections 2 and 4 were supported by Ministry of Education and Science of the Russian Federation grant 14.756.31.0001. Sections 3 and 4 were supported by RFBR grant 19-29-09085. High-performance computations presented in the paper were carried out on Skoltech HPC cluster Zhores.
Deep neural networks motivated by partial differential equations. Journal of Mathematical Imaging and Vision, pp. 1–13. Cited by: §1.
We evaluate the performance of the proposed method on the CIFAR-10 classification problem. We train ResNet-like architecture for 350 epochs with an initial learning rate equal to 0.1. The learning rate decays by a factor of 10 at epoch 150 and 300. Data augmentation is implemented. The batch size used for training is 512. For all experiments with different methods (the standard adjoint method, ANODE, and the IAM), we use the same setting. ResNet-like architecture is a standard ResNet such that each ResNet block, except the first and downsample blocks, is replaced with the ODE block.
In Table 2, we summarize the number of function evaluations (right-hand side of equation (1)) and the asymptotic memory consumption in the forward and backward passes. We consider the case of a single ODE block (Table 1(a)) as well as sequentially stacked ODE blocks (Table 1(b)).
Assume that in ANODE we use the non-adaptive solver that makes function evaluations at each of fixed time steps. During the forward pass through the single ODE block, ANODE performs solver calls, i.e., it evaluates the function (represented with a neural network) times. ANODE does not store values at intermediate points of the interval. Therefore, during the backward pass, ANODE firstly recomputes values in intermediate points using function calls and then computes gradients using backpropagation. The latter operation requires storing the whole computational graph that requires memory, where denotes the number of features in the output of neural network and states for the number of network parameters.
In turn, during the forward pass, the IAM requires function evaluations, where is a number of points in the Chebyshev grid and is a number of function evaluations per solver step. The benefit of the IAM is that during the backward pass it does not require to use the same points as during the forward pass, since can be interpolated in any time point with activations pre-computed in Chebyshev grid points. Thus, in contrast to ANODE, the IAM does not recompute intermediate points using ODE solver during the backward pass. Illustration of the difference between the forward and backward passes in ANODE and the IAM is presented in Figure 2 (for single ODE block) and in Figure 3 (for stacked ODE blocks). Moreover, the interpolation technique makes possible to use adaptive solvers that are more general than non-adaptive (fixed step) ones.
From Table 2 follows that in contrast to ANODE, the number of function evaluations and memory consumption for IAM does not depend on the number of stacked ODE blocks . Assuming and , we can see that the IAM outperforms ANODE in terms of memory consumption in the backward pass. Although ANODE requires less amount of memory in the forward pass than the IAM, ANODE has to recompute many activations (values at intermediate points) during the backward pass. Therefore, memory-efficient forward pass in ANODE leads to extra computations and memory overhead in the backward pass.
In practice, the IAM requires a smaller memory footprint than ANODE if a small number of ODE blocks is used. Particularly, current neural ODEs for density estimation problems exploit a single ODE block. At the same time, we train neural ODEs with multiple ODE blocks for classification problems. The comparison of memory consumption in such models is provided below.
Figure 4 demonstrates memory requirements for the compared methods. ANODE uses non-adaptive solvers. Therefore, the time step is equal to , where is the number of checkpoints. If to make the number of checkpoints greater for the considered setting, it will not fit into 16Gb GPU memory.
To demonstrate the benefits of the proposed method in density estimation problem, we carry out experiments using the FFJORD approach. We compare FFJORD performance on miniboone dataset combining this approach with the adjoint method and the IAM. The results have been obtained for the same setting and model architecture as in the original paper (Grathwohl et al., 2018).
The IAM significantly reduces training time compared with the adjoint method due to more effective backward pass. Figure 5 demonstrates that FFJORD with the IAM requires a much smaller number of function evaluations during the backward pass (Figure 5(d)), while having the number of function calls during the forward pass on par with the adjoint method (Figure 5(b)). Thus, the IAM reduces the total number of function evaluations (Figure 5(f)) and total training time (Figure 5(e)) of FFJORD. Moreover, the overall performance of the model remains the same (Figure 5).
To better illustrate the difference between the IAM and the adjoint method, we provide Figure 4(c), where we show validation loss depending on running time.