Ray's Blog

(五) 自动微分

自动微分的重要性

欢迎回到零基础深度学习之旅,今天探索自动微分。自动微分可以说是深度学习里最重要的技术。只要了解过深度学习,大概率都听过”反向传播(back propagation)“。反向传播算法是自动微分中反向模式的实现,也是本文要实现的模式。

简单回顾一下之前的内容,可以帮我们了解自动微分的重要性: 第一篇,我们学习了优化机制,我们知道要让模型学参数,得先有个目标函数。 第二篇,我们学习了梯度下降,我们知道为了有效地更新参数,需要求得目标函数在当前参数上的梯度。 上一篇,我们学习了各种优化器,可无论哪种,都需要梯度信息。

求梯度的方法可以分为三种:数值微分、符号微分、自动微分。我们已经在上文中见过数值微分的缺陷;符号微分也有它的麻烦。而本文要介绍的自动微分通过记录计算过程并应用链式法则(Chain Rule),使其能够高效准确地计算复杂函数的梯度。相比数值微分的精度问题和符号微分的表达式爆炸,自动微分既保证了数值精度,又能处理任意复杂的计算图(Computational graph)。其稳定、可扩展等优点,使成为 PyTorch、TensorFlow 等深度学习框架区别于 Numpy 的核心技术之一。

自动微分概念较为抽象,是我个人在学习过程中花时间最多的知识点。不过本文比我所学习时用的资料(主要是《the little learner》这本书)要少很多认知负担,哪怕是初学者,跟着本文的节奏也一定会有所收获。

梯度和链式法则

梯度的含义

在第二篇文章中,我们已经了解了梯度,这里先复习一下。

什么是梯度?简单来说,梯度就是一个函数对它的输入参数的变化的敏感度。

假设我们有一个函数 $f(a)$。我们把 $a $稍微调整一下,比如加上一个非常小的量 $Δa$,然后再计算一次函数的值 $f(a + Δa)$。很自然地,函数的输出也会跟着发生变化,这个变化我们记作 $Δf$ₐ。现在,我们就可以比较一下,输入改变$Δa$以后,输出的变化量,也就是$Δf$的值,之间的比:

\[Δf / Δa\]

当 $\Delta a$ 趋近于 0 时,这个比值就是函数在 $a$ 处的导数。处理多变量函数时,函数对单个变量的导数被称为偏导数,偏导数组成的向量被称为该函数的梯度。

我们的数值微分的实现正是基于这个定义。

复杂函数的梯度

下面看几个复杂一点的情况:

如果一个函数有不止一个输入,比如 $g(a, b)$,那我们就分别把 a 和 b 稍微调一下,然后观察函数值各自发生了多少变化。因为这个函数有两个输入,所以也有两个梯度信息,分别是:

  • 固定 $b$ 不变,只改变 $a$,观察输出的变化率:
    \(\frac{\partial g}{\partial a} = \lim_{\Delta a \to 0} \frac{g(a + \Delta a, b) - g(a, b)}{\Delta a}\)

  • 同理,固定 $a$ 不变,计算对 $b$ 的偏导数: \(\frac{\partial g}{\partial b} = \lim_{\Delta b \to 0} \frac{g(a, b + \Delta b) - g(a, b)}{\Delta b}\)

这两个偏导数组成的向量 $\left( \frac{\partial g}{\partial a}, \frac{\partial g}{\partial b} \right)$ 就是函数 $g$ 在 $(a, b)$ 处的梯度。

如果输入不只是单个数字,而是一个张量,那我们做的事情也是一样的:对张量里的每一个数值都稍微动一动,看看函数的输出变化了多少。最后,我们会得到一个结构和输入一样的张量,里面装的是每个位置上的梯度。

比如输入张量为: \(\mathbf{x} = \begin{bmatrix} x_{11} & x_{12} & x_{13} \\ x_{21} & x_{22} & x_{23} \end{bmatrix}\) 函数 $h(\mathbf{x})$ 输出一个标量。那么 $h$ 关于 $\mathbf{x}$ 的梯度也是一个 $2 \times 3$ 的张量,其第 $(i,j)$ 个元素是 $h$ 对 $x_{ij}$ 的偏导数: \(\nabla_\mathbf{x} h = \begin{bmatrix} \frac{\partial h}{\partial x_{11}} & \frac{\partial h}{\partial x_{12}} & \frac{\partial h}{\partial x_{13}} \\ \frac{\partial h}{\partial x_{21}} & \frac{\partial h}{\partial x_{22}} & \frac{\partial h}{\partial x_{23}} \end{bmatrix}\) 也就是说,无论你的输入结构多复杂,梯度也会有一样的形状。建议熟记此规则,最终达到看到输入结构即可直觉判断梯度形状的水平

