自动微分与深度学习框架


1. 学习问题与梯度下降

给定训练集 D = {(xₙ, yₙ)},目标是找到模型参数 θ,使损失函数最小化: \[ \min_{\theta} \mathcal{L}(\theta), \quad \mathcal{L} : \mathbb{R}^d \to \mathbb{R} \] 对于分类问题,标准损失是 交叉熵。PyTorch 的 CrossEntropyLoss 直接在原始 logits(softmax 之前)上应用负对数似然,这在数值上是稳定的。

为什么选择梯度下降?

在当前参数 xₜ 周围使用一阶泰勒展开: \[ f(x_t + \Delta x) \approx f(x_t) + \Delta x^T \nabla f \big|_{x_t} \] 为了使 f 减小,我们希望第二项尽可能为负。根据柯西-施瓦茨不等式,最佳方向是 与梯度相反,这给出了 SGD 更新规则: \[ x_{t+1} = x_t - \eta \cdot \nabla f \big|_{x_t} \] 其中 η 是学习率超参数。

SGD 算法

1
2
3
4
5
6
7
8
9
10
设置学习率 η
1. 初始化 θ ← θ₀
2. 对于 epoch = 1 到 maxEpoch:
3. 对于数据中的每个小批量:
4. total_g = 0
5. 对于批量中的每个 (x, y):
6. 计算误差: err(f(x; θ), y)
7. 计算梯度: g = ∂err/∂θ
8. total_g += g
9. 更新: θ = θ - η * total_g / N

关键问题: 如何高效计算 ∂ℒ/∂θ 对于任意深度网络中的 每个 参数?这正是自动微分所解决的问题。


2. 计算图

计算图 是一个有向无环图(DAG),其中:

  • 节点 代表变量(输入、参数、中间值)或操作。
  • 有向边 表示数据流 — 每个操作的输入。

通过拓扑排序进行前向评估

要评估图:

  1. 将所有节点放入未处理队列中。
  2. 重复查找所有输入已计算的节点。
  3. 评估该节点,将其移至已处理集合。
  4. 重复直到所有节点都被评估。

这确保每个节点在有效的依赖顺序中被计算一次。


3. 微分方法:分类

方法 工作原理 对 f 的成本: ℝⁿ → ℝ 实际应用
数值 在扰动点评估 f;有限差分 O(n) 前向传递 仅用于梯度检查 — 对训练来说太慢且不稳定
符号 代数地应用求和/乘积/链式法则 可能呈指数级膨胀 CAS 工具(Mathematica)。对于大型模型,表达式膨胀不切实际
前向模式 AD 在一次前向传递中传播导数“切线”与值 O(n) 前向传递 当 n 较小时效果良好(例如,Hessian-向量积)
反向模式 AD 先前向运行,然后向后传播“伴随” O(1) 后向传递 深度学习的主力 — 一次传递给出所有 ∂ℒ/∂θ

关键见解: 对于 f: ℝⁿ → ℝ(标量损失,多个参数),反向模式 AD 在与 单个 前向传递成比例的时间内计算完整梯度。前向模式需要 n 次传递。


4. 前向模式自动微分

前向模式 AD 通过携带每个中间值的 切线(导数值)来计算导数,沿着与前向计算相同的方向传播。

定义

对于每个中间值 vᵢ,定义: \[ \dot{v}_i = \frac{\partial v_i}{\partial x_1} \]

示例

y = ln(x₁) + x₁·x₂ − sin(x₂),计算 ∂y/∂x₁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 前向评估跟踪(值)
v1 = x1 # = 2
v2 = x2 # = 5
v3 = ln(v1) # = 0.693
v4 = v1 * v2 # = 10
v5 = sin(v2) # = −0.959
v6 = v3 + v4 # = 10.693
v7 = v6 − v5 # = 11.652 → y

# 前向 AD 跟踪(切线,种子用于 ∂/∂x₁)
v1_dot = 1 # 种子
v2_dot = 0 # 不对 x₂ 求导
v3_dot = v1_dot / v1 # = 0.5 (ln 的链式法则)
v4_dot = v1_dot*v2 + v2_dot*v1 # = 5.0 (乘积法则)
v5_dot = v2_dot * cos(v2) # = 0.0
v6_dot = v3_dot + v4_dot # = 5.5
v7_dot = v6_dot − v5_dot # = 5.5 → ∂y/∂x₁

限制

要获得 ∂y/∂x₂,我们需要使用不同的种子(v̇₁=0, v̇₂=1)重新运行整个跟踪。对于具有数百万参数的网络,这代价高昂 — 因此使用反向模式。


5. 反向模式自动微分

反向模式 AD 是 PyTorch 中 loss.backward() 的引擎。它在单次反向传递中计算相对于 所有 输入的梯度。

定义

