全书导航
大模型之路:从图灵、感知机到 ChatGPT · 卷 3

第 18 章:反向传播:错误如何变成学习

本章问题:模型如何知道每个参数该怎么调整?


第 17 章让我们有了一个能算任意函数的网络——但它的参数全是随机的。这一章解决:参数怎么从错误中学习。

18.1 从书页到黑板

第 8 章讲了反向传播的直觉和历史。第 17 章在代码里调用了 loss.backward()——一行魔法。这一章不做直觉,不做历史,不做魔法。

这一章用纸笔手算一遍。

读完本章,你不仅会知道反向传播"是什么"——你会亲手算出一个简单网络里梯度是怎样流过每一层的。然后你会发现:

反向传播不过是链式法则 + 耐心。


18.2 最小的训练样本

假设我们有一个极其简化的网络,只有一个神经元。输入 x 是一个标量(一个数),输出 ŷ 也是一个标量。只有一个参数 w:

ŷ = w × x

真实的正确答案是 y。损失使用均方误差(MSE):

loss = (ŷ - y)²

现在设具体数字:

  • x = 2(输入)
  • y = 10(正确答案)
  • w = 3(当前参数,初始猜测)

算一遍前向:

ŷ = w × x = 3 × 2 = 6loss = (6 - 10)² = 16

错得很厉害。问题:w 应该变大还是变小?变多少?


18.3 梯度:指向谷底的箭头

loss 是 w 的函数。画出来:横轴是 w,纵轴是 loss。当前的 loss=16,它在某个位置。如果你站在 w=3 的山坡上往周围看——哪边是下坡?

梯度回答的就是这个问题:loss 对 w 的导数 = 山坡在当前 w 位置的斜率。 符号告诉你方向(正号 = w 增大 loss 也会增大,所以 w 应该往反方向走;负号 = w 增大 loss 会减小),数值告诉你坡有多陡。

用链式法则拆解。loss 的计算分解为两步:

ŷ = w × x          (第一步:参数产生预测)loss = (ŷ - y)²    (第二步:预测和真值之间产生损失)

链式法则说:

∂loss/∂w = ∂loss/∂ŷ × ∂ŷ/∂w

先算第一项。loss = (ŷ - y)²,令 e = ŷ - y(误差):

∂loss/∂ŷ = 2(ŷ - y) = 2 × (6 - 10) = -8

这个数的含义:当前状态,如果 ŷ 在这里增加一点点,loss 会怎样变化?斜率为负——ŷ 增加 1,loss 会减少 8(在局部一阶线性近似的意义上)。

再算第二项。ŷ = w × x:

∂ŷ/∂w = x = 2

这个数的含义:w 增加 1,ŷ 就增加 x 倍。

连起来——链式法则的乘法,就是把两个"变化影响"叠加:

∂loss/∂w = ∂loss/∂ŷ × ∂ŷ/∂w = (-8) × 2 = -16

梯度 = -16。意思是:在当前 w=3 的位置,w 增加一点点,loss 会按 -16 的倍数变化(即减少约 16 倍的那个增量)。所以 w 应该往反方向——即正方向——走一点。


18.4 参数更新:往谷底走一步

有了梯度,怎么更新参数?梯度下降法则:

w_new = w_old - learning_rate × gradient

学习率是一个很小的数(比如 0.01),控制每一步走多远。

w_new = 3 - 0.01 × (-16) = 3 + 0.16 = 3.16

w 变大了。再算一次前向验证:

ŷ = 3.16 × 2 = 6.32loss = (6.32 - 10)² = 13.54

loss 从 16 降到了 13.54——确实在变好。如果你重复这个步骤几百次,w 会逐渐趋近于 5(因为 y=10, x=2,最优解是 w=5: ŷ=10, loss=0)。

这就是梯度下降的全部:沿着梯度的反方向,一小步一小步地走下山。


18.5 加入隐藏层:梯度的链条变长

现在把网络加深一层。两层网络,中间有一个隐藏神经元:

h = w₁ × x       (隐藏层)ŷ = w₂ × h       (输出层)loss = (ŷ - y)²

用同样的数据:x=2, y=10。初始化 w₁=1, w₂=1。

前向传播:

h = 1 × 2 = 2ŷ = 1 × 2 = 2loss = (2 - 10)² = 64

现在需要求两个梯度:∂loss/∂w₂ 和 ∂loss/∂w₁。

对于 w₂——它离 loss 很近,只有一个乘法:

∂loss/∂w₂ = ∂loss/∂ŷ × ∂ŷ/∂w₂           = 2(ŷ - y) × h           = 2(2 - 10) × 2           = -16 × 2           = -32

对于 w₁——它在更深的里面。梯度需要穿越两步:

∂loss/∂w₁ = ∂loss/∂ŷ × ∂ŷ/∂h × ∂h/∂w₁

第一项已经算过:∂loss/∂ŷ = -16。第二项 ∂ŷ/∂h = w₂ = 1。第三项 ∂h/∂w₁ = x = 2(如果没有激活函数)。

