Paper ID: 2410.11648

Efficient, Accurate and Stable Gradients for Neural ODEs

Sam McCallum, James Foster

Neural ODEs are a recently developed model class that combine the strong model priors of differential equations with the high-capacity function approximation of neural networks. One advantage of Neural ODEs is the potential for memory-efficient training via the continuous adjoint method. However, memory-efficient training comes at the cost of approximate gradients. Therefore, in practice, gradients are often obtained by simply backpropagating through the internal operations of the forward ODE solve - incurring high memory cost. Interestingly, it is possible to construct algebraically reversible ODE solvers that allow for both exact gradients and the memory-efficiency of the continuous adjoint method. Unfortunately, current reversible solvers are low-order and suffer from poor numerical stability. The use of these methods in practice is therefore limited. In this work, we present a class of algebraically reversible solvers that are both high-order and numerically stable. Moreover, any explicit numerical scheme can be made reversible by our method. This construction naturally extends to numerical schemes for Neural CDEs and SDEs.

Submitted: Oct 15, 2024