定义每个节点 vᵢ 的 伴随

\[\bar{v}_i = \frac{\partial y}{\partial v_i}\]

种子 v̄_output = 1,然后按 反向拓扑顺序 遍历图。

示例(相同函数)

1
2
3
4
5
6
7
8
9
10
11
12
# 反向 AD 跟踪(伴随)
v7_bar = 1 # 种子: ∂y/∂y = 1
v6_bar = v7_bar * 1 # ∂v7/∂v6 = 1 (减法)
v5_bar = v7_bar * (-1) # ∂v7/∂v5 = −1
v4_bar = v6_bar * 1 # ∂v6/∂v4 = 1 (加法)
v3_bar = v6_bar * 1 # ∂v6/∂v3 = 1

# 有多个输出边的节点:求和贡献
v2_bar = v5_bar * cos(v2) + v4_bar * v1 # ≈ 1.716
v1_bar = v4_bar * v2 + v3_bar / v1 # = 5.5

# 结果:∂y/∂x₁ = 5.5, ∂y/∂x₂ ≈ 1.716 (一次传递中得到!)

多路径规则

当节点 vᵢ 输入多个下游节点时,其伴随是部分贡献的

\[\bar{v}*i = \sum*{j \in \text{next}(i)} \bar{v}_j \cdot \frac{\partial v_j}{\partial v_i}\]

反向 AD 算法

1
2
3
4
5
6
7
8
9
10
11
12
def gradient(out):
node_to_grad = {out: [1]} # 种子

for i in reverse_topo_order(out):
v_bar_i = sum(node_to_grad[i]) # 求和部分伴随

for k in inputs(i):
# 计算部分伴随: v̄ᵢ · ∂vᵢ/∂vₖ
v_k_to_i = v_bar_i * local_grad(i, k)
node_to_grad[k].append(v_k_to_i) # 累加

return 输入的伴随

6. 向量-雅可比乘积与张量伴随

在实践中,节点持有 张量(矩阵、向量),而不是标量。我们使用 雅可比VJP 进行推广。

雅可比

对于 y ∈ ℝᵐ,x ∈ ℝⁿ: \[ J = \frac{\partial y}{\partial x} = \begin{bmatrix} \partial y_1/\partial x_1 & \partial y_1/\partial x_2 \\ \partial y_2/\partial x_1 & \partial y_2/\partial x_2 \end{bmatrix} \] 我们 从不 显式形成 J(它可能非常庞大)。相反,我们计算 向量-雅可比乘积 (VJP)\[ \bar{x} = J^T \bar{y} \]

示例:线性层

对于 y = Wx

梯度 公式
相对于输入 x x̄ = Wᵀ ȳ
相对于权重 W W̄ = ȳ xᵀ

实现模式: 每个原始操作都配有 forward()vjp() 函数。框架只需知道这些原始操作;其他一切由图遍历处理。

7. 反向传播与反向模式 AD:图扩展

有两种实现策略,这一区别对高阶导数很重要。

反向传播(第一代)

  • 直接在前向图上运行反向操作。
  • 手动逐个节点计算伴随
  • 用于早期框架:Caffe、cuda-convnet。
  • 无法 自然支持梯度的梯度。

反向模式 AD 通过图扩展(现代)

  • 为每个伴随计算构建新图节点
  • 反向传递本身成为计算图。
  • 可以在其上运行另一次反向传递 → 高阶导数
  • 被 PyTorch、JAX、现代 TensorFlow 使用。
1
2
3
4
5
6
7
8
9
# 梯度函数返回一个计算图,而不仅仅是一个值。
# 我们可以组合并再次求导:

grad_fn = grad(loss_fn) # 一阶梯度(一个图)
hessian_fn = grad(grad_fn) # 二阶 — 对梯度求导!

# JAX 使这一过程特别简洁:
import jax
hessian = jax.grad(jax.grad(loss_fn))

关键收获: 通过图扩展的反向模式 AD 在功能上严格优于经典反向传播。这是现代框架能够计算 Hessian、高阶梯度和元学习目标的原因。


8. 整合:深度学习框架

深度学习框架需要具备:表达性(任何网络架构)、生产力(隐藏 CUDA,自动微分)和 效率(可扩展到大型模型,自动硬件加速)。

设计原则

  1. 定义 程序为一个符号数据流图,包含占位符、变量和操作。
  2. 执行 该图的优化版本在可用设备上。

基本组件(TensorFlow 风格)

组件 角色 示例
占位符 在运行时输入数据 tf.placeholder(tf.float32, (1, 784))
变量 用于参数的有状态节点;在执行之间保持状态 tf.Variable(tf.zeros((100,)))
常量 静态数据 tf.constant([[1, 2], [3, 4]])
操作 数学操作;必须定义前向 + 反向 tf.nn.relu(...), tf.matmul(...)
会话 将图绑定到设备(CPU/GPU)的执行上下文 tf.Session()