∂loss/∂w₁ = (-16) × 1 × 2 = -32

两个梯度都是 -32。更新(学习率=0.01):

w₁ = 1 - 0.01 × (-32) = 1.32w₂ = 1 - 0.01 × (-32) = 1.32

这就是反向传播的核心计算:误差信号从输出端一路往回走,每个参数拿到自己那一份"贡献度"。

如果中间有 ReLU 激活呢?h = ReLU(w₁×x)。那么:

∂h/∂w₁ = (∂ReLU/∂内部) × (∂内部/∂w₁)

对于 ReLU,当 w₁×x > 0 时 ∂ReLU = 1,当 w₁×x ≤ 0 时 ∂ReLU = 0。就是这样简单。


18.6 自动微分:为什么你不必手算

上面的计算在一个只有两个参数的网络上已经显得繁琐了。在一个 60,000,000 参数的网络上,手算是不可能的。

但如果你的网络是由有限几种基本运算组成(如矩阵乘法、逐元素 ReLU、逐元素指数运算),你可以把整个计算拆成一张运算图。这张图中的每一个节点代表一次运算,有向边代表数据流动。对于每一个运算,你不需要知道"全局的梯度公式"——你只需要知道这个运算本身的局部导数规则。

例如:

  • y = a + b → ∂y/∂a = 1, ∂y/∂b = 1
  • y = a × b → ∂y/∂a = b, ∂y/∂b = a
  • y = ReLU(a) → ∂y/∂a = 1(当 a > 0),反之为 0

然后整个梯度可以沿着运算图自动地从后往前算。这就是自动微分(autograd)——PyTorch、TensorFlow 和 JAX 都在做这件事。

当你在 PyTorch 里写 loss.backward() 时,引擎做的事情就是:回溯你从输入到 loss 的每一步计算,应用每步的局部导数规则,然后把梯度累积到每个可训练参数上。

你手算的那个"链式法则"被自动化了——但背后的数学,和你刚才在纸上算的一模一样。


18.7 本章小实验:换一个 loss 手算一遍

用同样的网络(一个神经元,x=2, y=10, w=3),但把损失函数换成平均绝对误差(MAE)

loss = |ŷ - y|    (绝对值误差,而不是平方误差)

在 w=3 时,ŷ 仍为 6。

问题是:∂loss/∂w 等于多少?

你需要处理绝对值函数的导数:

  • 当 ŷ > y 时,∂loss/∂ŷ = +1
  • 当 ŷ < y 时,∂loss/∂ŷ = -1
  • 当 ŷ = y 时,导数未定义(在实际训练中遇到这个点,算法通常把这个样本跳过或回传 0)

在这里,ŷ=6, y=10, 所以 ŷ < y → ∂loss/∂ŷ = -1。

∂loss/∂w = ∂loss/∂ŷ × ∂ŷ/∂w = (-1) × 2 = -2

对比 MSE 的梯度是 -16。同样的错误程度,MSE 让 w 往前走 16/2 = 8 倍的距离。因为 MSE 对"错得离谱"的惩罚是二次增长——偏差翻倍,乘法里的梯度分量也翻倍。

所以不同损失函数让训练朝不同方向、以不同加速度收敛。选择损失函数本质上是在定义——当网络犯错时,你认为哪种错"更不可接受"。

这也解释了大模型训练中常常看到的一个模式:交叉熵损失倾向于强烈惩罚"对正确类别的概率过低",这驱动模型在大规模分类问题中快速收敛。


18.8 本章地图

text
问题:模型如何知道每个参数该怎么调整?方法:从损失函数出发,沿网络的反方向应用链式法则(自动微分),算出每个参数对最终错误的梯度。手算示例:单神经元梯度 = 2(ŷ-y) × x;两层网络通过 ∂loss/∂ŷ × ∂ŷ/∂h × ∂h/∂w₁ 穿越所有层。引擎:PyTorch 的 autograd、TensorFlow 的 GradientTape 都是把手算的链式法则自动化。今天:从最小的两层网络到大规模语言模型,训练的原理没有变——梯度沿着运算图反向流回每个参数。

18.9 本章结语:梯度是世界上最小的老师

你在纸上算出的 -16——那个数,就是学习本身的最小粒子。

每一次训练迭代,每一个参数都收到这样一个信号:你低了,你高了,你往这个方向走多少。把这一条信号乘以几十亿个参数,乘以几十亿次训练迭代——你就得到了一个大型语言模型的训练过程。

所有让人赞叹的 AI 能力——写诗、编程、翻译、推理——在训练过程中,都是通过不断地把这些细碎的"梯度信号"注入参数空间而逐步形成的。每一个参数的每一次微调,分别单独看,意义为零。但把它们同时应用在足够多的数据上,它们就在压缩世界的统计规律。

下一章不离开神经网络,但换个角度看。卷积不是普通的全连接矩阵乘法——它是一种更有结构的视觉处理。为什么 CNN 特别擅长图像?我们来看卷积的直觉。

SECTION §02 · ENGAGE

Discussion

留言区 · GitHub-powered comments via Giscus