链式法则

再看一个嵌套函数的情况。比如,我们有一个函数是先 $h$ 再 $g$ 再$f$,写成 $f(g(a))$,那我们怎么知道这个复杂函数对最初的 a 的梯度是多少呢?

答案是:我们可以一层一层地往回推。

我们让 b = g(a),再让 c = f(b)。那么 c 对 a 的变化率,也就是我们要的梯度,其实可以拆成两个部分的乘积:

\(\frac{dc}{da} = (\frac{dc}{db}) × (\frac{db}{da})\) 其中$f(g(a))$ 对 a 的导数,等于 f 对 g(a) 的导数乘上 g 对 a 的导数。也就是说我们不需要一次就搞清楚整个复杂表达式的梯度,而是可以把它拆成一段一段的子运算,每段只处理一个基本操作。

这就带来了一个非常关键的启发:如果我们知道每个最基本运算(比如加法、乘法、指数等等)的梯度计算方式,就可以像搭积木一样,把整个函数的梯度拼出来。这种计算梯度的方法就是链式法则,其具体代码实现就是自动微分。

注意形式上看起来像是 $\frac{dc}{db} \cdot \frac{db}{da}$ 中的 $db$ 被“约掉”了,不过这只是记号上的巧合,因为导数不是分数。

基本运算的梯度

现在我们知道了链式法则的原理,但要实现自动微分,我们还需要知道每个可导基本运算中的梯度计算。

我们需要预先定义一系列基本可导运算(如加法、乘法等)的梯度规则。这些规则构成了自动微分系统的“积木”。

例如,考虑加法运算 $z = a + b$:

  • 它对 $a$ 的偏导数是 1,
  • 对 $b$ 的偏导数也是 1, 也就是说如果你稍微动一动 a,结果的变化会和a的变化同样多;对 b 也是一样。所以加法的梯度就是 $(1.0,1.0)$——不管 a 和 b 的值具体是多少。

再看乘法 $z = a \cdot b$:

  • $\frac{\partial z}{\partial a} = b$
  • $\frac{\partial z}{\partial b} = a$ 如果只改动 a,那么它的变化量是 b × Δa;反过来改动 b,变化量是 a × Δb。所以它的梯度是 $(b, a)$。

有了加法了乘法,也就很容易得到减法和除法,因为减法就是加上负值,除法就是乘以倒数。

这些基本可导运算的梯度就是自动微分中的“积木块”,把他们搭建在一起就构成了自动微分。

计算图

有了积木块,接下来的问题就是怎么拼装这些积木。

这就引出了两个模式:正向模式(forward mode)反向模式(reverse mode)

假设有一个函数 f(g(h(x))),我们有两种走法:

正向模式:从输入出发,一路算上去 在正向模式中,我们从最里面的函数开始计算:先算 $h(x)$,把结果记为$v_1$,再算 $g(v_1)$,把结果记为$v_2$,最后算 $f(v_2)$,得到最终结果$v_3$。

把整个过程和中间结果都记录下来,会得到一个树状的结构,这个结构便是计算图。在本例中,由于没有分支,它表现为简单的链式结构:$x→v_1 →v_2 →v_3$。

在正向模式中,我们一边计算函数值,一边累积梯度:

  • 计算 $v_1 = h(x)$,同时计算 $\dot{v}_1 = \frac{\partial v_1}{\partial x}$
  • 计算 $v_2 = g(v_1)$,利用链式法则: \(\dot{v}_2 = \frac{\partial v_2}{\partial x} = \frac{\partial v_2}{\partial v_1} \cdot \frac{\partial v_1}{\partial x} = \frac{\partial g}{\partial v_1} \cdot \dot{v}_1\)
  • 类似地,计算: \(\dot{v}_3 = \frac{\partial f}{\partial v_2} \cdot \dot{v}_2\)