实现一个操作

1
2
3
4
5
6
7
8
9
10
11
class AddOperation(Operation):
"""定义加法操作:输出 = a + b"""
def __init__(self, a, b):
super().__init__([a, b]) # a, b 是输入节点

def forward(self, a, b):
return a + b

def backward(self, upstream_grad):
# ∂(a+b)/∂a = 1, ∂(a+b)/∂b = 1
return upstream_grad, upstream_grad

完整训练循环(使用 AutoGrad)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 定义训练目标
def objective(params, iter):
idx = batch_indices(iter)
return -log_posterior(params, train_images[idx], train_labels[idx], L2_reg)

# 通过自动微分获取梯度
objective_grad = grad(objective)

# 神经网络前向传递
def neural_net_predict(params, inputs):
"""params: (W, b) 元组的列表。inputs: (N x D) 矩阵。"""
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.tanh(outputs)
return outputs - logsumexp(outputs, axis=1, keepdims=True)

# 对数后验 = 对数先验 + 对数似然
def log_posterior(params, inputs, targets, L2_reg):
log_prior = -L2_reg * l2_norm(params)
log_lik = np.sum(neural_net_predict(params, inputs) * targets)
return log_prior + log_lik

实现会话(执行)

1
2
3
4
5
6
7
8
9
10
11
def run_session(end_node, feed_dict):
"""执行计算图。"""
for node in topological_sort(end_node):
if isinstance(node, Placeholder):
node.value = feed_dict[node]
elif isinstance(node, (Variable, Constant)):
pass # 值已设置
elif isinstance(node, Operation):
inputs = [n.value for n in node.input_nodes]
node.value = node.forward(*inputs)
return end_node.value

9. 框架比较:PyTorch、TensorFlow、JAX

方面 PyTorch TensorFlow JAX NumPy
范式 动态(急切) 静态图 / 急切 函数式转换 过程式
自动微分 动态计算图 静态计算图 函数式(grad/jit
硬件 CPU、GPU、TPU CPU、GPU、TPU CPU、GPU、TPU 仅 CPU
易用性 Pythonic 学习曲线陡峭 Pythonic + 函数式 非常简单
并行性 DataParallel / DDP tf.distribute pmap
生态系统 Lightning, TorchVision TensorBoard, TF Extended 与 NumPy 集成

动态与静态图

  • PyTorch(动态/急切): 计算图在编写 Python 代码时动态构建。使用标准打印语句易于调试。图可以在每次迭代中变化(例如,RNN 中不同的序列长度)。
  • TensorFlow v1(静态): 首先 定义 图,然后在会话中 执行。允许更激进的优化,但更难调试。
  • JAX(函数式): 没有显式的图对象。你编写纯函数并使用转换(jax.gradjax.jitjax.vmap)来获取梯度、JIT 编译和向量化。

10. 梯度检查与调试

在实现自定义反向传递时,数值梯度检查是你最好的朋友。

中心有限差分

\[\frac{\partial f(x_1, x_2)}{\partial x_1} \approx \frac{f(x_1 + h, x_2) - f(x_1 - h, x_2)}{2h}\]

这比单边公式更准确(误差为 O(h²) 而不是 O(h))。

最佳实践

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 梯度检查步骤:
# 1. 使用双精度(float64) — float32 的误差占主导
# 2. 选择一个小的 h(例如,1e-6)
# 3. 通过图计算前向差分两次
# 4. 与你的解析梯度比较

h = 1e-6

def numerical_grad(f, x, i):
"""通过中心差分计算 ∂f/∂xᵢ。"""
x_plus = x.copy(); x_plus[i] += h
x_minus = x.copy(); x_minus[i] -= h
return (f(x_plus) - f(x_minus)) / (2 * h)

# 更一般:选择一个随机方向 δ,检查方向导数
# δᵀ ∇f(θ) ≈ (f(θ + εδ) − f(θ − εδ)) / 2ε

提示: 如果你的解析梯度和数值梯度之间的差异超过 ~1e-5(相对误差),那么你的 VJP 实现中可能存在错误。检查零输入和边界条件等边缘情况。


总结

概念 功能
计算图 表示程序的 DAG;使前向评估和反向微分成为可能
拓扑排序 定义图的有效执行顺序
前向模式 AD 向前传播切线;对于 n 个输入,O(n) 次传递
反向模式 AD 向后传播伴随;对于标量输出,O(1) 次传递 — 高效训练的关键
VJP 反向 AD 在张量上的原始操作;避免形成完整的雅可比
图扩展 现代技术:反向传递构建新图,支持高阶导数
深度学习框架 结合:符号图定义 → 自动微分 → 在硬件上的优化执行