torchdiffeq 是一个针对PyTorch的库,它提供了用于解决常微分方程(ODEs)的求解器。这个库主要用于实现神经常微分方程(Neural Ordinary Differential Equations,或简称 Neural ODEs),这是一种深度学习模型,它在传统的残差网络(ResNet)基础上,将层间的映射视作连续的动态系统。
首先,我们来解释一下什么是ODEs。
ODE方程描述的是一个或多个函数及其导数之间的关系,依赖于单一的独立变量,而时间是最常见的独立变量。可以用以下公式来表述:
d
x
d
f
(
t
)
=
g
(
f
(
x
))
这个表达式就是扩展成为NeuralODE的关键。
然后,我们来解释一下Neural ODEs的基本思想。实际上,对于让人头疼的微分方程,解析求解是一个很难的事情,因为我们需要数学表达式刚好就是能够套用解析求解的形式。更多的我们会采用数值求解的逼近法,知道初始值,给定一个步长,毕竟这个初始值附近的下一个值处的导数值,然后逐步逼近我们感兴趣的表达式。
其实神经网络本身就是一个大的函数逼近器。他是一个多层嵌套的函数,因为有多层嵌套的性质,所以可以使用链式法则传递梯度,传递的同时更新网络的参数。最终得到我们要求解的
最后我们介绍一下ResNET。
ResNet的特点在于引入了残差学习,当我们考虑ResNet的深度趋向无限大时,可以把ResNet的每一层看作时间上的一个无限小的步长,此时网络就变成了NeuralODE。