最终得到 $\frac{\partial v_3}{\partial x} = \dot{v}_3$。整个过程像“前向传播梯度”。

反向模式:从输出开始,往输入回推 反向模式的思路恰好相反:先完整地进行一次正向计算,记录下所有中间结果和计算图结构,然后从输出开始,沿着计算图反向传播梯度。

还是上面的例子,反向模式从输出开始反向传播“影响”。输出 $v_3$ 对自己的导数为 1,即: \(\dot{v}_3 = \frac{\partial v_3}{\partial v_3} = 1\) 然后根据链式法则,向上游传播:

  • $v_2$ 的梯度为:$\dot{v}_2 = \dot{v}_3 \cdot \frac{\partial v_3}{\partial v_2}$
  • $v_1$ 的梯度为:$\dot{v}_1 = \dot{v}_2 \cdot \frac{\partial v_2}{\partial v_1}$
  • 输入 $x$ 的梯度为:$\dot{x} = \dot{v}_1 \cdot \frac{\partial v_1}{\partial x}$

最终结果是一样的,但计算顺序不同:正向模式是边算函数值边算梯度,反向模式是先算完所有函数值,再统一算梯度。

这种反向计算梯度的方式就是深度学习中最常听说的反向传播(backpropagation),也就是反向模式自动微分的实现。

反向模式的一个巨大优势是:无论有多少个输入参数,只要最终结果是个标量(比如l2_loss),就只需要一次反向遍历就能得到所有参数的梯度。这使得它特别适合深度学习中参数量很多但输出结果为标量的场景。

实际应用

接下来,我们通过一个具体的数值例子来手动计算梯度,加深对自动微分的理解。

例子函数: $f(x) = \sin(x^2 + 3x)$,在 $x = 2$ 处计算函数值和梯度

第一步:构建计算图

要计算 $f(x) = \sin(x^2 + 3x)$,我们需要将其分解为基本操作:

正向:
输入 x
   │
   ▼
 [x²] ───┐
         ├──→ [+] ─→ [sin] ─→ 输出 f
 [×3] ──┘
   │
   └──→ (梯度同步传播)

这个计算图清晰地展示了数据如何从输入 $x=2$ 流向最终输出。

第二步:正向模式计算

目标: 在 $x=2$ 处计算 $f(x)$ 的值和 $\frac{df}{dx}$

在正向模式中,我们从输入开始,在计算每个中间值的同时,也计算其对 $x$ 的导数。

初始条件:$x = 2$,$\frac{dx}{dx} = 1$

操作 函数值计算 导数计算
输入 $x = 2$ $\frac{dx}{dx} = 1$
操作1: $v_1 = x^2$ $v_1 = 2^2 = 4$ $\frac{dv_1}{dx} = 2x \cdot 1 = 2 \times 2 = 4$
操作2: $v_2 = 3x$ $v_2 = 3 \times 2 = 6$ $\frac{dv_2}{dx} = 3 \times 1 = 3$
操作3: $v_3 = v_1 + v_2$ $v_3 = 4 + 6 = 10$ $\frac{dv_3}{dx} = 4 + 3 = 7$
操作4: $v_4 = \sin(v_3)$ $v_4 = \sin(10) \approx -0.544$ $\frac{dv_4}{dx} = \cos(10) \times 7 \approx -0.839 \times 7 \approx -5.873$

最终结果:

  • 函数值:$f(2) \approx -0.544$
  • 梯度:$\frac{df}{dx}\big _{x=2} \approx -5.873$
第三步:反向模式计算
反向:
正向(算值):
输入 x
   │
   ▼
 [x²] ───┐
         ├──→ [+] ─→ [sin] ─→ 输出 f
 [×3] ──┘

反向(传梯度):
输入 x ←─ [x²] ←───┐
                   ├──← [+] ←─ [sin] ←─ 输出 f
       ←─ [×3] ←───┘

目标: 同样在 $x=2$ 处计算梯度,但使用反向传播

正向计算(只记录值):

  • $v_1 = x^2 = 4$
  • $v_2 = 3x = 6$
  • $v_3 = v_1 + v_2 = 10$
  • $v_4 = \sin(v_3) \approx -0.544$

在这个过程中,系统会记录下所有操作及其输入输出,形成一条梯度记录(tape),供后续反向传播使用。例如:

Tape记录:
- (x, square, v₁):x² → v₁ = 4
- (x, multiply_by_3, v₂):3x → v₂ = 6  
- (v₁, v₂, add, v₃):v₁ + v₂ → v₃ = 10
- (v₃, sin, v₄):sin(v₃) → v₄ ≈ -0.544

反向传播(利用tape计算梯度):

从输出开始,初始梯度为 $\frac{\partial f}{\partial v_4} = 1$

  1. 从输出开始:$\frac{\partial f}{\partial v_4} = 1$
  2. 经 $\sin$ 操作:$\frac{\partial f}{\partial v_3} = 1 \cdot \cos(10) \approx -0.839$
  3. 经加法操作:梯度均等传给 $v_1$ 和 $v_2$,即 $\frac{\partial f}{\partial v_1} = \frac{\partial f}{\partial v_2} = -0.839$
  4. 回传到 $x$:
    • 通过 $v_1 = x^2$:贡献为 $-0.839 \times 2x = -0.839 \times 4 = -3.356$
    • 通过 $v_2 = 3x$:贡献为 $-0.839 \times 3 = -2.517$
  5. 总梯度:$-3.356 + (-2.517) = \boxed{-5.873}$

:最终梯度是两条路径($x \to v_1 \to v_3$ 和 $x \to v_2 \to v_3$)贡献的和。

最终结果: $\frac{df}{dx}\big _{x=2} \approx -5.873$
对比总结

两种模式都得到了相同的结果:

  • 函数值:$f(2) \approx -0.544$
  • 梯度:$\frac{df}{dx}\big _{x=2} \approx -5.873$

综上,反向模式通过:(1) 完整前向计算构建tape; (2) 从输出反向传播梯度; (3) 组合局部梯度,高效解决了多参数场景的梯度计算问题。

代码实现:System A

理解了概念,下面开始代码实践。在前面几篇文章中,我们已经实现了基础的张量和优化器,但那时的标量只是普通的数值。现在为了支持自动微分,我们需要让系统“记住”计算过程,并能反向传播梯度。

熟悉面向对象编程的朋友可能会立刻想到:把标量包装成一个类,包含 value 和 grad 属性,再重载 addmul 等运算方法,在每次计算时同时更新值和梯度。这确实是主流框架(如 PyTorch)的做法,清晰直观。

但本文延续《The Little Learner》的函数式风格,采用一种更“函数式”的方式:用闭包来构建计算图,并通过递归调用链式法则完成反向传播。这种方式虽然初看有些抽象,但它能让我们更贴近自动微分的本质——不是“存储梯度”,而是“记录如何传播梯度”。

我们把这种实现方式命名为System A。

数据结构

我们从定义一个核心数据结构开始:

来看具体代码:

from dataclasses import dataclass, field
from itertools import count

_id_generator = count()

@dataclass(frozen=True)
class Dual:
    r: Scalar        # 实际的数值(real part)
    k: Callable      # 闭包捕获了运算上下文,封装了梯度传播逻辑
    id: int = field(default_factory=lambda: next(_id_generator), init=False) 

Dual 类就是一个“增强版标量”,它不仅包含数值 r,还有一个关键的 k 属性——我们称之为 link 函数。它是一个闭包,捕获了当前运算的上下文(比如输入值、运算类型)。函数内部又调用了输入值的link函数,整个函数调用栈就是一个计算图。下面会详细解释。

每个 Dual 还有一个唯一 id,用于在反向传播时记录梯度,避免重复计算。

为了方便处理普通数值和 Dual 对象,我们定义几个辅助函数:

def is_dual(d: Any) -> bool:
    return isinstance(d, Dual)

def get_r(d: Scalar) -> float:
    return d.r if is_dual(d) else d  # 提取数值,兼容普通数和 Dual

def get_k(d: Scalar) -> Callable:
    return d.k if is_dual(d) else end_of_chain  # 提取 link 函数,普通数用空函数兜底

def is_scalar(d: Any) -> bool:
    return isinstance(d, (int, float)) or is_dual(d)
基本运算的自动微分实现

现在,我们以加法和乘法为例,看看如何用 Dual 构建自动微分。

先看加法:

