11868 LLM Sys & 15642 ML Sys: DL Frameworks and Auto Differentiation
自动微分与深度学习框架
1. 学习问题与梯度下降
给定训练集 D = {(xₙ, yₙ)},目标是找到模型参数
θ,使损失函数最小化: \[
\min_{\theta} \mathcal{L}(\theta), \quad \mathcal{L} : \mathbb{R}^d \to
\mathbb{R}
\] 对于分类问题,标准损失是 交叉熵。PyTorch 的
CrossEntropyLoss 直接在原始 logits(softmax
之前)上应用负对数似然,这在数值上是稳定的。
为什么选择梯度下降?
在当前参数 xₜ 周围使用一阶泰勒展开: \[ f(x_t + \Delta x) \approx f(x_t) + \Delta x^T \nabla f \big|_{x_t} \] 为了使 f 减小,我们希望第二项尽可能为负。根据柯西-施瓦茨不等式,最佳方向是 与梯度相反,这给出了 SGD 更新规则: \[ x_{t+1} = x_t - \eta \cdot \nabla f \big|_{x_t} \] 其中 η 是学习率超参数。
SGD 算法
1 | 设置学习率 η |
关键问题: 如何高效计算 ∂ℒ/∂θ 对于任意深度网络中的 每个 参数?这正是自动微分所解决的问题。
2. 计算图
计算图 是一个有向无环图(DAG),其中:
- 节点 代表变量(输入、参数、中间值)或操作。
- 有向边 表示数据流 — 每个操作的输入。
通过拓扑排序进行前向评估
要评估图:
- 将所有节点放入未处理队列中。
- 重复查找所有输入已计算的节点。
- 评估该节点,将其移至已处理集合。
- 重复直到所有节点都被评估。
这确保每个节点在有效的依赖顺序中被计算一次。
3. 微分方法:分类
| 方法 | 工作原理 | 对 f 的成本: ℝⁿ → ℝ | 实际应用 |
|---|---|---|---|
| 数值 | 在扰动点评估 f;有限差分 | O(n) 前向传递 | 仅用于梯度检查 — 对训练来说太慢且不稳定 |
| 符号 | 代数地应用求和/乘积/链式法则 | 可能呈指数级膨胀 | CAS 工具(Mathematica)。对于大型模型,表达式膨胀不切实际 |
| 前向模式 AD | 在一次前向传递中传播导数“切线”与值 | O(n) 前向传递 | 当 n 较小时效果良好(例如,Hessian-向量积) |
| 反向模式 AD | 先前向运行,然后向后传播“伴随” | O(1) 后向传递 | 深度学习的主力 — 一次传递给出所有 ∂ℒ/∂θ |
关键见解: 对于 f: ℝⁿ → ℝ(标量损失,多个参数),反向模式 AD 在与 单个 前向传递成比例的时间内计算完整梯度。前向模式需要 n 次传递。
4. 前向模式自动微分
前向模式 AD 通过携带每个中间值的 切线(导数值)来计算导数,沿着与前向计算相同的方向传播。
定义
对于每个中间值 vᵢ,定义: \[ \dot{v}_i = \frac{\partial v_i}{\partial x_1} \]
示例
y = ln(x₁) + x₁·x₂ − sin(x₂),计算 ∂y/∂x₁:
1 | # 前向评估跟踪(值) |
限制
要获得 ∂y/∂x₂,我们需要使用不同的种子(v̇₁=0, v̇₂=1)重新运行整个跟踪。对于具有数百万参数的网络,这代价高昂 — 因此使用反向模式。
5. 反向模式自动微分
反向模式 AD 是 PyTorch 中 loss.backward()
的引擎。它在单次反向传递中计算相对于 所有
输入的梯度。
定义
定义每个节点 vᵢ 的 伴随:
\[\bar{v}_i = \frac{\partial y}{\partial v_i}\]
种子 v̄_output = 1,然后按 反向拓扑顺序 遍历图。
示例(相同函数)
1 | # 反向 AD 跟踪(伴随) |
多路径规则
当节点 vᵢ 输入多个下游节点时,其伴随是部分贡献的 和:
\[\bar{v}*i = \sum*{j \in \text{next}(i)} \bar{v}_j \cdot \frac{\partial v_j}{\partial v_i}\]
反向 AD 算法
1 | def gradient(out): |
6. 向量-雅可比乘积与张量伴随
在实践中,节点持有 张量(矩阵、向量),而不是标量。我们使用 雅可比 和 VJP 进行推广。
雅可比
对于 y ∈ ℝᵐ,x ∈ ℝⁿ: \[ J = \frac{\partial y}{\partial x} = \begin{bmatrix} \partial y_1/\partial x_1 & \partial y_1/\partial x_2 \\ \partial y_2/\partial x_1 & \partial y_2/\partial x_2 \end{bmatrix} \] 我们 从不 显式形成 J(它可能非常庞大)。相反,我们计算 向量-雅可比乘积 (VJP): \[ \bar{x} = J^T \bar{y} \]
示例:线性层
对于 y = Wx:
| 梯度 | 公式 |
|---|---|
| 相对于输入 x | x̄ = Wᵀ ȳ |
| 相对于权重 W | W̄ = ȳ xᵀ |
实现模式: 每个原始操作都配有
forward()和vjp()函数。框架只需知道这些原始操作;其他一切由图遍历处理。
7. 反向传播与反向模式 AD:图扩展
有两种实现策略,这一区别对高阶导数很重要。
反向传播(第一代)
- 直接在前向图上运行反向操作。
- 手动逐个节点计算伴随 值。
- 用于早期框架:Caffe、cuda-convnet。
- 无法 自然支持梯度的梯度。
反向模式 AD 通过图扩展(现代)
- 为每个伴随计算构建新图节点。
- 反向传递本身成为计算图。
- 可以在其上运行另一次反向传递 → 高阶导数。
- 被 PyTorch、JAX、现代 TensorFlow 使用。
1 | # 梯度函数返回一个计算图,而不仅仅是一个值。 |
关键收获: 通过图扩展的反向模式 AD 在功能上严格优于经典反向传播。这是现代框架能够计算 Hessian、高阶梯度和元学习目标的原因。
8. 整合:深度学习框架
深度学习框架需要具备:表达性(任何网络架构)、生产力(隐藏 CUDA,自动微分)和 效率(可扩展到大型模型,自动硬件加速)。
设计原则
- 定义 程序为一个符号数据流图,包含占位符、变量和操作。
- 执行 该图的优化版本在可用设备上。
基本组件(TensorFlow 风格)
| 组件 | 角色 | 示例 |
|---|---|---|
| 占位符 | 在运行时输入数据 | tf.placeholder(tf.float32, (1, 784)) |
| 变量 | 用于参数的有状态节点;在执行之间保持状态 | tf.Variable(tf.zeros((100,))) |
| 常量 | 静态数据 | tf.constant([[1, 2], [3, 4]]) |
| 操作 | 数学操作;必须定义前向 + 反向 | tf.nn.relu(...), tf.matmul(...) |
| 会话 | 将图绑定到设备(CPU/GPU)的执行上下文 | tf.Session() |
实现一个操作
1 | class AddOperation(Operation): |
完整训练循环(使用 AutoGrad)
1 | # 定义训练目标 |
实现会话(执行)
1 | def run_session(end_node, feed_dict): |
9. 框架比较:PyTorch、TensorFlow、JAX
| 方面 | PyTorch | TensorFlow | JAX | NumPy |
|---|---|---|---|---|
| 范式 | 动态(急切) | 静态图 / 急切 | 函数式转换 | 过程式 |
| 自动微分 | 动态计算图 | 静态计算图 | 函数式(grad/jit) |
无 |
| 硬件 | CPU、GPU、TPU | CPU、GPU、TPU | CPU、GPU、TPU | 仅 CPU |
| 易用性 | Pythonic | 学习曲线陡峭 | Pythonic + 函数式 | 非常简单 |
| 并行性 | DataParallel / DDP |
tf.distribute |
pmap |
无 |
| 生态系统 | Lightning, TorchVision | TensorBoard, TF Extended | 与 NumPy 集成 | — |
动态与静态图
- PyTorch(动态/急切): 计算图在编写 Python 代码时动态构建。使用标准打印语句易于调试。图可以在每次迭代中变化(例如,RNN 中不同的序列长度)。
- TensorFlow v1(静态): 首先 定义 图,然后在会话中 执行。允许更激进的优化,但更难调试。
- JAX(函数式):
没有显式的图对象。你编写纯函数并使用转换(
jax.grad、jax.jit、jax.vmap)来获取梯度、JIT 编译和向量化。
10. 梯度检查与调试
在实现自定义反向传递时,数值梯度检查是你最好的朋友。
中心有限差分
\[\frac{\partial f(x_1, x_2)}{\partial x_1} \approx \frac{f(x_1 + h, x_2) - f(x_1 - h, x_2)}{2h}\]
这比单边公式更准确(误差为 O(h²) 而不是 O(h))。
最佳实践
1 | # 梯度检查步骤: |
提示: 如果你的解析梯度和数值梯度之间的差异超过 ~1e-5(相对误差),那么你的 VJP 实现中可能存在错误。检查零输入和边界条件等边缘情况。
总结
| 概念 | 功能 |
|---|---|
| 计算图 | 表示程序的 DAG;使前向评估和反向微分成为可能 |
| 拓扑排序 | 定义图的有效执行顺序 |
| 前向模式 AD | 向前传播切线;对于 n 个输入,O(n) 次传递 |
| 反向模式 AD | 向后传播伴随;对于标量输出,O(1) 次传递 — 高效训练的关键 |
| VJP | 反向 AD 在张量上的原始操作;避免形成完整的雅可比 |
| 图扩展 | 现代技术:反向传递构建新图,支持高阶导数 |
| 深度学习框架 | 结合:符号图定义 → 自动微分 → 在硬件上的优化执行 |