def add_00(da, db):
    ra, rb = get_r(da), get_r(db)
    result_r = ra + rb  # 正向计算:a + b

    def link(_result_dual, z, grad_tape) -> None:
        # z 是从下游传来的梯度(∂L/∂output)
        # 加法的梯度规则:对 a 和 b 的偏导都是 1
        ga, gb = (z, z)  # 所以梯度原样传给两个输入

        # 记录当前节点的梯度到 tape
        grad_tape[da.id] = ga
        grad_tape[db.id] = gb

        # 递归调用上游的 link 函数,继续反向传播
        ka, kb = get_k(da), get_k(db)
        ka(da, ga, grad_tape)
        kb(db, gb, grad_tape)

    return Dual(result_r, link)

可以看到,首先和普通的加法运算一样,首先得到ra + rb的结果;然后定义link函数。

link函数有三个参数:_result_dual是当前运算结果(反向传播中通常不使用);z是上游传递过来的的梯度。 最后一个是grad_tape,根据变量名就知道这是记录梯度信息的地方,它是一个字典。

link函数执行三个关键步骤: (1) 根据当前运算的梯度规则计算中间梯度; (2) 将梯度记录到梯度字典(grad_tape)中; (3) 递归调用输入变量的link函数,继续向计算图上游传播梯度。

乘法的实现类似,但梯度规则不同:

def mul_00(da, db):
    ra, rb = get_r(da), get_r(db)
    result_r = ra * rb  # 正向计算:a * b

    def link(_result_dual, z, grad_tape) -> None:
        # 乘法的梯度规则:
        # ∂(a*b)/∂a = b, ∂(a*b)/∂b = a
        # 再乘上从下游传来的梯度 z(链式法则)
        ga, gb = (rb * z, ra * z)

        grad_tape[da.id] = ga
        grad_tape[db.id] = gb

        ka, kb = get_k(da), get_k(db)
        ka(da, ga, grad_tape)
        kb(db, gb, grad_tape)

    return Dual(result_r, link)

如果是叶节点,也就是计算到最后一步,则不需要继续传播,只记录梯度:

# 除了不继续传播,和其他的link函数没有本质区别
def end_of_chain(d: Scalar, z: float, grad_tape: dict) -> None:
    # 如果结果不是dual类,说明不需要记录其梯度
    if is_dual(d):

        dg= sigma.get(d.id, 0.0)
        sigma[d.id] = dg + z

def make_dual(d: Scalar) -> Dual:
    """创建一个叶节点(用end_of_chain作为k属性的Dual)"""
    return Dual(get_r(d), end_of_chain)

可以看到,link 函数的核心逻辑是:

根据当前运算的梯度规则,以及上游的梯度z,计算出中间结果的梯度。 将梯度写入 grad_tape(一个字典,用于记录和提取结果)。 调用输入变量的 link 函数,将梯度继续向上传播。

抽象通用模式:prim1 与 prim2

加法、乘法、指数、对数等运算,结构高度相似:都是先算值,再构造一个 link 函数。我们可以抽象出两个高阶函数,分别处理一元和二元运算:

def prim2(primal_fn: Callable, derivative_fn: Callable) -> Callable:
    """
    构造二元运算的自动微分版本
    primal_fn: 正向计算函数(如 lambda a,b: a+b)
    derivative_fn: 梯度函数,返回 (dL/da, dL/db)
    """
    def primitive(da, db):
        ra, rb = get_r(da), get_r(db)
        result_rho = primal_fn(ra, rb)

        # 如果两个输入都不是 Dual,说明不需要求梯度,直接返回数值
        # 创建 link 闭包开销很大,非必要不创建
        if not is_dual(da) and not is_dual(db):
            return result_rho

        def link(d, g, grad_tape) -> None:
            ga, gb = derivative_fn(ra, rb, g)  # 计算梯度
            # 不需要知道中间节点的梯度,可以省略记录步骤
            ka, kb = get_k(da), get_k(db)
            ka(da, ga, grad_tape)  # 继续向 a 传播
            kb(db, gb, grad_tape)  # 继续向 b 传播

        return Dual(result_rho, link)
    return primitive

# 一元操作逻辑相同
def prim1(primal_fn: Callable, derivative_fn: Callable) -> Callable:

    def primitive(da):

        ra = get_r(da)
        result = primal_fn(ra)

        if not is_dual(da):
            return result

        def link(d, g, grad_tape) -> None:
            ga = derivative_fn(ra, g)

            ka = get_k(da)
            ka(da, ga, grad_tape)
        return Dual(result, link)

    return primitive

有了这个抽象,我们就可以简洁地定义所有基本运算:

import math

# 二元运算
add_00 = prim2(lambda a,b: a+b, lambda a,b,z: (z, z))
mul_00 = prim2(lambda a,b: a*b, lambda a,b,z: (b*z, a*z))
div_00 = prim2(
    lambda ra, rb: ra / rb,
    lambda ra, rb, z: (z / rb,-ra * z / (rb**2)) # d/da (a/b) = 1/b, d/db (a/b) = -a/b^2
)
sub_00 = prim2(
    lambda ra, rb: ra - rb, 
    lambda ra, rb, z: (z, -z)  # d/da (a-b) = 1, d/db (a-b) = -1
)

# 一元运算
exp_0 = prim1(lambda a: math.exp(a), lambda a,z: math.exp(a)*z)
log_0 = prim1(lambda a: math.log(a), lambda a,z: z/a)

def pow_n(n: float):
    return prim1(
        lambda ra: ra**n,
        lambda ra, z: n * (ra ** (n - 1)) * z # d/da (a^n) = n*a^(n-1)
    )

square_0 = pow_n(2.0)

sqrt_0 = prim1(
    lambda ra: math.sqrt(ra),
    lambda ra, z: 0.5 * z / math.sqrt(ra),  # d/da sqrt(a) = 0.5 / sqrt(a)
)

最后,使用 ext1 和 ext2 将这些标量运算扩展到张量(嵌套列表)上:

tsqr = ext1(sqrt_0, 0)
tsqrt = ext1(sqrt_0, 0)
tlog = ext1(log_0, 0)
texp = ext1(exp_0, 0)

tsub = ext2(sub_00, 0, 0)
tmul = ext2(mul_00, 0, 0)
tadd = ext2(add_00, 0, 0)
tdiv = ext2(div_00, 0, 0)
tsqr = ext1(square_0, 0)
tsqrt = ext1(sqrt_0, 0)
获取梯度:nabla 与 get_grad

现在,只要我们用这些增强版运算构建一个计算过程,系统就会自动记录计算图。最后,我们通过 get_grad 来提取梯度

def get_grad(y: Dual, wrt: Theta) -> Theta:
    """从输出 y 中提取 wrt(关于哪些变量)的梯度"""
    sigma = {}               # 梯度字典(id到梯度的映射),即反向传播所需的计算图记录(tape)
    get_tape(y, sigma)          # 启动反向传播,填充 sigma

    return [sigma.get(d.id, 0.0) for d in wrt]   # 映射回原始结构(我们想知道的只有theta的梯度,已知theta是一个列表).

get_tape 是反向传播的入口:

def get_tape(y: Dual | list[Dual], grad_tape: dict) -> None:
    if is_scalar(y):
        k = get_k(y)
        k(y, 1.0, grad_tape)  # 从输出开始,初始梯度为 1.0
    elif isinstance(y, list):
        for item in y:
            get_tape(item, grad_tape)

新的nabla函数:

def nabla(f: Callable, theta: Theta):
    # 1. 将输入参数转为 Dual(开启微分)
    wrt = [make_dual(t) for t in theta]
    # 2. 执行函数,得到结果(构建计算图)
    result = f(wrt)
    # 3. 反向传播,提取梯度
    return get_grad(result, wrt)

总结 & 下一篇预告

在本文中我们理解了自动微分的核心思想是:将复杂函数分解为基本运算,通过链式法则组合这些基本运算的梯度。以及做出了我们的第一个代码实现——System A。

既然有System A,那么自然有System B。System B旨在解决System A的一些性能问题。本打算在本文中一并放出更高效的实现版本,但考虑到内容已经足够丰富,为了保证阅读的流畅性,我决定将这部分内容独立出来,作为下一篇《(五点五) 自动微分 2.0》。下一篇文章将介绍如何通过张量级操作和向量化计算提升性能。

完自动微分的进阶实现,我们立刻就会正式踏入神经网络的世界。届时,我们将用自己亲手打造的自动微分引擎,来训练真正的神经网络模型。敬请期待!